add lmdb dataset support and EpisodicDataset
This commit is contained in:
@@ -14,40 +14,28 @@ from ignite.contrib.handlers import ProgressBar
|
||||
from util.build import build_model, build_optimizer
|
||||
from util.handler import setup_common_handlers
|
||||
from data.transform import transform_pipeline
|
||||
from data.dataset import LMDBDataset
|
||||
|
||||
|
||||
def baseline_trainer(config, logger, val_loader):
|
||||
def baseline_trainer(config, logger):
|
||||
model = build_model(config.model, config.distributed.model)
|
||||
optimizer = build_optimizer(model.parameters(), config.baseline.optimizers)
|
||||
loss_fn = nn.CrossEntropyLoss()
|
||||
trainer = create_supervised_trainer(model, optimizer, loss_fn, idist.device(), non_blocking=True)
|
||||
trainer = create_supervised_trainer(model, optimizer, loss_fn, idist.device(), non_blocking=True,
|
||||
output_transform=lambda x, y, y_pred, loss: (loss.item(), y_pred, y))
|
||||
trainer.logger = logger
|
||||
RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
|
||||
RunningAverage(output_transform=lambda x: x[0]).attach(trainer, "loss")
|
||||
Accuracy(output_transform=lambda x: (x[1], x[2])).attach(trainer, "acc")
|
||||
ProgressBar(ncols=0).attach(trainer)
|
||||
|
||||
val_metrics = {
|
||||
"accuracy": Accuracy(),
|
||||
"nll": Loss(loss_fn)
|
||||
}
|
||||
evaluator = create_supervised_evaluator(model, val_metrics, idist.device())
|
||||
ProgressBar(ncols=0).attach(evaluator)
|
||||
|
||||
@trainer.on(Events.EPOCH_COMPLETED)
|
||||
def log_training_loss(engine):
|
||||
logger.info(f"Epoch[{engine.state.epoch}] Loss: {engine.state.output:.2f}")
|
||||
evaluator.run(val_loader)
|
||||
metrics = evaluator.state.metrics
|
||||
logger.info("Training Results - Avg accuracy: {:.2f} Avg loss: {:.2f}"
|
||||
.format(trainer.state.epoch, metrics["accuracy"], metrics["nll"]))
|
||||
|
||||
if idist.get_rank() == 0:
|
||||
GpuInfo().attach(trainer, name='gpu')
|
||||
|
||||
tb_logger = TensorboardLogger(log_dir=config.output_dir)
|
||||
tb_logger.attach(
|
||||
evaluator,
|
||||
trainer,
|
||||
log_handler=OutputHandler(
|
||||
tag="val",
|
||||
tag="train",
|
||||
metric_names='all',
|
||||
global_step_transform=global_step_from_engine(trainer),
|
||||
),
|
||||
@@ -70,8 +58,7 @@ def baseline_trainer(config, logger, val_loader):
|
||||
to_save = dict(model=model, optimizer=optimizer, trainer=trainer)
|
||||
setup_common_handlers(trainer, config.output_dir, print_interval_event=Events.EPOCH_COMPLETED, to_save=to_save,
|
||||
save_interval_event=Events.EPOCH_COMPLETED(every=25), n_saved=5,
|
||||
metrics_to_print=["loss"])
|
||||
save_best_model_by_val_score(config.output_dir, evaluator, model, "accuracy", 1, trainer)
|
||||
metrics_to_print=["loss", "acc"])
|
||||
return trainer
|
||||
|
||||
|
||||
@@ -80,14 +67,13 @@ def run(task, config, logger):
|
||||
torch.backends.cudnn.benchmark = True
|
||||
logger.info(f"start task {task}")
|
||||
if task == "baseline":
|
||||
train_dataset = ImageFolder(config.baseline.data.dataset.train.path,
|
||||
transform=transform_pipeline(config.baseline.data.dataset.train.pipeline))
|
||||
val_dataset = ImageFolder(config.baseline.data.dataset.val.path,
|
||||
transform=transform_pipeline(config.baseline.data.dataset.val.pipeline))
|
||||
train_dataset = LMDBDataset(config.baseline.data.dataset.train.lmdb_path,
|
||||
pipeline=config.baseline.data.dataset.train.pipeline)
|
||||
# train_dataset = ImageFolder(config.baseline.data.dataset.train.path,
|
||||
# transform=transform_pipeline(config.baseline.data.dataset.train.pipeline))
|
||||
logger.info(f"train with dataset:\n{train_dataset}")
|
||||
train_data_loader = idist.auto_dataloader(train_dataset, **config.baseline.data.dataloader)
|
||||
val_data_loader = idist.auto_dataloader(val_dataset, **config.baseline.data.dataloader)
|
||||
trainer = baseline_trainer(config, logger, val_data_loader)
|
||||
trainer = baseline_trainer(config, logger)
|
||||
try:
|
||||
trainer.run(train_data_loader, max_epochs=400)
|
||||
except Exception:
|
||||
|
||||
Reference in New Issue
Block a user