Skip to content

Commit d1ffaf0

Browse files
committed
TensorRT works with custom batch size
1 parent 823a6e4 commit d1ffaf0

6 files changed

Lines changed: 35 additions & 23 deletions

File tree

example/examples.h

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -595,6 +595,7 @@ class YoloDarknetExample : public VideoExample
595595
#else
596596
std::string pathToModel = "../data/";
597597
#endif
598+
size_t maxBatch = 1;
598599
enum class YOLOModels
599600
{
600601
TinyYOLOv3 = 0,
@@ -628,16 +629,19 @@ class YoloDarknetExample : public VideoExample
628629
config.emplace("modelConfiguration", pathToModel + "yolov4-tiny.cfg");
629630
config.emplace("modelBinary", pathToModel + "yolov4-tiny.weights");
630631
config.emplace("confidenceThreshold", "0.5");
631-
config.emplace("maxBatch", "4");
632+
maxBatch = 4;
632633
break;
633634

634635
case YOLOModels::ScaledYOLOv4:
635636
config.emplace("modelConfiguration", pathToModel + "yolov4-csp.cfg");
636637
config.emplace("modelBinary", pathToModel + "yolov4-csp.weights");
637638
config.emplace("confidenceThreshold", "0.5");
638-
config.emplace("maxBatch", "2");
639+
maxBatch = 2;
639640
break;
640641
}
642+
if (maxBatch < m_batchSize)
643+
maxBatch = m_batchSize;
644+
config.emplace("maxBatch", std::to_string(m_batchSize));
641645
config.emplace("classNames", pathToModel + "coco.names");
642646
config.emplace("maxCropRatio", "-1");
643647

@@ -787,7 +791,7 @@ class YoloTensorRTExample : public VideoExample
787791
#else
788792
std::string pathToModel = "../data/";
789793
#endif
790-
794+
size_t maxBatch = 1;
791795
enum class YOLOModels
792796
{
793797
TinyYOLOv3 = 0,
@@ -805,7 +809,7 @@ class YoloTensorRTExample : public VideoExample
805809
config.emplace("confidenceThreshold", "0.5");
806810
config.emplace("inference_precison", "FP32");
807811
config.emplace("net_type", "YOLOV3_TINY");
808-
config.emplace("maxBatch", "4");
812+
maxBatch = 4;
809813
config.emplace("maxCropRatio", "2");
810814
break;
811815

@@ -815,17 +819,17 @@ class YoloTensorRTExample : public VideoExample
815819
config.emplace("confidenceThreshold", "0.7");
816820
config.emplace("inference_precison", "FP32");
817821
config.emplace("net_type", "YOLOV3");
818-
config.emplace("maxBatch", "2");
822+
maxBatch = 2;
819823
config.emplace("maxCropRatio", "-1");
820824
break;
821825

822826
case YOLOModels::YOLOv4:
823827
config.emplace("modelConfiguration", pathToModel + "yolov4.cfg");
824828
config.emplace("modelBinary", pathToModel + "yolov4.weights");
825829
config.emplace("confidenceThreshold", "0.8");
826-
config.emplace("inference_precison", "FP16");
830+
config.emplace("inference_precison", "FP32");
827831
config.emplace("net_type", "YOLOV4");
828-
config.emplace("maxBatch", "1");
832+
maxBatch = 1;
829833
config.emplace("maxCropRatio", "-1");
830834
break;
831835

@@ -835,7 +839,7 @@ class YoloTensorRTExample : public VideoExample
835839
config.emplace("confidenceThreshold", "0.5");
836840
config.emplace("inference_precison", "FP32");
837841
config.emplace("net_type", "YOLOV4_TINY");
838-
config.emplace("maxBatch", "4");
842+
maxBatch = 4;
839843
config.emplace("maxCropRatio", "1");
840844
break;
841845

@@ -845,11 +849,13 @@ class YoloTensorRTExample : public VideoExample
845849
config.emplace("confidenceThreshold", "0.5");
846850
config.emplace("inference_precison", "FP32");
847851
config.emplace("net_type", "YOLOV5");
848-
config.emplace("maxBatch", "1");
852+
maxBatch = 1;
849853
config.emplace("maxCropRatio", "-1");
850854
break;
851855
}
852-
856+
if (maxBatch < m_batchSize)
857+
maxBatch = m_batchSize;
858+
config.emplace("maxBatch", std::to_string(m_batchSize));
853859
config.emplace("classNames", pathToModel + "coco.names");
854860

