update tester
This commit is contained in:
@@ -204,13 +204,21 @@ def get_trainer(config, kernel: EngineKernel):
|
||||
return trainer
|
||||
|
||||
|
||||
def save_images_helper(output_dir, paths, images_list):
|
||||
batch_size = len(paths)
|
||||
for i in range(batch_size):
|
||||
image_name = Path(paths[i]).name
|
||||
img_list = [img[i] for img in images_list]
|
||||
torchvision.utils.save_image(img_list, Path(output_dir) / image_name, nrow=len(img_list), padding=0,
|
||||
normalize=True, range=(-1, 1))
|
||||
|
||||
|
||||
def get_tester(config, kernel: TestEngineKernel):
|
||||
logger = logging.getLogger(config.name)
|
||||
|
||||
def _step(engine, batch):
|
||||
real_a, path = convert_tensor(batch, idist.device())
|
||||
fake = kernel.inference({"a": real_a})["a"]
|
||||
return {"path": path, "img": [real_a.detach(), fake.detach()]}
|
||||
batch = convert_tensor(batch, idist.device())
|
||||
return {"batch": batch, "generated": kernel.inference(batch)}
|
||||
|
||||
tester = Engine(_step)
|
||||
tester.logger = logger
|
||||
@@ -227,13 +235,14 @@ def get_tester(config, kernel: TestEngineKernel):
|
||||
|
||||
@tester.on(Events.ITERATION_COMPLETED)
|
||||
def save_images(engine):
|
||||
img_tensors = engine.state.output["img"]
|
||||
paths = engine.state.output["path"]
|
||||
batch_size = img_tensors[0].size(0)
|
||||
for i in range(batch_size):
|
||||
image_name = Path(paths[i]).name
|
||||
torchvision.utils.save_image([img[i] for img in img_tensors], engine.state.img_output_dir / image_name,
|
||||
nrow=len(img_tensors), padding=0, normalize=True, range=(-1, 1))
|
||||
if engine.state.dataloader.dataset.__class__.__name__ == "SingleFolderDataset":
|
||||
images, paths = engine.state.output["batch"]
|
||||
save_images_helper(config.img_output_dir, paths, [images, engine.state.output["generated"]])
|
||||
|
||||
else:
|
||||
for k in engine.state.output['generated']:
|
||||
images, paths = engine.state.output["batch"][k]
|
||||
save_images_helper(config.img_output_dir, paths, [images, engine.state.output["generated"][k]])
|
||||
|
||||
return tester
|
||||
|
||||
@@ -264,7 +273,7 @@ def run_kernel(task, config, kernel):
|
||||
print(traceback.format_exc())
|
||||
elif task == "test":
|
||||
assert config.resume_from is not None
|
||||
test_dataset = data.DATASET.build_with(config.data.test.video_dataset)
|
||||
test_dataset = data.DATASET.build_with(config.data.test[config.data.test.which])
|
||||
logger.info(f"test with dataset:\n{test_dataset}")
|
||||
test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader)
|
||||
tester = get_tester(config, kernel)
|
||||
|
||||
Reference in New Issue
Block a user