add attention image fuse

This commit is contained in:
2020-08-22 20:21:11 +08:00
parent ccc3d7614a
commit 58ed4524bf
4 changed files with 54 additions and 6 deletions

View File

@@ -19,7 +19,7 @@ from loss.gan import GANLoss
from model.weight_init import generation_init_weights
from model.GAN.residual_generator import GANImageBuffer
from model.GAN.UGATIT import RhoClipper
from util.image import make_2d_grid
from util.image import make_2d_grid, fuse_attention_map
from util.handler import setup_common_handlers, setup_tensorboard_handler
from util.build import build_model, build_optimizer
@@ -190,16 +190,23 @@ def get_trainer(config, logger):
@trainer.on(Events.ITERATION_COMPLETED(every=config.interval.tensorboard.image))
def show_images(engine):
output = engine.state.output
image_a_order = ["real_a", "fake_b", "rec_a", "id_a"]
image_b_order = ["real_b", "fake_a", "rec_b", "id_b"]
output["img"]["generated"]["real_a"] = fuse_attention_map(
output["img"]["generated"]["real_a"], output["img"]["heatmap"]["a2b"])
output["img"]["generated"]["real_b"] = fuse_attention_map(
output["img"]["generated"]["real_b"], output["img"]["heatmap"]["b2a"])
tensorboard_handler.writer.add_image(
"train/a",
make_2d_grid([engine.state.output["img"]["generated"][o] for o in image_a_order]),
make_2d_grid([output["img"]["generated"][o] for o in image_a_order]),
engine.state.iteration
)
tensorboard_handler.writer.add_image(
"train/b",
make_2d_grid([engine.state.output["img"]["generated"][o] for o in image_b_order]),
make_2d_grid([output["img"]["generated"][o] for o in image_b_order]),
engine.state.iteration
)