update gaitedge

This commit is contained in:
darkliang
2022-07-19 14:14:48 +08:00
parent 13894439a4
commit 4b681fb9bd
7 changed files with 31 additions and 22 deletions
+19 -14
View File
@@ -4,7 +4,7 @@ import torch.optim as optim
from ..base_model import BaseModel
from .gaitgl import GaitGL
from ..modules import SilhouetteCropAndResize
from ..modules import GaitAlign
from torchvision.transforms import Resize
from utils import get_valid_args, get_attr_from, is_list_or_tuple
import os.path as osp
@@ -46,10 +46,12 @@ class GaitEdge(GaitGL):
super(GaitEdge, self).build_network(model_cfg["GaitGL"])
self.Backbone = self.get_backbone(model_cfg['Segmentation'])
self.align = model_cfg['align']
self.CROP = SilhouetteCropAndResize()
self.gait_align = GaitAlign()
self.resize = Resize((64, 44))
self.is_edge = model_cfg['edge']
self.seg_lr = model_cfg['seg_lr']
self.kernel = torch.ones(
(model_cfg['kernel_size'], model_cfg['kernel_size'])).cuda()
def finetune_parameters(self):
fine_tune_params = list()
@@ -88,14 +90,22 @@ class GaitEdge(GaitGL):
"Error type for -Restore_Hint-, supported: int or string.")
self._load_ckpt(save_name)
def preprocess(self, sils):
dilated_mask = (morph.dilation(sils, self.kernel).detach()
) > 0.5 # Dilation
eroded_mask = (morph.erosion(sils, self.kernel).detach()
) > 0.5 # Dilation
edge_mask = dilated_mask ^ eroded_mask
return edge_mask, eroded_mask
def forward(self, inputs):
ipts, labs, _, _, seqL = inputs
ratios = ipts[0]
rgbs = ipts[1]
sils = ipts[2]
# if len(sils.size()) == 4:
# sils = sils.unsqueeze(2)
n, s, c, h, w = rgbs.size()
rgbs = rgbs.view(n*s, c, h, w)
sils = sils.view(n*s, 1, h, w)
@@ -103,24 +113,19 @@ class GaitEdge(GaitGL):
logits = torch.sigmoid(logis)
mask = torch.round(logits).float()
if self.is_edge:
kernel_1 = torch.ones((3, 3)).cuda()
kernel_2 = torch.ones((3, 3)).cuda()
dilated_mask = (morph.dilation(sils, kernel_1).detach()
) > 0.5 # Dilation
eroded_mask = (morph.erosion(sils, kernel_2).detach()
) > 0.5 # Dilation
edge_mask = dilated_mask ^ eroded_mask
edge_mask, eroded_mask = self.preprocess(sils)
# Gait Synthesis
new_logits = edge_mask*logits+eroded_mask*sils
if self.align:
cropped_logits = self.CROP(
cropped_logits = self.gait_align(
new_logits, sils, ratios)
else:
cropped_logits = self.resize(new_logits)
else:
if self.align:
cropped_logits = self.CROP(
cropped_logits = self.gait_align(
logits, mask, ratios)
else:
cropped_logits = self.resize(logits)