almost same as mmedit

This commit is contained in:
2020-08-08 13:17:26 +08:00
parent 7cf235781d
commit a5133e6795
3 changed files with 61 additions and 18 deletions

View File

@@ -40,14 +40,16 @@ def get_trainer(config, logger):
config.optimizers.discriminator)
milestones_values = [
(config.data.train.scheduler.start, config.optimizers.generator.lr),
(config.max_iteration, config.data.train.scheduler.target_lr),
(0, config.optimizers.generator.lr),
(100, config.optimizers.generator.lr),
(200, config.data.train.scheduler.target_lr)
]
lr_scheduler_g = PiecewiseLinear(optimizer_g, param_name="lr", milestones_values=milestones_values)
milestones_values = [
(config.data.train.scheduler.start, config.optimizers.discriminator.lr),
(config.max_iteration, config.data.train.scheduler.target_lr),
(0, config.optimizers.discriminator.lr),
(100, config.optimizers.discriminator.lr),
(200, config.data.train.scheduler.target_lr)
]
lr_scheduler_d = PiecewiseLinear(optimizer_d, param_name="lr", milestones_values=milestones_values)
@@ -73,13 +75,14 @@ def get_trainer(config, logger):
discriminator_a.requires_grad_(False)
discriminator_b.requires_grad_(False)
loss_g = dict(
id_a=config.loss.id.weight * id_loss(generator_a(real_b), real_b), # G_A(B)
id_b=config.loss.id.weight * id_loss(generator_b(real_a), real_a), # G_B(A)
cycle_a=config.loss.cycle.weight * cycle_loss(rec_a, real_a),
cycle_b=config.loss.cycle.weight * cycle_loss(rec_b, real_b),
gan_a=config.loss.gan.weight * gan_loss(discriminator_a(fake_b), True),
gan_b=config.loss.gan.weight * gan_loss(discriminator_b(fake_a), True)
)
if config.loss.id.weight > 0:
loss_g["id_a"] = config.loss.id.weight * id_loss(generator_a(real_b), real_b), # G_A(B)
loss_g["id_b"] = config.loss.id.weight * id_loss(generator_b(real_a), real_a), # G_B(A)
sum(loss_g.values()).backward()
optimizer_g.step()
@@ -116,8 +119,8 @@ def get_trainer(config, logger):
trainer = Engine(_step)
trainer.logger = logger
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_g)
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_d)
trainer.add_event_handler(Events.EPOCH_COMPLETED, lr_scheduler_g)
trainer.add_event_handler(Events.EPOCH_COMPLETED, lr_scheduler_d)
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
RunningAverage(output_transform=lambda x: sum(x["loss"]["d_a"].values())).attach(trainer, "loss_d_a")
@@ -129,10 +132,12 @@ def get_trainer(config, logger):
lr_scheduler_d=lr_scheduler_d, lr_scheduler_g=lr_scheduler_g
)
setup_common_handlers(trainer, config.output_dir, print_interval_event=Events.ITERATION_COMPLETED(every=10),
metrics_to_print=["loss_g", "loss_d_a", "loss_d_b"], to_save=to_save,
resume_from=config.resume_from, n_saved=5, filename_prefix=config.name,
save_interval_event=Events.ITERATION_COMPLETED(every=config.checkpoints.interval))
setup_common_handlers(trainer, config.output_dir, resume_from=config.resume_from, n_saved=5,
filename_prefix=config.name, to_save=to_save,
print_interval_event=Events.ITERATION_COMPLETED(every=10) | Events.COMPLETED,
metrics_to_print=["loss_g", "loss_d_a", "loss_d_b"],
save_interval_event=Events.ITERATION_COMPLETED(
every=config.checkpoints.interval) | Events.COMPLETED)
@trainer.on(Events.ITERATION_COMPLETED(once=config.max_iteration))
def terminate(engine):
@@ -147,12 +152,23 @@ def get_trainer(config, logger):
def global_step_transform(*args, **kwargs):
return trainer.state.iteration
def output_transform(output):
loss = dict()
for tl in output["loss"]:
if isinstance(output["loss"][tl], dict):
for l in output["loss"][tl]:
loss[f"{tl}_{l}"] = output["loss"][tl][l]
else:
loss[tl] = output["loss"][tl]
return loss
tb_logger.attach(
trainer,
log_handler=OutputHandler(
tag="loss",
metric_names=["loss_g", "loss_d_a", "loss_d_b"],
global_step_transform=global_step_transform,
output_transform=output_transform
),
event_name=Events.ITERATION_COMPLETED(every=50)
)