Some further optimizations.
This commit is contained in:
@ -2,7 +2,7 @@ import numpy as np
|
||||
from typing import List
|
||||
|
||||
from .base_model import BaseModel
|
||||
from .utils import letterbox, nms, xywh2xyxy
|
||||
from .utils import letterbox, nms_optimized, xywh2xyxy
|
||||
|
||||
|
||||
class RTMDet(BaseModel):
|
||||
@ -21,11 +21,10 @@ class RTMDet(BaseModel):
|
||||
|
||||
def preprocess(self, image: np.ndarray):
|
||||
th, tw = self.input_shape[2:]
|
||||
tensor, self.dx, self.dy, self.scale = letterbox(
|
||||
image, self.dx, self.dy, self.scale = letterbox(
|
||||
image, (tw, th), fill_value=114
|
||||
)
|
||||
tensor = tensor.astype(self.input_type, copy=False)
|
||||
tensor = tensor[..., ::-1]
|
||||
tensor = np.asarray(image).astype(self.input_type, copy=False)[..., ::-1]
|
||||
tensor = np.expand_dims(tensor, axis=0).transpose((0, 3, 1, 2))
|
||||
return tensor
|
||||
|
||||
@ -34,7 +33,7 @@ class RTMDet(BaseModel):
|
||||
classes = np.expand_dims(np.squeeze(tensor[1], axis=0), axis=-1)
|
||||
boxes = np.concatenate([boxes, classes], axis=-1)
|
||||
|
||||
boxes = nms(boxes, self.iou_threshold, self.conf_threshold)
|
||||
boxes = nms_optimized(boxes, self.iou_threshold, self.conf_threshold)
|
||||
|
||||
if boxes.shape[0] == 0:
|
||||
return boxes
|
||||
|
||||
@ -44,14 +44,12 @@ class SimCC(BaseModel):
|
||||
self.scale = 0
|
||||
|
||||
def preprocess(self, image: np.ndarray):
|
||||
tensor, self.dx, self.dy, self.scale = image, 0, 0, 1
|
||||
tensor = tensor.astype(self.input_type, copy=False)
|
||||
tensor = np.asarray(image).astype(self.input_type, copy=False)
|
||||
tensor = np.expand_dims(tensor, axis=0).transpose((0, 3, 1, 2))
|
||||
return tensor
|
||||
|
||||
def postprocess(self, tensor: List[np.ndarray]):
|
||||
kpts = tensor[0][0]
|
||||
scores = np.expand_dims(tensor[1][0], axis=-1)
|
||||
keypoints = np.concatenate([kpts, scores], axis=-1)
|
||||
|
||||
keypoints = np.concatenate(
|
||||
[tensor[0][0], np.expand_dims(tensor[1][0], axis=-1)], axis=-1
|
||||
)
|
||||
return keypoints
|
||||
|
||||
@ -65,6 +65,61 @@ def nms(boxes: np.ndarray, iou_threshold: float, conf_threshold: float):
|
||||
return np.array(result)
|
||||
|
||||
|
||||
def nms_optimized(boxes: np.ndarray, iou_threshold: float, conf_threshold: float):
|
||||
"""
|
||||
Perform Non-Maximum Suppression (NMS) on bounding boxes for a single class.
|
||||
"""
|
||||
|
||||
# Filter out boxes with low confidence scores
|
||||
scores = boxes[:, 4]
|
||||
keep = scores > conf_threshold
|
||||
boxes = boxes[keep]
|
||||
scores = scores[keep]
|
||||
|
||||
if boxes.shape[0] == 0:
|
||||
return np.empty((0, 5), dtype=boxes.dtype)
|
||||
|
||||
# Compute the area of the bounding boxes
|
||||
x1 = boxes[:, 0]
|
||||
y1 = boxes[:, 1]
|
||||
x2 = boxes[:, 2]
|
||||
y2 = boxes[:, 3]
|
||||
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||||
|
||||
# Sort the boxes by scores in descending order
|
||||
order = scores.argsort()[::-1]
|
||||
|
||||
keep_indices = []
|
||||
while order.size > 0:
|
||||
i = order[0]
|
||||
keep_indices.append(i)
|
||||
|
||||
# Compute IoU of the current box with the rest
|
||||
xx1 = np.maximum(x1[i], x1[order[1:]])
|
||||
yy1 = np.maximum(y1[i], y1[order[1:]])
|
||||
xx2 = np.minimum(x2[i], x2[order[1:]])
|
||||
yy2 = np.minimum(y2[i], y2[order[1:]])
|
||||
|
||||
# Compute width and height of the overlapping area
|
||||
w = np.maximum(0.0, xx2 - xx1 + 1)
|
||||
h = np.maximum(0.0, yy2 - yy1 + 1)
|
||||
|
||||
# Compute the area of the intersection
|
||||
inter = w * h
|
||||
|
||||
# Compute the IoU
|
||||
iou = inter / (areas[i] + areas[order[1:]] - inter)
|
||||
|
||||
# Keep boxes with IoU less than the threshold
|
||||
inds = np.where(iou <= iou_threshold)[0]
|
||||
|
||||
# Update the order array
|
||||
order = order[inds + 1]
|
||||
|
||||
# Return the boxes that are kept
|
||||
return boxes[keep_indices]
|
||||
|
||||
|
||||
def get_heatmap_points(heatmap: np.ndarray):
|
||||
keypoints = np.zeros([1, heatmap.shape[0], 3], dtype=np.float32)
|
||||
for i in range(heatmap.shape[0]):
|
||||
|
||||
Reference in New Issue
Block a user