This commit is contained in:
2020-09-17 09:34:53 +08:00
parent 2ff4a91057
commit 61e04de8a5
9 changed files with 168 additions and 288 deletions

View File

@@ -58,20 +58,20 @@ class MUNITEngineKernel(EngineKernel):
styles = dict()
contents = dict()
images = dict()
with torch.set_grad_enabled(not inference):
for phase in "ab":
contents[phase], styles[phase] = self.generators[phase].encode(batch[phase])
images["{0}2{0}".format(phase)] = self.generators[phase].decode(contents[phase], styles[phase])
styles[f"random_{phase}"] = torch.randn_like(styles[phase]).to(idist.device())
for phase in "ab":
contents[phase], styles[phase] = self.generators[phase].encode(batch[phase])
images["{0}2{0}".format(phase)] = self.generators[phase].decode(contents[phase], styles[phase])
styles[f"random_{phase}"] = torch.randn_like(styles[phase]).to(idist.device())
for phase in ("a2b", "b2a"):
# images["a2b"] = Gb.decode(content_a, random_style_b)
images[phase] = self.generators[phase[-1]].decode(contents[phase[0]], styles[f"random_{phase[-1]}"])
# contents["a2b"], styles["a2b"] = Gb.encode(images["a2b"])
contents[phase], styles[phase] = self.generators[phase[-1]].encode(images[phase])
if self.config.loss.recon.cycle.weight > 0:
images[f"{phase}2{phase[0]}"] = self.generators[phase[0]].decode(contents[phase], styles[phase[0]])
return dict(styles=styles, contents=contents, images=images)
for phase in ("a2b", "b2a"):
# images["a2b"] = Gb.decode(content_a, random_style_b)
images[phase] = self.generators[phase[-1]].decode(contents[phase[0]], styles[f"random_{phase[-1]}"])
# contents["a2b"], styles["a2b"] = Gb.encode(images["a2b"])
contents[phase], styles[phase] = self.generators[phase[-1]].encode(images[phase])
if self.config.loss.recon.cycle.weight > 0:
images[f"{phase}2{phase[0]}"] = self.generators[phase[0]].decode(contents[phase], styles[phase[0]])
return dict(styles=styles, contents=contents, images=images)
def criterion_generators(self, batch, generated) -> dict:
loss = dict()

View File

