update
This commit is contained in:
69
tool/verify_loss.py
Normal file
69
tool/verify_loss.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from ignite.utils import convert_tensor
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from data.dataset import SingleFolderDataset
|
||||
from loss.I2I.perceptual_loss import PerceptualLoss
|
||||
|
||||
import ignite.distributed as idist
|
||||
|
||||
CONFIG = """
|
||||
loss:
|
||||
perceptual:
|
||||
layer_weights:
|
||||
"1": 0.03125
|
||||
"6": 0.0625
|
||||
"11": 0.125
|
||||
"20": 0.25
|
||||
"29": 1
|
||||
criterion: 'NL2'
|
||||
style_loss: False
|
||||
perceptual_loss: True
|
||||
match_data:
|
||||
root: "/tmp/generated/"
|
||||
pipeline:
|
||||
- Load
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
not_match_data:
|
||||
root: "/data/i2i/selfie2anime/trainB/"
|
||||
pipeline:
|
||||
- Load
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
"""
|
||||
|
||||
config = OmegaConf.create(CONFIG)
|
||||
dataset = SingleFolderDataset(**config.match_data)
|
||||
data_loader = DataLoader(dataset, 1, False, num_workers=1)
|
||||
|
||||
perceptual_loss = PerceptualLoss(**config.loss.perceptual).to("cuda:0")
|
||||
|
||||
pls = []
|
||||
for batch in data_loader:
|
||||
with torch.no_grad():
|
||||
batch = convert_tensor(batch, "cuda:0")
|
||||
x, t = torch.chunk(batch, 2, -1)
|
||||
pl, _ = perceptual_loss(x, t)
|
||||
print(pl)
|
||||
pls.append(pl)
|
||||
|
||||
torch.save(torch.stack(pls).cpu(), "verify_loss.match.pt")
|
||||
|
||||
dataset = SingleFolderDataset(**config.not_match_data)
|
||||
data_loader = DataLoader(dataset, 4, False, num_workers=1)
|
||||
pls = []
|
||||
for batch in data_loader:
|
||||
with torch.no_grad():
|
||||
batch = convert_tensor(batch, "cuda:0")
|
||||
for i, j in [(0, 1), (1, 2), (2, 3), (3, 0)]:
|
||||
x, t = batch[i].unsqueeze(dim=0), batch[j].unsqueeze(dim=0)
|
||||
pl, _ = perceptual_loss(x, t)
|
||||
print(pl)
|
||||
pls.append(pl)
|
||||
torch.save(torch.stack(pls).cpu(), "verify_loss.not_match.pt")
|
||||
Reference in New Issue
Block a user