add test handler
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
from itertools import chain
|
||||
from math import ceil
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
|
||||
import ignite.distributed as idist
|
||||
from ignite.engine import Events, Engine
|
||||
@@ -175,7 +177,7 @@ def get_trainer(config, logger):
|
||||
to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers})
|
||||
to_save.update({f"generator_{k}": generators[k] for k in generators})
|
||||
to_save.update({f"discriminator_{k}": discriminators[k] for k in discriminators})
|
||||
setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=True,
|
||||
setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=True, set_epoch_for_dist_sampler=True,
|
||||
end_event=Events.ITERATION_COMPLETED(once=config.max_iteration))
|
||||
|
||||
def output_transform(output):
|
||||
@@ -247,6 +249,43 @@ def get_trainer(config, logger):
|
||||
return trainer
|
||||
|
||||
|
||||
def get_tester(config, logger):
|
||||
generator_a2b = build_model(config.model.generator, config.distributed.model)
|
||||
|
||||
def _step(engine, batch):
|
||||
real_a, path = convert_tensor(batch, idist.device())
|
||||
with torch.no_grad():
|
||||
fake_b = generator_a2b(real_a)[0]
|
||||
return {"path": path, "img": [real_a.detach(), fake_b.detach()]}
|
||||
|
||||
tester = Engine(_step)
|
||||
tester.logger = logger
|
||||
|
||||
to_load = dict(generator_a2b=generator_a2b)
|
||||
setup_common_handlers(tester, config, use_profiler=False, to_save=to_load)
|
||||
|
||||
@tester.on(Events.STARTED)
|
||||
def mkdir(engine):
|
||||
img_output_dir = config.img_output_dir or f"{config.output_dir}/test_images"
|
||||
engine.state.img_output_dir = Path(img_output_dir)
|
||||
if idist.get_rank() == 0 and not engine.state.img_output_dir.exists():
|
||||
engine.logger.info(f"mkdir {engine.state.img_output_dir}")
|
||||
engine.state.img_output_dir.mkdir()
|
||||
|
||||
@tester.on(Events.ITERATION_COMPLETED)
|
||||
def save_images(engine):
|
||||
img_tensors = engine.state.output["img"]
|
||||
paths = engine.state.output["path"]
|
||||
batch_size = img_tensors[0].size(0)
|
||||
for i in range(batch_size):
|
||||
# image_name = f"{engine.state.iteration * batch_size - batch_size + i + 1}.png"
|
||||
image_name = Path(paths[i]).name
|
||||
torchvision.utils.save_image([img[i] for img in img_tensors], engine.state.img_output_dir / image_name,
|
||||
nrow=len(img_tensors))
|
||||
|
||||
return tester
|
||||
|
||||
|
||||
def run(task, config, logger):
|
||||
assert torch.backends.cudnn.enabled
|
||||
torch.backends.cudnn.benchmark = True
|
||||
@@ -267,5 +306,16 @@ def run(task, config, logger):
|
||||
except Exception:
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
elif task == "test":
|
||||
assert config.resume_from is not None
|
||||
test_dataset = data.DATASET.build_with(config.data.test.dataset)
|
||||
logger.info(f"test with dataset:\n{test_dataset}")
|
||||
test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader)
|
||||
tester = get_tester(config, logger)
|
||||
try:
|
||||
tester.run(test_data_loader, max_epochs=1)
|
||||
except Exception:
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
else:
|
||||
return NotImplemented(f"invalid task: {task}")
|
||||
|
||||
Reference in New Issue
Block a user