update tester
This commit is contained in:
@@ -5,7 +5,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from engine.base.i2i import EngineKernel, run_kernel
|
||||
from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel
|
||||
from engine.util.build import build_model
|
||||
from loss.I2I.perceptual_loss import PerceptualLoss
|
||||
from loss.gan import GANLoss
|
||||
@@ -101,6 +101,31 @@ class TSITEngineKernel(EngineKernel):
|
||||
)
|
||||
|
||||
|
||||
class TSITTestEngineKernel(TestEngineKernel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
def build_generators(self) -> dict:
|
||||
generators = dict(
|
||||
main=build_model(self.config.model.generator)
|
||||
)
|
||||
return generators
|
||||
|
||||
def to_load(self):
|
||||
return {f"generator_{k}": self.generators[k] for k in self.generators}
|
||||
|
||||
def inference(self, batch):
|
||||
with torch.no_grad():
|
||||
fake = self.generators["main"](content_img=batch["a"][0], style_img=batch["b"][0])
|
||||
return {"a": fake.detach()}
|
||||
|
||||
|
||||
def run(task, config, _):
|
||||
kernel = TSITEngineKernel(config)
|
||||
run_kernel(task, config, kernel)
|
||||
if task == "train":
|
||||
kernel = TSITEngineKernel(config)
|
||||
run_kernel(task, config, kernel)
|
||||
elif task == "test":
|
||||
kernel = TSITTestEngineKernel(config)
|
||||
run_kernel(task, config, kernel)
|
||||
else:
|
||||
raise NotImplemented
|
||||
|
||||
@@ -204,13 +204,21 @@ def get_trainer(config, kernel: EngineKernel):
|
||||
return trainer
|
||||
|
||||
|
||||
def save_images_helper(output_dir, paths, images_list):
|
||||
batch_size = len(paths)
|
||||
for i in range(batch_size):
|
||||
image_name = Path(paths[i]).name
|
||||
img_list = [img[i] for img in images_list]
|
||||
torchvision.utils.save_image(img_list, Path(output_dir) / image_name, nrow=len(img_list), padding=0,
|
||||
normalize=True, range=(-1, 1))
|
||||
|
||||
|
||||
def get_tester(config, kernel: TestEngineKernel):
|
||||
logger = logging.getLogger(config.name)
|
||||
|
||||
def _step(engine, batch):
|
||||
real_a, path = convert_tensor(batch, idist.device())
|
||||
fake = kernel.inference({"a": real_a})["a"]
|
||||
return {"path": path, "img": [real_a.detach(), fake.detach()]}
|
||||
batch = convert_tensor(batch, idist.device())
|
||||
return {"batch": batch, "generated": kernel.inference(batch)}
|
||||
|
||||
tester = Engine(_step)
|
||||
tester.logger = logger
|
||||
@@ -227,13 +235,14 @@ def get_tester(config, kernel: TestEngineKernel):
|
||||
|
||||
@tester.on(Events.ITERATION_COMPLETED)
|
||||
def save_images(engine):
|
||||
img_tensors = engine.state.output["img"]
|
||||
paths = engine.state.output["path"]
|
||||
batch_size = img_tensors[0].size(0)
|
||||
for i in range(batch_size):
|
||||
image_name = Path(paths[i]).name
|
||||
torchvision.utils.save_image([img[i] for img in img_tensors], engine.state.img_output_dir / image_name,
|
||||
nrow=len(img_tensors), padding=0, normalize=True, range=(-1, 1))
|
||||
if engine.state.dataloader.dataset.__class__.__name__ == "SingleFolderDataset":
|
||||
images, paths = engine.state.output["batch"]
|
||||
save_images_helper(config.img_output_dir, paths, [images, engine.state.output["generated"]])
|
||||
|
||||
else:
|
||||
for k in engine.state.output['generated']:
|
||||
images, paths = engine.state.output["batch"][k]
|
||||
save_images_helper(config.img_output_dir, paths, [images, engine.state.output["generated"][k]])
|
||||
|
||||
return tester
|
||||
|
||||
@@ -264,7 +273,7 @@ def run_kernel(task, config, kernel):
|
||||
print(traceback.format_exc())
|
||||
elif task == "test":
|
||||
assert config.resume_from is not None
|
||||
test_dataset = data.DATASET.build_with(config.data.test.video_dataset)
|
||||
test_dataset = data.DATASET.build_with(config.data.test[config.data.test.which])
|
||||
logger.info(f"test with dataset:\n{test_dataset}")
|
||||
test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader)
|
||||
tester = get_tester(config, kernel)
|
||||
|
||||
Reference in New Issue
Block a user