Skip to content

Commit 62869c1

Browse files
committed
Add GUI id option
1 parent 853bcb9 commit 62869c1

3 files changed

Lines changed: 95 additions & 85 deletions

File tree

src/Detector/YoloDarknetDetector.cpp

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,17 @@ bool YoloDarknetDetector::Init(const config_t& config)
3434
{
3535
auto modelConfiguration = config.find("modelConfiguration");
3636
auto modelBinary = config.find("modelBinary");
37-
if (modelConfiguration != config.end() && modelBinary != config.end())
38-
{
39-
m_detector = std::make_unique<Detector>(modelConfiguration->second, modelBinary->second);
40-
m_detector->nms = 0.2f;
41-
m_WHRatio = static_cast<float>(m_detector->get_net_width()) / static_cast<float>(m_detector->get_net_height());
42-
}
37+
if (modelConfiguration == config.end() || modelBinary == config.end())
38+
return false;
39+
40+
int currGPUID = 0;
41+
auto gpuId = config.find("gpuId");
42+
if (gpuId != config.end())
43+
currGPUID = std::max(0, std::stoi(gpuId->second));
44+
45+
m_detector = std::make_unique<Detector>(modelConfiguration->second, modelBinary->second, currGPUID);
46+
m_detector->nms = 0.2f;
47+
m_WHRatio = static_cast<float>(m_detector->get_net_width()) / static_cast<float>(m_detector->get_net_height());
4348

4449
auto classNames = config.find("classNames");
4550
if (classNames != config.end())
@@ -59,22 +64,16 @@ bool YoloDarknetDetector::Init(const config_t& config)
5964

6065
auto confidenceThreshold = config.find("confidenceThreshold");
6166
if (confidenceThreshold != config.end())
62-
{
6367
m_confidenceThreshold = std::stof(confidenceThreshold->second);
64-
}
6568

6669
auto maxCropRatio = config.find("maxCropRatio");
6770
if (maxCropRatio != config.end())
68-
{
6971
m_maxCropRatio = std::stof(maxCropRatio->second);
70-
}
7172

7273
m_classesWhiteList.clear();
7374
auto whiteRange = config.equal_range("white_list");
7475
for (auto it = whiteRange.first; it != whiteRange.second; ++it)
75-
{
7676
m_classesWhiteList.insert(it->second);
77-
}
7877

7978
bool correct = m_detector.get() != nullptr;
8079
return correct;

src/Detector/YoloTensorRTDetector.cpp

Lines changed: 83 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,13 @@ YoloTensorRTDetector::YoloTensorRTDetector(
1616
"cow", "diningtable", "dog", "horse",
1717
"motorbike", "person", "pottedplant",
1818
"sheep", "sofa", "train", "tvmonitor" };
19+
20+
m_localConfig.calibration_image_list_file_txt = "";
21+
m_localConfig.inference_precison = tensor_rt::FP32;
22+
m_localConfig.net_type = tensor_rt::YOLOV4;
23+
m_localConfig.detect_thresh = 0.5f;
24+
m_localConfig.gpu_id = 0;
25+
m_localConfig.n_max_batch = 4;
1926
}
2027

