This commit is contained in:
2020-07-16 16:07:03 +08:00
parent 598bd9e0f1
commit 7d720c181b
3 changed files with 7 additions and 7 deletions

View File

@@ -73,7 +73,7 @@ def test(lmdb_path, import_path):
with torch.no_grad():
for item in data_loader:
st = time.time()
print("load", time.time() - load_st)
# print("load", time.time() - load_st)
item = convert_tensor(item, device, non_blocking=True)
# item["query"]: B x NK x 3 x W x H
# item["support"]: B x NK x 3 x W x H
@@ -81,12 +81,11 @@ def test(lmdb_path, import_path):
batch_size = item["target"].size(0)
query_batch = extractor(item["query"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N * K, -1)
support_batch = extractor(item["support"].view([-1, *item["query"].shape[-3:]])).view(batch_size, N, K, -1)
print("compute", time.time() - st)
# print("compute", time.time() - st)
load_st = time.time()
accs.append(evaluate(query_batch, item["target"], support_batch))
print(torch.tensor(accs).mean().item())
print("time: ", time.time() - st)
if __name__ == '__main__':
@@ -94,8 +93,8 @@ if __name__ == '__main__':
defined_path = ["/data/few-shot/lmdb/mini-imagenet/val.lmdb",
"/data/few-shot/lmdb/CUB_200_2011/data.lmdb",
"/data/few-shot/lmdb/STANFORD-CARS/train.lmdb",
# "/data/few-shot/lmdb/Plantae/data.lmdb",
# "/data/few-shot/lmdb/Places365/val.lmdb"
"/data/few-shot/lmdb/Plantae/data.lmdb",
"/data/few-shot/lmdb/Places365/val.lmdb"
]
parser = argparse.ArgumentParser(description="test")
parser.add_argument('-i', "--import_path", required=True)