forked from Smorodov/Multitarget-tracker
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdetect.h
More file actions
132 lines (122 loc) · 3.61 KB
/
detect.h
File metadata and controls
132 lines (122 loc) · 3.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#ifndef _DETECT_H_
#define _DETECT_H_
#include <string>
#include <vector>
#include "NvInfer.h"
namespace nvinfer1
{
template <typename T>
void write(char*& buffer, const T& val)
{
*reinterpret_cast<T*>(buffer) = val;
buffer += sizeof(T);
}
template <typename T>
void read(const char*& buffer, T& val)
{
val = *reinterpret_cast<const T*>(buffer);
buffer += sizeof(T);
}
class Detect :public IPluginV2IOExt
{
public:
Detect();
Detect(const void* data, size_t length);
Detect(const uint32_t n_anchor_, const uint32_t _n_classes_,
const uint32_t n_grid_h_, const uint32_t n_grid_w_/*,
const uint32_t &n_stride_h_, const uint32_t &n_stride_w_*/);
~Detect();
int getNbOutputs()const override
{
return 1;
}
Dims getOutputDimensions(int /*index*/, const Dims* inputs, int /*nbInputDims*/) override
{
return inputs[0];
}
int initialize() override
{
return 0;
}
void terminate() override
{
}
size_t getWorkspaceSize(int /*maxBatchSize*/) const override
{
return 0;
}
int enqueue(int batchSize, const void* const* inputs, void** outputs, void* workspace, cudaStream_t stream)override;
size_t getSerializationSize() const override;
void serialize(void* buffer) const override;
const char* getPluginType() const override
{
return "DETECT_TRT";
}
const char* getPluginVersion() const override
{
return "1.0";
}
void destroy() override
{
delete this;
}
void setPluginNamespace(const char* pluginNamespace) override
{
_s_plugin_namespace = pluginNamespace;
}
const char* getPluginNamespace() const override
{
return _s_plugin_namespace.c_str();
}
DataType getOutputDataType(int /*index*/, const nvinfer1::DataType* /*inputTypes*/, int /*nbInputs*/) const override
{
return DataType::kFLOAT;
}
bool isOutputBroadcastAcrossBatch(int /*outputIndex*/, const bool* /*inputIsBroadcasted*/, int /*nbInputs*/) const override
{
return false;
}
bool canBroadcastInputAcrossBatch(int /*inputIndex*/) const override
{
return false;
}
void attachToContext(
cudnnContext* /*cudnnContext*/, cublasContext* /*cublasContext*/, IGpuAllocator* /*gpuAllocator*/) override
{}
void configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput) override;
void detachFromContext() override
{}
bool supportsFormatCombination(int pos, const PluginTensorDesc* inOut, int /*nbInputs*/, int /*nbOutputs*/) const override
{
return inOut[pos].format == TensorFormat::kLINEAR && inOut[pos].type == DataType::kFLOAT;
}
IPluginV2IOExt* clone() const override;
private:
uint32_t _n_anchor;
uint32_t _n_classes;
uint32_t _n_grid_h;
uint32_t _n_grid_w;
//uint32_t _n_stride_h;
// uint32_t _n_stride_w;
uint64_t _n_output_size;
std::string _s_plugin_namespace;
}; //end detect
class DetectPluginCreator : public IPluginCreator
{
public:
DetectPluginCreator();
~DetectPluginCreator() override = default;
const char* getPluginName()const override;
const char* getPluginVersion() const override;
const PluginFieldCollection* getFieldNames() override;
IPluginV2IOExt* createPlugin(const char* name, const PluginFieldCollection* fc) override;
IPluginV2IOExt* deserializePlugin(const char* name, const void* serialData, size_t serialLength) override;
void setPluginNamespace(const char* libNamespace) override;
const char* getPluginNamespace() const override;
private:
std::string _s_name_space;
static PluginFieldCollection _fc;
static std::vector<PluginField> _vec_plugin_attributes;
};//end detect creator
}//end namespace nvinfer1
#endif