UGATIT pipeline
This commit is contained in:
@@ -219,12 +219,12 @@ def get_trainer(config, logger):
|
||||
with torch.no_grad():
|
||||
g = torch.Generator()
|
||||
g.manual_seed(config.misc.random_seed)
|
||||
indices = torch.randperm(len(engine.state.test_dataset), generator=g).tolist()[:10]
|
||||
random_start = torch.randperm(len(engine.state.test_dataset)-11, generator=g).tolist()[0]
|
||||
test_images = dict(
|
||||
a=[[], [], [], []],
|
||||
b=[[], [], [], []]
|
||||
)
|
||||
for i in indices:
|
||||
for i in range(random_start, random_start+10):
|
||||
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
|
||||
|
||||
real_a, real_b = batch["a"].view(1, *batch["a"].size()), batch["b"].view(1, *batch["a"].size())
|
||||
@@ -278,7 +278,6 @@ def get_tester(config, logger):
|
||||
paths = engine.state.output["path"]
|
||||
batch_size = img_tensors[0].size(0)
|
||||
for i in range(batch_size):
|
||||
# image_name = f"{engine.state.iteration * batch_size - batch_size + i + 1}.png"
|
||||
image_name = Path(paths[i]).name
|
||||
torchvision.utils.save_image([img[i] for img in img_tensors], engine.state.img_output_dir / image_name,
|
||||
nrow=len(img_tensors))
|
||||
@@ -308,7 +307,7 @@ def run(task, config, logger):
|
||||
print(traceback.format_exc())
|
||||
elif task == "test":
|
||||
assert config.resume_from is not None
|
||||
test_dataset = data.DATASET.build_with(config.data.test.dataset)
|
||||
test_dataset = data.DATASET.build_with(config.data.test.video_dataset)
|
||||
logger.info(f"test with dataset:\n{test_dataset}")
|
||||
test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader)
|
||||
tester = get_tester(config, logger)
|
||||
|
||||
Reference in New Issue
Block a user