66 lines
2.2 KiB
Python
66 lines
2.2 KiB
Python
import warnings
|
|
from abc import ABC, abstractmethod
|
|
from typing import List
|
|
import time
|
|
import numpy as np
|
|
import onnxruntime as ort
|
|
from tqdm import tqdm
|
|
|
|
|
|
class BaseModel(ABC):
|
|
def __init__(self, model_path: str, device: str = 'CUDA', warmup: int = 30):
|
|
self.opt = ort.SessionOptions()
|
|
|
|
if device == 'CUDA':
|
|
provider = 'CUDAExecutionProvider'
|
|
if provider not in ort.get_available_providers():
|
|
warnings.warn("No CUDAExecutionProvider found, switched to CPUExecutionProvider.", UserWarning)
|
|
provider = 'CPUExecutionProvider'
|
|
elif device == 'CPU':
|
|
provider = 'CPUExecutionProvider'
|
|
else:
|
|
raise ValueError('Provider {} does not exist.'.format(device))
|
|
|
|
self.session = ort.InferenceSession(model_path,
|
|
providers=[provider],
|
|
sess_options=self.opt)
|
|
|
|
self.input_name = self.session.get_inputs()[0].name
|
|
self.input_shape = self.session.get_inputs()[0].shape
|
|
|
|
input_type = self.session.get_inputs()[0].type
|
|
if input_type == 'tensor(float32)':
|
|
self.input_type = np.float32
|
|
elif input_type == 'tensor(float16)':
|
|
self.input_type = np.float16
|
|
elif input_type == 'tensor(uint8)':
|
|
self.input_type = np.uint8
|
|
else:
|
|
raise ValueError('Unknown input type: ', input_type)
|
|
|
|
if warmup > 0:
|
|
self.warmup(warmup)
|
|
|
|
@abstractmethod
|
|
def preprocess(self, image: np.ndarray):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def postprocess(self, tensor: List[np.ndarray]):
|
|
pass
|
|
|
|
def forward(self, image: np.ndarray):
|
|
tensor = self.preprocess(image)
|
|
result = self.session.run(None, {self.input_name: tensor})
|
|
output = self.postprocess(result)
|
|
return output
|
|
|
|
def warmup(self, epoch: int = 30):
|
|
print('{} start warmup!'.format(self.__class__.__name__))
|
|
tensor = np.random.random(self.input_shape).astype(self.input_type)
|
|
for _ in tqdm(range(epoch)):
|
|
self.session.run(None, {self.input_name: tensor})
|
|
|
|
def __call__(self, image: np.ndarray, *args, **kwargs):
|
|
return self.forward(image)
|