almost 0.1

This commit is contained in:
2020-09-06 10:34:52 +08:00
parent e3c760d0c5
commit ab545843bf
15 changed files with 308 additions and 680 deletions

View File

@@ -1,23 +1,21 @@
from itertools import chain
import logging
from itertools import chain
from pathlib import Path
import ignite.distributed as idist
import torch
import torchvision
import ignite.distributed as idist
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
from ignite.engine import Events, Engine
from ignite.metrics import RunningAverage
from ignite.utils import convert_tensor
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
from math import ceil
from omegaconf import read_write, OmegaConf
from util.image import make_2d_grid
from engine.util.handler import setup_common_handlers, setup_tensorboard_handler
from engine.util.build import build_optimizer
import data
from engine.util.build import build_optimizer
from engine.util.handler import setup_common_handlers, setup_tensorboard_handler
from util.image import make_2d_grid
def build_lr_schedulers(optimizers, config):
@@ -59,6 +57,7 @@ class EngineKernel(object):
self.config = config
self.logger = logging.getLogger(config.name)
self.generators, self.discriminators = self.build_models()
self.train_generator_first = True
def build_models(self) -> (dict, dict):
raise NotImplemented
@@ -69,7 +68,7 @@ class EngineKernel(object):
to_save.update({f"discriminator_{k}": self.discriminators[k] for k in self.discriminators})
return to_save
def setup_before_d(self):
def setup_after_g(self):
raise NotImplemented
def setup_before_g(self):
@@ -93,6 +92,9 @@ class EngineKernel(object):
"""
raise NotImplemented
def change_engine(self, config, engine: Engine):
pass
def get_trainer(config, kernel: EngineKernel):
logger = logging.getLogger(config.name)
@@ -106,26 +108,37 @@ def get_trainer(config, kernel: EngineKernel):
lr_schedulers = build_lr_schedulers(optimizers, config)
logger.info(f"build lr_schedulers:\n{lr_schedulers}")
image_per_iteration = max(config.iterations_per_epoch // config.handler.tensorboard.image, 1)
iteration_per_image = max(config.iterations_per_epoch // config.handler.tensorboard.image, 1)
def train_generators(batch, generated):
kernel.setup_before_g()
optimizers["g"].zero_grad()
loss_g = kernel.criterion_generators(batch, generated)
sum(loss_g.values()).backward()
optimizers["g"].step()
kernel.setup_after_g()
return loss_g
def train_discriminators(batch, generated):
optimizers["d"].zero_grad()
loss_d = kernel.criterion_discriminators(batch, generated)
sum(loss_d.values()).backward()
optimizers["d"].step()
return loss_d
def _step(engine, batch):
batch = convert_tensor(batch, idist.device())
generated = kernel.forward(batch)
kernel.setup_before_g()
optimizers["g"].zero_grad()
loss_g = kernel.criterion_generators(batch, generated)
sum(loss_g.values()).backward()
optimizers["g"].step()
if kernel.train_generator_first:
loss_g = train_generators(batch, generated)
loss_d = train_discriminators(batch, generated)
else:
loss_d = train_discriminators(batch, generated)
loss_g = train_generators(batch, generated)
kernel.setup_before_d()
optimizers["d"].zero_grad()
loss_d = kernel.criterion_discriminators(batch, generated)
sum(loss_d.values()).backward()
optimizers["d"].step()
if engine.state.iteration % image_per_iteration == 0:
if engine.state.iteration % iteration_per_image == 0:
return {
"loss": dict(g=loss_g, d=loss_d),
"img": kernel.intermediate_images(batch, generated)
@@ -137,6 +150,8 @@ def get_trainer(config, kernel: EngineKernel):
for lr_shd in lr_schedulers.values():
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_shd)
kernel.change_engine(config, trainer)
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d")
to_save = dict(trainer=trainer)
@@ -150,7 +165,7 @@ def get_trainer(config, kernel: EngineKernel):
tensorboard_handler = setup_tensorboard_handler(trainer, config, optimizers, step_type="item")
if tensorboard_handler is not None:
basic_image_event = Events.ITERATION_COMPLETED(
every=image_per_iteration)
every=iteration_per_image)
pairs_per_iteration = config.data.train.dataloader.batch_size * idist.get_world_size()
@trainer.on(basic_image_event)
@@ -227,7 +242,7 @@ def run_kernel(task, config, kernel):
logger = logging.getLogger(config.name)
with read_write(config):
real_batch_size = config.data.train.dataloader.batch_size * idist.get_world_size()
config.max_iteration = config.max_pairs // real_batch_size + 1
config.max_iteration = ceil(config.max_pairs / real_batch_size)
if task == "train":
train_dataset = data.DATASET.build_with(config.data.train.dataset)
@@ -243,7 +258,7 @@ def run_kernel(task, config, kernel):
test_dataset = data.DATASET.build_with(config.data.test.dataset)
trainer.state.test_dataset = test_dataset
try:
trainer.run(train_data_loader, max_epochs=config.max_iteration // len(train_data_loader) + 1)
trainer.run(train_data_loader, max_epochs=ceil(config.max_iteration / len(train_data_loader)))
except Exception:
import traceback
print(traceback.format_exc())