almost 0.1
This commit is contained in:
@@ -1,23 +1,21 @@
|
||||
from itertools import chain
|
||||
import logging
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
|
||||
import ignite.distributed as idist
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
import ignite.distributed as idist
|
||||
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
|
||||
from ignite.engine import Events, Engine
|
||||
from ignite.metrics import RunningAverage
|
||||
from ignite.utils import convert_tensor
|
||||
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
|
||||
|
||||
from math import ceil
|
||||
from omegaconf import read_write, OmegaConf
|
||||
|
||||
from util.image import make_2d_grid
|
||||
from engine.util.handler import setup_common_handlers, setup_tensorboard_handler
|
||||
from engine.util.build import build_optimizer
|
||||
|
||||
import data
|
||||
from engine.util.build import build_optimizer
|
||||
from engine.util.handler import setup_common_handlers, setup_tensorboard_handler
|
||||
from util.image import make_2d_grid
|
||||
|
||||
|
||||
def build_lr_schedulers(optimizers, config):
|
||||
@@ -59,6 +57,7 @@ class EngineKernel(object):
|
||||
self.config = config
|
||||
self.logger = logging.getLogger(config.name)
|
||||
self.generators, self.discriminators = self.build_models()
|
||||
self.train_generator_first = True
|
||||
|
||||
def build_models(self) -> (dict, dict):
|
||||
raise NotImplemented
|
||||
@@ -69,7 +68,7 @@ class EngineKernel(object):
|
||||
to_save.update({f"discriminator_{k}": self.discriminators[k] for k in self.discriminators})
|
||||
return to_save
|
||||
|
||||
def setup_before_d(self):
|
||||
def setup_after_g(self):
|
||||
raise NotImplemented
|
||||
|
||||
def setup_before_g(self):
|
||||
@@ -93,6 +92,9 @@ class EngineKernel(object):
|
||||
"""
|
||||
raise NotImplemented
|
||||
|
||||
def change_engine(self, config, engine: Engine):
|
||||
pass
|
||||
|
||||
|
||||
def get_trainer(config, kernel: EngineKernel):
|
||||
logger = logging.getLogger(config.name)
|
||||
@@ -106,26 +108,37 @@ def get_trainer(config, kernel: EngineKernel):
|
||||
lr_schedulers = build_lr_schedulers(optimizers, config)
|
||||
logger.info(f"build lr_schedulers:\n{lr_schedulers}")
|
||||
|
||||
image_per_iteration = max(config.iterations_per_epoch // config.handler.tensorboard.image, 1)
|
||||
iteration_per_image = max(config.iterations_per_epoch // config.handler.tensorboard.image, 1)
|
||||
|
||||
def train_generators(batch, generated):
|
||||
kernel.setup_before_g()
|
||||
optimizers["g"].zero_grad()
|
||||
loss_g = kernel.criterion_generators(batch, generated)
|
||||
sum(loss_g.values()).backward()
|
||||
optimizers["g"].step()
|
||||
kernel.setup_after_g()
|
||||
return loss_g
|
||||
|
||||
def train_discriminators(batch, generated):
|
||||
optimizers["d"].zero_grad()
|
||||
loss_d = kernel.criterion_discriminators(batch, generated)
|
||||
sum(loss_d.values()).backward()
|
||||
optimizers["d"].step()
|
||||
return loss_d
|
||||
|
||||
def _step(engine, batch):
|
||||
batch = convert_tensor(batch, idist.device())
|
||||
|
||||
generated = kernel.forward(batch)
|
||||
|
||||
kernel.setup_before_g()
|
||||
optimizers["g"].zero_grad()
|
||||
loss_g = kernel.criterion_generators(batch, generated)
|
||||
sum(loss_g.values()).backward()
|
||||
optimizers["g"].step()
|
||||
if kernel.train_generator_first:
|
||||
loss_g = train_generators(batch, generated)
|
||||
loss_d = train_discriminators(batch, generated)
|
||||
else:
|
||||
loss_d = train_discriminators(batch, generated)
|
||||
loss_g = train_generators(batch, generated)
|
||||
|
||||
kernel.setup_before_d()
|
||||
optimizers["d"].zero_grad()
|
||||
loss_d = kernel.criterion_discriminators(batch, generated)
|
||||
sum(loss_d.values()).backward()
|
||||
optimizers["d"].step()
|
||||
|
||||
if engine.state.iteration % image_per_iteration == 0:
|
||||
if engine.state.iteration % iteration_per_image == 0:
|
||||
return {
|
||||
"loss": dict(g=loss_g, d=loss_d),
|
||||
"img": kernel.intermediate_images(batch, generated)
|
||||
@@ -137,6 +150,8 @@ def get_trainer(config, kernel: EngineKernel):
|
||||
for lr_shd in lr_schedulers.values():
|
||||
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_shd)
|
||||
|
||||
kernel.change_engine(config, trainer)
|
||||
|
||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
|
||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d")
|
||||
to_save = dict(trainer=trainer)
|
||||
@@ -150,7 +165,7 @@ def get_trainer(config, kernel: EngineKernel):
|
||||
tensorboard_handler = setup_tensorboard_handler(trainer, config, optimizers, step_type="item")
|
||||
if tensorboard_handler is not None:
|
||||
basic_image_event = Events.ITERATION_COMPLETED(
|
||||
every=image_per_iteration)
|
||||
every=iteration_per_image)
|
||||
pairs_per_iteration = config.data.train.dataloader.batch_size * idist.get_world_size()
|
||||
|
||||
@trainer.on(basic_image_event)
|
||||
@@ -227,7 +242,7 @@ def run_kernel(task, config, kernel):
|
||||
logger = logging.getLogger(config.name)
|
||||
with read_write(config):
|
||||
real_batch_size = config.data.train.dataloader.batch_size * idist.get_world_size()
|
||||
config.max_iteration = config.max_pairs // real_batch_size + 1
|
||||
config.max_iteration = ceil(config.max_pairs / real_batch_size)
|
||||
|
||||
if task == "train":
|
||||
train_dataset = data.DATASET.build_with(config.data.train.dataset)
|
||||
@@ -243,7 +258,7 @@ def run_kernel(task, config, kernel):
|
||||
test_dataset = data.DATASET.build_with(config.data.test.dataset)
|
||||
trainer.state.test_dataset = test_dataset
|
||||
try:
|
||||
trainer.run(train_data_loader, max_epochs=config.max_iteration // len(train_data_loader) + 1)
|
||||
trainer.run(train_data_loader, max_epochs=ceil(config.max_iteration / len(train_data_loader)))
|
||||
except Exception:
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
|
||||
Reference in New Issue
Block a user