From e69fb6f439255824555dc7a85fdbf4f993456131 Mon Sep 17 00:00:00 2001 From: darkliang <12132342@mail.sustech.edu.cn> Date: Sun, 9 Apr 2023 17:15:42 +0800 Subject: [PATCH] support CASIA-E dataset --- configs/gaitbase/gaitbase_casiae.yaml | 102 +++ configs/gaitpart/gaitpart_casiae.yaml | 82 ++ configs/gaitset/gaitset_casiae.yaml | 77 ++ datasets/CASIA-E/CASIA-E.json | 1020 +++++++++++++++++++++++++ datasets/CASIA-E/README.md | 43 ++ datasets/CASIA-E/extractor.py | 98 +++ docs/1.model_zoo.md | 8 + opengait/evaluation/evaluator.py | 28 +- 8 files changed, 1449 insertions(+), 9 deletions(-) create mode 100644 configs/gaitbase/gaitbase_casiae.yaml create mode 100644 configs/gaitpart/gaitpart_casiae.yaml create mode 100644 configs/gaitset/gaitset_casiae.yaml create mode 100644 datasets/CASIA-E/CASIA-E.json create mode 100644 datasets/CASIA-E/README.md create mode 100644 datasets/CASIA-E/extractor.py diff --git a/configs/gaitbase/gaitbase_casiae.yaml b/configs/gaitbase/gaitbase_casiae.yaml new file mode 100644 index 0000000..3d05c3a --- /dev/null +++ b/configs/gaitbase/gaitbase_casiae.yaml @@ -0,0 +1,102 @@ +data_cfg: + dataset_name: CASIA-E + dataset_root: your_path + dataset_partition: ./datasets/CASIA-E/CASIA-E.json + num_workers: 1 + remove_no_gallery: false # Remove probe if no gallery for it + test_dataset_name: CASIA-E + +evaluator_cfg: + enable_float16: true + restore_ckpt_strict: true + restore_hint: 60000 + save_name: GaitBase + #eval_func: GREW_submission + sampler: + batch_shuffle: false + batch_size: 16 + sample_type: all_ordered # all indicates whole sequence used to test, while ordered means input sequence by its natural order; Other options: fixed_unordered + frames_all_limit: 720 # limit the number of sampled frames to prevent out of memory + metric: euc # cos + transform: + - type: BaseSilCuttingTransform + +loss_cfg: + - loss_term_weight: 1.0 + margin: 0.2 + type: TripletLoss + log_prefix: triplet + - loss_term_weight: 1.0 + scale: 16 + type: CrossEntropyLoss + log_prefix: softmax + log_accuracy: true + +model_cfg: + model: Baseline + backbone_cfg: + type: ResNet9 + block: BasicBlock + channels: # Layers configuration for automatically model construction + - 64 + - 128 + - 256 + - 512 + layers: + - 1 + - 1 + - 1 + - 1 + strides: + - 1 + - 2 + - 2 + - 1 + maxpool: false + SeparateFCs: + in_channels: 512 + out_channels: 256 + parts_num: 16 + SeparateBNNecks: + class_num: 200 + in_channels: 256 + parts_num: 16 + bin_num: + - 16 + +optimizer_cfg: + lr: 0.1 + momentum: 0.9 + solver: SGD + weight_decay: 0.0005 + +scheduler_cfg: + gamma: 0.1 + milestones: # Learning Rate Reduction at each milestones + - 20000 + - 40000 + - 50000 + scheduler: MultiStepLR +trainer_cfg: + enable_float16: true # half_percesion float for memory reduction and speedup + fix_BN: false + with_test: false + log_iter: 100 + restore_ckpt_strict: true + restore_hint: 0 + save_iter: 20000 + save_name: GaitBase + sync_BN: true + total_iter: 60000 + sampler: + batch_shuffle: true + batch_size: + - 8 # TripletSampler, batch_size[0] indicates Number of Identity + - 32 # batch_size[1] indicates Samples sequqnce for each Identity + frames_num_fixed: 30 # fixed frames number for training + frames_num_max: 40 # max frames number for unfixed training + frames_num_min: 20 # min frames number for unfixed traing + sample_type: fixed_unordered # fixed control input frames number, unordered for controlling order of input tensor; Other options: unfixed_ordered or all_ordered + type: TripletSampler + transform: + - type: BaseSilCuttingTransform \ No newline at end of file diff --git a/configs/gaitpart/gaitpart_casiae.yaml b/configs/gaitpart/gaitpart_casiae.yaml new file mode 100644 index 0000000..7bb5569 --- /dev/null +++ b/configs/gaitpart/gaitpart_casiae.yaml @@ -0,0 +1,82 @@ +data_cfg: + dataset_name: CASIA-E + dataset_root: your_path + dataset_partition: ./datasets/CASIA-E/CASIA-E.json + num_workers: 4 + remove_no_gallery: false + test_dataset_name: CASIA-E + +evaluator_cfg: + enable_float16: false + restore_ckpt_strict: true + restore_hint: 120000 + save_name: GaitPart + sampler: + batch_size: 4 + sample_type: all_ordered + type: InferenceSampler + metric: euc # cos + +loss_cfg: + loss_term_weight: 1.0 + margin: 0.2 + type: TripletLoss + log_prefix: triplet + +model_cfg: + model: GaitPart + backbone_cfg: + in_channels: 1 + layers_cfg: + - BC-32 + - BC-32 + - M + - BC-64 + - BC-64 + - M + - FC-128-3 + - FC-128-3 + - FC-256-3 + - FC-256-3 + type: Plain + SeparateFCs: + in_channels: 256 + out_channels: 256 + parts_num: 16 + bin_num: + - 16 + +optimizer_cfg: + lr: 0.0001 + momentum: 0.9 + solver: Adam + weight_decay: 0.0 + +scheduler_cfg: + gamma: 0.1 + milestones: + - 100000 + scheduler: MultiStepLR + +trainer_cfg: + enable_float16: true + fix_BN: false + log_iter: 100 + with_test: false + restore_ckpt_strict: true + restore_hint: 0 + save_iter: 120000 + save_name: GaitPart + sync_BN: false + total_iter: 120000 + sampler: + batch_shuffle: false + batch_size: + - 8 + - 32 + frames_num_fixed: 30 + frames_num_max: 50 + frames_num_min: 25 + frames_skip_num: 10 + sample_type: fixed_ordered + type: TripletSampler diff --git a/configs/gaitset/gaitset_casiae.yaml b/configs/gaitset/gaitset_casiae.yaml new file mode 100644 index 0000000..f21c286 --- /dev/null +++ b/configs/gaitset/gaitset_casiae.yaml @@ -0,0 +1,77 @@ +data_cfg: + dataset_name: CASIA-E + dataset_root: your_path + dataset_partition: ./datasets/CASIA-E/CASIA-E.json + num_workers: 1 + remove_no_gallery: false + test_dataset_name: CASIA-E + +evaluator_cfg: + enable_float16: false + restore_ckpt_strict: true + restore_hint: 60000 + save_name: GaitSet + sampler: + batch_size: 16 + sample_type: all_ordered + type: InferenceSampler + metric: euc # cos + +loss_cfg: + loss_term_weight: 1.0 + margin: 0.2 + type: TripletLoss + log_prefix: triplet + +model_cfg: + model: GaitSet + in_channels: + - 1 + - 64 + - 128 + - 256 + SeparateFCs: + in_channels: 256 + out_channels: 256 + parts_num: 62 + bin_num: + - 16 + - 8 + - 4 + - 2 + - 1 + +optimizer_cfg: + lr: 0.1 + momentum: 0.9 + solver: SGD + weight_decay: 0.0005 + +scheduler_cfg: + gamma: 0.1 + milestones: + - 20000 + - 40000 + - 50000 + scheduler: MultiStepLR + +trainer_cfg: + enable_float16: true + log_iter: 100 + with_test: false + restore_ckpt_strict: true + restore_hint: 0 + save_iter: 60000 + save_name: GaitSet + sync_BN: false + total_iter: 60000 + sampler: + batch_shuffle: false + batch_size: + - 8 + - 32 + frames_num_fixed: 30 + frames_num_max: 50 + frames_num_min: 25 + sample_type: fixed_unordered + type: TripletSampler diff --git a/datasets/CASIA-E/CASIA-E.json b/datasets/CASIA-E/CASIA-E.json new file mode 100644 index 0000000..15fde26 --- /dev/null +++ b/datasets/CASIA-E/CASIA-E.json @@ -0,0 +1,1020 @@ +{ + "TRAIN_SET": [ + "001", + "002", + "003", + "004", + "005", + "006", + "007", + "008", + "009", + "010", + "011", + "012", + "013", + "014", + "015", + "016", + "017", + "018", + "019", + "020", + "021", + "022", + "023", + "024", + "025", + "026", + "027", + "028", + "029", + "030", + "031", + "032", + "033", + "034", + "035", + "036", + "037", + "038", + "039", + "040", + "041", + "042", + "043", + "044", + "045", + "046", + "047", + "048", + "049", + "050", + "051", + "052", + "053", + "054", + "055", + "056", + "057", + "058", + "059", + "060", + "061", + "062", + "063", + "064", + "065", + "066", + "067", + "068", + "069", + "070", + "071", + "072", + "073", + "074", + "075", + "076", + "077", + "078", + "079", + "080", + "081", + "082", + "083", + "084", + "085", + "086", + "087", + "088", + "089", + "090", + "091", + "092", + "093", + "094", + "095", + "096", + "097", + "098", + "099", + "100", + "101", + "102", + "103", + "104", + "105", + "106", + "107", + "108", + "109", + "110", + "111", + "112", + "113", + "114", + "115", + "116", + "117", + "118", + "119", + "120", + "121", + "122", + "123", + "124", + "125", + "126", + "127", + "128", + "129", + "130", + "131", + "132", + "133", + "134", + "135", + "136", + "137", + "138", + "139", + "140", + "141", + "142", + "143", + "144", + "145", + "146", + "147", + "148", + "149", + "150", + "151", + "152", + "153", + "154", + "155", + "156", + "157", + "158", + "159", + "160", + "161", + "162", + "163", + "164", + "165", + "166", + "167", + "168", + "169", + "170", + "171", + "172", + "173", + "174", + "175", + "176", + "177", + "178", + "179", + "180", + "181", + "182", + "183", + "184", + "185", + "186", + "187", + "188", + "189", + "190", + "191", + "192", + "193", + "194", + "195", + "196", + "197", + "198", + "199", + "200" + ], + "TEST_SET": [ + "201", + "202", + "203", + "204", + "205", + "206", + "207", + "208", + "209", + "210", + "211", + "212", + "213", + "214", + "215", + "216", + "217", + "218", + "219", + "220", + "221", + "222", + "223", + "224", + "225", + "226", + "227", + "228", + "229", + "230", + "231", + "232", + "233", + "234", + "235", + "236", + "237", + "238", + "239", + "240", + "241", + "242", + "243", + "244", + "245", + "246", + "247", + "248", + "249", + "250", + "251", + "252", + "253", + "254", + "255", + "256", + "257", + "258", + "259", + "260", + "261", + "262", + "263", + "264", + "265", + "266", + "267", + "268", + "269", + "270", + "271", + "272", + "273", + "274", + "275", + "276", + "277", + "278", + "279", + "280", + "281", + "282", + "283", + "284", + "285", + "286", + "287", + "288", + "289", + "290", + "291", + "292", + "293", + "294", + "295", + "296", + "297", + "298", + "299", + "300", + "301", + "302", + "303", + "304", + "305", + "306", + "307", + "308", + "309", + "310", + "311", + "312", + "313", + "314", + "315", + "316", + "317", + "318", + "319", + "320", + "321", + "322", + "323", + "324", + "325", + "326", + "327", + "328", + "329", + "330", + "331", + "332", + "333", + "334", + "335", + "336", + "337", + "338", + "339", + "340", + "341", + "342", + "343", + "344", + "345", + "346", + "347", + "348", + "349", + "350", + "351", + "352", + "353", + "354", + "355", + "356", + "357", + "358", + "359", + "360", + "361", + "362", + "363", + "364", + "365", + "366", + "367", + "368", + "369", + "370", + "371", + "372", + "373", + "374", + "375", + "376", + "377", + "378", + "379", + "380", + "381", + "382", + "383", + "384", + "385", + "386", + "387", + "388", + "389", + "390", + "391", + "392", + "393", + "394", + "395", + "396", + "397", + "398", + "399", + "400", + "401", + "402", + "403", + "404", + "405", + "406", + "407", + "408", + "409", + "410", + "411", + "412", + "413", + "414", + "415", + "416", + "417", + "418", + "419", + "420", + "421", + "422", + "423", + "424", + "425", + "426", + "427", + "428", + "429", + "430", + "431", + "432", + "433", + "434", + "435", + "436", + "437", + "438", + "439", + "440", + "441", + "442", + "443", + "444", + "445", + "446", + "447", + "448", + "449", + "450", + "451", + "452", + "453", + "454", + "455", + "456", + "457", + "458", + "459", + "460", + "461", + "462", + "463", + "464", + "465", + "466", + "467", + "468", + "469", + "470", + "471", + "472", + "473", + "474", + "475", + "476", + "477", + "478", + "479", + "480", + "481", + "482", + "483", + "484", + "485", + "486", + "487", + "488", + "489", + "490", + "491", + "492", + "493", + "494", + "495", + "496", + "497", + "498", + "499", + "500", + "501", + "502", + "503", + "504", + "505", + "506", + "507", + "508", + "509", + "510", + "511", + "512", + "513", + "514", + "515", + "516", + "517", + "518", + "519", + "520", + "521", + "522", + "523", + "524", + "525", + "526", + "527", + "528", + "529", + "530", + "531", + "532", + "533", + "534", + "535", + "536", + "537", + "538", + "539", + "540", + "541", + "542", + "543", + "544", + "545", + "546", + "547", + "548", + "549", + "550", + "551", + "552", + "553", + "554", + "555", + "556", + "557", + "558", + "559", + "560", + "561", + "562", + "563", + "564", + "565", + "566", + "567", + "568", + "569", + "570", + "571", + "572", + "573", + "574", + "575", + "576", + "577", + "578", + "579", + "580", + "581", + "582", + "583", + "584", + "585", + "586", + "587", + "588", + "589", + "590", + "591", + "592", + "593", + "594", + "595", + "596", + "597", + "598", + "599", + "600", + "601", + "602", + "603", + "604", + "605", + "606", + "607", + "608", + "609", + "610", + "611", + "612", + "613", + "614", + "615", + "616", + "617", + "618", + "619", + "620", + "621", + "622", + "623", + "624", + "625", + "626", + "627", + "628", + "629", + "630", + "631", + "632", + "633", + "634", + "635", + "636", + "637", + "638", + "639", + "640", + "641", + "642", + "643", + "644", + "645", + "646", + "647", + "648", + "649", + "650", + "651", + "652", + "653", + "654", + "655", + "656", + "657", + "658", + "659", + "660", + "661", + "662", + "663", + "664", + "665", + "666", + "667", + "668", + "669", + "670", + "671", + "672", + "673", + "674", + "675", + "676", + "677", + "678", + "679", + "680", + "681", + "682", + "683", + "684", + "685", + "686", + "687", + "688", + "689", + "690", + "691", + "692", + "693", + "694", + "695", + "696", + "697", + "698", + "699", + "700", + "701", + "702", + "703", + "704", + "705", + "706", + "707", + "708", + "709", + "710", + "711", + "712", + "713", + "714", + "715", + "716", + "717", + "718", + "719", + "720", + "721", + "722", + "723", + "724", + "725", + "726", + "727", + "728", + "729", + "730", + "731", + "732", + "733", + "734", + "735", + "736", + "737", + "738", + "739", + "740", + "741", + "742", + "743", + "744", + "745", + "746", + "747", + "748", + "749", + "750", + "751", + "752", + "753", + "754", + "755", + "756", + "757", + "758", + "759", + "760", + "761", + "762", + "763", + "764", + "765", + "766", + "767", + "768", + "769", + "770", + "771", + "772", + "773", + "774", + "775", + "776", + "777", + "778", + "779", + "780", + "781", + "782", + "783", + "784", + "785", + "786", + "787", + "788", + "789", + "790", + "791", + "792", + "793", + "794", + "795", + "796", + "797", + "798", + "799", + "800", + "801", + "802", + "803", + "804", + "805", + "806", + "807", + "808", + "809", + "810", + "811", + "812", + "813", + "814", + "815", + "816", + "817", + "818", + "819", + "820", + "821", + "822", + "823", + "824", + "825", + "826", + "827", + "828", + "829", + "830", + "831", + "832", + "833", + "834", + "835", + "836", + "837", + "838", + "839", + "840", + "841", + "842", + "843", + "844", + "845", + "846", + "847", + "848", + "849", + "850", + "851", + "852", + "853", + "854", + "855", + "856", + "857", + "858", + "859", + "860", + "861", + "862", + "863", + "864", + "865", + "866", + "867", + "868", + "869", + "870", + "871", + "872", + "873", + "874", + "875", + "876", + "877", + "878", + "879", + "880", + "881", + "882", + "883", + "884", + "885", + "886", + "887", + "888", + "889", + "890", + "891", + "892", + "893", + "894", + "895", + "896", + "897", + "898", + "899", + "900", + "901", + "902", + "903", + "904", + "905", + "906", + "907", + "908", + "909", + "910", + "911", + "912", + "913", + "914", + "915", + "916", + "917", + "918", + "919", + "920", + "921", + "922", + "923", + "924", + "925", + "926", + "927", + "928", + "929", + "930", + "931", + "932", + "933", + "934", + "935", + "936", + "937", + "938", + "939", + "940", + "941", + "942", + "943", + "944", + "945", + "946", + "947", + "948", + "949", + "950", + "951", + "952", + "953", + "954", + "955", + "956", + "957", + "958", + "959", + "960", + "961", + "962", + "963", + "964", + "965", + "966", + "967", + "968", + "969", + "970", + "971", + "972", + "973", + "974", + "975", + "976", + "977", + "978", + "979", + "980", + "981", + "982", + "983", + "984", + "985", + "986", + "987", + "988", + "989", + "990", + "991", + "992", + "993", + "994", + "995", + "996", + "997", + "998", + "999", + "1000", + "1001", + "1002", + "1003", + "1004", + "1005", + "1006", + "1007", + "1008", + "1009", + "1010", + "1011", + "1012", + "1013", + "1014" + ] +} \ No newline at end of file diff --git a/datasets/CASIA-E/README.md b/datasets/CASIA-E/README.md new file mode 100644 index 0000000..811b89f --- /dev/null +++ b/datasets/CASIA-E/README.md @@ -0,0 +1,43 @@ +# CASIA-E +Application URL: https://www.scidb.cn/en/detail?dataSetId=57be0e918db743279baf44a38d013a06 +- Original + ``` + test615-1014.zip + train001-500.zip + val501-614.zip + ``` +- Run `python datasets/CASIA-E/extractor.py --input_path CASIA-E/ --output_path CASIA-E-processed/ -n 8 -s 64`. \ + `n` is number of workers. `s` is the target image size. +- Processed + ``` + CASIA-E-processed + forTrain # raw images + 001 (subject) + H (height) + scene1 (scene) + bg (walking condition) + 000 (view) + 1 (sequence number) + xxx.jpg (images) + ...... + ...... + ...... + ...... + ...... + ...... + ...... + + opengait # pickle file + 001 (subject) + H_scene1_bg_1 (type) + 000 (view) + 000.pkl (contains all frames) + ...... + ...... + ...... + ``` + +## Evaluation +Compared with the settings in the original paper, we only used 200 people for training, and the rest were used as the test set, and the division of gallery and probe is more practical and difficult. +For specific experimental settings, please refer to configs/gaitbase/gaitbase_casiae.yaml. +For the specific division of the probe and gallery, please refer to opengait/evaluation/evaluator.py. \ No newline at end of file diff --git a/datasets/CASIA-E/extractor.py b/datasets/CASIA-E/extractor.py new file mode 100644 index 0000000..23ad778 --- /dev/null +++ b/datasets/CASIA-E/extractor.py @@ -0,0 +1,98 @@ +import argparse +import os +from pathlib import Path +import tqdm +import cv2 +import tarfile +import zipfile +from functools import partial +import numpy as np +import pickle +import multiprocessing as mp + + +def make_pkl_for_one_person(id_, output_path, img_size=64): + if id_.split(".")[-1] != "tar" or not os.path.exists(os.path.join(output_path, id_)): + return + with tarfile.TarFile(os.path.join(output_path, id_)) as f: + f.extractall(output_path) + os.remove(os.path.join(output_path, id_)) + id_path = id_.split(".")[0] + input_path = os.path.join(output_path, "forTrain", id_path) + base_pkl_path = os.path.join(output_path, "opengait", id_path) + if not os.path.isdir(input_path): + print("Path not found: "+input_path) + for height in sorted(os.listdir(input_path)): + height_path = os.path.join(input_path, height) + for scene in sorted(os.listdir(height_path)): + scene_path = os.path.join(height_path, scene) + for type_ in sorted(os.listdir(scene_path)): + type_path = os.path.join(scene_path, type_) + for view in sorted(os.listdir(type_path)): + view_path = os.path.join(type_path, view) + for num in sorted(os.listdir(view_path)): + num_path = os.path.join(view_path, num) + imgs = [] + for file_ in sorted(os.listdir(num_path)): + img = cv2.imread(os.path.join( + num_path, file_), cv2.IMREAD_GRAYSCALE) + if img_size != img.shape[0]: + img = cv2.resize( + img, dsize=(img_size, img_size)) + imgs.append(img) + if len(imgs) > 5: + pkl_path = os.path.join( + base_pkl_path, f"{height}-{scene}-{type_}-{num}", view) + os.makedirs(pkl_path, exist_ok=True) + pickle.dump(np.asarray(imgs), open( + os.path.join(pkl_path, f"{view}.pkl"), "wb")) + else: + print("No enough imgs: "+num_path) + + +def extractall(base_path: Path, output_path: Path, workers=1, img_size=64) -> None: + """Extract all archives in base_path to output_path. + + Args: + base_path (Path): Path to the directory containing the archives. + output_path (Path): Path to the directory to extract the archives to. + """ + + os.makedirs(output_path, exist_ok=True) + print("Unzipping train set...") + with open(os.path.join(base_path, 'train001-500.zip'), 'rb') as f: + z = zipfile.ZipFile(f) + z.extractall(output_path) + print("Unzipping validation set...") + with open(os.path.join(base_path, 'val501-614.zip'), 'rb') as f: + z = zipfile.ZipFile(f) + z.extractall(output_path) + print("Unzipping test set...") + with open(os.path.join(base_path, 'test615-1014.zip'), 'rb') as f: + z = zipfile.ZipFile(f) + z.extractall(output_path) + print("Extracting tar file...") + os.makedirs(os.path.join(output_path,"forTrain")) + os.makedirs(os.path.join(output_path,"opengait")) + ids = os.listdir(os.path.join(output_path)) + progress = tqdm.tqdm(total=len(ids), desc='Pretreating', unit='id') + + with mp.Pool(workers) as pool: + for _ in pool.imap_unordered(partial(make_pkl_for_one_person, output_path=output_path, img_size=img_size), ids): + progress.update(1) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='CASIA-E extractor') + parser.add_argument('-b', '--input_path', type=str, + required=True, help='Base path to CASIA-E zip files') + parser.add_argument('-o', '--output_path', type=str, + required=True, help='Output path for extracted files. The pickle files are generated in ${output_path}/opengait/') + parser.add_argument('-s', '--img_size', default=64, + type=int, help='Image resizing size. Default 64') + parser.add_argument('-n', '--num_workers', + type=int, default=1, help='Number of workers') + args = parser.parse_args() + + extractall(Path(args.input_path), Path(args.output_path), + args.num_workers, args.img_size) diff --git a/docs/1.model_zoo.md b/docs/1.model_zoo.md index ee5ff9a..740d3cf 100644 --- a/docs/1.model_zoo.md +++ b/docs/1.model_zoo.md @@ -56,6 +56,14 @@ | [DeepGaitV2-P3D](https://arxiv.org/pdf/2303.03301.pdf) | 74.4 | - | 64x44 | - | - | | [SwinGait(Transformer-based)](https://arxiv.org/pdf/2303.03301.pdf) | 75.0 | - | 64x44 | - | - | +## [CASIA-E](https://www.scidb.cn/en/detail?dataSetId=57be0e918db743279baf44a38d013a06) + +| Model | `Rank@1.NM` | `Rank@1.BG` | `Rank@1.CL` | Input size| Configuration | +| :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | :---------: | :---------:| :----: | :-------------: | :--------------------------------------------------------------: | +| GaitSet | 82.54 | 75.26 | 62.53 | 64x44 | configs/gaitset/gaitset_casiae.yaml +| GaitPart | 82.92 | 74.36 | 60.48 | 64x44 | configs/gaitpart/gaitpart_casiae.yaml +| GaitBase | 91.59 | 86.65 | 74.73 | 64x44 | configs/gaitbase/gaitbase_casiae.yaml + ------------------------------------------ The results in the parentheses are mentioned in the papers. diff --git a/opengait/evaluation/evaluator.py b/opengait/evaluation/evaluator.py index 251f707..c030339 100644 --- a/opengait/evaluation/evaluator.py +++ b/opengait/evaluation/evaluator.py @@ -70,12 +70,21 @@ def cross_view_gallery_evaluation(feature, label, seq_type, view, dataset, metri def single_view_gallery_evaluation(feature, label, seq_type, view, dataset, metric): probe_seq_dict = {'CASIA-B': {'NM': ['nm-05', 'nm-06'], 'BG': ['bg-01', 'bg-02'], 'CL': ['cl-01', 'cl-02']}, - 'OUMVLP': {'NM': ['00']}} + 'OUMVLP': {'NM': ['00']}, + 'CASIA-E': {'NM': ['H-scene2-nm-1', 'H-scene2-nm-2', 'L-scene2-nm-1', 'L-scene2-nm-2', 'H-scene3-nm-1', 'H-scene3-nm-2', 'L-scene3-nm-1', 'L-scene3-nm-2', 'H-scene3_s-nm-1', 'H-scene3_s-nm-2', 'L-scene3_s-nm-1', 'L-scene3_s-nm-2',], + 'BG': ['H-scene2-bg-1', 'H-scene2-bg-2', 'L-scene2-bg-1', 'L-scene2-bg-2', 'H-scene3-bg-1', 'H-scene3-bg-2', 'L-scene3-bg-1', 'L-scene3-bg-2', 'H-scene3_s-bg-1', 'H-scene3_s-bg-2', 'L-scene3_s-bg-1', 'L-scene3_s-bg-2'], + 'CL': ['H-scene2-cl-1', 'H-scene2-cl-2', 'L-scene2-cl-1', 'L-scene2-cl-2', 'H-scene3-cl-1', 'H-scene3-cl-2', 'L-scene3-cl-1', 'L-scene3-cl-2', 'H-scene3_s-cl-1', 'H-scene3_s-cl-2', 'L-scene3_s-cl-1', 'L-scene3_s-cl-2'] + } + + } gallery_seq_dict = {'CASIA-B': ['nm-01', 'nm-02', 'nm-03', 'nm-04'], - 'OUMVLP': ['01']} + 'OUMVLP': ['01'], + 'CASIA-E': ['H-scene1-nm-1', 'H-scene1-nm-2', 'L-scene1-nm-1', 'L-scene1-nm-2']} msg_mgr = get_msg_mgr() acc = {} view_list = sorted(np.unique(view)) + if dataset == 'CASIA-E': + view_list.remove("270") view_num = len(view_list) num_rank = 1 for (type_, probe_seq) in probe_seq_dict[dataset].items(): @@ -92,8 +101,8 @@ def single_view_gallery_evaluation(feature, label, seq_type, view, dataset, metr gallery_y = label[gseq_mask] gallery_x = feature[gseq_mask, :] dist = cuda_dist(probe_x, gallery_x, metric) - idx = dist.cpu().sort(1)[1].numpy() - acc[type_][v1, v2] = np.round(np.sum(np.cumsum(np.reshape(probe_y, [-1, 1]) == gallery_y[idx[:, 0:num_rank]], 1) > 0, + idx = dist.topk(num_rank, largest=False)[1].cpu().numpy() + acc[type_][v1, v2] = np.round(np.sum(np.cumsum(np.reshape(probe_y, [-1, 1]) == gallery_y[idx], 1) > 0, 0) * 100 / dist.shape[0], 2) result_dict = {} @@ -113,7 +122,7 @@ def evaluate_indoor_dataset(data, dataset, metric='euc', cross_view_gallery=Fals label = np.array(label) view = np.array(view) - if dataset not in ('CASIA-B', 'OUMVLP'): + if dataset not in ('CASIA-B', 'OUMVLP', 'CASIA-E'): raise KeyError("DataSet %s hasn't been supported !" % dataset) if cross_view_gallery: return cross_view_gallery_evaluation( @@ -145,7 +154,7 @@ def evaluate_real_scene(data, dataset, metric='euc'): probe_y = label[pseq_mask] dist = cuda_dist(probe_x, gallery_x, metric) - idx = dist.cpu().sort(1)[1].numpy() + idx = dist.topk(num_rank, largest=False)[1].cpu().numpy() acc = np.round(np.sum(np.cumsum(np.reshape(probe_y, [-1, 1]) == gallery_y[idx[:, 0:num_rank]], 1) > 0, 0) * 100 / dist.shape[0], 2) msg_mgr.log_info('==Rank-1==') @@ -173,8 +182,9 @@ def GREW_submission(data, dataset, metric='euc'): probe_x = feature[pseq_mask, :] probe_y = view[pseq_mask] + num_rank = 20 dist = cuda_dist(probe_x, gallery_x, metric) - idx = dist.cpu().sort(1)[1].numpy() + idx = dist.topk(num_rank, largest=False)[1].cpu().numpy() save_path = os.path.join( "GREW_result/"+strftime('%Y-%m%d-%H%M%S', localtime())+".csv") @@ -182,8 +192,8 @@ def GREW_submission(data, dataset, metric='euc'): with open(save_path, "w") as f: f.write("videoId,rank1,rank2,rank3,rank4,rank5,rank6,rank7,rank8,rank9,rank10,rank11,rank12,rank13,rank14,rank15,rank16,rank17,rank18,rank19,rank20\n") for i in range(len(idx)): - r_format = [int(idx) for idx in gallery_y[idx[i, 0:20]]] - output_row = '{}'+',{}'*20+'\n' + r_format = [int(idx) for idx in gallery_y[idx[i, 0:num_rank]]] + output_row = '{}'+',{}'*num_rank+'\n' f.write(output_row.format(probe_y[i], *r_format)) print("GREW result saved to {}/{}".format(os.getcwd(), save_path)) return