TAFG
This commit is contained in:
@@ -58,20 +58,20 @@ class MUNITEngineKernel(EngineKernel):
|
||||
styles = dict()
|
||||
contents = dict()
|
||||
images = dict()
|
||||
with torch.set_grad_enabled(not inference):
|
||||
for phase in "ab":
|
||||
contents[phase], styles[phase] = self.generators[phase].encode(batch[phase])
|
||||
images["{0}2{0}".format(phase)] = self.generators[phase].decode(contents[phase], styles[phase])
|
||||
styles[f"random_{phase}"] = torch.randn_like(styles[phase]).to(idist.device())
|
||||
|
||||
for phase in "ab":
|
||||
contents[phase], styles[phase] = self.generators[phase].encode(batch[phase])
|
||||
images["{0}2{0}".format(phase)] = self.generators[phase].decode(contents[phase], styles[phase])
|
||||
styles[f"random_{phase}"] = torch.randn_like(styles[phase]).to(idist.device())
|
||||
|
||||
for phase in ("a2b", "b2a"):
|
||||
# images["a2b"] = Gb.decode(content_a, random_style_b)
|
||||
images[phase] = self.generators[phase[-1]].decode(contents[phase[0]], styles[f"random_{phase[-1]}"])
|
||||
# contents["a2b"], styles["a2b"] = Gb.encode(images["a2b"])
|
||||
contents[phase], styles[phase] = self.generators[phase[-1]].encode(images[phase])
|
||||
if self.config.loss.recon.cycle.weight > 0:
|
||||
images[f"{phase}2{phase[0]}"] = self.generators[phase[0]].decode(contents[phase], styles[phase[0]])
|
||||
return dict(styles=styles, contents=contents, images=images)
|
||||
for phase in ("a2b", "b2a"):
|
||||
# images["a2b"] = Gb.decode(content_a, random_style_b)
|
||||
images[phase] = self.generators[phase[-1]].decode(contents[phase[0]], styles[f"random_{phase[-1]}"])
|
||||
# contents["a2b"], styles["a2b"] = Gb.encode(images["a2b"])
|
||||
contents[phase], styles[phase] = self.generators[phase[-1]].encode(images[phase])
|
||||
if self.config.loss.recon.cycle.weight > 0:
|
||||
images[f"{phase}2{phase[0]}"] = self.generators[phase[0]].decode(contents[phase], styles[phase[0]])
|
||||
return dict(styles=styles, contents=contents, images=images)
|
||||
|
||||
def criterion_generators(self, batch, generated) -> dict:
|
||||
loss = dict()
|
||||
|
||||
106
engine/TAFG.py
106
engine/TAFG.py
@@ -3,8 +3,7 @@ from itertools import chain
|
||||
import ignite.distributed as idist
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ignite.engine import Events
|
||||
from omegaconf import read_write, OmegaConf
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from engine.base.i2i import EngineKernel, run_kernel
|
||||
from engine.util.build import build_model
|
||||
@@ -21,17 +20,14 @@ class TAFGEngineKernel(EngineKernel):
|
||||
perceptual_loss_cfg.pop("weight")
|
||||
self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
|
||||
|
||||
style_loss_cfg = OmegaConf.to_container(config.loss.style)
|
||||
style_loss_cfg.pop("weight")
|
||||
self.style_loss = PerceptualLoss(**style_loss_cfg).to(idist.device())
|
||||
|
||||
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
||||
gan_loss_cfg.pop("weight")
|
||||
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
||||
|
||||
self.fm_loss = nn.L1Loss() if config.loss.fm.level == 1 else nn.MSELoss()
|
||||
self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss()
|
||||
self.style_recon_loss = nn.L1Loss() if config.loss.style_recon.level == 1 else nn.MSELoss()
|
||||
self.content_recon_loss = nn.L1Loss() if config.loss.content_recon.level == 1 else nn.MSELoss()
|
||||
self.cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
|
||||
|
||||
self.edge_loss = EdgeLoss("HED", hed_pretrained_model_path=config.loss.edge.hed_pretrained_model_path).to(
|
||||
idist.device())
|
||||
@@ -67,47 +63,67 @@ class TAFGEngineKernel(EngineKernel):
|
||||
def forward(self, batch, inference=False) -> dict:
|
||||
generator = self.generators["main"]
|
||||
batch = self._process_batch(batch, inference)
|
||||
|
||||
styles = dict()
|
||||
contents = dict()
|
||||
images = dict()
|
||||
with torch.set_grad_enabled(not inference):
|
||||
fake = dict(
|
||||
a=generator(content_img=batch["edge_a"], style_img=batch["a"], which_decoder="a"),
|
||||
b=generator(content_img=batch["edge_a"], style_img=batch["b"], which_decoder="b"),
|
||||
)
|
||||
return fake
|
||||
for ph in "ab":
|
||||
contents[ph], styles[ph] = generator.encode(batch[ph]["edge"], batch[ph]["img"], ph, ph)
|
||||
for ph in ("a2b", "b2a"):
|
||||
images[f"fake_{ph[-1]}"] = generator.decode(contents[ph[0]], styles[ph[-1]], ph[-1])
|
||||
contents["recon_a"], styles["recon_b"] = generator.encode(
|
||||
self.edge_loss.edge_extractor(images["fake_b"]), images["fake_b"], "b", "b")
|
||||
images["a2a"] = generator.decode(contents["a"], styles["a"], "a")
|
||||
images["b2b"] = generator.decode(contents["b"], styles["recon_b"], "b")
|
||||
images["cycle_a"] = generator.decode(contents["recon_a"], styles["a"], "a")
|
||||
return dict(styles=styles, contents=contents, images=images)
|
||||
|
||||
def criterion_generators(self, batch, generated) -> dict:
|
||||
batch = self._process_batch(batch)
|
||||
loss = dict()
|
||||
loss_perceptual, _ = self.perceptual_loss(generated["b"], batch["a"])
|
||||
_, loss_style = self.style_loss(generated["a"], batch["a"])
|
||||
loss["style"] = self.config.loss.style.weight * loss_style
|
||||
loss["perceptual"] = self.config.loss.perceptual.weight * loss_perceptual
|
||||
for phase in "ab":
|
||||
pred_fake = self.discriminators[phase](generated[phase])
|
||||
loss[f"gan_{phase}"] = 0
|
||||
|
||||
for ph in "ab":
|
||||
loss[f"recon_image_{ph}"] = self.config.loss.recon.weight * self.recon_loss(
|
||||
generated["images"][f"{ph}2{ph}"], batch[ph]["img"])
|
||||
|
||||
pred_fake = self.discriminators[ph](generated["images"][f"fake_{ph}"])
|
||||
loss[f"gan_{ph}"] = 0
|
||||
for sub_pred_fake in pred_fake:
|
||||
# last output is actual prediction
|
||||
loss[f"gan_{phase}"] += self.gan_loss(sub_pred_fake[-1], True)
|
||||
loss[f"gan_{ph}"] += self.gan_loss(sub_pred_fake[-1], True) * self.config.loss.gan.weight
|
||||
loss[f"recon_content_a"] = self.config.loss.content_recon.weight * self.content_recon_loss(
|
||||
generated["contents"]["a"], generated["contents"]["recon_a"]
|
||||
)
|
||||
loss[f"recon_style_b"] = self.config.loss.style_recon.weight * self.style_recon_loss(
|
||||
generated["styles"]["b"], generated["styles"]["recon_b"]
|
||||
)
|
||||
|
||||
if self.config.loss.fm.weight > 0 and phase == "b":
|
||||
pred_real = self.discriminators[phase](batch[phase])
|
||||
loss_fm = 0
|
||||
num_scale_discriminator = len(pred_fake)
|
||||
for i in range(num_scale_discriminator):
|
||||
# last output is the final prediction, so we exclude it
|
||||
num_intermediate_outputs = len(pred_fake[i]) - 1
|
||||
for j in range(num_intermediate_outputs):
|
||||
loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator
|
||||
loss[f"fm_{phase}"] = self.config.loss.fm.weight * loss_fm
|
||||
loss["recon"] = self.config.loss.recon.weight * self.recon_loss(generated["a"], batch["a"])
|
||||
loss["edge"] = self.config.loss.edge.weight * self.edge_loss(generated["b"], batch["edge_a"][:, 0:1, :, :])
|
||||
for ph in ("a2b", "b2a"):
|
||||
if self.config.loss.perceptual.weight > 0:
|
||||
loss[f"perceptual_{ph}"] = self.config.loss.perceptual.weight * self.perceptual_loss(
|
||||
batch[ph[0]]["img"], generated["images"][f"fake_{ph[-1]}"]
|
||||
)
|
||||
if self.config.loss.edge.weight > 0:
|
||||
loss[f"edge_a"] = self.config.loss.edge.weight * self.edge_loss(
|
||||
generated["images"]["fake_b"], batch["a"]["edge"][:, 0:1, :, :]
|
||||
)
|
||||
loss[f"edge_b"] = self.config.loss.edge.weight * self.edge_loss(
|
||||
generated["images"]["fake_a"], batch["b"]["edge"]
|
||||
)
|
||||
|
||||
if self.config.loss.cycle.weight > 0:
|
||||
loss[f"cycle_a"] = self.config.loss.cycle.weight * self.cycle_loss(
|
||||
batch["a"]["img"], generated["images"]["cycle_a"]
|
||||
)
|
||||
return loss
|
||||
|
||||
def criterion_discriminators(self, batch, generated) -> dict:
|
||||
loss = dict()
|
||||
# batch = self._process_batch(batch)
|
||||
for phase in self.discriminators.keys():
|
||||
pred_real = self.discriminators[phase](batch[phase])
|
||||
pred_fake = self.discriminators[phase](generated[phase].detach())
|
||||
pred_real = self.discriminators[phase](batch[phase]["img"])
|
||||
pred_fake = self.discriminators[phase](generated["images"][f"fake_{phase}"].detach())
|
||||
loss[f"gan_{phase}"] = 0
|
||||
for i in range(len(pred_fake)):
|
||||
loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True)
|
||||
@@ -122,17 +138,25 @@ class TAFGEngineKernel(EngineKernel):
|
||||
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
||||
"""
|
||||
batch = self._process_batch(batch)
|
||||
edge = batch["edge_a"][:, 0:1, :, :]
|
||||
return dict(
|
||||
a=[edge.expand(-1, 3, -1, -1).detach(), batch["a"].detach(), batch["b"].detach(), generated["a"].detach(),
|
||||
generated["b"].detach()]
|
||||
a=[batch["a"]["edge"][:, 0:1, :, :].expand(-1, 3, -1, -1).detach(),
|
||||
batch["a"]["img"].detach(),
|
||||
generated["images"]["a2a"].detach(),
|
||||
generated["images"]["fake_b"].detach(),
|
||||
generated["images"]["cycle_a"].detach(),
|
||||
],
|
||||
b=[batch["b"]["edge"].expand(-1, 3, -1, -1).detach(),
|
||||
batch["b"]["img"].detach(),
|
||||
generated["images"]["b2b"].detach(),
|
||||
generated["images"]["fake_a"].detach()]
|
||||
)
|
||||
|
||||
def change_engine(self, config, trainer):
|
||||
@trainer.on(Events.ITERATION_STARTED(once=int(config.max_iteration / 3)))
|
||||
def change_config(engine):
|
||||
with read_write(config):
|
||||
config.loss.perceptual.weight = 5
|
||||
pass
|
||||
# @trainer.on(Events.ITERATION_STARTED(once=int(config.max_iteration / 3)))
|
||||
# def change_config(engine):
|
||||
# with read_write(config):
|
||||
# config.loss.perceptual.weight = 5
|
||||
|
||||
|
||||
def run(task, config, _):
|
||||
|
||||
@@ -132,9 +132,12 @@ def get_trainer(config, kernel: EngineKernel):
|
||||
generated = kernel.forward(batch)
|
||||
|
||||
if kernel.train_generator_first:
|
||||
# simultaneous, train G with simultaneous D
|
||||
loss_g = train_generators(batch, generated)
|
||||
loss_d = train_discriminators(batch, generated)
|
||||
else:
|
||||
# update discriminators first, not simultaneous.
|
||||
# train G with updated discriminators
|
||||
loss_d = train_discriminators(batch, generated)
|
||||
loss_g = train_generators(batch, generated)
|
||||
|
||||
@@ -152,8 +155,8 @@ def get_trainer(config, kernel: EngineKernel):
|
||||
|
||||
kernel.change_engine(config, trainer)
|
||||
|
||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
|
||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d")
|
||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values()), epoch_bound=False).attach(trainer, "loss_g")
|
||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values()), epoch_bound=False).attach(trainer, "loss_d")
|
||||
to_save = dict(trainer=trainer)
|
||||
to_save.update({f"lr_scheduler_{k}": lr_schedulers[k] for k in lr_schedulers})
|
||||
to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers})
|
||||
@@ -188,7 +191,13 @@ def get_trainer(config, kernel: EngineKernel):
|
||||
for i in range(random_start, random_start + 10):
|
||||
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
|
||||
for k in batch:
|
||||
batch[k] = batch[k].view(1, *batch[k].size())
|
||||
if isinstance(batch[k], torch.Tensor):
|
||||
batch[k] = batch[k].unsqueeze(0)
|
||||
elif isinstance(batch[k], dict):
|
||||
for kk in batch[k]:
|
||||
if isinstance(batch[k][kk], torch.Tensor):
|
||||
batch[k][kk] = batch[k][kk].unsqueeze(0)
|
||||
|
||||
generated = kernel.forward(batch)
|
||||
images = kernel.intermediate_images(batch, generated)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user