diff --git a/configs/default.yaml b/configs/default.yaml index 2c5e4b1..8467662 100644 --- a/configs/default.yaml +++ b/configs/default.yaml @@ -18,7 +18,6 @@ evaluator_cfg: sample_type: all_ordered type: InferenceSampler transform: - - img_w: 64 type: BaseSilCuttingTransform metric: euc # cos cross_view_gallery: false @@ -70,5 +69,4 @@ trainer_cfg: sample_type: fixed_unordered type: TripletSampler transform: - - img_w: 64 type: BaseSilCuttingTransform diff --git a/opengait/data/transform.py b/opengait/data/transform.py index 3655123..376e19a 100644 --- a/opengait/data/transform.py +++ b/opengait/data/transform.py @@ -10,8 +10,8 @@ class NoOperation(): class BaseSilTransform(): - def __init__(self, disvor=255.0, img_shape=None): - self.disvor = disvor + def __init__(self, divsor=255.0, img_shape=None): + self.divsor = divsor self.img_shape = img_shape def __call__(self, x): @@ -19,22 +19,21 @@ class BaseSilTransform(): s = x.shape[0] _ = [s] + [*self.img_shape] x = x.reshape(*_) - return x / self.disvor + return x / self.divsor class BaseSilCuttingTransform(): - def __init__(self, img_w=64, disvor=255.0, cutting=None): - self.img_w = img_w - self.disvor = disvor + def __init__(self, divsor=255.0, cutting=None): + self.divsor = divsor self.cutting = cutting def __call__(self, x): if self.cutting is not None: cutting = self.cutting else: - cutting = int(self.img_w // 64) * 10 + cutting = int(x.shape[-1] // 64) * 10 x = x[..., cutting:-cutting] - return x / self.disvor + return x / self.divsor class BaseRgbTransform():