Solve the problem of dimension misuse. (#59)

* commit for fix dimension

* fix dimension for all method

* restore config

* clean up baseline config

* add contiguous

* rm comment
This commit is contained in:
Junhao Liang
2022-06-28 12:27:16 +08:00
committed by GitHub
parent 715e7448fa
commit 14fa5212d4
14 changed files with 99 additions and 121 deletions
+1 -3
View File
@@ -46,9 +46,6 @@ model_cfg:
- M - M
- BC-256 - BC-256
- BC-256 - BC-256
# - M
# - BC-512
# - BC-512
type: Plain type: Plain
SeparateFCs: SeparateFCs:
in_channels: 256 in_channels: 256
@@ -81,6 +78,7 @@ trainer_cfg:
enable_float16: true # half_percesion float for memory reduction and speedup enable_float16: true # half_percesion float for memory reduction and speedup
fix_BN: false fix_BN: false
log_iter: 100 log_iter: 100
with_test: true
restore_ckpt_strict: true restore_ckpt_strict: true
restore_hint: 0 restore_hint: 0
save_iter: 10000 save_iter: 10000
+2 -2
View File
@@ -47,8 +47,8 @@ scheduler_cfg:
scheduler: MultiStepLR scheduler: MultiStepLR
trainer_cfg: trainer_cfg:
enable_distributed: true enable_float16: true
enable_float16: false with_test: true
log_iter: 100 log_iter: 100
restore_ckpt_strict: true restore_ckpt_strict: true
restore_hint: 0 restore_hint: 0
+1
View File
@@ -59,6 +59,7 @@ scheduler_cfg:
trainer_cfg: trainer_cfg:
enable_float16: true enable_float16: true
log_iter: 100 log_iter: 100
with_test: true
restore_ckpt_strict: true restore_ckpt_strict: true
restore_hint: 0 restore_hint: 0
save_iter: 10000 save_iter: 10000
+2 -1
View File
@@ -57,12 +57,13 @@ scheduler_cfg:
trainer_cfg: trainer_cfg:
enable_float16: true enable_float16: true
log_iter: 100 log_iter: 100
with_test: true
restore_ckpt_strict: true restore_ckpt_strict: true
restore_hint: 0 restore_hint: 0
save_iter: 10000 save_iter: 10000
save_name: GaitSet save_name: GaitSet
sync_BN: false sync_BN: false
total_iter: 42000 total_iter: 40000
sampler: sampler:
batch_shuffle: false batch_shuffle: false
batch_size: batch_size:
+1 -2
View File
@@ -74,10 +74,9 @@ scheduler_cfg:
scheduler: MultiStepLR scheduler: MultiStepLR
trainer_cfg: trainer_cfg:
enable_distributed: true
enable_float16: true enable_float16: true
fix_layers: false fix_layers: false
with_test: false with_test: true
log_iter: 100 log_iter: 100
optimizer_reset: false optimizer_reset: false
restore_ckpt_strict: true restore_ckpt_strict: true
+11 -13
View File
@@ -14,31 +14,29 @@ class CrossEntropyLoss(BaseLoss):
def forward(self, logits, labels): def forward(self, logits, labels):
""" """
logits: [n, p, c] logits: [n, c, p]
labels: [n] labels: [n]
""" """
logits = logits.permute(1, 0, 2).contiguous() # [n, p, c] -> [p, n, c] n, c, p = logits.size()
p, _, c = logits.size() log_preds = F.log_softmax(logits * self.scale, dim=1) # [n, c, p]
log_preds = F.log_softmax(logits * self.scale, dim=-1) # [p, n, c]
one_hot_labels = self.label2one_hot( one_hot_labels = self.label2one_hot(
labels, c).unsqueeze(0).repeat(p, 1, 1) # [p, n, c] labels, c).unsqueeze(2).repeat(1, 1, p) # [n, c, p]
loss = self.compute_loss(log_preds, one_hot_labels) loss = self.compute_loss(log_preds, one_hot_labels)
self.info.update({'loss': loss.detach().clone()}) self.info.update({'loss': loss.detach().clone()})
if self.log_accuracy: if self.log_accuracy:
pred = logits.argmax(dim=-1) # [p, n] pred = logits.argmax(dim=1) # [n, p]
accu = (pred == labels.unsqueeze(0)).float().mean() accu = (pred == labels.unsqueeze(1)).float().mean()
self.info.update({'accuracy': accu}) self.info.update({'accuracy': accu})
return loss, self.info return loss, self.info
def compute_loss(self, predis, labels): def compute_loss(self, predis, labels):
softmax_loss = -(labels * predis).sum(-1) # [p, n] softmax_loss = -(labels * predis).sum(1) # [n, p]
losses = softmax_loss.mean(-1) losses = softmax_loss.mean(0) # [p]
if self.label_smooth: if self.label_smooth:
smooth_loss = - predis.mean(dim=-1) # [p, n] smooth_loss = - predis.mean(dim=1) # [n, p]
smooth_loss = smooth_loss.mean() # [p] smooth_loss = smooth_loss.mean(0) # [p]
smooth_loss = smooth_loss * self.eps losses = smooth_loss * self.eps + losses * (1. - self.eps)
losses = smooth_loss + losses * (1. - self.eps)
return losses return losses
def label2one_hot(self, label, class_num): def label2one_hot(self, label, class_num):
+8 -8
View File
@@ -11,14 +11,13 @@ class TripletLoss(BaseLoss):
@gather_and_scale_wrapper @gather_and_scale_wrapper
def forward(self, embeddings, labels): def forward(self, embeddings, labels):
# embeddings: [n, p, c], label: [n] # embeddings: [n, c, p], label: [n]
embeddings = embeddings.permute( embeddings = embeddings.permute(
1, 0, 2).contiguous() # [n, p, c] -> [p, n, c] 2, 0, 1).contiguous().float() # [n, c, p] -> [p, n, c]
embeddings = embeddings.float()
ref_embed, ref_label = embeddings, labels ref_embed, ref_label = embeddings, labels
dist = self.ComputeDistance(embeddings, ref_embed) # [p, n1, n2] dist = self.ComputeDistance(embeddings, ref_embed) # [p, n1, n2]
mean_dist = dist.mean(1).mean(1) mean_dist = dist.mean((1, 2)) # [p]
ap_dist, an_dist = self.Convert2Triplets(labels, ref_label, dist) ap_dist, an_dist = self.Convert2Triplets(labels, ref_label, dist)
dist_diff = (ap_dist - an_dist).view(dist.size(0), -1) dist_diff = (ap_dist - an_dist).view(dist.size(0), -1)
loss = F.relu(dist_diff + self.margin) loss = F.relu(dist_diff + self.margin)
@@ -50,7 +49,7 @@ class TripletLoss(BaseLoss):
""" """
x2 = torch.sum(x ** 2, -1).unsqueeze(2) # [p, n_x, 1] x2 = torch.sum(x ** 2, -1).unsqueeze(2) # [p, n_x, 1]
y2 = torch.sum(y ** 2, -1).unsqueeze(1) # [p, 1, n_y] y2 = torch.sum(y ** 2, -1).unsqueeze(1) # [p, 1, n_y]
inner = x.matmul(y.transpose(-1, -2)) # [p, n_x, n_y] inner = x.matmul(y.transpose(1, 2)) # [p, n_x, n_y]
dist = x2 + y2 - 2 * inner dist = x2 + y2 - 2 * inner
dist = torch.sqrt(F.relu(dist)) # [p, n_x, n_y] dist = torch.sqrt(F.relu(dist)) # [p, n_x, n_y]
return dist return dist
@@ -60,9 +59,10 @@ class TripletLoss(BaseLoss):
row_labels: tensor with size [n_r] row_labels: tensor with size [n_r]
clo_label : tensor with size [n_c] clo_label : tensor with size [n_c]
""" """
matches = (row_labels.unsqueeze(1) == clo_label.unsqueeze(0)).bool() # [n_r, n_c] matches = (row_labels.unsqueeze(1) ==
diffenc = torch.logical_not(matches) # [n_r, n_c] clo_label.unsqueeze(0)).bool() # [n_r, n_c]
p, n, m = dist.size() diffenc = torch.logical_not(matches) # [n_r, n_c]
p, n, _ = dist.size()
ap_dist = dist[:, matches].view(p, n, -1, 1) ap_dist = dist[:, matches].view(p, n, -1, 1)
an_dist = dist[:, diffenc].view(p, n, 1, -1) an_dist = dist[:, diffenc].view(p, n, 1, -1)
return ap_dist, an_dist return ap_dist, an_dist
+6 -11
View File
@@ -19,26 +19,21 @@ class Baseline(BaseModel):
sils = ipts[0] sils = ipts[0]
if len(sils.size()) == 4: if len(sils.size()) == 4:
sils = sils.unsqueeze(2) sils = sils.unsqueeze(1)
del ipts del ipts
outs = self.Backbone(sils) # [n, s, c, h, w] outs = self.Backbone(sils) # [n, c, s, h, w]
# Temporal Pooling, TP # Temporal Pooling, TP
outs = self.TP(outs, seqL, dim=1)[0] # [n, c, h, w] outs = self.TP(outs, seqL, options={"dim": 2})[0] # [n, c, h, w]
# Horizontal Pooling Matching, HPM # Horizontal Pooling Matching, HPM
feat = self.HPP(outs) # [n, c, p] feat = self.HPP(outs) # [n, c, p]
feat = feat.permute(2, 0, 1).contiguous() # [p, n, c]
embed_1 = self.FCs(feat) # [p, n, c] embed_1 = self.FCs(feat) # [n, c, p]
embed_2, logits = self.BNNecks(embed_1) # [p, n, c] embed_2, logits = self.BNNecks(embed_1) # [n, c, p]
embed_1 = embed_1.permute(1, 0, 2).contiguous() # [n, p, c]
embed_2 = embed_2.permute(1, 0, 2).contiguous() # [n, p, c]
logits = logits.permute(1, 0, 2).contiguous() # [n, p, c]
embed = embed_1 embed = embed_1
n, s, _, h, w = sils.size() n, _, s, h, w = sils.size()
retval = { retval = {
'training_feat': { 'training_feat': {
'triplet': {'embeddings': embed_1, 'labels': labs}, 'triplet': {'embeddings': embed_1, 'labels': labs},
+12 -17
View File
@@ -75,7 +75,7 @@ class GaitGL(BaseModel):
class_num = model_cfg['class_num'] class_num = model_cfg['class_num']
dataset_name = self.cfgs['data_cfg']['dataset_name'] dataset_name = self.cfgs['data_cfg']['dataset_name']
if dataset_name in ['OUMVLP','GREW']: if dataset_name in ['OUMVLP', 'GREW']:
# For OUMVLP and GREW # For OUMVLP and GREW
self.conv3d = nn.Sequential( self.conv3d = nn.Sequential(
BasicConv3d(1, in_c[0], kernel_size=(3, 3, 3), BasicConv3d(1, in_c[0], kernel_size=(3, 3, 3),
@@ -135,12 +135,11 @@ class GaitGL(BaseModel):
self.GLConvB2 = GLConv(in_c[2], in_c[2], halving=3, fm_sign=True, kernel_size=( self.GLConvB2 = GLConv(in_c[2], in_c[2], halving=3, fm_sign=True, kernel_size=(
3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1)) 3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
self.TP = PackSequenceWrapper(torch.max) self.TP = PackSequenceWrapper(torch.max)
self.HPP = GeMHPP() self.HPP = GeMHPP()
self.Head0 = SeparateFCs(64, in_c[-1], in_c[-1]) self.Head0 = SeparateFCs(64, in_c[-1], in_c[-1])
if 'SeparateBNNecks' in model_cfg.keys(): if 'SeparateBNNecks' in model_cfg.keys():
self.BNNecks = SeparateBNNecks(**model_cfg['SeparateBNNecks']) self.BNNecks = SeparateBNNecks(**model_cfg['SeparateBNNecks'])
self.Bn_head = False self.Bn_head = False
@@ -171,22 +170,18 @@ class GaitGL(BaseModel):
outs = self.GLConvA1(outs) outs = self.GLConvA1(outs)
outs = self.GLConvB2(outs) # [n, c, s, h, w] outs = self.GLConvB2(outs) # [n, c, s, h, w]
outs = self.TP(outs, dim=2, seq_dim=2, seqL=seqL)[0] # [n, c, h, w] outs = self.TP(outs, seqL=seqL, options={"dim": 2})[0] # [n, c, h, w]
outs = self.HPP(outs) # [n, c, p] outs = self.HPP(outs) # [n, c, p]
outs = outs.permute(2, 0, 1).contiguous() # [p, n, c]
gait = self.Head0(outs) # [p, n, c] gait = self.Head0(outs) # [n, c, p]
if self.Bn_head: # Original GaitGL Head if self.Bn_head: # Original GaitGL Head
gait = gait.permute(1, 2, 0).contiguous() # [n, c, p]
bnft = self.Bn(gait) # [n, c, p] bnft = self.Bn(gait) # [n, c, p]
logi = self.Head1(bnft.permute(2, 0, 1).contiguous()) # [p, n, c] logi = self.Head1(bnft) # [n, c, p]
embed = bnft.permute(0, 2, 1).contiguous() # [n, p, c] embed = bnft
else: # BNNechk as Head else: # BNNechk as Head
bnft, logi = self.BNNecks(gait) # [p, n, c] bnft, logi = self.BNNecks(gait) # [n, c, p]
embed = gait.permute(1, 0, 2).contiguous() # [n, p, c] embed = gait
logi = logi.permute(1, 0, 2).contiguous() # [n, p, c]
n, _, s, h, w = sils.size() n, _, s, h, w = sils.size()
retval = { retval = {
+12 -13
View File
@@ -45,12 +45,12 @@ class TemporalFeatureAggregator(nn.Module):
def forward(self, x): def forward(self, x):
""" """
Input: x, [n, s, c, p] Input: x, [n, c, s, p]
Output: ret, [n, p, c] Output: ret, [n, c, p]
""" """
n, s, c, p = x.size() n, c, s, p = x.size()
x = x.permute(3, 0, 2, 1).contiguous() # [p, n, c, s] x = x.permute(3, 0, 1, 2).contiguous() # [p, n, c, s]
feature = x.split(1, 0) # [[n, c, s], ...] feature = x.split(1, 0) # [[1, n, c, s], ...]
x = x.view(-1, c, s) x = x.view(-1, c, s)
# MTB1: ConvNet1d & Sigmoid # MTB1: ConvNet1d & Sigmoid
@@ -73,7 +73,7 @@ class TemporalFeatureAggregator(nn.Module):
# Temporal Pooling # Temporal Pooling
ret = self.TP(feature3x1 + feature3x3, dim=-1)[0] # [p, n, c] ret = self.TP(feature3x1 + feature3x3, dim=-1)[0] # [p, n, c]
ret = ret.permute(1, 0, 2).contiguous() # [n, p, c] ret = ret.permute(1, 2, 0).contiguous() # [n, p, c]
return ret return ret
@@ -102,17 +102,16 @@ class GaitPart(BaseModel):
sils = ipts[0] sils = ipts[0]
if len(sils.size()) == 4: if len(sils.size()) == 4:
sils = sils.unsqueeze(2) sils = sils.unsqueeze(1)
del ipts del ipts
out = self.Backbone(sils) # [n, s, c, h, w] out = self.Backbone(sils) # [n, c, s, h, w]
out = self.HPP(out) # [n, s, c, p] out = self.HPP(out) # [n, c, s, p]
out = self.TFA(out, seqL) # [n, p, c] out = self.TFA(out, seqL) # [n, c, p]
embs = self.Head(out.permute(1, 0, 2).contiguous()) # [p, n, c] embs = self.Head(out) # [n, c, p]
embs = embs.permute(1, 0, 2).contiguous() # [n, p, c]
n, s, _, h, w = sils.size() n, _, s, h, w = sils.size()
retval = { retval = {
'training_feat': { 'training_feat': {
'triplet': {'embeddings': embs, 'labels': labs} 'triplet': {'embeddings': embs, 'labels': labs}
+5 -7
View File
@@ -49,30 +49,28 @@ class GaitSet(BaseModel):
ipts, labs, _, _, seqL = inputs ipts, labs, _, _, seqL = inputs
sils = ipts[0] # [n, s, h, w] sils = ipts[0] # [n, s, h, w]
if len(sils.size()) == 4: if len(sils.size()) == 4:
sils = sils.unsqueeze(2) sils = sils.unsqueeze(1)
del ipts del ipts
outs = self.set_block1(sils) outs = self.set_block1(sils)
gl = self.set_pooling(outs, seqL, dim=1)[0] gl = self.set_pooling(outs, seqL, options={"dim": 2})[0]
gl = self.gl_block2(gl) gl = self.gl_block2(gl)
outs = self.set_block2(outs) outs = self.set_block2(outs)
gl = gl + self.set_pooling(outs, seqL, dim=1)[0] gl = gl + self.set_pooling(outs, seqL, options={"dim": 2})[0]
gl = self.gl_block3(gl) gl = self.gl_block3(gl)
outs = self.set_block3(outs) outs = self.set_block3(outs)
outs = self.set_pooling(outs, seqL, dim=1)[0] outs = self.set_pooling(outs, seqL, options={"dim": 2})[0]
gl = gl + outs gl = gl + outs
# Horizontal Pooling Matching, HPM # Horizontal Pooling Matching, HPM
feature1 = self.HPP(outs) # [n, c, p] feature1 = self.HPP(outs) # [n, c, p]
feature2 = self.HPP(gl) # [n, c, p] feature2 = self.HPP(gl) # [n, c, p]
feature = torch.cat([feature1, feature2], -1) # [n, c, p] feature = torch.cat([feature1, feature2], -1) # [n, c, p]
feature = feature.permute(2, 0, 1).contiguous() # [p, n, c]
embs = self.Head(feature) embs = self.Head(feature)
embs = embs.permute(1, 0, 2).contiguous() # [n, p, c]
n, s, _, h, w = sils.size() n, _, s, h, w = sils.size()
retval = { retval = {
'training_feat': { 'training_feat': {
'triplet': {'embeddings': embs, 'labels': labs} 'triplet': {'embeddings': embs, 'labels': labs}
+7 -9
View File
@@ -89,12 +89,12 @@ class GLN(BaseModel):
sils = ipts[0] # [n, s, h, w] sils = ipts[0] # [n, s, h, w]
del ipts del ipts
if len(sils.size()) == 4: if len(sils.size()) == 4:
sils = sils.unsqueeze(2) sils = sils.unsqueeze(1)
n, s, _, h, w = sils.size() n, _, s, h, w = sils.size()
### stage 0 sil ### ### stage 0 sil ###
sil_0_outs = self.sil_stage_0(sils) sil_0_outs = self.sil_stage_0(sils)
stage_0_sil_set = self.set_pooling(sil_0_outs, seqL, dim=1)[0] stage_0_sil_set = self.set_pooling(sil_0_outs, seqL, options={"dim": 2})[0]
### stage 1 sil ### ### stage 1 sil ###
sil_1_ipts = self.MaxP_sil(sil_0_outs) sil_1_ipts = self.MaxP_sil(sil_0_outs)
@@ -105,13 +105,13 @@ class GLN(BaseModel):
sil_2_outs = self.sil_stage_2(sil_2_ipts) sil_2_outs = self.sil_stage_2(sil_2_ipts)
### stage 1 set ### ### stage 1 set ###
set_1_ipts = self.set_pooling(sil_1_ipts, seqL, dim=1)[0] set_1_ipts = self.set_pooling(sil_1_ipts, seqL, options={"dim": 2})[0]
stage_1_sil_set = self.set_pooling(sil_1_outs, seqL, dim=1)[0] stage_1_sil_set = self.set_pooling(sil_1_outs, seqL, options={"dim": 2})[0]
set_1_outs = self.set_stage_1(set_1_ipts) + stage_1_sil_set set_1_outs = self.set_stage_1(set_1_ipts) + stage_1_sil_set
### stage 2 set ### ### stage 2 set ###
set_2_ipts = self.MaxP_set(set_1_outs) set_2_ipts = self.MaxP_set(set_1_outs)
stage_2_sil_set = self.set_pooling(sil_2_outs, seqL, dim=1)[0] stage_2_sil_set = self.set_pooling(sil_2_outs, seqL, options={"dim": 2})[0]
set_2_outs = self.set_stage_2(set_2_ipts) + stage_2_sil_set set_2_outs = self.set_stage_2(set_2_ipts) + stage_2_sil_set
set1 = torch.cat((stage_0_sil_set, stage_0_sil_set), dim=1) set1 = torch.cat((stage_0_sil_set, stage_0_sil_set), dim=1)
@@ -133,11 +133,9 @@ class GLN(BaseModel):
set2 = self.HPP(set2) set2 = self.HPP(set2)
set3 = self.HPP(set3) set3 = self.HPP(set3)
feature = torch.cat([set1, set2, set3], - feature = torch.cat([set1, set2, set3], -1)
1).permute(2, 0, 1).contiguous()
feature = self.Head(feature) feature = self.Head(feature)
feature = feature.permute(1, 0, 2).contiguous() # n p c
# compact_bloack # compact_bloack
if not self.pretrain: if not self.pretrain:
+25 -29
View File
@@ -38,14 +38,14 @@ class SetBlockWrapper(nn.Module):
def forward(self, x, *args, **kwargs): def forward(self, x, *args, **kwargs):
""" """
In x: [n, s, c, h, w] In x: [n, c_in, s, h_in, w_in]
Out x: [n, s, ...] Out x: [n, c_out, s, h_out, w_out]
""" """
n, s, c, h, w = x.size() n, c, s, h, w = x.size()
x = self.forward_block(x.view(-1, c, h, w), *args, **kwargs) x = self.forward_block(x.transpose(
input_size = x.size() 1, 2).view(-1, c, h, w), *args, **kwargs)
output_size = [n, s] + [*input_size[1:]] output_size = x.size()
return x.view(*output_size) return x.reshape(n, s, *output_size[1:]).transpose(1, 2).contiguous()
class PackSequenceWrapper(nn.Module): class PackSequenceWrapper(nn.Module):
@@ -53,26 +53,20 @@ class PackSequenceWrapper(nn.Module):
super(PackSequenceWrapper, self).__init__() super(PackSequenceWrapper, self).__init__()
self.pooling_func = pooling_func self.pooling_func = pooling_func
def forward(self, seqs, seqL, seq_dim=1, **kwargs): def forward(self, seqs, seqL, dim=2, options={}):
""" """
In seqs: [n, s, ...] In seqs: [n, c, s, ...]
Out rets: [n, ...] Out rets: [n, ...]
""" """
if seqL is None: if seqL is None:
return self.pooling_func(seqs, **kwargs) return self.pooling_func(seqs, **options)
seqL = seqL[0].data.cpu().numpy().tolist() seqL = seqL[0].data.cpu().numpy().tolist()
start = [0] + np.cumsum(seqL).tolist()[:-1] start = [0] + np.cumsum(seqL).tolist()[:-1]
rets = [] rets = []
for curr_start, curr_seqL in zip(start, seqL): for curr_start, curr_seqL in zip(start, seqL):
narrowed_seq = seqs.narrow(seq_dim, curr_start, curr_seqL) narrowed_seq = seqs.narrow(dim, curr_start, curr_seqL)
# save the memory rets.append(self.pooling_func(narrowed_seq, **options))
# splited_narrowed_seq = torch.split(narrowed_seq, 256, dim=1)
# ret = []
# for seq_to_pooling in splited_narrowed_seq:
# ret.append(self.pooling_func(seq_to_pooling, keepdim=True, **kwargs)
# [0] if self.is_tuple_result else self.pooling_func(seq_to_pooling, **kwargs))
rets.append(self.pooling_func(narrowed_seq, **kwargs))
if len(rets) > 0 and is_list_or_tuple(rets[0]): if len(rets) > 0 and is_list_or_tuple(rets[0]):
return [torch.cat([ret[j] for ret in rets]) return [torch.cat([ret[j] for ret in rets])
for j in range(len(rets[0]))] for j in range(len(rets[0]))]
@@ -101,13 +95,15 @@ class SeparateFCs(nn.Module):
def forward(self, x): def forward(self, x):
""" """
x: [p, n, c] x: [n, c_in, p]
out: [n, c_out, p]
""" """
x = x.permute(2, 0, 1).contiguous()
if self.norm: if self.norm:
out = x.matmul(F.normalize(self.fc_bin, dim=1)) out = x.matmul(F.normalize(self.fc_bin, dim=1))
else: else:
out = x.matmul(self.fc_bin) out = x.matmul(self.fc_bin)
return out return out.permute(1, 2, 0).contiguous()
class SeparateBNNecks(nn.Module): class SeparateBNNecks(nn.Module):
@@ -133,24 +129,24 @@ class SeparateBNNecks(nn.Module):
def forward(self, x): def forward(self, x):
""" """
x: [p, n, c] x: [n, c, p]
""" """
if self.parallel_BN1d: if self.parallel_BN1d:
p, n, c = x.size() n, c, p = x.size()
x = x.transpose(0, 1).contiguous().view(n, -1) # [n, p*c] x = x.view(n, -1) # [n, c*p]
x = self.bn1d(x) x = self.bn1d(x)
x = x.view(n, p, c).permute(1, 0, 2).contiguous() x = x.view(n, c, p)
else: else:
x = torch.cat([bn(_.squeeze(0)).unsqueeze(0) x = torch.cat([bn(_x) for _x, bn in zip(
for _, bn in zip(x.split(1, 0), self.bn1d)], 0) # [p, n, c] x.split(1, 2), self.bn1d)], 2) # [p, n, c]
feature = x.permute(2, 0, 1).contiguous()
if self.norm: if self.norm:
feature = F.normalize(x, dim=-1) # [p, n, c] feature = F.normalize(feature, dim=-1) # [p, n, c]
logits = feature.matmul(F.normalize( logits = feature.matmul(F.normalize(
self.fc_bin, dim=1)) # [p, n, c] self.fc_bin, dim=1)) # [p, n, c]
else: else:
feature = x
logits = feature.matmul(self.fc_bin) logits = feature.matmul(self.fc_bin)
return feature, logits return feature.permute(1, 2, 0).contiguous(), logits.permute(1, 2, 0).contiguous()
class FocalConv2d(nn.Module): class FocalConv2d(nn.Module):
+6 -6
View File
@@ -10,20 +10,20 @@ def cuda_dist(x, y, metric='euc'):
x = torch.from_numpy(x).cuda() x = torch.from_numpy(x).cuda()
y = torch.from_numpy(y).cuda() y = torch.from_numpy(y).cuda()
if metric == 'cos': if metric == 'cos':
x = F.normalize(x, p=2, dim=2) # n p c x = F.normalize(x, p=2, dim=1) # n c p
y = F.normalize(y, p=2, dim=2) # n p c y = F.normalize(y, p=2, dim=1) # n c p
num_bin = x.size(1) num_bin = x.size(2)
n_x = x.size(0) n_x = x.size(0)
n_y = y.size(0) n_y = y.size(0)
dist = torch.zeros(n_x, n_y).cuda() dist = torch.zeros(n_x, n_y).cuda()
for i in range(num_bin): for i in range(num_bin):
_x = x[:, i, ...] _x = x[:, :, i]
_y = y[:, i, ...] _y = y[:, :, i]
if metric == 'cos': if metric == 'cos':
dist += torch.matmul(_x, _y.transpose(0, 1)) dist += torch.matmul(_x, _y.transpose(0, 1))
else: else:
_dist = torch.sum(_x ** 2, 1).unsqueeze(1) + torch.sum(_y ** 2, 1).unsqueeze( _dist = torch.sum(_x ** 2, 1).unsqueeze(1) + torch.sum(_y ** 2, 1).unsqueeze(
1).transpose(0, 1) - 2 * torch.matmul(_x, _y.transpose(0, 1)) 0) - 2 * torch.matmul(_x, _y.transpose(0, 1))
dist += torch.sqrt(F.relu(_dist)) dist += torch.sqrt(F.relu(_dist))
return 1 - dist/num_bin if metric == 'cos' else dist / num_bin return 1 - dist/num_bin if metric == 'cos' else dist / num_bin