add attention image fuse
This commit is contained in:
@@ -19,7 +19,7 @@ from loss.gan import GANLoss
|
||||
from model.weight_init import generation_init_weights
|
||||
from model.GAN.residual_generator import GANImageBuffer
|
||||
from model.GAN.UGATIT import RhoClipper
|
||||
from util.image import make_2d_grid
|
||||
from util.image import make_2d_grid, fuse_attention_map
|
||||
from util.handler import setup_common_handlers, setup_tensorboard_handler
|
||||
from util.build import build_model, build_optimizer
|
||||
|
||||
@@ -190,16 +190,23 @@ def get_trainer(config, logger):
|
||||
|
||||
@trainer.on(Events.ITERATION_COMPLETED(every=config.interval.tensorboard.image))
|
||||
def show_images(engine):
|
||||
output = engine.state.output
|
||||
image_a_order = ["real_a", "fake_b", "rec_a", "id_a"]
|
||||
image_b_order = ["real_b", "fake_a", "rec_b", "id_b"]
|
||||
|
||||
output["img"]["generated"]["real_a"] = fuse_attention_map(
|
||||
output["img"]["generated"]["real_a"], output["img"]["heatmap"]["a2b"])
|
||||
output["img"]["generated"]["real_b"] = fuse_attention_map(
|
||||
output["img"]["generated"]["real_b"], output["img"]["heatmap"]["b2a"])
|
||||
|
||||
tensorboard_handler.writer.add_image(
|
||||
"train/a",
|
||||
make_2d_grid([engine.state.output["img"]["generated"][o] for o in image_a_order]),
|
||||
make_2d_grid([output["img"]["generated"][o] for o in image_a_order]),
|
||||
engine.state.iteration
|
||||
)
|
||||
tensorboard_handler.writer.add_image(
|
||||
"train/b",
|
||||
make_2d_grid([engine.state.output["img"]["generated"][o] for o in image_b_order]),
|
||||
make_2d_grid([output["img"]["generated"][o] for o in image_b_order]),
|
||||
engine.state.iteration
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user