This commit is contained in:
2020-07-23 22:32:28 +08:00
parent 3a72dcb5f0
commit ead93c1b0e
5 changed files with 78 additions and 34 deletions

View File

@@ -8,11 +8,13 @@ from io import BytesIO
from torch.utils.data import Dataset
from torchvision.datasets.folder import default_loader
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torchvision.transforms import functional
from pathlib import Path
from collections import defaultdict
class CARS(Dataset):
class _CARS(Dataset):
def __init__(self, root, loader=default_loader, transform=None):
self.root = Path(root)
self.transform = transform
@@ -32,7 +34,7 @@ class CARS(Dataset):
sample = self.loader(self.root / "cars_train" / file_name)
if self.transform is not None:
sample = self.transform(sample)
return sample
return sample, target
class ImprovedImageFolder(ImageFolder):
@@ -44,7 +46,7 @@ class ImprovedImageFolder(ImageFolder):
assert len(self.classes_list) == len(self.classes)
def __getitem__(self, item):
return super().__getitem__(item)[0]
return super().__getitem__(item)
class LMDBDataset(Dataset):
@@ -61,16 +63,10 @@ class LMDBDataset(Dataset):
def __getitem__(self, i):
with self.db.begin(write=False) as txn:
sample = Image.open(BytesIO(txn.get("{}".format(i).encode())))
if sample.mode != "RGB":
sample = sample.convert("RGB")
sample, target = pickle.loads(txn.get("{}".format(i).encode()))
if self.transform is not None:
try:
sample = self.transform(sample)
except RuntimeError as re:
print(sample.format, sample.size, sample.mode)
raise re
return sample
sample = self.transform(sample)
return sample, target
class EpisodicDataset(Dataset):
@@ -81,6 +77,24 @@ class EpisodicDataset(Dataset):
self.num_set = num_set # K
self.num_episodes = num_episodes
self.t0 = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def apply_transform(self, img):
# img1 = self.transform(img)
# img2 = self.transform(img)
# return [self.t0(img), self.t0(functional.hflip(img))]
return [self.t0(img)]
def __len__(self):
return self.num_episodes
@@ -95,8 +109,11 @@ class EpisodicDataset(Dataset):
idx_list = torch.randperm(len(image_list))[:self.num_set * 2].tolist()
else:
idx_list = torch.randint(high=len(image_list), size=(self.num_set * 2,)).tolist()
support_set_list.extend([self.origin[image_list[idx]] for idx in idx_list[:self.num_set]])
query_set_list.extend([self.origin[image_list[idx]] for idx in idx_list[self.num_set:]])
support = [self.origin[image_list[idx]][0] for idx in idx_list[:self.num_set]]
query = [self.origin[image_list[idx]][0] for idx in idx_list[:self.num_set]]
support_set_list.extend(sum(map(self.apply_transform, support), list()))
query_set_list.extend(sum(map(self.apply_transform, query), list()))
target_list.extend([i] * self.num_set)
return {
"support": torch.stack(support_set_list),