1+ '''
2+ convert models to tensorrt engine and inference
3+ '''
4+
5+ import os
6+ import torch
7+ import tensorrt as trt
8+ from loguru import logger
9+
10+ import pycuda .driver as cuda
11+ import pycuda .autoinit
12+ import numpy as np
13+
14+ from collections import OrderedDict , namedtuple
15+
16+ from loguru import logger
17+
18+ class TensorRTConverter (object ):
19+
20+ def __init__ (self , model , input_shape , ckpt_path , min_opt_max_batch = [1 , 1 , 1 ], use_fp16 = False ,
21+ device = torch .device ('cpu' ), load_ckpt = False , ckpt_key = 'state_dict' ):
22+ """
23+ Args:
24+ model: torch.nn.Module, model
25+ input_shape: List[int], (c, h, w), input tensor shape
26+ ckpt_path: str, used to save the onnx and tensorrt file
27+ use_fp16: bool, whether use fp16
28+ device: str | torch.device, cpu or gpu
29+ load_ckpt: bool, whether load the pt/pth checkpoint file
30+ ckpt_key: str, the weight key in ckpt dict
31+ """
32+ self .model = model
33+ self .input_shape = input_shape
34+
35+ postfix_length = len (ckpt_path .split ('.' )[- 1 ]) # .pt/.pth
36+ self .onnx_model = ckpt_path [:- postfix_length ] + 'onnx'
37+ self .trt_model = ckpt_path [:- postfix_length ] + 'engine'
38+ self .use_fp16 = use_fp16
39+
40+ self .device = device
41+
42+ if load_ckpt :
43+ logger .info (f'to convert TensorRT, load ckpt { ckpt_path } first' )
44+ self .load_ckpt (ckpt_path , ckpt_key )
45+ logger .info ('load ckpt done' )
46+
47+
48+ self .min_input_shape = tuple ([min_opt_max_batch [0 ], * input_shape ])
49+ self .opt_input_shape = tuple ([min_opt_max_batch [1 ], * input_shape ])
50+ self .max_input_shape = tuple ([min_opt_max_batch [2 ], * input_shape ])
51+
52+ # check tensor rt version
53+ self .is_trt_10 = int (trt .__version__ .split ("." )[0 ]) >= 10
54+
55+ # constants
56+ self .WORK_SPACE = 1 << 30
57+ self .INPUT_NAME = "images"
58+ self .OUTPUT_NAME = "output"
59+ self .OPSET_VERSION = 12
60+
61+ def load_ckpt (self , ckpt_path , ckpt_key ):
62+ """
63+ load checkpoint file if needed
64+ """
65+ ckpt = torch .load (ckpt_path , map_location = self .device )
66+ state_dict = ckpt .get (ckpt_key , ckpt )
67+ new_state = {}
68+ for k , v in state_dict .items ():
69+ name = k
70+ if name .startswith ("module." ):
71+ name = name [len ("module." ):]
72+ if name .startswith ("model." ):
73+ name = name [len ("model." ):]
74+ new_state [name ] = v
75+ self .model .load_state_dict (new_state )
76+ self .model .eval ()
77+
78+ def export_onnx (self ):
79+ logger .info ('to convert TensorRT, convert onnx first' )
80+
81+ if os .path .exists (self .onnx_model ):
82+ logger .warning (f'the onnx { self .onnx_model } already exists, so the export progress is stopped' )
83+ return
84+
85+ dummy = torch .randn ([1 , * self .input_shape ]).cuda ()
86+ torch .onnx .export (
87+ self .model , dummy , self .onnx_model ,
88+ input_names = [self .INPUT_NAME ],
89+ output_names = [self .OUTPUT_NAME ],
90+ opset_version = self .OPSET_VERSION ,
91+ do_constant_folding = True ,
92+ dynamic_axes = {self .INPUT_NAME :{0 :"batch_size" }, self .OUTPUT_NAME :{0 :"batch_size" }}
93+ )
94+ logger .info (f'convert onnx done, save path: { self .onnx_model } ' )
95+
96+ def export (self ):
97+
98+ if os .path .exists (self .trt_model ):
99+ logger .warning (f'the engine { self .trt_model } already exists, so the export progress is stopped' )
100+ return
101+
102+ trt_logger = trt .Logger (trt .Logger .WARNING )
103+ builder = trt .Builder (trt_logger )
104+ network = builder .create_network (1 << int (trt .NetworkDefinitionCreationFlag .EXPLICIT_BATCH ))
105+ parser = trt .OnnxParser (network , trt_logger )
106+
107+ # convert onnx
108+ self .export_onnx ()
109+
110+ logger .info ('converting tensorrt' )
111+ with open (self .onnx_model , "rb" ) as f :
112+ if not parser .parse (f .read ()):
113+ for error in range (parser .num_errors ):
114+ trt_logger .error (parser .get_error (error ))
115+
116+ # build configs
117+ config = builder .create_builder_config ()
118+ if self .use_fp16 and builder .platform_has_fast_fp16 :
119+ config .set_flag (trt .BuilderFlag .FP16 )
120+ logger .info ('enabled fp16' )
121+
122+ if self .is_trt_10 :
123+ config .set_memory_pool_limit (trt .MemoryPoolType .WORKSPACE , self .WORK_SPACE )
124+ else :
125+ config .max_workspace_size = self .WORK_SPACE
126+ profile = builder .create_optimization_profile ()
127+ # support dynamic batch
128+ profile .set_shape (
129+ self .INPUT_NAME ,
130+ min = self .min_input_shape ,
131+ opt = self .opt_input_shape ,
132+ max = self .max_input_shape
133+ )
134+ config .add_optimization_profile (profile )
135+
136+ logger .info ('saving tensorrt engine' )
137+ if self .is_trt_10 :
138+ engine = builder .build_serialized_network (network , config )
139+ else :
140+ engine = builder .build_engine (network , config )
141+
142+ with open (self .trt_model , "wb" ) as f :
143+ f .write (engine if self .is_trt_10 else engine .serialize ())
144+
145+ logger .info (f'convert tensorrt done, save path: { self .trt_model } ' )
146+
147+
148+ class TensorRTInference (object ):
149+
150+ def __init__ (self , engine_path , min_opt_max_batch = [1 , 1 , 1 ], use_fp16 = False ,
151+ device = torch .device ('cpu' )):
152+
153+
154+ self .use_fp16 = use_fp16
155+ self .engine_path = engine_path
156+ self .device = device
157+
158+ self .is_trt_10 = int (trt .__version__ .split ("." )[0 ]) >= 10
159+
160+ self .trt_logger = trt .Logger (trt .Logger .WARNING )
161+ self .runtime = trt .Runtime (self .trt_logger )
162+
163+ # load engine
164+ logger .info (f'loading tensor rt engine { engine_path } ' )
165+ self .load_engine ()
166+ logger .info (f'load tensor rt engine done' )
167+
168+ # constants
169+ self .INPUT_NAME = "images"
170+ self .OUTPUT_NAME = "output"
171+
172+ def load_engine (self ):
173+ Binding = namedtuple ("Binding" , ("name" , "dtype" , "shape" , "data" , "ptr" ))
174+
175+ # Deserialize the engine
176+ with open (self .engine_path , "rb" ) as f :
177+ self .engine = self .runtime .deserialize_cuda_engine (f .read ())
178+
179+ # create context
180+ self .context = self .engine .create_execution_context ()
181+
182+ # Execution context
183+ self .bindings = OrderedDict ()
184+
185+ num = range (self .engine .num_io_tensors ) if self .is_trt_10 else range (self .engine .num_bindings )
186+
187+ # Parse bindings
188+ for index in num :
189+ if self .is_trt_10 :
190+ name = self .engine .get_tensor_name (index )
191+ dtype = trt .nptype (self .engine .get_tensor_dtype (name ))
192+ is_input = self .engine .get_tensor_mode (name ) == trt .TensorIOMode .INPUT
193+ if is_input and - 1 in tuple (self .engine .get_tensor_shape (name )):
194+ self .context .set_input_shape (name , tuple (self .engine .get_tensor_profile_shape (name , 0 )[1 ]))
195+
196+ shape = tuple (self .context .get_tensor_shape (name ))
197+
198+ else :
199+ name = self .engine .get_binding_name (index )
200+ dtype = trt .nptype (self .engine .get_binding_dtype (index ))
201+ is_input = self .engine .binding_is_input (index )
202+
203+ # Handle dynamic shapes
204+ if is_input and - 1 in self .engine .get_binding_shape (index ):
205+ profile_index = 0
206+ min_shape , opt_shape , max_shape = self .engine .get_profile_shape (profile_index , index )
207+ self .context .set_binding_shape (index , opt_shape )
208+
209+ shape = tuple (self .context .get_binding_shape (index ))
210+ data = torch .from_numpy (np .empty (shape , dtype = dtype )).to (self .device )
211+ self .bindings [name ] = Binding (name , dtype , shape , data , int (data .data_ptr ()))
212+
213+ self .binding_addrs = OrderedDict ((n , d .ptr ) for n , d in self .bindings .items ())
214+
215+ def inference (self , input_tensor ):
216+ """
217+ Args:
218+ input_tensor: torch.Tensor, shape (b, c, h, w)
219+ """
220+
221+ temp_im_batch = input_tensor .clone ()
222+ batch_array = []
223+ inp_batch = input_tensor .shape [0 ]
224+ out_batch = self .bindings [self .OUTPUT_NAME ].shape [0 ]
225+ resultant_features = []
226+
227+ # Divide batch to sub batches
228+ while inp_batch > out_batch :
229+ batch_array .append (temp_im_batch [:out_batch ])
230+ temp_im_batch = temp_im_batch [out_batch :]
231+ inp_batch = temp_im_batch .shape [0 ]
232+ if temp_im_batch .shape [0 ] > 0 :
233+ batch_array .append (temp_im_batch )
234+
235+ for temp_batch in batch_array :
236+ # Adjust for dynamic shapes
237+ if temp_batch .shape != self .bindings [self .INPUT_NAME ].shape :
238+ if self .is_trt_10 :
239+
240+ self .context .set_input_shape (self .INPUT_NAME , temp_batch .shape )
241+ self .bindings [self .INPUT_NAME ] = self .bindings [self .INPUT_NAME ]._replace (shape = temp_batch .shape )
242+ self .bindings [self .OUTPUT_NAME ].data .resize_ (tuple (self .context .get_tensor_shape (self .OUTPUT_NAME )))
243+ else :
244+ i_in = self .model_ .get_binding_index (self .INPUT_NAME )
245+ i_out = self .model_ .get_binding_index (self .OUTPUT_NAME )
246+ self .context .set_binding_shape (i_in , temp_batch .shape )
247+ self .bindings [self .INPUT_NAME ] = self .bindings [self .INPUT_NAME ]._replace (shape = temp_batch .shape )
248+ output_shape = tuple (self .context .get_binding_shape (i_out ))
249+ self .bindings [self .OUTPUT_NAME ].data .resize_ (output_shape )
250+
251+ s = self .bindings [self .INPUT_NAME ].shape
252+ assert temp_batch .shape == s , f"Input size { temp_batch .shape } does not match model size { s } "
253+
254+ self .binding_addrs [self .INPUT_NAME ] = int (temp_batch .data_ptr ())
255+
256+ # Execute inference
257+ self .context .execute_v2 (list (self .binding_addrs .values ()))
258+ features = self .bindings [self .OUTPUT_NAME ].data
259+ resultant_features .append (features .clone ())
260+
261+ if len (resultant_features ) == 1 :
262+ return resultant_features [0 ]
263+ else :
264+ rslt_features = torch .cat (resultant_features , dim = 0 )
265+ rslt_features = rslt_features [: input_tensor .shape [0 ]]
266+ return rslt_features
267+
268+ def __call__ (self , input_tensor ):
269+ return self .inference (input_tensor )
270+
271+ def eval (self , ):
272+ # for compatibility
273+ return
274+
0 commit comments