working
This commit is contained in:
@@ -189,34 +189,33 @@ def get_trainer(config, kernel: EngineKernel):
|
||||
for i in range(len(image_list)):
|
||||
test_images[k].append([])
|
||||
|
||||
with torch.no_grad():
|
||||
g = torch.Generator()
|
||||
g.manual_seed(config.misc.random_seed + engine.state.epoch
|
||||
if config.handler.test.random else config.misc.random_seed)
|
||||
random_start = \
|
||||
torch.randperm(len(engine.state.test_dataset) - config.handler.test.images, generator=g).tolist()[0]
|
||||
for i in range(random_start, random_start + config.handler.test.images):
|
||||
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
|
||||
for k in batch:
|
||||
if isinstance(batch[k], torch.Tensor):
|
||||
batch[k] = batch[k].unsqueeze(0)
|
||||
elif isinstance(batch[k], dict):
|
||||
for kk in batch[k]:
|
||||
if isinstance(batch[k][kk], torch.Tensor):
|
||||
batch[k][kk] = batch[k][kk].unsqueeze(0)
|
||||
g = torch.Generator()
|
||||
g.manual_seed(config.misc.random_seed + engine.state.epoch
|
||||
if config.handler.test.random else config.misc.random_seed)
|
||||
random_start = \
|
||||
torch.randperm(len(engine.state.test_dataset) - config.handler.test.images, generator=g).tolist()[0]
|
||||
for i in range(random_start, random_start + config.handler.test.images):
|
||||
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
|
||||
for k in batch:
|
||||
if isinstance(batch[k], torch.Tensor):
|
||||
batch[k] = batch[k].unsqueeze(0)
|
||||
elif isinstance(batch[k], dict):
|
||||
for kk in batch[k]:
|
||||
if isinstance(batch[k][kk], torch.Tensor):
|
||||
batch[k][kk] = batch[k][kk].unsqueeze(0)
|
||||
|
||||
generated = kernel.forward(batch)
|
||||
images = kernel.intermediate_images(batch, generated)
|
||||
generated = kernel.forward(batch, inference=True)
|
||||
images = kernel.intermediate_images(batch, generated)
|
||||
|
||||
for k in test_images:
|
||||
for j in range(len(images[k])):
|
||||
test_images[k][j].append(images[k][j])
|
||||
for k in test_images:
|
||||
tensorboard_handler.writer.add_image(
|
||||
f"test/{k}",
|
||||
make_2d_grid([torch.cat(ti) for ti in test_images[k]], range=(-1, 1)),
|
||||
engine.state.iteration * pairs_per_iteration
|
||||
)
|
||||
for j in range(len(images[k])):
|
||||
test_images[k][j].append(images[k][j])
|
||||
for k in test_images:
|
||||
tensorboard_handler.writer.add_image(
|
||||
f"test/{k}",
|
||||
make_2d_grid([torch.cat(ti) for ti in test_images[k]], range=(-1, 1)),
|
||||
engine.state.iteration * pairs_per_iteration
|
||||
)
|
||||
return trainer
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user