Support SUSTech1K
This commit is contained in:
@@ -3,11 +3,14 @@
|
|||||||
<div align="center"><img src="./assets/nm.gif" width = "100" height = "100" alt="nm" /><img src="./assets/bg.gif" width = "100" height = "100" alt="bg" /><img src="./assets/cl.gif" width = "100" height = "100" alt="cl" /></div>
|
<div align="center"><img src="./assets/nm.gif" width = "100" height = "100" alt="nm" /><img src="./assets/bg.gif" width = "100" height = "100" alt="bg" /><img src="./assets/cl.gif" width = "100" height = "100" alt="cl" /></div>
|
||||||
|
|
||||||
------------------------------------------
|
------------------------------------------
|
||||||
|
📣📣📣 **[*SUSTech1K*](https://lidargait.github.io) relseased, pls checking the [tutorial](datasets/SUSTech1K/README.md).** 📣📣📣
|
||||||
|
|
||||||
🎉🎉🎉 **[*OpenGait*](https://openaccess.thecvf.com/content/CVPR2023/papers/Fan_OpenGait_Revisiting_Gait_Recognition_Towards_Better_Practicality_CVPR_2023_paper.pdf) has been accpected by CVPR2023 as a highlight paper!** 🎉🎉🎉
|
🎉🎉🎉 **[*OpenGait*](https://openaccess.thecvf.com/content/CVPR2023/papers/Fan_OpenGait_Revisiting_Gait_Recognition_Towards_Better_Practicality_CVPR_2023_paper.pdf) has been accpected by CVPR2023 as a highlight paper!** 🎉🎉🎉
|
||||||
|
|
||||||
OpenGait is a flexible and extensible gait recognition project provided by the [Shiqi Yu Group](https://faculty.sustech.edu.cn/yusq/) and supported in part by [WATRIX.AI](http://www.watrix.ai).
|
OpenGait is a flexible and extensible gait recognition project provided by the [Shiqi Yu Group](https://faculty.sustech.edu.cn/yusq/) and supported in part by [WATRIX.AI](http://www.watrix.ai).
|
||||||
|
|
||||||
## What's New
|
## What's New
|
||||||
|
- **[July 2023]** [SUSTech1K](datasets/SUSTech1K/README.md) is released and supported by OpenGait.
|
||||||
- **[May 2023]** A real gait recognition system [All-in-One-Gait](https://github.com/jdyjjj/All-in-One-Gait) provided by [Dongyang Jin](https://github.com/jdyjjj) is avaliable.
|
- **[May 2023]** A real gait recognition system [All-in-One-Gait](https://github.com/jdyjjj/All-in-One-Gait) provided by [Dongyang Jin](https://github.com/jdyjjj) is avaliable.
|
||||||
- [Apr 2023] [CASIA-E](datasets/CASIA-E/README.md) is supported by OpenGait.
|
- [Apr 2023] [CASIA-E](datasets/CASIA-E/README.md) is supported by OpenGait.
|
||||||
- [Feb 2023] [HID 2023 competition](https://hid2023.iapr-tc4.org/) is open, welcome to participate. Additionally, tutorial for the competition has been updated in [datasets/HID/](./datasets/HID).
|
- [Feb 2023] [HID 2023 competition](https://hid2023.iapr-tc4.org/) is open, welcome to participate. Additionally, tutorial for the competition has been updated in [datasets/HID/](./datasets/HID).
|
||||||
@@ -50,7 +53,7 @@ Results and models are available in the [model zoo](docs/1.model_zoo.md).
|
|||||||
## Authors:
|
## Authors:
|
||||||
**Open Gait Team (OGT)**
|
**Open Gait Team (OGT)**
|
||||||
- [Chao Fan (樊超)](https://chaofan996.github.io), 12131100@mail.sustech.edu.cn
|
- [Chao Fan (樊超)](https://chaofan996.github.io), 12131100@mail.sustech.edu.cn
|
||||||
- [Chuanfu Shen (沈川福)](https://faculty.sustech.edu.cn/?p=95396&tagid=yusq&cat=2&iscss=1&snapid=1&orderby=date), 11950016@mail.sustech.edu.cn
|
- [Chuanfu Shen (沈川福)](https://chuanfushen.github.io), 11950016@mail.sustech.edu.cn
|
||||||
- [Junhao Liang (梁峻豪)](https://faculty.sustech.edu.cn/?p=95401&tagid=yusq&cat=2&iscss=1&snapid=1&orderby=date), 12132342@mail.sustech.edu.cn
|
- [Junhao Liang (梁峻豪)](https://faculty.sustech.edu.cn/?p=95401&tagid=yusq&cat=2&iscss=1&snapid=1&orderby=date), 12132342@mail.sustech.edu.cn
|
||||||
|
|
||||||
## Acknowledgement
|
## Acknowledgement
|
||||||
|
|||||||
@@ -0,0 +1,101 @@
|
|||||||
|
data_cfg:
|
||||||
|
dataset_name: SUSTech1K
|
||||||
|
dataset_root: your_path_of_SUSTech1K-Released-pkl
|
||||||
|
dataset_partition: ./datasets/SUSTech1K/SUSTech1K.json
|
||||||
|
num_workers: 4
|
||||||
|
data_in_use: [false, true, false, false, false, false, false, false, false, false, false, false, false, false, false, false]
|
||||||
|
remove_no_gallery: false # Remove probe if no gallery for it
|
||||||
|
test_dataset_name: SUSTech1K
|
||||||
|
|
||||||
|
evaluator_cfg:
|
||||||
|
enable_float16: true
|
||||||
|
restore_ckpt_strict: true
|
||||||
|
restore_hint: 40000
|
||||||
|
save_name: LidarGait
|
||||||
|
eval_func: evaluate_indoor_dataset #evaluate_Gait3D
|
||||||
|
sampler:
|
||||||
|
batch_shuffle: false
|
||||||
|
batch_size: 4
|
||||||
|
sample_type: all_ordered # all indicates whole sequence used to test, while ordered means input sequence by its natural order; Other options: fixed_unordered
|
||||||
|
frames_all_limit: 720 # limit the number of sampled frames to prevent out of memory
|
||||||
|
metric: euc # cos
|
||||||
|
transform:
|
||||||
|
- type: BaseSilTransform
|
||||||
|
|
||||||
|
loss_cfg:
|
||||||
|
- loss_term_weight: 1.0
|
||||||
|
margin: 0.2
|
||||||
|
type: TripletLoss
|
||||||
|
log_prefix: triplet
|
||||||
|
- loss_term_weight: 1.0
|
||||||
|
scale: 16
|
||||||
|
type: CrossEntropyLoss
|
||||||
|
log_prefix: softmax
|
||||||
|
log_accuracy: true
|
||||||
|
|
||||||
|
model_cfg:
|
||||||
|
model: Baseline
|
||||||
|
backbone_cfg:
|
||||||
|
type: ResNet9
|
||||||
|
in_channel: 3
|
||||||
|
block: BasicBlock
|
||||||
|
channels: # Layers configuration for automatically model construction
|
||||||
|
- 64
|
||||||
|
- 128
|
||||||
|
- 256
|
||||||
|
- 512
|
||||||
|
layers:
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
strides:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 2
|
||||||
|
- 1
|
||||||
|
maxpool: false
|
||||||
|
SeparateFCs:
|
||||||
|
in_channels: 512
|
||||||
|
out_channels: 256
|
||||||
|
parts_num: 16
|
||||||
|
SeparateBNNecks:
|
||||||
|
class_num: 250
|
||||||
|
in_channels: 256
|
||||||
|
parts_num: 16
|
||||||
|
bin_num:
|
||||||
|
- 16
|
||||||
|
|
||||||
|
optimizer_cfg:
|
||||||
|
lr: 0.1
|
||||||
|
momentum: 0.9
|
||||||
|
solver: SGD
|
||||||
|
weight_decay: 0.0005
|
||||||
|
|
||||||
|
scheduler_cfg:
|
||||||
|
gamma: 0.1
|
||||||
|
milestones: # Learning Rate Reduction at each milestones
|
||||||
|
- 20000
|
||||||
|
- 30000
|
||||||
|
scheduler: MultiStepLR
|
||||||
|
trainer_cfg:
|
||||||
|
enable_float16: true # half_percesion float for memory reduction and speedup
|
||||||
|
fix_BN: false
|
||||||
|
with_test: true #true
|
||||||
|
log_iter: 100
|
||||||
|
restore_ckpt_strict: true
|
||||||
|
restore_hint: 0
|
||||||
|
save_iter: 5000
|
||||||
|
save_name: LidarGait
|
||||||
|
sync_BN: true
|
||||||
|
total_iter: 40000
|
||||||
|
sampler:
|
||||||
|
batch_shuffle: true
|
||||||
|
batch_size:
|
||||||
|
- 8 # TripletSampler, batch_size[0] indicates Number of Identity
|
||||||
|
- 8 # batch_size[1] indicates Samples sequqnce for each Identity
|
||||||
|
frames_num_fixed: 10 # fixed frames number for training
|
||||||
|
sample_type: fixed_unordered # fixed control input frames number, unordered for controlling order of input tensor; Other options: unfixed_ordered or all_ordered
|
||||||
|
type: TripletSampler
|
||||||
|
transform:
|
||||||
|
- type: BaseSilTransform
|
||||||
@@ -0,0 +1,33 @@
|
|||||||
|
# Tutorial for [SUSTech1K](https://lidargait.github.io)
|
||||||
|
|
||||||
|
## Download the SUSTech1K dataset
|
||||||
|
Download the dataset from the [link](https://lidargait.github.io).
|
||||||
|
decompress these two file by following command:
|
||||||
|
```shell
|
||||||
|
unzip -P password SUSTech1K-pkl.zip | xargs -n1 tar xzvf
|
||||||
|
```
|
||||||
|
password should be obtained by signing [agreement](https://lidargait.github.io/static/resources/SUSTech1KAgreement.pdf) and sending to email (shencf2019@mail.sustech.edu.cn)
|
||||||
|
|
||||||
|
## Train the dataset
|
||||||
|
Modify the `dataset_root` in `configs/lidargait/lidargait_sustech1k.yaml`, and then run this command:
|
||||||
|
```shell
|
||||||
|
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 opengait/main.py --cfgs configs/lidargait/lidargait_sustech1k.yaml --phase train
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Process from RAW dataset
|
||||||
|
|
||||||
|
### Preprocess the dataset (Optional)
|
||||||
|
Download the raw dataset from the [official link](https://lidargait.github.io). You will get two compressed files, i.e. `DATASET_DOWNLOAD.md5`, `SUSTeck1K-RAW.zip`, and `SUSTeck1K-pkl.zip`.
|
||||||
|
We recommend using our provided pickle files for convenience, or process raw dataset into pickle by this command:
|
||||||
|
```shell
|
||||||
|
python datasets/SUSTech1K/pretreatment_SUSTech1K.py -i SUSTech1K-Released-2023 -o SUSTech1K-pkl -n 8
|
||||||
|
```
|
||||||
|
|
||||||
|
### Projecting PointCloud into Depth image (Optional)
|
||||||
|
You can use our processed depth images, or you can process via the command:
|
||||||
|
```shell
|
||||||
|
python datasets/SUSTech1K/point2depth.py -i SUSTech1K-Released-2023/ -o SUSTech1K-Released-2023/ -n 8
|
||||||
|
```
|
||||||
|
We recommend using our provided depth images for convenience.
|
||||||
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,279 @@
|
|||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
import open3d as o3d
|
||||||
|
# This source is based on https://github.com/AbnerHqC/GaitSet/blob/master/pretreatment.py
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import multiprocessing as mp
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
from collections import defaultdict
|
||||||
|
from functools import partial
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
def align_img(img: np.ndarray, img_size: int = 64) -> np.ndarray:
|
||||||
|
"""Aligns the image to the center.
|
||||||
|
Args:
|
||||||
|
img (np.ndarray): Image to align.
|
||||||
|
img_size (int, optional): Image resizing size. Defaults to 64.
|
||||||
|
Returns:
|
||||||
|
np.ndarray: Aligned image.
|
||||||
|
"""
|
||||||
|
if img.sum() <= 10000:
|
||||||
|
y_top = 0
|
||||||
|
y_btm = img.shape[0]
|
||||||
|
else:
|
||||||
|
# Get the upper and lower points
|
||||||
|
# img.sum
|
||||||
|
y_sum = img.sum(axis=2).sum(axis=1)
|
||||||
|
y_top = (y_sum != 0).argmax(axis=0)
|
||||||
|
y_btm = (y_sum != 0).cumsum(axis=0).argmax(axis=0)
|
||||||
|
|
||||||
|
img = img[y_top: y_btm, :,:]
|
||||||
|
|
||||||
|
# As the height of a person is larger than the width,
|
||||||
|
# use the height to calculate resize ratio.
|
||||||
|
ratio = img.shape[1] / img.shape[0]
|
||||||
|
img = cv2.resize(img, (int(img_size * ratio), img_size), interpolation=cv2.INTER_CUBIC)
|
||||||
|
|
||||||
|
# Get the median of the x-axis and take it as the person's x-center.
|
||||||
|
x_csum = img.sum(axis=2).sum(axis=0).cumsum()
|
||||||
|
x_center = img.shape[1] // 2
|
||||||
|
for idx, csum in enumerate(x_csum):
|
||||||
|
if csum > img.sum() / 2:
|
||||||
|
x_center = idx
|
||||||
|
break
|
||||||
|
|
||||||
|
# if not x_center:
|
||||||
|
# logging.warning(f'{img_file} has no center.')
|
||||||
|
# continue
|
||||||
|
|
||||||
|
# Get the left and right points
|
||||||
|
half_width = img_size // 2
|
||||||
|
left = x_center - half_width
|
||||||
|
right = x_center + half_width
|
||||||
|
if left <= 0 or right >= img.shape[1]:
|
||||||
|
left += half_width
|
||||||
|
right += half_width
|
||||||
|
# _ = np.zeros((img.shape[0], half_width,3))
|
||||||
|
# img = np.concatenate([_, img, _], axis=1)
|
||||||
|
|
||||||
|
img = img[:, left: right,:].astype('uint8')
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def lidar_to_2d_front_view(points,
|
||||||
|
v_res,
|
||||||
|
h_res,
|
||||||
|
v_fov,
|
||||||
|
val="depth",
|
||||||
|
cmap="jet",
|
||||||
|
saveto=None,
|
||||||
|
y_fudge=0.0
|
||||||
|
):
|
||||||
|
""" Takes points in 3D space from LIDAR data and projects them to a 2D
|
||||||
|
"front view" image, and saves that image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
points: (np array)
|
||||||
|
The numpy array containing the lidar points.
|
||||||
|
The shape should be Nx4
|
||||||
|
- Where N is the number of points, and
|
||||||
|
- each point is specified by 4 values (x, y, z, reflectance)
|
||||||
|
v_res: (float)
|
||||||
|
vertical resolution of the lidar sensor used.
|
||||||
|
h_res: (float)
|
||||||
|
horizontal resolution of the lidar sensor used.
|
||||||
|
v_fov: (tuple of two floats)
|
||||||
|
(minimum_negative_angle, max_positive_angle)
|
||||||
|
val: (str)
|
||||||
|
What value to use to encode the points that get plotted.
|
||||||
|
One of {"depth", "height", "reflectance"}
|
||||||
|
cmap: (str)
|
||||||
|
Color map to use to color code the `val` values.
|
||||||
|
NOTE: Must be a value accepted by matplotlib's scatter function
|
||||||
|
Examples: "jet", "gray"
|
||||||
|
saveto: (str or None)
|
||||||
|
If a string is provided, it saves the image as this filename.
|
||||||
|
If None, then it just shows the image.
|
||||||
|
y_fudge: (float)
|
||||||
|
A hacky fudge factor to use if the theoretical calculations of
|
||||||
|
vertical range do not match the actual data.
|
||||||
|
|
||||||
|
For a Velodyne HDL 64E, set this value to 5.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# DUMMY PROOFING
|
||||||
|
assert len(v_fov) ==2, "v_fov must be list/tuple of length 2"
|
||||||
|
assert v_fov[0] <= 0, "first element in v_fov must be 0 or negative"
|
||||||
|
assert val in {"depth", "height", "reflectance"}, \
|
||||||
|
'val must be one of {"depth", "height", "reflectance"}'
|
||||||
|
|
||||||
|
|
||||||
|
x_lidar = - points[:, 0]
|
||||||
|
y_lidar = - points[:, 1]
|
||||||
|
z_lidar = points[:, 2]
|
||||||
|
# Distance relative to origin when looked from top
|
||||||
|
d_lidar = np.sqrt(x_lidar ** 2 + y_lidar ** 2)
|
||||||
|
# Absolute distance relative to origin
|
||||||
|
# d_lidar = np.sqrt(x_lidar ** 2 + y_lidar ** 2, z_lidar ** 2)
|
||||||
|
|
||||||
|
v_fov_total = -v_fov[0] + v_fov[1]
|
||||||
|
|
||||||
|
# Convert to Radians
|
||||||
|
v_res_rad = v_res * (np.pi/180)
|
||||||
|
h_res_rad = h_res * (np.pi/180)
|
||||||
|
|
||||||
|
# PROJECT INTO IMAGE COORDINATES
|
||||||
|
x_img = np.arctan2(-y_lidar, x_lidar)/ h_res_rad
|
||||||
|
y_img = np.arctan2(z_lidar, d_lidar)/ v_res_rad
|
||||||
|
|
||||||
|
# SHIFT COORDINATES TO MAKE 0,0 THE MINIMUM
|
||||||
|
x_min = -360.0 / h_res / 2 # Theoretical min x value based on sensor specs
|
||||||
|
x_img -= x_min # Shift
|
||||||
|
x_max = 360.0 / h_res # Theoretical max x value after shifting
|
||||||
|
|
||||||
|
y_min = v_fov[0] / v_res # theoretical min y value based on sensor specs
|
||||||
|
y_img -= y_min # Shift
|
||||||
|
y_max = v_fov_total / v_res # Theoretical max x value after shifting
|
||||||
|
|
||||||
|
y_max += y_fudge # Fudge factor if the calculations based on
|
||||||
|
# spec sheet do not match the range of
|
||||||
|
# angles collected by in the data.
|
||||||
|
|
||||||
|
# WHAT DATA TO USE TO ENCODE THE VALUE FOR EACH PIXEL
|
||||||
|
if val == "reflectance":
|
||||||
|
pass
|
||||||
|
elif val == "height":
|
||||||
|
pixel_values = z_lidar
|
||||||
|
else:
|
||||||
|
pixel_values = -d_lidar
|
||||||
|
# pixel_values = 'w'
|
||||||
|
|
||||||
|
# PLOT THE IMAGE
|
||||||
|
cmap = "jet" # Color map to use
|
||||||
|
dpi = 100 # Image resolution
|
||||||
|
fig, ax = plt.subplots(figsize=(x_max/dpi, y_max/dpi), dpi=dpi)
|
||||||
|
ax.scatter(x_img,y_img, s=1, c=pixel_values, linewidths=0, alpha=1, cmap=cmap)
|
||||||
|
ax.set_facecolor((0, 0, 0)) # Set regions with no points to black
|
||||||
|
ax.axis('scaled') # {equal, scaled}
|
||||||
|
ax.xaxis.set_visible(False) # Do not draw axis tick marks
|
||||||
|
ax.yaxis.set_visible(False) # Do not draw axis tick marks
|
||||||
|
plt.xlim([0, x_max]) # prevent drawing empty space outside of horizontal FOV
|
||||||
|
plt.ylim([0, y_max]) # prevent drawing empty space outside of vertical FOV
|
||||||
|
|
||||||
|
saveto = saveto.replace('.pcd','.png')
|
||||||
|
fig.savefig(saveto, dpi=dpi, bbox_inches='tight', pad_inches=0.0)
|
||||||
|
plt.close()
|
||||||
|
img = cv2.imread(saveto)
|
||||||
|
img = align_img(img)
|
||||||
|
|
||||||
|
aligned_path = saveto.replace('offline','aligned')
|
||||||
|
os.makedirs(os.path.dirname(aligned_path), exist_ok=True)
|
||||||
|
cv2.imwrite(aligned_path, img)
|
||||||
|
# fig, ax = plt.subplots(figsize=(x_max/dpi, y_max/dpi), dpi=dpi)
|
||||||
|
# ax.scatter(x_img,y_img, s=1, c='white', linewidths=0, alpha=1)
|
||||||
|
# ax.set_facecolor((0, 0, 0)) # Set regions with no points to black
|
||||||
|
# ax.axis('scaled') # {equal, scaled}
|
||||||
|
# ax.xaxis.set_visible(False) # Do not draw axis tick marks
|
||||||
|
# ax.yaxis.set_visible(False) # Do not draw axis tick marks
|
||||||
|
# plt.xlim([0, x_max]) # prevent drawing empty space outside of horizontal FOV
|
||||||
|
# plt.ylim([0, y_max]) # prevent drawing empty space outside of vertical FOV
|
||||||
|
|
||||||
|
# fig.savefig(saveto.replace('depth','sils'), dpi=dpi, bbox_inches='tight', pad_inches=0.0)
|
||||||
|
# plt.close()
|
||||||
|
|
||||||
|
|
||||||
|
def pcd2depth(img_groups: Tuple, output_path: Path, img_size: int = 64, verbose: bool = False, dataset='CASIAB') -> None:
|
||||||
|
"""Reads a group of images and saves the data in pickle format.
|
||||||
|
Args:
|
||||||
|
img_groups (Tuple): Tuple of (sid, seq, view) and list of image paths.
|
||||||
|
output_path (Path): Output path.
|
||||||
|
img_size (int, optional): Image resizing size. Defaults to 64.
|
||||||
|
verbose (bool, optional): Display debug info. Defaults to False.
|
||||||
|
"""
|
||||||
|
sinfo = img_groups[0]
|
||||||
|
img_paths = img_groups[1]
|
||||||
|
for img_file in sorted(img_paths):
|
||||||
|
pcd_name = img_file.split('/')[-1]
|
||||||
|
pcd = o3d.io.read_point_cloud(img_file)
|
||||||
|
points = np.asarray(pcd.points)
|
||||||
|
HRES = 0.19188 # horizontal resolution (assuming 20Hz setting)
|
||||||
|
VRES = 0.2
|
||||||
|
VFOV = (-25.0, 15.0) # Field of view (-ve, +ve) along vertical axis
|
||||||
|
Y_FUDGE = 0 # y fudge factor for velodyne HDL 64E
|
||||||
|
dst_path = os.path.join(output_path, *sinfo)
|
||||||
|
os.makedirs(dst_path, exist_ok=True)
|
||||||
|
dst_path = os.path.join(dst_path,pcd_name)
|
||||||
|
lidar_to_2d_front_view(points, v_res=VRES, h_res=HRES, v_fov=VFOV, val="depth",
|
||||||
|
saveto=dst_path, y_fudge=Y_FUDGE)
|
||||||
|
# if len(points) == 0:
|
||||||
|
# print(img_file)
|
||||||
|
# to_pickle.append(points)
|
||||||
|
# dst_path = os.path.join(output_path, *sinfo)
|
||||||
|
# os.makedirs(dst_path, exist_ok=True)
|
||||||
|
# pkl_path = os.path.join(dst_path, f'pcd-{sinfo[2]}.pkl')
|
||||||
|
# pickle.dump(to_pickle, open(pkl_path, 'wb'))
|
||||||
|
# if len(to_pickle) < 5:
|
||||||
|
# logging.warning(f'{sinfo} has less than 5 valid data.')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def pretreat(input_path: Path, output_path: Path, img_size: int = 64, workers: int = 4, verbose: bool = False, dataset: str = 'CASIAB') -> None:
|
||||||
|
"""Reads a dataset and saves the data in pickle format.
|
||||||
|
Args:
|
||||||
|
input_path (Path): Dataset root path.
|
||||||
|
output_path (Path): Output path.
|
||||||
|
img_size (int, optional): Image resizing size. Defaults to 64.
|
||||||
|
workers (int, optional): Number of thread workers. Defaults to 4.
|
||||||
|
verbose (bool, optional): Display debug info. Defaults to False.
|
||||||
|
"""
|
||||||
|
img_groups = defaultdict(list)
|
||||||
|
logging.info(f'Listing {input_path}')
|
||||||
|
total_files = 0
|
||||||
|
for sid in tqdm(sorted(os.listdir(input_path))):
|
||||||
|
for seq in os.listdir(os.path.join(input_path,sid)):
|
||||||
|
for view in os.listdir(os.path.join(input_path,sid,seq)):
|
||||||
|
for img_path in os.listdir(os.path.join(input_path,sid,seq,view,'PCDs')):
|
||||||
|
img_groups[(sid, seq, view,'PCDs_offline_depths')].append(os.path.join(input_path,sid,seq,view, 'PCDs',img_path))
|
||||||
|
total_files += 1
|
||||||
|
|
||||||
|
logging.info(f'Total files listed: {total_files}')
|
||||||
|
|
||||||
|
progress = tqdm(total=len(img_groups), desc='Pretreating', unit='folder')
|
||||||
|
|
||||||
|
with mp.Pool(workers) as pool:
|
||||||
|
logging.info(f'Start pretreating {input_path}')
|
||||||
|
for _ in pool.imap_unordered(partial(pcd2depth, output_path=output_path, img_size=img_size, verbose=verbose, dataset=dataset), img_groups.items()):
|
||||||
|
progress.update(1)
|
||||||
|
logging.info('Done')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(description='OpenGait dataset pretreatment module.')
|
||||||
|
parser.add_argument('-i', '--input_path', default='', type=str, help='Root path of raw dataset.')
|
||||||
|
parser.add_argument('-o', '--output_path', default='', type=str, help='Output path of pickled dataset.')
|
||||||
|
parser.add_argument('-l', '--log_file', default='./pretreatment.log', type=str, help='Log file path. Default: ./pretreatment.log')
|
||||||
|
parser.add_argument('-n', '--n_workers', default=4, type=int, help='Number of thread workers. Default: 4')
|
||||||
|
parser.add_argument('-r', '--img_size', default=64, type=int, help='Image resizing size. Default 64')
|
||||||
|
parser.add_argument('-d', '--dataset', default='CASIAB', type=str, help='Dataset for pretreatment.')
|
||||||
|
parser.add_argument('-v', '--verbose', default=False, action='store_true', help='Display debug info.')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO, filename=args.log_file, filemode='w', format='[%(asctime)s - %(levelname)s]: %(message)s')
|
||||||
|
|
||||||
|
if args.verbose:
|
||||||
|
logging.getLogger().setLevel(logging.DEBUG)
|
||||||
|
logging.info('Verbose mode is on.')
|
||||||
|
for k, v in args.__dict__.items():
|
||||||
|
logging.debug(f'{k}: {v}')
|
||||||
|
|
||||||
|
pretreat(input_path=Path(args.input_path), output_path=Path(args.output_path), img_size=args.img_size, workers=args.n_workers, verbose=args.verbose, dataset=args.dataset)
|
||||||
@@ -0,0 +1,221 @@
|
|||||||
|
# This source is based on https://github.com/AbnerHqC/GaitSet/blob/master/pretreatment.py
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import multiprocessing as mp
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
from collections import defaultdict
|
||||||
|
from functools import partial
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import json
|
||||||
|
import open3d as o3d
|
||||||
|
|
||||||
|
def compare_pcd_rgb_timestamp(pcd_file,rgb_file):
|
||||||
|
pcd_time = float(pcd_file.split('/')[-1].replace('.pcd','')) + 0.05
|
||||||
|
rgb_time = float(rgb_file.split('/')[-1].replace('.jpg','')[:10] + '.' + rgb_file.split('/')[-1].replace('.jpg','')[10:])
|
||||||
|
return pcd_time, rgb_time
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def imgs2pickle(img_groups: Tuple, output_path: Path, img_size: int = 64, verbose: bool = False, dataset='CASIAB') -> None:
|
||||||
|
"""Reads a group of images and saves the data in pickle format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img_groups (Tuple): Tuple of (sid, seq, view) and list of image paths.
|
||||||
|
output_path (Path): Output path.
|
||||||
|
img_size (int, optional): Image resizing size. Defaults to 64.
|
||||||
|
verbose (bool, optional): Display debug info. Defaults to False.
|
||||||
|
"""
|
||||||
|
sinfo = img_groups[0]
|
||||||
|
img_paths = img_groups[1] # path with modality name
|
||||||
|
to_pickle = []
|
||||||
|
cnt = 0
|
||||||
|
pcd_list = []
|
||||||
|
rgb_list = []
|
||||||
|
|
||||||
|
threshold = 0.020 # 20 ms
|
||||||
|
|
||||||
|
for index, modality_files in enumerate(img_paths):
|
||||||
|
data_files = modality_files[1]
|
||||||
|
modality = modality_files[0]
|
||||||
|
if modality == 'PCDs':
|
||||||
|
data = [np.asarray(o3d.io.read_point_cloud(points).points) for points in data_files]
|
||||||
|
pcd_list = data_files
|
||||||
|
elif modality == 'RGB_raw':
|
||||||
|
imgs = [cv2.imread(rgb) for rgb in data_files]
|
||||||
|
rgb_list = data_files
|
||||||
|
imgs = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in imgs]
|
||||||
|
HWs = [img.shape[:2] for img in imgs]
|
||||||
|
# transpose to (C, H W)
|
||||||
|
data = [cv2.resize(img, (img_size, img_size), interpolation=cv2.INTER_CUBIC) for img in imgs]
|
||||||
|
imgs = [img.transpose(2, 0, 1) for img in imgs]
|
||||||
|
data = np.asarray(data)
|
||||||
|
HWs = np.asarray(HWs)
|
||||||
|
elif modality == 'Sils_raw':
|
||||||
|
sils = [cv2.imread(sil, cv2.IMREAD_GRAYSCALE) for sil in data_files]
|
||||||
|
data = [cv2.resize(sil, (img_size, img_size), interpolation=cv2.INTER_CUBIC) for sil in sils]
|
||||||
|
data = np.asarray(data)
|
||||||
|
elif modality == 'Sils_aligned':
|
||||||
|
sils = [cv2.imread(sil, cv2.IMREAD_GRAYSCALE) for sil in data_files]
|
||||||
|
data = [cv2.resize(sil, (img_size, img_size), interpolation=cv2.INTER_CUBIC) for sil in sils]
|
||||||
|
data = np.asarray(data)
|
||||||
|
elif modality == 'Pose':
|
||||||
|
data = [json.load(open(pose)) for pose in data_files]
|
||||||
|
data = np.asarray(data)
|
||||||
|
elif modality == 'PCDs_depths':
|
||||||
|
imgs = [cv2.imread(rgb) for rgb in data_files]
|
||||||
|
imgs = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in imgs]
|
||||||
|
data = [img.transpose(2, 0, 1) for img in imgs]
|
||||||
|
data = np.asarray(data)
|
||||||
|
elif modality == 'PCDs_sils':
|
||||||
|
data = [cv2.imread(sil, cv2.IMREAD_GRAYSCALE) for sil in data_files]
|
||||||
|
data = np.asarray(data)
|
||||||
|
|
||||||
|
dst_path = os.path.join(output_path, *sinfo)
|
||||||
|
os.makedirs(dst_path, exist_ok=True)
|
||||||
|
if modality == 'RGB_raw':
|
||||||
|
pkl_path = os.path.join(dst_path, f'{cnt:02d}-{sinfo[2]}-Camera-Ratios-HW.pkl')
|
||||||
|
pickle.dump(HWs, open(pkl_path, 'wb'))
|
||||||
|
cnt += 1
|
||||||
|
|
||||||
|
if 'PCDs' in modality:
|
||||||
|
pkl_path = os.path.join(dst_path, f'{cnt:02d}-{sinfo[2]}-LiDAR-{modality}.pkl')
|
||||||
|
pickle.dump(data, open(pkl_path, 'wb'))
|
||||||
|
else:
|
||||||
|
pkl_path = os.path.join(dst_path, f'{cnt:02d}-{sinfo[2]}-Camera-{modality}.pkl')
|
||||||
|
pickle.dump(data, open(pkl_path, 'wb'))
|
||||||
|
cnt += 1
|
||||||
|
|
||||||
|
pcd_indexs = []
|
||||||
|
rgb_indexs = []
|
||||||
|
# print(pcd_list)
|
||||||
|
for pcd_index in range(len(pcd_list)):
|
||||||
|
time_diff = 1
|
||||||
|
tmp = pcd_index, 0
|
||||||
|
for rgb_index in range(len(rgb_list)):
|
||||||
|
pcd_t, rgb_t = compare_pcd_rgb_timestamp(pcd_list[pcd_index], rgb_list[rgb_index])
|
||||||
|
diff = abs(pcd_t - rgb_t)
|
||||||
|
if diff < time_diff:
|
||||||
|
tmp = pcd_index, rgb_index
|
||||||
|
time_diff = diff
|
||||||
|
if time_diff <= threshold:
|
||||||
|
pcd_indexs.append(tmp[0])
|
||||||
|
rgb_indexs.append(tmp[1])
|
||||||
|
|
||||||
|
if len(set(pcd_indexs)) != len(pcd_indexs):
|
||||||
|
print(img_groups[0], pcd_indexs, rgb_indexs, len(pcd_indexs) == len(pcd_indexs))
|
||||||
|
|
||||||
|
for index, modality_files in enumerate(img_paths):
|
||||||
|
modality = modality_files[0]
|
||||||
|
data_files = modality_files[1]
|
||||||
|
data_files = [data_files[index] for index in pcd_indexs] if 'PCDs' in modality else [data_files[index] for index in rgb_indexs]
|
||||||
|
|
||||||
|
if modality == 'PCDs':
|
||||||
|
data = [np.asarray(o3d.io.read_point_cloud(points).points) for points in data_files]
|
||||||
|
pcd_list = data_files
|
||||||
|
elif modality == 'RGB_raw':
|
||||||
|
imgs = [cv2.imread(rgb) for rgb in data_files]
|
||||||
|
rgb_list = data_files
|
||||||
|
imgs = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in imgs]
|
||||||
|
HWs = [img.shape[:2] for img in imgs]
|
||||||
|
# transpose to (C, H W)
|
||||||
|
data = [cv2.resize(img, (img_size, img_size), interpolation=cv2.INTER_CUBIC) for img in imgs]
|
||||||
|
imgs = [img.transpose(2, 0, 1) for img in imgs]
|
||||||
|
data = np.asarray(data)
|
||||||
|
HWs = np.asarray(HWs)
|
||||||
|
elif modality == 'Sils_raw':
|
||||||
|
sils = [cv2.imread(sil, cv2.IMREAD_GRAYSCALE) for sil in data_files]
|
||||||
|
data = [cv2.resize(sil, (img_size, img_size), interpolation=cv2.INTER_CUBIC) for sil in sils]
|
||||||
|
data = np.asarray(data)
|
||||||
|
elif modality == 'Sils_aligned':
|
||||||
|
sils = [cv2.imread(sil, cv2.IMREAD_GRAYSCALE) for sil in data_files]
|
||||||
|
data = [cv2.resize(sil, (img_size, img_size), interpolation=cv2.INTER_CUBIC) for sil in sils]
|
||||||
|
data = np.asarray(data)
|
||||||
|
elif modality == 'Pose':
|
||||||
|
data = [json.load(open(pose)) for pose in data_files]
|
||||||
|
data = np.asarray(data)
|
||||||
|
elif modality == 'PCDs_depths':
|
||||||
|
imgs = [cv2.imread(rgb) for rgb in data_files]
|
||||||
|
imgs = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in imgs]
|
||||||
|
data = [img.transpose(2, 0, 1) for img in imgs]
|
||||||
|
data = np.asarray(data)
|
||||||
|
elif modality == 'PCDs_sils':
|
||||||
|
data = [cv2.imread(sil, cv2.IMREAD_GRAYSCALE) for sil in data_files]
|
||||||
|
data = np.asarray(data)
|
||||||
|
|
||||||
|
dst_path = os.path.join(output_path, *sinfo)
|
||||||
|
os.makedirs(dst_path, exist_ok=True)
|
||||||
|
if modality == 'RGB_raw':
|
||||||
|
pkl_path = os.path.join(dst_path, f'{cnt:02d}-sync-{sinfo[2]}-Camera-Ratios-HW.pkl')
|
||||||
|
pickle.dump(HWs, open(pkl_path, 'wb'))
|
||||||
|
cnt += 1
|
||||||
|
|
||||||
|
if 'PCDs' in modality:
|
||||||
|
pkl_path = os.path.join(dst_path, f'{cnt:02d}-sync-{sinfo[2]}-LiDAR-{modality}.pkl')
|
||||||
|
pickle.dump(data, open(pkl_path, 'wb'))
|
||||||
|
else:
|
||||||
|
pkl_path = os.path.join(dst_path, f'{cnt:02d}-sync-{sinfo[2]}-Camera-{modality}.pkl')
|
||||||
|
pickle.dump(data, open(pkl_path, 'wb'))
|
||||||
|
cnt += 1
|
||||||
|
|
||||||
|
|
||||||
|
def pretreat(input_path: Path, output_path: Path, img_size: int = 64, workers: int = 4, verbose: bool = False, dataset: str = 'CASIAB') -> None:
|
||||||
|
"""Reads a dataset and saves the data in pickle format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_path (Path): Dataset root path.
|
||||||
|
output_path (Path): Output path.
|
||||||
|
img_size (int, optional): Image resizing size. Defaults to 64.
|
||||||
|
workers (int, optional): Number of thread workers. Defaults to 4.
|
||||||
|
verbose (bool, optional): Display debug info. Defaults to False.
|
||||||
|
"""
|
||||||
|
img_groups = defaultdict(list)
|
||||||
|
logging.info(f'Listing {input_path}')
|
||||||
|
total_files = 0
|
||||||
|
for id_ in tqdm(sorted(os.listdir(input_path))):
|
||||||
|
for type_ in os.listdir(os.path.join(input_path,id_)):
|
||||||
|
for view_ in os.listdir(os.path.join(input_path,id_,type_)):
|
||||||
|
for modality in sorted(os.listdir(os.path.join(input_path,id_,type_,view_))):
|
||||||
|
modality_path = os.path.join(input_path,id_,type_,view_,modality)
|
||||||
|
file_names = sorted(os.listdir(modality_path))
|
||||||
|
file_names = [os.path.join(modality_path, file_name) for file_name in file_names]
|
||||||
|
img_groups[(id_, type_, view_)].append((modality, file_names))
|
||||||
|
total_files += 1
|
||||||
|
|
||||||
|
logging.info(f'Total files listed: {total_files}')
|
||||||
|
|
||||||
|
progress = tqdm(total=len(img_groups), desc='Pretreating', unit='folder')
|
||||||
|
|
||||||
|
with mp.Pool(workers) as pool:
|
||||||
|
logging.info(f'Start pretreating {input_path}')
|
||||||
|
for _ in pool.imap_unordered(partial(imgs2pickle, output_path=output_path, img_size=img_size, verbose=verbose, dataset=dataset), img_groups.items()):
|
||||||
|
progress.update(1)
|
||||||
|
logging.info('Done')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(description='OpenGait dataset pretreatment module.')
|
||||||
|
parser.add_argument('-i', '--input_path', default='', type=str, help='Root path of raw dataset.')
|
||||||
|
parser.add_argument('-o', '--output_path', default='', type=str, help='Output path of pickled dataset.')
|
||||||
|
parser.add_argument('-l', '--log_file', default='./pretreatment.log', type=str, help='Log file path. Default: ./pretreatment.log')
|
||||||
|
parser.add_argument('-n', '--n_workers', default=4, type=int, help='Number of thread workers. Default: 4')
|
||||||
|
parser.add_argument('-r', '--img_size', default=64, type=int, help='Image resizing size. Default 64')
|
||||||
|
parser.add_argument('-d', '--dataset', default='CASIAB', type=str, help='Dataset for pretreatment.')
|
||||||
|
parser.add_argument('-v', '--verbose', default=False, action='store_true', help='Display debug info.')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO, filename=args.log_file, filemode='w', format='[%(asctime)s - %(levelname)s]: %(message)s')
|
||||||
|
|
||||||
|
if args.verbose:
|
||||||
|
logging.getLogger().setLevel(logging.DEBUG)
|
||||||
|
logging.info('Verbose mode is on.')
|
||||||
|
for k, v in args.__dict__.items():
|
||||||
|
logging.debug(f'{k}: {v}')
|
||||||
|
|
||||||
|
pretreat(input_path=Path(args.input_path), output_path=Path(args.output_path), img_size=args.img_size, workers=args.n_workers, verbose=args.verbose, dataset=args.dataset)
|
||||||
@@ -74,46 +74,59 @@ def single_view_gallery_evaluation(feature, label, seq_type, view, dataset, metr
|
|||||||
'CASIA-E': {'NM': ['H-scene2-nm-1', 'H-scene2-nm-2', 'L-scene2-nm-1', 'L-scene2-nm-2', 'H-scene3-nm-1', 'H-scene3-nm-2', 'L-scene3-nm-1', 'L-scene3-nm-2', 'H-scene3_s-nm-1', 'H-scene3_s-nm-2', 'L-scene3_s-nm-1', 'L-scene3_s-nm-2', ],
|
'CASIA-E': {'NM': ['H-scene2-nm-1', 'H-scene2-nm-2', 'L-scene2-nm-1', 'L-scene2-nm-2', 'H-scene3-nm-1', 'H-scene3-nm-2', 'L-scene3-nm-1', 'L-scene3-nm-2', 'H-scene3_s-nm-1', 'H-scene3_s-nm-2', 'L-scene3_s-nm-1', 'L-scene3_s-nm-2', ],
|
||||||
'BG': ['H-scene2-bg-1', 'H-scene2-bg-2', 'L-scene2-bg-1', 'L-scene2-bg-2', 'H-scene3-bg-1', 'H-scene3-bg-2', 'L-scene3-bg-1', 'L-scene3-bg-2', 'H-scene3_s-bg-1', 'H-scene3_s-bg-2', 'L-scene3_s-bg-1', 'L-scene3_s-bg-2'],
|
'BG': ['H-scene2-bg-1', 'H-scene2-bg-2', 'L-scene2-bg-1', 'L-scene2-bg-2', 'H-scene3-bg-1', 'H-scene3-bg-2', 'L-scene3-bg-1', 'L-scene3-bg-2', 'H-scene3_s-bg-1', 'H-scene3_s-bg-2', 'L-scene3_s-bg-1', 'L-scene3_s-bg-2'],
|
||||||
'CL': ['H-scene2-cl-1', 'H-scene2-cl-2', 'L-scene2-cl-1', 'L-scene2-cl-2', 'H-scene3-cl-1', 'H-scene3-cl-2', 'L-scene3-cl-1', 'L-scene3-cl-2', 'H-scene3_s-cl-1', 'H-scene3_s-cl-2', 'L-scene3_s-cl-1', 'L-scene3_s-cl-2']
|
'CL': ['H-scene2-cl-1', 'H-scene2-cl-2', 'L-scene2-cl-1', 'L-scene2-cl-2', 'H-scene3-cl-1', 'H-scene3-cl-2', 'L-scene3-cl-1', 'L-scene3-cl-2', 'H-scene3_s-cl-1', 'H-scene3_s-cl-2', 'L-scene3_s-cl-1', 'L-scene3_s-cl-2']
|
||||||
}
|
},
|
||||||
|
'SUSTech1K': {'Normal': ['01-nm'], 'Bag': ['bg'], 'Clothing': ['cl'], 'Carrying':['cr'], 'Umberalla': ['ub'], 'Uniform': ['uf'], 'Occlusion': ['oc'],'Night': ['nt'], 'Overall': ['01','02','03','04']}
|
||||||
}
|
}
|
||||||
gallery_seq_dict = {'CASIA-B': ['nm-01', 'nm-02', 'nm-03', 'nm-04'],
|
gallery_seq_dict = {'CASIA-B': ['nm-01', 'nm-02', 'nm-03', 'nm-04'],
|
||||||
'OUMVLP': ['01'],
|
'OUMVLP': ['01'],
|
||||||
'CASIA-E': ['H-scene1-nm-1', 'H-scene1-nm-2', 'L-scene1-nm-1', 'L-scene1-nm-2']}
|
'CASIA-E': ['H-scene1-nm-1', 'H-scene1-nm-2', 'L-scene1-nm-1', 'L-scene1-nm-2'],
|
||||||
|
'SUSTech1K': ['00-nm'],}
|
||||||
msg_mgr = get_msg_mgr()
|
msg_mgr = get_msg_mgr()
|
||||||
acc = {}
|
acc = {}
|
||||||
view_list = sorted(np.unique(view))
|
view_list = sorted(np.unique(view))
|
||||||
|
num_rank = 1
|
||||||
if dataset == 'CASIA-E':
|
if dataset == 'CASIA-E':
|
||||||
view_list.remove("270")
|
view_list.remove("270")
|
||||||
|
if dataset == 'SUSTech1K':
|
||||||
|
num_rank = 5
|
||||||
view_num = len(view_list)
|
view_num = len(view_list)
|
||||||
num_rank = 1
|
|
||||||
for (type_, probe_seq) in probe_seq_dict[dataset].items():
|
for (type_, probe_seq) in probe_seq_dict[dataset].items():
|
||||||
acc[type_] = np.zeros((view_num, view_num)) - 1.
|
acc[type_] = np.zeros((view_num, view_num, num_rank)) - 1.
|
||||||
for (v1, probe_view) in enumerate(view_list):
|
for (v1, probe_view) in enumerate(view_list):
|
||||||
pseq_mask = np.isin(seq_type, probe_seq) & np.isin(
|
pseq_mask = np.isin(seq_type, probe_seq) & np.isin(
|
||||||
view, probe_view)
|
view, probe_view)
|
||||||
|
pseq_mask = pseq_mask if 'SUSTech1K' not in dataset else np.any(np.asarray(
|
||||||
|
[np.char.find(seq_type, probe)>=0 for probe in probe_seq]), axis=0
|
||||||
|
) & np.isin(view, probe_view) # For SUSTech1K only
|
||||||
probe_x = feature[pseq_mask, :]
|
probe_x = feature[pseq_mask, :]
|
||||||
probe_y = label[pseq_mask]
|
probe_y = label[pseq_mask]
|
||||||
|
|
||||||
for (v2, gallery_view) in enumerate(view_list):
|
for (v2, gallery_view) in enumerate(view_list):
|
||||||
gseq_mask = np.isin(seq_type, gallery_seq_dict[dataset]) & np.isin(
|
gseq_mask = np.isin(seq_type, gallery_seq_dict[dataset]) & np.isin(
|
||||||
view, [gallery_view])
|
view, [gallery_view])
|
||||||
|
gseq_mask = gseq_mask if 'SUSTech1K' not in dataset else np.any(np.asarray(
|
||||||
|
[np.char.find(seq_type, gallery)>=0 for gallery in gallery_seq_dict[dataset]]), axis=0
|
||||||
|
) & np.isin(view, [gallery_view]) # For SUSTech1K only
|
||||||
gallery_y = label[gseq_mask]
|
gallery_y = label[gseq_mask]
|
||||||
gallery_x = feature[gseq_mask, :]
|
gallery_x = feature[gseq_mask, :]
|
||||||
dist = cuda_dist(probe_x, gallery_x, metric)
|
dist = cuda_dist(probe_x, gallery_x, metric)
|
||||||
idx = dist.topk(num_rank, largest=False)[1].cpu().numpy()
|
idx = dist.topk(num_rank, largest=False)[1].cpu().numpy()
|
||||||
acc[type_][v1, v2] = np.round(np.sum(np.cumsum(np.reshape(probe_y, [-1, 1]) == gallery_y[idx], 1) > 0,
|
acc[type_][v1, v2, :] = np.round(np.sum(np.cumsum(np.reshape(probe_y, [-1, 1]) == gallery_y[idx[:, 0:num_rank]], 1) > 0,
|
||||||
0) * 100 / dist.shape[0], 2)
|
0) * 100 / dist.shape[0], 2)
|
||||||
|
|
||||||
result_dict = {}
|
result_dict = {}
|
||||||
msg_mgr.log_info('===Rank-1 (Exclude identical-view cases)===')
|
msg_mgr.log_info('===Rank-1 (Exclude identical-view cases)===')
|
||||||
out_str = ""
|
out_str = ""
|
||||||
for type_ in probe_seq_dict[dataset].keys():
|
for rank in range(num_rank):
|
||||||
sub_acc = de_diag(acc[type_], each_angle=True)
|
out_str = ""
|
||||||
msg_mgr.log_info(f'{type_}: {sub_acc}')
|
for type_ in probe_seq_dict[dataset].keys():
|
||||||
result_dict[f'scalar/test_accuracy/{type_}'] = np.mean(sub_acc)
|
sub_acc = de_diag(acc[type_][:,:,rank], each_angle=True)
|
||||||
out_str += f"{type_}: {np.mean(sub_acc):.2f}%\t"
|
if rank == 0:
|
||||||
msg_mgr.log_info(out_str)
|
msg_mgr.log_info(f'{type_}@R{rank+1}: {sub_acc}')
|
||||||
|
result_dict[f'scalar/test_accuracy/{type_}@R{rank+1}'] = np.mean(sub_acc)
|
||||||
|
out_str += f"{type_}@R{rank+1}: {np.mean(sub_acc):.2f}%\t"
|
||||||
|
msg_mgr.log_info(out_str)
|
||||||
return result_dict
|
return result_dict
|
||||||
|
|
||||||
|
|
||||||
@@ -122,7 +135,7 @@ def evaluate_indoor_dataset(data, dataset, metric='euc', cross_view_gallery=Fals
|
|||||||
label = np.array(label)
|
label = np.array(label)
|
||||||
view = np.array(view)
|
view = np.array(view)
|
||||||
|
|
||||||
if dataset not in ('CASIA-B', 'OUMVLP', 'CASIA-E'):
|
if dataset not in ('CASIA-B', 'OUMVLP', 'CASIA-E', 'SUSTech1K'):
|
||||||
raise KeyError("DataSet %s hasn't been supported !" % dataset)
|
raise KeyError("DataSet %s hasn't been supported !" % dataset)
|
||||||
if cross_view_gallery:
|
if cross_view_gallery:
|
||||||
return cross_view_gallery_evaluation(
|
return cross_view_gallery_evaluation(
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import torch
|
|||||||
from ..base_model import BaseModel
|
from ..base_model import BaseModel
|
||||||
from ..modules import SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWrapper, SeparateFCs, SeparateBNNecks
|
from ..modules import SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWrapper, SeparateFCs, SeparateBNNecks
|
||||||
|
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
class Baseline(BaseModel):
|
class Baseline(BaseModel):
|
||||||
|
|
||||||
@@ -20,6 +21,8 @@ class Baseline(BaseModel):
|
|||||||
sils = ipts[0]
|
sils = ipts[0]
|
||||||
if len(sils.size()) == 4:
|
if len(sils.size()) == 4:
|
||||||
sils = sils.unsqueeze(1)
|
sils = sils.unsqueeze(1)
|
||||||
|
else:
|
||||||
|
sils = rearrange(sils, 'n s c h w -> n c s h w')
|
||||||
|
|
||||||
del ipts
|
del ipts
|
||||||
outs = self.Backbone(sils) # [n, c, s, h, w]
|
outs = self.Backbone(sils) # [n, c, s, h, w]
|
||||||
@@ -33,17 +36,16 @@ class Baseline(BaseModel):
|
|||||||
embed_2, logits = self.BNNecks(embed_1) # [n, c, p]
|
embed_2, logits = self.BNNecks(embed_1) # [n, c, p]
|
||||||
embed = embed_1
|
embed = embed_1
|
||||||
|
|
||||||
n, _, s, h, w = sils.size()
|
|
||||||
retval = {
|
retval = {
|
||||||
'training_feat': {
|
'training_feat': {
|
||||||
'triplet': {'embeddings': embed_1, 'labels': labs},
|
'triplet': {'embeddings': embed_1, 'labels': labs},
|
||||||
'softmax': {'logits': logits, 'labels': labs}
|
'softmax': {'logits': logits, 'labels': labs}
|
||||||
},
|
},
|
||||||
'visual_summary': {
|
'visual_summary': {
|
||||||
'image/sils': sils.view(n*s, 1, h, w)
|
'image/sils': rearrange(sils,'n c s h w -> (n s) c h w')
|
||||||
},
|
},
|
||||||
'inference_feat': {
|
'inference_feat': {
|
||||||
'embeddings': embed
|
'embeddings': embed
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return retval
|
return retval
|
||||||
Reference in New Issue
Block a user