1.0.0 official release (#18)
* fix bug in fix_BN * gaitgl OUMVLP support. * update ./doc/3.advance_usage.md Cross-Dataset Evalution & Data Agumentation * update config * update docs.3 * update docs.3 * add loss doc and gather input decorator * refine the create model doc * support rearrange directory of unzipped OUMVLP * fix some bugs in loss_aggregator.py * refine docs and little fix * add oumvlp pretreatment description * pretreatment dataset fix oumvlp description * add gaitgl oumvlp result * assert gaitgl input size * add pipeline * update the readme. * update pipeline and readme * Corrigendum. * add logo and remove path * update new logo * Update README.md * modify logo size Co-authored-by: 12131100 <12131100@mail.sustech.edu.cn> Co-authored-by: noahshen98 <77523610+noahshen98@users.noreply.github.com> Co-authored-by: Noah <595311942@qq.com>
This commit is contained in:
@@ -1,7 +1,28 @@
|
||||
"""The plain backbone.
|
||||
|
||||
The plain backbone only contains the BasicConv2d, FocalConv2d and MaxPool2d and LeakyReLU layers.
|
||||
"""
|
||||
|
||||
import torch.nn as nn
|
||||
from ..modules import BasicConv2d, FocalConv2d
|
||||
|
||||
|
||||
class Plain(nn.Module):
|
||||
"""
|
||||
The Plain backbone class.
|
||||
|
||||
An implicit LeakyRelu appended to each layer except maxPooling.
|
||||
The kernel size, stride and padding of the first convolution layer are 5, 1, 2, the ones of other layers are 3, 1, 1.
|
||||
|
||||
Typical usage:
|
||||
- BC-64: Basic conv2d with output channel 64. The input channel is the output channel of previous layer.
|
||||
|
||||
- M: nn.MaxPool2d(kernel_size=2, stride=2)].
|
||||
|
||||
- FC-128-1: Focal conv2d with output channel 64 and halving 1(divided to 2^1=2 parts).
|
||||
|
||||
Use it in your configuration file.
|
||||
"""
|
||||
|
||||
def __init__(self, layers_cfg, in_channels=1):
|
||||
super(Plain, self).__init__()
|
||||
@@ -13,9 +34,11 @@ class Plain(nn.Module):
|
||||
def forward(self, seqs):
|
||||
out = self.feature(seqs)
|
||||
return out
|
||||
|
||||
# torchvision/models/vgg.py
|
||||
|
||||
def make_layers(self):
|
||||
"""
|
||||
Reference: torchvision/models/vgg.py
|
||||
"""
|
||||
def get_layer(cfg, in_c, kernel_size, stride, padding):
|
||||
cfg = cfg.split('-')
|
||||
typ = cfg[0]
|
||||
@@ -27,7 +50,8 @@ class Plain(nn.Module):
|
||||
return BasicConv2d(in_c, out_c, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||
return FocalConv2d(in_c, out_c, kernel_size=kernel_size, stride=stride, padding=padding, halving=int(cfg[2]))
|
||||
|
||||
Layers = [get_layer(self.layers_cfg[0], self.in_channels, 5, 1, 2), nn.LeakyReLU(inplace=True)]
|
||||
Layers = [get_layer(self.layers_cfg[0], self.in_channels,
|
||||
5, 1, 2), nn.LeakyReLU(inplace=True)]
|
||||
in_c = int(self.layers_cfg[0].split('-')[1])
|
||||
for cfg in self.layers_cfg[1:]:
|
||||
if cfg == 'M':
|
||||
@@ -37,6 +61,3 @@ class Plain(nn.Module):
|
||||
Layers += [conv2d, nn.LeakyReLU(inplace=True)]
|
||||
in_c = int(cfg.split('-')[1])
|
||||
return nn.Sequential(*Layers)
|
||||
|
||||
|
||||
|
||||
|
||||
+84
-30
@@ -1,3 +1,14 @@
|
||||
"""The base model definition.
|
||||
|
||||
This module defines the abstract meta model class and base model class. In the base model,
|
||||
we define the basic model functions, like get_loader, build_network, and run_train, etc.
|
||||
The api of the base model is run_train and run_test, they are used in `lib/main.py`.
|
||||
|
||||
Typical usage:
|
||||
|
||||
BaseModel.run_train(model)
|
||||
BaseModel.run_test(model)
|
||||
"""
|
||||
import torch
|
||||
import numpy as np
|
||||
import os.path as osp
|
||||
@@ -13,7 +24,6 @@ from abc import abstractmethod
|
||||
|
||||
from . import backbones
|
||||
from .loss_aggregator import LossAggregator
|
||||
from modeling.modules import fix_BN
|
||||
from data.transform import get_transform
|
||||
from data.collate_fn import CollateFn
|
||||
from data.dataset import DataSet
|
||||
@@ -28,80 +38,97 @@ __all__ = ['BaseModel']
|
||||
|
||||
|
||||
class MetaModel(metaclass=ABCMeta):
|
||||
"""The necessary functions for the base model.
|
||||
|
||||
This class defines the necessary functions for the base model, in the base model, we have implemented them.
|
||||
"""
|
||||
@abstractmethod
|
||||
def get_loader(self, data_cfg):
|
||||
'''
|
||||
Build your data Loader here.
|
||||
Inputs: data_cfg, dict
|
||||
Return: Loader
|
||||
'''
|
||||
"""Based on the given data_cfg, we get the data loader."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def build_network(self, model_cfg):
|
||||
'''
|
||||
Build your Model here.
|
||||
Inputs: model_cfg, dict
|
||||
Return: Network, nn.Module(s)
|
||||
'''
|
||||
"""Build your network here."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def init_parameters(self):
|
||||
"""Initialize the parameters of your network."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_optimizer(self, optimizer_cfg):
|
||||
'''
|
||||
Build your Optimizer here.
|
||||
Inputs: optimizer_cfg, dict
|
||||
Return: Optimizer, a optimizer object
|
||||
'''
|
||||
"""Based on the given optimizer_cfg, we get the optimizer."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_scheduler(self, scheduler_cfg):
|
||||
'''
|
||||
Build your Scheduler.
|
||||
Inputs: scheduler_cfg, dict
|
||||
Optimizer, your optimizer
|
||||
Return: Scheduler, a scheduler object
|
||||
'''
|
||||
"""Based on the given scheduler_cfg, we get the scheduler."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def save_ckpt(self, iteration):
|
||||
"""Save the checkpoint, including model parameter, optimizer and scheduler."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def resume_ckpt(self, restore_hint):
|
||||
"""Resume the model from the checkpoint, including model parameter, optimizer and scheduler."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def inputs_pretreament(self, inputs):
|
||||
"""Transform the input data based on transform setting."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def train_step(self, loss_num) -> bool:
|
||||
"""Do one training step."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def inference(self):
|
||||
"""Do inference (calculate features.)."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def run_train(model):
|
||||
"""Run a whole train schedule."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def run_test(model):
|
||||
"""Run a whole test schedule."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class BaseModel(MetaModel, nn.Module):
|
||||
"""Base model.
|
||||
|
||||
This class inherites the MetaModel class, and implements the basic model functions, like get_loader, build_network, etc.
|
||||
|
||||
Attributes:
|
||||
msg_mgr: the massage manager.
|
||||
cfgs: the configs.
|
||||
iteration: the current iteration of the model.
|
||||
engine_cfg: the configs of the engine(train or test).
|
||||
save_path: the path to save the checkpoints.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, cfgs, training):
|
||||
"""Initialize the base model.
|
||||
|
||||
Complete the model initialization, including the data loader, the network, the optimizer, the scheduler, the loss.
|
||||
|
||||
Args:
|
||||
cfgs:
|
||||
All of the configs.
|
||||
training:
|
||||
Whether the model is in training mode.
|
||||
"""
|
||||
|
||||
super(BaseModel, self).__init__()
|
||||
self.msg_mgr = get_msg_mgr()
|
||||
self.cfgs = cfgs
|
||||
@@ -132,8 +159,6 @@ class BaseModel(MetaModel, nn.Module):
|
||||
"cuda", self.device))
|
||||
|
||||
if training:
|
||||
if cfgs['trainer_cfg']['fix_BN']:
|
||||
fix_BN(self)
|
||||
self.loss_aggregator = LossAggregator(cfgs['loss_cfg'])
|
||||
self.optimizer = self.get_optimizer(self.cfgs['optimizer_cfg'])
|
||||
self.scheduler = self.get_scheduler(cfgs['scheduler_cfg'])
|
||||
@@ -142,7 +167,12 @@ class BaseModel(MetaModel, nn.Module):
|
||||
if restore_hint != 0:
|
||||
self.resume_ckpt(restore_hint)
|
||||
|
||||
if training:
|
||||
if cfgs['trainer_cfg']['fix_BN']:
|
||||
self.fix_BN()
|
||||
|
||||
def get_backbone(self, model_cfg):
|
||||
"""Get the backbone of the model."""
|
||||
def _get_backbone(backbone_cfg):
|
||||
if is_dict(backbone_cfg):
|
||||
Backbone = get_attr_from([backbones], backbone_cfg['type'])
|
||||
@@ -266,7 +296,20 @@ class BaseModel(MetaModel, nn.Module):
|
||||
"Error type for -Restore_Hint-, supported: int or string.")
|
||||
self._load_ckpt(save_name)
|
||||
|
||||
def fix_BN(self):
|
||||
for module in self.modules():
|
||||
classname = module.__class__.__name__
|
||||
if classname.find('BatchNorm') != -1:
|
||||
module.eval()
|
||||
|
||||
def inputs_pretreament(self, inputs):
|
||||
"""Conduct transforms on input data.
|
||||
|
||||
Args:
|
||||
inputs: the input data.
|
||||
Returns:
|
||||
tuple: training data including inputs, labels, and some meta data.
|
||||
"""
|
||||
seqs_batch, labs_batch, typs_batch, vies_batch, seqL_batch = inputs
|
||||
trf_cfgs = self.engine_cfg['transform']
|
||||
seq_trfs = get_transform(trf_cfgs)
|
||||
@@ -293,9 +336,13 @@ class BaseModel(MetaModel, nn.Module):
|
||||
return ipts, labs, typs, vies, seqL
|
||||
|
||||
def train_step(self, loss_sum) -> bool:
|
||||
'''
|
||||
Conduct loss_sum.backward(), self.optimizer.step() and self.scheduler.step().
|
||||
'''
|
||||
"""Conduct loss_sum.backward(), self.optimizer.step() and self.scheduler.step().
|
||||
|
||||
Args:
|
||||
loss_sum:The loss of the current batch.
|
||||
Returns:
|
||||
bool: True if the training is finished, False otherwise.
|
||||
"""
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
if loss_sum <= 1e-9:
|
||||
@@ -322,6 +369,13 @@ class BaseModel(MetaModel, nn.Module):
|
||||
return True
|
||||
|
||||
def inference(self, rank):
|
||||
"""Inference all the test data.
|
||||
|
||||
Args:
|
||||
rank: the rank of the current process.Transform
|
||||
Returns:
|
||||
Odict: contains the inference results.
|
||||
"""
|
||||
total_size = len(self.test_loader)
|
||||
if rank == 0:
|
||||
pbar = tqdm(total=total_size, desc='Transforming')
|
||||
@@ -355,9 +409,7 @@ class BaseModel(MetaModel, nn.Module):
|
||||
|
||||
@ staticmethod
|
||||
def run_train(model):
|
||||
'''
|
||||
Accept the instance object(model) here, and then run the train loop handler.
|
||||
'''
|
||||
"""Accept the instance object(model) here, and then run the train loop."""
|
||||
for inputs in model.train_loader:
|
||||
ipts = model.inputs_pretreament(inputs)
|
||||
with autocast(enabled=model.engine_cfg['enable_float16']):
|
||||
@@ -390,6 +442,8 @@ class BaseModel(MetaModel, nn.Module):
|
||||
|
||||
@ staticmethod
|
||||
def run_test(model):
|
||||
"""Accept the instance object(model) here, and then run the test loop."""
|
||||
|
||||
rank = torch.distributed.get_rank()
|
||||
with torch.no_grad():
|
||||
info_dict = model.inference(rank)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""The loss aggregator."""
|
||||
|
||||
import torch
|
||||
from . import losses
|
||||
from utils import is_dict, get_attr_from, get_valid_args, is_tensor, get_ddp_module
|
||||
@@ -6,18 +8,48 @@ from utils import get_msg_mgr
|
||||
|
||||
|
||||
class LossAggregator():
|
||||
"""The loss aggregator.
|
||||
|
||||
This class is used to aggregate the losses.
|
||||
For example, if you have two losses, one is triplet loss, the other is cross entropy loss,
|
||||
you can aggregate them as follows:
|
||||
loss_num = tripley_loss + cross_entropy_loss
|
||||
|
||||
Attributes:
|
||||
losses: A dict of losses.
|
||||
"""
|
||||
|
||||
def __init__(self, loss_cfg) -> None:
|
||||
"""
|
||||
Initialize the loss aggregator.
|
||||
|
||||
Args:
|
||||
loss_cfg: Config of losses. List for multiple losses.
|
||||
"""
|
||||
self.losses = {loss_cfg['log_prefix']: self._build_loss_(loss_cfg)} if is_dict(loss_cfg) \
|
||||
else {cfg['log_prefix']: self._build_loss_(cfg) for cfg in loss_cfg}
|
||||
|
||||
def _build_loss_(self, loss_cfg):
|
||||
"""Build the losses from loss_cfg.
|
||||
|
||||
Args:
|
||||
loss_cfg: Config of loss.
|
||||
"""
|
||||
Loss = get_attr_from([losses], loss_cfg['type'])
|
||||
valid_loss_arg = get_valid_args(
|
||||
Loss, loss_cfg, ['type', 'pair_based_loss'])
|
||||
loss = get_ddp_module(Loss(**valid_loss_arg))
|
||||
Loss, loss_cfg, ['type', 'gather_and_scale'])
|
||||
loss = get_ddp_module(Loss(**valid_loss_arg).cuda())
|
||||
return loss
|
||||
|
||||
def __call__(self, training_feats):
|
||||
"""Compute the sum of all losses.
|
||||
|
||||
The input is a dict of features. The key is the name of loss and the value is the feature and label. If the key not in
|
||||
built losses and the value is torch.Tensor, then it is the computed loss to be added loss_sum.
|
||||
|
||||
Args:
|
||||
training_feats: A dict of features. The same as the output["training_feat"] of the model.
|
||||
"""
|
||||
loss_sum = .0
|
||||
loss_info = Odict()
|
||||
|
||||
@@ -28,14 +60,12 @@ class LossAggregator():
|
||||
for name, value in info.items():
|
||||
loss_info['scalar/%s/%s' % (k, name)] = value
|
||||
loss = loss.mean() * loss_func.loss_term_weights
|
||||
if loss_func.pair_based_loss:
|
||||
loss = loss * torch.distributed.get_world_size()
|
||||
loss_sum += loss
|
||||
|
||||
else:
|
||||
if isinstance(v, dict):
|
||||
raise ValueError(
|
||||
"The key %s in -Trainng-Feat- should be stated as the log_prefix of a certain loss defined in your loss_cfg."
|
||||
"The key %s in -Trainng-Feat- should be stated as the log_prefix of a certain loss defined in your loss_cfg."%v
|
||||
)
|
||||
elif is_tensor(v):
|
||||
_ = v.mean()
|
||||
|
||||
@@ -1,13 +1,54 @@
|
||||
from ctypes import ArgumentError
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
from utils import Odict
|
||||
import functools
|
||||
from utils import ddp_all_gather
|
||||
|
||||
class BasicLoss(nn.Module):
|
||||
def __init__(self, loss_term_weights=1.0):
|
||||
super(BasicLoss, self).__init__()
|
||||
|
||||
self.loss_term_weights = loss_term_weights
|
||||
self.pair_based_loss = True
|
||||
self.info = Odict()
|
||||
|
||||
def gather_and_scale_wrapper(func):
|
||||
"""Internal wrapper: gather the input from multple cards to one card, and scale the loss by the number of cards.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def inner(*args, **kwds):
|
||||
try:
|
||||
|
||||
for k, v in kwds.items():
|
||||
kwds[k] = ddp_all_gather(v)
|
||||
|
||||
loss, loss_info = func(*args, **kwds)
|
||||
loss *= torch.distributed.get_world_size()
|
||||
return loss, loss_info
|
||||
except:
|
||||
raise ArgumentError
|
||||
return inner
|
||||
|
||||
|
||||
class BaseLoss(nn.Module):
|
||||
"""
|
||||
Base class for all losses.
|
||||
|
||||
Your loss should also subclass this class.
|
||||
|
||||
Attribute:
|
||||
loss_term_weights: the weight of the loss.
|
||||
info: the loss info.
|
||||
"""
|
||||
loss_term_weights = 1.0
|
||||
info = Odict()
|
||||
|
||||
def forward(self, logits, labels):
|
||||
raise NotImplementedError
|
||||
"""
|
||||
The default forward function.
|
||||
|
||||
This function should be overridden by the subclass.
|
||||
|
||||
Args:
|
||||
logits: the logits of the model.
|
||||
labels: the labels of the data.
|
||||
|
||||
Returns:
|
||||
tuple of loss and info.
|
||||
"""
|
||||
return .0, self.info
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .base import BasicLoss
|
||||
from .base import BaseLoss
|
||||
|
||||
|
||||
class CrossEntropyLoss(BasicLoss):
|
||||
class CrossEntropyLoss(BaseLoss):
|
||||
def __init__(self, scale=2**4, label_smooth=True, eps=0.1, loss_term_weights=1.0, log_accuracy=False):
|
||||
super(CrossEntropyLoss, self).__init__()
|
||||
self.scale = scale
|
||||
@@ -13,7 +13,6 @@ class CrossEntropyLoss(BasicLoss):
|
||||
self.log_accuracy = log_accuracy
|
||||
|
||||
self.loss_term_weights = loss_term_weights
|
||||
self.pair_based_loss = False
|
||||
|
||||
def forward(self, logits, labels):
|
||||
"""
|
||||
@@ -26,7 +25,7 @@ class CrossEntropyLoss(BasicLoss):
|
||||
one_hot_labels = self.label2one_hot(
|
||||
labels, c).unsqueeze(0).repeat(p, 1, 1) # [p, n, c]
|
||||
loss = self.compute_loss(log_preds, one_hot_labels)
|
||||
self.info.update({'loss': loss})
|
||||
self.info.update({'loss': loss.detach().clone()})
|
||||
if self.log_accuracy:
|
||||
pred = logits.argmax(dim=-1) # [p, n]
|
||||
accu = (pred == labels.unsqueeze(0)).float().mean()
|
||||
|
||||
@@ -1,22 +1,19 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .base import BasicLoss
|
||||
from utils import ddp_all_gather
|
||||
from .base import BaseLoss, gather_and_scale_wrapper
|
||||
|
||||
|
||||
class TripletLoss(BasicLoss):
|
||||
class TripletLoss(BaseLoss):
|
||||
def __init__(self, margin, loss_term_weights=1.0):
|
||||
super(TripletLoss, self).__init__()
|
||||
self.margin = margin
|
||||
|
||||
self.loss_term_weights = loss_term_weights
|
||||
self.pair_based_loss = True
|
||||
|
||||
@gather_and_scale_wrapper
|
||||
def forward(self, embeddings, labels):
|
||||
# embeddings: [n, p, c], label: [n]
|
||||
embeddings = ddp_all_gather(embeddings)
|
||||
labels = ddp_all_gather(labels)
|
||||
embeddings = embeddings.permute(
|
||||
1, 0, 2).contiguous() # [n, p, c] -> [p, n, c]
|
||||
embeddings = embeddings.float()
|
||||
@@ -32,10 +29,10 @@ class TripletLoss(BasicLoss):
|
||||
loss_avg, loss_num = self.AvgNonZeroReducer(loss)
|
||||
|
||||
self.info.update({
|
||||
'loss': loss_avg,
|
||||
'hard_loss': hard_loss,
|
||||
'loss_num': loss_num,
|
||||
'mean_dist': mean_dist})
|
||||
'loss': loss_avg.detach().clone(),
|
||||
'hard_loss': hard_loss.detach().clone(),
|
||||
'loss_num': loss_num.detach().clone(),
|
||||
'mean_dist': mean_dist.detach().clone()})
|
||||
|
||||
return loss_avg, self.info
|
||||
|
||||
|
||||
@@ -63,8 +63,8 @@ class GeMHPP(nn.Module):
|
||||
|
||||
class GaitGL(BaseModel):
|
||||
"""
|
||||
Title: Gait Recognition via Effective Global-Local Feature Representation and Local Temporal Aggregation
|
||||
ICCV2021: https://openaccess.thecvf.com/content/ICCV2021/papers/Lin_Gait_Recognition_via_Effective_Global-Local_Feature_Representation_and_Local_Temporal_ICCV_2021_paper.pdf
|
||||
GaitGL: Gait Recognition via Effective Global-Local Feature Representation and Local Temporal Aggregation
|
||||
Arxiv : https://arxiv.org/pdf/2011.01461.pdf
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kargs):
|
||||
@@ -73,31 +73,71 @@ class GaitGL(BaseModel):
|
||||
def build_network(self, model_cfg):
|
||||
in_c = model_cfg['channels']
|
||||
class_num = model_cfg['class_num']
|
||||
dataset_name = self.cfgs['data_cfg']['dataset_name']
|
||||
|
||||
# For CASIA-B
|
||||
self.conv3d = nn.Sequential(
|
||||
BasicConv3d(1, in_c[0], kernel_size=(3, 3, 3),
|
||||
stride=(1, 1, 1), padding=(1, 1, 1)),
|
||||
nn.LeakyReLU(inplace=True)
|
||||
)
|
||||
self.LTA = nn.Sequential(
|
||||
BasicConv3d(in_c[0], in_c[0], kernel_size=(
|
||||
3, 1, 1), stride=(3, 1, 1), padding=(0, 0, 0)),
|
||||
nn.LeakyReLU(inplace=True)
|
||||
)
|
||||
if dataset_name == 'OUMVLP':
|
||||
# For OUMVLP
|
||||
self.conv3d = nn.Sequential(
|
||||
BasicConv3d(1, in_c[0], kernel_size=(3, 3, 3),
|
||||
stride=(1, 1, 1), padding=(1, 1, 1)),
|
||||
nn.LeakyReLU(inplace=True),
|
||||
BasicConv3d(in_c[0], in_c[0], kernel_size=(3, 3, 3),
|
||||
stride=(1, 1, 1), padding=(1, 1, 1)),
|
||||
nn.LeakyReLU(inplace=True),
|
||||
)
|
||||
self.LTA = nn.Sequential(
|
||||
BasicConv3d(in_c[0], in_c[0], kernel_size=(
|
||||
3, 1, 1), stride=(3, 1, 1), padding=(0, 0, 0)),
|
||||
nn.LeakyReLU(inplace=True)
|
||||
)
|
||||
|
||||
self.GLConvA0 = GLConv(in_c[0], in_c[1], halving=3, fm_sign=False, kernel_size=(
|
||||
3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
|
||||
self.MaxPool0 = nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2))
|
||||
self.GLConvA0 = nn.Sequential(
|
||||
GLConv(in_c[0], in_c[1], halving=1, fm_sign=False, kernel_size=(
|
||||
3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1)),
|
||||
GLConv(in_c[1], in_c[1], halving=1, fm_sign=False, kernel_size=(
|
||||
3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1)),
|
||||
)
|
||||
self.MaxPool0 = nn.MaxPool3d(
|
||||
kernel_size=(1, 2, 2), stride=(1, 2, 2))
|
||||
|
||||
self.GLConvA1 = GLConv(in_c[1], in_c[2], halving=3, fm_sign=False, kernel_size=(
|
||||
3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
|
||||
self.GLConvB2 = GLConv(in_c[2], in_c[2], halving=3, fm_sign=True, kernel_size=(
|
||||
3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
|
||||
self.GLConvA1 = nn.Sequential(
|
||||
GLConv(in_c[1], in_c[2], halving=1, fm_sign=False, kernel_size=(
|
||||
3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1)),
|
||||
GLConv(in_c[2], in_c[2], halving=1, fm_sign=False, kernel_size=(
|
||||
3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1)),
|
||||
)
|
||||
self.GLConvB2 = nn.Sequential(
|
||||
GLConv(in_c[2], in_c[3], halving=1, fm_sign=False, kernel_size=(
|
||||
3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1)),
|
||||
GLConv(in_c[3], in_c[3], halving=1, fm_sign=True, kernel_size=(
|
||||
3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1)),
|
||||
)
|
||||
else:
|
||||
# For CASIA-B or other unstated datasets.
|
||||
self.conv3d = nn.Sequential(
|
||||
BasicConv3d(1, in_c[0], kernel_size=(3, 3, 3),
|
||||
stride=(1, 1, 1), padding=(1, 1, 1)),
|
||||
nn.LeakyReLU(inplace=True)
|
||||
)
|
||||
self.LTA = nn.Sequential(
|
||||
BasicConv3d(in_c[0], in_c[0], kernel_size=(
|
||||
3, 1, 1), stride=(3, 1, 1), padding=(0, 0, 0)),
|
||||
nn.LeakyReLU(inplace=True)
|
||||
)
|
||||
|
||||
self.Head0 = SeparateFCs(64, in_c[2], in_c[2])
|
||||
self.Bn = nn.BatchNorm1d(in_c[2])
|
||||
self.Head1 = SeparateFCs(64, in_c[2], class_num)
|
||||
self.GLConvA0 = GLConv(in_c[0], in_c[1], halving=3, fm_sign=False, kernel_size=(
|
||||
3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
|
||||
self.MaxPool0 = nn.MaxPool3d(
|
||||
kernel_size=(1, 2, 2), stride=(1, 2, 2))
|
||||
|
||||
self.GLConvA1 = GLConv(in_c[1], in_c[2], halving=3, fm_sign=False, kernel_size=(
|
||||
3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
|
||||
self.GLConvB2 = GLConv(in_c[2], in_c[2], halving=3, fm_sign=True, kernel_size=(
|
||||
3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
|
||||
|
||||
self.Head0 = SeparateFCs(64, in_c[-1], in_c[-1])
|
||||
self.Bn = nn.BatchNorm1d(in_c[-1])
|
||||
self.Head1 = SeparateFCs(64, in_c[-1], class_num)
|
||||
|
||||
self.TP = PackSequenceWrapper(torch.max)
|
||||
self.HPP = GeMHPP()
|
||||
@@ -105,7 +145,9 @@ class GaitGL(BaseModel):
|
||||
def forward(self, inputs):
|
||||
ipts, labs, _, _, seqL = inputs
|
||||
seqL = None if not self.training else seqL
|
||||
|
||||
if not self.training and len(labs) != 1:
|
||||
raise ValueError(
|
||||
'The input size of each GPU must be 1 in testing mode, but got {}!'.format(len(labs)))
|
||||
sils = ipts[0].unsqueeze(1)
|
||||
del ipts
|
||||
n, _, s, h, w = sils.size()
|
||||
|
||||
@@ -191,10 +191,3 @@ def RmBN2dAffine(model):
|
||||
if isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.requires_grad = False
|
||||
m.bias.requires_grad = False
|
||||
|
||||
|
||||
def fix_BN(model):
|
||||
for module in model.modules():
|
||||
classname = module.__class__.__name__
|
||||
if classname.find('BatchNorm2d') != -1:
|
||||
module.eval()
|
||||
|
||||
Reference in New Issue
Block a user