This commit is contained in:
2020-07-11 20:51:22 +08:00
parent 07c63abb30
commit 598bd9e0f1
4 changed files with 60 additions and 33 deletions

View File

@@ -40,15 +40,17 @@ class ImprovedImageFolder(ImageFolder):
self.classes_list = defaultdict(list)
for i in range(len(self)):
self.classes_list[self.samples[i][-1]].append(i)
assert len(self.classes_list) == len(self.classes)
def __getitem__(self, item):
return super().__getitem__(item)[0]
class LMDBDataset(Dataset):
def __init__(self, lmdb_path):
self.db = lmdb.open(lmdb_path, subdir=os.path.isdir(lmdb_path), readonly=True, lock=False,
readahead=False, meminit=False)
def __init__(self, lmdb_path, transform=None):
self.db = lmdb.open(lmdb_path, map_size=1099511627776, subdir=os.path.isdir(lmdb_path), readonly=True,
lock=False, readahead=False, meminit=False)
self.transform = transform
with self.db.begin(write=False) as txn:
self.classes_list = pickle.loads(txn.get(b"classes_list"))
self._len = pickle.loads(txn.get(b"__len__"))
@@ -58,7 +60,10 @@ class LMDBDataset(Dataset):
def __getitem__(self, i):
with self.db.begin(write=False) as txn:
return torch.load(BytesIO(txn.get("{}".format(i).encode())))
sample = torch.load(BytesIO(txn.get("{}".format(i).encode())))
if self.transform is not None:
sample = self.transform(sample)
return sample
class EpisodicDataset(Dataset):
@@ -73,7 +78,7 @@ class EpisodicDataset(Dataset):
return self.num_episodes
def __getitem__(self, _):
random_classes = torch.randint(high=len(self.origin.classes_list), size=(self.num_class,)).tolist()
random_classes = torch.randperm(len(self.origin.classes_list))[:self.num_class].tolist()
support_set_list = []
query_set_list = []
target_list = []
@@ -83,8 +88,8 @@ class EpisodicDataset(Dataset):
idx_list = torch.randperm(len(image_list))[:self.num_set * 2].tolist()
else:
idx_list = torch.randint(high=len(image_list), size=(self.num_set * 2,)).tolist()
support_set_list.extend([self.origin[idx] for idx in idx_list[:self.num_set]])
query_set_list.extend([self.origin[idx] for idx in idx_list[self.num_set:]])
support_set_list.extend([self.origin[image_list[idx]] for idx in idx_list[:self.num_set]])
query_set_list.extend([self.origin[image_list[idx]] for idx in idx_list[self.num_set:]])
target_list.extend([i] * self.num_set)
return {
"support": torch.stack(support_set_list),

View File

@@ -1,17 +1,19 @@
import torch
import lmdb
import os
import pickle
from io import BytesIO
import argparse
import torch
import lmdb
from data.dataset import CARS, ImprovedImageFolder
import torchvision
from tqdm import tqdm
def dataset_to_lmdb(dataset, lmdb_path):
env = lmdb.open(lmdb_path, map_size=1099511627776 * 2, subdir=os.path.isdir(lmdb_path))
env = lmdb.open(lmdb_path, map_size=1099511627776, subdir=os.path.isdir(lmdb_path))
with env.begin(write=True) as txn:
for i in tqdm(range(len(dataset))):
for i in tqdm(range(len(dataset)), ncols=50):
buffer = BytesIO()
torch.save(dataset[i], buffer)
txn.put("{}".format(i).encode(), buffer.getvalue())
@@ -19,17 +21,23 @@ def dataset_to_lmdb(dataset, lmdb_path):
txn.put(b"__len__", pickle.dumps(len(dataset)))
def main():
data_transform = torchvision.transforms.Compose([
torchvision.transforms.Resize([int(224 * 1.15), int(224 * 1.15)]),
def transform(save_path, dataset_path):
print(save_path, dataset_path)
dt = torchvision.transforms.Compose([
torchvision.transforms.Resize((256, 256)),
torchvision.transforms.CenterCrop(224),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
origin_dataset = ImprovedImageFolder("/data/few-shot/CUB_200_2011/CUB_200_2011/images", transform=data_transform)
dataset_to_lmdb(origin_dataset, "/data/few-shot/lmdb/CUB_200_2011/data.lmdb")
# origin_dataset = CARS("/data/few-shot/STANFORD-CARS/", transform=dt)
origin_dataset = ImprovedImageFolder(dataset_path, transform=dt)
dataset_to_lmdb(origin_dataset, save_path)
if __name__ == '__main__':
main()
parser = argparse.ArgumentParser(description="transform dataset to lmdb database")
parser.add_argument('--save', required=True)
parser.add_argument('--dataset', required=True)
args = parser.parse_args()
transform(args.save, args.dataset)