2128
///
@@ -31,78 +38,82 @@ YoloTensorRTDetector::~YoloTensorRTDetector(void)
3138
///
3239
bool YoloTensorRTDetector::Init(const config_t& config)
3340
{
34-
auto modelConfiguration = config.find("modelConfiguration");
35-
auto modelBinary = config.find("modelBinary");
36-
if (modelConfiguration != config.end() && modelBinary != config.end())
37-
{
38-
m_detector = std::make_unique<tensor_rt::Detector>();
39-
tensor_rt::Config localConfig;
40-
localConfig.file_model_cfg = modelConfiguration->second;
41-
localConfig.file_model_weights = modelBinary->second;
42-
localConfig.calibration_image_list_file_txt = "";
43-
localConfig.inference_precison = tensor_rt::FP32;
44-
localConfig.net_type = tensor_rt::YOLOV4;
45-
46-
auto inference_precison = config.find("inference_precison");
47-
if (inference_precison != config.end())
48-
{
49-
std::map<std::string, tensor_rt::Precision> dictPrecison;
50-
dictPrecison["INT8"] = tensor_rt::INT8;
51-
dictPrecison["FP16"] = tensor_rt::FP16;
52-
dictPrecison["FP32"] = tensor_rt::FP32;
53-
auto precison = dictPrecison.find(inference_precison->second);
54-
if (precison != dictPrecison.end())
55-
localConfig.inference_precison = precison->second;
56-
}
57-
58-
auto net_type = config.find("net_type");
59-
if (net_type != config.end())
60-
{
61-
std::map<std::string, tensor_rt::ModelType> dictNetType;
62-
dictNetType["YOLOV2"] = tensor_rt::YOLOV2;
63-
dictNetType["YOLOV3"] = tensor_rt::YOLOV3;
64-
dictNetType["YOLOV2_TINY"] = tensor_rt::YOLOV2_TINY;
65-
dictNetType["YOLOV3_TINY"] = tensor_rt::YOLOV3_TINY;
66-
dictNetType["YOLOV4"] = tensor_rt::YOLOV4;
67-
dictNetType["YOLOV4_TINY"] = tensor_rt::YOLOV4_TINY;
68-
69-
auto netType = dictNetType.find(net_type->second);
70-
if (netType != dictNetType.end())
71-
localConfig.net_type = netType->second;
72-
}
73-
74-
m_detector->init(localConfig);
75-
}
76-
77-
auto classNames = config.find("classNames");
78-
if (classNames != config.end())
79-
{
80-
std::ifstream classNamesFile(classNames->second);
81-
if (classNamesFile.is_open())
82-
{
83-
m_classNames.clear();
84-
std::string className;
85-
for (; std::getline(classNamesFile, className); )
86-
{
87-
m_classNames.push_back(className);
88-
}
89-
}
90-
}
91-
92-
auto confidenceThreshold = config.find("confidenceThreshold");
93-
if (confidenceThreshold != config.end())
94-
m_confidenceThreshold = std::stof(confidenceThreshold->second);
95-
96-
auto maxCropRatio = config.find("maxCropRatio");
97-
if (maxCropRatio != config.end())
98-
{
99-
m_maxCropRatio = std::stof(maxCropRatio->second);
100-
if (m_maxCropRatio < 1.f)
101-
m_maxCropRatio = 1.f;
102-
}
103-
104-
bool correct = m_detector.get() != nullptr;
105-
return correct;
41+
m_detector.reset();
42+
43+
auto modelConfiguration = config.find("modelConfiguration");
44+
auto modelBinary = config.find("modelBinary");
45+
if (modelConfiguration == config.end() || modelBinary == config.end())
46+
return false;
47+
48+
auto confidenceThreshold = config.find("confidenceThreshold");
49+
if (confidenceThreshold != config.end())
50+
m_localConfig.detect_thresh = std::stof(confidenceThreshold->second);
51+
52+
auto gpuId = config.find("gpuId");
53+
if (gpuId != config.end())
54+
m_localConfig.gpu_id = std::max(0, std::stoi(gpuId->second));
55+
56+
auto maxBatch = config.find("maxBatch");
57+
if (maxBatch != config.end())
58+
m_localConfig.n_max_batch = std::max(1, std::stoi(maxBatch->second));
59+
60+
m_localConfig.file_model_cfg = modelConfiguration->second;
61+
m_localConfig.file_model_weights = modelBinary->second;
62+
63+
auto inference_precison = config.find("inference_precison");
64+
if (inference_precison != config.end())
65+
{
66+
std::map<std::string, tensor_rt::Precision> dictPrecison;
67+
dictPrecison["INT8"] = tensor_rt::INT8;
68+
dictPrecison["FP16"] = tensor_rt::FP16;
69+
dictPrecison["FP32"] = tensor_rt::FP32;
70+
auto precison = dictPrecison.find(inference_precison->second);
71+
if (precison != dictPrecison.end())
72+
m_localConfig.inference_precison = precison->second;
73+
}
74+
75+
auto net_type = config.find("net_type");
76+
if (net_type != config.end())
77+
{
78+
std::map<std::string, tensor_rt::ModelType> dictNetType;
79+
dictNetType["YOLOV2"] = tensor_rt::YOLOV2;
80+
dictNetType["YOLOV3"] = tensor_rt::YOLOV3;
81+
dictNetType["YOLOV2_TINY"] = tensor_rt::YOLOV2_TINY;
82+
dictNetType["YOLOV3_TINY"] = tensor_rt::YOLOV3_TINY;
83+
dictNetType["YOLOV4"] = tensor_rt::YOLOV4;
84+
dictNetType["YOLOV4_TINY"] = tensor_rt::YOLOV4_TINY;
85+
86+
auto netType = dictNetType.find(net_type->second);
87+
if (netType != dictNetType.end())
88+
m_localConfig.net_type = netType->second;
89+
}
90+
91+
auto classNames = config.find("classNames");
92+
if (classNames != config.end())
93+
{
94+
std::ifstream classNamesFile(classNames->second);
95+
if (classNamesFile.is_open())
96+
{
97+
m_classNames.clear();
98+
std::string className;
99+
for (; std::getline(classNamesFile, className); )
100+
{
101+
m_classNames.push_back(className);
102+
}
103+
}
104+
}
105+
106+
auto maxCropRatio = config.find("maxCropRatio");
107+
if (maxCropRatio != config.end())
108+
{
109+
m_maxCropRatio = std::stof(maxCropRatio->second);
110+
if (m_maxCropRatio < 1.f)
111+
m_maxCropRatio = 1.f;
112+
}
113+
114+
m_detector = std::make_unique<tensor_rt::Detector>();
115+
m_detector->init(m_localConfig);
116+
return m_detector.get() != nullptr;
106117
}
107118

108119
///

src/Detector/YoloTensorRTDetector.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class YoloTensorRTDetector : public BaseDetector
2424
private:
2525
std::unique_ptr<tensor_rt::Detector> m_detector;
2626

27-
float m_confidenceThreshold = 0.5f;
2827
float m_maxCropRatio = 3.0f;
2928
std::vector<std::string> m_classNames;
29+
tensor_rt::Config m_localConfig;
3030
};

0 commit comments

Comments
 (0)