855861
config.emplace("white_list", std::to_string((objtype_t)ObjectTypes::obj_person));

src/Detector/YoloTensorRTDetector.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ bool YoloTensorRTDetector::Init(const config_t& config)
4747
auto maxBatch = config.find("maxBatch");
4848
if (maxBatch != config.end())
4949
m_batchSize = std::max(1, std::stoi(maxBatch->second));
50+
m_localConfig.batch_size = m_batchSize;
5051

5152
m_localConfig.file_model_cfg = modelConfiguration->second;
5253
m_localConfig.file_model_weights = modelBinary->second;

src/Detector/tensorrt_yolo/class_detector.h

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@ namespace tensor_rt
2323

2424
enum ModelType
2525
{
26-
YOLOV2 = 0,
27-
YOLOV3,
28-
YOLOV2_TINY,
29-
YOLOV3_TINY,
30-
YOLOV4,
31-
YOLOV4_TINY,
32-
YOLOV5
26+
YOLOV2 = 0,
27+
YOLOV3,
28+
YOLOV2_TINY,
29+
YOLOV3_TINY,
30+
YOLOV4,
31+
YOLOV4_TINY,
32+
YOLOV5
3333
};
3434

3535
enum Precision
@@ -41,9 +41,9 @@ namespace tensor_rt
4141

4242
struct Config
4343
{
44-
std::string file_model_cfg = "configs/yolov3.cfg";
44+
std::string file_model_cfg = "yolov4.cfg";
4545

46-
std::string file_model_weights = "configs/yolov3.weights";
46+
std::string file_model_weights = "yolov4.weights";
4747

4848
float detect_thresh = 0.9f;
4949

@@ -53,6 +53,8 @@ namespace tensor_rt
5353

5454
int gpu_id = 0;
5555

56+
uint32_t batch_size = 1;
57+
5658
std::string calibration_image_list_file_txt = "configs/calibration_images.txt";
5759
};
5860

src/Detector/tensorrt_yolo/class_yolo_detector.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ class YoloDectector
116116
_infer_param.calibImagesPath = "";
117117
_infer_param.probThresh = _config.detect_thresh;
118118
_infer_param.nmsThresh = 0.5;
119+
_infer_param.batchSize = _config.batch_size;
119120
}
120121

121122
void build_net()

src/Detector/tensorrt_yolo/yolo.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ Yolo::Yolo( const NetworkInfo& networkInfo, const InferParams& inferParams) :
2828
m_NMSThresh(inferParams.nmsThresh),
2929
m_PrintPerfInfo(inferParams.printPerfInfo),
3030
m_PrintPredictions(inferParams.printPredictionInfo),
31+
m_BatchSize(inferParams.batchSize),
3132
m_Logger(Logger()),
3233
m_Network(nullptr),
3334
m_Builder(nullptr),

src/Detector/tensorrt_yolo/yolo.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,13 @@ struct NetworkInfo
6464
*/
6565
struct InferParams
6666
{
67-
bool printPerfInfo;
68-
bool printPredictionInfo;
67+
bool printPerfInfo = false;
68+
bool printPredictionInfo = false;
6969
std::string calibImages;
7070
std::string calibImagesPath;
71-
float probThresh;
72-
float nmsThresh;
71+
float probThresh = 0.5f;
72+
float nmsThresh = 0.5f;
73+
uint32_t batchSize = 1;
7374
};
7475

7576
/**

0 commit comments

Comments
 (0)