@@ -12,6 +12,8 @@ OpenGait is a flexible and extensible gait recognition project provided by the [
|
|||||||
The corresponding [paper](https://openaccess.thecvf.com/content/CVPR2023/papers/Fan_OpenGait_Revisiting_Gait_Recognition_Towards_Better_Practicality_CVPR_2023_paper.pdf) has been accepted by CVPR2023 as a highlight paper.
|
The corresponding [paper](https://openaccess.thecvf.com/content/CVPR2023/papers/Fan_OpenGait_Revisiting_Gait_Recognition_Towards_Better_Practicality_CVPR_2023_paper.pdf) has been accepted by CVPR2023 as a highlight paper.
|
||||||
|
|
||||||
## What's New
|
## What's New
|
||||||
|
- **[May 2024]**
|
||||||
|
The code of Large Vison Model based method [BigGait](https://arxiv.org/pdf/2402.19122) is available at [here](opengait/modeling/models/BigGait.py).
|
||||||
- **[Apr 2024]**
|
- **[Apr 2024]**
|
||||||
Our team's latest checkpoints for projects such as DeepGaitv2, SkeletonGait, SkeletonGait++, and SwinGait will be released on [Hugging Face](https://huggingface.co/opengait/OpenGait). Additionally, previously released checkpoints will also be gradually made available on it.
|
Our team's latest checkpoints for projects such as DeepGaitv2, SkeletonGait, SkeletonGait++, and SwinGait will be released on [Hugging Face](https://huggingface.co/opengait/OpenGait). Additionally, previously released checkpoints will also be gradually made available on it.
|
||||||
- **[Mar 2024]** [Chao](https://chaofan996.github.io) gives a talk about 'Progress in Gait Recognition'. The [video](https://event.baai.ac.cn/activities/768) and [slides](https://github.com/ChaoFan996/ChaoFan996.github.io/blob/main/240315-Progress%20in%20Gait%20Recognition.pdf) are both available😊
|
- **[Mar 2024]** [Chao](https://chaofan996.github.io) gives a talk about 'Progress in Gait Recognition'. The [video](https://event.baai.ac.cn/activities/768) and [slides](https://github.com/ChaoFan996/ChaoFan996.github.io/blob/main/240315-Progress%20in%20Gait%20Recognition.pdf) are both available😊
|
||||||
@@ -30,7 +32,7 @@ Our team's latest checkpoints for projects such as DeepGaitv2, SkeletonGait, Ske
|
|||||||
- [Mar 2022] Dataset [GREW](https://www.grew-benchmark.org) is supported in [datasets/GREW](./datasets/GREW). -->
|
- [Mar 2022] Dataset [GREW](https://www.grew-benchmark.org) is supported in [datasets/GREW](./datasets/GREW). -->
|
||||||
|
|
||||||
## Our Publications
|
## Our Publications
|
||||||
- [**CVPR'24**] BigGait: Learning Gait Representation You Want by Large Vision Models. [*Paper*](https://arxiv.org/pdf/2402.19122.pdf), and *Code* (coming soon).
|
- [**CVPR'24**] BigGait: Learning Gait Representation You Want by Large Vision Models. [*Paper*](https://arxiv.org/pdf/2402.19122.pdf), and [*Code*](opengait/modeling/models/BigGait.py).
|
||||||
- [**AAAI'24**] SkeletonGait++: Gait Recognition Using Skeleton Maps. [*Paper*](https://arxiv.org/pdf/2311.13444.pdf), and [*Code*](opengait/modeling/models/skeletongait%2B%2B.py).
|
- [**AAAI'24**] SkeletonGait++: Gait Recognition Using Skeleton Maps. [*Paper*](https://arxiv.org/pdf/2311.13444.pdf), and [*Code*](opengait/modeling/models/skeletongait%2B%2B.py).
|
||||||
- [**AAAI'24**] Cross-Covariate Gait Recognition: A Benchmark. [*Paper*](https://arxiv.org/pdf/2312.14404.pdf), [*Dataset*](https://github.com/ShinanZou/CCGR), and [*Code*](https://github.com/ShiqiYu/OpenGait/blob/master/opengait/modeling/models/deepgaitv2.py).
|
- [**AAAI'24**] Cross-Covariate Gait Recognition: A Benchmark. [*Paper*](https://arxiv.org/pdf/2312.14404.pdf), [*Dataset*](https://github.com/ShinanZou/CCGR), and [*Code*](https://github.com/ShiqiYu/OpenGait/blob/master/opengait/modeling/models/deepgaitv2.py).
|
||||||
- [**Arxiv'23**] Exploring Deep Models for Practical Gait Recognition. [*Paper*](https://arxiv.org/pdf/2303.03301.pdf), [*DeepGaitV2*](https://github.com/ShiqiYu/OpenGait/blob/master/opengait/modeling/models/deepgaitv2.py), and [*SwinGait*](https://github.com/ShiqiYu/OpenGait/blob/master/opengait/modeling/models/swingait.py).
|
- [**Arxiv'23**] Exploring Deep Models for Practical Gait Recognition. [*Paper*](https://arxiv.org/pdf/2303.03301.pdf), [*DeepGaitV2*](https://github.com/ShiqiYu/OpenGait/blob/master/opengait/modeling/models/deepgaitv2.py), and [*SwinGait*](https://github.com/ShiqiYu/OpenGait/blob/master/opengait/modeling/models/swingait.py).
|
||||||
|
|||||||
@@ -0,0 +1,150 @@
|
|||||||
|
data_cfg:
|
||||||
|
dataset_name: CCPG
|
||||||
|
# TODO
|
||||||
|
dataset_root: /4GPU/data/CCPG/Released/CCPG-ratio-pkl/
|
||||||
|
dataset_partition: datasets/CCPG/CCPG.json
|
||||||
|
data_in_use: [True, True] # images / real_ratios
|
||||||
|
num_workers: 8
|
||||||
|
remove_no_gallery: false # Remove probe if no gallery for it
|
||||||
|
test_dataset_name: CCPG
|
||||||
|
|
||||||
|
evaluator_cfg:
|
||||||
|
enable_float16: true
|
||||||
|
restore_ckpt_strict: True
|
||||||
|
restore_hint: 40000
|
||||||
|
save_name: BigGait__Dinov2_Gaitbase_Frame30
|
||||||
|
eval_func: evaluate_CCPG
|
||||||
|
sampler:
|
||||||
|
batch_shuffle: false
|
||||||
|
batch_size: 8 # GPUs number
|
||||||
|
sample_type: all_ordered # all indicates whole sequence used to test, while ordered means input sequence by its natural order; Other options: fixed_unordered
|
||||||
|
frames_all_limit: 250 # limit the number of sampled frames to prevent out of memory
|
||||||
|
metric: euc # cos
|
||||||
|
transform:
|
||||||
|
- type: BaseRgbTransform
|
||||||
|
- type: NoOperation
|
||||||
|
|
||||||
|
loss_cfg:
|
||||||
|
- loss_term_weight: 1.0
|
||||||
|
margin: 0.2
|
||||||
|
type: TripletLoss
|
||||||
|
log_prefix: triplet
|
||||||
|
- loss_term_weight: 1.0
|
||||||
|
scale: 16
|
||||||
|
type: CrossEntropyLoss
|
||||||
|
log_prefix: softmax
|
||||||
|
log_accuracy: true
|
||||||
|
|
||||||
|
model_cfg:
|
||||||
|
model: BigGait__Dinov2_Gaitbase
|
||||||
|
pretrained_dinov2: pretrained_LVMs/dinov2_vits14_pretrain.pth # DINOv2 Download Link: https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth
|
||||||
|
pretrained_mask_branch: None
|
||||||
|
# pretrained_mask_branch: pretrained_LVMs/MaskBranch_vits14.pt # load pretrain Mask_Branch
|
||||||
|
image_size: 224 # 448x224
|
||||||
|
sils_size: 32 # 64x32
|
||||||
|
|
||||||
|
Denoising_Branch:
|
||||||
|
source_dim: 1536
|
||||||
|
target_dim: 16
|
||||||
|
p: 0
|
||||||
|
softmax: True
|
||||||
|
Relu: True
|
||||||
|
Up: False
|
||||||
|
|
||||||
|
Appearance_Branch:
|
||||||
|
source_dim: 1536
|
||||||
|
target_dim: 16
|
||||||
|
p: 0
|
||||||
|
softmax: False
|
||||||
|
Relu: False
|
||||||
|
Up: False
|
||||||
|
|
||||||
|
Mask_Branch:
|
||||||
|
source_dim: 384
|
||||||
|
target_dim: 2
|
||||||
|
p: 0.5
|
||||||
|
softmax: True
|
||||||
|
Relu: False
|
||||||
|
Up: True
|
||||||
|
|
||||||
|
AttentionFusion:
|
||||||
|
in_channels: 64
|
||||||
|
squeeze_ratio: 16
|
||||||
|
feat_len: 2
|
||||||
|
|
||||||
|
backbone_cfg:
|
||||||
|
type: ResNet9
|
||||||
|
block: BasicBlock
|
||||||
|
in_channel: 1
|
||||||
|
channels: # Layers configuration for automatically model construction
|
||||||
|
- 64
|
||||||
|
- 128
|
||||||
|
- 256
|
||||||
|
- 512
|
||||||
|
layers:
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
strides:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 2
|
||||||
|
- 1
|
||||||
|
maxpool: false
|
||||||
|
SeparateFCs:
|
||||||
|
in_channels: 512
|
||||||
|
out_channels: 256
|
||||||
|
parts_num: 16
|
||||||
|
SeparateBNNecks:
|
||||||
|
class_num: 100
|
||||||
|
in_channels: 256
|
||||||
|
parts_num: 16
|
||||||
|
bin_num:
|
||||||
|
- 16
|
||||||
|
|
||||||
|
optimizer_cfg:
|
||||||
|
lr: 0.1
|
||||||
|
momentum: 0.9
|
||||||
|
solver: SGD
|
||||||
|
weight_decay: 0.0005
|
||||||
|
|
||||||
|
scheduler_cfg:
|
||||||
|
gamma: 0.1
|
||||||
|
milestones: # Learning Rate Reduction at each milestones
|
||||||
|
- 15000
|
||||||
|
- 25000
|
||||||
|
- 30000
|
||||||
|
- 35000
|
||||||
|
scheduler: MultiStepLR
|
||||||
|
|
||||||
|
|
||||||
|
trainer_cfg:
|
||||||
|
find_unused_parameters: True
|
||||||
|
enable_float16: true # half_percesion float for memory reduction and speedup
|
||||||
|
fix_BN: false
|
||||||
|
log_iter: 100
|
||||||
|
with_test: true
|
||||||
|
restore_ckpt_strict: true
|
||||||
|
restore_hint: 0
|
||||||
|
save_iter: 10000
|
||||||
|
save_name: BigGait__Dinov2_Gaitbase_Frame30
|
||||||
|
sync_BN: true
|
||||||
|
total_iter: 40000
|
||||||
|
sampler:
|
||||||
|
batch_shuffle: true
|
||||||
|
batch_size:
|
||||||
|
- 8 # TripletSampler, batch_size[0] indicates Number of Identity
|
||||||
|
- 8 # batch_size[1] indicates Samples sequqnce for each Identity
|
||||||
|
frames_num_fixed: 30 # fixed frames number for training
|
||||||
|
frames_skip_num: 4
|
||||||
|
frames_num_max: 40 # max frames number for unfixed training
|
||||||
|
frames_num_min: 20 # min frames number for unfixed traing
|
||||||
|
sample_type: fixed_unordered # fixed control input frames number, unordered for controlling order of input tensor; Other options: unfixed_ordered or all_ordered
|
||||||
|
type: TripletSampler
|
||||||
|
transform:
|
||||||
|
- type: Compose
|
||||||
|
trf_cfg:
|
||||||
|
- type: RandomHorizontalFlip
|
||||||
|
- type: BaseRgbTransform
|
||||||
|
- type: NoOperation
|
||||||
@@ -0,0 +1,71 @@
|
|||||||
|
import os
|
||||||
|
from time import time
|
||||||
|
from multiprocessing import Pool
|
||||||
|
from tqdm import tqdm
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
SRC_0 = '../SUSTech1K-Released-2023_mask/'
|
||||||
|
DST_0 = '../SUSTech1K-Released-2023_mask_256128pkl/'
|
||||||
|
|
||||||
|
SRC = SRC_0 # Path_of_RGB_rearranged
|
||||||
|
DST = DST_0 # Path_of_RGB_256128pkl_PadResized
|
||||||
|
|
||||||
|
def resize_with_padding(img, target_size):
|
||||||
|
h, w, _ = img.shape
|
||||||
|
target_h, target_w = target_size
|
||||||
|
resized_img = cv2.resize(img, (int(w * target_h / h), target_h))
|
||||||
|
padded_img = np.zeros((target_h, target_w, 3), dtype=np.uint8)
|
||||||
|
x_offset = (target_w - resized_img.shape[1]) // 2
|
||||||
|
if x_offset < 0 :
|
||||||
|
x_offset = abs(x_offset)
|
||||||
|
padded_img = resized_img[:, x_offset:x_offset+target_w,:]
|
||||||
|
else:
|
||||||
|
padded_img[:, x_offset:x_offset + resized_img.shape[1]] = resized_img
|
||||||
|
return padded_img
|
||||||
|
|
||||||
|
def job(src, id):
|
||||||
|
for ty in sorted(os.listdir(os.path.join(src, id))):
|
||||||
|
for vi in sorted(os.listdir(os.path.join(src, id, ty))):
|
||||||
|
exist_file = os.path.join(DST, id, ty, vi, vi+"-aligned-rgbs.pkl")
|
||||||
|
if os.path.exists(exist_file):
|
||||||
|
print('Have Passed: ' + DST + '/' + id + '/' + ty)
|
||||||
|
continue
|
||||||
|
ratios = []
|
||||||
|
aligned_imgs = []
|
||||||
|
for img_file in sorted(os.listdir(os.path.join(src, id, ty, vi))):
|
||||||
|
img_path = os.path.join(src, id, ty, vi, img_file)
|
||||||
|
img = cv2.imread(img_path)
|
||||||
|
ratio = img.shape[1]/img.shape[0]
|
||||||
|
ratios.append(ratio)
|
||||||
|
aligned_img = np.transpose(cv2.cvtColor(resize_with_padding(img, (256, 128)), cv2.COLOR_BGR2RGB), (2, 0, 1))
|
||||||
|
aligned_imgs.append(aligned_img)
|
||||||
|
if len(aligned_imgs) > 0:
|
||||||
|
output_path = os.path.join(DST, id, ty, vi)
|
||||||
|
os.makedirs(output_path, exist_ok=True)
|
||||||
|
pickle.dump(np.asarray(aligned_imgs), open(os.path.join(output_path, vi+"-aligned-rgbs.pkl"), "wb"))
|
||||||
|
pickle.dump(np.asarray(ratios), open(os.path.join(output_path, vi+"-ratios.pkl"), "wb"))
|
||||||
|
print('Successfully saved: ' + DST + '/' + id + '/' + ty + '/' + vi)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
a = time()
|
||||||
|
po = Pool(8)
|
||||||
|
src_path = SRC
|
||||||
|
|
||||||
|
cnt = 0
|
||||||
|
need_data = sorted(os.listdir(src_path))
|
||||||
|
for id in tqdm(need_data[:]):
|
||||||
|
po.apply_async(job,(src_path, id,))
|
||||||
|
cnt = cnt + 1
|
||||||
|
|
||||||
|
print('---START---')
|
||||||
|
po.close()
|
||||||
|
po.join()
|
||||||
|
print(cnt)
|
||||||
|
|
||||||
|
t = time() - a
|
||||||
|
print('---END---{}'.format(t))
|
||||||
@@ -0,0 +1,319 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
# References:
|
||||||
|
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
|
||||||
|
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
from einops import rearrange
|
||||||
|
from ..base_model import BaseModel
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from kornia import morphology as morph
|
||||||
|
import random
|
||||||
|
|
||||||
|
# import GaitBase & DINOv2_small
|
||||||
|
from .BigGait_utils.BigGait_GaitBase import Baseline
|
||||||
|
from .BigGait_utils.DINOv2 import vit_small
|
||||||
|
from .BigGait_utils.save_img import save_image, pca_image
|
||||||
|
|
||||||
|
# ######################################## BigGait ###########################################
|
||||||
|
|
||||||
|
class infoDistillation(nn.Module):
|
||||||
|
def __init__(self, source_dim, target_dim, p, softmax, Relu, Up=True):
|
||||||
|
super(infoDistillation, self).__init__()
|
||||||
|
self.dropout = nn.Dropout(p=p)
|
||||||
|
self.bn_s = nn.BatchNorm1d(source_dim, affine=False)
|
||||||
|
self.bn_t = nn.BatchNorm1d(target_dim, affine=False)
|
||||||
|
if Relu:
|
||||||
|
self.down_sampling = nn.Sequential(
|
||||||
|
nn.Linear(source_dim, source_dim//2),
|
||||||
|
nn.BatchNorm1d(source_dim//2, affine=False),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(source_dim//2, target_dim),
|
||||||
|
)
|
||||||
|
if Up:
|
||||||
|
self.up_sampling = nn.Sequential(
|
||||||
|
nn.Linear(target_dim, source_dim//2),
|
||||||
|
nn.BatchNorm1d(source_dim//2, affine=False),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(source_dim//2, source_dim),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.down_sampling = nn.Linear(source_dim, target_dim)
|
||||||
|
if Up:
|
||||||
|
self.up_sampling = nn.Linear(target_dim, source_dim)
|
||||||
|
self.softmax = softmax
|
||||||
|
self.mse = nn.MSELoss()
|
||||||
|
self.Up = Up
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# [n, c]
|
||||||
|
d_x = self.down_sampling(self.bn_s(self.dropout(x)))
|
||||||
|
if self.softmax:
|
||||||
|
d_x = F.softmax(d_x, dim=1)
|
||||||
|
if self.Up:
|
||||||
|
u_x = self.up_sampling(d_x)
|
||||||
|
return d_x, torch.mean(self.mse(u_x, x))
|
||||||
|
else:
|
||||||
|
return d_x, None
|
||||||
|
else:
|
||||||
|
if self.Up:
|
||||||
|
u_x = self.up_sampling(d_x)
|
||||||
|
return torch.sigmoid(self.bn_t(d_x)), torch.mean(self.mse(u_x, x))
|
||||||
|
else:
|
||||||
|
return torch.sigmoid(self.bn_t(d_x)), None
|
||||||
|
|
||||||
|
|
||||||
|
def padding_resize(x, ratios, target_h, target_w):
|
||||||
|
n,h,w = x.size(0),target_h, target_w
|
||||||
|
ratios = ratios.view(-1)
|
||||||
|
need_w = (h * ratios).int()
|
||||||
|
need_padding_mask = need_w < w
|
||||||
|
pad_left = torch.where(need_padding_mask, (w - need_w) // 2, torch.tensor(0).to(x.device))
|
||||||
|
pad_right = torch.where(need_padding_mask, w - need_w - pad_left, torch.tensor(0).to(x.device)).tolist()
|
||||||
|
need_w = need_w.tolist()
|
||||||
|
pad_left = pad_left.tolist()
|
||||||
|
x = torch.concat([F.pad(F.interpolate(x[i:i+1,...], (h, need_w[i]), mode="bilinear", align_corners=False), (pad_left[i], pad_right[i])) if need_padding_mask[i] else F.interpolate(x[i:i+1,...], (h, need_w[i]), mode="bilinear", align_corners=False)[...,pad_left[i]:pad_left[i]+w] for i in range(n)], dim=0)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class BigGait__Dinov2_Gaitbase(BaseModel):
|
||||||
|
def build_network(self, model_cfg):
|
||||||
|
# get pretained models
|
||||||
|
self.pretrained_dinov2 = model_cfg["pretrained_dinov2"]
|
||||||
|
self.pretrained_mask_branch = model_cfg["pretrained_mask_branch"]
|
||||||
|
|
||||||
|
# set input size
|
||||||
|
self.image_size = model_cfg["image_size"]
|
||||||
|
self.sils_size = model_cfg["sils_size"]
|
||||||
|
|
||||||
|
# set feature dim
|
||||||
|
self.f4_dim = model_cfg["Mask_Branch"]['source_dim']
|
||||||
|
self.fc_dim = self.f4_dim*4
|
||||||
|
self.mask_dim = model_cfg["Mask_Branch"]['target_dim']
|
||||||
|
self.app_dim = model_cfg["Appearance_Branch"]['target_dim']
|
||||||
|
self.denoising_dim = model_cfg["Denoising_Branch"]['target_dim']
|
||||||
|
|
||||||
|
# init submodules
|
||||||
|
self.Denoising_Branch = infoDistillation(**model_cfg["Denoising_Branch"])
|
||||||
|
self.Appearance_Branch = infoDistillation(**model_cfg["Appearance_Branch"])
|
||||||
|
self.Mask_Branch = infoDistillation(**model_cfg["Mask_Branch"])
|
||||||
|
self.gait_net = Baseline(model_cfg)
|
||||||
|
|
||||||
|
def init_DINOv2(self):
|
||||||
|
self.backbone = vit_small(logger = self.msg_mgr)
|
||||||
|
self.msg_mgr.log_info(f'load model from: {self.pretrained_dinov2}')
|
||||||
|
pretrain_dict = torch.load(self.pretrained_dinov2)
|
||||||
|
msg = self.backbone.load_state_dict(pretrain_dict, strict=True)
|
||||||
|
n_parameters = sum(p.numel() for p in self.backbone.parameters())
|
||||||
|
self.msg_mgr.log_info('Missing keys: {}'.format(msg.missing_keys))
|
||||||
|
self.msg_mgr.log_info('Unexpected keys: {}'.format(msg.unexpected_keys))
|
||||||
|
self.msg_mgr.log_info(f"=> loaded successfully '{self.pretrained_dinov2}'")
|
||||||
|
self.msg_mgr.log_info('DINOv2 Count: {:.5f}M'.format(n_parameters / 1e6))
|
||||||
|
|
||||||
|
def init_Mask_Branch(self):
|
||||||
|
self.msg_mgr.log_info(f'load model from: {self.pretrained_mask_branch}')
|
||||||
|
load_dict = torch.load(self.pretrained_mask_branch, map_location=torch.device("cpu"))['model']
|
||||||
|
msg = self.Mask_Branch.load_state_dict(load_dict, strict=True)
|
||||||
|
n_parameters = sum(p.numel() for p in self.Mask_Branch.parameters())
|
||||||
|
self.msg_mgr.log_info('Missing keys: {}'.format(msg.missing_keys))
|
||||||
|
self.msg_mgr.log_info('Unexpected keys: {}'.format(msg.unexpected_keys))
|
||||||
|
self.msg_mgr.log_info(f"=> loaded successfully '{self.pretrained_mask_branch}'")
|
||||||
|
self.msg_mgr.log_info('SegmentationBranch Count: {:.5f}M'.format(n_parameters / 1e6))
|
||||||
|
|
||||||
|
def init_parameters(self):
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, (nn.Conv3d, nn.Conv2d, nn.Conv1d)):
|
||||||
|
nn.init.xavier_uniform_(m.weight.data)
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias.data, 0.0)
|
||||||
|
elif isinstance(m, nn.Linear):
|
||||||
|
nn.init.xavier_uniform_(m.weight.data)
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias.data, 0.0)
|
||||||
|
elif isinstance(m, (nn.BatchNorm3d, nn.BatchNorm2d, nn.BatchNorm1d)):
|
||||||
|
if m.affine:
|
||||||
|
nn.init.normal_(m.weight.data, 1.0, 0.02)
|
||||||
|
nn.init.constant_(m.bias.data, 0.0)
|
||||||
|
|
||||||
|
n_parameters = sum(p.numel() for p in self.parameters())
|
||||||
|
self.msg_mgr.log_info('Expect backbone Count: {:.5f}M'.format(n_parameters / 1e6))
|
||||||
|
|
||||||
|
self.init_DINOv2()
|
||||||
|
self.backbone.eval()
|
||||||
|
self.backbone.requires_grad_(False)
|
||||||
|
|
||||||
|
self.Mask_Branch.train()
|
||||||
|
self.Mask_Branch.requires_grad_(True)
|
||||||
|
|
||||||
|
n_parameters = sum(p.numel() for p in self.parameters())
|
||||||
|
self.msg_mgr.log_info('All Backbone Count: {:.5f}M'.format(n_parameters / 1e6))
|
||||||
|
self.msg_mgr.log_info("=> init successfully")
|
||||||
|
|
||||||
|
# resize image
|
||||||
|
def preprocess(self, sils, image_size, mode='bilinear'):
|
||||||
|
# shape: [nxs,c,h,w] / [nxs,c,224,112]
|
||||||
|
return F.interpolate(sils, (image_size*2, image_size), mode=mode, align_corners=False)
|
||||||
|
|
||||||
|
def min_max_norm(self, x):
|
||||||
|
return (x - x.min())/(x.max() - x.min())
|
||||||
|
|
||||||
|
# cal foreground
|
||||||
|
def get_body(self, mask):
|
||||||
|
# value: [0,1] shape: [nxs, h, w, c]
|
||||||
|
def judge_edge(image, edge=1):
|
||||||
|
# [nxs,h,w]
|
||||||
|
edge_pixel_count = image[:, :edge, :].sum(dim=(1,2)) + image[:, -edge:, :].sum(dim=(1,2))
|
||||||
|
return edge_pixel_count > (image.size(2)) * edge
|
||||||
|
condition_mask = torch.round(mask[...,0]) - mask[...,0].detach() + mask[...,0]
|
||||||
|
condition_mask = judge_edge(condition_mask, 5)
|
||||||
|
mask[condition_mask, :, :, 0] = mask[condition_mask, :, :, 1]
|
||||||
|
return mask[...,0]
|
||||||
|
|
||||||
|
def connect_loss(self, images, n, s, c):
|
||||||
|
images = images.view(n*s,c,self.sils_size*2,self.sils_size)
|
||||||
|
gradient_x = F.conv2d(images, torch.Tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]])[None,None,...].repeat(1,c,1,1).to(images.dtype).to(images.device), padding=1)
|
||||||
|
gradient_y = F.conv2d(images, torch.Tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]])[None,None,...].repeat(1,c,1,1).to(images.dtype).to(images.device), padding=1)
|
||||||
|
loss_connectivity = (torch.sum(torch.abs(gradient_x)) + torch.sum(torch.abs(gradient_y))) / (n*s*c*self.sils_size*2*self.sils_size)
|
||||||
|
return loss_connectivity
|
||||||
|
|
||||||
|
# Binarization and Closing operations to enhance foreground
|
||||||
|
def get_edge(self, sils, threshold=1):
|
||||||
|
mask_sils = torch.round(sils * threshold)
|
||||||
|
kernel = torch.ones((3,3))
|
||||||
|
dilated_mask = morph.dilation(mask_sils, kernel.to(sils.device)).detach() # Dilation
|
||||||
|
kernel = torch.ones((5,5))
|
||||||
|
eroded_mask = morph.erosion(dilated_mask, kernel.to(sils.device)).detach() # Erosion
|
||||||
|
edge_mask = (dilated_mask > 0.5) ^ (eroded_mask > 0.5)
|
||||||
|
sils = edge_mask * sils + (eroded_mask > 0.5) * torch.ones_like(sils, dtype=sils.dtype, device=sils.device)
|
||||||
|
return sils
|
||||||
|
|
||||||
|
def diversity_loss(self, images, max_p):
|
||||||
|
# [ns, hw, c]
|
||||||
|
p = torch.sum(images, dim=1) / (torch.sum(images, dim=(1,2)) + 1e-6).view(-1,1).repeat(1,max_p)
|
||||||
|
entropies = -torch.sum(p * torch.log2(p + 1e-6), dim=1)
|
||||||
|
max_p = torch.Tensor([1/max_p]).repeat(max_p).to(images.dtype).to(images.device)
|
||||||
|
max_entropies = -torch.sum(max_p * torch.log2(max_p), dim=0)
|
||||||
|
return torch.mean(max_entropies - entropies)
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
if self.training:
|
||||||
|
if self.iteration == 500 and '.pt' in self.pretrained_mask_branch:
|
||||||
|
self.init_Mask_Branch()
|
||||||
|
if self.iteration >= 500:
|
||||||
|
self.Mask_Branch.eval()
|
||||||
|
self.Mask_Branch.requires_grad_(False)
|
||||||
|
|
||||||
|
ipts, labs, ty, vi, seqL = inputs
|
||||||
|
sils = ipts[0] # input_images; shape: [n,s,c,h,w];
|
||||||
|
ratios = ipts[1] # real_image_ratios shape: [n,s,ratio]; ratio: w/h, e.g. 112/224=0.5;
|
||||||
|
del ipts
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
n,s,c,h,w = sils.size()
|
||||||
|
sils = rearrange(sils, 'n s c h w -> (n s) c h w').contiguous()
|
||||||
|
if h == 2*w:
|
||||||
|
outs = self.preprocess(sils, self.image_size) # [ns,c,448,224] if have used pad_resize for input images
|
||||||
|
else:
|
||||||
|
outs = self.preprocess(padding_resize(sils, ratios, 256, 128), self.image_size) # [ns,c,448,224] if have not used pad_resize for input images
|
||||||
|
outs = self.backbone(outs, is_training=True) # [ns,h*w,c]
|
||||||
|
outs_last1 = outs["x_norm_patchtokens"].contiguous()
|
||||||
|
outs_last4 = outs["x_norm_patchtokens_mid4"].contiguous()
|
||||||
|
|
||||||
|
outs_last1 = rearrange(outs_last1.view(n, s, self.image_size//7, self.image_size//14, -1), 'n s h w c -> (n s) c h w').contiguous()
|
||||||
|
outs_last4 = rearrange(outs_last4.view(n, s, self.image_size//7, self.image_size//14, -1), 'n s h w c -> (n s) c h w').contiguous()
|
||||||
|
outs_last1 = self.preprocess(outs_last1, self.sils_size) # [ns,c,64,32]
|
||||||
|
outs_last4 = self.preprocess(outs_last4, self.sils_size) # [ns,c,64,32]
|
||||||
|
outs_last1 = rearrange(outs_last1.view(n, s, -1, self.sils_size*2, self.sils_size), 'n s c h w -> (n s) (h w) c').contiguous()
|
||||||
|
outs_last4 = rearrange(outs_last4.view(n, s, -1, self.sils_size*2, self.sils_size), 'n s c h w -> (n s) (h w) c').contiguous()
|
||||||
|
|
||||||
|
# get foreground
|
||||||
|
mask = torch.ones_like(outs_last1[...,0], device=outs_last1.device, dtype=outs_last1.dtype).view(n*s,1,self.sils_size*2,self.sils_size)
|
||||||
|
mask = padding_resize(mask, ratios, self.sils_size*2, self.sils_size)
|
||||||
|
foreground = outs_last1.view(-1, self.f4_dim)[mask.view(-1) != 0]
|
||||||
|
fore_feat, loss_mse1 = self.Mask_Branch(foreground)
|
||||||
|
foreground = torch.zeros_like(mask, dtype=fore_feat.dtype, device=fore_feat.device).view(-1,1).repeat(1,self.mask_dim)
|
||||||
|
foreground[mask.view(-1) != 0] = fore_feat
|
||||||
|
loss_connectivity_shape = self.connect_loss(foreground, n, s, self.mask_dim)
|
||||||
|
foreground = foreground.detach().clone()
|
||||||
|
foreground = self.get_body(foreground.view(n*s,self.sils_size*2,self.sils_size,self.mask_dim)).view(n*s,-1) # [n*s,h*w]
|
||||||
|
foreground = self.get_edge(foreground.view(n*s,1,self.sils_size*2,self.sils_size)).view(n*s,-1) # [n*s,h*w]
|
||||||
|
del fore_feat, mask
|
||||||
|
|
||||||
|
# get denosing
|
||||||
|
denosing = outs_last4.view(-1, self.fc_dim)[foreground.view(-1) != 0]
|
||||||
|
den_feat, _ = self.Denoising_Branch(denosing)
|
||||||
|
denosing = torch.zeros_like(foreground, dtype=den_feat.dtype, device=den_feat.device).view(-1,1).repeat(1,self.denoising_dim)
|
||||||
|
denosing[foreground.view(-1) != 0] = den_feat
|
||||||
|
loss_connectivity_part = self.connect_loss(denosing.view(n*s,-1,self.denoising_dim)[...,:-1].permute(0,2,1), n, s, (self.denoising_dim-1))
|
||||||
|
loss_diversity_part = self.diversity_loss(denosing.view(n*s,-1,self.denoising_dim), self.denoising_dim)
|
||||||
|
del den_feat
|
||||||
|
|
||||||
|
# get appearance
|
||||||
|
appearance = outs_last4.view(-1, self.fc_dim)[foreground.view(-1) != 0]
|
||||||
|
app_feat, _ = self.Appearance_Branch(appearance)
|
||||||
|
appearance = torch.zeros_like(foreground, dtype=app_feat.dtype, device=app_feat.device).view(-1,1).repeat(1,self.app_dim)
|
||||||
|
appearance[foreground.view(-1) != 0] = app_feat
|
||||||
|
appearance = appearance.view(n*s,-1,self.app_dim)
|
||||||
|
del app_feat
|
||||||
|
|
||||||
|
# vis
|
||||||
|
if self.training:
|
||||||
|
try:
|
||||||
|
vis_num = min(5, n*s)
|
||||||
|
vis_mask = foreground.view(n*s, self.sils_size*2*self.sils_size, -1)[:vis_num].detach().cpu().numpy()
|
||||||
|
vis_denosing = pca_image(data={'embeddings':denosing.view(n*s, self.sils_size*2*self.sils_size, -1)[:vis_num].detach().cpu().numpy()}, mask=vis_mask, root=None, model_name=None, dataset=None, n_components=3, is_return=True) # n s c h w
|
||||||
|
vis_appearance = pca_image(data={'embeddings':appearance.view(n*s, self.sils_size*2*self.sils_size, -1)[:vis_num].detach().cpu().numpy()}, mask=vis_mask, root=None, model_name=None, dataset=None, n_components=3, is_return=True) # n s c h w
|
||||||
|
except:
|
||||||
|
vis_denosing = torch.ones_like(foreground).view(n,s,1,self.sils_size*2,self.sils_size).detach().cpu().numpy()
|
||||||
|
vis_appearance = torch.ones_like(foreground).view(n,s,1,self.sils_size*2,self.sils_size).detach().cpu().numpy()
|
||||||
|
|
||||||
|
# Black DA
|
||||||
|
if self.training:
|
||||||
|
mask_idx = random.sample(list(range(n)), int(round(n*0.2)))
|
||||||
|
feat_list = [denosing.view(n,s,-1), appearance.view(n,s,-1)]
|
||||||
|
for i in mask_idx:
|
||||||
|
idx = random.sample(list(range(2)), 1)
|
||||||
|
for j in idx:
|
||||||
|
feat_list[j][i] = torch.zeros_like(feat_list[j][i], device=feat_list[j].device, dtype=feat_list[j].dtype)
|
||||||
|
|
||||||
|
# get embeding
|
||||||
|
embed_1, logits = self.gait_net(
|
||||||
|
denosing.view(n,s,self.sils_size*2,self.sils_size,self.denoising_dim).permute(0, 4, 1, 2, 3).contiguous(),
|
||||||
|
appearance.view(n,s,self.sils_size*2,self.sils_size,self.app_dim).permute(0, 4, 1, 2, 3).contiguous(),
|
||||||
|
seqL,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.training:
|
||||||
|
retval = {
|
||||||
|
'training_feat': {
|
||||||
|
'shape_connect':loss_connectivity_shape*0.02,
|
||||||
|
'shape_mse': loss_mse1,
|
||||||
|
'part_connect':loss_connectivity_part*0.01,
|
||||||
|
'part_diversity':loss_diversity_part*5,
|
||||||
|
'triplet': {'embeddings': embed_1, 'labels': labs},
|
||||||
|
'softmax': {'logits': logits, 'labels': labs},
|
||||||
|
},
|
||||||
|
'visual_summary': {
|
||||||
|
'image/input': sils.view(n*s, c, h, w),
|
||||||
|
'image/foreground': self.min_max_norm(rearrange(foreground.view(n, s, self.sils_size*2, self.sils_size, -1), 'n s h w c -> (n s) c h w').contiguous()),
|
||||||
|
'image/denosing':self.min_max_norm(rearrange(torch.from_numpy(vis_denosing).float(), 'n s c h w -> (n s) c h w').contiguous()),
|
||||||
|
'image/appearance': self.min_max_norm(rearrange(torch.from_numpy(vis_appearance).float(), 'n s c h w -> (n s) c h w').contiguous()),
|
||||||
|
},
|
||||||
|
'inference_feat': {
|
||||||
|
'embeddings': embed_1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
retval = {
|
||||||
|
'training_feat': {},
|
||||||
|
'visual_summary': {},
|
||||||
|
'inference_feat': {'embeddings': embed_1}
|
||||||
|
}
|
||||||
|
return retval
|
||||||
@@ -0,0 +1,190 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
from einops import rearrange
|
||||||
|
from ...modules import SetBlockWrapper, SeparateFCs, SeparateBNNecks, PackSequenceWrapper, HorizontalPoolingPyramid
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
# ######################################## GaitBase ###########################################
|
||||||
|
|
||||||
|
def conv1x1(in_planes, out_planes, stride=1):
|
||||||
|
"""1x1 convolution"""
|
||||||
|
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||||
|
|
||||||
|
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
||||||
|
"""3x3 convolution with padding"""
|
||||||
|
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||||
|
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
||||||
|
|
||||||
|
class AttentionFusion(nn.Module):
|
||||||
|
def __init__(self, in_channels, squeeze_ratio, feat_len):
|
||||||
|
super(AttentionFusion, self).__init__()
|
||||||
|
hidden_dim = int(in_channels / squeeze_ratio)
|
||||||
|
self.feat_len = feat_len
|
||||||
|
self.conv = SetBlockWrapper(
|
||||||
|
nn.Sequential(
|
||||||
|
conv1x1(in_channels * feat_len, hidden_dim),
|
||||||
|
nn.BatchNorm2d(hidden_dim),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
conv3x3(hidden_dim, hidden_dim),
|
||||||
|
nn.BatchNorm2d(hidden_dim),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
conv1x1(hidden_dim, in_channels * feat_len),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, feat_list):
|
||||||
|
'''
|
||||||
|
sil_feat: [n, c, s, h, w]
|
||||||
|
map_feat: [n, c, s, h, w]
|
||||||
|
...
|
||||||
|
'''
|
||||||
|
feats = torch.cat(feat_list, dim=1)
|
||||||
|
score = self.conv(feats) # [n, 2 * c, s, h, w]
|
||||||
|
score = rearrange(score, 'n (c d) s h w -> n c d s h w', d=self.feat_len)
|
||||||
|
score = F.softmax(score, dim=2)
|
||||||
|
retun = feat_list[0]*score[:,:,0]
|
||||||
|
for i in range(1, self.feat_len):
|
||||||
|
retun += feat_list[i]*score[:,:,i]
|
||||||
|
return retun
|
||||||
|
|
||||||
|
|
||||||
|
from torchvision.models.resnet import BasicBlock, Bottleneck, ResNet
|
||||||
|
from ...modules import BasicConv2d
|
||||||
|
block_map = {'BasicBlock': BasicBlock,
|
||||||
|
'Bottleneck': Bottleneck}
|
||||||
|
|
||||||
|
class Pre_ResNet9(ResNet):
|
||||||
|
def __init__(self, type, block, channels=[32, 64, 128, 256], in_channel=1, layers=[1, 2, 2, 1], strides=[1, 2, 2, 1], maxpool=True):
|
||||||
|
if block in block_map.keys():
|
||||||
|
block = block_map[block]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Error type for -block-Cfg-, supported: 'BasicBlock' or 'Bottleneck'.")
|
||||||
|
self.maxpool_flag = maxpool
|
||||||
|
super(Pre_ResNet9, self).__init__(block, layers)
|
||||||
|
|
||||||
|
# Not used #
|
||||||
|
self.fc = None
|
||||||
|
self.layer2 = None
|
||||||
|
self.layer3 = None
|
||||||
|
self.layer4 = None
|
||||||
|
############
|
||||||
|
self.inplanes = channels[0]
|
||||||
|
self.bn1 = nn.BatchNorm2d(self.inplanes)
|
||||||
|
|
||||||
|
self.conv1 = BasicConv2d(in_channel, self.inplanes, 3, 1, 1)
|
||||||
|
|
||||||
|
self.layer1 = self._make_layer(
|
||||||
|
block, channels[0], layers[0], stride=strides[0], dilate=False)
|
||||||
|
|
||||||
|
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
||||||
|
if blocks >= 1:
|
||||||
|
layer = super()._make_layer(block, planes, blocks, stride=stride, dilate=dilate)
|
||||||
|
else:
|
||||||
|
def layer(x): return x
|
||||||
|
return layer
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.bn1(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
if self.maxpool_flag:
|
||||||
|
x = self.maxpool(x)
|
||||||
|
|
||||||
|
x = self.layer1(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class Post_ResNet9(ResNet):
|
||||||
|
def __init__(self, type, block, channels=[32, 64, 128, 256], in_channel=1, layers=[1, 2, 2, 1], strides=[1, 2, 2, 1], maxpool=True):
|
||||||
|
if block in block_map.keys():
|
||||||
|
block = block_map[block]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Error type for -block-Cfg-, supported: 'BasicBlock' or 'Bottleneck'.")
|
||||||
|
super(Post_ResNet9, self).__init__(block, layers)
|
||||||
|
# Not used #
|
||||||
|
self.fc = None
|
||||||
|
self.conv1 = None
|
||||||
|
self.bn1 = None
|
||||||
|
self.relu = None
|
||||||
|
self.layer1 = None
|
||||||
|
############
|
||||||
|
self.inplanes = channels[0]
|
||||||
|
self.layer2 = self._make_layer(
|
||||||
|
block, channels[1], layers[1], stride=strides[1], dilate=False)
|
||||||
|
self.layer3 = self._make_layer(
|
||||||
|
block, channels[2], layers[2], stride=strides[2], dilate=False)
|
||||||
|
self.layer4 = self._make_layer(
|
||||||
|
block, channels[3], layers[3], stride=strides[3], dilate=False)
|
||||||
|
|
||||||
|
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
||||||
|
if blocks >= 1:
|
||||||
|
layer = super()._make_layer(block, planes, blocks, stride=stride, dilate=dilate)
|
||||||
|
else:
|
||||||
|
def layer(x): return x
|
||||||
|
return layer
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.layer2(x)
|
||||||
|
x = self.layer3(x)
|
||||||
|
x = self.layer4(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
from utils import get_valid_args, is_list, is_dict, np2var, ts2np, list2var, get_attr_from
|
||||||
|
from ... import backbones
|
||||||
|
class Baseline(nn.Module):
|
||||||
|
def __init__(self, model_cfg):
|
||||||
|
super(Baseline, self).__init__()
|
||||||
|
model_cfg['backbone_cfg']['in_channel'] = model_cfg['Denoising_Branch']['target_dim']
|
||||||
|
self.pre_part = SetBlockWrapper(Pre_ResNet9(**model_cfg['backbone_cfg']))
|
||||||
|
|
||||||
|
model_cfg['backbone_cfg']['in_channel'] = model_cfg['Appearance_Branch']['target_dim']
|
||||||
|
self.pre_rgb = SetBlockWrapper(Pre_ResNet9(**model_cfg['backbone_cfg']))
|
||||||
|
|
||||||
|
self.post_backbone = SetBlockWrapper(Post_ResNet9(**model_cfg['backbone_cfg']))
|
||||||
|
self.FCs = SeparateFCs(**model_cfg['SeparateFCs'])
|
||||||
|
self.BNNecks = SeparateBNNecks(**model_cfg['SeparateBNNecks'])
|
||||||
|
self.TP = PackSequenceWrapper(torch.max)
|
||||||
|
self.HPP = HorizontalPoolingPyramid(bin_num=model_cfg['bin_num'])
|
||||||
|
|
||||||
|
self.fusion = AttentionFusion(**model_cfg['AttentionFusion'])
|
||||||
|
|
||||||
|
def get_backbone(self, backbone_cfg):
|
||||||
|
"""Get the backbone of the model."""
|
||||||
|
if is_dict(backbone_cfg):
|
||||||
|
Backbone = get_attr_from([backbones], backbone_cfg['type'])
|
||||||
|
valid_args = get_valid_args(Backbone, backbone_cfg, ['type'])
|
||||||
|
return Backbone(**valid_args)
|
||||||
|
if is_list(backbone_cfg):
|
||||||
|
Backbone = nn.ModuleList([self.get_backbone(cfg)
|
||||||
|
for cfg in backbone_cfg])
|
||||||
|
return Backbone
|
||||||
|
raise ValueError(
|
||||||
|
"Error type for -Backbone-Cfg-, supported: (A list of) dict.")
|
||||||
|
|
||||||
|
def vis_forward(self, denosing, appearance, seqL):
|
||||||
|
denosing = self.pre_part(denosing) # [n, c, s, h, w]
|
||||||
|
appearance = self.pre_rgb(appearance) # [n, c, s, h, w]
|
||||||
|
outs = self.fusion([denosing, appearance])
|
||||||
|
return denosing, appearance, outs
|
||||||
|
|
||||||
|
def forward(self, denosing, appearance, seqL):
|
||||||
|
denosing = self.pre_part(denosing) # [n, c, s, h, w]
|
||||||
|
appearance = self.pre_rgb(appearance) # [n, c, s, h, w]
|
||||||
|
outs = self.fusion([denosing, appearance])
|
||||||
|
# heat_mapt = rearrange(outs, 'n c s h w -> n s h w c')
|
||||||
|
del denosing, appearance
|
||||||
|
outs = self.post_backbone(outs)
|
||||||
|
|
||||||
|
# Temporal Pooling, TP
|
||||||
|
outs = self.TP(outs, seqL, options={"dim": 2})[0] # [n, c, h, w]
|
||||||
|
|
||||||
|
# Horizontal Pooling Matching, HPM
|
||||||
|
outs = self.HPP(outs) # [n, c, p]
|
||||||
|
|
||||||
|
embed_1 = self.FCs(outs) # [n, c, p]
|
||||||
|
_, logits = self.BNNecks(embed_1) # [n, c, p]
|
||||||
|
# return embed_1, logits, heat_mapt
|
||||||
|
return embed_1, logits
|
||||||
@@ -0,0 +1,343 @@
|
|||||||
|
from functools import partial
|
||||||
|
import math
|
||||||
|
from typing import Sequence, Tuple, Union, Callable
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
from torch.nn.init import trunc_normal_
|
||||||
|
from .dino_layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
|
||||||
|
|
||||||
|
# ######################################## DINO ###########################################
|
||||||
|
|
||||||
|
def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
|
||||||
|
if not depth_first and include_root:
|
||||||
|
fn(module=module, name=name)
|
||||||
|
for child_name, child_module in module.named_children():
|
||||||
|
child_name = ".".join((name, child_name)) if name else child_name
|
||||||
|
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
||||||
|
if depth_first and include_root:
|
||||||
|
fn(module=module, name=name)
|
||||||
|
return module
|
||||||
|
|
||||||
|
def init_weights_vit_timm(module: nn.Module, name: str = ""):
|
||||||
|
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
||||||
|
if isinstance(module, nn.Linear):
|
||||||
|
trunc_normal_(module.weight, std=0.02)
|
||||||
|
if module.bias is not None:
|
||||||
|
nn.init.zeros_(module.bias)
|
||||||
|
|
||||||
|
class BlockChunk(nn.ModuleList):
|
||||||
|
def forward(self, x):
|
||||||
|
for b in self:
|
||||||
|
x = b(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class DinoVisionTransformer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
img_size=224,
|
||||||
|
patch_size=16,
|
||||||
|
in_chans=3,
|
||||||
|
embed_dim=768,
|
||||||
|
depth=12,
|
||||||
|
num_heads=12,
|
||||||
|
mlp_ratio=4.0,
|
||||||
|
qkv_bias=True,
|
||||||
|
ffn_bias=True,
|
||||||
|
proj_bias=True,
|
||||||
|
drop_path_rate=0.0,
|
||||||
|
drop_path_uniform=False,
|
||||||
|
init_values=None, # for layerscale: None or 0 => no layerscale
|
||||||
|
embed_layer=PatchEmbed,
|
||||||
|
act_layer=nn.GELU,
|
||||||
|
block_fn=Block,
|
||||||
|
ffn_layer="mlp",
|
||||||
|
block_chunks=1,
|
||||||
|
logger = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
img_size (int, tuple): input image size
|
||||||
|
patch_size (int, tuple): patch size
|
||||||
|
in_chans (int): number of input channels
|
||||||
|
embed_dim (int): embedding dimension
|
||||||
|
depth (int): depth of transformer
|
||||||
|
num_heads (int): number of attention heads
|
||||||
|
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
||||||
|
qkv_bias (bool): enable bias for qkv if True
|
||||||
|
proj_bias (bool): enable bias for proj in attn if True
|
||||||
|
ffn_bias (bool): enable bias for ffn if True
|
||||||
|
drop_path_rate (float): stochastic depth rate
|
||||||
|
drop_path_uniform (bool): apply uniform drop rate across blocks
|
||||||
|
weight_init (str): weight init scheme
|
||||||
|
init_values (float): layer-scale init values
|
||||||
|
embed_layer (nn.Module): patch embedding layer
|
||||||
|
act_layer (nn.Module): MLP activation layer
|
||||||
|
block_fn (nn.Module): transformer block class
|
||||||
|
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
|
||||||
|
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
||||||
|
|
||||||
|
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
||||||
|
self.num_tokens = 1
|
||||||
|
self.n_blocks = depth
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.patch_size = patch_size
|
||||||
|
|
||||||
|
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
||||||
|
self.patch_embed.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||||
|
num_patches = self.patch_embed.num_patches
|
||||||
|
|
||||||
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||||
|
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
||||||
|
|
||||||
|
if drop_path_uniform is True:
|
||||||
|
dpr = [drop_path_rate] * depth
|
||||||
|
else:
|
||||||
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||||
|
|
||||||
|
if ffn_layer == "mlp":
|
||||||
|
logger.log_info("using MLP layer as FFN")
|
||||||
|
ffn_layer = Mlp
|
||||||
|
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
|
||||||
|
logger.log_info("using SwiGLU layer as FFN")
|
||||||
|
ffn_layer = SwiGLUFFNFused
|
||||||
|
elif ffn_layer == "identity":
|
||||||
|
logger.log_info("using Identity layer as FFN")
|
||||||
|
|
||||||
|
def f(*args, **kwargs):
|
||||||
|
return nn.Identity()
|
||||||
|
|
||||||
|
ffn_layer = f
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
blocks_list = [
|
||||||
|
block_fn(
|
||||||
|
dim=embed_dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
proj_bias=proj_bias,
|
||||||
|
ffn_bias=ffn_bias,
|
||||||
|
drop_path=dpr[i],
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
act_layer=act_layer,
|
||||||
|
ffn_layer=ffn_layer,
|
||||||
|
init_values=init_values,
|
||||||
|
)
|
||||||
|
for i in range(depth)
|
||||||
|
]
|
||||||
|
if block_chunks > 0:
|
||||||
|
self.chunked_blocks = True
|
||||||
|
chunked_blocks = []
|
||||||
|
chunksize = depth // block_chunks
|
||||||
|
for i in range(0, depth, chunksize):
|
||||||
|
# this is to keep the block index consistent if we chunk the block list
|
||||||
|
chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
|
||||||
|
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
|
||||||
|
else:
|
||||||
|
self.chunked_blocks = False
|
||||||
|
self.blocks = nn.ModuleList(blocks_list)
|
||||||
|
|
||||||
|
self.norm = norm_layer(embed_dim)
|
||||||
|
self.head = nn.Identity()
|
||||||
|
|
||||||
|
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
|
||||||
|
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
def init_weights(self):
|
||||||
|
trunc_normal_(self.pos_embed, std=0.02)
|
||||||
|
nn.init.normal_(self.cls_token, std=1e-6)
|
||||||
|
named_apply(init_weights_vit_timm, self)
|
||||||
|
|
||||||
|
def interpolate_pos_encoding(self, x, w, h):
|
||||||
|
previous_dtype = x.dtype
|
||||||
|
npatch = x.shape[1] - 1
|
||||||
|
N = self.pos_embed.shape[1] - 1
|
||||||
|
if npatch == N and w == h:
|
||||||
|
return self.pos_embed
|
||||||
|
pos_embed = self.pos_embed.float()
|
||||||
|
class_pos_embed = pos_embed[:, 0]
|
||||||
|
patch_pos_embed = pos_embed[:, 1:]
|
||||||
|
dim = x.shape[-1]
|
||||||
|
w0 = w // self.patch_size
|
||||||
|
h0 = h // self.patch_size
|
||||||
|
# we add a small number to avoid floating point error in the interpolation
|
||||||
|
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||||
|
w0, h0 = w0 + 0.1, h0 + 0.1
|
||||||
|
|
||||||
|
patch_pos_embed = nn.functional.interpolate(
|
||||||
|
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
|
||||||
|
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
|
||||||
|
mode="bicubic",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
|
||||||
|
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||||
|
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
|
||||||
|
|
||||||
|
def prepare_tokens_with_masks(self, x, masks=None):
|
||||||
|
B, nc, w, h = x.shape
|
||||||
|
x = self.patch_embed(x)
|
||||||
|
if masks is not None:
|
||||||
|
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
|
||||||
|
|
||||||
|
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||||
|
x = x + self.interpolate_pos_encoding(x, w, h)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward_features_list(self, x_list, masks_list):
|
||||||
|
x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
|
||||||
|
for blk in self.blocks:
|
||||||
|
x = blk(x)
|
||||||
|
|
||||||
|
all_x = x
|
||||||
|
output = []
|
||||||
|
for x, masks in zip(all_x, masks_list):
|
||||||
|
x_norm = self.norm(x)
|
||||||
|
output.append(
|
||||||
|
{
|
||||||
|
"x_norm_clstoken": x_norm[:, 0],
|
||||||
|
"x_norm_patchtokens": x_norm[:, 1:],
|
||||||
|
"x_prenorm": x,
|
||||||
|
"masks": masks,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def forward_features(self, x, masks=None):
|
||||||
|
if isinstance(x, list):
|
||||||
|
return self.forward_features_list(x, masks)
|
||||||
|
|
||||||
|
x = self.prepare_tokens_with_masks(x, masks)
|
||||||
|
|
||||||
|
x_mid4 = []
|
||||||
|
# idx_mid4 = [2,5,8,11]
|
||||||
|
idx_mid4 = [int(i * len(self.blocks) / 4 + len(self.blocks) / 4 - 1) for i in range(4)]
|
||||||
|
assert len(idx_mid4) == 4
|
||||||
|
for i, blk in enumerate(self.blocks):
|
||||||
|
x = blk(x)
|
||||||
|
if i in idx_mid4:
|
||||||
|
x_mid4.append(x)
|
||||||
|
|
||||||
|
x_mid4 = partial(nn.LayerNorm, eps=1e-6)(x_mid4[0].shape[-1]*4, elementwise_affine=False)(torch.concat(x_mid4, dim=-1))
|
||||||
|
return {
|
||||||
|
"x_norm_patchtokens": self.norm(x)[:, 1:],
|
||||||
|
"x_norm_patchtokens_mid4": x_mid4[:, 1:],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# def forward_features(self, x, masks=None):
|
||||||
|
# if isinstance(x, list):
|
||||||
|
# return self.forward_features_list(x, masks)
|
||||||
|
|
||||||
|
# x = self.prepare_tokens_with_masks(x, masks)
|
||||||
|
|
||||||
|
# for blk in self.blocks:
|
||||||
|
# x = blk(x)
|
||||||
|
|
||||||
|
# x_norm = self.norm(x)
|
||||||
|
# return {
|
||||||
|
# "x_norm_clstoken": x_norm[:, 0],
|
||||||
|
# "x_norm_patchtokens": x_norm[:, 1:],
|
||||||
|
# "x_prenorm": x,
|
||||||
|
# "masks": masks,
|
||||||
|
# }
|
||||||
|
|
||||||
|
def _get_intermediate_layers_not_chunked(self, x, n=1):
|
||||||
|
x = self.prepare_tokens_with_masks(x)
|
||||||
|
# If n is an int, take the n last blocks. If it's a list, take them
|
||||||
|
output, total_block_len = [], len(self.blocks)
|
||||||
|
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
||||||
|
for i, blk in enumerate(self.blocks):
|
||||||
|
x = blk(x)
|
||||||
|
if i in blocks_to_take:
|
||||||
|
output.append(x)
|
||||||
|
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
||||||
|
return output
|
||||||
|
|
||||||
|
def _get_intermediate_layers_chunked(self, x, n=1):
|
||||||
|
x = self.prepare_tokens_with_masks(x)
|
||||||
|
output, i, total_block_len = [], 0, len(self.blocks[-1])
|
||||||
|
# If n is an int, take the n last blocks. If it's a list, take them
|
||||||
|
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
||||||
|
for block_chunk in self.blocks:
|
||||||
|
for blk in block_chunk[i:]: # Passing the nn.Identity()
|
||||||
|
x = blk(x)
|
||||||
|
if i in blocks_to_take:
|
||||||
|
output.append(x)
|
||||||
|
i += 1
|
||||||
|
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
||||||
|
return output
|
||||||
|
|
||||||
|
def get_intermediate_layers(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
n: Union[int, Sequence] = 1, # Layers or n last layers to take
|
||||||
|
reshape: bool = False,
|
||||||
|
return_class_token: bool = False,
|
||||||
|
norm=True,
|
||||||
|
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
||||||
|
if self.chunked_blocks:
|
||||||
|
outputs = self._get_intermediate_layers_chunked(x, n)
|
||||||
|
else:
|
||||||
|
outputs = self._get_intermediate_layers_not_chunked(x, n)
|
||||||
|
if norm:
|
||||||
|
outputs = [self.norm(out) for out in outputs]
|
||||||
|
class_tokens = [out[:, 0] for out in outputs]
|
||||||
|
outputs = [out[:, 1:] for out in outputs]
|
||||||
|
if reshape:
|
||||||
|
B, _, w, h = x.shape
|
||||||
|
outputs = [
|
||||||
|
out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
|
||||||
|
for out in outputs
|
||||||
|
]
|
||||||
|
if return_class_token:
|
||||||
|
return tuple(zip(outputs, class_tokens))
|
||||||
|
return tuple(outputs)
|
||||||
|
|
||||||
|
def forward(self, *args, is_training=False, **kwargs):
|
||||||
|
ret = self.forward_features(*args, **kwargs)
|
||||||
|
if is_training:
|
||||||
|
return ret
|
||||||
|
else:
|
||||||
|
return self.head(ret["x_norm_clstoken"])
|
||||||
|
|
||||||
|
def vit_small(patch_size=16, **kwargs):
|
||||||
|
model = DinoVisionTransformer(
|
||||||
|
img_size=518,
|
||||||
|
patch_size=14,
|
||||||
|
init_values=1.0,
|
||||||
|
ffn_layer="mlp",
|
||||||
|
block_chunks=0,
|
||||||
|
|
||||||
|
embed_dim=384,
|
||||||
|
depth=12,
|
||||||
|
num_heads=6,
|
||||||
|
mlp_ratio=4,
|
||||||
|
block_fn=partial(Block, attn_class=MemEffAttention),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def vit_large(patch_size=16, **kwargs):
|
||||||
|
model = DinoVisionTransformer(
|
||||||
|
img_size=518,
|
||||||
|
patch_size=14,
|
||||||
|
init_values=1.0,
|
||||||
|
ffn_layer="mlp",
|
||||||
|
block_chunks=0,
|
||||||
|
|
||||||
|
embed_dim=1024,
|
||||||
|
depth=24,
|
||||||
|
num_heads=16,
|
||||||
|
mlp_ratio=4,
|
||||||
|
block_fn=partial(Block, attn_class=MemEffAttention),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
return model
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from .dino_head import DINOHead
|
||||||
|
from .mlp import Mlp
|
||||||
|
from .patch_embed import PatchEmbed
|
||||||
|
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
|
||||||
|
from .block import NestedTensorBlock
|
||||||
|
from .attention import MemEffAttention
|
||||||
@@ -0,0 +1,81 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
# References:
|
||||||
|
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
||||||
|
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from torch import Tensor
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger("dinov2")
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from xformers.ops import memory_efficient_attention, unbind, fmha
|
||||||
|
|
||||||
|
XFORMERS_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
logger.warning("xFormers not available")
|
||||||
|
XFORMERS_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
num_heads: int = 8,
|
||||||
|
qkv_bias: bool = False,
|
||||||
|
proj_bias: bool = True,
|
||||||
|
attn_drop: float = 0.0,
|
||||||
|
proj_drop: float = 0.0,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
head_dim = dim // num_heads
|
||||||
|
self.scale = head_dim**-0.5
|
||||||
|
|
||||||
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||||
|
self.attn_drop = nn.Dropout(attn_drop)
|
||||||
|
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
||||||
|
self.proj_drop = nn.Dropout(proj_drop)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
B, N, C = x.shape
|
||||||
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||||
|
|
||||||
|
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
||||||
|
attn = q @ k.transpose(-2, -1)
|
||||||
|
|
||||||
|
attn = attn.softmax(dim=-1)
|
||||||
|
attn = self.attn_drop(attn)
|
||||||
|
|
||||||
|
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||||
|
x = self.proj(x)
|
||||||
|
x = self.proj_drop(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MemEffAttention(Attention):
|
||||||
|
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
||||||
|
if not XFORMERS_AVAILABLE:
|
||||||
|
assert attn_bias is None, "xFormers is required for nested tensors usage"
|
||||||
|
return super().forward(x)
|
||||||
|
|
||||||
|
B, N, C = x.shape
|
||||||
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
||||||
|
|
||||||
|
q, k, v = unbind(qkv, 2)
|
||||||
|
|
||||||
|
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
||||||
|
x = x.reshape([B, N, C])
|
||||||
|
|
||||||
|
x = self.proj(x)
|
||||||
|
x = self.proj_drop(x)
|
||||||
|
return x
|
||||||
@@ -0,0 +1,252 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
# References:
|
||||||
|
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
||||||
|
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Callable, List, Any, Tuple, Dict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn, Tensor
|
||||||
|
|
||||||
|
from .attention import Attention, MemEffAttention
|
||||||
|
from .drop_path import DropPath
|
||||||
|
from .layer_scale import LayerScale
|
||||||
|
from .mlp import Mlp
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger("dinov2")
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from xformers.ops import fmha
|
||||||
|
from xformers.ops import scaled_index_add, index_select_cat
|
||||||
|
|
||||||
|
XFORMERS_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
logger.warning("xFormers not available")
|
||||||
|
XFORMERS_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
|
class Block(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
mlp_ratio: float = 4.0,
|
||||||
|
qkv_bias: bool = False,
|
||||||
|
proj_bias: bool = True,
|
||||||
|
ffn_bias: bool = True,
|
||||||
|
drop: float = 0.0,
|
||||||
|
attn_drop: float = 0.0,
|
||||||
|
init_values=None,
|
||||||
|
drop_path: float = 0.0,
|
||||||
|
act_layer: Callable[..., nn.Module] = nn.GELU,
|
||||||
|
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
||||||
|
attn_class: Callable[..., nn.Module] = Attention,
|
||||||
|
ffn_layer: Callable[..., nn.Module] = Mlp,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
|
||||||
|
self.norm1 = norm_layer(dim)
|
||||||
|
self.attn = attn_class(
|
||||||
|
dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
proj_bias=proj_bias,
|
||||||
|
attn_drop=attn_drop,
|
||||||
|
proj_drop=drop,
|
||||||
|
)
|
||||||
|
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||||
|
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||||
|
|
||||||
|
self.norm2 = norm_layer(dim)
|
||||||
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||||
|
self.mlp = ffn_layer(
|
||||||
|
in_features=dim,
|
||||||
|
hidden_features=mlp_hidden_dim,
|
||||||
|
act_layer=act_layer,
|
||||||
|
drop=drop,
|
||||||
|
bias=ffn_bias,
|
||||||
|
)
|
||||||
|
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||||
|
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||||
|
|
||||||
|
self.sample_drop_ratio = drop_path
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
def attn_residual_func(x: Tensor) -> Tensor:
|
||||||
|
return self.ls1(self.attn(self.norm1(x)))
|
||||||
|
|
||||||
|
def ffn_residual_func(x: Tensor) -> Tensor:
|
||||||
|
return self.ls2(self.mlp(self.norm2(x)))
|
||||||
|
|
||||||
|
if self.training and self.sample_drop_ratio > 0.1:
|
||||||
|
# the overhead is compensated only for a drop path rate larger than 0.1
|
||||||
|
x = drop_add_residual_stochastic_depth(
|
||||||
|
x,
|
||||||
|
residual_func=attn_residual_func,
|
||||||
|
sample_drop_ratio=self.sample_drop_ratio,
|
||||||
|
)
|
||||||
|
x = drop_add_residual_stochastic_depth(
|
||||||
|
x,
|
||||||
|
residual_func=ffn_residual_func,
|
||||||
|
sample_drop_ratio=self.sample_drop_ratio,
|
||||||
|
)
|
||||||
|
elif self.training and self.sample_drop_ratio > 0.0:
|
||||||
|
x = x + self.drop_path1(attn_residual_func(x))
|
||||||
|
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
||||||
|
else:
|
||||||
|
x = x + attn_residual_func(x)
|
||||||
|
x = x + ffn_residual_func(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def drop_add_residual_stochastic_depth(
|
||||||
|
x: Tensor,
|
||||||
|
residual_func: Callable[[Tensor], Tensor],
|
||||||
|
sample_drop_ratio: float = 0.0,
|
||||||
|
) -> Tensor:
|
||||||
|
# 1) extract subset using permutation
|
||||||
|
b, n, d = x.shape
|
||||||
|
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
||||||
|
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
||||||
|
x_subset = x[brange]
|
||||||
|
|
||||||
|
# 2) apply residual_func to get residual
|
||||||
|
residual = residual_func(x_subset)
|
||||||
|
|
||||||
|
x_flat = x.flatten(1)
|
||||||
|
residual = residual.flatten(1)
|
||||||
|
|
||||||
|
residual_scale_factor = b / sample_subset_size
|
||||||
|
|
||||||
|
# 3) add the residual
|
||||||
|
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
||||||
|
return x_plus_residual.view_as(x)
|
||||||
|
|
||||||
|
|
||||||
|
def get_branges_scales(x, sample_drop_ratio=0.0):
|
||||||
|
b, n, d = x.shape
|
||||||
|
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
||||||
|
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
||||||
|
residual_scale_factor = b / sample_subset_size
|
||||||
|
return brange, residual_scale_factor
|
||||||
|
|
||||||
|
|
||||||
|
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
|
||||||
|
if scaling_vector is None:
|
||||||
|
x_flat = x.flatten(1)
|
||||||
|
residual = residual.flatten(1)
|
||||||
|
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
||||||
|
else:
|
||||||
|
x_plus_residual = scaled_index_add(
|
||||||
|
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
|
||||||
|
)
|
||||||
|
return x_plus_residual
|
||||||
|
|
||||||
|
|
||||||
|
attn_bias_cache: Dict[Tuple, Any] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def get_attn_bias_and_cat(x_list, branges=None):
|
||||||
|
"""
|
||||||
|
this will perform the index select, cat the tensors, and provide the attn_bias from cache
|
||||||
|
"""
|
||||||
|
batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
|
||||||
|
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
||||||
|
if all_shapes not in attn_bias_cache.keys():
|
||||||
|
seqlens = []
|
||||||
|
for b, x in zip(batch_sizes, x_list):
|
||||||
|
for _ in range(b):
|
||||||
|
seqlens.append(x.shape[1])
|
||||||
|
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
||||||
|
attn_bias._batch_sizes = batch_sizes
|
||||||
|
attn_bias_cache[all_shapes] = attn_bias
|
||||||
|
|
||||||
|
if branges is not None:
|
||||||
|
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
|
||||||
|
else:
|
||||||
|
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
||||||
|
cat_tensors = torch.cat(tensors_bs1, dim=1)
|
||||||
|
|
||||||
|
return attn_bias_cache[all_shapes], cat_tensors
|
||||||
|
|
||||||
|
|
||||||
|
def drop_add_residual_stochastic_depth_list(
|
||||||
|
x_list: List[Tensor],
|
||||||
|
residual_func: Callable[[Tensor, Any], Tensor],
|
||||||
|
sample_drop_ratio: float = 0.0,
|
||||||
|
scaling_vector=None,
|
||||||
|
) -> Tensor:
|
||||||
|
# 1) generate random set of indices for dropping samples in the batch
|
||||||
|
branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
|
||||||
|
branges = [s[0] for s in branges_scales]
|
||||||
|
residual_scale_factors = [s[1] for s in branges_scales]
|
||||||
|
|
||||||
|
# 2) get attention bias and index+concat the tensors
|
||||||
|
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
|
||||||
|
|
||||||
|
# 3) apply residual_func to get residual, and split the result
|
||||||
|
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
|
||||||
|
outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class NestedTensorBlock(Block):
|
||||||
|
def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
|
||||||
|
"""
|
||||||
|
x_list contains a list of tensors to nest together and run
|
||||||
|
"""
|
||||||
|
assert isinstance(self.attn, MemEffAttention)
|
||||||
|
|
||||||
|
if self.training and self.sample_drop_ratio > 0.0:
|
||||||
|
|
||||||
|
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
||||||
|
return self.attn(self.norm1(x), attn_bias=attn_bias)
|
||||||
|
|
||||||
|
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
||||||
|
return self.mlp(self.norm2(x))
|
||||||
|
|
||||||
|
x_list = drop_add_residual_stochastic_depth_list(
|
||||||
|
x_list,
|
||||||
|
residual_func=attn_residual_func,
|
||||||
|
sample_drop_ratio=self.sample_drop_ratio,
|
||||||
|
scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
|
||||||
|
)
|
||||||
|
x_list = drop_add_residual_stochastic_depth_list(
|
||||||
|
x_list,
|
||||||
|
residual_func=ffn_residual_func,
|
||||||
|
sample_drop_ratio=self.sample_drop_ratio,
|
||||||
|
scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
|
||||||
|
)
|
||||||
|
return x_list
|
||||||
|
else:
|
||||||
|
|
||||||
|
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
||||||
|
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
|
||||||
|
|
||||||
|
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
||||||
|
return self.ls2(self.mlp(self.norm2(x)))
|
||||||
|
|
||||||
|
attn_bias, x = get_attn_bias_and_cat(x_list)
|
||||||
|
x = x + attn_residual_func(x, attn_bias=attn_bias)
|
||||||
|
x = x + ffn_residual_func(x)
|
||||||
|
return attn_bias.split(x)
|
||||||
|
|
||||||
|
def forward(self, x_or_x_list):
|
||||||
|
if isinstance(x_or_x_list, Tensor):
|
||||||
|
return super().forward(x_or_x_list)
|
||||||
|
elif isinstance(x_or_x_list, list):
|
||||||
|
assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
|
||||||
|
return self.forward_nested(x_or_x_list)
|
||||||
|
else:
|
||||||
|
raise AssertionError
|
||||||
@@ -0,0 +1,59 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn.init import trunc_normal_
|
||||||
|
from torch.nn.utils import weight_norm
|
||||||
|
|
||||||
|
|
||||||
|
class DINOHead(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_dim,
|
||||||
|
out_dim,
|
||||||
|
use_bn=False,
|
||||||
|
nlayers=3,
|
||||||
|
hidden_dim=2048,
|
||||||
|
bottleneck_dim=256,
|
||||||
|
mlp_bias=True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
nlayers = max(nlayers, 1)
|
||||||
|
self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
|
||||||
|
self.apply(self._init_weights)
|
||||||
|
self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
|
||||||
|
self.last_layer.weight_g.data.fill_(1)
|
||||||
|
|
||||||
|
def _init_weights(self, m):
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
trunc_normal_(m.weight, std=0.02)
|
||||||
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.mlp(x)
|
||||||
|
eps = 1e-6 if x.dtype == torch.float16 else 1e-12
|
||||||
|
x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
|
||||||
|
x = self.last_layer(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
|
||||||
|
if nlayers == 1:
|
||||||
|
return nn.Linear(in_dim, bottleneck_dim, bias=bias)
|
||||||
|
else:
|
||||||
|
layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
|
||||||
|
if use_bn:
|
||||||
|
layers.append(nn.BatchNorm1d(hidden_dim))
|
||||||
|
layers.append(nn.GELU())
|
||||||
|
for _ in range(nlayers - 2):
|
||||||
|
layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
|
||||||
|
if use_bn:
|
||||||
|
layers.append(nn.BatchNorm1d(hidden_dim))
|
||||||
|
layers.append(nn.GELU())
|
||||||
|
layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
|
||||||
|
return nn.Sequential(*layers)
|
||||||
@@ -0,0 +1,35 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
# References:
|
||||||
|
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
||||||
|
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
|
||||||
|
|
||||||
|
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
||||||
|
if drop_prob == 0.0 or not training:
|
||||||
|
return x
|
||||||
|
keep_prob = 1 - drop_prob
|
||||||
|
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||||||
|
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
||||||
|
if keep_prob > 0.0:
|
||||||
|
random_tensor.div_(keep_prob)
|
||||||
|
output = x * random_tensor
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class DropPath(nn.Module):
|
||||||
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
||||||
|
|
||||||
|
def __init__(self, drop_prob=None):
|
||||||
|
super(DropPath, self).__init__()
|
||||||
|
self.drop_prob = drop_prob
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return drop_path(x, self.drop_prob, self.training)
|
||||||
@@ -0,0 +1,28 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
|
||||||
|
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
class LayerScale(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
init_values: Union[float, Tensor] = 1e-5,
|
||||||
|
inplace: bool = False,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.inplace = inplace
|
||||||
|
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
||||||
@@ -0,0 +1,41 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
# References:
|
||||||
|
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
||||||
|
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
|
||||||
|
|
||||||
|
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
|
||||||
|
class Mlp(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_features: int,
|
||||||
|
hidden_features: Optional[int] = None,
|
||||||
|
out_features: Optional[int] = None,
|
||||||
|
act_layer: Callable[..., nn.Module] = nn.GELU,
|
||||||
|
drop: float = 0.0,
|
||||||
|
bias: bool = True,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
out_features = out_features or in_features
|
||||||
|
hidden_features = hidden_features or in_features
|
||||||
|
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
||||||
|
self.act = act_layer()
|
||||||
|
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
||||||
|
self.drop = nn.Dropout(drop)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
x = self.fc1(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.drop(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
x = self.drop(x)
|
||||||
|
return x
|
||||||
@@ -0,0 +1,90 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
# References:
|
||||||
|
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
||||||
|
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
||||||
|
|
||||||
|
from typing import Callable, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from torch import Tensor
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
def make_2tuple(x):
|
||||||
|
if isinstance(x, tuple):
|
||||||
|
assert len(x) == 2
|
||||||
|
return x
|
||||||
|
|
||||||
|
assert isinstance(x, int)
|
||||||
|
return (x, x)
|
||||||
|
|
||||||
|
|
||||||
|
class PatchEmbed(nn.Module):
|
||||||
|
"""
|
||||||
|
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img_size: Image size.
|
||||||
|
patch_size: Patch token size.
|
||||||
|
in_chans: Number of input image channels.
|
||||||
|
embed_dim: Number of linear projection output channels.
|
||||||
|
norm_layer: Normalization layer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
img_size: Union[int, Tuple[int, int]] = 224,
|
||||||
|
patch_size: Union[int, Tuple[int, int]] = 16,
|
||||||
|
in_chans: int = 3,
|
||||||
|
embed_dim: int = 768,
|
||||||
|
norm_layer: Optional[Callable] = None,
|
||||||
|
flatten_embedding: bool = True,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
image_HW = make_2tuple(img_size)
|
||||||
|
patch_HW = make_2tuple(patch_size)
|
||||||
|
patch_grid_size = (
|
||||||
|
image_HW[0] // patch_HW[0],
|
||||||
|
image_HW[1] // patch_HW[1],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.img_size = image_HW
|
||||||
|
self.patch_size = patch_HW
|
||||||
|
self.patches_resolution = patch_grid_size
|
||||||
|
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
|
||||||
|
|
||||||
|
self.in_chans = in_chans
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
|
||||||
|
self.flatten_embedding = flatten_embedding
|
||||||
|
|
||||||
|
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
|
||||||
|
# self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=(patch_HW[0]//2, patch_HW[1]//2))
|
||||||
|
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
_, _, H, W = x.shape
|
||||||
|
patch_H, patch_W = self.patch_size
|
||||||
|
|
||||||
|
# assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
|
||||||
|
# assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
|
||||||
|
|
||||||
|
x = self.proj(x) # B C H W
|
||||||
|
H, W = x.size(2), x.size(3)
|
||||||
|
x = x.flatten(2).transpose(1, 2) # B HW C
|
||||||
|
x = self.norm(x)
|
||||||
|
if not self.flatten_embedding:
|
||||||
|
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
|
||||||
|
return x
|
||||||
|
|
||||||
|
def flops(self) -> float:
|
||||||
|
Ho, Wo = self.patches_resolution
|
||||||
|
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
||||||
|
if self.norm is not None:
|
||||||
|
flops += Ho * Wo * self.embed_dim
|
||||||
|
return flops
|
||||||
@@ -0,0 +1,63 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
from torch import Tensor, nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class SwiGLUFFN(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_features: int,
|
||||||
|
hidden_features: Optional[int] = None,
|
||||||
|
out_features: Optional[int] = None,
|
||||||
|
act_layer: Callable[..., nn.Module] = None,
|
||||||
|
drop: float = 0.0,
|
||||||
|
bias: bool = True,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
out_features = out_features or in_features
|
||||||
|
hidden_features = hidden_features or in_features
|
||||||
|
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
|
||||||
|
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
x12 = self.w12(x)
|
||||||
|
x1, x2 = x12.chunk(2, dim=-1)
|
||||||
|
hidden = F.silu(x1) * x2
|
||||||
|
return self.w3(hidden)
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from xformers.ops import SwiGLU
|
||||||
|
|
||||||
|
XFORMERS_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
SwiGLU = SwiGLUFFN
|
||||||
|
XFORMERS_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
|
class SwiGLUFFNFused(SwiGLU):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_features: int,
|
||||||
|
hidden_features: Optional[int] = None,
|
||||||
|
out_features: Optional[int] = None,
|
||||||
|
act_layer: Callable[..., nn.Module] = None,
|
||||||
|
drop: float = 0.0,
|
||||||
|
bias: bool = True,
|
||||||
|
) -> None:
|
||||||
|
out_features = out_features or in_features
|
||||||
|
hidden_features = hidden_features or in_features
|
||||||
|
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
||||||
|
super().__init__(
|
||||||
|
in_features=in_features,
|
||||||
|
hidden_features=hidden_features,
|
||||||
|
out_features=out_features,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
@@ -0,0 +1,100 @@
|
|||||||
|
from os import path as osp
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
from PIL import Image
|
||||||
|
import imageio
|
||||||
|
from glob import glob
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
from sklearn.decomposition import PCA
|
||||||
|
from sklearn.preprocessing import minmax_scale
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
|
||||||
|
def pca_image(data, mask, root, model_name, dataset, n_components=3, is_return=False):
|
||||||
|
features = data['embeddings']
|
||||||
|
ns,hw,c = features.shape
|
||||||
|
features = features.reshape(ns*hw,c)
|
||||||
|
mask = mask.reshape(ns*hw)
|
||||||
|
|
||||||
|
pca = PCA(n_components=n_components)
|
||||||
|
pca_features = pca.fit_transform(features[mask != 0])
|
||||||
|
pca_features = minmax_scale(pca_features, (0,255), axis=1)
|
||||||
|
# pca_features = minmax_scale(pca_features, (0,255), axis=0)
|
||||||
|
|
||||||
|
norm_features = np.zeros_like(mask,dtype=np.uint8).reshape(ns*hw,1).repeat(n_components,axis=1)
|
||||||
|
norm_features[mask != 0] = pca_features
|
||||||
|
|
||||||
|
if is_return:
|
||||||
|
norm_features = norm_features.reshape(1,ns,64,32,n_components)[...,:3].transpose(0,1,4,2,3) #
|
||||||
|
return norm_features
|
||||||
|
|
||||||
|
s = 20
|
||||||
|
assert ns % s == 0
|
||||||
|
norm_features = norm_features.reshape(ns//s,s,64,32,n_components)[...,:3].transpose(0,1,4,2,3)
|
||||||
|
data['embeddings'] = norm_features
|
||||||
|
save_image(data, root, model_name, dataset, need='image')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def save_image(data, root, model_name, dataset, need='image', mask=None):
|
||||||
|
images, label, seq_type, view = data['embeddings'], data['labels'], data['types'], data['views'] # n s c h w
|
||||||
|
if "image" in need:
|
||||||
|
root_path = os.path.join(root, dataset, model_name+'_image')
|
||||||
|
os.makedirs(os.path.join(root_path),exist_ok=True)
|
||||||
|
for i, id in enumerate(label[:]):
|
||||||
|
tmp = os.path.join(root_path, str(id).zfill(5), str(seq_type[i]), str(view[i]))
|
||||||
|
os.makedirs(tmp, exist_ok=True)
|
||||||
|
mb = None if mask is None else mask[i]
|
||||||
|
save_func(tmp, images[i], need, mb)
|
||||||
|
save_gif(tmp, tmp, str(view[i]))
|
||||||
|
|
||||||
|
if 'pkl' in need:
|
||||||
|
root_path = os.path.join(root, dataset, model_name+'_pkl')
|
||||||
|
os.makedirs(os.path.join(root_path),exist_ok=True)
|
||||||
|
for i, id in enumerate(label[:]):
|
||||||
|
tmp = os.path.join(root_path, str(id).zfill(5), str(seq_type[i]), str(view[i]))
|
||||||
|
os.makedirs(tmp, exist_ok=True)
|
||||||
|
mb = None if mask is None else mask[i]
|
||||||
|
save_func(tmp, images[i], 'pkl', mb)
|
||||||
|
|
||||||
|
if 'w' in need:
|
||||||
|
root_path = os.path.join(root, dataset, model_name+'_w')
|
||||||
|
os.makedirs(os.path.join(root_path),exist_ok=True)
|
||||||
|
for i, id in enumerate(label[:]):
|
||||||
|
tmp = os.path.join(root_path, str(id).zfill(5), str(seq_type[i]), str(view[i]))
|
||||||
|
os.makedirs(tmp, exist_ok=True)
|
||||||
|
mb = None if mask is None else mask[i]
|
||||||
|
save_func(tmp, data['w'], 'w', mb)
|
||||||
|
return
|
||||||
|
|
||||||
|
def save_func(tmp, data, ipts_type='image', mask=None):
|
||||||
|
if 'image' in ipts_type :
|
||||||
|
for i, con in enumerate(data):
|
||||||
|
if con.shape[0] == 1:
|
||||||
|
if 'jet' in ipts_type :
|
||||||
|
im = ((cv2.applyColorMap(con[0], cv2.COLORMAP_JET) * 0.5)[...,::-1] + 1.0*mask[i])
|
||||||
|
# im = mask[i]
|
||||||
|
im = np.clip(im,0,255).astype(np.uint8)
|
||||||
|
im = Image.fromarray(im, mode='RGB') # [h,w,c]
|
||||||
|
else:
|
||||||
|
im = Image.fromarray(con[0], mode='L')
|
||||||
|
else:
|
||||||
|
im = Image.fromarray(con.transpose(1,2,0), mode='RGB')
|
||||||
|
im.save(os.path.join(tmp, '%03d.png' % i))
|
||||||
|
elif ipts_type == 'pkl':
|
||||||
|
with open(os.path.join(tmp,'00.pkl'), 'wb') as f:
|
||||||
|
pickle.dump(data[:,0,:,:], f)
|
||||||
|
elif ipts_type == 'w':
|
||||||
|
for i in range(len(data)):
|
||||||
|
with open(os.path.join(tmp, str(i).zfill(2) + '.pkl'), 'wb') as f:
|
||||||
|
pickle.dump(data[i], f)
|
||||||
|
|
||||||
|
def save_gif(image_folder, save_folder, name="movie"):
|
||||||
|
images = []
|
||||||
|
filenames = sorted(glob(osp.join(image_folder, '*.png')))
|
||||||
|
# print(filenames)
|
||||||
|
for filename in filenames:
|
||||||
|
images.append(imageio.imread(filename))
|
||||||
|
imageio.mimsave(os.path.join(save_folder, f'{name}.gif'), images, duration=50, loop=0)
|
||||||
Reference in New Issue
Block a user