@@ -3,8 +3,7 @@ from itertools import chain
import ignite.distributed as idist
import torch
import torch.nn as nn
from ignite.engine import Events
from omegaconf import read_write, OmegaConf
from omegaconf import OmegaConf
from engine.base.i2i import EngineKernel, run_kernel
from engine.util.build import build_model
@@ -21,17 +20,14 @@ class TAFGEngineKernel(EngineKernel):
perceptual_loss_cfg.pop("weight")
self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
style_loss_cfg = OmegaConf.to_container(config.loss.style)
style_loss_cfg.pop("weight")
self.style_loss = PerceptualLoss(**style_loss_cfg).to(idist.device())
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight")
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
self.fm_loss = nn.L1Loss() if config.loss.fm.level == 1 else nn.MSELoss()
self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss()
self.style_recon_loss = nn.L1Loss() if config.loss.style_recon.level == 1 else nn.MSELoss()
self.content_recon_loss = nn.L1Loss() if config.loss.content_recon.level == 1 else nn.MSELoss()
self.cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
self.edge_loss = EdgeLoss("HED", hed_pretrained_model_path=config.loss.edge.hed_pretrained_model_path).to(
idist.device())
@@ -67,47 +63,67 @@ class TAFGEngineKernel(EngineKernel):
def forward(self, batch, inference=False) -> dict:
generator = self.generators["main"]
batch = self._process_batch(batch, inference)
styles = dict()
contents = dict()
images = dict()
with torch.set_grad_enabled(not inference):
fake = dict(
a=generator(content_img=batch["edge_a"], style_img=batch["a"], which_decoder="a"),
b=generator(content_img=batch["edge_a"], style_img=batch["b"], which_decoder="b"),
)
return fake
for ph in "ab":
contents[ph], styles[ph] = generator.encode(batch[ph]["edge"], batch[ph]["img"], ph, ph)
for ph in ("a2b", "b2a"):
images[f"fake_{ph[-1]}"] = generator.decode(contents[ph[0]], styles[ph[-1]], ph[-1])
contents["recon_a"], styles["recon_b"] = generator.encode(
self.edge_loss.edge_extractor(images["fake_b"]), images["fake_b"], "b", "b")
images["a2a"] = generator.decode(contents["a"], styles["a"], "a")
images["b2b"] = generator.decode(contents["b"], styles["recon_b"], "b")
images["cycle_a"] = generator.decode(contents["recon_a"], styles["a"], "a")
return dict(styles=styles, contents=contents, images=images)
def criterion_generators(self, batch, generated) -> dict:
batch = self._process_batch(batch)
loss = dict()
loss_perceptual, _ = self.perceptual_loss(generated["b"], batch["a"])
_, loss_style = self.style_loss(generated["a"], batch["a"])
loss["style"] = self.config.loss.style.weight * loss_style
loss["perceptual"] = self.config.loss.perceptual.weight * loss_perceptual
for phase in "ab":
pred_fake = self.discriminators[phase](generated[phase])
loss[f"gan_{phase}"] = 0
for ph in "ab":
loss[f"recon_image_{ph}"] = self.config.loss.recon.weight * self.recon_loss(
generated["images"][f"{ph}2{ph}"], batch[ph]["img"])
pred_fake = self.discriminators[ph](generated["images"][f"fake_{ph}"])
loss[f"gan_{ph}"] = 0
for sub_pred_fake in pred_fake:
# last output is actual prediction
loss[f"gan_{phase}"] += self.gan_loss(sub_pred_fake[-1], True)
loss[f"gan_{ph}"] += self.gan_loss(sub_pred_fake[-1], True) * self.config.loss.gan.weight
loss[f"recon_content_a"] = self.config.loss.content_recon.weight * self.content_recon_loss(
generated["contents"]["a"], generated["contents"]["recon_a"]
)
loss[f"recon_style_b"] = self.config.loss.style_recon.weight * self.style_recon_loss(
generated["styles"]["b"], generated["styles"]["recon_b"]
)
if self.config.loss.fm.weight > 0 and phase == "b":
pred_real = self.discriminators[phase](batch[phase])
loss_fm = 0
num_scale_discriminator = len(pred_fake)
for i in range(num_scale_discriminator):
# last output is the final prediction, so we exclude it
num_intermediate_outputs = len(pred_fake[i]) - 1
for j in range(num_intermediate_outputs):
loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator
loss[f"fm_{phase}"] = self.config.loss.fm.weight * loss_fm
loss["recon"] = self.config.loss.recon.weight * self.recon_loss(generated["a"], batch["a"])
loss["edge"] = self.config.loss.edge.weight * self.edge_loss(generated["b"], batch["edge_a"][:, 0:1, :, :])
for ph in ("a2b", "b2a"):
if self.config.loss.perceptual.weight > 0:
loss[f"perceptual_{ph}"] = self.config.loss.perceptual.weight * self.perceptual_loss(
batch[ph[0]]["img"], generated["images"][f"fake_{ph[-1]}"]
)
if self.config.loss.edge.weight > 0:
loss[f"edge_a"] = self.config.loss.edge.weight * self.edge_loss(
generated["images"]["fake_b"], batch["a"]["edge"][:, 0:1, :, :]
)
loss[f"edge_b"] = self.config.loss.edge.weight * self.edge_loss(
generated["images"]["fake_a"], batch["b"]["edge"]
)
if self.config.loss.cycle.weight > 0:
loss[f"cycle_a"] = self.config.loss.cycle.weight * self.cycle_loss(
batch["a"]["img"], generated["images"]["cycle_a"]
)
return loss
def criterion_discriminators(self, batch, generated) -> dict:
loss = dict()
# batch = self._process_batch(batch)
for phase in self.discriminators.keys():
pred_real = self.discriminators[phase](batch[phase])
pred_fake = self.discriminators[phase](generated[phase].detach())
pred_real = self.discriminators[phase](batch[phase]["img"])
pred_fake = self.discriminators[phase](generated["images"][f"fake_{phase}"].detach())
loss[f"gan_{phase}"] = 0
for i in range(len(pred_fake)):
loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True)
@@ -122,17 +138,25 @@ class TAFGEngineKernel(EngineKernel):
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
"""
batch = self._process_batch(batch)
edge = batch["edge_a"][:, 0:1, :, :]
return dict(
a=[edge.expand(-1, 3, -1, -1).detach(), batch["a"].detach(), batch["b"].detach(), generated["a"].detach(),
generated["b"].detach()]
a=[batch["a"]["edge"][:, 0:1, :, :].expand(-1, 3, -1, -1).detach(),
batch["a"]["img"].detach(),
generated["images"]["a2a"].detach(),
generated["images"]["fake_b"].detach(),
generated["images"]["cycle_a"].detach(),
],
b=[batch["b"]["edge"].expand(-1, 3, -1, -1).detach(),
batch["b"]["img"].detach(),
generated["images"]["b2b"].detach(),
generated["images"]["fake_a"].detach()]
)
def change_engine(self, config, trainer):
@trainer.on(Events.ITERATION_STARTED(once=int(config.max_iteration / 3)))
def change_config(engine):
with read_write(config):
config.loss.perceptual.weight = 5
pass
# @trainer.on(Events.ITERATION_STARTED(once=int(config.max_iteration / 3)))
# def change_config(engine):
# with read_write(config):
# config.loss.perceptual.weight = 5
def run(task, config, _):

View File

@@ -132,9 +132,12 @@ def get_trainer(config, kernel: EngineKernel):
generated = kernel.forward(batch)
if kernel.train_generator_first:
# simultaneous, train G with simultaneous D
loss_g = train_generators(batch, generated)
loss_d = train_discriminators(batch, generated)
else:
# update discriminators first, not simultaneous.
# train G with updated discriminators
loss_d = train_discriminators(batch, generated)
loss_g = train_generators(batch, generated)
@@ -152,8 +155,8 @@ def get_trainer(config, kernel: EngineKernel):
kernel.change_engine(config, trainer)
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d")
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values()), epoch_bound=False).attach(trainer, "loss_g")
RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values()), epoch_bound=False).attach(trainer, "loss_d")
to_save = dict(trainer=trainer)
to_save.update({f"lr_scheduler_{k}": lr_schedulers[k] for k in lr_schedulers})
to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers})
@@ -188,7 +191,13 @@ def get_trainer(config, kernel: EngineKernel):
for i in range(random_start, random_start + 10):
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
for k in batch:
batch[k] = batch[k].view(1, *batch[k].size())
if isinstance(batch[k], torch.Tensor):
batch[k] = batch[k].unsqueeze(0)
elif isinstance(batch[k], dict):
for kk in batch[k]:
if isinstance(batch[k][kk], torch.Tensor):
batch[k][kk] = batch[k][kk].unsqueeze(0)
generated = kernel.forward(batch)
images = kernel.intermediate_images(batch, generated)