Merge pull request #212 from Bugjudger/master

BigGait Source Code
This commit is contained in:
Dongyang Jin
2024-05-11 12:00:02 +08:00
committed by GitHub
16 changed files with 1837 additions and 1 deletions
+3 -1
View File
@@ -12,6 +12,8 @@ OpenGait is a flexible and extensible gait recognition project provided by the [
The corresponding [paper](https://openaccess.thecvf.com/content/CVPR2023/papers/Fan_OpenGait_Revisiting_Gait_Recognition_Towards_Better_Practicality_CVPR_2023_paper.pdf) has been accepted by CVPR2023 as a highlight paper. The corresponding [paper](https://openaccess.thecvf.com/content/CVPR2023/papers/Fan_OpenGait_Revisiting_Gait_Recognition_Towards_Better_Practicality_CVPR_2023_paper.pdf) has been accepted by CVPR2023 as a highlight paper.
## What's New ## What's New
- **[May 2024]**
The code of Large Vison Model based method [BigGait](https://arxiv.org/pdf/2402.19122) is available at [here](opengait/modeling/models/BigGait.py).
- **[Apr 2024]** - **[Apr 2024]**
Our team's latest checkpoints for projects such as DeepGaitv2, SkeletonGait, SkeletonGait++, and SwinGait will be released on [Hugging Face](https://huggingface.co/opengait/OpenGait). Additionally, previously released checkpoints will also be gradually made available on it. Our team's latest checkpoints for projects such as DeepGaitv2, SkeletonGait, SkeletonGait++, and SwinGait will be released on [Hugging Face](https://huggingface.co/opengait/OpenGait). Additionally, previously released checkpoints will also be gradually made available on it.
- **[Mar 2024]** [Chao](https://chaofan996.github.io) gives a talk about 'Progress in Gait Recognition'. The [video](https://event.baai.ac.cn/activities/768) and [slides](https://github.com/ChaoFan996/ChaoFan996.github.io/blob/main/240315-Progress%20in%20Gait%20Recognition.pdf) are both available😊 - **[Mar 2024]** [Chao](https://chaofan996.github.io) gives a talk about 'Progress in Gait Recognition'. The [video](https://event.baai.ac.cn/activities/768) and [slides](https://github.com/ChaoFan996/ChaoFan996.github.io/blob/main/240315-Progress%20in%20Gait%20Recognition.pdf) are both available😊
@@ -30,7 +32,7 @@ Our team's latest checkpoints for projects such as DeepGaitv2, SkeletonGait, Ske
- [Mar 2022] Dataset [GREW](https://www.grew-benchmark.org) is supported in [datasets/GREW](./datasets/GREW). --> - [Mar 2022] Dataset [GREW](https://www.grew-benchmark.org) is supported in [datasets/GREW](./datasets/GREW). -->
## Our Publications ## Our Publications
- [**CVPR'24**] BigGait: Learning Gait Representation You Want by Large Vision Models. [*Paper*](https://arxiv.org/pdf/2402.19122.pdf), and *Code* (coming soon). - [**CVPR'24**] BigGait: Learning Gait Representation You Want by Large Vision Models. [*Paper*](https://arxiv.org/pdf/2402.19122.pdf), and [*Code*](opengait/modeling/models/BigGait.py).
- [**AAAI'24**] SkeletonGait++: Gait Recognition Using Skeleton Maps. [*Paper*](https://arxiv.org/pdf/2311.13444.pdf), and [*Code*](opengait/modeling/models/skeletongait%2B%2B.py). - [**AAAI'24**] SkeletonGait++: Gait Recognition Using Skeleton Maps. [*Paper*](https://arxiv.org/pdf/2311.13444.pdf), and [*Code*](opengait/modeling/models/skeletongait%2B%2B.py).
- [**AAAI'24**] Cross-Covariate Gait Recognition: A Benchmark. [*Paper*](https://arxiv.org/pdf/2312.14404.pdf), [*Dataset*](https://github.com/ShinanZou/CCGR), and [*Code*](https://github.com/ShiqiYu/OpenGait/blob/master/opengait/modeling/models/deepgaitv2.py). - [**AAAI'24**] Cross-Covariate Gait Recognition: A Benchmark. [*Paper*](https://arxiv.org/pdf/2312.14404.pdf), [*Dataset*](https://github.com/ShinanZou/CCGR), and [*Code*](https://github.com/ShiqiYu/OpenGait/blob/master/opengait/modeling/models/deepgaitv2.py).
- [**Arxiv'23**] Exploring Deep Models for Practical Gait Recognition. [*Paper*](https://arxiv.org/pdf/2303.03301.pdf), [*DeepGaitV2*](https://github.com/ShiqiYu/OpenGait/blob/master/opengait/modeling/models/deepgaitv2.py), and [*SwinGait*](https://github.com/ShiqiYu/OpenGait/blob/master/opengait/modeling/models/swingait.py). - [**Arxiv'23**] Exploring Deep Models for Practical Gait Recognition. [*Paper*](https://arxiv.org/pdf/2303.03301.pdf), [*DeepGaitV2*](https://github.com/ShiqiYu/OpenGait/blob/master/opengait/modeling/models/deepgaitv2.py), and [*SwinGait*](https://github.com/ShiqiYu/OpenGait/blob/master/opengait/modeling/models/swingait.py).
+150
View File
@@ -0,0 +1,150 @@
data_cfg:
dataset_name: CCPG
# TODO
dataset_root: /4GPU/data/CCPG/Released/CCPG-ratio-pkl/
dataset_partition: datasets/CCPG/CCPG.json
data_in_use: [True, True] # images / real_ratios
num_workers: 8
remove_no_gallery: false # Remove probe if no gallery for it
test_dataset_name: CCPG
evaluator_cfg:
enable_float16: true
restore_ckpt_strict: True
restore_hint: 40000
save_name: BigGait__Dinov2_Gaitbase_Frame30
eval_func: evaluate_CCPG
sampler:
batch_shuffle: false
batch_size: 8 # GPUs number
sample_type: all_ordered # all indicates whole sequence used to test, while ordered means input sequence by its natural order; Other options: fixed_unordered
frames_all_limit: 250 # limit the number of sampled frames to prevent out of memory
metric: euc # cos
transform:
- type: BaseRgbTransform
- type: NoOperation
loss_cfg:
- loss_term_weight: 1.0
margin: 0.2
type: TripletLoss
log_prefix: triplet
- loss_term_weight: 1.0
scale: 16
type: CrossEntropyLoss
log_prefix: softmax
log_accuracy: true
model_cfg:
model: BigGait__Dinov2_Gaitbase
pretrained_dinov2: pretrained_LVMs/dinov2_vits14_pretrain.pth # DINOv2 Download Link: https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth
pretrained_mask_branch: None
# pretrained_mask_branch: pretrained_LVMs/MaskBranch_vits14.pt # load pretrain Mask_Branch
image_size: 224 # 448x224
sils_size: 32 # 64x32
Denoising_Branch:
source_dim: 1536
target_dim: 16
p: 0
softmax: True
Relu: True
Up: False
Appearance_Branch:
source_dim: 1536
target_dim: 16
p: 0
softmax: False
Relu: False
Up: False
Mask_Branch:
source_dim: 384
target_dim: 2
p: 0.5
softmax: True
Relu: False
Up: True
AttentionFusion:
in_channels: 64
squeeze_ratio: 16
feat_len: 2
backbone_cfg:
type: ResNet9
block: BasicBlock
in_channel: 1
channels: # Layers configuration for automatically model construction
- 64
- 128
- 256
- 512
layers:
- 1
- 1
- 1
- 1
strides:
- 1
- 2
- 2
- 1
maxpool: false
SeparateFCs:
in_channels: 512
out_channels: 256
parts_num: 16
SeparateBNNecks:
class_num: 100
in_channels: 256
parts_num: 16
bin_num:
- 16
optimizer_cfg:
lr: 0.1
momentum: 0.9
solver: SGD
weight_decay: 0.0005
scheduler_cfg:
gamma: 0.1
milestones: # Learning Rate Reduction at each milestones
- 15000
- 25000
- 30000
- 35000
scheduler: MultiStepLR
trainer_cfg:
find_unused_parameters: True
enable_float16: true # half_percesion float for memory reduction and speedup
fix_BN: false
log_iter: 100
with_test: true
restore_ckpt_strict: true
restore_hint: 0
save_iter: 10000
save_name: BigGait__Dinov2_Gaitbase_Frame30
sync_BN: true
total_iter: 40000
sampler:
batch_shuffle: true
batch_size:
- 8 # TripletSampler, batch_size[0] indicates Number of Identity
- 8 # batch_size[1] indicates Samples sequqnce for each Identity
frames_num_fixed: 30 # fixed frames number for training
frames_skip_num: 4
frames_num_max: 40 # max frames number for unfixed training
frames_num_min: 20 # min frames number for unfixed traing
sample_type: fixed_unordered # fixed control input frames number, unordered for controlling order of input tensor; Other options: unfixed_ordered or all_ordered
type: TripletSampler
transform:
- type: Compose
trf_cfg:
- type: RandomHorizontalFlip
- type: BaseRgbTransform
- type: NoOperation
+71
View File
@@ -0,0 +1,71 @@
import os
from time import time
from multiprocessing import Pool
from tqdm import tqdm
import numpy as np
import os
import pickle
import numpy as np
import cv2
from tqdm import tqdm
SRC_0 = '../SUSTech1K-Released-2023_mask/'
DST_0 = '../SUSTech1K-Released-2023_mask_256128pkl/'
SRC = SRC_0 # Path_of_RGB_rearranged
DST = DST_0 # Path_of_RGB_256128pkl_PadResized
def resize_with_padding(img, target_size):
h, w, _ = img.shape
target_h, target_w = target_size
resized_img = cv2.resize(img, (int(w * target_h / h), target_h))
padded_img = np.zeros((target_h, target_w, 3), dtype=np.uint8)
x_offset = (target_w - resized_img.shape[1]) // 2
if x_offset < 0 :
x_offset = abs(x_offset)
padded_img = resized_img[:, x_offset:x_offset+target_w,:]
else:
padded_img[:, x_offset:x_offset + resized_img.shape[1]] = resized_img
return padded_img
def job(src, id):
for ty in sorted(os.listdir(os.path.join(src, id))):
for vi in sorted(os.listdir(os.path.join(src, id, ty))):
exist_file = os.path.join(DST, id, ty, vi, vi+"-aligned-rgbs.pkl")
if os.path.exists(exist_file):
print('Have Passed: ' + DST + '/' + id + '/' + ty)
continue
ratios = []
aligned_imgs = []
for img_file in sorted(os.listdir(os.path.join(src, id, ty, vi))):
img_path = os.path.join(src, id, ty, vi, img_file)
img = cv2.imread(img_path)
ratio = img.shape[1]/img.shape[0]
ratios.append(ratio)
aligned_img = np.transpose(cv2.cvtColor(resize_with_padding(img, (256, 128)), cv2.COLOR_BGR2RGB), (2, 0, 1))
aligned_imgs.append(aligned_img)
if len(aligned_imgs) > 0:
output_path = os.path.join(DST, id, ty, vi)
os.makedirs(output_path, exist_ok=True)
pickle.dump(np.asarray(aligned_imgs), open(os.path.join(output_path, vi+"-aligned-rgbs.pkl"), "wb"))
pickle.dump(np.asarray(ratios), open(os.path.join(output_path, vi+"-ratios.pkl"), "wb"))
print('Successfully saved: ' + DST + '/' + id + '/' + ty + '/' + vi)
if __name__ == '__main__':
a = time()
po = Pool(8)
src_path = SRC
cnt = 0
need_data = sorted(os.listdir(src_path))
for id in tqdm(need_data[:]):
po.apply_async(job,(src_path, id,))
cnt = cnt + 1
print('---START---')
po.close()
po.join()
print(cnt)
t = time() - a
print('---END---{}'.format(t))
+319
View File
@@ -0,0 +1,319 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
import torch
import torch.nn as nn
import torch.utils.checkpoint
from einops import rearrange
from ..base_model import BaseModel
from torch.nn import functional as F
from kornia import morphology as morph
import random
# import GaitBase & DINOv2_small
from .BigGait_utils.BigGait_GaitBase import Baseline
from .BigGait_utils.DINOv2 import vit_small
from .BigGait_utils.save_img import save_image, pca_image
# ######################################## BigGait ###########################################
class infoDistillation(nn.Module):
def __init__(self, source_dim, target_dim, p, softmax, Relu, Up=True):
super(infoDistillation, self).__init__()
self.dropout = nn.Dropout(p=p)
self.bn_s = nn.BatchNorm1d(source_dim, affine=False)
self.bn_t = nn.BatchNorm1d(target_dim, affine=False)
if Relu:
self.down_sampling = nn.Sequential(
nn.Linear(source_dim, source_dim//2),
nn.BatchNorm1d(source_dim//2, affine=False),
nn.GELU(),
nn.Linear(source_dim//2, target_dim),
)
if Up:
self.up_sampling = nn.Sequential(
nn.Linear(target_dim, source_dim//2),
nn.BatchNorm1d(source_dim//2, affine=False),
nn.GELU(),
nn.Linear(source_dim//2, source_dim),
)
else:
self.down_sampling = nn.Linear(source_dim, target_dim)
if Up:
self.up_sampling = nn.Linear(target_dim, source_dim)
self.softmax = softmax
self.mse = nn.MSELoss()
self.Up = Up
def forward(self, x):
# [n, c]
d_x = self.down_sampling(self.bn_s(self.dropout(x)))
if self.softmax:
d_x = F.softmax(d_x, dim=1)
if self.Up:
u_x = self.up_sampling(d_x)
return d_x, torch.mean(self.mse(u_x, x))
else:
return d_x, None
else:
if self.Up:
u_x = self.up_sampling(d_x)
return torch.sigmoid(self.bn_t(d_x)), torch.mean(self.mse(u_x, x))
else:
return torch.sigmoid(self.bn_t(d_x)), None
def padding_resize(x, ratios, target_h, target_w):
n,h,w = x.size(0),target_h, target_w
ratios = ratios.view(-1)
need_w = (h * ratios).int()
need_padding_mask = need_w < w
pad_left = torch.where(need_padding_mask, (w - need_w) // 2, torch.tensor(0).to(x.device))
pad_right = torch.where(need_padding_mask, w - need_w - pad_left, torch.tensor(0).to(x.device)).tolist()
need_w = need_w.tolist()
pad_left = pad_left.tolist()
x = torch.concat([F.pad(F.interpolate(x[i:i+1,...], (h, need_w[i]), mode="bilinear", align_corners=False), (pad_left[i], pad_right[i])) if need_padding_mask[i] else F.interpolate(x[i:i+1,...], (h, need_w[i]), mode="bilinear", align_corners=False)[...,pad_left[i]:pad_left[i]+w] for i in range(n)], dim=0)
return x
class BigGait__Dinov2_Gaitbase(BaseModel):
def build_network(self, model_cfg):
# get pretained models
self.pretrained_dinov2 = model_cfg["pretrained_dinov2"]
self.pretrained_mask_branch = model_cfg["pretrained_mask_branch"]
# set input size
self.image_size = model_cfg["image_size"]
self.sils_size = model_cfg["sils_size"]
# set feature dim
self.f4_dim = model_cfg["Mask_Branch"]['source_dim']
self.fc_dim = self.f4_dim*4
self.mask_dim = model_cfg["Mask_Branch"]['target_dim']
self.app_dim = model_cfg["Appearance_Branch"]['target_dim']
self.denoising_dim = model_cfg["Denoising_Branch"]['target_dim']
# init submodules
self.Denoising_Branch = infoDistillation(**model_cfg["Denoising_Branch"])
self.Appearance_Branch = infoDistillation(**model_cfg["Appearance_Branch"])
self.Mask_Branch = infoDistillation(**model_cfg["Mask_Branch"])
self.gait_net = Baseline(model_cfg)
def init_DINOv2(self):
self.backbone = vit_small(logger = self.msg_mgr)
self.msg_mgr.log_info(f'load model from: {self.pretrained_dinov2}')
pretrain_dict = torch.load(self.pretrained_dinov2)
msg = self.backbone.load_state_dict(pretrain_dict, strict=True)
n_parameters = sum(p.numel() for p in self.backbone.parameters())
self.msg_mgr.log_info('Missing keys: {}'.format(msg.missing_keys))
self.msg_mgr.log_info('Unexpected keys: {}'.format(msg.unexpected_keys))
self.msg_mgr.log_info(f"=> loaded successfully '{self.pretrained_dinov2}'")
self.msg_mgr.log_info('DINOv2 Count: {:.5f}M'.format(n_parameters / 1e6))
def init_Mask_Branch(self):
self.msg_mgr.log_info(f'load model from: {self.pretrained_mask_branch}')
load_dict = torch.load(self.pretrained_mask_branch, map_location=torch.device("cpu"))['model']
msg = self.Mask_Branch.load_state_dict(load_dict, strict=True)
n_parameters = sum(p.numel() for p in self.Mask_Branch.parameters())
self.msg_mgr.log_info('Missing keys: {}'.format(msg.missing_keys))
self.msg_mgr.log_info('Unexpected keys: {}'.format(msg.unexpected_keys))
self.msg_mgr.log_info(f"=> loaded successfully '{self.pretrained_mask_branch}'")
self.msg_mgr.log_info('SegmentationBranch Count: {:.5f}M'.format(n_parameters / 1e6))
def init_parameters(self):
for m in self.modules():
if isinstance(m, (nn.Conv3d, nn.Conv2d, nn.Conv1d)):
nn.init.xavier_uniform_(m.weight.data)
if m.bias is not None:
nn.init.constant_(m.bias.data, 0.0)
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight.data)
if m.bias is not None:
nn.init.constant_(m.bias.data, 0.0)
elif isinstance(m, (nn.BatchNorm3d, nn.BatchNorm2d, nn.BatchNorm1d)):
if m.affine:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0.0)
n_parameters = sum(p.numel() for p in self.parameters())
self.msg_mgr.log_info('Expect backbone Count: {:.5f}M'.format(n_parameters / 1e6))
self.init_DINOv2()
self.backbone.eval()
self.backbone.requires_grad_(False)
self.Mask_Branch.train()
self.Mask_Branch.requires_grad_(True)
n_parameters = sum(p.numel() for p in self.parameters())
self.msg_mgr.log_info('All Backbone Count: {:.5f}M'.format(n_parameters / 1e6))
self.msg_mgr.log_info("=> init successfully")
# resize image
def preprocess(self, sils, image_size, mode='bilinear'):
# shape: [nxs,c,h,w] / [nxs,c,224,112]
return F.interpolate(sils, (image_size*2, image_size), mode=mode, align_corners=False)
def min_max_norm(self, x):
return (x - x.min())/(x.max() - x.min())
# cal foreground
def get_body(self, mask):
# value: [0,1] shape: [nxs, h, w, c]
def judge_edge(image, edge=1):
# [nxs,h,w]
edge_pixel_count = image[:, :edge, :].sum(dim=(1,2)) + image[:, -edge:, :].sum(dim=(1,2))
return edge_pixel_count > (image.size(2)) * edge
condition_mask = torch.round(mask[...,0]) - mask[...,0].detach() + mask[...,0]
condition_mask = judge_edge(condition_mask, 5)
mask[condition_mask, :, :, 0] = mask[condition_mask, :, :, 1]
return mask[...,0]
def connect_loss(self, images, n, s, c):
images = images.view(n*s,c,self.sils_size*2,self.sils_size)
gradient_x = F.conv2d(images, torch.Tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]])[None,None,...].repeat(1,c,1,1).to(images.dtype).to(images.device), padding=1)
gradient_y = F.conv2d(images, torch.Tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]])[None,None,...].repeat(1,c,1,1).to(images.dtype).to(images.device), padding=1)
loss_connectivity = (torch.sum(torch.abs(gradient_x)) + torch.sum(torch.abs(gradient_y))) / (n*s*c*self.sils_size*2*self.sils_size)
return loss_connectivity
# Binarization and Closing operations to enhance foreground
def get_edge(self, sils, threshold=1):
mask_sils = torch.round(sils * threshold)
kernel = torch.ones((3,3))
dilated_mask = morph.dilation(mask_sils, kernel.to(sils.device)).detach() # Dilation
kernel = torch.ones((5,5))
eroded_mask = morph.erosion(dilated_mask, kernel.to(sils.device)).detach() # Erosion
edge_mask = (dilated_mask > 0.5) ^ (eroded_mask > 0.5)
sils = edge_mask * sils + (eroded_mask > 0.5) * torch.ones_like(sils, dtype=sils.dtype, device=sils.device)
return sils
def diversity_loss(self, images, max_p):
# [ns, hw, c]
p = torch.sum(images, dim=1) / (torch.sum(images, dim=(1,2)) + 1e-6).view(-1,1).repeat(1,max_p)
entropies = -torch.sum(p * torch.log2(p + 1e-6), dim=1)
max_p = torch.Tensor([1/max_p]).repeat(max_p).to(images.dtype).to(images.device)
max_entropies = -torch.sum(max_p * torch.log2(max_p), dim=0)
return torch.mean(max_entropies - entropies)
def forward(self, inputs):
if self.training:
if self.iteration == 500 and '.pt' in self.pretrained_mask_branch:
self.init_Mask_Branch()
if self.iteration >= 500:
self.Mask_Branch.eval()
self.Mask_Branch.requires_grad_(False)
ipts, labs, ty, vi, seqL = inputs
sils = ipts[0] # input_images; shape: [n,s,c,h,w];
ratios = ipts[1] # real_image_ratios shape: [n,s,ratio]; ratio: w/h, e.g. 112/224=0.5;
del ipts
with torch.no_grad():
n,s,c,h,w = sils.size()
sils = rearrange(sils, 'n s c h w -> (n s) c h w').contiguous()
if h == 2*w:
outs = self.preprocess(sils, self.image_size) # [ns,c,448,224] if have used pad_resize for input images
else:
outs = self.preprocess(padding_resize(sils, ratios, 256, 128), self.image_size) # [ns,c,448,224] if have not used pad_resize for input images
outs = self.backbone(outs, is_training=True) # [ns,h*w,c]
outs_last1 = outs["x_norm_patchtokens"].contiguous()
outs_last4 = outs["x_norm_patchtokens_mid4"].contiguous()
outs_last1 = rearrange(outs_last1.view(n, s, self.image_size//7, self.image_size//14, -1), 'n s h w c -> (n s) c h w').contiguous()
outs_last4 = rearrange(outs_last4.view(n, s, self.image_size//7, self.image_size//14, -1), 'n s h w c -> (n s) c h w').contiguous()
outs_last1 = self.preprocess(outs_last1, self.sils_size) # [ns,c,64,32]
outs_last4 = self.preprocess(outs_last4, self.sils_size) # [ns,c,64,32]
outs_last1 = rearrange(outs_last1.view(n, s, -1, self.sils_size*2, self.sils_size), 'n s c h w -> (n s) (h w) c').contiguous()
outs_last4 = rearrange(outs_last4.view(n, s, -1, self.sils_size*2, self.sils_size), 'n s c h w -> (n s) (h w) c').contiguous()
# get foreground
mask = torch.ones_like(outs_last1[...,0], device=outs_last1.device, dtype=outs_last1.dtype).view(n*s,1,self.sils_size*2,self.sils_size)
mask = padding_resize(mask, ratios, self.sils_size*2, self.sils_size)
foreground = outs_last1.view(-1, self.f4_dim)[mask.view(-1) != 0]
fore_feat, loss_mse1 = self.Mask_Branch(foreground)
foreground = torch.zeros_like(mask, dtype=fore_feat.dtype, device=fore_feat.device).view(-1,1).repeat(1,self.mask_dim)
foreground[mask.view(-1) != 0] = fore_feat
loss_connectivity_shape = self.connect_loss(foreground, n, s, self.mask_dim)
foreground = foreground.detach().clone()
foreground = self.get_body(foreground.view(n*s,self.sils_size*2,self.sils_size,self.mask_dim)).view(n*s,-1) # [n*s,h*w]
foreground = self.get_edge(foreground.view(n*s,1,self.sils_size*2,self.sils_size)).view(n*s,-1) # [n*s,h*w]
del fore_feat, mask
# get denosing
denosing = outs_last4.view(-1, self.fc_dim)[foreground.view(-1) != 0]
den_feat, _ = self.Denoising_Branch(denosing)
denosing = torch.zeros_like(foreground, dtype=den_feat.dtype, device=den_feat.device).view(-1,1).repeat(1,self.denoising_dim)
denosing[foreground.view(-1) != 0] = den_feat
loss_connectivity_part = self.connect_loss(denosing.view(n*s,-1,self.denoising_dim)[...,:-1].permute(0,2,1), n, s, (self.denoising_dim-1))
loss_diversity_part = self.diversity_loss(denosing.view(n*s,-1,self.denoising_dim), self.denoising_dim)
del den_feat
# get appearance
appearance = outs_last4.view(-1, self.fc_dim)[foreground.view(-1) != 0]
app_feat, _ = self.Appearance_Branch(appearance)
appearance = torch.zeros_like(foreground, dtype=app_feat.dtype, device=app_feat.device).view(-1,1).repeat(1,self.app_dim)
appearance[foreground.view(-1) != 0] = app_feat
appearance = appearance.view(n*s,-1,self.app_dim)
del app_feat
# vis
if self.training:
try:
vis_num = min(5, n*s)
vis_mask = foreground.view(n*s, self.sils_size*2*self.sils_size, -1)[:vis_num].detach().cpu().numpy()
vis_denosing = pca_image(data={'embeddings':denosing.view(n*s, self.sils_size*2*self.sils_size, -1)[:vis_num].detach().cpu().numpy()}, mask=vis_mask, root=None, model_name=None, dataset=None, n_components=3, is_return=True) # n s c h w
vis_appearance = pca_image(data={'embeddings':appearance.view(n*s, self.sils_size*2*self.sils_size, -1)[:vis_num].detach().cpu().numpy()}, mask=vis_mask, root=None, model_name=None, dataset=None, n_components=3, is_return=True) # n s c h w
except:
vis_denosing = torch.ones_like(foreground).view(n,s,1,self.sils_size*2,self.sils_size).detach().cpu().numpy()
vis_appearance = torch.ones_like(foreground).view(n,s,1,self.sils_size*2,self.sils_size).detach().cpu().numpy()
# Black DA
if self.training:
mask_idx = random.sample(list(range(n)), int(round(n*0.2)))
feat_list = [denosing.view(n,s,-1), appearance.view(n,s,-1)]
for i in mask_idx:
idx = random.sample(list(range(2)), 1)
for j in idx:
feat_list[j][i] = torch.zeros_like(feat_list[j][i], device=feat_list[j].device, dtype=feat_list[j].dtype)
# get embeding
embed_1, logits = self.gait_net(
denosing.view(n,s,self.sils_size*2,self.sils_size,self.denoising_dim).permute(0, 4, 1, 2, 3).contiguous(),
appearance.view(n,s,self.sils_size*2,self.sils_size,self.app_dim).permute(0, 4, 1, 2, 3).contiguous(),
seqL,
)
if self.training:
retval = {
'training_feat': {
'shape_connect':loss_connectivity_shape*0.02,
'shape_mse': loss_mse1,
'part_connect':loss_connectivity_part*0.01,
'part_diversity':loss_diversity_part*5,
'triplet': {'embeddings': embed_1, 'labels': labs},
'softmax': {'logits': logits, 'labels': labs},
},
'visual_summary': {
'image/input': sils.view(n*s, c, h, w),
'image/foreground': self.min_max_norm(rearrange(foreground.view(n, s, self.sils_size*2, self.sils_size, -1), 'n s h w c -> (n s) c h w').contiguous()),
'image/denosing':self.min_max_norm(rearrange(torch.from_numpy(vis_denosing).float(), 'n s c h w -> (n s) c h w').contiguous()),
'image/appearance': self.min_max_norm(rearrange(torch.from_numpy(vis_appearance).float(), 'n s c h w -> (n s) c h w').contiguous()),
},
'inference_feat': {
'embeddings': embed_1
}
}
else:
retval = {
'training_feat': {},
'visual_summary': {},
'inference_feat': {'embeddings': embed_1}
}
return retval
@@ -0,0 +1,190 @@
import torch
import torch.nn as nn
import torch.utils.checkpoint
from einops import rearrange
from ...modules import SetBlockWrapper, SeparateFCs, SeparateBNNecks, PackSequenceWrapper, HorizontalPoolingPyramid
from torch.nn import functional as F
# ######################################## GaitBase ###########################################
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
class AttentionFusion(nn.Module):
def __init__(self, in_channels, squeeze_ratio, feat_len):
super(AttentionFusion, self).__init__()
hidden_dim = int(in_channels / squeeze_ratio)
self.feat_len = feat_len
self.conv = SetBlockWrapper(
nn.Sequential(
conv1x1(in_channels * feat_len, hidden_dim),
nn.BatchNorm2d(hidden_dim),
nn.ReLU(inplace=True),
conv3x3(hidden_dim, hidden_dim),
nn.BatchNorm2d(hidden_dim),
nn.ReLU(inplace=True),
conv1x1(hidden_dim, in_channels * feat_len),
)
)
def forward(self, feat_list):
'''
sil_feat: [n, c, s, h, w]
map_feat: [n, c, s, h, w]
...
'''
feats = torch.cat(feat_list, dim=1)
score = self.conv(feats) # [n, 2 * c, s, h, w]
score = rearrange(score, 'n (c d) s h w -> n c d s h w', d=self.feat_len)
score = F.softmax(score, dim=2)
retun = feat_list[0]*score[:,:,0]
for i in range(1, self.feat_len):
retun += feat_list[i]*score[:,:,i]
return retun
from torchvision.models.resnet import BasicBlock, Bottleneck, ResNet
from ...modules import BasicConv2d
block_map = {'BasicBlock': BasicBlock,
'Bottleneck': Bottleneck}
class Pre_ResNet9(ResNet):
def __init__(self, type, block, channels=[32, 64, 128, 256], in_channel=1, layers=[1, 2, 2, 1], strides=[1, 2, 2, 1], maxpool=True):
if block in block_map.keys():
block = block_map[block]
else:
raise ValueError(
"Error type for -block-Cfg-, supported: 'BasicBlock' or 'Bottleneck'.")
self.maxpool_flag = maxpool
super(Pre_ResNet9, self).__init__(block, layers)
# Not used #
self.fc = None
self.layer2 = None
self.layer3 = None
self.layer4 = None
############
self.inplanes = channels[0]
self.bn1 = nn.BatchNorm2d(self.inplanes)
self.conv1 = BasicConv2d(in_channel, self.inplanes, 3, 1, 1)
self.layer1 = self._make_layer(
block, channels[0], layers[0], stride=strides[0], dilate=False)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
if blocks >= 1:
layer = super()._make_layer(block, planes, blocks, stride=stride, dilate=dilate)
else:
def layer(x): return x
return layer
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
if self.maxpool_flag:
x = self.maxpool(x)
x = self.layer1(x)
return x
class Post_ResNet9(ResNet):
def __init__(self, type, block, channels=[32, 64, 128, 256], in_channel=1, layers=[1, 2, 2, 1], strides=[1, 2, 2, 1], maxpool=True):
if block in block_map.keys():
block = block_map[block]
else:
raise ValueError(
"Error type for -block-Cfg-, supported: 'BasicBlock' or 'Bottleneck'.")
super(Post_ResNet9, self).__init__(block, layers)
# Not used #
self.fc = None
self.conv1 = None
self.bn1 = None
self.relu = None
self.layer1 = None
############
self.inplanes = channels[0]
self.layer2 = self._make_layer(
block, channels[1], layers[1], stride=strides[1], dilate=False)
self.layer3 = self._make_layer(
block, channels[2], layers[2], stride=strides[2], dilate=False)
self.layer4 = self._make_layer(
block, channels[3], layers[3], stride=strides[3], dilate=False)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
if blocks >= 1:
layer = super()._make_layer(block, planes, blocks, stride=stride, dilate=dilate)
else:
def layer(x): return x
return layer
def forward(self, x):
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
return x
from utils import get_valid_args, is_list, is_dict, np2var, ts2np, list2var, get_attr_from
from ... import backbones
class Baseline(nn.Module):
def __init__(self, model_cfg):
super(Baseline, self).__init__()
model_cfg['backbone_cfg']['in_channel'] = model_cfg['Denoising_Branch']['target_dim']
self.pre_part = SetBlockWrapper(Pre_ResNet9(**model_cfg['backbone_cfg']))
model_cfg['backbone_cfg']['in_channel'] = model_cfg['Appearance_Branch']['target_dim']
self.pre_rgb = SetBlockWrapper(Pre_ResNet9(**model_cfg['backbone_cfg']))
self.post_backbone = SetBlockWrapper(Post_ResNet9(**model_cfg['backbone_cfg']))
self.FCs = SeparateFCs(**model_cfg['SeparateFCs'])
self.BNNecks = SeparateBNNecks(**model_cfg['SeparateBNNecks'])
self.TP = PackSequenceWrapper(torch.max)
self.HPP = HorizontalPoolingPyramid(bin_num=model_cfg['bin_num'])
self.fusion = AttentionFusion(**model_cfg['AttentionFusion'])
def get_backbone(self, backbone_cfg):
"""Get the backbone of the model."""
if is_dict(backbone_cfg):
Backbone = get_attr_from([backbones], backbone_cfg['type'])
valid_args = get_valid_args(Backbone, backbone_cfg, ['type'])
return Backbone(**valid_args)
if is_list(backbone_cfg):
Backbone = nn.ModuleList([self.get_backbone(cfg)
for cfg in backbone_cfg])
return Backbone
raise ValueError(
"Error type for -Backbone-Cfg-, supported: (A list of) dict.")
def vis_forward(self, denosing, appearance, seqL):
denosing = self.pre_part(denosing) # [n, c, s, h, w]
appearance = self.pre_rgb(appearance) # [n, c, s, h, w]
outs = self.fusion([denosing, appearance])
return denosing, appearance, outs
def forward(self, denosing, appearance, seqL):
denosing = self.pre_part(denosing) # [n, c, s, h, w]
appearance = self.pre_rgb(appearance) # [n, c, s, h, w]
outs = self.fusion([denosing, appearance])
# heat_mapt = rearrange(outs, 'n c s h w -> n s h w c')
del denosing, appearance
outs = self.post_backbone(outs)
# Temporal Pooling, TP
outs = self.TP(outs, seqL, options={"dim": 2})[0] # [n, c, h, w]
# Horizontal Pooling Matching, HPM
outs = self.HPP(outs) # [n, c, p]
embed_1 = self.FCs(outs) # [n, c, p]
_, logits = self.BNNecks(embed_1) # [n, c, p]
# return embed_1, logits, heat_mapt
return embed_1, logits
@@ -0,0 +1,343 @@
from functools import partial
import math
from typing import Sequence, Tuple, Union, Callable
import torch
import torch.nn as nn
import torch.utils.checkpoint
from torch.nn.init import trunc_normal_
from .dino_layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
# ######################################## DINO ###########################################
def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
if not depth_first and include_root:
fn(module=module, name=name)
for child_name, child_module in module.named_children():
child_name = ".".join((name, child_name)) if name else child_name
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
if depth_first and include_root:
fn(module=module, name=name)
return module
def init_weights_vit_timm(module: nn.Module, name: str = ""):
"""ViT weight initialization, original timm impl (for reproducibility)"""
if isinstance(module, nn.Linear):
trunc_normal_(module.weight, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
class BlockChunk(nn.ModuleList):
def forward(self, x):
for b in self:
x = b(x)
return x
class DinoVisionTransformer(nn.Module):
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.0,
qkv_bias=True,
ffn_bias=True,
proj_bias=True,
drop_path_rate=0.0,
drop_path_uniform=False,
init_values=None, # for layerscale: None or 0 => no layerscale
embed_layer=PatchEmbed,
act_layer=nn.GELU,
block_fn=Block,
ffn_layer="mlp",
block_chunks=1,
logger = None
):
"""
Args:
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
in_chans (int): number of input channels
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
proj_bias (bool): enable bias for proj in attn if True
ffn_bias (bool): enable bias for ffn if True
drop_path_rate (float): stochastic depth rate
drop_path_uniform (bool): apply uniform drop rate across blocks
weight_init (str): weight init scheme
init_values (float): layer-scale init values
embed_layer (nn.Module): patch embedding layer
act_layer (nn.Module): MLP activation layer
block_fn (nn.Module): transformer block class
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
"""
super().__init__()
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_tokens = 1
self.n_blocks = depth
self.num_heads = num_heads
self.patch_size = patch_size
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
self.patch_embed.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
if drop_path_uniform is True:
dpr = [drop_path_rate] * depth
else:
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
if ffn_layer == "mlp":
logger.log_info("using MLP layer as FFN")
ffn_layer = Mlp
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
logger.log_info("using SwiGLU layer as FFN")
ffn_layer = SwiGLUFFNFused
elif ffn_layer == "identity":
logger.log_info("using Identity layer as FFN")
def f(*args, **kwargs):
return nn.Identity()
ffn_layer = f
else:
raise NotImplementedError
blocks_list = [
block_fn(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
ffn_bias=ffn_bias,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer,
ffn_layer=ffn_layer,
init_values=init_values,
)
for i in range(depth)
]
if block_chunks > 0:
self.chunked_blocks = True
chunked_blocks = []
chunksize = depth // block_chunks
for i in range(0, depth, chunksize):
# this is to keep the block index consistent if we chunk the block list
chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
else:
self.chunked_blocks = False
self.blocks = nn.ModuleList(blocks_list)
self.norm = norm_layer(embed_dim)
self.head = nn.Identity()
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
self.init_weights()
def init_weights(self):
trunc_normal_(self.pos_embed, std=0.02)
nn.init.normal_(self.cls_token, std=1e-6)
named_apply(init_weights_vit_timm, self)
def interpolate_pos_encoding(self, x, w, h):
previous_dtype = x.dtype
npatch = x.shape[1] - 1
N = self.pos_embed.shape[1] - 1
if npatch == N and w == h:
return self.pos_embed
pos_embed = self.pos_embed.float()
class_pos_embed = pos_embed[:, 0]
patch_pos_embed = pos_embed[:, 1:]
dim = x.shape[-1]
w0 = w // self.patch_size
h0 = h // self.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
w0, h0 = w0 + 0.1, h0 + 0.1
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
mode="bicubic",
)
assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
def prepare_tokens_with_masks(self, x, masks=None):
B, nc, w, h = x.shape
x = self.patch_embed(x)
if masks is not None:
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
x = x + self.interpolate_pos_encoding(x, w, h)
return x
def forward_features_list(self, x_list, masks_list):
x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
for blk in self.blocks:
x = blk(x)
all_x = x
output = []
for x, masks in zip(all_x, masks_list):
x_norm = self.norm(x)
output.append(
{
"x_norm_clstoken": x_norm[:, 0],
"x_norm_patchtokens": x_norm[:, 1:],
"x_prenorm": x,
"masks": masks,
}
)
return output
def forward_features(self, x, masks=None):
if isinstance(x, list):
return self.forward_features_list(x, masks)
x = self.prepare_tokens_with_masks(x, masks)
x_mid4 = []
# idx_mid4 = [2,5,8,11]
idx_mid4 = [int(i * len(self.blocks) / 4 + len(self.blocks) / 4 - 1) for i in range(4)]
assert len(idx_mid4) == 4
for i, blk in enumerate(self.blocks):
x = blk(x)
if i in idx_mid4:
x_mid4.append(x)
x_mid4 = partial(nn.LayerNorm, eps=1e-6)(x_mid4[0].shape[-1]*4, elementwise_affine=False)(torch.concat(x_mid4, dim=-1))
return {
"x_norm_patchtokens": self.norm(x)[:, 1:],
"x_norm_patchtokens_mid4": x_mid4[:, 1:],
}
# def forward_features(self, x, masks=None):
# if isinstance(x, list):
# return self.forward_features_list(x, masks)
# x = self.prepare_tokens_with_masks(x, masks)
# for blk in self.blocks:
# x = blk(x)
# x_norm = self.norm(x)
# return {
# "x_norm_clstoken": x_norm[:, 0],
# "x_norm_patchtokens": x_norm[:, 1:],
# "x_prenorm": x,
# "masks": masks,
# }
def _get_intermediate_layers_not_chunked(self, x, n=1):
x = self.prepare_tokens_with_masks(x)
# If n is an int, take the n last blocks. If it's a list, take them
output, total_block_len = [], len(self.blocks)
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
for i, blk in enumerate(self.blocks):
x = blk(x)
if i in blocks_to_take:
output.append(x)
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
return output
def _get_intermediate_layers_chunked(self, x, n=1):
x = self.prepare_tokens_with_masks(x)
output, i, total_block_len = [], 0, len(self.blocks[-1])
# If n is an int, take the n last blocks. If it's a list, take them
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
for block_chunk in self.blocks:
for blk in block_chunk[i:]: # Passing the nn.Identity()
x = blk(x)
if i in blocks_to_take:
output.append(x)
i += 1
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
return output
def get_intermediate_layers(
self,
x: torch.Tensor,
n: Union[int, Sequence] = 1, # Layers or n last layers to take
reshape: bool = False,
return_class_token: bool = False,
norm=True,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
if self.chunked_blocks:
outputs = self._get_intermediate_layers_chunked(x, n)
else:
outputs = self._get_intermediate_layers_not_chunked(x, n)
if norm:
outputs = [self.norm(out) for out in outputs]
class_tokens = [out[:, 0] for out in outputs]
outputs = [out[:, 1:] for out in outputs]
if reshape:
B, _, w, h = x.shape
outputs = [
out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
for out in outputs
]
if return_class_token:
return tuple(zip(outputs, class_tokens))
return tuple(outputs)
def forward(self, *args, is_training=False, **kwargs):
ret = self.forward_features(*args, **kwargs)
if is_training:
return ret
else:
return self.head(ret["x_norm_clstoken"])
def vit_small(patch_size=16, **kwargs):
model = DinoVisionTransformer(
img_size=518,
patch_size=14,
init_values=1.0,
ffn_layer="mlp",
block_chunks=0,
embed_dim=384,
depth=12,
num_heads=6,
mlp_ratio=4,
block_fn=partial(Block, attn_class=MemEffAttention),
**kwargs,
)
return model
def vit_large(patch_size=16, **kwargs):
model = DinoVisionTransformer(
img_size=518,
patch_size=14,
init_values=1.0,
ffn_layer="mlp",
block_chunks=0,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4,
block_fn=partial(Block, attn_class=MemEffAttention),
**kwargs,
)
return model
@@ -0,0 +1,12 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from .dino_head import DINOHead
from .mlp import Mlp
from .patch_embed import PatchEmbed
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
from .block import NestedTensorBlock
from .attention import MemEffAttention
@@ -0,0 +1,81 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
import logging
from torch import Tensor
from torch import nn
logger = logging.getLogger("dinov2")
try:
from xformers.ops import memory_efficient_attention, unbind, fmha
XFORMERS_AVAILABLE = True
except ImportError:
logger.warning("xFormers not available")
XFORMERS_AVAILABLE = False
class Attention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
proj_bias: bool = True,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
) -> None:
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim, bias=proj_bias)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: Tensor) -> Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class MemEffAttention(Attention):
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
if not XFORMERS_AVAILABLE:
assert attn_bias is None, "xFormers is required for nested tensors usage"
return super().forward(x)
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
q, k, v = unbind(qkv, 2)
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
x = x.reshape([B, N, C])
x = self.proj(x)
x = self.proj_drop(x)
return x
@@ -0,0 +1,252 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
import logging
from typing import Callable, List, Any, Tuple, Dict
import torch
from torch import nn, Tensor
from .attention import Attention, MemEffAttention
from .drop_path import DropPath
from .layer_scale import LayerScale
from .mlp import Mlp
logger = logging.getLogger("dinov2")
try:
from xformers.ops import fmha
from xformers.ops import scaled_index_add, index_select_cat
XFORMERS_AVAILABLE = True
except ImportError:
logger.warning("xFormers not available")
XFORMERS_AVAILABLE = False
class Block(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = False,
proj_bias: bool = True,
ffn_bias: bool = True,
drop: float = 0.0,
attn_drop: float = 0.0,
init_values=None,
drop_path: float = 0.0,
act_layer: Callable[..., nn.Module] = nn.GELU,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
attn_class: Callable[..., nn.Module] = Attention,
ffn_layer: Callable[..., nn.Module] = Mlp,
) -> None:
super().__init__()
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
self.norm1 = norm_layer(dim)
self.attn = attn_class(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
attn_drop=attn_drop,
proj_drop=drop,
)
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = ffn_layer(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
bias=ffn_bias,
)
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.sample_drop_ratio = drop_path
def forward(self, x: Tensor) -> Tensor:
def attn_residual_func(x: Tensor) -> Tensor:
return self.ls1(self.attn(self.norm1(x)))
def ffn_residual_func(x: Tensor) -> Tensor:
return self.ls2(self.mlp(self.norm2(x)))
if self.training and self.sample_drop_ratio > 0.1:
# the overhead is compensated only for a drop path rate larger than 0.1
x = drop_add_residual_stochastic_depth(
x,
residual_func=attn_residual_func,
sample_drop_ratio=self.sample_drop_ratio,
)
x = drop_add_residual_stochastic_depth(
x,
residual_func=ffn_residual_func,
sample_drop_ratio=self.sample_drop_ratio,
)
elif self.training and self.sample_drop_ratio > 0.0:
x = x + self.drop_path1(attn_residual_func(x))
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
else:
x = x + attn_residual_func(x)
x = x + ffn_residual_func(x)
return x
def drop_add_residual_stochastic_depth(
x: Tensor,
residual_func: Callable[[Tensor], Tensor],
sample_drop_ratio: float = 0.0,
) -> Tensor:
# 1) extract subset using permutation
b, n, d = x.shape
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
x_subset = x[brange]
# 2) apply residual_func to get residual
residual = residual_func(x_subset)
x_flat = x.flatten(1)
residual = residual.flatten(1)
residual_scale_factor = b / sample_subset_size
# 3) add the residual
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
return x_plus_residual.view_as(x)
def get_branges_scales(x, sample_drop_ratio=0.0):
b, n, d = x.shape
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
residual_scale_factor = b / sample_subset_size
return brange, residual_scale_factor
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
if scaling_vector is None:
x_flat = x.flatten(1)
residual = residual.flatten(1)
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
else:
x_plus_residual = scaled_index_add(
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
)
return x_plus_residual
attn_bias_cache: Dict[Tuple, Any] = {}
def get_attn_bias_and_cat(x_list, branges=None):
"""
this will perform the index select, cat the tensors, and provide the attn_bias from cache
"""
batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
if all_shapes not in attn_bias_cache.keys():
seqlens = []
for b, x in zip(batch_sizes, x_list):
for _ in range(b):
seqlens.append(x.shape[1])
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
attn_bias._batch_sizes = batch_sizes
attn_bias_cache[all_shapes] = attn_bias
if branges is not None:
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
else:
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
cat_tensors = torch.cat(tensors_bs1, dim=1)
return attn_bias_cache[all_shapes], cat_tensors
def drop_add_residual_stochastic_depth_list(
x_list: List[Tensor],
residual_func: Callable[[Tensor, Any], Tensor],
sample_drop_ratio: float = 0.0,
scaling_vector=None,
) -> Tensor:
# 1) generate random set of indices for dropping samples in the batch
branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
branges = [s[0] for s in branges_scales]
residual_scale_factors = [s[1] for s in branges_scales]
# 2) get attention bias and index+concat the tensors
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
# 3) apply residual_func to get residual, and split the result
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
outputs = []
for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
return outputs
class NestedTensorBlock(Block):
def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
"""
x_list contains a list of tensors to nest together and run
"""
assert isinstance(self.attn, MemEffAttention)
if self.training and self.sample_drop_ratio > 0.0:
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
return self.attn(self.norm1(x), attn_bias=attn_bias)
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
return self.mlp(self.norm2(x))
x_list = drop_add_residual_stochastic_depth_list(
x_list,
residual_func=attn_residual_func,
sample_drop_ratio=self.sample_drop_ratio,
scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
)
x_list = drop_add_residual_stochastic_depth_list(
x_list,
residual_func=ffn_residual_func,
sample_drop_ratio=self.sample_drop_ratio,
scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
)
return x_list
else:
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
return self.ls2(self.mlp(self.norm2(x)))
attn_bias, x = get_attn_bias_and_cat(x_list)
x = x + attn_residual_func(x, attn_bias=attn_bias)
x = x + ffn_residual_func(x)
return attn_bias.split(x)
def forward(self, x_or_x_list):
if isinstance(x_or_x_list, Tensor):
return super().forward(x_or_x_list)
elif isinstance(x_or_x_list, list):
assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
return self.forward_nested(x_or_x_list)
else:
raise AssertionError
@@ -0,0 +1,59 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
from torch.nn.init import trunc_normal_
from torch.nn.utils import weight_norm
class DINOHead(nn.Module):
def __init__(
self,
in_dim,
out_dim,
use_bn=False,
nlayers=3,
hidden_dim=2048,
bottleneck_dim=256,
mlp_bias=True,
):
super().__init__()
nlayers = max(nlayers, 1)
self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
self.apply(self._init_weights)
self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
self.last_layer.weight_g.data.fill_(1)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.mlp(x)
eps = 1e-6 if x.dtype == torch.float16 else 1e-12
x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
x = self.last_layer(x)
return x
def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
if nlayers == 1:
return nn.Linear(in_dim, bottleneck_dim, bias=bias)
else:
layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
if use_bn:
layers.append(nn.BatchNorm1d(hidden_dim))
layers.append(nn.GELU())
for _ in range(nlayers - 2):
layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
if use_bn:
layers.append(nn.BatchNorm1d(hidden_dim))
layers.append(nn.GELU())
layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
return nn.Sequential(*layers)
@@ -0,0 +1,35 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
from torch import nn
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0:
random_tensor.div_(keep_prob)
output = x * random_tensor
return output
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
@@ -0,0 +1,28 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
from typing import Union
import torch
from torch import Tensor
from torch import nn
class LayerScale(nn.Module):
def __init__(
self,
dim: int,
init_values: Union[float, Tensor] = 1e-5,
inplace: bool = False,
) -> None:
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x: Tensor) -> Tensor:
return x.mul_(self.gamma) if self.inplace else x * self.gamma
@@ -0,0 +1,41 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
from typing import Callable, Optional
from torch import Tensor, nn
class Mlp(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer: Callable[..., nn.Module] = nn.GELU,
drop: float = 0.0,
bias: bool = True,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
self.drop = nn.Dropout(drop)
def forward(self, x: Tensor) -> Tensor:
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
@@ -0,0 +1,90 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
from typing import Callable, Optional, Tuple, Union
from torch import Tensor
import torch.nn as nn
def make_2tuple(x):
if isinstance(x, tuple):
assert len(x) == 2
return x
assert isinstance(x, int)
return (x, x)
class PatchEmbed(nn.Module):
"""
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
Args:
img_size: Image size.
patch_size: Patch token size.
in_chans: Number of input image channels.
embed_dim: Number of linear projection output channels.
norm_layer: Normalization layer.
"""
def __init__(
self,
img_size: Union[int, Tuple[int, int]] = 224,
patch_size: Union[int, Tuple[int, int]] = 16,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer: Optional[Callable] = None,
flatten_embedding: bool = True,
) -> None:
super().__init__()
image_HW = make_2tuple(img_size)
patch_HW = make_2tuple(patch_size)
patch_grid_size = (
image_HW[0] // patch_HW[0],
image_HW[1] // patch_HW[1],
)
self.img_size = image_HW
self.patch_size = patch_HW
self.patches_resolution = patch_grid_size
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
self.flatten_embedding = flatten_embedding
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
# self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=(patch_HW[0]//2, patch_HW[1]//2))
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x: Tensor) -> Tensor:
_, _, H, W = x.shape
patch_H, patch_W = self.patch_size
# assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
# assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
x = self.proj(x) # B C H W
H, W = x.size(2), x.size(3)
x = x.flatten(2).transpose(1, 2) # B HW C
x = self.norm(x)
if not self.flatten_embedding:
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
return x
def flops(self) -> float:
Ho, Wo = self.patches_resolution
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
if self.norm is not None:
flops += Ho * Wo * self.embed_dim
return flops
@@ -0,0 +1,63 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Callable, Optional
from torch import Tensor, nn
import torch.nn.functional as F
class SwiGLUFFN(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer: Callable[..., nn.Module] = None,
drop: float = 0.0,
bias: bool = True,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
def forward(self, x: Tensor) -> Tensor:
x12 = self.w12(x)
x1, x2 = x12.chunk(2, dim=-1)
hidden = F.silu(x1) * x2
return self.w3(hidden)
try:
from xformers.ops import SwiGLU
XFORMERS_AVAILABLE = True
except ImportError:
SwiGLU = SwiGLUFFN
XFORMERS_AVAILABLE = False
class SwiGLUFFNFused(SwiGLU):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer: Callable[..., nn.Module] = None,
drop: float = 0.0,
bias: bool = True,
) -> None:
out_features = out_features or in_features
hidden_features = hidden_features or in_features
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
super().__init__(
in_features=in_features,
hidden_features=hidden_features,
out_features=out_features,
bias=bias,
)
@@ -0,0 +1,100 @@
from os import path as osp
import os
import pickle
from PIL import Image
import imageio
from glob import glob
import matplotlib.pyplot as plt
import numpy as np
from sklearn.decomposition import PCA
from sklearn.preprocessing import minmax_scale
import cv2
def pca_image(data, mask, root, model_name, dataset, n_components=3, is_return=False):
features = data['embeddings']
ns,hw,c = features.shape
features = features.reshape(ns*hw,c)
mask = mask.reshape(ns*hw)
pca = PCA(n_components=n_components)
pca_features = pca.fit_transform(features[mask != 0])
pca_features = minmax_scale(pca_features, (0,255), axis=1)
# pca_features = minmax_scale(pca_features, (0,255), axis=0)
norm_features = np.zeros_like(mask,dtype=np.uint8).reshape(ns*hw,1).repeat(n_components,axis=1)
norm_features[mask != 0] = pca_features
if is_return:
norm_features = norm_features.reshape(1,ns,64,32,n_components)[...,:3].transpose(0,1,4,2,3) #
return norm_features
s = 20
assert ns % s == 0
norm_features = norm_features.reshape(ns//s,s,64,32,n_components)[...,:3].transpose(0,1,4,2,3)
data['embeddings'] = norm_features
save_image(data, root, model_name, dataset, need='image')
def save_image(data, root, model_name, dataset, need='image', mask=None):
images, label, seq_type, view = data['embeddings'], data['labels'], data['types'], data['views'] # n s c h w
if "image" in need:
root_path = os.path.join(root, dataset, model_name+'_image')
os.makedirs(os.path.join(root_path),exist_ok=True)
for i, id in enumerate(label[:]):
tmp = os.path.join(root_path, str(id).zfill(5), str(seq_type[i]), str(view[i]))
os.makedirs(tmp, exist_ok=True)
mb = None if mask is None else mask[i]
save_func(tmp, images[i], need, mb)
save_gif(tmp, tmp, str(view[i]))
if 'pkl' in need:
root_path = os.path.join(root, dataset, model_name+'_pkl')
os.makedirs(os.path.join(root_path),exist_ok=True)
for i, id in enumerate(label[:]):
tmp = os.path.join(root_path, str(id).zfill(5), str(seq_type[i]), str(view[i]))
os.makedirs(tmp, exist_ok=True)
mb = None if mask is None else mask[i]
save_func(tmp, images[i], 'pkl', mb)
if 'w' in need:
root_path = os.path.join(root, dataset, model_name+'_w')
os.makedirs(os.path.join(root_path),exist_ok=True)
for i, id in enumerate(label[:]):
tmp = os.path.join(root_path, str(id).zfill(5), str(seq_type[i]), str(view[i]))
os.makedirs(tmp, exist_ok=True)
mb = None if mask is None else mask[i]
save_func(tmp, data['w'], 'w', mb)
return
def save_func(tmp, data, ipts_type='image', mask=None):
if 'image' in ipts_type :
for i, con in enumerate(data):
if con.shape[0] == 1:
if 'jet' in ipts_type :
im = ((cv2.applyColorMap(con[0], cv2.COLORMAP_JET) * 0.5)[...,::-1] + 1.0*mask[i])
# im = mask[i]
im = np.clip(im,0,255).astype(np.uint8)
im = Image.fromarray(im, mode='RGB') # [h,w,c]
else:
im = Image.fromarray(con[0], mode='L')
else:
im = Image.fromarray(con.transpose(1,2,0), mode='RGB')
im.save(os.path.join(tmp, '%03d.png' % i))
elif ipts_type == 'pkl':
with open(os.path.join(tmp,'00.pkl'), 'wb') as f:
pickle.dump(data[:,0,:,:], f)
elif ipts_type == 'w':
for i in range(len(data)):
with open(os.path.join(tmp, str(i).zfill(2) + '.pkl'), 'wb') as f:
pickle.dump(data[i], f)
def save_gif(image_folder, save_folder, name="movie"):
images = []
filenames = sorted(glob(osp.join(image_folder, '*.png')))
# print(filenames)
for filename in filenames:
images.append(imageio.imread(filename))
imageio.mimsave(os.path.join(save_folder, f'{name}.gif'), images, duration=50, loop=0)