This commit is contained in:
2020-09-05 10:33:35 +08:00
parent 2469bf15fe
commit 39c754374c
21 changed files with 550 additions and 705 deletions

View File

@@ -1,32 +1,23 @@
from itertools import chain
from math import ceil
from pathlib import Path
import logging
from pathlib import Path
import torch
import torchvision
import ignite.distributed as idist
from ignite.engine import Events, Engine
from ignite.metrics import RunningAverage
from ignite.utils import convert_tensor
from ignite.contrib.handlers.tensorboard_logger import OptimizerParamsHandler
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
from model import MODEL
from omegaconf import read_write, OmegaConf
from util.image import make_2d_grid
from util.handler import setup_common_handlers, setup_tensorboard_handler
from util.build import build_optimizer
from engine.util.handler import setup_common_handlers, setup_tensorboard_handler
from engine.util.build import build_optimizer
from omegaconf import OmegaConf
def build_model(cfg):
cfg = OmegaConf.to_container(cfg)
bn_to_sync_bn = cfg.pop("_bn_to_sync_bn", False)
model = MODEL.build_with(cfg)
if bn_to_sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
return idist.auto_model(model)
import data
def build_lr_schedulers(optimizers, config):
@@ -47,10 +38,26 @@ def build_lr_schedulers(optimizers, config):
)
class EngineKernel(object):
def __init__(self, config, logger):
class TestEngineKernel(object):
def __init__(self, config):
self.config = config
self.logger = logger
self.logger = logging.getLogger(config.name)
self.generators = self.build_generators()
def build_generators(self) -> dict:
raise NotImplemented
def to_load(self):
return {f"generator_{k}": self.generators[k] for k in self.generators}
def inference(self, batch):
raise NotImplemented
class EngineKernel(object):
def __init__(self, config):
self.config = config
self.logger = logging.getLogger(config.name)
self.generators, self.discriminators = self.build_models()
def build_models(self) -> (dict, dict):
@@ -87,39 +94,43 @@ class EngineKernel(object):
raise NotImplemented
def get_trainer(config, ek: EngineKernel, iter_per_epoch):
def get_trainer(config, kernel: EngineKernel):
logger = logging.getLogger(config.name)
generators, discriminators = ek.generators, ek.discriminators
generators, discriminators = kernel.generators, kernel.discriminators
optimizers = dict(
g=build_optimizer(chain(*[m.parameters() for m in generators.values()]), config.optimizers.generator),
d=build_optimizer(chain(*[m.parameters() for m in discriminators.values()]), config.optimizers.discriminator),
)
logger.info("build optimizers", optimizers)
logger.info(f"build optimizers:\n{optimizers}")
lr_schedulers = build_lr_schedulers(optimizers, config)
logger.info(f"build lr_schedulers:\n{lr_schedulers}")
image_per_iteration = max(config.iterations_per_epoch // config.handler.tensorboard.image, 1)
def _step(engine, batch):
batch = convert_tensor(batch, idist.device())
generated = ek.forward(batch)
generated = kernel.forward(batch)
ek.setup_before_g()
kernel.setup_before_g()
optimizers["g"].zero_grad()
loss_g = ek.criterion_generators(batch, generated)
loss_g = kernel.criterion_generators(batch, generated)
sum(loss_g.values()).backward()
optimizers["g"].step()
ek.setup_before_d()
kernel.setup_before_d()
optimizers["d"].zero_grad()
loss_d = ek.criterion_discriminators(batch, generated)
loss_d = kernel.criterion_discriminators(batch, generated)
sum(loss_d.values()).backward()
optimizers["d"].step()
return {
"loss": dict(g=loss_g, d=loss_d),
"img": ek.intermediate_images(batch, generated)
}
if engine.state.iteration % image_per_iteration == 0:
return {
"loss": dict(g=loss_g, d=loss_d),
"img": kernel.intermediate_images(batch, generated)
}
return {"loss": dict(g=loss_g, d=loss_d)}
trainer = Engine(_step)
trainer.logger = logger
@@ -131,33 +142,22 @@ def get_trainer(config, ek: EngineKernel, iter_per_epoch):
to_save = dict(trainer=trainer)
to_save.update({f"lr_scheduler_{k}": lr_schedulers[k] for k in lr_schedulers})
to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers})
to_save.update(ek.to_save())
setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=True, set_epoch_for_dist_sampler=True,
to_save.update(kernel.to_save())
setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=config.handler.clear_cuda_cache,
set_epoch_for_dist_sampler=config.handler.set_epoch_for_dist_sampler,
end_event=Events.ITERATION_COMPLETED(once=config.max_iteration))
def output_transform(output):
loss = dict()
for tl in output["loss"]:
if isinstance(output["loss"][tl], dict):
for l in output["loss"][tl]:
loss[f"{tl}_{l}"] = output["loss"][tl][l]
else:
loss[tl] = output["loss"][tl]
return loss
pairs_per_iteration = config.data.train.dataloader.batch_size
tensorboard_handler = setup_tensorboard_handler(trainer, config, output_transform, iter_per_epoch)
tensorboard_handler = setup_tensorboard_handler(trainer, config, optimizers, step_type="item")
if tensorboard_handler is not None:
tensorboard_handler.attach(
trainer,
log_handler=OptimizerParamsHandler(optimizers["g"], tag="optimizer_g"),
event_name=Events.ITERATION_STARTED(every=max(iter_per_epoch // config.interval.tensorboard.scalar, 1))
)
basic_image_event = Events.ITERATION_COMPLETED(
every=image_per_iteration)
pairs_per_iteration = config.data.train.dataloader.batch_size * idist.get_world_size()
@trainer.on(Events.ITERATION_COMPLETED(every=max(iter_per_epoch // config.interval.tensorboard.image, 1)))
@trainer.on(basic_image_event)
def show_images(engine):
output = engine.state.output
test_images = {}
for k in output["img"]:
image_list = output["img"][k]
tensorboard_handler.writer.add_image(f"train/{k}", make_2d_grid(image_list),
@@ -174,8 +174,8 @@ def get_trainer(config, ek: EngineKernel, iter_per_epoch):
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
for k in batch:
batch[k] = batch[k].view(1, *batch[k].size())
generated = ek.forward(batch)
images = ek.intermediate_images(batch, generated)
generated = kernel.forward(batch)
images = kernel.intermediate_images(batch, generated)
for k in test_images:
for j in range(len(images[k])):
@@ -187,3 +187,78 @@ def get_trainer(config, ek: EngineKernel, iter_per_epoch):
engine.state.iteration * pairs_per_iteration
)
return trainer
def get_tester(config, kernel: TestEngineKernel):
logger = logging.getLogger(config.name)
def _step(engine, batch):
real_a, path = convert_tensor(batch, idist.device())
fake = kernel.inference({"a": real_a})["a"]
return {"path": path, "img": [real_a.detach(), fake.detach()]}
tester = Engine(_step)
tester.logger = logger
setup_common_handlers(tester, config, use_profiler=True, to_save=kernel.to_load())
@tester.on(Events.STARTED)
def mkdir(engine):
img_output_dir = config.img_output_dir or f"{config.output_dir}/test_images"
engine.state.img_output_dir = Path(img_output_dir)
if idist.get_rank() == 0 and not engine.state.img_output_dir.exists():
engine.logger.info(f"mkdir {engine.state.img_output_dir}")
engine.state.img_output_dir.mkdir()
@tester.on(Events.ITERATION_COMPLETED)
def save_images(engine):
img_tensors = engine.state.output["img"]
paths = engine.state.output["path"]
batch_size = img_tensors[0].size(0)
for i in range(batch_size):
image_name = Path(paths[i]).name
torchvision.utils.save_image([img[i] for img in img_tensors], engine.state.img_output_dir / image_name,
nrow=len(img_tensors))
return tester
def run_kernel(task, config, kernel):
assert torch.backends.cudnn.enabled
torch.backends.cudnn.benchmark = True
logger = logging.getLogger(config.name)
with read_write(config):
real_batch_size = config.data.train.dataloader.batch_size * idist.get_world_size()
config.max_iteration = config.max_pairs // real_batch_size + 1
if task == "train":
train_dataset = data.DATASET.build_with(config.data.train.dataset)
logger.info(f"train with dataset:\n{train_dataset}")
dataloader_kwargs = OmegaConf.to_container(config.data.train.dataloader)
dataloader_kwargs["batch_size"] = dataloader_kwargs["batch_size"] * idist.get_world_size()
train_data_loader = idist.auto_dataloader(train_dataset, **dataloader_kwargs)
with read_write(config):
config.iterations_per_epoch = len(train_data_loader)
trainer = get_trainer(config, kernel)
if idist.get_rank() == 0:
test_dataset = data.DATASET.build_with(config.data.test.dataset)
trainer.state.test_dataset = test_dataset
try:
trainer.run(train_data_loader, max_epochs=config.max_iteration // len(train_data_loader) + 1)
except Exception:
import traceback
print(traceback.format_exc())
elif task == "test":
assert config.resume_from is not None
test_dataset = data.DATASET.build_with(config.data.test.video_dataset)
logger.info(f"test with dataset:\n{test_dataset}")
test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader)
tester = get_tester(config, kernel)
try:
tester.run(test_data_loader, max_epochs=1)
except Exception:
import traceback
print(traceback.format_exc())
else:
return NotImplemented(f"invalid task: {task}")