import cv2 import torch import torch.nn as nn import torch.nn.functional as F from torchvision.ops import roi_align # ================================================================================================== base_path = "/RapidPoseTriangulation/extras/mmdeploy/exports/" det_target_size = (320, 320) pose_target_size = (384, 288) # ================================================================================================== class Letterbox(nn.Module): def __init__(self, target_size, fill_value=128): """Resize and pad image while keeping aspect ratio""" super(Letterbox, self).__init__() self.target_size = target_size self.fill_value = fill_value def calc_params(self, ishape): ih, iw = ishape[1], ishape[2] th, tw = self.target_size scale = torch.min(tw / iw, th / ih) nw = torch.round(iw * scale) nh = torch.round(ih * scale) pad_w = tw - nw pad_h = th - nh pad_left = pad_w // 2 pad_top = pad_h // 2 pad_right = pad_w - pad_left pad_bottom = pad_h - pad_top paddings = (pad_left, pad_right, pad_top, pad_bottom) return paddings, scale, (nw, nh) def forward(self, img): paddings, _, (nw, nh) = self.calc_params(img.shape) # Resize the image img = img.to(torch.float32) img = img.permute(0, 3, 1, 2) img = F.interpolate( img, size=(nh, nw), mode="bilinear", align_corners=False, ) img = img.permute(0, 2, 3, 1) img = img.round() # Pad the image img = F.pad( img.permute(0, 3, 1, 2), pad=paddings, mode="constant", value=self.fill_value, ) img = img.permute(0, 2, 3, 1) return img # ================================================================================================== class BoxCrop(nn.Module): def __init__(self, target_size): """Crop bounding box from image""" super(BoxCrop, self).__init__() self.target_size = target_size self.padding_scale = 1.25 def calc_params(self, bbox): start_x, start_y, end_x, end_y = bbox[0, 0], bbox[0, 1], bbox[0, 2], bbox[0, 3] target_h, target_w = self.target_size # Calculate original bounding box width, height and center bbox_w = end_x - start_x bbox_h = end_y - start_y center_x = (start_x + end_x) / 2.0 center_y = (start_y + end_y) / 2.0 # Calculate the aspect ratios bbox_aspect = bbox_w / bbox_h target_aspect = target_w / target_h # Adjust the scaled bounding box to match the target aspect ratio if bbox_aspect > target_aspect: adjusted_h = bbox_w / target_aspect adjusted_w = bbox_w else: adjusted_w = bbox_h * target_aspect adjusted_h = bbox_h # Scale the bounding box by the padding_scale scaled_bbox_w = adjusted_w * self.padding_scale scaled_bbox_h = adjusted_h * self.padding_scale # Calculate scaled bounding box coordinates new_start_x = center_x - scaled_bbox_w / 2.0 new_start_y = center_y - scaled_bbox_h / 2.0 new_end_x = center_x + scaled_bbox_w / 2.0 new_end_y = center_y + scaled_bbox_h / 2.0 # Define the new box coordinates new_box = torch.stack((new_start_x, new_start_y, new_end_x, new_end_y), dim=0) new_box = new_box.unsqueeze(0) scale = torch.stack( ((target_w / scaled_bbox_w), (target_h / scaled_bbox_h)), dim=0 ) return scale, new_box def forward(self, img, bbox): _, bbox = self.calc_params(bbox) batch_indices = torch.zeros(bbox.shape[0], 1) rois = torch.cat([batch_indices, bbox], dim=1) # Resize and crop img = img.to(torch.float32) img = img.permute(0, 3, 1, 2) img = roi_align( img, rois, output_size=self.target_size, spatial_scale=1.0, sampling_ratio=0, ) img = img.permute(0, 2, 3, 1) img = img.round() return img # ================================================================================================== class DetPreprocess(nn.Module): def __init__(self, target_size, fill_value=114): super(DetPreprocess, self).__init__() self.letterbox = Letterbox(target_size, fill_value) def forward(self, img): # img: torch.Tensor of shape [batch, H, W, C], dtype=torch.uint8 img = self.letterbox(img) return img # ================================================================================================== class DetPostprocess(nn.Module): def __init__(self, target_size): super(DetPostprocess, self).__init__() self.target_size = target_size self.letterbox = Letterbox(target_size) def forward(self, img, boxes): paddings, scale, _ = self.letterbox.calc_params(img.shape) boxes = boxes.float() boxes[:, :, 0] -= paddings[0] boxes[:, :, 2] -= paddings[0] boxes[:, :, 1] -= paddings[2] boxes[:, :, 3] -= paddings[2] zero = torch.tensor(0) boxes = torch.max(boxes, zero) th, tw = self.target_size pad_w = paddings[0] + paddings[1] pad_h = paddings[2] + paddings[3] max_w = tw - pad_w - 1 max_h = th - pad_h - 1 b0 = boxes[:, :, 0] b1 = boxes[:, :, 1] b2 = boxes[:, :, 2] b3 = boxes[:, :, 3] b0 = torch.min(b0, max_w) b1 = torch.min(b1, max_h) b2 = torch.min(b2, max_w) b3 = torch.min(b3, max_h) boxes[:, :, 0] = b0 boxes[:, :, 1] = b1 boxes[:, :, 2] = b2 boxes[:, :, 3] = b3 boxes[:, :, 0:4] /= scale return boxes # ================================================================================================== class PosePreprocess(nn.Module): def __init__(self, target_size, fill_value=114): super(PosePreprocess, self).__init__() self.boxcrop = BoxCrop(target_size) def forward(self, img, bbox): # img: torch.Tensor of shape [1, H, W, C], dtype=torch.uint8 # bbox: torch.Tensor of shape [1, 4], dtype=torch.float32 img = self.boxcrop(img, bbox) return img # ================================================================================================== class PosePostprocess(nn.Module): def __init__(self, target_size): super(PosePostprocess, self).__init__() self.boxcrop = BoxCrop(target_size) self.target_size = target_size def forward(self, img, bbox, keypoints): scale, bbox = self.boxcrop.calc_params(bbox) kp = keypoints.float() kp[:, :, 0:2] /= scale kp[:, :, 0] += bbox[0, 0] kp[:, :, 1] += bbox[0, 1] zero = torch.tensor(0) kp = torch.max(kp, zero) max_w = img.shape[2] - 1 max_h = img.shape[1] - 1 k0 = kp[:, :, 0] k1 = kp[:, :, 1] k0 = torch.min(k0, max_w) k1 = torch.min(k1, max_h) kp[:, :, 0] = k0 kp[:, :, 1] = k1 return kp # ================================================================================================== def main(): img_path = "/RapidPoseTriangulation/scripts/../data/h1/54138969-img_003201.jpg" image = cv2.imread(img_path, 3) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Initialize the DetPreprocess module preprocess_model = DetPreprocess(target_size=det_target_size) det_dummy_input_a0 = torch.from_numpy(image).unsqueeze(0) # Export to ONNX torch.onnx.export( preprocess_model, det_dummy_input_a0, base_path + "det_preprocess.onnx", opset_version=11, input_names=["input_image"], output_names=["preprocessed_image"], dynamic_axes={ "input_image": {0: "batch_size", 1: "height", 2: "width"}, "preprocessed_image": {0: "batch_size"}, }, ) # Initialize the DetPostprocess module postprocess_model = DetPostprocess(target_size=det_target_size) det_dummy_input_b0 = torch.from_numpy(image).unsqueeze(0) det_dummy_input_b1 = torch.rand(1, 10, 5) # Export to ONNX torch.onnx.export( postprocess_model, (det_dummy_input_b0, det_dummy_input_b1), base_path + "det_postprocess.onnx", opset_version=11, input_names=["input_image", "boxes"], output_names=["output_boxes"], dynamic_axes={ "input_image": {0: "batch_size", 1: "height", 2: "width"}, "boxes": {0: "batch_size", 1: "num_boxes"}, "output_boxes": {0: "batch_size", 1: "num_boxes"}, }, ) # Initialize the PosePreprocess module preprocess_model = PosePreprocess(target_size=pose_target_size) det_dummy_input_c0 = torch.from_numpy(image).unsqueeze(0) det_dummy_input_c1 = torch.tensor([[352, 339, 518, 594]]).to(torch.int32) # Export to ONNX torch.onnx.export( preprocess_model, (det_dummy_input_c0, det_dummy_input_c1), base_path + "pose_preprocess.onnx", opset_version=11, input_names=["input_image", "bbox"], output_names=["preprocessed_image"], dynamic_axes={ "input_image": {0: "batch_size", 1: "height", 2: "width"}, "preprocessed_image": {0: "batch_size"}, }, ) # Initialize the PosePostprocess module postprocess_model = PosePostprocess(target_size=pose_target_size) det_dummy_input_d0 = torch.from_numpy(image).unsqueeze(0) det_dummy_input_d1 = torch.tensor([[352, 339, 518, 594]]).to(torch.int32) det_dummy_input_d2 = torch.rand(1, 17, 2) # Export to ONNX torch.onnx.export( postprocess_model, (det_dummy_input_d0, det_dummy_input_d1, det_dummy_input_d2), base_path + "pose_postprocess.onnx", opset_version=11, input_names=["input_image", "bbox", "keypoints"], output_names=["output_keypoints"], dynamic_axes={ "input_image": {0: "batch_size", 1: "height", 2: "width"}, "output_keypoints": {0: "batch_size"}, }, ) # ================================================================================================== if __name__ == "__main__": main()