almost same as mmedit
This commit is contained in:
@@ -40,14 +40,16 @@ def get_trainer(config, logger):
|
||||
config.optimizers.discriminator)
|
||||
|
||||
milestones_values = [
|
||||
(config.data.train.scheduler.start, config.optimizers.generator.lr),
|
||||
(config.max_iteration, config.data.train.scheduler.target_lr),
|
||||
(0, config.optimizers.generator.lr),
|
||||
(100, config.optimizers.generator.lr),
|
||||
(200, config.data.train.scheduler.target_lr)
|
||||
]
|
||||
lr_scheduler_g = PiecewiseLinear(optimizer_g, param_name="lr", milestones_values=milestones_values)
|
||||
|
||||
milestones_values = [
|
||||
(config.data.train.scheduler.start, config.optimizers.discriminator.lr),
|
||||
(config.max_iteration, config.data.train.scheduler.target_lr),
|
||||
(0, config.optimizers.discriminator.lr),
|
||||
(100, config.optimizers.discriminator.lr),
|
||||
(200, config.data.train.scheduler.target_lr)
|
||||
]
|
||||
lr_scheduler_d = PiecewiseLinear(optimizer_d, param_name="lr", milestones_values=milestones_values)
|
||||
|
||||
@@ -73,13 +75,14 @@ def get_trainer(config, logger):
|
||||
discriminator_a.requires_grad_(False)
|
||||
discriminator_b.requires_grad_(False)
|
||||
loss_g = dict(
|
||||
id_a=config.loss.id.weight * id_loss(generator_a(real_b), real_b), # G_A(B)
|
||||
id_b=config.loss.id.weight * id_loss(generator_b(real_a), real_a), # G_B(A)
|
||||
cycle_a=config.loss.cycle.weight * cycle_loss(rec_a, real_a),
|
||||
cycle_b=config.loss.cycle.weight * cycle_loss(rec_b, real_b),
|
||||
gan_a=config.loss.gan.weight * gan_loss(discriminator_a(fake_b), True),
|
||||
gan_b=config.loss.gan.weight * gan_loss(discriminator_b(fake_a), True)
|
||||
)
|
||||
if config.loss.id.weight > 0:
|
||||
loss_g["id_a"] = config.loss.id.weight * id_loss(generator_a(real_b), real_b), # G_A(B)
|
||||
loss_g["id_b"] = config.loss.id.weight * id_loss(generator_b(real_a), real_a), # G_B(A)
|
||||
sum(loss_g.values()).backward()
|
||||
optimizer_g.step()
|
||||
|
||||
@@ -116,8 +119,8 @@ def get_trainer(config, logger):
|
||||
|
||||
trainer = Engine(_step)
|
||||
trainer.logger = logger
|
||||
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_g)
|
||||
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_d)
|
||||
trainer.add_event_handler(Events.EPOCH_COMPLETED, lr_scheduler_g)
|
||||
trainer.add_event_handler(Events.EPOCH_COMPLETED, lr_scheduler_d)
|
||||
|
||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
|
||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["d_a"].values())).attach(trainer, "loss_d_a")
|
||||
@@ -129,10 +132,12 @@ def get_trainer(config, logger):
|
||||
lr_scheduler_d=lr_scheduler_d, lr_scheduler_g=lr_scheduler_g
|
||||
)
|
||||
|
||||
setup_common_handlers(trainer, config.output_dir, print_interval_event=Events.ITERATION_COMPLETED(every=10),
|
||||
metrics_to_print=["loss_g", "loss_d_a", "loss_d_b"], to_save=to_save,
|
||||
resume_from=config.resume_from, n_saved=5, filename_prefix=config.name,
|
||||
save_interval_event=Events.ITERATION_COMPLETED(every=config.checkpoints.interval))
|
||||
setup_common_handlers(trainer, config.output_dir, resume_from=config.resume_from, n_saved=5,
|
||||
filename_prefix=config.name, to_save=to_save,
|
||||
print_interval_event=Events.ITERATION_COMPLETED(every=10) | Events.COMPLETED,
|
||||
metrics_to_print=["loss_g", "loss_d_a", "loss_d_b"],
|
||||
save_interval_event=Events.ITERATION_COMPLETED(
|
||||
every=config.checkpoints.interval) | Events.COMPLETED)
|
||||
|
||||
@trainer.on(Events.ITERATION_COMPLETED(once=config.max_iteration))
|
||||
def terminate(engine):
|
||||
@@ -147,12 +152,23 @@ def get_trainer(config, logger):
|
||||
def global_step_transform(*args, **kwargs):
|
||||
return trainer.state.iteration
|
||||
|
||||
def output_transform(output):
|
||||
loss = dict()
|
||||
for tl in output["loss"]:
|
||||
if isinstance(output["loss"][tl], dict):
|
||||
for l in output["loss"][tl]:
|
||||
loss[f"{tl}_{l}"] = output["loss"][tl][l]
|
||||
else:
|
||||
loss[tl] = output["loss"][tl]
|
||||
return loss
|
||||
|
||||
tb_logger.attach(
|
||||
trainer,
|
||||
log_handler=OutputHandler(
|
||||
tag="loss",
|
||||
metric_names=["loss_g", "loss_d_a", "loss_d_b"],
|
||||
global_step_transform=global_step_transform,
|
||||
output_transform=output_transform
|
||||
),
|
||||
event_name=Events.ITERATION_COMPLETED(every=50)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user