TAHG 0.0.3
This commit is contained in:
@@ -44,7 +44,7 @@ def build_lr_schedulers(optimizers, config):
|
||||
)
|
||||
|
||||
|
||||
def get_trainer(config, logger):
|
||||
def get_trainer(config, logger, train_data_loader):
|
||||
generator = build_model(config.model.generator, config.distributed.model)
|
||||
discriminators = dict(
|
||||
a=build_model(config.model.discriminator, config.distributed.model),
|
||||
@@ -85,11 +85,12 @@ def get_trainer(config, logger):
|
||||
def _step(engine, batch):
|
||||
batch = convert_tensor(batch, idist.device())
|
||||
real = dict(a=batch["a"], b=batch["b"])
|
||||
content_img = batch["edge"]
|
||||
fake = dict(
|
||||
a=generator(content_img=content_img, style_img=real["a"], which_decoder="a"),
|
||||
b=generator(content_img=content_img, style_img=real["b"], which_decoder="b"),
|
||||
a=generator(content_img=batch["edge_a"], style_img=real["a"], which_decoder="a"),
|
||||
b=generator(content_img=batch["edge_a"], style_img=real["b"], which_decoder="b"),
|
||||
)
|
||||
rec_b = generator(content_img=batch["edge_b"], style_img=real["b"], which_decoder="b")
|
||||
rec_bb = generator(content_img=batch["edge_b"], style_img=fake["b"], which_decoder="b")
|
||||
|
||||
optimizers["g"].zero_grad()
|
||||
loss_g = dict()
|
||||
@@ -99,8 +100,10 @@ def get_trainer(config, logger):
|
||||
loss_g[f"gan_{d}"] = config.loss.gan.weight * gan_loss(pred_fake, True)
|
||||
_, t = perceptual_loss(fake[d], real[d])
|
||||
loss_g[f"perceptual_{d}"] = config.loss.perceptual.weight * t
|
||||
loss_g[f"edge_{d}"] = config.loss.edge.weight * edge_loss(fake[d], content_img)
|
||||
loss_g[f"edge_{d}"] = config.loss.edge.weight * edge_loss(fake[d], batch["edge_a"])
|
||||
loss_g["recon_a"] = config.loss.recon.weight * recon_loss(fake["a"], real["a"])
|
||||
loss_g["recon_b"] = config.loss.recon.weight * recon_loss(rec_b, real["b"])
|
||||
loss_g["recon_bb"] = config.loss.recon.weight * recon_loss(rec_bb, real["b"])
|
||||
sum(loss_g.values()).backward()
|
||||
optimizers["g"].step()
|
||||
|
||||
@@ -118,7 +121,10 @@ def get_trainer(config, logger):
|
||||
optimizers["d"].step()
|
||||
|
||||
generated_img = {f"real_{k}": real[k].detach() for k in real}
|
||||
generated_img["rec_b"] = rec_b.detach()
|
||||
generated_img["rec_bb"] = rec_b.detach()
|
||||
generated_img.update({f"fake_{k}": fake[k].detach() for k in fake})
|
||||
generated_img.update({f"edge_{k}": batch[f"edge_{k}"].expand(-1, 3, -1, -1).detach() for k in "ab"})
|
||||
return {
|
||||
"loss": {
|
||||
"g": {ln: loss_g[ln].mean().item() for ln in loss_g},
|
||||
@@ -153,20 +159,21 @@ def get_trainer(config, logger):
|
||||
loss[tl] = output["loss"][tl]
|
||||
return loss
|
||||
|
||||
tensorboard_handler = setup_tensorboard_handler(trainer, config, output_transform)
|
||||
iter_per_epoch = len(train_data_loader)
|
||||
tensorboard_handler = setup_tensorboard_handler(trainer, config, output_transform, iter_per_epoch)
|
||||
if tensorboard_handler is not None:
|
||||
tensorboard_handler.attach(
|
||||
trainer,
|
||||
log_handler=OptimizerParamsHandler(optimizers["g"], tag="optimizer_g"),
|
||||
event_name=Events.ITERATION_STARTED(every=config.interval.tensorboard.scalar)
|
||||
event_name=Events.ITERATION_STARTED(every=max(iter_per_epoch // config.interval.tensorboard.scalar, 1))
|
||||
)
|
||||
|
||||
@trainer.on(Events.ITERATION_COMPLETED(every=config.interval.tensorboard.image))
|
||||
@trainer.on(Events.ITERATION_COMPLETED(every=max(iter_per_epoch // config.interval.tensorboard.image, 1)))
|
||||
def show_images(engine):
|
||||
output = engine.state.output
|
||||
image_order = dict(
|
||||
a=["real_a", "fake_a"],
|
||||
b=["real_b", "fake_b"]
|
||||
a=["edge_a", "real_a", "fake_a", "fake_b"],
|
||||
b=["edge_b", "real_b", "rec_b", "rec_bb"]
|
||||
)
|
||||
for k in "ab":
|
||||
tensorboard_handler.writer.add_image(
|
||||
@@ -175,6 +182,42 @@ def get_trainer(config, logger):
|
||||
engine.state.iteration
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
g = torch.Generator()
|
||||
g.manual_seed(config.misc.random_seed)
|
||||
random_start = torch.randperm(len(engine.state.test_dataset) - 11, generator=g).tolist()[0]
|
||||
test_images = dict(
|
||||
a=[[], [], [], []],
|
||||
b=[[], [], [], []]
|
||||
)
|
||||
for i in range(random_start, random_start + 10):
|
||||
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
|
||||
for k in batch:
|
||||
batch[k] = batch[k].view(1, *batch[k].size())
|
||||
|
||||
real = dict(a=batch["a"], b=batch["b"])
|
||||
fake = dict(
|
||||
a=generator(content_img=batch["edge_a"], style_img=real["a"], which_decoder="a"),
|
||||
b=generator(content_img=batch["edge_a"], style_img=real["b"], which_decoder="b"),
|
||||
)
|
||||
rec_b = generator(content_img=batch["edge_b"], style_img=real["b"], which_decoder="b")
|
||||
rec_bb = generator(content_img=batch["edge_b"], style_img=fake["b"], which_decoder="b")
|
||||
|
||||
test_images["a"][0].append(batch["edge_a"])
|
||||
test_images["a"][1].append(batch["a"])
|
||||
test_images["a"][2].append(fake["a"])
|
||||
test_images["a"][3].append(fake["b"])
|
||||
test_images["b"][0].append(batch["edge_b"])
|
||||
test_images["b"][1].append(batch["b"])
|
||||
test_images["b"][2].append(rec_b)
|
||||
test_images["b"][3].append(rec_bb)
|
||||
for n in "ab":
|
||||
tensorboard_handler.writer.add_image(
|
||||
f"test/{n}",
|
||||
make_2d_grid([torch.cat(ti) for ti in test_images[n]]),
|
||||
engine.state.iteration
|
||||
)
|
||||
|
||||
return trainer
|
||||
|
||||
|
||||
@@ -189,7 +232,7 @@ def run(task, config, logger):
|
||||
train_dataset = data.DATASET.build_with(config.data.train.dataset)
|
||||
logger.info(f"train with dataset:\n{train_dataset}")
|
||||
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
|
||||
trainer = get_trainer(config, logger)
|
||||
trainer = get_trainer(config, logger, train_data_loader)
|
||||
if idist.get_rank() == 0:
|
||||
test_dataset = data.DATASET.build_with(config.data.test.dataset)
|
||||
trainer.state.test_dataset = test_dataset
|
||||
|
||||
Reference in New Issue
Block a user