add U-GAT-IT

This commit is contained in:
2020-08-21 16:14:30 +08:00
parent 323bf2f6ab
commit 1a1cb9b00f
18 changed files with 815 additions and 55 deletions

View File

@@ -99,9 +99,9 @@ class EpisodicDataset(Dataset):
def __getitem__(self, _):
random_classes = torch.randperm(len(self.origin.classes_list))[:self.num_class].tolist()
support_set_list = []
query_set_list = []
target_list = []
support_set = []
query_set = []
target_set = []
for tag, c in enumerate(random_classes):
image_list = self.origin.classes_list[c]
@@ -113,13 +113,13 @@ class EpisodicDataset(Dataset):
support = self._fetch_list_data(map(image_list.__getitem__, idx_list[:self.num_support]))
query = self._fetch_list_data(map(image_list.__getitem__, idx_list[self.num_support:]))
support_set_list.extend(support)
query_set_list.extend(query)
target_list.extend([tag] * self.num_query)
support_set.extend(support)
query_set.extend(query)
target_set.extend([tag] * self.num_query)
return {
"support": torch.stack(support_set_list),
"query": torch.stack(query_set_list),
"target": torch.tensor(target_list)
"support": torch.stack(support_set),
"query": torch.stack(query_set),
"target": torch.tensor(target_set)
}
def __repr__(self):