update tester

This commit is contained in:
2020-09-10 18:34:52 +08:00
parent 7ea9c6d0d5
commit 72d09aa483
7 changed files with 79 additions and 44 deletions

View File

@@ -5,7 +5,7 @@ import torch
import torch.nn as nn
from omegaconf import OmegaConf
from engine.base.i2i import EngineKernel, run_kernel
from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel
from engine.util.build import build_model
from loss.I2I.perceptual_loss import PerceptualLoss
from loss.gan import GANLoss
@@ -101,6 +101,31 @@ class TSITEngineKernel(EngineKernel):
)
class TSITTestEngineKernel(TestEngineKernel):
def __init__(self, config):
super().__init__(config)
def build_generators(self) -> dict:
generators = dict(
main=build_model(self.config.model.generator)
)
return generators
def to_load(self):
return {f"generator_{k}": self.generators[k] for k in self.generators}
def inference(self, batch):
with torch.no_grad():
fake = self.generators["main"](content_img=batch["a"][0], style_img=batch["b"][0])
return {"a": fake.detach()}
def run(task, config, _):
kernel = TSITEngineKernel(config)
run_kernel(task, config, kernel)
if task == "train":
kernel = TSITEngineKernel(config)
run_kernel(task, config, kernel)
elif task == "test":
kernel = TSITTestEngineKernel(config)
run_kernel(task, config, kernel)
else:
raise NotImplemented

View File

@@ -204,13 +204,21 @@ def get_trainer(config, kernel: EngineKernel):
return trainer
def save_images_helper(output_dir, paths, images_list):
batch_size = len(paths)
for i in range(batch_size):
image_name = Path(paths[i]).name
img_list = [img[i] for img in images_list]
torchvision.utils.save_image(img_list, Path(output_dir) / image_name, nrow=len(img_list), padding=0,
normalize=True, range=(-1, 1))
def get_tester(config, kernel: TestEngineKernel):
logger = logging.getLogger(config.name)
def _step(engine, batch):
real_a, path = convert_tensor(batch, idist.device())
fake = kernel.inference({"a": real_a})["a"]
return {"path": path, "img": [real_a.detach(), fake.detach()]}
batch = convert_tensor(batch, idist.device())
return {"batch": batch, "generated": kernel.inference(batch)}
tester = Engine(_step)
tester.logger = logger
@@ -227,13 +235,14 @@ def get_tester(config, kernel: TestEngineKernel):
@tester.on(Events.ITERATION_COMPLETED)
def save_images(engine):
img_tensors = engine.state.output["img"]
paths = engine.state.output["path"]
batch_size = img_tensors[0].size(0)
for i in range(batch_size):
image_name = Path(paths[i]).name
torchvision.utils.save_image([img[i] for img in img_tensors], engine.state.img_output_dir / image_name,
nrow=len(img_tensors), padding=0, normalize=True, range=(-1, 1))
if engine.state.dataloader.dataset.__class__.__name__ == "SingleFolderDataset":
images, paths = engine.state.output["batch"]
save_images_helper(config.img_output_dir, paths, [images, engine.state.output["generated"]])
else:
for k in engine.state.output['generated']:
images, paths = engine.state.output["batch"][k]
save_images_helper(config.img_output_dir, paths, [images, engine.state.output["generated"][k]])
return tester
@@ -264,7 +273,7 @@ def run_kernel(task, config, kernel):
print(traceback.format_exc())
elif task == "test":
assert config.resume_from is not None
test_dataset = data.DATASET.build_with(config.data.test.video_dataset)
test_dataset = data.DATASET.build_with(config.data.test[config.data.test.which])
logger.info(f"test with dataset:\n{test_dataset}")
test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader)
tester = get_tester(config, kernel)