remove img_w arg from transform function

This commit is contained in:
darkliang
2023-02-07 13:37:09 +08:00
parent bda25e8978
commit 311bc848fa
2 changed files with 7 additions and 10 deletions
-2
View File
@@ -18,7 +18,6 @@ evaluator_cfg:
sample_type: all_ordered sample_type: all_ordered
type: InferenceSampler type: InferenceSampler
transform: transform:
- img_w: 64
type: BaseSilCuttingTransform type: BaseSilCuttingTransform
metric: euc # cos metric: euc # cos
cross_view_gallery: false cross_view_gallery: false
@@ -70,5 +69,4 @@ trainer_cfg:
sample_type: fixed_unordered sample_type: fixed_unordered
type: TripletSampler type: TripletSampler
transform: transform:
- img_w: 64
type: BaseSilCuttingTransform type: BaseSilCuttingTransform
+7 -8
View File
@@ -10,8 +10,8 @@ class NoOperation():
class BaseSilTransform(): class BaseSilTransform():
def __init__(self, disvor=255.0, img_shape=None): def __init__(self, divsor=255.0, img_shape=None):
self.disvor = disvor self.divsor = divsor
self.img_shape = img_shape self.img_shape = img_shape
def __call__(self, x): def __call__(self, x):
@@ -19,22 +19,21 @@ class BaseSilTransform():
s = x.shape[0] s = x.shape[0]
_ = [s] + [*self.img_shape] _ = [s] + [*self.img_shape]
x = x.reshape(*_) x = x.reshape(*_)
return x / self.disvor return x / self.divsor
class BaseSilCuttingTransform(): class BaseSilCuttingTransform():
def __init__(self, img_w=64, disvor=255.0, cutting=None): def __init__(self, divsor=255.0, cutting=None):
self.img_w = img_w self.divsor = divsor
self.disvor = disvor
self.cutting = cutting self.cutting = cutting
def __call__(self, x): def __call__(self, x):
if self.cutting is not None: if self.cutting is not None:
cutting = self.cutting cutting = self.cutting
else: else:
cutting = int(self.img_w // 64) * 10 cutting = int(x.shape[-1] // 64) * 10
x = x[..., cutting:-cutting] x = x[..., cutting:-cutting]
return x / self.disvor return x / self.divsor
class BaseRgbTransform(): class BaseRgbTransform():