add U-GAT-IT
This commit is contained in:
@@ -99,9 +99,9 @@ class EpisodicDataset(Dataset):
|
||||
|
||||
def __getitem__(self, _):
|
||||
random_classes = torch.randperm(len(self.origin.classes_list))[:self.num_class].tolist()
|
||||
support_set_list = []
|
||||
query_set_list = []
|
||||
target_list = []
|
||||
support_set = []
|
||||
query_set = []
|
||||
target_set = []
|
||||
for tag, c in enumerate(random_classes):
|
||||
image_list = self.origin.classes_list[c]
|
||||
|
||||
@@ -113,13 +113,13 @@ class EpisodicDataset(Dataset):
|
||||
|
||||
support = self._fetch_list_data(map(image_list.__getitem__, idx_list[:self.num_support]))
|
||||
query = self._fetch_list_data(map(image_list.__getitem__, idx_list[self.num_support:]))
|
||||
support_set_list.extend(support)
|
||||
query_set_list.extend(query)
|
||||
target_list.extend([tag] * self.num_query)
|
||||
support_set.extend(support)
|
||||
query_set.extend(query)
|
||||
target_set.extend([tag] * self.num_query)
|
||||
return {
|
||||
"support": torch.stack(support_set_list),
|
||||
"query": torch.stack(query_set_list),
|
||||
"target": torch.tensor(target_list)
|
||||
"support": torch.stack(support_set),
|
||||
"query": torch.stack(query_set),
|
||||
"target": torch.tensor(target_set)
|
||||
}
|
||||
|
||||
def __repr__(self):
|
||||
|
||||
Reference in New Issue
Block a user