TAFG
This commit is contained in:
@@ -132,9 +132,12 @@ def get_trainer(config, kernel: EngineKernel):
|
||||
generated = kernel.forward(batch)
|
||||
|
||||
if kernel.train_generator_first:
|
||||
# simultaneous, train G with simultaneous D
|
||||
loss_g = train_generators(batch, generated)
|
||||
loss_d = train_discriminators(batch, generated)
|
||||
else:
|
||||
# update discriminators first, not simultaneous.
|
||||
# train G with updated discriminators
|
||||
loss_d = train_discriminators(batch, generated)
|
||||
loss_g = train_generators(batch, generated)
|
||||
|
||||
@@ -152,8 +155,8 @@ def get_trainer(config, kernel: EngineKernel):
|
||||
|
||||
kernel.change_engine(config, trainer)
|
||||
|
||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
|
||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d")
|
||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values()), epoch_bound=False).attach(trainer, "loss_g")
|
||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values()), epoch_bound=False).attach(trainer, "loss_d")
|
||||
to_save = dict(trainer=trainer)
|
||||
to_save.update({f"lr_scheduler_{k}": lr_schedulers[k] for k in lr_schedulers})
|
||||
to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers})
|
||||
@@ -188,7 +191,13 @@ def get_trainer(config, kernel: EngineKernel):
|
||||
for i in range(random_start, random_start + 10):
|
||||
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
|
||||
for k in batch:
|
||||
batch[k] = batch[k].view(1, *batch[k].size())
|
||||
if isinstance(batch[k], torch.Tensor):
|
||||
batch[k] = batch[k].unsqueeze(0)
|
||||
elif isinstance(batch[k], dict):
|
||||
for kk in batch[k]:
|
||||
if isinstance(batch[k][kk], torch.Tensor):
|
||||
batch[k][kk] = batch[k][kk].unsqueeze(0)
|
||||
|
||||
generated = kernel.forward(batch)
|
||||
images = kernel.intermediate_images(batch, generated)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user