This commit is contained in:
2020-09-05 22:00:17 +08:00
parent 39c754374c
commit e3c760d0c5
12 changed files with 122 additions and 43 deletions

View File

@@ -90,10 +90,10 @@ class TAFGEngineKernel(EngineKernel):
loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator
loss[f"fm_{phase}"] = self.config.loss.fm.weight * loss_fm
loss["recon"] = self.recon_loss(generated["a"], batch["a"]) * self.config.loss.recon.weight
loss["style_recon"] = self.config.loss.style_recon.weight * self.style_recon_loss(
self.generators["main"].module.style_encoders["b"](batch["b"]),
self.generators["main"].module.style_encoders["b"](generated["b"])
)
# loss["style_recon"] = self.config.loss.style_recon.weight * self.style_recon_loss(
# self.generators["main"].module.style_encoders["b"](batch["b"]),
# self.generators["main"].module.style_encoders["b"](generated["b"])
# )
return loss
def criterion_discriminators(self, batch, generated) -> dict:

View File

@@ -148,6 +148,9 @@ class UGATITTestEngineKernel(TestEngineKernel):
def run(task, config, _):
if task == "train":
kernel = UGATITEngineKernel(config)
if task == "test":
run_kernel(task, config, kernel)
elif task == "test":
kernel = UGATITTestEngineKernel(config)
run_kernel(task, config, kernel)
run_kernel(task, config, kernel)
else:
raise NotImplemented

View File

@@ -160,7 +160,7 @@ def get_trainer(config, kernel: EngineKernel):
for k in output["img"]:
image_list = output["img"][k]
tensorboard_handler.writer.add_image(f"train/{k}", make_2d_grid(image_list),
tensorboard_handler.writer.add_image(f"train/{k}", make_2d_grid(image_list, range=(-1, 1)),
engine.state.iteration * pairs_per_iteration)
test_images[k] = []
for i in range(len(image_list)):
@@ -183,7 +183,7 @@ def get_trainer(config, kernel: EngineKernel):
for k in test_images:
tensorboard_handler.writer.add_image(
f"test/{k}",
make_2d_grid([torch.cat(ti) for ti in test_images[k]]),
make_2d_grid([torch.cat(ti) for ti in test_images[k]], range=(-1, 1)),
engine.state.iteration * pairs_per_iteration
)
return trainer
@@ -218,14 +218,12 @@ def get_tester(config, kernel: TestEngineKernel):
for i in range(batch_size):
image_name = Path(paths[i]).name
torchvision.utils.save_image([img[i] for img in img_tensors], engine.state.img_output_dir / image_name,
nrow=len(img_tensors))
nrow=len(img_tensors), padding=0, normalize=True, range=(-1, 1))
return tester
def run_kernel(task, config, kernel):
assert torch.backends.cudnn.enabled
torch.backends.cudnn.benchmark = True
logger = logging.getLogger(config.name)
with read_write(config):
real_batch_size = config.data.train.dataloader.batch_size * idist.get_world_size()