update
This commit is contained in:
@@ -90,10 +90,10 @@ class TAFGEngineKernel(EngineKernel):
|
||||
loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator
|
||||
loss[f"fm_{phase}"] = self.config.loss.fm.weight * loss_fm
|
||||
loss["recon"] = self.recon_loss(generated["a"], batch["a"]) * self.config.loss.recon.weight
|
||||
loss["style_recon"] = self.config.loss.style_recon.weight * self.style_recon_loss(
|
||||
self.generators["main"].module.style_encoders["b"](batch["b"]),
|
||||
self.generators["main"].module.style_encoders["b"](generated["b"])
|
||||
)
|
||||
# loss["style_recon"] = self.config.loss.style_recon.weight * self.style_recon_loss(
|
||||
# self.generators["main"].module.style_encoders["b"](batch["b"]),
|
||||
# self.generators["main"].module.style_encoders["b"](generated["b"])
|
||||
# )
|
||||
return loss
|
||||
|
||||
def criterion_discriminators(self, batch, generated) -> dict:
|
||||
|
||||
@@ -148,6 +148,9 @@ class UGATITTestEngineKernel(TestEngineKernel):
|
||||
def run(task, config, _):
|
||||
if task == "train":
|
||||
kernel = UGATITEngineKernel(config)
|
||||
if task == "test":
|
||||
run_kernel(task, config, kernel)
|
||||
elif task == "test":
|
||||
kernel = UGATITTestEngineKernel(config)
|
||||
run_kernel(task, config, kernel)
|
||||
run_kernel(task, config, kernel)
|
||||
else:
|
||||
raise NotImplemented
|
||||
|
||||
@@ -160,7 +160,7 @@ def get_trainer(config, kernel: EngineKernel):
|
||||
|
||||
for k in output["img"]:
|
||||
image_list = output["img"][k]
|
||||
tensorboard_handler.writer.add_image(f"train/{k}", make_2d_grid(image_list),
|
||||
tensorboard_handler.writer.add_image(f"train/{k}", make_2d_grid(image_list, range=(-1, 1)),
|
||||
engine.state.iteration * pairs_per_iteration)
|
||||
test_images[k] = []
|
||||
for i in range(len(image_list)):
|
||||
@@ -183,7 +183,7 @@ def get_trainer(config, kernel: EngineKernel):
|
||||
for k in test_images:
|
||||
tensorboard_handler.writer.add_image(
|
||||
f"test/{k}",
|
||||
make_2d_grid([torch.cat(ti) for ti in test_images[k]]),
|
||||
make_2d_grid([torch.cat(ti) for ti in test_images[k]], range=(-1, 1)),
|
||||
engine.state.iteration * pairs_per_iteration
|
||||
)
|
||||
return trainer
|
||||
@@ -218,14 +218,12 @@ def get_tester(config, kernel: TestEngineKernel):
|
||||
for i in range(batch_size):
|
||||
image_name = Path(paths[i]).name
|
||||
torchvision.utils.save_image([img[i] for img in img_tensors], engine.state.img_output_dir / image_name,
|
||||
nrow=len(img_tensors))
|
||||
nrow=len(img_tensors), padding=0, normalize=True, range=(-1, 1))
|
||||
|
||||
return tester
|
||||
|
||||
|
||||
def run_kernel(task, config, kernel):
|
||||
assert torch.backends.cudnn.enabled
|
||||
torch.backends.cudnn.benchmark = True
|
||||
logger = logging.getLogger(config.name)
|
||||
with read_write(config):
|
||||
real_batch_size = config.data.train.dataloader.batch_size * idist.get_world_size()
|
||||
|
||||
Reference in New Issue
Block a user