change
This commit is contained in:
@@ -1,32 +1,23 @@
|
||||
from itertools import chain
|
||||
from math import ceil
|
||||
from pathlib import Path
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
import ignite.distributed as idist
|
||||
from ignite.engine import Events, Engine
|
||||
from ignite.metrics import RunningAverage
|
||||
from ignite.utils import convert_tensor
|
||||
from ignite.contrib.handlers.tensorboard_logger import OptimizerParamsHandler
|
||||
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
|
||||
|
||||
from model import MODEL
|
||||
from omegaconf import read_write, OmegaConf
|
||||
|
||||
from util.image import make_2d_grid
|
||||
from util.handler import setup_common_handlers, setup_tensorboard_handler
|
||||
from util.build import build_optimizer
|
||||
from engine.util.handler import setup_common_handlers, setup_tensorboard_handler
|
||||
from engine.util.build import build_optimizer
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
|
||||
def build_model(cfg):
|
||||
cfg = OmegaConf.to_container(cfg)
|
||||
bn_to_sync_bn = cfg.pop("_bn_to_sync_bn", False)
|
||||
model = MODEL.build_with(cfg)
|
||||
if bn_to_sync_bn:
|
||||
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||
return idist.auto_model(model)
|
||||
import data
|
||||
|
||||
|
||||
def build_lr_schedulers(optimizers, config):
|
||||
@@ -47,10 +38,26 @@ def build_lr_schedulers(optimizers, config):
|
||||
)
|
||||
|
||||
|
||||
class EngineKernel(object):
|
||||
def __init__(self, config, logger):
|
||||
class TestEngineKernel(object):
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.logger = logger
|
||||
self.logger = logging.getLogger(config.name)
|
||||
self.generators = self.build_generators()
|
||||
|
||||
def build_generators(self) -> dict:
|
||||
raise NotImplemented
|
||||
|
||||
def to_load(self):
|
||||
return {f"generator_{k}": self.generators[k] for k in self.generators}
|
||||
|
||||
def inference(self, batch):
|
||||
raise NotImplemented
|
||||
|
||||
|
||||
class EngineKernel(object):
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.logger = logging.getLogger(config.name)
|
||||
self.generators, self.discriminators = self.build_models()
|
||||
|
||||
def build_models(self) -> (dict, dict):
|
||||
@@ -87,39 +94,43 @@ class EngineKernel(object):
|
||||
raise NotImplemented
|
||||
|
||||
|
||||
def get_trainer(config, ek: EngineKernel, iter_per_epoch):
|
||||
def get_trainer(config, kernel: EngineKernel):
|
||||
logger = logging.getLogger(config.name)
|
||||
generators, discriminators = ek.generators, ek.discriminators
|
||||
generators, discriminators = kernel.generators, kernel.discriminators
|
||||
optimizers = dict(
|
||||
g=build_optimizer(chain(*[m.parameters() for m in generators.values()]), config.optimizers.generator),
|
||||
d=build_optimizer(chain(*[m.parameters() for m in discriminators.values()]), config.optimizers.discriminator),
|
||||
)
|
||||
logger.info("build optimizers", optimizers)
|
||||
logger.info(f"build optimizers:\n{optimizers}")
|
||||
|
||||
lr_schedulers = build_lr_schedulers(optimizers, config)
|
||||
logger.info(f"build lr_schedulers:\n{lr_schedulers}")
|
||||
|
||||
image_per_iteration = max(config.iterations_per_epoch // config.handler.tensorboard.image, 1)
|
||||
|
||||
def _step(engine, batch):
|
||||
batch = convert_tensor(batch, idist.device())
|
||||
|
||||
generated = ek.forward(batch)
|
||||
generated = kernel.forward(batch)
|
||||
|
||||
ek.setup_before_g()
|
||||
kernel.setup_before_g()
|
||||
optimizers["g"].zero_grad()
|
||||
loss_g = ek.criterion_generators(batch, generated)
|
||||
loss_g = kernel.criterion_generators(batch, generated)
|
||||
sum(loss_g.values()).backward()
|
||||
optimizers["g"].step()
|
||||
|
||||
ek.setup_before_d()
|
||||
kernel.setup_before_d()
|
||||
optimizers["d"].zero_grad()
|
||||
loss_d = ek.criterion_discriminators(batch, generated)
|
||||
loss_d = kernel.criterion_discriminators(batch, generated)
|
||||
sum(loss_d.values()).backward()
|
||||
optimizers["d"].step()
|
||||
|
||||
return {
|
||||
"loss": dict(g=loss_g, d=loss_d),
|
||||
"img": ek.intermediate_images(batch, generated)
|
||||
}
|
||||
if engine.state.iteration % image_per_iteration == 0:
|
||||
return {
|
||||
"loss": dict(g=loss_g, d=loss_d),
|
||||
"img": kernel.intermediate_images(batch, generated)
|
||||
}
|
||||
return {"loss": dict(g=loss_g, d=loss_d)}
|
||||
|
||||
trainer = Engine(_step)
|
||||
trainer.logger = logger
|
||||
@@ -131,33 +142,22 @@ def get_trainer(config, ek: EngineKernel, iter_per_epoch):
|
||||
to_save = dict(trainer=trainer)
|
||||
to_save.update({f"lr_scheduler_{k}": lr_schedulers[k] for k in lr_schedulers})
|
||||
to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers})
|
||||
to_save.update(ek.to_save())
|
||||
setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=True, set_epoch_for_dist_sampler=True,
|
||||
to_save.update(kernel.to_save())
|
||||
setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=config.handler.clear_cuda_cache,
|
||||
set_epoch_for_dist_sampler=config.handler.set_epoch_for_dist_sampler,
|
||||
end_event=Events.ITERATION_COMPLETED(once=config.max_iteration))
|
||||
|
||||
def output_transform(output):
|
||||
loss = dict()
|
||||
for tl in output["loss"]:
|
||||
if isinstance(output["loss"][tl], dict):
|
||||
for l in output["loss"][tl]:
|
||||
loss[f"{tl}_{l}"] = output["loss"][tl][l]
|
||||
else:
|
||||
loss[tl] = output["loss"][tl]
|
||||
return loss
|
||||
|
||||
pairs_per_iteration = config.data.train.dataloader.batch_size
|
||||
tensorboard_handler = setup_tensorboard_handler(trainer, config, output_transform, iter_per_epoch)
|
||||
tensorboard_handler = setup_tensorboard_handler(trainer, config, optimizers, step_type="item")
|
||||
if tensorboard_handler is not None:
|
||||
tensorboard_handler.attach(
|
||||
trainer,
|
||||
log_handler=OptimizerParamsHandler(optimizers["g"], tag="optimizer_g"),
|
||||
event_name=Events.ITERATION_STARTED(every=max(iter_per_epoch // config.interval.tensorboard.scalar, 1))
|
||||
)
|
||||
basic_image_event = Events.ITERATION_COMPLETED(
|
||||
every=image_per_iteration)
|
||||
pairs_per_iteration = config.data.train.dataloader.batch_size * idist.get_world_size()
|
||||
|
||||
@trainer.on(Events.ITERATION_COMPLETED(every=max(iter_per_epoch // config.interval.tensorboard.image, 1)))
|
||||
@trainer.on(basic_image_event)
|
||||
def show_images(engine):
|
||||
output = engine.state.output
|
||||
test_images = {}
|
||||
|
||||
for k in output["img"]:
|
||||
image_list = output["img"][k]
|
||||
tensorboard_handler.writer.add_image(f"train/{k}", make_2d_grid(image_list),
|
||||
@@ -174,8 +174,8 @@ def get_trainer(config, ek: EngineKernel, iter_per_epoch):
|
||||
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
|
||||
for k in batch:
|
||||
batch[k] = batch[k].view(1, *batch[k].size())
|
||||
generated = ek.forward(batch)
|
||||
images = ek.intermediate_images(batch, generated)
|
||||
generated = kernel.forward(batch)
|
||||
images = kernel.intermediate_images(batch, generated)
|
||||
|
||||
for k in test_images:
|
||||
for j in range(len(images[k])):
|
||||
@@ -187,3 +187,78 @@ def get_trainer(config, ek: EngineKernel, iter_per_epoch):
|
||||
engine.state.iteration * pairs_per_iteration
|
||||
)
|
||||
return trainer
|
||||
|
||||
|
||||
def get_tester(config, kernel: TestEngineKernel):
|
||||
logger = logging.getLogger(config.name)
|
||||
|
||||
def _step(engine, batch):
|
||||
real_a, path = convert_tensor(batch, idist.device())
|
||||
fake = kernel.inference({"a": real_a})["a"]
|
||||
return {"path": path, "img": [real_a.detach(), fake.detach()]}
|
||||
|
||||
tester = Engine(_step)
|
||||
tester.logger = logger
|
||||
|
||||
setup_common_handlers(tester, config, use_profiler=True, to_save=kernel.to_load())
|
||||
|
||||
@tester.on(Events.STARTED)
|
||||
def mkdir(engine):
|
||||
img_output_dir = config.img_output_dir or f"{config.output_dir}/test_images"
|
||||
engine.state.img_output_dir = Path(img_output_dir)
|
||||
if idist.get_rank() == 0 and not engine.state.img_output_dir.exists():
|
||||
engine.logger.info(f"mkdir {engine.state.img_output_dir}")
|
||||
engine.state.img_output_dir.mkdir()
|
||||
|
||||
@tester.on(Events.ITERATION_COMPLETED)
|
||||
def save_images(engine):
|
||||
img_tensors = engine.state.output["img"]
|
||||
paths = engine.state.output["path"]
|
||||
batch_size = img_tensors[0].size(0)
|
||||
for i in range(batch_size):
|
||||
image_name = Path(paths[i]).name
|
||||
torchvision.utils.save_image([img[i] for img in img_tensors], engine.state.img_output_dir / image_name,
|
||||
nrow=len(img_tensors))
|
||||
|
||||
return tester
|
||||
|
||||
|
||||
def run_kernel(task, config, kernel):
|
||||
assert torch.backends.cudnn.enabled
|
||||
torch.backends.cudnn.benchmark = True
|
||||
logger = logging.getLogger(config.name)
|
||||
with read_write(config):
|
||||
real_batch_size = config.data.train.dataloader.batch_size * idist.get_world_size()
|
||||
config.max_iteration = config.max_pairs // real_batch_size + 1
|
||||
|
||||
if task == "train":
|
||||
train_dataset = data.DATASET.build_with(config.data.train.dataset)
|
||||
logger.info(f"train with dataset:\n{train_dataset}")
|
||||
dataloader_kwargs = OmegaConf.to_container(config.data.train.dataloader)
|
||||
dataloader_kwargs["batch_size"] = dataloader_kwargs["batch_size"] * idist.get_world_size()
|
||||
train_data_loader = idist.auto_dataloader(train_dataset, **dataloader_kwargs)
|
||||
with read_write(config):
|
||||
config.iterations_per_epoch = len(train_data_loader)
|
||||
|
||||
trainer = get_trainer(config, kernel)
|
||||
if idist.get_rank() == 0:
|
||||
test_dataset = data.DATASET.build_with(config.data.test.dataset)
|
||||
trainer.state.test_dataset = test_dataset
|
||||
try:
|
||||
trainer.run(train_data_loader, max_epochs=config.max_iteration // len(train_data_loader) + 1)
|
||||
except Exception:
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
elif task == "test":
|
||||
assert config.resume_from is not None
|
||||
test_dataset = data.DATASET.build_with(config.data.test.video_dataset)
|
||||
logger.info(f"test with dataset:\n{test_dataset}")
|
||||
test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader)
|
||||
tester = get_tester(config, kernel)
|
||||
try:
|
||||
tester.run(test_data_loader, max_epochs=1)
|
||||
except Exception:
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
else:
|
||||
return NotImplemented(f"invalid task: {task}")
|
||||
|
||||
Reference in New Issue
Block a user