TAFG update
This commit is contained in:
@@ -58,6 +58,10 @@ class EngineKernel(object):
|
||||
self.logger = logging.getLogger(config.name)
|
||||
self.generators, self.discriminators = self.build_models()
|
||||
self.train_generator_first = True
|
||||
self.engine = None
|
||||
|
||||
def bind_engine(self, engine):
|
||||
self.engine = engine
|
||||
|
||||
def build_models(self) -> (dict, dict):
|
||||
raise NotImplemented
|
||||
@@ -154,6 +158,7 @@ def get_trainer(config, kernel: EngineKernel):
|
||||
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_shd)
|
||||
|
||||
kernel.change_engine(config, trainer)
|
||||
kernel.bind_engine(trainer)
|
||||
|
||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values()), epoch_bound=False).attach(trainer, "loss_g")
|
||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values()), epoch_bound=False).attach(trainer, "loss_d")
|
||||
@@ -186,9 +191,11 @@ def get_trainer(config, kernel: EngineKernel):
|
||||
|
||||
with torch.no_grad():
|
||||
g = torch.Generator()
|
||||
g.manual_seed(config.misc.random_seed)
|
||||
random_start = torch.randperm(len(engine.state.test_dataset) - 11, generator=g).tolist()[0]
|
||||
for i in range(random_start, random_start + 10):
|
||||
g.manual_seed(config.misc.random_seed + engine.state.epoch
|
||||
if config.handler.test.random else config.misc.random_seed)
|
||||
random_start = \
|
||||
torch.randperm(len(engine.state.test_dataset) - config.handler.test.images, generator=g).tolist()[0]
|
||||
for i in range(random_start, random_start + config.handler.test.images):
|
||||
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
|
||||
for k in batch:
|
||||
if isinstance(batch[k], torch.Tensor):
|
||||
|
||||
Reference in New Issue
Block a user