TAFG update

This commit is contained in:
2020-09-18 12:03:44 +08:00
parent 61e04de8a5
commit b01016edb5
6 changed files with 91 additions and 59 deletions

View File

@@ -58,6 +58,10 @@ class EngineKernel(object):
self.logger = logging.getLogger(config.name)
self.generators, self.discriminators = self.build_models()
self.train_generator_first = True
self.engine = None
def bind_engine(self, engine):
self.engine = engine
def build_models(self) -> (dict, dict):
raise NotImplemented
@@ -154,6 +158,7 @@ def get_trainer(config, kernel: EngineKernel):
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_shd)
kernel.change_engine(config, trainer)
kernel.bind_engine(trainer)
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values()), epoch_bound=False).attach(trainer, "loss_g")
RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values()), epoch_bound=False).attach(trainer, "loss_d")
@@ -186,9 +191,11 @@ def get_trainer(config, kernel: EngineKernel):
with torch.no_grad():
g = torch.Generator()
g.manual_seed(config.misc.random_seed)
random_start = torch.randperm(len(engine.state.test_dataset) - 11, generator=g).tolist()[0]
for i in range(random_start, random_start + 10):
g.manual_seed(config.misc.random_seed + engine.state.epoch
if config.handler.test.random else config.misc.random_seed)
random_start = \
torch.randperm(len(engine.state.test_dataset) - config.handler.test.images, generator=g).tolist()[0]
for i in range(random_start, random_start + config.handler.test.images):
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
for k in batch:
if isinstance(batch[k], torch.Tensor):