Some small updates.
This commit is contained in:
@ -9,22 +9,23 @@ from tqdm import tqdm
|
||||
|
||||
|
||||
class BaseModel(ABC):
|
||||
def __init__(self, model_path: str, warmup: int):
|
||||
def __init__(
|
||||
self, model_path: str, warmup: int, usetrt: bool = True, usegpu: bool = True
|
||||
):
|
||||
self.opt = ort.SessionOptions()
|
||||
providers = ort.get_available_providers()
|
||||
# ort.set_default_logger_severity(1)
|
||||
|
||||
provider = ""
|
||||
if "CUDAExecutionProvider" in providers:
|
||||
provider = "CUDAExecutionProvider"
|
||||
else:
|
||||
provider = "CPUExecutionProvider"
|
||||
self.provider = provider
|
||||
print("Found providers:", providers)
|
||||
print("Using:", provider)
|
||||
self.providers = []
|
||||
if usetrt and "TensorrtExecutionProvider" in providers:
|
||||
self.providers.append("TensorrtExecutionProvider")
|
||||
if usegpu and "CUDAExecutionProvider" in providers:
|
||||
self.providers.append("CUDAExecutionProvider")
|
||||
self.providers.append("CPUExecutionProvider")
|
||||
print("Using providers:", self.providers)
|
||||
|
||||
self.session = ort.InferenceSession(
|
||||
model_path, providers=[provider], sess_options=self.opt
|
||||
model_path, providers=self.providers, sess_options=self.opt
|
||||
)
|
||||
|
||||
self.input_names = [input.name for input in self.session.get_inputs()]
|
||||
@ -65,7 +66,7 @@ class BaseModel(ABC):
|
||||
if "image" in iname:
|
||||
ishape = self.input_shapes[i]
|
||||
if "batch_size" in ishape:
|
||||
if self.provider == "TensorrtExecutionProvider":
|
||||
if "TensorrtExecutionProvider" in self.providers:
|
||||
# Using different images sizes for TensorRT warmup takes too long
|
||||
ishape = [1, 1000, 1000, 3]
|
||||
else:
|
||||
@ -89,7 +90,7 @@ class BaseModel(ABC):
|
||||
]
|
||||
)
|
||||
else:
|
||||
raise ValueError("Undefined input type")
|
||||
raise ValueError("Undefined input type:", iname)
|
||||
|
||||
tensor = tensor.astype(self.input_types[i])
|
||||
inputs[iname] = tensor
|
||||
|
||||
Reference in New Issue
Block a user