From 66971ea129097ef7cb2d5cfc82d2d96a57fc46d2 Mon Sep 17 00:00:00 2001
From: chuanfushen <11950016@mail.sustech.edu.cn>
Date: Sat, 15 Jul 2023 17:12:13 +0800
Subject: [PATCH] Support SUSTech1K
---
README.md | 5 +-
configs/lidargait/lidargait_sustech1k.yaml | 101 ++
datasets/SUSTech1K/README.md | 33 +
datasets/SUSTech1K/SUSTech1K.json | 1056 ++++++++++++++++++
datasets/SUSTech1K/point2depth.py | 279 +++++
datasets/SUSTech1K/pretreatment_SUSTech1K.py | 221 ++++
opengait/evaluation/evaluator.py | 39 +-
opengait/modeling/models/baseline.py | 8 +-
8 files changed, 1725 insertions(+), 17 deletions(-)
create mode 100644 configs/lidargait/lidargait_sustech1k.yaml
create mode 100644 datasets/SUSTech1K/README.md
create mode 100644 datasets/SUSTech1K/SUSTech1K.json
create mode 100644 datasets/SUSTech1K/point2depth.py
create mode 100644 datasets/SUSTech1K/pretreatment_SUSTech1K.py
diff --git a/README.md b/README.md
index de3e4d0..ba001c5 100644
--- a/README.md
+++ b/README.md
@@ -3,11 +3,14 @@
------------------------------------------
+π£π£π£ **[*SUSTech1K*](https://lidargait.github.io) relseased, pls checking the [tutorial](datasets/SUSTech1K/README.md).** π£π£π£
+
πππ **[*OpenGait*](https://openaccess.thecvf.com/content/CVPR2023/papers/Fan_OpenGait_Revisiting_Gait_Recognition_Towards_Better_Practicality_CVPR_2023_paper.pdf) has been accpected by CVPR2023 as a highlight paperοΌ** πππ
OpenGait is a flexible and extensible gait recognition project provided by the [Shiqi Yu Group](https://faculty.sustech.edu.cn/yusq/) and supported in part by [WATRIX.AI](http://www.watrix.ai).
## What's New
+- **[July 2023]** [SUSTech1K](datasets/SUSTech1K/README.md) is released and supported by OpenGait.
- **[May 2023]** A real gait recognition system [All-in-One-Gait](https://github.com/jdyjjj/All-in-One-Gait) provided by [Dongyang Jin](https://github.com/jdyjjj) is avaliable.
- [Apr 2023] [CASIA-E](datasets/CASIA-E/README.md) is supported by OpenGait.
- [Feb 2023] [HID 2023 competition](https://hid2023.iapr-tc4.org/) is open, welcome to participate. Additionally, tutorial for the competition has been updated in [datasets/HID/](./datasets/HID).
@@ -50,7 +53,7 @@ Results and models are available in the [model zoo](docs/1.model_zoo.md).
## Authors:
**Open Gait Team (OGT)**
- [Chao Fan (ζ¨θΆ
)](https://chaofan996.github.io), 12131100@mail.sustech.edu.cn
-- [Chuanfu Shen (ζ²ε·η¦)](https://faculty.sustech.edu.cn/?p=95396&tagid=yusq&cat=2&iscss=1&snapid=1&orderby=date), 11950016@mail.sustech.edu.cn
+- [Chuanfu Shen (ζ²ε·η¦)](https://chuanfushen.github.io), 11950016@mail.sustech.edu.cn
- [Junhao Liang (ζ’ε³»θ±ͺ)](https://faculty.sustech.edu.cn/?p=95401&tagid=yusq&cat=2&iscss=1&snapid=1&orderby=date), 12132342@mail.sustech.edu.cn
## Acknowledgement
diff --git a/configs/lidargait/lidargait_sustech1k.yaml b/configs/lidargait/lidargait_sustech1k.yaml
new file mode 100644
index 0000000..d1c73b9
--- /dev/null
+++ b/configs/lidargait/lidargait_sustech1k.yaml
@@ -0,0 +1,101 @@
+data_cfg:
+ dataset_name: SUSTech1K
+ dataset_root: your_path_of_SUSTech1K-Released-pkl
+ dataset_partition: ./datasets/SUSTech1K/SUSTech1K.json
+ num_workers: 4
+ data_in_use: [false, true, false, false, false, false, false, false, false, false, false, false, false, false, false, false]
+ remove_no_gallery: false # Remove probe if no gallery for it
+ test_dataset_name: SUSTech1K
+
+evaluator_cfg:
+ enable_float16: true
+ restore_ckpt_strict: true
+ restore_hint: 40000
+ save_name: LidarGait
+ eval_func: evaluate_indoor_dataset #evaluate_Gait3D
+ sampler:
+ batch_shuffle: false
+ batch_size: 4
+ sample_type: all_ordered # all indicates whole sequence used to test, while ordered means input sequence by its natural order; Other options: fixed_unordered
+ frames_all_limit: 720 # limit the number of sampled frames to prevent out of memory
+ metric: euc # cos
+ transform:
+ - type: BaseSilTransform
+
+loss_cfg:
+ - loss_term_weight: 1.0
+ margin: 0.2
+ type: TripletLoss
+ log_prefix: triplet
+ - loss_term_weight: 1.0
+ scale: 16
+ type: CrossEntropyLoss
+ log_prefix: softmax
+ log_accuracy: true
+
+model_cfg:
+ model: Baseline
+ backbone_cfg:
+ type: ResNet9
+ in_channel: 3
+ block: BasicBlock
+ channels: # Layers configuration for automatically model construction
+ - 64
+ - 128
+ - 256
+ - 512
+ layers:
+ - 1
+ - 1
+ - 1
+ - 1
+ strides:
+ - 1
+ - 2
+ - 2
+ - 1
+ maxpool: false
+ SeparateFCs:
+ in_channels: 512
+ out_channels: 256
+ parts_num: 16
+ SeparateBNNecks:
+ class_num: 250
+ in_channels: 256
+ parts_num: 16
+ bin_num:
+ - 16
+
+optimizer_cfg:
+ lr: 0.1
+ momentum: 0.9
+ solver: SGD
+ weight_decay: 0.0005
+
+scheduler_cfg:
+ gamma: 0.1
+ milestones: # Learning Rate Reduction at each milestones
+ - 20000
+ - 30000
+ scheduler: MultiStepLR
+trainer_cfg:
+ enable_float16: true # half_percesion float for memory reduction and speedup
+ fix_BN: false
+ with_test: true #true
+ log_iter: 100
+ restore_ckpt_strict: true
+ restore_hint: 0
+ save_iter: 5000
+ save_name: LidarGait
+ sync_BN: true
+ total_iter: 40000
+ sampler:
+ batch_shuffle: true
+ batch_size:
+ - 8 # TripletSampler, batch_size[0] indicates Number of Identity
+ - 8 # batch_size[1] indicates Samples sequqnce for each Identity
+ frames_num_fixed: 10 # fixed frames number for training
+ sample_type: fixed_unordered # fixed control input frames number, unordered for controlling order of input tensor; Other options: unfixed_ordered or all_ordered
+ type: TripletSampler
+ transform:
+ - type: BaseSilTransform
\ No newline at end of file
diff --git a/datasets/SUSTech1K/README.md b/datasets/SUSTech1K/README.md
new file mode 100644
index 0000000..925c5fa
--- /dev/null
+++ b/datasets/SUSTech1K/README.md
@@ -0,0 +1,33 @@
+# Tutorial for [SUSTech1K](https://lidargait.github.io)
+
+## Download the SUSTech1K dataset
+Download the dataset from the [link](https://lidargait.github.io).
+decompress these two file by following command:
+```shell
+unzip -P password SUSTech1K-pkl.zip | xargs -n1 tar xzvf
+```
+password should be obtained by signing [agreement](https://lidargait.github.io/static/resources/SUSTech1KAgreement.pdf) and sending to email (shencf2019@mail.sustech.edu.cn)
+
+## Train the dataset
+Modify the `dataset_root` in `configs/lidargait/lidargait_sustech1k.yaml`, and then run this command:
+```shell
+CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 opengait/main.py --cfgs configs/lidargait/lidargait_sustech1k.yaml --phase train
+```
+
+
+## Process from RAW dataset
+
+### Preprocess the dataset (Optional)
+Download the raw dataset from the [official link](https://lidargait.github.io). You will get two compressed files, i.e. `DATASET_DOWNLOAD.md5`, `SUSTeck1K-RAW.zip`, and `SUSTeck1K-pkl.zip`.
+We recommend using our provided pickle files for convenience, or process raw dataset into pickle by this command:
+```shell
+python datasets/SUSTech1K/pretreatment_SUSTech1K.py -i SUSTech1K-Released-2023 -o SUSTech1K-pkl -n 8
+```
+
+### Projecting PointCloud into Depth image (Optional)
+You can use our processed depth images, or you can process via the command:
+```shell
+python datasets/SUSTech1K/point2depth.py -i SUSTech1K-Released-2023/ -o SUSTech1K-Released-2023/ -n 8
+```
+We recommend using our provided depth images for convenience.
+
diff --git a/datasets/SUSTech1K/SUSTech1K.json b/datasets/SUSTech1K/SUSTech1K.json
new file mode 100644
index 0000000..dd80292
--- /dev/null
+++ b/datasets/SUSTech1K/SUSTech1K.json
@@ -0,0 +1,1056 @@
+{
+ "TRAIN_SET": [
+ "0000",
+ "0002",
+ "0005",
+ "0018",
+ "0021",
+ "0026",
+ "0033",
+ "0039",
+ "0040",
+ "0044",
+ "0046",
+ "0047",
+ "0050",
+ "0052",
+ "0060",
+ "0062",
+ "0066",
+ "0073",
+ "0075",
+ "0078",
+ "0080",
+ "0092",
+ "0093",
+ "0096",
+ "0097",
+ "0102",
+ "0103",
+ "0116",
+ "0118",
+ "0126",
+ "0128",
+ "0144",
+ "0149",
+ "0151",
+ "0153",
+ "0154",
+ "0156",
+ "0157",
+ "0158",
+ "0164",
+ "0165",
+ "0168",
+ "0169",
+ "0174",
+ "0180",
+ "0183",
+ "0193",
+ "0194",
+ "0196",
+ "0199",
+ "0201",
+ "0203",
+ "0204",
+ "0208",
+ "0212",
+ "0220",
+ "0226",
+ "0227",
+ "0231",
+ "0232",
+ "0243",
+ "0253",
+ "0257",
+ "0267",
+ "0268",
+ "0282",
+ "0284",
+ "0287",
+ "0289",
+ "0291",
+ "0293",
+ "0295",
+ "0297",
+ "0301",
+ "0306",
+ "0311",
+ "0315",
+ "0318",
+ "0320",
+ "0321",
+ "0323",
+ "0325",
+ "0327",
+ "0341",
+ "0347",
+ "0351",
+ "0356",
+ "0357",
+ "0360",
+ "0367",
+ "0372",
+ "0379",
+ "0380",
+ "0395",
+ "0402",
+ "0412",
+ "0421",
+ "0425",
+ "0434",
+ "0438",
+ "0440",
+ "0441",
+ "0444",
+ "0452",
+ "0457",
+ "0458",
+ "0464",
+ "0468",
+ "0473",
+ "0474",
+ "0478",
+ "0488",
+ "0493",
+ "0495",
+ "0496",
+ "0497",
+ "0502",
+ "0514",
+ "0522",
+ "0521",
+ "0525",
+ "0533",
+ "0535",
+ "0537",
+ "0540",
+ "0544",
+ "0545",
+ "0546",
+ "0547",
+ "0551",
+ "0552",
+ "0553",
+ "0555",
+ "0557",
+ "0559",
+ "0577",
+ "0581",
+ "0583",
+ "0584",
+ "0585",
+ "0591",
+ "0597",
+ "0600",
+ "0605",
+ "0610",
+ "0611",
+ "0612",
+ "0616",
+ "0631",
+ "0632",
+ "0634",
+ "0636",
+ "0637",
+ "0641",
+ "0649",
+ "0653",
+ "0655",
+ "0664",
+ "0665",
+ "0671",
+ "0675",
+ "0677",
+ "0687",
+ "0695",
+ "0701",
+ "0702",
+ "0707",
+ "0717",
+ "0720",
+ "0721",
+ "0723",
+ "0726",
+ "0731",
+ "0756",
+ "0757",
+ "0759",
+ "0760",
+ "0767",
+ "0770",
+ "0773",
+ "0779",
+ "0780",
+ "0783",
+ "0791",
+ "0792",
+ "0796",
+ "0805",
+ "0810",
+ "0811",
+ "0823",
+ "0828",
+ "0830",
+ "0839",
+ "0841",
+ "0844",
+ "0845",
+ "0846",
+ "0850",
+ "0853",
+ "0860",
+ "0862",
+ "0863",
+ "0865",
+ "0868",
+ "0869",
+ "0876",
+ "0883",
+ "0884",
+ "0888",
+ "0897",
+ "0904",
+ "0906",
+ "0907",
+ "0908",
+ "0918",
+ "0922",
+ "0923",
+ "0925",
+ "0933",
+ "0938",
+ "0942",
+ "0943",
+ "0944",
+ "0948",
+ "0951",
+ "0959",
+ "0965",
+ "0966",
+ "0969",
+ "0970",
+ "0973",
+ "0978",
+ "0979",
+ "0982",
+ "0996",
+ "0997",
+ "1002",
+ "1004",
+ "1011",
+ "1013",
+ "1015",
+ "1019",
+ "1024",
+ "1026",
+ "1027",
+ "1036",
+ "1038",
+ "1046",
+ "1056",
+ "1057"
+ ],
+ "TEST_SET": [
+ "0001",
+ "0003",
+ "0004",
+ "0006",
+ "0007",
+ "0008",
+ "0009",
+ "0010",
+ "0011",
+ "0012",
+ "0013",
+ "0014",
+ "0015",
+ "0016",
+ "0017",
+ "0019",
+ "0020",
+ "0022",
+ "0023",
+ "0024",
+ "0025",
+ "0027",
+ "0028",
+ "0029",
+ "0030",
+ "0031",
+ "0032",
+ "0034",
+ "0035",
+ "0036",
+ "0037",
+ "0038",
+ "0041",
+ "0042",
+ "0043",
+ "0045",
+ "0048",
+ "0049",
+ "0051",
+ "0053",
+ "0054",
+ "0055",
+ "0056",
+ "0057",
+ "0058",
+ "0059",
+ "0061",
+ "0063",
+ "0064",
+ "0065",
+ "0067",
+ "0068",
+ "0069",
+ "0070",
+ "0071",
+ "0072",
+ "0074",
+ "0076",
+ "0077",
+ "0079",
+ "0081",
+ "0082",
+ "0083",
+ "0084",
+ "0085",
+ "0086",
+ "0087",
+ "0088",
+ "0089",
+ "0090",
+ "0091",
+ "0094",
+ "0095",
+ "0098",
+ "0099",
+ "0100",
+ "0101",
+ "0104",
+ "0105",
+ "0106",
+ "0107",
+ "0108",
+ "0109",
+ "0110",
+ "0111",
+ "0112",
+ "0113",
+ "0114",
+ "0115",
+ "0117",
+ "0119",
+ "0120",
+ "0121",
+ "0122",
+ "0123",
+ "0124",
+ "0125",
+ "0127",
+ "0129",
+ "0130",
+ "0131",
+ "0132",
+ "0133",
+ "0134",
+ "0135",
+ "0136",
+ "0137",
+ "0138",
+ "0139",
+ "0140",
+ "0141",
+ "0142",
+ "0143",
+ "0145",
+ "0146",
+ "0147",
+ "0148",
+ "0150",
+ "0152",
+ "0155",
+ "0159",
+ "0160",
+ "0161",
+ "0162",
+ "0163",
+ "0166",
+ "0167",
+ "0170",
+ "0171",
+ "0172",
+ "0173",
+ "0175",
+ "0176",
+ "0177",
+ "0178",
+ "0179",
+ "0181",
+ "0182",
+ "0184",
+ "0185",
+ "0186",
+ "0187",
+ "0188",
+ "0189",
+ "0190",
+ "0191",
+ "0192",
+ "0195",
+ "0197",
+ "0198",
+ "0200",
+ "0202",
+ "0205",
+ "0206",
+ "0207",
+ "0209",
+ "0210",
+ "0211",
+ "0213",
+ "0214",
+ "0215",
+ "0216",
+ "0217",
+ "0218",
+ "0219",
+ "0221",
+ "0222",
+ "0223",
+ "0224",
+ "0225",
+ "0228",
+ "0229",
+ "0230",
+ "0233",
+ "0234",
+ "0235",
+ "0236",
+ "0237",
+ "0238",
+ "0239",
+ "0240",
+ "0241",
+ "0242",
+ "0244",
+ "0245",
+ "0246",
+ "0247",
+ "0248",
+ "0249",
+ "0250",
+ "0251",
+ "0252",
+ "0254",
+ "0255",
+ "0256",
+ "0258",
+ "0259",
+ "0260",
+ "0261",
+ "0262",
+ "0263",
+ "0264",
+ "0265",
+ "0266",
+ "0269",
+ "0270",
+ "0271",
+ "0272",
+ "0273",
+ "0274",
+ "0275",
+ "0276",
+ "0277",
+ "0278",
+ "0279",
+ "0280",
+ "0281",
+ "0283",
+ "0285",
+ "0286",
+ "0288",
+ "0290",
+ "0292",
+ "0294",
+ "0296",
+ "0298",
+ "0299",
+ "0300",
+ "0302",
+ "0303",
+ "0304",
+ "0305",
+ "0307",
+ "0308",
+ "0309",
+ "0310",
+ "0312",
+ "0313",
+ "0314",
+ "0316",
+ "0317",
+ "0319",
+ "0322",
+ "0324",
+ "0326",
+ "0328",
+ "0329",
+ "0330",
+ "0331",
+ "0332",
+ "0333",
+ "0334",
+ "0335",
+ "0336",
+ "0337",
+ "0338",
+ "0339",
+ "0340",
+ "0342",
+ "0343",
+ "0344",
+ "0345",
+ "0346",
+ "0348",
+ "0349",
+ "0350",
+ "0353",
+ "0354",
+ "0355",
+ "0358",
+ "0359",
+ "0361",
+ "0362",
+ "0363",
+ "0364",
+ "0365",
+ "0366",
+ "0368",
+ "0369",
+ "0370",
+ "0371",
+ "0373",
+ "0374",
+ "0375",
+ "0376",
+ "0377",
+ "0378",
+ "0381",
+ "0382",
+ "0383",
+ "0384",
+ "0385",
+ "0386",
+ "0387",
+ "0388",
+ "0389",
+ "0390",
+ "0391",
+ "0392",
+ "0393",
+ "0394",
+ "0396",
+ "0397",
+ "0398",
+ "0399",
+ "0400",
+ "0401",
+ "0403",
+ "0404",
+ "0405",
+ "0406",
+ "0407",
+ "0408",
+ "0409",
+ "0410",
+ "0411",
+ "0413",
+ "0414",
+ "0415",
+ "0416",
+ "0417",
+ "0418",
+ "0419",
+ "0420",
+ "0422",
+ "0423",
+ "0424",
+ "0426",
+ "0427",
+ "0428",
+ "0429",
+ "0430",
+ "0431",
+ "0432",
+ "0433",
+ "0435",
+ "0436",
+ "0437",
+ "0439",
+ "0442",
+ "0443",
+ "0445",
+ "0446",
+ "0447",
+ "0448",
+ "0449",
+ "0450",
+ "0451",
+ "0453",
+ "0454",
+ "0455",
+ "0456",
+ "0459",
+ "0460",
+ "0461",
+ "0462",
+ "0463",
+ "0465",
+ "0466",
+ "0467",
+ "0469",
+ "0470",
+ "0471",
+ "0472",
+ "0475",
+ "0476",
+ "0477",
+ "0479",
+ "0480",
+ "0481",
+ "0482",
+ "0483",
+ "0484",
+ "0485",
+ "0486",
+ "0487",
+ "0489",
+ "0490",
+ "0491",
+ "0492",
+ "0494",
+ "0498",
+ "0499",
+ "0500",
+ "0501",
+ "0503",
+ "0504",
+ "0505",
+ "0506",
+ "0507",
+ "0508",
+ "0509",
+ "0510",
+ "0511",
+ "0512",
+ "0513",
+ "0515",
+ "0516",
+ "0517",
+ "0518",
+ "0519",
+ "0520",
+ "0523",
+ "0524",
+ "0526",
+ "0527",
+ "0528",
+ "0529",
+ "0530",
+ "0531",
+ "0532",
+ "0534",
+ "0536",
+ "0538",
+ "0539",
+ "0541",
+ "0542",
+ "0543",
+ "0548",
+ "0549",
+ "0550",
+ "0554",
+ "0556",
+ "0558",
+ "0560",
+ "0561",
+ "0562",
+ "0563",
+ "0564",
+ "0565",
+ "0566",
+ "0567",
+ "0568",
+ "0569",
+ "0570",
+ "0571",
+ "0572",
+ "0573",
+ "0574",
+ "0575",
+ "0576",
+ "0578",
+ "0579",
+ "0580",
+ "0582",
+ "0586",
+ "0587",
+ "0588",
+ "0589",
+ "0590",
+ "0592",
+ "0593",
+ "0594",
+ "0595",
+ "0596",
+ "0598",
+ "0599",
+ "0601",
+ "0602",
+ "0603",
+ "0604",
+ "0606",
+ "0607",
+ "0608",
+ "0609",
+ "0613",
+ "0614",
+ "0615",
+ "0617",
+ "0618",
+ "0619",
+ "0620",
+ "0621",
+ "0622",
+ "0623",
+ "0624",
+ "0625",
+ "0626",
+ "0627",
+ "0628",
+ "0629",
+ "0630",
+ "0633",
+ "0635",
+ "0638",
+ "0639",
+ "0640",
+ "0642",
+ "0643",
+ "0644",
+ "0645",
+ "0646",
+ "0647",
+ "0648",
+ "0650",
+ "0651",
+ "0652",
+ "0654",
+ "0656",
+ "0657",
+ "0658",
+ "0659",
+ "0660",
+ "0661",
+ "0662",
+ "0663",
+ "0666",
+ "0667",
+ "0668",
+ "0669",
+ "0670",
+ "0672",
+ "0673",
+ "0674",
+ "0676",
+ "0678",
+ "0679",
+ "0680",
+ "0681",
+ "0682",
+ "0683",
+ "0684",
+ "0685",
+ "0686",
+ "0688",
+ "0689",
+ "0690",
+ "0691",
+ "0692",
+ "0693",
+ "0694",
+ "0696",
+ "0697",
+ "0698",
+ "0699",
+ "0700",
+ "0703",
+ "0704",
+ "0705",
+ "0706",
+ "0708",
+ "0709",
+ "0710",
+ "0711",
+ "0712",
+ "0713",
+ "0714",
+ "0715",
+ "0716",
+ "0718",
+ "0719",
+ "0722",
+ "0724",
+ "0725",
+ "0727",
+ "0728",
+ "0729",
+ "0730",
+ "0732",
+ "0733",
+ "0734",
+ "0735",
+ "0736",
+ "0737",
+ "0738",
+ "0739",
+ "0740",
+ "0741",
+ "0742",
+ "0743",
+ "0744",
+ "0745",
+ "0746",
+ "0747",
+ "0748",
+ "0749",
+ "0750",
+ "0751",
+ "0752",
+ "0753",
+ "0754",
+ "0755",
+ "0758",
+ "0761",
+ "0762",
+ "0763",
+ "0764",
+ "0765",
+ "0766",
+ "0768",
+ "0769",
+ "0771",
+ "0772",
+ "0774",
+ "0775",
+ "0776",
+ "0777",
+ "0778",
+ "0781",
+ "0782",
+ "0784",
+ "0785",
+ "0786",
+ "0787",
+ "0788",
+ "0789",
+ "0790",
+ "0793",
+ "0794",
+ "0795",
+ "0797",
+ "0798",
+ "0799",
+ "0800",
+ "0801",
+ "0802",
+ "0803",
+ "0804",
+ "0806",
+ "0807",
+ "0808",
+ "0809",
+ "0812",
+ "0813",
+ "0814",
+ "0815",
+ "0816",
+ "0817",
+ "0818",
+ "0819",
+ "0820",
+ "0821",
+ "0822",
+ "0824",
+ "0825",
+ "0826",
+ "0827",
+ "0829",
+ "0831",
+ "0832",
+ "0833",
+ "0834",
+ "0835",
+ "0836",
+ "0837",
+ "0838",
+ "0840",
+ "0842",
+ "0843",
+ "0847",
+ "0848",
+ "0849",
+ "0851",
+ "0852",
+ "0854",
+ "0855",
+ "0856",
+ "0857",
+ "0858",
+ "0859",
+ "0861",
+ "0864",
+ "0866",
+ "0867",
+ "0870",
+ "0871",
+ "0872",
+ "0873",
+ "0874",
+ "0875",
+ "0877",
+ "0878",
+ "0879",
+ "0880",
+ "0881",
+ "0882",
+ "0885",
+ "0886",
+ "0887",
+ "0889",
+ "0890",
+ "0891",
+ "0892",
+ "0893",
+ "0894",
+ "0895",
+ "0896",
+ "0898",
+ "0899",
+ "0900",
+ "0901",
+ "0902",
+ "0903",
+ "0905",
+ "0909",
+ "0910",
+ "0911",
+ "0912",
+ "0913",
+ "0914",
+ "0915",
+ "0916",
+ "0917",
+ "0919",
+ "0920",
+ "0921",
+ "0924",
+ "0926",
+ "0927",
+ "0928",
+ "0929",
+ "0930",
+ "0931",
+ "0932",
+ "0934",
+ "0935",
+ "0936",
+ "0937",
+ "0939",
+ "0940",
+ "0941",
+ "0945",
+ "0946",
+ "0947",
+ "0949",
+ "0950",
+ "0952",
+ "0953",
+ "0954",
+ "0955",
+ "0956",
+ "0957",
+ "0958",
+ "0960",
+ "0961",
+ "0962",
+ "0963",
+ "0964",
+ "0967",
+ "0968",
+ "0971",
+ "0972",
+ "0974",
+ "0975",
+ "0976",
+ "0977",
+ "0980",
+ "0981",
+ "0983",
+ "0984",
+ "0985",
+ "0986",
+ "0987",
+ "0988",
+ "0989",
+ "0990",
+ "0991",
+ "0992",
+ "0993",
+ "0994",
+ "0995",
+ "0998",
+ "0999",
+ "1000",
+ "1001",
+ "1003",
+ "1005",
+ "1006",
+ "1007",
+ "1008",
+ "1009",
+ "1010",
+ "1012",
+ "1014",
+ "1016",
+ "1017",
+ "1018",
+ "1020",
+ "1021",
+ "1022",
+ "1023",
+ "1025",
+ "1028",
+ "1029",
+ "1030",
+ "1031",
+ "1032",
+ "1033",
+ "1034",
+ "1035",
+ "1037",
+ "1039",
+ "1040",
+ "1041",
+ "1042",
+ "1043",
+ "1044",
+ "1045",
+ "1049",
+ "1055"
+ ]
+}
\ No newline at end of file
diff --git a/datasets/SUSTech1K/point2depth.py b/datasets/SUSTech1K/point2depth.py
new file mode 100644
index 0000000..c685a0a
--- /dev/null
+++ b/datasets/SUSTech1K/point2depth.py
@@ -0,0 +1,279 @@
+import matplotlib.pyplot as plt
+
+import open3d as o3d
+# This source is based on https://github.com/AbnerHqC/GaitSet/blob/master/pretreatment.py
+import argparse
+import logging
+import multiprocessing as mp
+import os
+import pickle
+from collections import defaultdict
+from functools import partial
+from pathlib import Path
+from typing import Tuple
+
+import cv2
+import numpy as np
+from tqdm import tqdm
+
+def align_img(img: np.ndarray, img_size: int = 64) -> np.ndarray:
+ """Aligns the image to the center.
+ Args:
+ img (np.ndarray): Image to align.
+ img_size (int, optional): Image resizing size. Defaults to 64.
+ Returns:
+ np.ndarray: Aligned image.
+ """
+ if img.sum() <= 10000:
+ y_top = 0
+ y_btm = img.shape[0]
+ else:
+ # Get the upper and lower points
+ # img.sum
+ y_sum = img.sum(axis=2).sum(axis=1)
+ y_top = (y_sum != 0).argmax(axis=0)
+ y_btm = (y_sum != 0).cumsum(axis=0).argmax(axis=0)
+
+ img = img[y_top: y_btm, :,:]
+
+ # As the height of a person is larger than the width,
+ # use the height to calculate resize ratio.
+ ratio = img.shape[1] / img.shape[0]
+ img = cv2.resize(img, (int(img_size * ratio), img_size), interpolation=cv2.INTER_CUBIC)
+
+ # Get the median of the x-axis and take it as the person's x-center.
+ x_csum = img.sum(axis=2).sum(axis=0).cumsum()
+ x_center = img.shape[1] // 2
+ for idx, csum in enumerate(x_csum):
+ if csum > img.sum() / 2:
+ x_center = idx
+ break
+
+ # if not x_center:
+ # logging.warning(f'{img_file} has no center.')
+ # continue
+
+ # Get the left and right points
+ half_width = img_size // 2
+ left = x_center - half_width
+ right = x_center + half_width
+ if left <= 0 or right >= img.shape[1]:
+ left += half_width
+ right += half_width
+ # _ = np.zeros((img.shape[0], half_width,3))
+ # img = np.concatenate([_, img, _], axis=1)
+
+ img = img[:, left: right,:].astype('uint8')
+ return img
+
+
+
+
+
+def lidar_to_2d_front_view(points,
+ v_res,
+ h_res,
+ v_fov,
+ val="depth",
+ cmap="jet",
+ saveto=None,
+ y_fudge=0.0
+ ):
+ """ Takes points in 3D space from LIDAR data and projects them to a 2D
+ "front view" image, and saves that image.
+
+ Args:
+ points: (np array)
+ The numpy array containing the lidar points.
+ The shape should be Nx4
+ - Where N is the number of points, and
+ - each point is specified by 4 values (x, y, z, reflectance)
+ v_res: (float)
+ vertical resolution of the lidar sensor used.
+ h_res: (float)
+ horizontal resolution of the lidar sensor used.
+ v_fov: (tuple of two floats)
+ (minimum_negative_angle, max_positive_angle)
+ val: (str)
+ What value to use to encode the points that get plotted.
+ One of {"depth", "height", "reflectance"}
+ cmap: (str)
+ Color map to use to color code the `val` values.
+ NOTE: Must be a value accepted by matplotlib's scatter function
+ Examples: "jet", "gray"
+ saveto: (str or None)
+ If a string is provided, it saves the image as this filename.
+ If None, then it just shows the image.
+ y_fudge: (float)
+ A hacky fudge factor to use if the theoretical calculations of
+ vertical range do not match the actual data.
+
+ For a Velodyne HDL 64E, set this value to 5.
+ """
+
+ # DUMMY PROOFING
+ assert len(v_fov) ==2, "v_fov must be list/tuple of length 2"
+ assert v_fov[0] <= 0, "first element in v_fov must be 0 or negative"
+ assert val in {"depth", "height", "reflectance"}, \
+ 'val must be one of {"depth", "height", "reflectance"}'
+
+
+ x_lidar = - points[:, 0]
+ y_lidar = - points[:, 1]
+ z_lidar = points[:, 2]
+ # Distance relative to origin when looked from top
+ d_lidar = np.sqrt(x_lidar ** 2 + y_lidar ** 2)
+ # Absolute distance relative to origin
+ # d_lidar = np.sqrt(x_lidar ** 2 + y_lidar ** 2, z_lidar ** 2)
+
+ v_fov_total = -v_fov[0] + v_fov[1]
+
+ # Convert to Radians
+ v_res_rad = v_res * (np.pi/180)
+ h_res_rad = h_res * (np.pi/180)
+
+ # PROJECT INTO IMAGE COORDINATES
+ x_img = np.arctan2(-y_lidar, x_lidar)/ h_res_rad
+ y_img = np.arctan2(z_lidar, d_lidar)/ v_res_rad
+
+ # SHIFT COORDINATES TO MAKE 0,0 THE MINIMUM
+ x_min = -360.0 / h_res / 2 # Theoretical min x value based on sensor specs
+ x_img -= x_min # Shift
+ x_max = 360.0 / h_res # Theoretical max x value after shifting
+
+ y_min = v_fov[0] / v_res # theoretical min y value based on sensor specs
+ y_img -= y_min # Shift
+ y_max = v_fov_total / v_res # Theoretical max x value after shifting
+
+ y_max += y_fudge # Fudge factor if the calculations based on
+ # spec sheet do not match the range of
+ # angles collected by in the data.
+
+ # WHAT DATA TO USE TO ENCODE THE VALUE FOR EACH PIXEL
+ if val == "reflectance":
+ pass
+ elif val == "height":
+ pixel_values = z_lidar
+ else:
+ pixel_values = -d_lidar
+ # pixel_values = 'w'
+
+ # PLOT THE IMAGE
+ cmap = "jet" # Color map to use
+ dpi = 100 # Image resolution
+ fig, ax = plt.subplots(figsize=(x_max/dpi, y_max/dpi), dpi=dpi)
+ ax.scatter(x_img,y_img, s=1, c=pixel_values, linewidths=0, alpha=1, cmap=cmap)
+ ax.set_facecolor((0, 0, 0)) # Set regions with no points to black
+ ax.axis('scaled') # {equal, scaled}
+ ax.xaxis.set_visible(False) # Do not draw axis tick marks
+ ax.yaxis.set_visible(False) # Do not draw axis tick marks
+ plt.xlim([0, x_max]) # prevent drawing empty space outside of horizontal FOV
+ plt.ylim([0, y_max]) # prevent drawing empty space outside of vertical FOV
+
+ saveto = saveto.replace('.pcd','.png')
+ fig.savefig(saveto, dpi=dpi, bbox_inches='tight', pad_inches=0.0)
+ plt.close()
+ img = cv2.imread(saveto)
+ img = align_img(img)
+
+ aligned_path = saveto.replace('offline','aligned')
+ os.makedirs(os.path.dirname(aligned_path), exist_ok=True)
+ cv2.imwrite(aligned_path, img)
+ # fig, ax = plt.subplots(figsize=(x_max/dpi, y_max/dpi), dpi=dpi)
+ # ax.scatter(x_img,y_img, s=1, c='white', linewidths=0, alpha=1)
+ # ax.set_facecolor((0, 0, 0)) # Set regions with no points to black
+ # ax.axis('scaled') # {equal, scaled}
+ # ax.xaxis.set_visible(False) # Do not draw axis tick marks
+ # ax.yaxis.set_visible(False) # Do not draw axis tick marks
+ # plt.xlim([0, x_max]) # prevent drawing empty space outside of horizontal FOV
+ # plt.ylim([0, y_max]) # prevent drawing empty space outside of vertical FOV
+
+ # fig.savefig(saveto.replace('depth','sils'), dpi=dpi, bbox_inches='tight', pad_inches=0.0)
+ # plt.close()
+
+
+def pcd2depth(img_groups: Tuple, output_path: Path, img_size: int = 64, verbose: bool = False, dataset='CASIAB') -> None:
+ """Reads a group of images and saves the data in pickle format.
+ Args:
+ img_groups (Tuple): Tuple of (sid, seq, view) and list of image paths.
+ output_path (Path): Output path.
+ img_size (int, optional): Image resizing size. Defaults to 64.
+ verbose (bool, optional): Display debug info. Defaults to False.
+ """
+ sinfo = img_groups[0]
+ img_paths = img_groups[1]
+ for img_file in sorted(img_paths):
+ pcd_name = img_file.split('/')[-1]
+ pcd = o3d.io.read_point_cloud(img_file)
+ points = np.asarray(pcd.points)
+ HRES = 0.19188 # horizontal resolution (assuming 20Hz setting)
+ VRES = 0.2
+ VFOV = (-25.0, 15.0) # Field of view (-ve, +ve) along vertical axis
+ Y_FUDGE = 0 # y fudge factor for velodyne HDL 64E
+ dst_path = os.path.join(output_path, *sinfo)
+ os.makedirs(dst_path, exist_ok=True)
+ dst_path = os.path.join(dst_path,pcd_name)
+ lidar_to_2d_front_view(points, v_res=VRES, h_res=HRES, v_fov=VFOV, val="depth",
+ saveto=dst_path, y_fudge=Y_FUDGE)
+ # if len(points) == 0:
+ # print(img_file)
+ # to_pickle.append(points)
+ # dst_path = os.path.join(output_path, *sinfo)
+ # os.makedirs(dst_path, exist_ok=True)
+ # pkl_path = os.path.join(dst_path, f'pcd-{sinfo[2]}.pkl')
+ # pickle.dump(to_pickle, open(pkl_path, 'wb'))
+ # if len(to_pickle) < 5:
+ # logging.warning(f'{sinfo} has less than 5 valid data.')
+
+
+
+def pretreat(input_path: Path, output_path: Path, img_size: int = 64, workers: int = 4, verbose: bool = False, dataset: str = 'CASIAB') -> None:
+ """Reads a dataset and saves the data in pickle format.
+ Args:
+ input_path (Path): Dataset root path.
+ output_path (Path): Output path.
+ img_size (int, optional): Image resizing size. Defaults to 64.
+ workers (int, optional): Number of thread workers. Defaults to 4.
+ verbose (bool, optional): Display debug info. Defaults to False.
+ """
+ img_groups = defaultdict(list)
+ logging.info(f'Listing {input_path}')
+ total_files = 0
+ for sid in tqdm(sorted(os.listdir(input_path))):
+ for seq in os.listdir(os.path.join(input_path,sid)):
+ for view in os.listdir(os.path.join(input_path,sid,seq)):
+ for img_path in os.listdir(os.path.join(input_path,sid,seq,view,'PCDs')):
+ img_groups[(sid, seq, view,'PCDs_offline_depths')].append(os.path.join(input_path,sid,seq,view, 'PCDs',img_path))
+ total_files += 1
+
+ logging.info(f'Total files listed: {total_files}')
+
+ progress = tqdm(total=len(img_groups), desc='Pretreating', unit='folder')
+
+ with mp.Pool(workers) as pool:
+ logging.info(f'Start pretreating {input_path}')
+ for _ in pool.imap_unordered(partial(pcd2depth, output_path=output_path, img_size=img_size, verbose=verbose, dataset=dataset), img_groups.items()):
+ progress.update(1)
+ logging.info('Done')
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='OpenGait dataset pretreatment module.')
+ parser.add_argument('-i', '--input_path', default='', type=str, help='Root path of raw dataset.')
+ parser.add_argument('-o', '--output_path', default='', type=str, help='Output path of pickled dataset.')
+ parser.add_argument('-l', '--log_file', default='./pretreatment.log', type=str, help='Log file path. Default: ./pretreatment.log')
+ parser.add_argument('-n', '--n_workers', default=4, type=int, help='Number of thread workers. Default: 4')
+ parser.add_argument('-r', '--img_size', default=64, type=int, help='Image resizing size. Default 64')
+ parser.add_argument('-d', '--dataset', default='CASIAB', type=str, help='Dataset for pretreatment.')
+ parser.add_argument('-v', '--verbose', default=False, action='store_true', help='Display debug info.')
+ args = parser.parse_args()
+
+ logging.basicConfig(level=logging.INFO, filename=args.log_file, filemode='w', format='[%(asctime)s - %(levelname)s]: %(message)s')
+
+ if args.verbose:
+ logging.getLogger().setLevel(logging.DEBUG)
+ logging.info('Verbose mode is on.')
+ for k, v in args.__dict__.items():
+ logging.debug(f'{k}: {v}')
+
+ pretreat(input_path=Path(args.input_path), output_path=Path(args.output_path), img_size=args.img_size, workers=args.n_workers, verbose=args.verbose, dataset=args.dataset)
diff --git a/datasets/SUSTech1K/pretreatment_SUSTech1K.py b/datasets/SUSTech1K/pretreatment_SUSTech1K.py
new file mode 100644
index 0000000..e379ffd
--- /dev/null
+++ b/datasets/SUSTech1K/pretreatment_SUSTech1K.py
@@ -0,0 +1,221 @@
+# This source is based on https://github.com/AbnerHqC/GaitSet/blob/master/pretreatment.py
+import argparse
+import logging
+import multiprocessing as mp
+import os
+import pickle
+from collections import defaultdict
+from functools import partial
+from pathlib import Path
+from typing import Tuple
+
+import cv2
+import numpy as np
+from tqdm import tqdm
+
+import json
+import open3d as o3d
+
+def compare_pcd_rgb_timestamp(pcd_file,rgb_file):
+ pcd_time = float(pcd_file.split('/')[-1].replace('.pcd','')) + 0.05
+ rgb_time = float(rgb_file.split('/')[-1].replace('.jpg','')[:10] + '.' + rgb_file.split('/')[-1].replace('.jpg','')[10:])
+ return pcd_time, rgb_time
+
+
+
+def imgs2pickle(img_groups: Tuple, output_path: Path, img_size: int = 64, verbose: bool = False, dataset='CASIAB') -> None:
+ """Reads a group of images and saves the data in pickle format.
+
+ Args:
+ img_groups (Tuple): Tuple of (sid, seq, view) and list of image paths.
+ output_path (Path): Output path.
+ img_size (int, optional): Image resizing size. Defaults to 64.
+ verbose (bool, optional): Display debug info. Defaults to False.
+ """
+ sinfo = img_groups[0]
+ img_paths = img_groups[1] # path with modality name
+ to_pickle = []
+ cnt = 0
+ pcd_list = []
+ rgb_list = []
+
+ threshold = 0.020 # 20 ms
+
+ for index, modality_files in enumerate(img_paths):
+ data_files = modality_files[1]
+ modality = modality_files[0]
+ if modality == 'PCDs':
+ data = [np.asarray(o3d.io.read_point_cloud(points).points) for points in data_files]
+ pcd_list = data_files
+ elif modality == 'RGB_raw':
+ imgs = [cv2.imread(rgb) for rgb in data_files]
+ rgb_list = data_files
+ imgs = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in imgs]
+ HWs = [img.shape[:2] for img in imgs]
+ # transpose to (C, H W)
+ data = [cv2.resize(img, (img_size, img_size), interpolation=cv2.INTER_CUBIC) for img in imgs]
+ imgs = [img.transpose(2, 0, 1) for img in imgs]
+ data = np.asarray(data)
+ HWs = np.asarray(HWs)
+ elif modality == 'Sils_raw':
+ sils = [cv2.imread(sil, cv2.IMREAD_GRAYSCALE) for sil in data_files]
+ data = [cv2.resize(sil, (img_size, img_size), interpolation=cv2.INTER_CUBIC) for sil in sils]
+ data = np.asarray(data)
+ elif modality == 'Sils_aligned':
+ sils = [cv2.imread(sil, cv2.IMREAD_GRAYSCALE) for sil in data_files]
+ data = [cv2.resize(sil, (img_size, img_size), interpolation=cv2.INTER_CUBIC) for sil in sils]
+ data = np.asarray(data)
+ elif modality == 'Pose':
+ data = [json.load(open(pose)) for pose in data_files]
+ data = np.asarray(data)
+ elif modality == 'PCDs_depths':
+ imgs = [cv2.imread(rgb) for rgb in data_files]
+ imgs = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in imgs]
+ data = [img.transpose(2, 0, 1) for img in imgs]
+ data = np.asarray(data)
+ elif modality == 'PCDs_sils':
+ data = [cv2.imread(sil, cv2.IMREAD_GRAYSCALE) for sil in data_files]
+ data = np.asarray(data)
+
+ dst_path = os.path.join(output_path, *sinfo)
+ os.makedirs(dst_path, exist_ok=True)
+ if modality == 'RGB_raw':
+ pkl_path = os.path.join(dst_path, f'{cnt:02d}-{sinfo[2]}-Camera-Ratios-HW.pkl')
+ pickle.dump(HWs, open(pkl_path, 'wb'))
+ cnt += 1
+
+ if 'PCDs' in modality:
+ pkl_path = os.path.join(dst_path, f'{cnt:02d}-{sinfo[2]}-LiDAR-{modality}.pkl')
+ pickle.dump(data, open(pkl_path, 'wb'))
+ else:
+ pkl_path = os.path.join(dst_path, f'{cnt:02d}-{sinfo[2]}-Camera-{modality}.pkl')
+ pickle.dump(data, open(pkl_path, 'wb'))
+ cnt += 1
+
+ pcd_indexs = []
+ rgb_indexs = []
+ # print(pcd_list)
+ for pcd_index in range(len(pcd_list)):
+ time_diff = 1
+ tmp = pcd_index, 0
+ for rgb_index in range(len(rgb_list)):
+ pcd_t, rgb_t = compare_pcd_rgb_timestamp(pcd_list[pcd_index], rgb_list[rgb_index])
+ diff = abs(pcd_t - rgb_t)
+ if diff < time_diff:
+ tmp = pcd_index, rgb_index
+ time_diff = diff
+ if time_diff <= threshold:
+ pcd_indexs.append(tmp[0])
+ rgb_indexs.append(tmp[1])
+
+ if len(set(pcd_indexs)) != len(pcd_indexs):
+ print(img_groups[0], pcd_indexs, rgb_indexs, len(pcd_indexs) == len(pcd_indexs))
+
+ for index, modality_files in enumerate(img_paths):
+ modality = modality_files[0]
+ data_files = modality_files[1]
+ data_files = [data_files[index] for index in pcd_indexs] if 'PCDs' in modality else [data_files[index] for index in rgb_indexs]
+
+ if modality == 'PCDs':
+ data = [np.asarray(o3d.io.read_point_cloud(points).points) for points in data_files]
+ pcd_list = data_files
+ elif modality == 'RGB_raw':
+ imgs = [cv2.imread(rgb) for rgb in data_files]
+ rgb_list = data_files
+ imgs = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in imgs]
+ HWs = [img.shape[:2] for img in imgs]
+ # transpose to (C, H W)
+ data = [cv2.resize(img, (img_size, img_size), interpolation=cv2.INTER_CUBIC) for img in imgs]
+ imgs = [img.transpose(2, 0, 1) for img in imgs]
+ data = np.asarray(data)
+ HWs = np.asarray(HWs)
+ elif modality == 'Sils_raw':
+ sils = [cv2.imread(sil, cv2.IMREAD_GRAYSCALE) for sil in data_files]
+ data = [cv2.resize(sil, (img_size, img_size), interpolation=cv2.INTER_CUBIC) for sil in sils]
+ data = np.asarray(data)
+ elif modality == 'Sils_aligned':
+ sils = [cv2.imread(sil, cv2.IMREAD_GRAYSCALE) for sil in data_files]
+ data = [cv2.resize(sil, (img_size, img_size), interpolation=cv2.INTER_CUBIC) for sil in sils]
+ data = np.asarray(data)
+ elif modality == 'Pose':
+ data = [json.load(open(pose)) for pose in data_files]
+ data = np.asarray(data)
+ elif modality == 'PCDs_depths':
+ imgs = [cv2.imread(rgb) for rgb in data_files]
+ imgs = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in imgs]
+ data = [img.transpose(2, 0, 1) for img in imgs]
+ data = np.asarray(data)
+ elif modality == 'PCDs_sils':
+ data = [cv2.imread(sil, cv2.IMREAD_GRAYSCALE) for sil in data_files]
+ data = np.asarray(data)
+
+ dst_path = os.path.join(output_path, *sinfo)
+ os.makedirs(dst_path, exist_ok=True)
+ if modality == 'RGB_raw':
+ pkl_path = os.path.join(dst_path, f'{cnt:02d}-sync-{sinfo[2]}-Camera-Ratios-HW.pkl')
+ pickle.dump(HWs, open(pkl_path, 'wb'))
+ cnt += 1
+
+ if 'PCDs' in modality:
+ pkl_path = os.path.join(dst_path, f'{cnt:02d}-sync-{sinfo[2]}-LiDAR-{modality}.pkl')
+ pickle.dump(data, open(pkl_path, 'wb'))
+ else:
+ pkl_path = os.path.join(dst_path, f'{cnt:02d}-sync-{sinfo[2]}-Camera-{modality}.pkl')
+ pickle.dump(data, open(pkl_path, 'wb'))
+ cnt += 1
+
+
+def pretreat(input_path: Path, output_path: Path, img_size: int = 64, workers: int = 4, verbose: bool = False, dataset: str = 'CASIAB') -> None:
+ """Reads a dataset and saves the data in pickle format.
+
+ Args:
+ input_path (Path): Dataset root path.
+ output_path (Path): Output path.
+ img_size (int, optional): Image resizing size. Defaults to 64.
+ workers (int, optional): Number of thread workers. Defaults to 4.
+ verbose (bool, optional): Display debug info. Defaults to False.
+ """
+ img_groups = defaultdict(list)
+ logging.info(f'Listing {input_path}')
+ total_files = 0
+ for id_ in tqdm(sorted(os.listdir(input_path))):
+ for type_ in os.listdir(os.path.join(input_path,id_)):
+ for view_ in os.listdir(os.path.join(input_path,id_,type_)):
+ for modality in sorted(os.listdir(os.path.join(input_path,id_,type_,view_))):
+ modality_path = os.path.join(input_path,id_,type_,view_,modality)
+ file_names = sorted(os.listdir(modality_path))
+ file_names = [os.path.join(modality_path, file_name) for file_name in file_names]
+ img_groups[(id_, type_, view_)].append((modality, file_names))
+ total_files += 1
+
+ logging.info(f'Total files listed: {total_files}')
+
+ progress = tqdm(total=len(img_groups), desc='Pretreating', unit='folder')
+
+ with mp.Pool(workers) as pool:
+ logging.info(f'Start pretreating {input_path}')
+ for _ in pool.imap_unordered(partial(imgs2pickle, output_path=output_path, img_size=img_size, verbose=verbose, dataset=dataset), img_groups.items()):
+ progress.update(1)
+ logging.info('Done')
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='OpenGait dataset pretreatment module.')
+ parser.add_argument('-i', '--input_path', default='', type=str, help='Root path of raw dataset.')
+ parser.add_argument('-o', '--output_path', default='', type=str, help='Output path of pickled dataset.')
+ parser.add_argument('-l', '--log_file', default='./pretreatment.log', type=str, help='Log file path. Default: ./pretreatment.log')
+ parser.add_argument('-n', '--n_workers', default=4, type=int, help='Number of thread workers. Default: 4')
+ parser.add_argument('-r', '--img_size', default=64, type=int, help='Image resizing size. Default 64')
+ parser.add_argument('-d', '--dataset', default='CASIAB', type=str, help='Dataset for pretreatment.')
+ parser.add_argument('-v', '--verbose', default=False, action='store_true', help='Display debug info.')
+ args = parser.parse_args()
+
+ logging.basicConfig(level=logging.INFO, filename=args.log_file, filemode='w', format='[%(asctime)s - %(levelname)s]: %(message)s')
+
+ if args.verbose:
+ logging.getLogger().setLevel(logging.DEBUG)
+ logging.info('Verbose mode is on.')
+ for k, v in args.__dict__.items():
+ logging.debug(f'{k}: {v}')
+
+ pretreat(input_path=Path(args.input_path), output_path=Path(args.output_path), img_size=args.img_size, workers=args.n_workers, verbose=args.verbose, dataset=args.dataset)
\ No newline at end of file
diff --git a/opengait/evaluation/evaluator.py b/opengait/evaluation/evaluator.py
index 896e4ae..3546531 100644
--- a/opengait/evaluation/evaluator.py
+++ b/opengait/evaluation/evaluator.py
@@ -74,46 +74,59 @@ def single_view_gallery_evaluation(feature, label, seq_type, view, dataset, metr
'CASIA-E': {'NM': ['H-scene2-nm-1', 'H-scene2-nm-2', 'L-scene2-nm-1', 'L-scene2-nm-2', 'H-scene3-nm-1', 'H-scene3-nm-2', 'L-scene3-nm-1', 'L-scene3-nm-2', 'H-scene3_s-nm-1', 'H-scene3_s-nm-2', 'L-scene3_s-nm-1', 'L-scene3_s-nm-2', ],
'BG': ['H-scene2-bg-1', 'H-scene2-bg-2', 'L-scene2-bg-1', 'L-scene2-bg-2', 'H-scene3-bg-1', 'H-scene3-bg-2', 'L-scene3-bg-1', 'L-scene3-bg-2', 'H-scene3_s-bg-1', 'H-scene3_s-bg-2', 'L-scene3_s-bg-1', 'L-scene3_s-bg-2'],
'CL': ['H-scene2-cl-1', 'H-scene2-cl-2', 'L-scene2-cl-1', 'L-scene2-cl-2', 'H-scene3-cl-1', 'H-scene3-cl-2', 'L-scene3-cl-1', 'L-scene3-cl-2', 'H-scene3_s-cl-1', 'H-scene3_s-cl-2', 'L-scene3_s-cl-1', 'L-scene3_s-cl-2']
- }
-
+ },
+ 'SUSTech1K': {'Normal': ['01-nm'], 'Bag': ['bg'], 'Clothing': ['cl'], 'Carrying':['cr'], 'Umberalla': ['ub'], 'Uniform': ['uf'], 'Occlusion': ['oc'],'Night': ['nt'], 'Overall': ['01','02','03','04']}
}
gallery_seq_dict = {'CASIA-B': ['nm-01', 'nm-02', 'nm-03', 'nm-04'],
'OUMVLP': ['01'],
- 'CASIA-E': ['H-scene1-nm-1', 'H-scene1-nm-2', 'L-scene1-nm-1', 'L-scene1-nm-2']}
+ 'CASIA-E': ['H-scene1-nm-1', 'H-scene1-nm-2', 'L-scene1-nm-1', 'L-scene1-nm-2'],
+ 'SUSTech1K': ['00-nm'],}
msg_mgr = get_msg_mgr()
acc = {}
view_list = sorted(np.unique(view))
+ num_rank = 1
if dataset == 'CASIA-E':
view_list.remove("270")
+ if dataset == 'SUSTech1K':
+ num_rank = 5
view_num = len(view_list)
- num_rank = 1
+
for (type_, probe_seq) in probe_seq_dict[dataset].items():
- acc[type_] = np.zeros((view_num, view_num)) - 1.
+ acc[type_] = np.zeros((view_num, view_num, num_rank)) - 1.
for (v1, probe_view) in enumerate(view_list):
pseq_mask = np.isin(seq_type, probe_seq) & np.isin(
view, probe_view)
+ pseq_mask = pseq_mask if 'SUSTech1K' not in dataset else np.any(np.asarray(
+ [np.char.find(seq_type, probe)>=0 for probe in probe_seq]), axis=0
+ ) & np.isin(view, probe_view) # For SUSTech1K only
probe_x = feature[pseq_mask, :]
probe_y = label[pseq_mask]
for (v2, gallery_view) in enumerate(view_list):
gseq_mask = np.isin(seq_type, gallery_seq_dict[dataset]) & np.isin(
view, [gallery_view])
+ gseq_mask = gseq_mask if 'SUSTech1K' not in dataset else np.any(np.asarray(
+ [np.char.find(seq_type, gallery)>=0 for gallery in gallery_seq_dict[dataset]]), axis=0
+ ) & np.isin(view, [gallery_view]) # For SUSTech1K only
gallery_y = label[gseq_mask]
gallery_x = feature[gseq_mask, :]
dist = cuda_dist(probe_x, gallery_x, metric)
idx = dist.topk(num_rank, largest=False)[1].cpu().numpy()
- acc[type_][v1, v2] = np.round(np.sum(np.cumsum(np.reshape(probe_y, [-1, 1]) == gallery_y[idx], 1) > 0,
+ acc[type_][v1, v2, :] = np.round(np.sum(np.cumsum(np.reshape(probe_y, [-1, 1]) == gallery_y[idx[:, 0:num_rank]], 1) > 0,
0) * 100 / dist.shape[0], 2)
result_dict = {}
msg_mgr.log_info('===Rank-1 (Exclude identical-view cases)===')
out_str = ""
- for type_ in probe_seq_dict[dataset].keys():
- sub_acc = de_diag(acc[type_], each_angle=True)
- msg_mgr.log_info(f'{type_}: {sub_acc}')
- result_dict[f'scalar/test_accuracy/{type_}'] = np.mean(sub_acc)
- out_str += f"{type_}: {np.mean(sub_acc):.2f}%\t"
- msg_mgr.log_info(out_str)
+ for rank in range(num_rank):
+ out_str = ""
+ for type_ in probe_seq_dict[dataset].keys():
+ sub_acc = de_diag(acc[type_][:,:,rank], each_angle=True)
+ if rank == 0:
+ msg_mgr.log_info(f'{type_}@R{rank+1}: {sub_acc}')
+ result_dict[f'scalar/test_accuracy/{type_}@R{rank+1}'] = np.mean(sub_acc)
+ out_str += f"{type_}@R{rank+1}: {np.mean(sub_acc):.2f}%\t"
+ msg_mgr.log_info(out_str)
return result_dict
@@ -122,7 +135,7 @@ def evaluate_indoor_dataset(data, dataset, metric='euc', cross_view_gallery=Fals
label = np.array(label)
view = np.array(view)
- if dataset not in ('CASIA-B', 'OUMVLP', 'CASIA-E'):
+ if dataset not in ('CASIA-B', 'OUMVLP', 'CASIA-E', 'SUSTech1K'):
raise KeyError("DataSet %s hasn't been supported !" % dataset)
if cross_view_gallery:
return cross_view_gallery_evaluation(
diff --git a/opengait/modeling/models/baseline.py b/opengait/modeling/models/baseline.py
index 4e1c72f..ba130d1 100644
--- a/opengait/modeling/models/baseline.py
+++ b/opengait/modeling/models/baseline.py
@@ -3,6 +3,7 @@ import torch
from ..base_model import BaseModel
from ..modules import SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWrapper, SeparateFCs, SeparateBNNecks
+from einops import rearrange
class Baseline(BaseModel):
@@ -20,6 +21,8 @@ class Baseline(BaseModel):
sils = ipts[0]
if len(sils.size()) == 4:
sils = sils.unsqueeze(1)
+ else:
+ sils = rearrange(sils, 'n s c h w -> n c s h w')
del ipts
outs = self.Backbone(sils) # [n, c, s, h, w]
@@ -33,17 +36,16 @@ class Baseline(BaseModel):
embed_2, logits = self.BNNecks(embed_1) # [n, c, p]
embed = embed_1
- n, _, s, h, w = sils.size()
retval = {
'training_feat': {
'triplet': {'embeddings': embed_1, 'labels': labs},
'softmax': {'logits': logits, 'labels': labs}
},
'visual_summary': {
- 'image/sils': sils.view(n*s, 1, h, w)
+ 'image/sils': rearrange(sils,'n c s h w -> (n s) c h w')
},
'inference_feat': {
'embeddings': embed
}
}
- return retval
+ return retval
\ No newline at end of file