update gaitedge
This commit is contained in:
@@ -4,7 +4,7 @@ import torch.optim as optim
|
||||
|
||||
from ..base_model import BaseModel
|
||||
from .gaitgl import GaitGL
|
||||
from ..modules import SilhouetteCropAndResize
|
||||
from ..modules import GaitAlign
|
||||
from torchvision.transforms import Resize
|
||||
from utils import get_valid_args, get_attr_from, is_list_or_tuple
|
||||
import os.path as osp
|
||||
@@ -46,10 +46,12 @@ class GaitEdge(GaitGL):
|
||||
super(GaitEdge, self).build_network(model_cfg["GaitGL"])
|
||||
self.Backbone = self.get_backbone(model_cfg['Segmentation'])
|
||||
self.align = model_cfg['align']
|
||||
self.CROP = SilhouetteCropAndResize()
|
||||
self.gait_align = GaitAlign()
|
||||
self.resize = Resize((64, 44))
|
||||
self.is_edge = model_cfg['edge']
|
||||
self.seg_lr = model_cfg['seg_lr']
|
||||
self.kernel = torch.ones(
|
||||
(model_cfg['kernel_size'], model_cfg['kernel_size'])).cuda()
|
||||
|
||||
def finetune_parameters(self):
|
||||
fine_tune_params = list()
|
||||
@@ -88,14 +90,22 @@ class GaitEdge(GaitGL):
|
||||
"Error type for -Restore_Hint-, supported: int or string.")
|
||||
self._load_ckpt(save_name)
|
||||
|
||||
def preprocess(self, sils):
|
||||
|
||||
dilated_mask = (morph.dilation(sils, self.kernel).detach()
|
||||
) > 0.5 # Dilation
|
||||
eroded_mask = (morph.erosion(sils, self.kernel).detach()
|
||||
) > 0.5 # Dilation
|
||||
edge_mask = dilated_mask ^ eroded_mask
|
||||
return edge_mask, eroded_mask
|
||||
|
||||
def forward(self, inputs):
|
||||
ipts, labs, _, _, seqL = inputs
|
||||
|
||||
ratios = ipts[0]
|
||||
rgbs = ipts[1]
|
||||
sils = ipts[2]
|
||||
# if len(sils.size()) == 4:
|
||||
# sils = sils.unsqueeze(2)
|
||||
|
||||
n, s, c, h, w = rgbs.size()
|
||||
rgbs = rgbs.view(n*s, c, h, w)
|
||||
sils = sils.view(n*s, 1, h, w)
|
||||
@@ -103,24 +113,19 @@ class GaitEdge(GaitGL):
|
||||
logits = torch.sigmoid(logis)
|
||||
mask = torch.round(logits).float()
|
||||
if self.is_edge:
|
||||
kernel_1 = torch.ones((3, 3)).cuda()
|
||||
kernel_2 = torch.ones((3, 3)).cuda()
|
||||
|
||||
dilated_mask = (morph.dilation(sils, kernel_1).detach()
|
||||
) > 0.5 # Dilation
|
||||
eroded_mask = (morph.erosion(sils, kernel_2).detach()
|
||||
) > 0.5 # Dilation
|
||||
edge_mask = dilated_mask ^ eroded_mask
|
||||
edge_mask, eroded_mask = self.preprocess(sils)
|
||||
|
||||
# Gait Synthesis
|
||||
new_logits = edge_mask*logits+eroded_mask*sils
|
||||
|
||||
if self.align:
|
||||
cropped_logits = self.CROP(
|
||||
cropped_logits = self.gait_align(
|
||||
new_logits, sils, ratios)
|
||||
else:
|
||||
cropped_logits = self.resize(new_logits)
|
||||
else:
|
||||
if self.align:
|
||||
cropped_logits = self.CROP(
|
||||
cropped_logits = self.gait_align(
|
||||
logits, mask, ratios)
|
||||
else:
|
||||
cropped_logits = self.resize(logits)
|
||||
|
||||
Reference in New Issue
Block a user