Some small updates.

This commit is contained in:
Daniel
2024-12-04 11:41:12 +01:00
parent 97ff32b9ce
commit 6452d20ec8
3 changed files with 34 additions and 24 deletions

View File

@ -9,22 +9,23 @@ from tqdm import tqdm
class BaseModel(ABC):
def __init__(self, model_path: str, warmup: int):
def __init__(
self, model_path: str, warmup: int, usetrt: bool = True, usegpu: bool = True
):
self.opt = ort.SessionOptions()
providers = ort.get_available_providers()
# ort.set_default_logger_severity(1)
provider = ""
if "CUDAExecutionProvider" in providers:
provider = "CUDAExecutionProvider"
else:
provider = "CPUExecutionProvider"
self.provider = provider
print("Found providers:", providers)
print("Using:", provider)
self.providers = []
if usetrt and "TensorrtExecutionProvider" in providers:
self.providers.append("TensorrtExecutionProvider")
if usegpu and "CUDAExecutionProvider" in providers:
self.providers.append("CUDAExecutionProvider")
self.providers.append("CPUExecutionProvider")
print("Using providers:", self.providers)
self.session = ort.InferenceSession(
model_path, providers=[provider], sess_options=self.opt
model_path, providers=self.providers, sess_options=self.opt
)
self.input_names = [input.name for input in self.session.get_inputs()]
@ -65,7 +66,7 @@ class BaseModel(ABC):
if "image" in iname:
ishape = self.input_shapes[i]
if "batch_size" in ishape:
if self.provider == "TensorrtExecutionProvider":
if "TensorrtExecutionProvider" in self.providers:
# Using different images sizes for TensorRT warmup takes too long
ishape = [1, 1000, 1000, 3]
else:
@ -89,7 +90,7 @@ class BaseModel(ABC):
]
)
else:
raise ValueError("Undefined input type")
raise ValueError("Undefined input type:", iname)
tensor = tensor.astype(self.input_types[i])
inputs[iname] = tensor