test
This commit is contained in:
@@ -8,11 +8,13 @@ from io import BytesIO
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision.datasets.folder import default_loader
|
||||
from torchvision.datasets import ImageFolder
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms import functional
|
||||
from pathlib import Path
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
class CARS(Dataset):
|
||||
class _CARS(Dataset):
|
||||
def __init__(self, root, loader=default_loader, transform=None):
|
||||
self.root = Path(root)
|
||||
self.transform = transform
|
||||
@@ -32,7 +34,7 @@ class CARS(Dataset):
|
||||
sample = self.loader(self.root / "cars_train" / file_name)
|
||||
if self.transform is not None:
|
||||
sample = self.transform(sample)
|
||||
return sample
|
||||
return sample, target
|
||||
|
||||
|
||||
class ImprovedImageFolder(ImageFolder):
|
||||
@@ -44,7 +46,7 @@ class ImprovedImageFolder(ImageFolder):
|
||||
assert len(self.classes_list) == len(self.classes)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return super().__getitem__(item)[0]
|
||||
return super().__getitem__(item)
|
||||
|
||||
|
||||
class LMDBDataset(Dataset):
|
||||
@@ -61,16 +63,10 @@ class LMDBDataset(Dataset):
|
||||
|
||||
def __getitem__(self, i):
|
||||
with self.db.begin(write=False) as txn:
|
||||
sample = Image.open(BytesIO(txn.get("{}".format(i).encode())))
|
||||
if sample.mode != "RGB":
|
||||
sample = sample.convert("RGB")
|
||||
sample, target = pickle.loads(txn.get("{}".format(i).encode()))
|
||||
if self.transform is not None:
|
||||
try:
|
||||
sample = self.transform(sample)
|
||||
except RuntimeError as re:
|
||||
print(sample.format, sample.size, sample.mode)
|
||||
raise re
|
||||
return sample
|
||||
sample = self.transform(sample)
|
||||
return sample, target
|
||||
|
||||
|
||||
class EpisodicDataset(Dataset):
|
||||
@@ -81,6 +77,24 @@ class EpisodicDataset(Dataset):
|
||||
self.num_set = num_set # K
|
||||
self.num_episodes = num_episodes
|
||||
|
||||
self.t0 = transforms.Compose([
|
||||
transforms.Resize((224, 224)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
])
|
||||
self.transform = transforms.Compose([
|
||||
transforms.Resize((224, 224)),
|
||||
transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
def apply_transform(self, img):
|
||||
# img1 = self.transform(img)
|
||||
# img2 = self.transform(img)
|
||||
# return [self.t0(img), self.t0(functional.hflip(img))]
|
||||
return [self.t0(img)]
|
||||
|
||||
def __len__(self):
|
||||
return self.num_episodes
|
||||
|
||||
@@ -95,8 +109,11 @@ class EpisodicDataset(Dataset):
|
||||
idx_list = torch.randperm(len(image_list))[:self.num_set * 2].tolist()
|
||||
else:
|
||||
idx_list = torch.randint(high=len(image_list), size=(self.num_set * 2,)).tolist()
|
||||
support_set_list.extend([self.origin[image_list[idx]] for idx in idx_list[:self.num_set]])
|
||||
query_set_list.extend([self.origin[image_list[idx]] for idx in idx_list[self.num_set:]])
|
||||
support = [self.origin[image_list[idx]][0] for idx in idx_list[:self.num_set]]
|
||||
query = [self.origin[image_list[idx]][0] for idx in idx_list[:self.num_set]]
|
||||
|
||||
support_set_list.extend(sum(map(self.apply_transform, support), list()))
|
||||
query_set_list.extend(sum(map(self.apply_transform, query), list()))
|
||||
target_list.extend([i] * self.num_set)
|
||||
return {
|
||||
"support": torch.stack(support_set_list),
|
||||
|
||||
Reference in New Issue
Block a user