Solve the problem of dimension misuse. (#59)
* commit for fix dimension * fix dimension for all method * restore config * clean up baseline config * add contiguous * rm comment
This commit is contained in:
@@ -46,9 +46,6 @@ model_cfg:
|
|||||||
- M
|
- M
|
||||||
- BC-256
|
- BC-256
|
||||||
- BC-256
|
- BC-256
|
||||||
# - M
|
|
||||||
# - BC-512
|
|
||||||
# - BC-512
|
|
||||||
type: Plain
|
type: Plain
|
||||||
SeparateFCs:
|
SeparateFCs:
|
||||||
in_channels: 256
|
in_channels: 256
|
||||||
@@ -81,6 +78,7 @@ trainer_cfg:
|
|||||||
enable_float16: true # half_percesion float for memory reduction and speedup
|
enable_float16: true # half_percesion float for memory reduction and speedup
|
||||||
fix_BN: false
|
fix_BN: false
|
||||||
log_iter: 100
|
log_iter: 100
|
||||||
|
with_test: true
|
||||||
restore_ckpt_strict: true
|
restore_ckpt_strict: true
|
||||||
restore_hint: 0
|
restore_hint: 0
|
||||||
save_iter: 10000
|
save_iter: 10000
|
||||||
|
|||||||
@@ -47,8 +47,8 @@ scheduler_cfg:
|
|||||||
scheduler: MultiStepLR
|
scheduler: MultiStepLR
|
||||||
|
|
||||||
trainer_cfg:
|
trainer_cfg:
|
||||||
enable_distributed: true
|
enable_float16: true
|
||||||
enable_float16: false
|
with_test: true
|
||||||
log_iter: 100
|
log_iter: 100
|
||||||
restore_ckpt_strict: true
|
restore_ckpt_strict: true
|
||||||
restore_hint: 0
|
restore_hint: 0
|
||||||
|
|||||||
@@ -59,6 +59,7 @@ scheduler_cfg:
|
|||||||
trainer_cfg:
|
trainer_cfg:
|
||||||
enable_float16: true
|
enable_float16: true
|
||||||
log_iter: 100
|
log_iter: 100
|
||||||
|
with_test: true
|
||||||
restore_ckpt_strict: true
|
restore_ckpt_strict: true
|
||||||
restore_hint: 0
|
restore_hint: 0
|
||||||
save_iter: 10000
|
save_iter: 10000
|
||||||
|
|||||||
@@ -57,12 +57,13 @@ scheduler_cfg:
|
|||||||
trainer_cfg:
|
trainer_cfg:
|
||||||
enable_float16: true
|
enable_float16: true
|
||||||
log_iter: 100
|
log_iter: 100
|
||||||
|
with_test: true
|
||||||
restore_ckpt_strict: true
|
restore_ckpt_strict: true
|
||||||
restore_hint: 0
|
restore_hint: 0
|
||||||
save_iter: 10000
|
save_iter: 10000
|
||||||
save_name: GaitSet
|
save_name: GaitSet
|
||||||
sync_BN: false
|
sync_BN: false
|
||||||
total_iter: 42000
|
total_iter: 40000
|
||||||
sampler:
|
sampler:
|
||||||
batch_shuffle: false
|
batch_shuffle: false
|
||||||
batch_size:
|
batch_size:
|
||||||
|
|||||||
@@ -74,10 +74,9 @@ scheduler_cfg:
|
|||||||
scheduler: MultiStepLR
|
scheduler: MultiStepLR
|
||||||
|
|
||||||
trainer_cfg:
|
trainer_cfg:
|
||||||
enable_distributed: true
|
|
||||||
enable_float16: true
|
enable_float16: true
|
||||||
fix_layers: false
|
fix_layers: false
|
||||||
with_test: false
|
with_test: true
|
||||||
log_iter: 100
|
log_iter: 100
|
||||||
optimizer_reset: false
|
optimizer_reset: false
|
||||||
restore_ckpt_strict: true
|
restore_ckpt_strict: true
|
||||||
|
|||||||
@@ -14,31 +14,29 @@ class CrossEntropyLoss(BaseLoss):
|
|||||||
|
|
||||||
def forward(self, logits, labels):
|
def forward(self, logits, labels):
|
||||||
"""
|
"""
|
||||||
logits: [n, p, c]
|
logits: [n, c, p]
|
||||||
labels: [n]
|
labels: [n]
|
||||||
"""
|
"""
|
||||||
logits = logits.permute(1, 0, 2).contiguous() # [n, p, c] -> [p, n, c]
|
n, c, p = logits.size()
|
||||||
p, _, c = logits.size()
|
log_preds = F.log_softmax(logits * self.scale, dim=1) # [n, c, p]
|
||||||
log_preds = F.log_softmax(logits * self.scale, dim=-1) # [p, n, c]
|
|
||||||
one_hot_labels = self.label2one_hot(
|
one_hot_labels = self.label2one_hot(
|
||||||
labels, c).unsqueeze(0).repeat(p, 1, 1) # [p, n, c]
|
labels, c).unsqueeze(2).repeat(1, 1, p) # [n, c, p]
|
||||||
loss = self.compute_loss(log_preds, one_hot_labels)
|
loss = self.compute_loss(log_preds, one_hot_labels)
|
||||||
self.info.update({'loss': loss.detach().clone()})
|
self.info.update({'loss': loss.detach().clone()})
|
||||||
if self.log_accuracy:
|
if self.log_accuracy:
|
||||||
pred = logits.argmax(dim=-1) # [p, n]
|
pred = logits.argmax(dim=1) # [n, p]
|
||||||
accu = (pred == labels.unsqueeze(0)).float().mean()
|
accu = (pred == labels.unsqueeze(1)).float().mean()
|
||||||
self.info.update({'accuracy': accu})
|
self.info.update({'accuracy': accu})
|
||||||
return loss, self.info
|
return loss, self.info
|
||||||
|
|
||||||
def compute_loss(self, predis, labels):
|
def compute_loss(self, predis, labels):
|
||||||
softmax_loss = -(labels * predis).sum(-1) # [p, n]
|
softmax_loss = -(labels * predis).sum(1) # [n, p]
|
||||||
losses = softmax_loss.mean(-1)
|
losses = softmax_loss.mean(0) # [p]
|
||||||
|
|
||||||
if self.label_smooth:
|
if self.label_smooth:
|
||||||
smooth_loss = - predis.mean(dim=-1) # [p, n]
|
smooth_loss = - predis.mean(dim=1) # [n, p]
|
||||||
smooth_loss = smooth_loss.mean() # [p]
|
smooth_loss = smooth_loss.mean(0) # [p]
|
||||||
smooth_loss = smooth_loss * self.eps
|
losses = smooth_loss * self.eps + losses * (1. - self.eps)
|
||||||
losses = smooth_loss + losses * (1. - self.eps)
|
|
||||||
return losses
|
return losses
|
||||||
|
|
||||||
def label2one_hot(self, label, class_num):
|
def label2one_hot(self, label, class_num):
|
||||||
|
|||||||
@@ -11,14 +11,13 @@ class TripletLoss(BaseLoss):
|
|||||||
|
|
||||||
@gather_and_scale_wrapper
|
@gather_and_scale_wrapper
|
||||||
def forward(self, embeddings, labels):
|
def forward(self, embeddings, labels):
|
||||||
# embeddings: [n, p, c], label: [n]
|
# embeddings: [n, c, p], label: [n]
|
||||||
embeddings = embeddings.permute(
|
embeddings = embeddings.permute(
|
||||||
1, 0, 2).contiguous() # [n, p, c] -> [p, n, c]
|
2, 0, 1).contiguous().float() # [n, c, p] -> [p, n, c]
|
||||||
embeddings = embeddings.float()
|
|
||||||
|
|
||||||
ref_embed, ref_label = embeddings, labels
|
ref_embed, ref_label = embeddings, labels
|
||||||
dist = self.ComputeDistance(embeddings, ref_embed) # [p, n1, n2]
|
dist = self.ComputeDistance(embeddings, ref_embed) # [p, n1, n2]
|
||||||
mean_dist = dist.mean(1).mean(1)
|
mean_dist = dist.mean((1, 2)) # [p]
|
||||||
ap_dist, an_dist = self.Convert2Triplets(labels, ref_label, dist)
|
ap_dist, an_dist = self.Convert2Triplets(labels, ref_label, dist)
|
||||||
dist_diff = (ap_dist - an_dist).view(dist.size(0), -1)
|
dist_diff = (ap_dist - an_dist).view(dist.size(0), -1)
|
||||||
loss = F.relu(dist_diff + self.margin)
|
loss = F.relu(dist_diff + self.margin)
|
||||||
@@ -50,7 +49,7 @@ class TripletLoss(BaseLoss):
|
|||||||
"""
|
"""
|
||||||
x2 = torch.sum(x ** 2, -1).unsqueeze(2) # [p, n_x, 1]
|
x2 = torch.sum(x ** 2, -1).unsqueeze(2) # [p, n_x, 1]
|
||||||
y2 = torch.sum(y ** 2, -1).unsqueeze(1) # [p, 1, n_y]
|
y2 = torch.sum(y ** 2, -1).unsqueeze(1) # [p, 1, n_y]
|
||||||
inner = x.matmul(y.transpose(-1, -2)) # [p, n_x, n_y]
|
inner = x.matmul(y.transpose(1, 2)) # [p, n_x, n_y]
|
||||||
dist = x2 + y2 - 2 * inner
|
dist = x2 + y2 - 2 * inner
|
||||||
dist = torch.sqrt(F.relu(dist)) # [p, n_x, n_y]
|
dist = torch.sqrt(F.relu(dist)) # [p, n_x, n_y]
|
||||||
return dist
|
return dist
|
||||||
@@ -60,9 +59,10 @@ class TripletLoss(BaseLoss):
|
|||||||
row_labels: tensor with size [n_r]
|
row_labels: tensor with size [n_r]
|
||||||
clo_label : tensor with size [n_c]
|
clo_label : tensor with size [n_c]
|
||||||
"""
|
"""
|
||||||
matches = (row_labels.unsqueeze(1) == clo_label.unsqueeze(0)).bool() # [n_r, n_c]
|
matches = (row_labels.unsqueeze(1) ==
|
||||||
diffenc = torch.logical_not(matches) # [n_r, n_c]
|
clo_label.unsqueeze(0)).bool() # [n_r, n_c]
|
||||||
p, n, m = dist.size()
|
diffenc = torch.logical_not(matches) # [n_r, n_c]
|
||||||
|
p, n, _ = dist.size()
|
||||||
ap_dist = dist[:, matches].view(p, n, -1, 1)
|
ap_dist = dist[:, matches].view(p, n, -1, 1)
|
||||||
an_dist = dist[:, diffenc].view(p, n, 1, -1)
|
an_dist = dist[:, diffenc].view(p, n, 1, -1)
|
||||||
return ap_dist, an_dist
|
return ap_dist, an_dist
|
||||||
|
|||||||
@@ -19,26 +19,21 @@ class Baseline(BaseModel):
|
|||||||
|
|
||||||
sils = ipts[0]
|
sils = ipts[0]
|
||||||
if len(sils.size()) == 4:
|
if len(sils.size()) == 4:
|
||||||
sils = sils.unsqueeze(2)
|
sils = sils.unsqueeze(1)
|
||||||
|
|
||||||
del ipts
|
del ipts
|
||||||
outs = self.Backbone(sils) # [n, s, c, h, w]
|
outs = self.Backbone(sils) # [n, c, s, h, w]
|
||||||
|
|
||||||
# Temporal Pooling, TP
|
# Temporal Pooling, TP
|
||||||
outs = self.TP(outs, seqL, dim=1)[0] # [n, c, h, w]
|
outs = self.TP(outs, seqL, options={"dim": 2})[0] # [n, c, h, w]
|
||||||
# Horizontal Pooling Matching, HPM
|
# Horizontal Pooling Matching, HPM
|
||||||
feat = self.HPP(outs) # [n, c, p]
|
feat = self.HPP(outs) # [n, c, p]
|
||||||
feat = feat.permute(2, 0, 1).contiguous() # [p, n, c]
|
|
||||||
|
|
||||||
embed_1 = self.FCs(feat) # [p, n, c]
|
embed_1 = self.FCs(feat) # [n, c, p]
|
||||||
embed_2, logits = self.BNNecks(embed_1) # [p, n, c]
|
embed_2, logits = self.BNNecks(embed_1) # [n, c, p]
|
||||||
|
|
||||||
embed_1 = embed_1.permute(1, 0, 2).contiguous() # [n, p, c]
|
|
||||||
embed_2 = embed_2.permute(1, 0, 2).contiguous() # [n, p, c]
|
|
||||||
logits = logits.permute(1, 0, 2).contiguous() # [n, p, c]
|
|
||||||
embed = embed_1
|
embed = embed_1
|
||||||
|
|
||||||
n, s, _, h, w = sils.size()
|
n, _, s, h, w = sils.size()
|
||||||
retval = {
|
retval = {
|
||||||
'training_feat': {
|
'training_feat': {
|
||||||
'triplet': {'embeddings': embed_1, 'labels': labs},
|
'triplet': {'embeddings': embed_1, 'labels': labs},
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ class GaitGL(BaseModel):
|
|||||||
class_num = model_cfg['class_num']
|
class_num = model_cfg['class_num']
|
||||||
dataset_name = self.cfgs['data_cfg']['dataset_name']
|
dataset_name = self.cfgs['data_cfg']['dataset_name']
|
||||||
|
|
||||||
if dataset_name in ['OUMVLP','GREW']:
|
if dataset_name in ['OUMVLP', 'GREW']:
|
||||||
# For OUMVLP and GREW
|
# For OUMVLP and GREW
|
||||||
self.conv3d = nn.Sequential(
|
self.conv3d = nn.Sequential(
|
||||||
BasicConv3d(1, in_c[0], kernel_size=(3, 3, 3),
|
BasicConv3d(1, in_c[0], kernel_size=(3, 3, 3),
|
||||||
@@ -135,12 +135,11 @@ class GaitGL(BaseModel):
|
|||||||
self.GLConvB2 = GLConv(in_c[2], in_c[2], halving=3, fm_sign=True, kernel_size=(
|
self.GLConvB2 = GLConv(in_c[2], in_c[2], halving=3, fm_sign=True, kernel_size=(
|
||||||
3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
|
3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
|
||||||
|
|
||||||
|
|
||||||
self.TP = PackSequenceWrapper(torch.max)
|
self.TP = PackSequenceWrapper(torch.max)
|
||||||
self.HPP = GeMHPP()
|
self.HPP = GeMHPP()
|
||||||
|
|
||||||
self.Head0 = SeparateFCs(64, in_c[-1], in_c[-1])
|
self.Head0 = SeparateFCs(64, in_c[-1], in_c[-1])
|
||||||
|
|
||||||
if 'SeparateBNNecks' in model_cfg.keys():
|
if 'SeparateBNNecks' in model_cfg.keys():
|
||||||
self.BNNecks = SeparateBNNecks(**model_cfg['SeparateBNNecks'])
|
self.BNNecks = SeparateBNNecks(**model_cfg['SeparateBNNecks'])
|
||||||
self.Bn_head = False
|
self.Bn_head = False
|
||||||
@@ -171,22 +170,18 @@ class GaitGL(BaseModel):
|
|||||||
outs = self.GLConvA1(outs)
|
outs = self.GLConvA1(outs)
|
||||||
outs = self.GLConvB2(outs) # [n, c, s, h, w]
|
outs = self.GLConvB2(outs) # [n, c, s, h, w]
|
||||||
|
|
||||||
outs = self.TP(outs, dim=2, seq_dim=2, seqL=seqL)[0] # [n, c, h, w]
|
outs = self.TP(outs, seqL=seqL, options={"dim": 2})[0] # [n, c, h, w]
|
||||||
outs = self.HPP(outs) # [n, c, p]
|
outs = self.HPP(outs) # [n, c, p]
|
||||||
outs = outs.permute(2, 0, 1).contiguous() # [p, n, c]
|
|
||||||
|
|
||||||
gait = self.Head0(outs) # [p, n, c]
|
gait = self.Head0(outs) # [n, c, p]
|
||||||
|
|
||||||
if self.Bn_head: # Original GaitGL Head
|
if self.Bn_head: # Original GaitGL Head
|
||||||
gait = gait.permute(1, 2, 0).contiguous() # [n, c, p]
|
|
||||||
bnft = self.Bn(gait) # [n, c, p]
|
bnft = self.Bn(gait) # [n, c, p]
|
||||||
logi = self.Head1(bnft.permute(2, 0, 1).contiguous()) # [p, n, c]
|
logi = self.Head1(bnft) # [n, c, p]
|
||||||
embed = bnft.permute(0, 2, 1).contiguous() # [n, p, c]
|
embed = bnft
|
||||||
else: # BNNechk as Head
|
else: # BNNechk as Head
|
||||||
bnft, logi = self.BNNecks(gait) # [p, n, c]
|
bnft, logi = self.BNNecks(gait) # [n, c, p]
|
||||||
embed = gait.permute(1, 0, 2).contiguous() # [n, p, c]
|
embed = gait
|
||||||
|
|
||||||
logi = logi.permute(1, 0, 2).contiguous() # [n, p, c]
|
|
||||||
|
|
||||||
n, _, s, h, w = sils.size()
|
n, _, s, h, w = sils.size()
|
||||||
retval = {
|
retval = {
|
||||||
|
|||||||
@@ -45,12 +45,12 @@ class TemporalFeatureAggregator(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
"""
|
"""
|
||||||
Input: x, [n, s, c, p]
|
Input: x, [n, c, s, p]
|
||||||
Output: ret, [n, p, c]
|
Output: ret, [n, c, p]
|
||||||
"""
|
"""
|
||||||
n, s, c, p = x.size()
|
n, c, s, p = x.size()
|
||||||
x = x.permute(3, 0, 2, 1).contiguous() # [p, n, c, s]
|
x = x.permute(3, 0, 1, 2).contiguous() # [p, n, c, s]
|
||||||
feature = x.split(1, 0) # [[n, c, s], ...]
|
feature = x.split(1, 0) # [[1, n, c, s], ...]
|
||||||
x = x.view(-1, c, s)
|
x = x.view(-1, c, s)
|
||||||
|
|
||||||
# MTB1: ConvNet1d & Sigmoid
|
# MTB1: ConvNet1d & Sigmoid
|
||||||
@@ -73,7 +73,7 @@ class TemporalFeatureAggregator(nn.Module):
|
|||||||
|
|
||||||
# Temporal Pooling
|
# Temporal Pooling
|
||||||
ret = self.TP(feature3x1 + feature3x3, dim=-1)[0] # [p, n, c]
|
ret = self.TP(feature3x1 + feature3x3, dim=-1)[0] # [p, n, c]
|
||||||
ret = ret.permute(1, 0, 2).contiguous() # [n, p, c]
|
ret = ret.permute(1, 2, 0).contiguous() # [n, p, c]
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@@ -102,17 +102,16 @@ class GaitPart(BaseModel):
|
|||||||
|
|
||||||
sils = ipts[0]
|
sils = ipts[0]
|
||||||
if len(sils.size()) == 4:
|
if len(sils.size()) == 4:
|
||||||
sils = sils.unsqueeze(2)
|
sils = sils.unsqueeze(1)
|
||||||
|
|
||||||
del ipts
|
del ipts
|
||||||
out = self.Backbone(sils) # [n, s, c, h, w]
|
out = self.Backbone(sils) # [n, c, s, h, w]
|
||||||
out = self.HPP(out) # [n, s, c, p]
|
out = self.HPP(out) # [n, c, s, p]
|
||||||
out = self.TFA(out, seqL) # [n, p, c]
|
out = self.TFA(out, seqL) # [n, c, p]
|
||||||
|
|
||||||
embs = self.Head(out.permute(1, 0, 2).contiguous()) # [p, n, c]
|
embs = self.Head(out) # [n, c, p]
|
||||||
embs = embs.permute(1, 0, 2).contiguous() # [n, p, c]
|
|
||||||
|
|
||||||
n, s, _, h, w = sils.size()
|
n, _, s, h, w = sils.size()
|
||||||
retval = {
|
retval = {
|
||||||
'training_feat': {
|
'training_feat': {
|
||||||
'triplet': {'embeddings': embs, 'labels': labs}
|
'triplet': {'embeddings': embs, 'labels': labs}
|
||||||
|
|||||||
@@ -49,30 +49,28 @@ class GaitSet(BaseModel):
|
|||||||
ipts, labs, _, _, seqL = inputs
|
ipts, labs, _, _, seqL = inputs
|
||||||
sils = ipts[0] # [n, s, h, w]
|
sils = ipts[0] # [n, s, h, w]
|
||||||
if len(sils.size()) == 4:
|
if len(sils.size()) == 4:
|
||||||
sils = sils.unsqueeze(2)
|
sils = sils.unsqueeze(1)
|
||||||
|
|
||||||
del ipts
|
del ipts
|
||||||
outs = self.set_block1(sils)
|
outs = self.set_block1(sils)
|
||||||
gl = self.set_pooling(outs, seqL, dim=1)[0]
|
gl = self.set_pooling(outs, seqL, options={"dim": 2})[0]
|
||||||
gl = self.gl_block2(gl)
|
gl = self.gl_block2(gl)
|
||||||
|
|
||||||
outs = self.set_block2(outs)
|
outs = self.set_block2(outs)
|
||||||
gl = gl + self.set_pooling(outs, seqL, dim=1)[0]
|
gl = gl + self.set_pooling(outs, seqL, options={"dim": 2})[0]
|
||||||
gl = self.gl_block3(gl)
|
gl = self.gl_block3(gl)
|
||||||
|
|
||||||
outs = self.set_block3(outs)
|
outs = self.set_block3(outs)
|
||||||
outs = self.set_pooling(outs, seqL, dim=1)[0]
|
outs = self.set_pooling(outs, seqL, options={"dim": 2})[0]
|
||||||
gl = gl + outs
|
gl = gl + outs
|
||||||
|
|
||||||
# Horizontal Pooling Matching, HPM
|
# Horizontal Pooling Matching, HPM
|
||||||
feature1 = self.HPP(outs) # [n, c, p]
|
feature1 = self.HPP(outs) # [n, c, p]
|
||||||
feature2 = self.HPP(gl) # [n, c, p]
|
feature2 = self.HPP(gl) # [n, c, p]
|
||||||
feature = torch.cat([feature1, feature2], -1) # [n, c, p]
|
feature = torch.cat([feature1, feature2], -1) # [n, c, p]
|
||||||
feature = feature.permute(2, 0, 1).contiguous() # [p, n, c]
|
|
||||||
embs = self.Head(feature)
|
embs = self.Head(feature)
|
||||||
embs = embs.permute(1, 0, 2).contiguous() # [n, p, c]
|
|
||||||
|
|
||||||
n, s, _, h, w = sils.size()
|
n, _, s, h, w = sils.size()
|
||||||
retval = {
|
retval = {
|
||||||
'training_feat': {
|
'training_feat': {
|
||||||
'triplet': {'embeddings': embs, 'labels': labs}
|
'triplet': {'embeddings': embs, 'labels': labs}
|
||||||
|
|||||||
@@ -89,12 +89,12 @@ class GLN(BaseModel):
|
|||||||
sils = ipts[0] # [n, s, h, w]
|
sils = ipts[0] # [n, s, h, w]
|
||||||
del ipts
|
del ipts
|
||||||
if len(sils.size()) == 4:
|
if len(sils.size()) == 4:
|
||||||
sils = sils.unsqueeze(2)
|
sils = sils.unsqueeze(1)
|
||||||
n, s, _, h, w = sils.size()
|
n, _, s, h, w = sils.size()
|
||||||
|
|
||||||
### stage 0 sil ###
|
### stage 0 sil ###
|
||||||
sil_0_outs = self.sil_stage_0(sils)
|
sil_0_outs = self.sil_stage_0(sils)
|
||||||
stage_0_sil_set = self.set_pooling(sil_0_outs, seqL, dim=1)[0]
|
stage_0_sil_set = self.set_pooling(sil_0_outs, seqL, options={"dim": 2})[0]
|
||||||
|
|
||||||
### stage 1 sil ###
|
### stage 1 sil ###
|
||||||
sil_1_ipts = self.MaxP_sil(sil_0_outs)
|
sil_1_ipts = self.MaxP_sil(sil_0_outs)
|
||||||
@@ -105,13 +105,13 @@ class GLN(BaseModel):
|
|||||||
sil_2_outs = self.sil_stage_2(sil_2_ipts)
|
sil_2_outs = self.sil_stage_2(sil_2_ipts)
|
||||||
|
|
||||||
### stage 1 set ###
|
### stage 1 set ###
|
||||||
set_1_ipts = self.set_pooling(sil_1_ipts, seqL, dim=1)[0]
|
set_1_ipts = self.set_pooling(sil_1_ipts, seqL, options={"dim": 2})[0]
|
||||||
stage_1_sil_set = self.set_pooling(sil_1_outs, seqL, dim=1)[0]
|
stage_1_sil_set = self.set_pooling(sil_1_outs, seqL, options={"dim": 2})[0]
|
||||||
set_1_outs = self.set_stage_1(set_1_ipts) + stage_1_sil_set
|
set_1_outs = self.set_stage_1(set_1_ipts) + stage_1_sil_set
|
||||||
|
|
||||||
### stage 2 set ###
|
### stage 2 set ###
|
||||||
set_2_ipts = self.MaxP_set(set_1_outs)
|
set_2_ipts = self.MaxP_set(set_1_outs)
|
||||||
stage_2_sil_set = self.set_pooling(sil_2_outs, seqL, dim=1)[0]
|
stage_2_sil_set = self.set_pooling(sil_2_outs, seqL, options={"dim": 2})[0]
|
||||||
set_2_outs = self.set_stage_2(set_2_ipts) + stage_2_sil_set
|
set_2_outs = self.set_stage_2(set_2_ipts) + stage_2_sil_set
|
||||||
|
|
||||||
set1 = torch.cat((stage_0_sil_set, stage_0_sil_set), dim=1)
|
set1 = torch.cat((stage_0_sil_set, stage_0_sil_set), dim=1)
|
||||||
@@ -133,11 +133,9 @@ class GLN(BaseModel):
|
|||||||
set2 = self.HPP(set2)
|
set2 = self.HPP(set2)
|
||||||
set3 = self.HPP(set3)
|
set3 = self.HPP(set3)
|
||||||
|
|
||||||
feature = torch.cat([set1, set2, set3], -
|
feature = torch.cat([set1, set2, set3], -1)
|
||||||
1).permute(2, 0, 1).contiguous()
|
|
||||||
|
|
||||||
feature = self.Head(feature)
|
feature = self.Head(feature)
|
||||||
feature = feature.permute(1, 0, 2).contiguous() # n p c
|
|
||||||
|
|
||||||
# compact_bloack
|
# compact_bloack
|
||||||
if not self.pretrain:
|
if not self.pretrain:
|
||||||
|
|||||||
@@ -38,14 +38,14 @@ class SetBlockWrapper(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x, *args, **kwargs):
|
def forward(self, x, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
In x: [n, s, c, h, w]
|
In x: [n, c_in, s, h_in, w_in]
|
||||||
Out x: [n, s, ...]
|
Out x: [n, c_out, s, h_out, w_out]
|
||||||
"""
|
"""
|
||||||
n, s, c, h, w = x.size()
|
n, c, s, h, w = x.size()
|
||||||
x = self.forward_block(x.view(-1, c, h, w), *args, **kwargs)
|
x = self.forward_block(x.transpose(
|
||||||
input_size = x.size()
|
1, 2).view(-1, c, h, w), *args, **kwargs)
|
||||||
output_size = [n, s] + [*input_size[1:]]
|
output_size = x.size()
|
||||||
return x.view(*output_size)
|
return x.reshape(n, s, *output_size[1:]).transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
|
||||||
class PackSequenceWrapper(nn.Module):
|
class PackSequenceWrapper(nn.Module):
|
||||||
@@ -53,26 +53,20 @@ class PackSequenceWrapper(nn.Module):
|
|||||||
super(PackSequenceWrapper, self).__init__()
|
super(PackSequenceWrapper, self).__init__()
|
||||||
self.pooling_func = pooling_func
|
self.pooling_func = pooling_func
|
||||||
|
|
||||||
def forward(self, seqs, seqL, seq_dim=1, **kwargs):
|
def forward(self, seqs, seqL, dim=2, options={}):
|
||||||
"""
|
"""
|
||||||
In seqs: [n, s, ...]
|
In seqs: [n, c, s, ...]
|
||||||
Out rets: [n, ...]
|
Out rets: [n, ...]
|
||||||
"""
|
"""
|
||||||
if seqL is None:
|
if seqL is None:
|
||||||
return self.pooling_func(seqs, **kwargs)
|
return self.pooling_func(seqs, **options)
|
||||||
seqL = seqL[0].data.cpu().numpy().tolist()
|
seqL = seqL[0].data.cpu().numpy().tolist()
|
||||||
start = [0] + np.cumsum(seqL).tolist()[:-1]
|
start = [0] + np.cumsum(seqL).tolist()[:-1]
|
||||||
|
|
||||||
rets = []
|
rets = []
|
||||||
for curr_start, curr_seqL in zip(start, seqL):
|
for curr_start, curr_seqL in zip(start, seqL):
|
||||||
narrowed_seq = seqs.narrow(seq_dim, curr_start, curr_seqL)
|
narrowed_seq = seqs.narrow(dim, curr_start, curr_seqL)
|
||||||
# save the memory
|
rets.append(self.pooling_func(narrowed_seq, **options))
|
||||||
# splited_narrowed_seq = torch.split(narrowed_seq, 256, dim=1)
|
|
||||||
# ret = []
|
|
||||||
# for seq_to_pooling in splited_narrowed_seq:
|
|
||||||
# ret.append(self.pooling_func(seq_to_pooling, keepdim=True, **kwargs)
|
|
||||||
# [0] if self.is_tuple_result else self.pooling_func(seq_to_pooling, **kwargs))
|
|
||||||
rets.append(self.pooling_func(narrowed_seq, **kwargs))
|
|
||||||
if len(rets) > 0 and is_list_or_tuple(rets[0]):
|
if len(rets) > 0 and is_list_or_tuple(rets[0]):
|
||||||
return [torch.cat([ret[j] for ret in rets])
|
return [torch.cat([ret[j] for ret in rets])
|
||||||
for j in range(len(rets[0]))]
|
for j in range(len(rets[0]))]
|
||||||
@@ -101,13 +95,15 @@ class SeparateFCs(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
"""
|
"""
|
||||||
x: [p, n, c]
|
x: [n, c_in, p]
|
||||||
|
out: [n, c_out, p]
|
||||||
"""
|
"""
|
||||||
|
x = x.permute(2, 0, 1).contiguous()
|
||||||
if self.norm:
|
if self.norm:
|
||||||
out = x.matmul(F.normalize(self.fc_bin, dim=1))
|
out = x.matmul(F.normalize(self.fc_bin, dim=1))
|
||||||
else:
|
else:
|
||||||
out = x.matmul(self.fc_bin)
|
out = x.matmul(self.fc_bin)
|
||||||
return out
|
return out.permute(1, 2, 0).contiguous()
|
||||||
|
|
||||||
|
|
||||||
class SeparateBNNecks(nn.Module):
|
class SeparateBNNecks(nn.Module):
|
||||||
@@ -133,24 +129,24 @@ class SeparateBNNecks(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
"""
|
"""
|
||||||
x: [p, n, c]
|
x: [n, c, p]
|
||||||
"""
|
"""
|
||||||
if self.parallel_BN1d:
|
if self.parallel_BN1d:
|
||||||
p, n, c = x.size()
|
n, c, p = x.size()
|
||||||
x = x.transpose(0, 1).contiguous().view(n, -1) # [n, p*c]
|
x = x.view(n, -1) # [n, c*p]
|
||||||
x = self.bn1d(x)
|
x = self.bn1d(x)
|
||||||
x = x.view(n, p, c).permute(1, 0, 2).contiguous()
|
x = x.view(n, c, p)
|
||||||
else:
|
else:
|
||||||
x = torch.cat([bn(_.squeeze(0)).unsqueeze(0)
|
x = torch.cat([bn(_x) for _x, bn in zip(
|
||||||
for _, bn in zip(x.split(1, 0), self.bn1d)], 0) # [p, n, c]
|
x.split(1, 2), self.bn1d)], 2) # [p, n, c]
|
||||||
|
feature = x.permute(2, 0, 1).contiguous()
|
||||||
if self.norm:
|
if self.norm:
|
||||||
feature = F.normalize(x, dim=-1) # [p, n, c]
|
feature = F.normalize(feature, dim=-1) # [p, n, c]
|
||||||
logits = feature.matmul(F.normalize(
|
logits = feature.matmul(F.normalize(
|
||||||
self.fc_bin, dim=1)) # [p, n, c]
|
self.fc_bin, dim=1)) # [p, n, c]
|
||||||
else:
|
else:
|
||||||
feature = x
|
|
||||||
logits = feature.matmul(self.fc_bin)
|
logits = feature.matmul(self.fc_bin)
|
||||||
return feature, logits
|
return feature.permute(1, 2, 0).contiguous(), logits.permute(1, 2, 0).contiguous()
|
||||||
|
|
||||||
|
|
||||||
class FocalConv2d(nn.Module):
|
class FocalConv2d(nn.Module):
|
||||||
|
|||||||
@@ -10,20 +10,20 @@ def cuda_dist(x, y, metric='euc'):
|
|||||||
x = torch.from_numpy(x).cuda()
|
x = torch.from_numpy(x).cuda()
|
||||||
y = torch.from_numpy(y).cuda()
|
y = torch.from_numpy(y).cuda()
|
||||||
if metric == 'cos':
|
if metric == 'cos':
|
||||||
x = F.normalize(x, p=2, dim=2) # n p c
|
x = F.normalize(x, p=2, dim=1) # n c p
|
||||||
y = F.normalize(y, p=2, dim=2) # n p c
|
y = F.normalize(y, p=2, dim=1) # n c p
|
||||||
num_bin = x.size(1)
|
num_bin = x.size(2)
|
||||||
n_x = x.size(0)
|
n_x = x.size(0)
|
||||||
n_y = y.size(0)
|
n_y = y.size(0)
|
||||||
dist = torch.zeros(n_x, n_y).cuda()
|
dist = torch.zeros(n_x, n_y).cuda()
|
||||||
for i in range(num_bin):
|
for i in range(num_bin):
|
||||||
_x = x[:, i, ...]
|
_x = x[:, :, i]
|
||||||
_y = y[:, i, ...]
|
_y = y[:, :, i]
|
||||||
if metric == 'cos':
|
if metric == 'cos':
|
||||||
dist += torch.matmul(_x, _y.transpose(0, 1))
|
dist += torch.matmul(_x, _y.transpose(0, 1))
|
||||||
else:
|
else:
|
||||||
_dist = torch.sum(_x ** 2, 1).unsqueeze(1) + torch.sum(_y ** 2, 1).unsqueeze(
|
_dist = torch.sum(_x ** 2, 1).unsqueeze(1) + torch.sum(_y ** 2, 1).unsqueeze(
|
||||||
1).transpose(0, 1) - 2 * torch.matmul(_x, _y.transpose(0, 1))
|
0) - 2 * torch.matmul(_x, _y.transpose(0, 1))
|
||||||
dist += torch.sqrt(F.relu(_dist))
|
dist += torch.sqrt(F.relu(_dist))
|
||||||
return 1 - dist/num_bin if metric == 'cos' else dist / num_bin
|
return 1 - dist/num_bin if metric == 'cos' else dist / num_bin
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user