UGATIT pipeline

This commit is contained in:
2020-08-28 08:15:29 +08:00
parent 09db0a413f
commit 42d6253a1d
5 changed files with 211 additions and 4 deletions

View File

@@ -219,12 +219,12 @@ def get_trainer(config, logger):
with torch.no_grad():
g = torch.Generator()
g.manual_seed(config.misc.random_seed)
indices = torch.randperm(len(engine.state.test_dataset), generator=g).tolist()[:10]
random_start = torch.randperm(len(engine.state.test_dataset)-11, generator=g).tolist()[0]
test_images = dict(
a=[[], [], [], []],
b=[[], [], [], []]
)
for i in indices:
for i in range(random_start, random_start+10):
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
real_a, real_b = batch["a"].view(1, *batch["a"].size()), batch["b"].view(1, *batch["a"].size())
@@ -278,7 +278,6 @@ def get_tester(config, logger):
paths = engine.state.output["path"]
batch_size = img_tensors[0].size(0)
for i in range(batch_size):
# image_name = f"{engine.state.iteration * batch_size - batch_size + i + 1}.png"
image_name = Path(paths[i]).name
torchvision.utils.save_image([img[i] for img in img_tensors], engine.state.img_output_dir / image_name,
nrow=len(img_tensors))
@@ -308,7 +307,7 @@ def run(task, config, logger):
print(traceback.format_exc())
elif task == "test":
assert config.resume_from is not None
test_dataset = data.DATASET.build_with(config.data.test.dataset)
test_dataset = data.DATASET.build_with(config.data.test.video_dataset)
logger.info(f"test with dataset:\n{test_dataset}")
test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader)
tester = get_tester(config, logger)