Skip to content

Commit 16b4b27

Browse files
committed
add tensorrt support
1 parent 13486e4 commit 16b4b27

File tree

14 files changed

+411
-64
lines changed

14 files changed

+411
-64
lines changed
Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
'''
2+
convert models to tensorrt engine and inference
3+
'''
4+
5+
import os
6+
import torch
7+
import tensorrt as trt
8+
from loguru import logger
9+
10+
import pycuda.driver as cuda
11+
import pycuda.autoinit
12+
import numpy as np
13+
14+
from collections import OrderedDict, namedtuple
15+
16+
from loguru import logger
17+
18+
class TensorRTConverter(object):
19+
20+
def __init__(self, model, input_shape, ckpt_path, min_opt_max_batch=[1, 1, 1], use_fp16=False,
21+
device=torch.device('cpu'), load_ckpt=False, ckpt_key='state_dict'):
22+
"""
23+
Args:
24+
model: torch.nn.Module, model
25+
input_shape: List[int], (c, h, w), input tensor shape
26+
ckpt_path: str, used to save the onnx and tensorrt file
27+
use_fp16: bool, whether use fp16
28+
device: str | torch.device, cpu or gpu
29+
load_ckpt: bool, whether load the pt/pth checkpoint file
30+
ckpt_key: str, the weight key in ckpt dict
31+
"""
32+
self.model = model
33+
self.input_shape = input_shape
34+
35+
postfix_length = len(ckpt_path.split('.')[-1]) # .pt/.pth
36+
self.onnx_model = ckpt_path[:-postfix_length] + 'onnx'
37+
self.trt_model = ckpt_path[:-postfix_length] + 'engine'
38+
self.use_fp16 = use_fp16
39+
40+
self.device = device
41+
42+
if load_ckpt:
43+
logger.info(f'to convert TensorRT, load ckpt {ckpt_path} first')
44+
self.load_ckpt(ckpt_path, ckpt_key)
45+
logger.info('load ckpt done')
46+
47+
48+
self.min_input_shape = tuple([min_opt_max_batch[0], *input_shape])
49+
self.opt_input_shape = tuple([min_opt_max_batch[1], *input_shape])
50+
self.max_input_shape = tuple([min_opt_max_batch[2], *input_shape])
51+
52+
# check tensor rt version
53+
self.is_trt_10 = int(trt.__version__.split(".")[0]) >= 10
54+
55+
# constants
56+
self.WORK_SPACE = 1 << 30
57+
self.INPUT_NAME = "images"
58+
self.OUTPUT_NAME = "output"
59+
self.OPSET_VERSION = 12
60+
61+
def load_ckpt(self, ckpt_path, ckpt_key):
62+
"""
63+
load checkpoint file if needed
64+
"""
65+
ckpt = torch.load(ckpt_path, map_location=self.device)
66+
state_dict = ckpt.get(ckpt_key, ckpt)
67+
new_state = {}
68+
for k, v in state_dict.items():
69+
name = k
70+
if name.startswith("module."):
71+
name = name[len("module."):]
72+
if name.startswith("model."):
73+
name = name[len("model."):]
74+
new_state[name] = v
75+
self.model.load_state_dict(new_state)
76+
self.model.eval()
77+
78+
def export_onnx(self):
79+
logger.info('to convert TensorRT, convert onnx first')
80+
81+
if os.path.exists(self.onnx_model):
82+
logger.warning(f'the onnx {self.onnx_model} already exists, so the export progress is stopped')
83+
return
84+
85+
dummy = torch.randn([1, *self.input_shape]).cuda()
86+
torch.onnx.export(
87+
self.model, dummy, self.onnx_model,
88+
input_names=[self.INPUT_NAME],
89+
output_names=[self.OUTPUT_NAME],
90+
opset_version=self.OPSET_VERSION,
91+
do_constant_folding=True,
92+
dynamic_axes={self.INPUT_NAME:{0:"batch_size"}, self.OUTPUT_NAME:{0:"batch_size"}}
93+
)
94+
logger.info(f'convert onnx done, save path: {self.onnx_model}')
95+
96+
def export(self):
97+
98+
if os.path.exists(self.trt_model):
99+
logger.warning(f'the engine {self.trt_model} already exists, so the export progress is stopped')
100+
return
101+
102+
trt_logger = trt.Logger(trt.Logger.WARNING)
103+
builder = trt.Builder(trt_logger)
104+
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
105+
parser = trt.OnnxParser(network, trt_logger)
106+
107+
# convert onnx
108+
self.export_onnx()
109+
110+
logger.info('converting tensorrt')
111+
with open(self.onnx_model, "rb") as f:
112+
if not parser.parse(f.read()):
113+
for error in range(parser.num_errors):
114+
trt_logger.error(parser.get_error(error))
115+
116+
# build configs
117+
config = builder.create_builder_config()
118+
if self.use_fp16 and builder.platform_has_fast_fp16:
119+
config.set_flag(trt.BuilderFlag.FP16)
120+
logger.info('enabled fp16')
121+
122+
if self.is_trt_10:
123+
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, self.WORK_SPACE)
124+
else:
125+
config.max_workspace_size = self.WORK_SPACE
126+
profile = builder.create_optimization_profile()
127+
# support dynamic batch
128+
profile.set_shape(
129+
self.INPUT_NAME,
130+
min=self.min_input_shape,
131+
opt=self.opt_input_shape,
132+
max=self.max_input_shape
133+
)
134+
config.add_optimization_profile(profile)
135+
136+
logger.info('saving tensorrt engine')
137+
if self.is_trt_10:
138+
engine = builder.build_serialized_network(network, config)
139+
else:
140+
engine = builder.build_engine(network, config)
141+
142+
with open(self.trt_model, "wb") as f:
143+
f.write(engine if self.is_trt_10 else engine.serialize())
144+
145+
logger.info(f'convert tensorrt done, save path: {self.trt_model}')
146+
147+
148+
class TensorRTInference(object):
149+
150+
def __init__(self, engine_path, min_opt_max_batch=[1, 1, 1], use_fp16=False,
151+
device=torch.device('cpu')):
152+
153+
154+
self.use_fp16 = use_fp16
155+
self.engine_path = engine_path
156+
self.device = device
157+
158+
self.is_trt_10 = int(trt.__version__.split(".")[0]) >= 10
159+
160+
self.trt_logger = trt.Logger(trt.Logger.WARNING)
161+
self.runtime = trt.Runtime(self.trt_logger)
162+
163+
# load engine
164+
logger.info(f'loading tensor rt engine {engine_path}')
165+
self.load_engine()
166+
logger.info(f'load tensor rt engine done')
167+
168+
# constants
169+
self.INPUT_NAME = "images"
170+
self.OUTPUT_NAME = "output"
171+
172+
def load_engine(self):
173+
Binding = namedtuple("Binding", ("name", "dtype", "shape", "data", "ptr"))
174+
175+
# Deserialize the engine
176+
with open(self.engine_path, "rb") as f:
177+
self.engine = self.runtime.deserialize_cuda_engine(f.read())
178+
179+
# create context
180+
self.context = self.engine.create_execution_context()
181+
182+
# Execution context
183+
self.bindings = OrderedDict()
184+
185+
num = range(self.engine.num_io_tensors) if self.is_trt_10 else range(self.engine.num_bindings)
186+
187+
# Parse bindings
188+
for index in num:
189+
if self.is_trt_10:
190+
name = self.engine.get_tensor_name(index)
191+
dtype = trt.nptype(self.engine.get_tensor_dtype(name))
192+
is_input = self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT
193+
if is_input and -1 in tuple(self.engine.get_tensor_shape(name)):
194+
self.context.set_input_shape(name, tuple(self.engine.get_tensor_profile_shape(name, 0)[1]))
195+
196+
shape = tuple(self.context.get_tensor_shape(name))
197+
198+
else:
199+
name = self.engine.get_binding_name(index)
200+
dtype = trt.nptype(self.engine.get_binding_dtype(index))
201+
is_input = self.engine.binding_is_input(index)
202+
203+
# Handle dynamic shapes
204+
if is_input and -1 in self.engine.get_binding_shape(index):
205+
profile_index = 0
206+
min_shape, opt_shape, max_shape = self.engine.get_profile_shape(profile_index, index)
207+
self.context.set_binding_shape(index, opt_shape)
208+
209+
shape = tuple(self.context.get_binding_shape(index))
210+
data = torch.from_numpy(np.empty(shape, dtype=dtype)).to(self.device)
211+
self.bindings[name] = Binding(name, dtype, shape, data, int(data.data_ptr()))
212+
213+
self.binding_addrs = OrderedDict((n, d.ptr) for n, d in self.bindings.items())
214+
215+
def inference(self, input_tensor):
216+
"""
217+
Args:
218+
input_tensor: torch.Tensor, shape (b, c, h, w)
219+
"""
220+
221+
temp_im_batch = input_tensor.clone()
222+
batch_array = []
223+
inp_batch = input_tensor.shape[0]
224+
out_batch = self.bindings[self.OUTPUT_NAME].shape[0]
225+
resultant_features = []
226+
227+
# Divide batch to sub batches
228+
while inp_batch > out_batch:
229+
batch_array.append(temp_im_batch[:out_batch])
230+
temp_im_batch = temp_im_batch[out_batch:]
231+
inp_batch = temp_im_batch.shape[0]
232+
if temp_im_batch.shape[0] > 0:
233+
batch_array.append(temp_im_batch)
234+
235+
for temp_batch in batch_array:
236+
# Adjust for dynamic shapes
237+
if temp_batch.shape != self.bindings[self.INPUT_NAME].shape:
238+
if self.is_trt_10:
239+
240+
self.context.set_input_shape(self.INPUT_NAME, temp_batch.shape)
241+
self.bindings[self.INPUT_NAME] = self.bindings[self.INPUT_NAME]._replace(shape=temp_batch.shape)
242+
self.bindings[self.OUTPUT_NAME].data.resize_(tuple(self.context.get_tensor_shape(self.OUTPUT_NAME)))
243+
else:
244+
i_in = self.model_.get_binding_index(self.INPUT_NAME)
245+
i_out = self.model_.get_binding_index(self.OUTPUT_NAME)
246+
self.context.set_binding_shape(i_in, temp_batch.shape)
247+
self.bindings[self.INPUT_NAME] = self.bindings[self.INPUT_NAME]._replace(shape=temp_batch.shape)
248+
output_shape = tuple(self.context.get_binding_shape(i_out))
249+
self.bindings[self.OUTPUT_NAME].data.resize_(output_shape)
250+
251+
s = self.bindings[self.INPUT_NAME].shape
252+
assert temp_batch.shape == s, f"Input size {temp_batch.shape} does not match model size {s}"
253+
254+
self.binding_addrs[self.INPUT_NAME] = int(temp_batch.data_ptr())
255+
256+
# Execute inference
257+
self.context.execute_v2(list(self.binding_addrs.values()))
258+
features = self.bindings[self.OUTPUT_NAME].data
259+
resultant_features.append(features.clone())
260+
261+
if len(resultant_features) == 1:
262+
return resultant_features[0]
263+
else:
264+
rslt_features = torch.cat(resultant_features, dim=0)
265+
rslt_features = rslt_features[: input_tensor.shape[0]]
266+
return rslt_features
267+
268+
def __call__(self, input_tensor):
269+
return self.inference(input_tensor)
270+
271+
def eval(self, ):
272+
# for compatibility
273+
return
274+

0 commit comments

Comments
 (0)