add gaitedge training code

This commit is contained in:
darkliang
2022-07-17 13:47:50 +08:00
parent 4205c5f283
commit b183455eb8
17 changed files with 814 additions and 11 deletions
+105
View File
@@ -0,0 +1,105 @@
import torch.nn as nn
import torch
class ConvBlock(nn.Module):
def __init__(self, ch_in, ch_out):
super(ConvBlock, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=3,
stride=1, padding=1, bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True),
nn.Conv2d(ch_out, ch_out, kernel_size=3,
stride=1, padding=1, bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True)
)
def forward(self, x):
x = self.conv(x)
return x
class UpConv(nn.Module):
def __init__(self, ch_in, ch_out):
super(UpConv, self).__init__()
self.up = nn.Sequential(
nn.Upsample(scale_factor=2),
nn.Conv2d(ch_in, ch_out, kernel_size=3,
stride=1, padding=1, bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True)
)
def forward(self, x):
x = self.up(x)
return x
class U_Net(nn.Module):
def __init__(self, in_channels=3, freeze_half=True):
super(U_Net, self).__init__()
self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
self.Conv1 = ConvBlock(ch_in=in_channels, ch_out=16)
self.Conv2 = ConvBlock(ch_in=16, ch_out=32)
self.Conv3 = ConvBlock(ch_in=32, ch_out=64)
self.Conv4 = ConvBlock(ch_in=64, ch_out=128)
self.freeze = freeze_half
# Begin Fine-tuning
if freeze_half:
self.Conv1.requires_grad_(False)
self.Conv2.requires_grad_(False)
self.Conv3.requires_grad_(False)
self.Conv4.requires_grad_(False)
# End Fine-tuning
self.Up4 = UpConv(ch_in=128, ch_out=64)
self.Up_conv4 = ConvBlock(ch_in=128, ch_out=64)
self.Up3 = UpConv(ch_in=64, ch_out=32)
self.Up_conv3 = ConvBlock(ch_in=64, ch_out=32)
self.Up2 = UpConv(ch_in=32, ch_out=16)
self.Up_conv2 = ConvBlock(ch_in=32, ch_out=16)
self.Conv_1x1 = nn.Conv2d(
16, 1, kernel_size=1, stride=1, padding=0)
def forward(self, x):
if self.freeze:
with torch.no_grad():
# encoding path
# Begin Fine-tuning
x1 = self.Conv1(x)
x2 = self.Maxpool(x1)
x2 = self.Conv2(x2)
x3 = self.Maxpool(x2)
x3 = self.Conv3(x3)
x4 = self.Maxpool(x3)
x4 = self.Conv4(x4)
# End Fine-tuning
else:
x1 = self.Conv1(x)
x2 = self.Maxpool(x1)
x2 = self.Conv2(x2)
x3 = self.Maxpool(x2)
x3 = self.Conv3(x3)
x4 = self.Maxpool(x3)
x4 = self.Conv4(x4)
d4 = self.Up4(x4)
d4 = torch.cat((x3, d4), dim=1)
d4 = self.Up_conv4(d4)
d3 = self.Up3(d4)
d3 = torch.cat((x2, d3), dim=1)
d3 = self.Up_conv3(d3)
d2 = self.Up2(d3)
d2 = torch.cat((x1, d2), dim=1)
d2 = self.Up_conv2(d2)
d1 = self.Conv_1x1(d2)
return d1
+41
View File
@@ -0,0 +1,41 @@
import torch
from .base import BaseLoss
from utils import MeanIOU
class BinaryCrossEntropyLoss(BaseLoss):
def __init__(self, loss_term_weight=1.0, eps=1.0e-9):
super(BinaryCrossEntropyLoss, self).__init__(loss_term_weight)
self.eps = eps
def forward(self, logits, labels):
"""
logits: [n, 1, h, w]
labels: [n, 1, h, w]
"""
# predts = torch.sigmoid(logits.float())
labels = labels.float()
logits = logits.float()
loss = - (labels * torch.log(logits + self.eps) +
(1 - labels) * torch.log(1. - logits + self.eps))
n = loss.size(0)
loss = loss.view(n, -1)
mean_loss = loss.mean()
hard_loss = loss.max()
miou = MeanIOU((logits > 0.5).float(), labels)
self.info.update({
'loss': mean_loss.detach().clone(),
'hard_loss': hard_loss.detach().clone(),
'miou': miou.detach().clone()})
return mean_loss, self.info
if __name__ == "__main__":
loss_func = BinaryCrossEntropyLoss()
ipts = torch.randn(1, 1, 128, 64)
tags = (torch.randn(1, 1, 128, 64) > 0.).float()
loss = loss_func(ipts, tags)
print(loss)
+135
View File
@@ -0,0 +1,135 @@
import torch
from kornia import morphology as morph
import torch.optim as optim
from ..base_model import BaseModel
from .gaitgl import GaitGL
from ..modules import SilhouetteCropAndResize
from torchvision.transforms import Resize
from utils import get_valid_args, get_attr_from, is_list_or_tuple
import os.path as osp
class Segmentation(BaseModel):
def forward(self, inputs):
ipts, labs, typs, vies, seqL = inputs
del seqL
# ratios = ipts[0]
rgbs = ipts[1]
sils = ipts[2]
# del ipts
n, s, c, h, w = rgbs.size()
rgbs = rgbs.view(n*s, c, h, w)
sils = sils.view(n*s, 1, h, w)
logi = self.Backbone(rgbs) # [n*s, c, h, w]
logits = torch.sigmoid(logi)
pred = (logits > 0.5).float() # [n*s, c, h, w]
retval = {
'training_feat': {
'bce': {'logits': logits, 'labels': sils}
},
'visual_summary': {
'image/sils': sils, 'image/logits': logits, "image/pred": pred
},
'inference_feat': {
'pred': pred, 'mask': sils
}
}
return retval
class GaitEdge(GaitGL):
def build_network(self, model_cfg):
super(GaitEdge, self).build_network(model_cfg["GaitGL"])
self.Backbone = self.get_backbone(model_cfg['Segmentation'])
self.align = model_cfg['align']
self.CROP = SilhouetteCropAndResize()
self.resize = Resize((64, 44))
self.is_edge = model_cfg['edge']
self.seg_lr = model_cfg['seg_lr']
def finetune_parameters(self):
fine_tune_params = list()
others_params = list()
for name, p in self.named_parameters():
if not p.requires_grad:
continue
if 'Backbone' in name:
fine_tune_params.append(p)
else:
others_params.append(p)
return [{'params': fine_tune_params, 'lr': self.seg_lr}, {'params': others_params}]
def get_optimizer(self, optimizer_cfg):
self.msg_mgr.log_info(optimizer_cfg)
optimizer = get_attr_from([optim], optimizer_cfg['solver'])
valid_arg = get_valid_args(optimizer, optimizer_cfg, ['solver'])
optimizer = optimizer(self.finetune_parameters(), **valid_arg)
return optimizer
def resume_ckpt(self, restore_hint):
if is_list_or_tuple(restore_hint):
for restore_hint_i in restore_hint:
self.resume_ckpt(restore_hint_i)
return
if isinstance(restore_hint, int):
save_name = self.engine_cfg['save_name']
save_name = osp.join(
self.save_path, 'checkpoints/{}-{:0>5}.pt'.format(save_name, restore_hint))
self.iteration = restore_hint
elif isinstance(restore_hint, str):
save_name = restore_hint
self.iteration = 0
else:
raise ValueError(
"Error type for -Restore_Hint-, supported: int or string.")
self._load_ckpt(save_name)
def forward(self, inputs):
ipts, labs, _, _, seqL = inputs
ratios = ipts[0]
rgbs = ipts[1]
sils = ipts[2]
# if len(sils.size()) == 4:
# sils = sils.unsqueeze(2)
n, s, c, h, w = rgbs.size()
rgbs = rgbs.view(n*s, c, h, w)
sils = sils.view(n*s, 1, h, w)
logis = self.Backbone(rgbs) # [n, s, c, h, w]
logits = torch.sigmoid(logis)
mask = torch.round(logits).float()
if self.is_edge:
kernel_1 = torch.ones((3, 3)).cuda()
kernel_2 = torch.ones((3, 3)).cuda()
dilated_mask = (morph.dilation(sils, kernel_1).detach()
) > 0.5 # Dilation
eroded_mask = (morph.erosion(sils, kernel_2).detach()
) > 0.5 # Dilation
edge_mask = dilated_mask ^ eroded_mask
new_logits = edge_mask*logits+eroded_mask*sils
if self.align:
cropped_logits = self.CROP(
new_logits, sils, ratios)
else:
cropped_logits = self.resize(new_logits)
else:
if self.align:
cropped_logits = self.CROP(
logits, mask, ratios)
else:
cropped_logits = self.resize(logits)
_, c, H, W = cropped_logits.size()
cropped_logits = cropped_logits.view(n, s, H, W)
retval = super(GaitEdge, self).forward(
[[cropped_logits], labs, None, None, seqL])
retval['training_feat']['bce'] = {'logits': logits, 'labels': sils}
retval['visual_summary']['image/roi'] = cropped_logits.view(
n*s, 1, H, W)
return retval
+56
View File
@@ -3,6 +3,7 @@ import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from utils import clones, is_list_or_tuple
from torchvision.ops import RoIAlign
class HorizontalPoolingPyramid():
@@ -182,6 +183,61 @@ class BasicConv3d(nn.Module):
return outs
class SilhouetteCropAndResize(nn.Module):
def __init__(self, H=64, W=44, eps=1, **kwargs):
super(SilhouetteCropAndResize, self).__init__()
self.H, self.W, self.eps = H, W, eps
self.Pad = nn.ZeroPad2d((int(self.W / 2), int(self.W / 2), 0, 0))
self.RoiPool = RoIAlign((self.H, self.W), 1, sampling_ratio=-1)
def forward(self, feature_map, binary_mask, w_h_ratio):
"""
In sils: [n, c, h, w]
w_h_ratio: [n, 1]
Out aligned_sils: [n, c, H, W]
"""
n, c, h, w = feature_map.size()
# w_h_ratio = w_h_ratio.repeat(1, 1) # [n, 1]
w_h_ratio = w_h_ratio.view(-1, 1) # [n, 1]
h_sum = binary_mask.sum(-1) # [n, c, h]
_ = (h_sum >= self.eps).float().cumsum(axis=-1) # [n, c, h]
h_top = (_ == 0).float().sum(-1) # [n, c]
h_bot = (_ != torch.max(_, dim=-1, keepdim=True)
[0]).float().sum(-1) + 1. # [n, c]
w_sum = binary_mask.sum(-2) # [n, c, w]
w_cumsum = w_sum.cumsum(axis=-1) # [n, c, w]
w_h_sum = w_sum.sum(-1).unsqueeze(-1) # [n, c, 1]
w_center = (w_cumsum < w_h_sum / 2.).float().sum(-1) # [n, c]
p1 = self.W - self.H * w_h_ratio
p1 = p1 / 2.
p1 = torch.clamp(p1, min=0) # [n, c]
t_w = w_h_ratio * self.H / w
p2 = p1 / t_w # [n, c]
height = h_bot - h_top # [n, c]
width = height * w / h # [n, c]
width_p = int(self.W / 2)
feature_map = self.Pad(feature_map)
w_center = w_center + width_p # [n, c]
w_left = w_center - width / 2 - p2 # [n, c]
w_right = w_center + width / 2 + p2 # [n, c]
w_left = torch.clamp(w_left, min=0., max=w+2*width_p)
w_right = torch.clamp(w_right, min=0., max=w+2*width_p)
boxes = torch.cat([w_left, h_top, w_right, h_bot], dim=-1)
# index of bbox in batch
box_index = torch.arange(n, device=feature_map.device)
rois = torch.cat([box_index.view(-1, 1), boxes], -1)
crops = self.RoiPool(feature_map, rois) # [n, c, H, W]
return crops
def RmBN2dAffine(model):
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
+1
View File
@@ -7,4 +7,5 @@ from .common import mkdir, clones
from .common import MergeCfgsDict
from .common import get_attr_from
from .common import NoOp
from .common import MeanIOU
from .msg_manager import get_msg_mgr
+13 -1
View File
@@ -138,7 +138,7 @@ def clones(module, N):
def config_loader(path):
with open(path, 'r') as stream:
src_cfgs = yaml.safe_load(stream)
with open("./config/default.yaml", 'r') as stream:
with open("./configs/default.yaml", 'r') as stream:
dst_cfgs = yaml.safe_load(stream)
MergeCfgsDict(src_cfgs, dst_cfgs)
return dst_cfgs
@@ -203,3 +203,15 @@ def get_ddp_module(module, **kwargs):
def params_count(net):
n_parameters = sum(p.numel() for p in net.parameters())
return 'Parameters Count: {:.5f}M'.format(n_parameters / 1e6)
def MeanIOU(msk1, msk2, eps=1.0e-9):
if not is_tensor(msk1):
msk1 = torch.from_numpy(msk1).cuda()
if not is_tensor(msk2):
msk2 = torch.from_numpy(msk2).cuda()
n = msk1.size(0)
inter = msk1 * msk2
union = ((msk1 + msk2) > 0.).float()
MeIOU = inter.view(n, -1).sum(-1) / (union.view(n, -1).sum(-1) + eps)
return MeIOU
+11 -3
View File
@@ -3,7 +3,7 @@ from time import strftime, localtime
import torch
import numpy as np
import torch.nn.functional as F
from utils import get_msg_mgr, mkdir
from utils import get_msg_mgr, mkdir, MeanIOU
def cuda_dist(x, y, metric='euc'):
@@ -124,10 +124,10 @@ def identification_real_scene(data, dataset, metric='euc'):
gallery_seq_type = {'0001-1000': ['1', '2'],
"HID2021": ['0'], '0001-1000-test': ['0'],
'GREW': ['01']}
'GREW': ['01'], 'TTG-200': ['1']}
probe_seq_type = {'0001-1000': ['3', '4', '5', '6'],
"HID2021": ['1'], '0001-1000-test': ['1'],
'GREW': ['02']}
'GREW': ['02'], 'TTG-200': ['2', '3', '4', '5', '6']}
num_rank = 20
acc = np.zeros([num_rank]) - 1.
@@ -274,3 +274,11 @@ def re_ranking(original_dist, query_num, k1, k2, lambda_value):
del jaccard_dist
final_dist = final_dist[:query_num, query_num:]
return final_dist
def mean_iou(data, dataset):
labels = data['mask']
pred = data['pred']
miou = MeanIOU(pred, labels)
get_msg_mgr().log_info('mIOU: %.3f' % (miou.mean()))
return {"scalar/test_accuracy/mIOU": miou}