Improved box cutting with always fixed tensor shapes.

This commit is contained in:
Daniel
2024-12-04 17:54:57 +01:00
parent 6452d20ec8
commit acf1d19b64
2 changed files with 254 additions and 216 deletions

View File

@ -2,6 +2,7 @@ import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.ops import roi_align
# ==================================================================================================
@ -20,35 +21,10 @@ class Letterbox(nn.Module):
self.target_size = target_size
self.fill_value = fill_value
def calc_params_and_crop(self, ishape, bbox=None):
ih0, iw0 = ishape[1], ishape[2]
def calc_params(self, ishape):
ih, iw = ishape[1], ishape[2]
th, tw = self.target_size
if bbox is not None:
bbox = bbox[0].float()
x1, y1, x2, y2 = bbox[0], bbox[1], bbox[2], bbox[3]
# Slightly increase bbox size
factor = 1.25
w = x2 - x1
h = y2 - y1
x1 -= w * (factor - 1) / 2
x2 += w * (factor - 1) / 2
y1 -= h * (factor - 1) / 2
y2 += h * (factor - 1) / 2
zero = torch.tensor(0)
x1 = torch.max(x1, zero).to(torch.int64)
y1 = torch.max(y1, zero).to(torch.int64)
x2 = torch.min(x2, iw0).to(torch.int64)
y2 = torch.min(y2, ih0).to(torch.int64)
bbox = torch.stack((x1, y1, x2, y2), dim=0).unsqueeze(0)
ih = y2 - y1
iw = x2 - x1
else:
ih, iw = ih0, iw0
scale = torch.min(tw / iw, th / ih)
nw = torch.round(iw * scale)
nh = torch.round(ih * scale)
@ -61,21 +37,16 @@ class Letterbox(nn.Module):
pad_bottom = pad_h - pad_top
paddings = (pad_left, pad_right, pad_top, pad_bottom)
return paddings, scale, (nw, nh), bbox
return paddings, scale, (nw, nh)
def forward(self, img, bbox=None):
paddings, _, (nw, nh), bbox = self.calc_params_and_crop(img.shape, bbox)
# Optional: Crop the image
if bbox is not None:
x1, y1, x2, y2 = bbox[0, 0], bbox[0, 1], bbox[0, 2], bbox[0, 3]
img = img.to(torch.float32)
img = img[:, y1:y2, x1:x2, :]
def forward(self, img):
paddings, _, (nw, nh) = self.calc_params(img.shape)
# Resize the image
img = img.to(torch.float32)
img = img.permute(0, 3, 1, 2)
img = F.interpolate(
img.permute(0, 3, 1, 2),
img,
size=(nh, nw),
mode="bilinear",
align_corners=False,
@ -91,9 +62,82 @@ class Letterbox(nn.Module):
value=self.fill_value,
)
img = img.permute(0, 2, 3, 1)
canvas = img
return canvas
return img
# ==================================================================================================
class BoxCrop(nn.Module):
def __init__(self, target_size):
"""Crop bounding box from image"""
super(BoxCrop, self).__init__()
self.target_size = target_size
self.padding_scale = 1.25
def calc_params(self, bbox):
start_x, start_y, end_x, end_y = bbox[0, 0], bbox[0, 1], bbox[0, 2], bbox[0, 3]
target_h, target_w = self.target_size
# Calculate original bounding box width, height and center
bbox_w = end_x - start_x
bbox_h = end_y - start_y
center_x = (start_x + end_x) / 2.0
center_y = (start_y + end_y) / 2.0
# Calculate the aspect ratios
bbox_aspect = bbox_w / bbox_h
target_aspect = target_w / target_h
# Adjust the scaled bounding box to match the target aspect ratio
if bbox_aspect > target_aspect:
adjusted_h = bbox_w / target_aspect
adjusted_w = bbox_w
else:
adjusted_w = bbox_h * target_aspect
adjusted_h = bbox_h
# Scale the bounding box by the padding_scale
scaled_bbox_w = adjusted_w * self.padding_scale
scaled_bbox_h = adjusted_h * self.padding_scale
# Calculate scaled bounding box coordinates
new_start_x = center_x - scaled_bbox_w / 2.0
new_start_y = center_y - scaled_bbox_h / 2.0
new_end_x = center_x + scaled_bbox_w / 2.0
new_end_y = center_y + scaled_bbox_h / 2.0
# Define the new box coordinates
new_box = torch.stack((new_start_x, new_start_y, new_end_x, new_end_y), dim=0)
new_box = new_box.unsqueeze(0)
scale = torch.stack(
((target_w / scaled_bbox_w), (target_h / scaled_bbox_h)), dim=0
)
return scale, new_box
def forward(self, img, bbox):
_, bbox = self.calc_params(bbox)
batch_indices = torch.zeros(bbox.shape[0], 1)
rois = torch.cat([batch_indices, bbox], dim=1)
# Resize and crop
img = img.to(torch.float32)
img = img.permute(0, 3, 1, 2)
img = roi_align(
img,
rois,
output_size=self.target_size,
spatial_scale=1.0,
sampling_ratio=0,
)
img = img.permute(0, 2, 3, 1)
img = img.round()
return img
# ==================================================================================================
@ -106,7 +150,7 @@ class DetPreprocess(nn.Module):
def forward(self, img):
# img: torch.Tensor of shape [batch, H, W, C], dtype=torch.uint8
img = self.letterbox(img, None)
img = self.letterbox(img)
return img
@ -121,7 +165,7 @@ class DetPostprocess(nn.Module):
self.letterbox = Letterbox(target_size)
def forward(self, img, boxes):
paddings, scale, _, _ = self.letterbox.calc_params_and_crop(img.shape, None)
paddings, scale, _ = self.letterbox.calc_params(img.shape)
boxes = boxes.float()
boxes[:, :, 0] -= paddings[0]
@ -160,12 +204,12 @@ class DetPostprocess(nn.Module):
class PosePreprocess(nn.Module):
def __init__(self, target_size, fill_value=114):
super(PosePreprocess, self).__init__()
self.letterbox = Letterbox(target_size, fill_value)
self.boxcrop = BoxCrop(target_size)
def forward(self, img, bbox):
# img: torch.Tensor of shape [1, H, W, C], dtype=torch.uint8
# bbox: torch.Tensor of shape [1, 4], dtype=torch.float32
img = self.letterbox(img, bbox)
img = self.boxcrop(img, bbox)
return img
@ -175,25 +219,22 @@ class PosePreprocess(nn.Module):
class PosePostprocess(nn.Module):
def __init__(self, target_size):
super(PosePostprocess, self).__init__()
self.boxcrop = BoxCrop(target_size)
self.target_size = target_size
self.letterbox = Letterbox(target_size)
def forward(self, img, bbox, keypoints):
paddings, scale, _, bbox = self.letterbox.calc_params_and_crop(img.shape, bbox)
scale, bbox = self.boxcrop.calc_params(bbox)
kp = keypoints.float()
kp[:, :, 0] -= paddings[0]
kp[:, :, 1] -= paddings[2]
kp[:, :, 0:2] /= scale
kp[:, :, 0] += bbox[0, 0]
kp[:, :, 1] += bbox[0, 1]
zero = torch.tensor(0)
kp = torch.max(kp, zero)
th, tw = self.target_size
pad_w = paddings[0] + paddings[1]
pad_h = paddings[2] + paddings[3]
max_w = tw - pad_w - 1
max_h = th - pad_h - 1
max_w = img.shape[2] - 1
max_h = img.shape[1] - 1
k0 = kp[:, :, 0]
k1 = kp[:, :, 1]
k0 = torch.min(k0, max_w)
@ -201,10 +242,6 @@ class PosePostprocess(nn.Module):
kp[:, :, 0] = k0
kp[:, :, 1] = k1
kp[:, :, 0:2] /= scale
kp[:, :, 0] += bbox[0, 0]
kp[:, :, 1] += bbox[0, 1]
return kp
@ -215,6 +252,7 @@ def main():
img_path = "/RapidPoseTriangulation/scripts/../data/h1/54138969-img_003201.jpg"
image = cv2.imread(img_path, 3)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Initialize the DetPreprocess module
preprocess_model = DetPreprocess(target_size=det_target_size)
@ -257,7 +295,7 @@ def main():
# Initialize the PosePreprocess module
preprocess_model = PosePreprocess(target_size=pose_target_size)
det_dummy_input_c0 = torch.from_numpy(image).unsqueeze(0)
det_dummy_input_c1 = torch.tensor([[10, 10, 90, 40]]).to(torch.int32)
det_dummy_input_c1 = torch.tensor([[352, 339, 518, 594]]).to(torch.int32)
# Export to ONNX
torch.onnx.export(
@ -276,8 +314,8 @@ def main():
# Initialize the PosePostprocess module
postprocess_model = PosePostprocess(target_size=pose_target_size)
det_dummy_input_d0 = torch.from_numpy(image).unsqueeze(0)
det_dummy_input_d1 = torch.tensor([[10, 10, 90, 40]]).to(torch.int32)
det_dummy_input_d2 = torch.rand(1, 17, 3)
det_dummy_input_d1 = torch.tensor([[352, 339, 518, 594]]).to(torch.int32)
det_dummy_input_d2 = torch.rand(1, 17, 2)
# Export to ONNX
torch.onnx.export(