UGATIT version 0.1
This commit is contained in:
@@ -17,7 +17,7 @@ def empty_cuda_cache(_):
|
||||
|
||||
|
||||
def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_cache=False, use_profiler=True,
|
||||
to_save=None, metrics_to_print=None, end_event=None):
|
||||
to_save=None, end_event=None, set_epoch_for_dist_sampler=True):
|
||||
"""
|
||||
Helper method to setup trainer with common handlers.
|
||||
1. TerminateOnNan
|
||||
@@ -30,21 +30,21 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_
|
||||
:param clear_cuda_cache:
|
||||
:param use_profiler:
|
||||
:param to_save:
|
||||
:param metrics_to_print:
|
||||
:param end_event:
|
||||
:param set_epoch_for_dist_sampler:
|
||||
:return:
|
||||
"""
|
||||
|
||||
if isinstance(trainer.state.dataloader.sampler, DistributedSampler):
|
||||
if set_epoch_for_dist_sampler:
|
||||
@trainer.on(Events.EPOCH_STARTED)
|
||||
def distrib_set_epoch(engine):
|
||||
trainer.logger.debug(f"set_epoch {engine.state.epoch - 1} for DistributedSampler")
|
||||
trainer.state.dataloader.sampler.set_epoch(engine.state.epoch - 1)
|
||||
if isinstance(trainer.state.dataloader.sampler, DistributedSampler):
|
||||
trainer.logger.debug(f"set_epoch {engine.state.epoch - 1} for DistributedSampler")
|
||||
trainer.state.dataloader.sampler.set_epoch(engine.state.epoch - 1)
|
||||
|
||||
@trainer.on(Events.STARTED)
|
||||
@idist.one_rank_only()
|
||||
def print_dataloader_size(engine):
|
||||
@trainer.on(Events.STARTED | Events.EPOCH_COMPLETED(once=1))
|
||||
def print_info(engine):
|
||||
engine.logger.info(f"data loader length: {len(engine.state.dataloader)}")
|
||||
engine.logger.info(f"- GPU util: \n{torch.cuda.memory_summary(0)}")
|
||||
|
||||
if stop_on_nan:
|
||||
trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
|
||||
@@ -62,20 +62,8 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_
|
||||
def log_intermediate_results():
|
||||
profiler.print_results(profiler.get_results())
|
||||
|
||||
print_interval_event = Events.ITERATION_COMPLETED(every=config.interval.print_per_iteration) | Events.COMPLETED
|
||||
|
||||
ProgressBar(ncols=0).attach(trainer, "all")
|
||||
|
||||
if metrics_to_print is not None:
|
||||
@trainer.on(print_interval_event)
|
||||
def print_interval(engine):
|
||||
print_str = f"epoch:{engine.state.epoch} iter:{engine.state.iteration}\t"
|
||||
for m in metrics_to_print:
|
||||
if m not in engine.state.metrics:
|
||||
continue
|
||||
print_str += f"{m}={engine.state.metrics[m]:.3f} "
|
||||
engine.logger.debug(print_str)
|
||||
|
||||
if to_save is not None:
|
||||
checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=config.output_dir, require_empty=False),
|
||||
n_saved=config.checkpoint.n_saved, filename_prefix=config.name)
|
||||
@@ -86,6 +74,7 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_
|
||||
if not checkpoint_path.exists():
|
||||
raise FileNotFoundError(f"Checkpoint '{checkpoint_path}' is not found")
|
||||
ckp = torch.load(checkpoint_path.as_posix(), map_location="cpu")
|
||||
trainer.logger.info(f"load state_dict for {ckp.keys()}")
|
||||
Checkpoint.load_objects(to_load=to_save, checkpoint=ckp)
|
||||
engine.logger.info(f"resume from a checkpoint {checkpoint_path}")
|
||||
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=config.checkpoint.epoch_interval) | Events.COMPLETED,
|
||||
|
||||
@@ -5,6 +5,21 @@ import warnings
|
||||
from torch.nn.functional import interpolate
|
||||
|
||||
|
||||
def attention_colored_map(attentions, size=None, cmap_name="jet"):
|
||||
assert attentions.dim() == 4 and attentions.size(1) == 1
|
||||
|
||||
min_attentions = attentions.view(attentions.size(0), -1).min(1, keepdim=True)[0].view(attentions.size(0), 1, 1, 1)
|
||||
attentions -= min_attentions
|
||||
attentions /= attentions.view(attentions.size(0), -1).max(1, keepdim=True)[0].view(attentions.size(0), 1, 1, 1)
|
||||
|
||||
if size is not None and attentions.size()[-2:] != size:
|
||||
assert len(size) == 2, "for interpolate, size must be (x, y), have two dim"
|
||||
attentions = interpolate(attentions, size, mode="bilinear", align_corners=False)
|
||||
cmap = get_cmap(cmap_name)
|
||||
ca = cmap(attentions.squeeze(1).cpu())[:, :, :, :3]
|
||||
return torch.from_numpy(ca).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
|
||||
def fuse_attention_map(images, attentions, cmap_name="jet", alpha=0.5):
|
||||
"""
|
||||
|
||||
@@ -20,18 +35,7 @@ def fuse_attention_map(images, attentions, cmap_name="jet", alpha=0.5):
|
||||
if attentions.size(1) != 1:
|
||||
warnings.warn(f"attentions's channels should be 1 but got {attentions.size(1)}")
|
||||
return images
|
||||
|
||||
min_attentions = attentions.view(attentions.size(0), -1).min(1, keepdim=True)[0].view(attentions.size(0), 1, 1, 1)
|
||||
attentions -= min_attentions
|
||||
attentions /= attentions.view(attentions.size(0), -1).max(1, keepdim=True)[0].view(attentions.size(0), 1, 1, 1)
|
||||
|
||||
if images.size() != attentions.size():
|
||||
attentions = interpolate(attentions, images.size()[-2:])
|
||||
colored_attentions = torch.zeros_like(images)
|
||||
cmap = get_cmap(cmap_name)
|
||||
for i, at in enumerate(attentions):
|
||||
ca = cmap(at[0].cpu().numpy())[:, :, :3]
|
||||
colored_attentions[i] = torch.from_numpy(ca).permute(2, 0, 1).view(colored_attentions[i].size())
|
||||
colored_attentions = attention_colored_map(attentions, images.size()[-2:], cmap_name).to(images.device)
|
||||
return images * alpha + colored_attentions * (1 - alpha)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user