first commit
This commit is contained in:
68
include/YOLOv5/DetectorCOCO.h
Normal file
68
include/YOLOv5/DetectorCOCO.h
Normal file
@@ -0,0 +1,68 @@
|
||||
|
||||
#include <iostream>
|
||||
#include <chrono>
|
||||
#include <cmath>
|
||||
#include "cuda_utils.h"
|
||||
#include "logging.h"
|
||||
#include "utils.h"
|
||||
#include "preprocess.h"
|
||||
#include "structs.h"
|
||||
#include <fstream>
|
||||
#include <map>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include <opencv2/opencv.hpp>
|
||||
#include "NvInfer.h"
|
||||
#include "yololayer.h"
|
||||
|
||||
using namespace nvinfer1;
|
||||
|
||||
#define USE_FP32 // set USE_INT8 or USE_FP16 or USE_FP32
|
||||
#define DEVICE 0 // GPU id
|
||||
#define NMS_THRESH_YOLO 0.4
|
||||
#define CONF_THRESH_YOLO_HEAD 0.75
|
||||
#define CONF_THRESH_YOLO_CHAIR 0.8
|
||||
#define BATCH_SIZE_YOLO 1
|
||||
#define MAX_IMAGE_INPUT_SIZE_THRESH 3000 * 3000 // ensure it exceed the maximum size in the input images !
|
||||
|
||||
class DetectorCOCO{
|
||||
|
||||
public:
|
||||
static DetectorCOCO &getInst()
|
||||
{
|
||||
static DetectorCOCO instance;
|
||||
return instance;
|
||||
};
|
||||
void _init();
|
||||
DetectorCOCO();
|
||||
cv::Rect get_rect_yolov5(cv::Mat& img, float bbox[4]);
|
||||
float iou_yolov5(float lbox[4], float rbox[4]);
|
||||
static bool cmp_yolov5(const Yolo::Detection& a, const Yolo::Detection& b);
|
||||
void nms_yolov5(std::vector<Yolo::Detection>& res, float *output, float conf_thresh, float nms_thresh);
|
||||
void doInference(IExecutionContext& context, cudaStream_t& stream, void **buffers, float* output, int batchSize);
|
||||
std::vector<ObjectDetect> predict(cv::Mat inputImage, std::string type);
|
||||
|
||||
private:
|
||||
const int INPUT_H = Yolo::INPUT_H;
|
||||
const int INPUT_W = Yolo::INPUT_W;
|
||||
const int CLASS_NUM = Yolo::CLASS_NUM;
|
||||
const int OUTPUT_SIZE = Yolo::MAX_OUTPUT_BBOX_COUNT * sizeof(Yolo::Detection) / sizeof(float) + 1;
|
||||
const char* INPUT_BLOB_NAME = "data";
|
||||
const char* OUTPUT_BLOB_NAME = "prob";
|
||||
Logger gLogger;
|
||||
std::string engine_name_head = "/opt/nvidia/deepstream/deepstream-6.0/sources/apps/sample_apps/deepstream-app/Model_Classify/verify/crowdhuman_0-body_1-head__yolov5m.engine";
|
||||
std::string engine_name_chair = "/media/thai/A0B6A6B3B6A688FC2/code/PROJECT_BI/CLONE/tensorrtx/yolov5/build/20220615_weight_80class.engine";
|
||||
|
||||
IRuntime* runtime_head;
|
||||
IRuntime* runtime_chair;
|
||||
|
||||
ICudaEngine* engine_head;
|
||||
ICudaEngine* engine_chair;
|
||||
|
||||
IExecutionContext* context_head;
|
||||
IExecutionContext* context_chair;
|
||||
|
||||
|
||||
|
||||
|
||||
};
|
||||
18
include/YOLOv5/cuda_utils.h
Executable file
18
include/YOLOv5/cuda_utils.h
Executable file
@@ -0,0 +1,18 @@
|
||||
#ifndef TRTX_CUDA_UTILS_H_
|
||||
#define TRTX_CUDA_UTILS_H_
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
#ifndef CUDA_CHECK
|
||||
#define CUDA_CHECK(callstr)\
|
||||
{\
|
||||
cudaError_t error_code = callstr;\
|
||||
if (error_code != cudaSuccess) {\
|
||||
std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__;\
|
||||
assert(0);\
|
||||
}\
|
||||
}
|
||||
#endif // CUDA_CHECK
|
||||
|
||||
#endif // TRTX_CUDA_UTILS_H_
|
||||
|
||||
504
include/YOLOv5/logging.h
Executable file
504
include/YOLOv5/logging.h
Executable file
@@ -0,0 +1,504 @@
|
||||
/*
|
||||
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef TENSORRT_LOGGING_H
|
||||
#define TENSORRT_LOGGING_H
|
||||
|
||||
#include "NvInferRuntimeCommon.h"
|
||||
#include <cassert>
|
||||
#include <ctime>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include "macros.h"
|
||||
|
||||
using Severity = nvinfer1::ILogger::Severity;
|
||||
|
||||
class LogStreamConsumerBuffer : public std::stringbuf
|
||||
{
|
||||
public:
|
||||
LogStreamConsumerBuffer(std::ostream& stream, const std::string& prefix, bool shouldLog)
|
||||
: mOutput(stream)
|
||||
, mPrefix(prefix)
|
||||
, mShouldLog(shouldLog)
|
||||
{
|
||||
}
|
||||
|
||||
LogStreamConsumerBuffer(LogStreamConsumerBuffer&& other)
|
||||
: mOutput(other.mOutput)
|
||||
{
|
||||
}
|
||||
|
||||
~LogStreamConsumerBuffer()
|
||||
{
|
||||
// std::streambuf::pbase() gives a pointer to the beginning of the buffered part of the output sequence
|
||||
// std::streambuf::pptr() gives a pointer to the current position of the output sequence
|
||||
// if the pointer to the beginning is not equal to the pointer to the current position,
|
||||
// call putOutput() to log the output to the stream
|
||||
if (pbase() != pptr())
|
||||
{
|
||||
putOutput();
|
||||
}
|
||||
}
|
||||
|
||||
// synchronizes the stream buffer and returns 0 on success
|
||||
// synchronizing the stream buffer consists of inserting the buffer contents into the stream,
|
||||
// resetting the buffer and flushing the stream
|
||||
virtual int sync()
|
||||
{
|
||||
putOutput();
|
||||
return 0;
|
||||
}
|
||||
|
||||
void putOutput()
|
||||
{
|
||||
if (mShouldLog)
|
||||
{
|
||||
// prepend timestamp
|
||||
std::time_t timestamp = std::time(nullptr);
|
||||
tm* tm_local = std::localtime(×tamp);
|
||||
std::cout << "[";
|
||||
std::cout << std::setw(2) << std::setfill('0') << 1 + tm_local->tm_mon << "/";
|
||||
std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_mday << "/";
|
||||
std::cout << std::setw(4) << std::setfill('0') << 1900 + tm_local->tm_year << "-";
|
||||
std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_hour << ":";
|
||||
std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_min << ":";
|
||||
std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_sec << "] ";
|
||||
// std::stringbuf::str() gets the string contents of the buffer
|
||||
// insert the buffer contents pre-appended by the appropriate prefix into the stream
|
||||
mOutput << mPrefix << str();
|
||||
// set the buffer to empty
|
||||
str("");
|
||||
// flush the stream
|
||||
mOutput.flush();
|
||||
}
|
||||
}
|
||||
|
||||
void setShouldLog(bool shouldLog)
|
||||
{
|
||||
mShouldLog = shouldLog;
|
||||
}
|
||||
|
||||
private:
|
||||
std::ostream& mOutput;
|
||||
std::string mPrefix;
|
||||
bool mShouldLog;
|
||||
};
|
||||
|
||||
//!
|
||||
//! \class LogStreamConsumerBase
|
||||
//! \brief Convenience object used to initialize LogStreamConsumerBuffer before std::ostream in LogStreamConsumer
|
||||
//!
|
||||
class LogStreamConsumerBase
|
||||
{
|
||||
public:
|
||||
LogStreamConsumerBase(std::ostream& stream, const std::string& prefix, bool shouldLog)
|
||||
: mBuffer(stream, prefix, shouldLog)
|
||||
{
|
||||
}
|
||||
|
||||
protected:
|
||||
LogStreamConsumerBuffer mBuffer;
|
||||
};
|
||||
|
||||
//!
|
||||
//! \class LogStreamConsumer
|
||||
//! \brief Convenience object used to facilitate use of C++ stream syntax when logging messages.
|
||||
//! Order of base classes is LogStreamConsumerBase and then std::ostream.
|
||||
//! This is because the LogStreamConsumerBase class is used to initialize the LogStreamConsumerBuffer member field
|
||||
//! in LogStreamConsumer and then the address of the buffer is passed to std::ostream.
|
||||
//! This is necessary to prevent the address of an uninitialized buffer from being passed to std::ostream.
|
||||
//! Please do not change the order of the parent classes.
|
||||
//!
|
||||
class LogStreamConsumer : protected LogStreamConsumerBase, public std::ostream
|
||||
{
|
||||
public:
|
||||
//! \brief Creates a LogStreamConsumer which logs messages with level severity.
|
||||
//! Reportable severity determines if the messages are severe enough to be logged.
|
||||
LogStreamConsumer(Severity reportableSeverity, Severity severity)
|
||||
: LogStreamConsumerBase(severityOstream(severity), severityPrefix(severity), severity <= reportableSeverity)
|
||||
, std::ostream(&mBuffer) // links the stream buffer with the stream
|
||||
, mShouldLog(severity <= reportableSeverity)
|
||||
, mSeverity(severity)
|
||||
{
|
||||
}
|
||||
|
||||
LogStreamConsumer(LogStreamConsumer&& other)
|
||||
: LogStreamConsumerBase(severityOstream(other.mSeverity), severityPrefix(other.mSeverity), other.mShouldLog)
|
||||
, std::ostream(&mBuffer) // links the stream buffer with the stream
|
||||
, mShouldLog(other.mShouldLog)
|
||||
, mSeverity(other.mSeverity)
|
||||
{
|
||||
}
|
||||
|
||||
void setReportableSeverity(Severity reportableSeverity)
|
||||
{
|
||||
mShouldLog = mSeverity <= reportableSeverity;
|
||||
mBuffer.setShouldLog(mShouldLog);
|
||||
}
|
||||
|
||||
private:
|
||||
static std::ostream& severityOstream(Severity severity)
|
||||
{
|
||||
return severity >= Severity::kINFO ? std::cout : std::cerr;
|
||||
}
|
||||
|
||||
static std::string severityPrefix(Severity severity)
|
||||
{
|
||||
switch (severity)
|
||||
{
|
||||
case Severity::kINTERNAL_ERROR: return "[F] ";
|
||||
case Severity::kERROR: return "[E] ";
|
||||
case Severity::kWARNING: return "[W] ";
|
||||
case Severity::kINFO: return "[I] ";
|
||||
case Severity::kVERBOSE: return "[V] ";
|
||||
default: assert(0); return "";
|
||||
}
|
||||
}
|
||||
|
||||
bool mShouldLog;
|
||||
Severity mSeverity;
|
||||
};
|
||||
|
||||
//! \class Logger
|
||||
//!
|
||||
//! \brief Class which manages logging of TensorRT tools and samples
|
||||
//!
|
||||
//! \details This class provides a common interface for TensorRT tools and samples to log information to the console,
|
||||
//! and supports logging two types of messages:
|
||||
//!
|
||||
//! - Debugging messages with an associated severity (info, warning, error, or internal error/fatal)
|
||||
//! - Test pass/fail messages
|
||||
//!
|
||||
//! The advantage of having all samples use this class for logging as opposed to emitting directly to stdout/stderr is
|
||||
//! that the logic for controlling the verbosity and formatting of sample output is centralized in one location.
|
||||
//!
|
||||
//! In the future, this class could be extended to support dumping test results to a file in some standard format
|
||||
//! (for example, JUnit XML), and providing additional metadata (e.g. timing the duration of a test run).
|
||||
//!
|
||||
//! TODO: For backwards compatibility with existing samples, this class inherits directly from the nvinfer1::ILogger
|
||||
//! interface, which is problematic since there isn't a clean separation between messages coming from the TensorRT
|
||||
//! library and messages coming from the sample.
|
||||
//!
|
||||
//! In the future (once all samples are updated to use Logger::getTRTLogger() to access the ILogger) we can refactor the
|
||||
//! class to eliminate the inheritance and instead make the nvinfer1::ILogger implementation a member of the Logger
|
||||
//! object.
|
||||
|
||||
class Logger : public nvinfer1::ILogger
|
||||
{
|
||||
public:
|
||||
Logger(Severity severity = Severity::kWARNING)
|
||||
: mReportableSeverity(severity)
|
||||
{
|
||||
}
|
||||
|
||||
//!
|
||||
//! \enum TestResult
|
||||
//! \brief Represents the state of a given test
|
||||
//!
|
||||
enum class TestResult
|
||||
{
|
||||
kRUNNING, //!< The test is running
|
||||
kPASSED, //!< The test passed
|
||||
kFAILED, //!< The test failed
|
||||
kWAIVED //!< The test was waived
|
||||
};
|
||||
|
||||
//!
|
||||
//! \brief Forward-compatible method for retrieving the nvinfer::ILogger associated with this Logger
|
||||
//! \return The nvinfer1::ILogger associated with this Logger
|
||||
//!
|
||||
//! TODO Once all samples are updated to use this method to register the logger with TensorRT,
|
||||
//! we can eliminate the inheritance of Logger from ILogger
|
||||
//!
|
||||
nvinfer1::ILogger& getTRTLogger()
|
||||
{
|
||||
return *this;
|
||||
}
|
||||
|
||||
//!
|
||||
//! \brief Implementation of the nvinfer1::ILogger::log() virtual method
|
||||
//!
|
||||
//! Note samples should not be calling this function directly; it will eventually go away once we eliminate the
|
||||
//! inheritance from nvinfer1::ILogger
|
||||
//!
|
||||
void log(Severity severity, const char* msg) TRT_NOEXCEPT override
|
||||
{
|
||||
LogStreamConsumer(mReportableSeverity, severity) << "[TRT] " << std::string(msg) << std::endl;
|
||||
}
|
||||
|
||||
//!
|
||||
//! \brief Method for controlling the verbosity of logging output
|
||||
//!
|
||||
//! \param severity The logger will only emit messages that have severity of this level or higher.
|
||||
//!
|
||||
void setReportableSeverity(Severity severity)
|
||||
{
|
||||
mReportableSeverity = severity;
|
||||
}
|
||||
|
||||
//!
|
||||
//! \brief Opaque handle that holds logging information for a particular test
|
||||
//!
|
||||
//! This object is an opaque handle to information used by the Logger to print test results.
|
||||
//! The sample must call Logger::defineTest() in order to obtain a TestAtom that can be used
|
||||
//! with Logger::reportTest{Start,End}().
|
||||
//!
|
||||
class TestAtom
|
||||
{
|
||||
public:
|
||||
TestAtom(TestAtom&&) = default;
|
||||
|
||||
private:
|
||||
friend class Logger;
|
||||
|
||||
TestAtom(bool started, const std::string& name, const std::string& cmdline)
|
||||
: mStarted(started)
|
||||
, mName(name)
|
||||
, mCmdline(cmdline)
|
||||
{
|
||||
}
|
||||
|
||||
bool mStarted;
|
||||
std::string mName;
|
||||
std::string mCmdline;
|
||||
};
|
||||
|
||||
//!
|
||||
//! \brief Define a test for logging
|
||||
//!
|
||||
//! \param[in] name The name of the test. This should be a string starting with
|
||||
//! "TensorRT" and containing dot-separated strings containing
|
||||
//! the characters [A-Za-z0-9_].
|
||||
//! For example, "TensorRT.sample_googlenet"
|
||||
//! \param[in] cmdline The command line used to reproduce the test
|
||||
//
|
||||
//! \return a TestAtom that can be used in Logger::reportTest{Start,End}().
|
||||
//!
|
||||
static TestAtom defineTest(const std::string& name, const std::string& cmdline)
|
||||
{
|
||||
return TestAtom(false, name, cmdline);
|
||||
}
|
||||
|
||||
//!
|
||||
//! \brief A convenience overloaded version of defineTest() that accepts an array of command-line arguments
|
||||
//! as input
|
||||
//!
|
||||
//! \param[in] name The name of the test
|
||||
//! \param[in] argc The number of command-line arguments
|
||||
//! \param[in] argv The array of command-line arguments (given as C strings)
|
||||
//!
|
||||
//! \return a TestAtom that can be used in Logger::reportTest{Start,End}().
|
||||
static TestAtom defineTest(const std::string& name, int argc, char const* const* argv)
|
||||
{
|
||||
auto cmdline = genCmdlineString(argc, argv);
|
||||
return defineTest(name, cmdline);
|
||||
}
|
||||
|
||||
//!
|
||||
//! \brief Report that a test has started.
|
||||
//!
|
||||
//! \pre reportTestStart() has not been called yet for the given testAtom
|
||||
//!
|
||||
//! \param[in] testAtom The handle to the test that has started
|
||||
//!
|
||||
static void reportTestStart(TestAtom& testAtom)
|
||||
{
|
||||
reportTestResult(testAtom, TestResult::kRUNNING);
|
||||
assert(!testAtom.mStarted);
|
||||
testAtom.mStarted = true;
|
||||
}
|
||||
|
||||
//!
|
||||
//! \brief Report that a test has ended.
|
||||
//!
|
||||
//! \pre reportTestStart() has been called for the given testAtom
|
||||
//!
|
||||
//! \param[in] testAtom The handle to the test that has ended
|
||||
//! \param[in] result The result of the test. Should be one of TestResult::kPASSED,
|
||||
//! TestResult::kFAILED, TestResult::kWAIVED
|
||||
//!
|
||||
static void reportTestEnd(const TestAtom& testAtom, TestResult result)
|
||||
{
|
||||
assert(result != TestResult::kRUNNING);
|
||||
assert(testAtom.mStarted);
|
||||
reportTestResult(testAtom, result);
|
||||
}
|
||||
|
||||
static int reportPass(const TestAtom& testAtom)
|
||||
{
|
||||
reportTestEnd(testAtom, TestResult::kPASSED);
|
||||
return EXIT_SUCCESS;
|
||||
}
|
||||
|
||||
static int reportFail(const TestAtom& testAtom)
|
||||
{
|
||||
reportTestEnd(testAtom, TestResult::kFAILED);
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
||||
static int reportWaive(const TestAtom& testAtom)
|
||||
{
|
||||
reportTestEnd(testAtom, TestResult::kWAIVED);
|
||||
return EXIT_SUCCESS;
|
||||
}
|
||||
|
||||
static int reportTest(const TestAtom& testAtom, bool pass)
|
||||
{
|
||||
return pass ? reportPass(testAtom) : reportFail(testAtom);
|
||||
}
|
||||
|
||||
Severity getReportableSeverity() const
|
||||
{
|
||||
return mReportableSeverity;
|
||||
}
|
||||
|
||||
private:
|
||||
//!
|
||||
//! \brief returns an appropriate string for prefixing a log message with the given severity
|
||||
//!
|
||||
static const char* severityPrefix(Severity severity)
|
||||
{
|
||||
switch (severity)
|
||||
{
|
||||
case Severity::kINTERNAL_ERROR: return "[F] ";
|
||||
case Severity::kERROR: return "[E] ";
|
||||
case Severity::kWARNING: return "[W] ";
|
||||
case Severity::kINFO: return "[I] ";
|
||||
case Severity::kVERBOSE: return "[V] ";
|
||||
default: assert(0); return "";
|
||||
}
|
||||
}
|
||||
|
||||
//!
|
||||
//! \brief returns an appropriate string for prefixing a test result message with the given result
|
||||
//!
|
||||
static const char* testResultString(TestResult result)
|
||||
{
|
||||
switch (result)
|
||||
{
|
||||
case TestResult::kRUNNING: return "RUNNING";
|
||||
case TestResult::kPASSED: return "PASSED";
|
||||
case TestResult::kFAILED: return "FAILED";
|
||||
case TestResult::kWAIVED: return "WAIVED";
|
||||
default: assert(0); return "";
|
||||
}
|
||||
}
|
||||
|
||||
//!
|
||||
//! \brief returns an appropriate output stream (cout or cerr) to use with the given severity
|
||||
//!
|
||||
static std::ostream& severityOstream(Severity severity)
|
||||
{
|
||||
return severity >= Severity::kINFO ? std::cout : std::cerr;
|
||||
}
|
||||
|
||||
//!
|
||||
//! \brief method that implements logging test results
|
||||
//!
|
||||
static void reportTestResult(const TestAtom& testAtom, TestResult result)
|
||||
{
|
||||
severityOstream(Severity::kINFO) << "&&&& " << testResultString(result) << " " << testAtom.mName << " # "
|
||||
<< testAtom.mCmdline << std::endl;
|
||||
}
|
||||
|
||||
//!
|
||||
//! \brief generate a command line string from the given (argc, argv) values
|
||||
//!
|
||||
static std::string genCmdlineString(int argc, char const* const* argv)
|
||||
{
|
||||
std::stringstream ss;
|
||||
for (int i = 0; i < argc; i++)
|
||||
{
|
||||
if (i > 0)
|
||||
ss << " ";
|
||||
ss << argv[i];
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
Severity mReportableSeverity;
|
||||
};
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
//!
|
||||
//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kVERBOSE
|
||||
//!
|
||||
//! Example usage:
|
||||
//!
|
||||
//! LOG_VERBOSE(logger) << "hello world" << std::endl;
|
||||
//!
|
||||
inline LogStreamConsumer LOG_VERBOSE(const Logger& logger)
|
||||
{
|
||||
return LogStreamConsumer(logger.getReportableSeverity(), Severity::kVERBOSE);
|
||||
}
|
||||
|
||||
//!
|
||||
//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINFO
|
||||
//!
|
||||
//! Example usage:
|
||||
//!
|
||||
//! LOG_INFO(logger) << "hello world" << std::endl;
|
||||
//!
|
||||
inline LogStreamConsumer LOG_INFO(const Logger& logger)
|
||||
{
|
||||
return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINFO);
|
||||
}
|
||||
|
||||
//!
|
||||
//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kWARNING
|
||||
//!
|
||||
//! Example usage:
|
||||
//!
|
||||
//! LOG_WARN(logger) << "hello world" << std::endl;
|
||||
//!
|
||||
inline LogStreamConsumer LOG_WARN(const Logger& logger)
|
||||
{
|
||||
return LogStreamConsumer(logger.getReportableSeverity(), Severity::kWARNING);
|
||||
}
|
||||
|
||||
//!
|
||||
//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kERROR
|
||||
//!
|
||||
//! Example usage:
|
||||
//!
|
||||
//! LOG_ERROR(logger) << "hello world" << std::endl;
|
||||
//!
|
||||
inline LogStreamConsumer LOG_ERROR(const Logger& logger)
|
||||
{
|
||||
return LogStreamConsumer(logger.getReportableSeverity(), Severity::kERROR);
|
||||
}
|
||||
|
||||
//!
|
||||
//! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINTERNAL_ERROR
|
||||
// ("fatal" severity)
|
||||
//!
|
||||
//! Example usage:
|
||||
//!
|
||||
//! LOG_FATAL(logger) << "hello world" << std::endl;
|
||||
//!
|
||||
inline LogStreamConsumer LOG_FATAL(const Logger& logger)
|
||||
{
|
||||
return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINTERNAL_ERROR);
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
#endif // TENSORRT_LOGGING_H
|
||||
27
include/YOLOv5/macros.h
Executable file
27
include/YOLOv5/macros.h
Executable file
@@ -0,0 +1,27 @@
|
||||
#ifndef __MACROS_H
|
||||
#define __MACROS_H
|
||||
|
||||
#ifdef API_EXPORTS
|
||||
#if defined(_MSC_VER)
|
||||
#define API __declspec(dllexport)
|
||||
#else
|
||||
#define API __attribute__((visibility("default")))
|
||||
#endif
|
||||
#else
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#define API __declspec(dllimport)
|
||||
#else
|
||||
#define API
|
||||
#endif
|
||||
#endif // API_EXPORTS
|
||||
|
||||
#if NV_TENSORRT_MAJOR >= 8
|
||||
#define TRT_NOEXCEPT noexcept
|
||||
#define TRT_CONST_ENQUEUE const
|
||||
#else
|
||||
#define TRT_NOEXCEPT
|
||||
#define TRT_CONST_ENQUEUE
|
||||
#endif
|
||||
|
||||
#endif // __MACROS_H
|
||||
16
include/YOLOv5/preprocess.h
Executable file
16
include/YOLOv5/preprocess.h
Executable file
@@ -0,0 +1,16 @@
|
||||
#ifndef __PREPROCESS_H
|
||||
#define __PREPROCESS_H
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cstdint>
|
||||
|
||||
|
||||
struct AffineMatrix{
|
||||
float value[6];
|
||||
};
|
||||
|
||||
|
||||
void preprocess_kernel_img(uint8_t* src, int src_width, int src_height,
|
||||
float* dst, int dst_width, int dst_height,
|
||||
cudaStream_t stream);
|
||||
#endif // __PREPROCESS_H
|
||||
52
include/YOLOv5/utils.h
Executable file
52
include/YOLOv5/utils.h
Executable file
@@ -0,0 +1,52 @@
|
||||
#ifndef TRTX_YOLOV5_UTILS_H_
|
||||
#define TRTX_YOLOV5_UTILS_H_
|
||||
|
||||
#include <dirent.h>
|
||||
#include <opencv2/opencv.hpp>
|
||||
|
||||
static inline cv::Mat preprocess_img(cv::Mat& img, int input_w, int input_h) {
|
||||
int w, h, x, y;
|
||||
float r_w = input_w / (img.cols*1.0);
|
||||
float r_h = input_h / (img.rows*1.0);
|
||||
if (r_h > r_w) {
|
||||
w = input_w;
|
||||
h = r_w * img.rows;
|
||||
x = 0;
|
||||
y = (input_h - h) / 2;
|
||||
} else {
|
||||
w = r_h * img.cols;
|
||||
h = input_h;
|
||||
x = (input_w - w) / 2;
|
||||
y = 0;
|
||||
}
|
||||
cv::Mat re(h, w, CV_8UC3);
|
||||
cv::resize(img, re, re.size(), 0, 0, cv::INTER_LINEAR);
|
||||
cv::Mat out(input_h, input_w, CV_8UC3, cv::Scalar(128, 128, 128));
|
||||
re.copyTo(out(cv::Rect(x, y, re.cols, re.rows)));
|
||||
return out;
|
||||
}
|
||||
|
||||
static inline int read_files_in_dir(const char *p_dir_name, std::vector<std::string> &file_names) {
|
||||
DIR *p_dir = opendir(p_dir_name);
|
||||
if (p_dir == nullptr) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
struct dirent* p_file = nullptr;
|
||||
while ((p_file = readdir(p_dir)) != nullptr) {
|
||||
if (strcmp(p_file->d_name, ".") != 0 &&
|
||||
strcmp(p_file->d_name, "..") != 0) {
|
||||
//std::string cur_file_name(p_dir_name);
|
||||
//cur_file_name += "/";
|
||||
//cur_file_name += p_file->d_name;
|
||||
std::string cur_file_name(p_file->d_name);
|
||||
file_names.push_back(cur_file_name);
|
||||
}
|
||||
}
|
||||
|
||||
closedir(p_dir);
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // TRTX_YOLOV5_UTILS_H_
|
||||
|
||||
138
include/YOLOv5/yololayer.h
Executable file
138
include/YOLOv5/yololayer.h
Executable file
@@ -0,0 +1,138 @@
|
||||
#ifndef _YOLO_LAYER_H
|
||||
#define _YOLO_LAYER_H
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <NvInfer.h>
|
||||
#include "macros.h"
|
||||
|
||||
namespace Yolo
|
||||
{
|
||||
static constexpr int CHECK_COUNT = 3;
|
||||
static constexpr float IGNORE_THRESH = 0.1f;
|
||||
struct YoloKernel
|
||||
{
|
||||
int width;
|
||||
int height;
|
||||
float anchors[CHECK_COUNT * 2];
|
||||
};
|
||||
static constexpr int MAX_OUTPUT_BBOX_COUNT = 1000;
|
||||
static constexpr int CLASS_NUM = 80;
|
||||
static constexpr int INPUT_H = 640; // yolov5's input height and width must be divisible by 32.
|
||||
static constexpr int INPUT_W = 640;
|
||||
|
||||
static constexpr int LOCATIONS = 4;
|
||||
struct alignas(float) Detection {
|
||||
//center_x center_y w h
|
||||
float bbox[LOCATIONS];
|
||||
float conf; // bbox_conf * cls_conf
|
||||
float class_id;
|
||||
};
|
||||
}
|
||||
|
||||
namespace nvinfer1
|
||||
{
|
||||
class API YoloLayerPlugin : public IPluginV2IOExt
|
||||
{
|
||||
public:
|
||||
YoloLayerPlugin(int classCount, int netWidth, int netHeight, int maxOut, const std::vector<Yolo::YoloKernel>& vYoloKernel);
|
||||
YoloLayerPlugin(const void* data, size_t length);
|
||||
~YoloLayerPlugin();
|
||||
|
||||
int getNbOutputs() const TRT_NOEXCEPT override
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
|
||||
Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) TRT_NOEXCEPT override;
|
||||
|
||||
int initialize() TRT_NOEXCEPT override;
|
||||
|
||||
virtual void terminate() TRT_NOEXCEPT override {};
|
||||
|
||||
virtual size_t getWorkspaceSize(int maxBatchSize) const TRT_NOEXCEPT override { return 0; }
|
||||
|
||||
virtual int enqueue(int batchSize, const void* const* inputs, void*TRT_CONST_ENQUEUE* outputs, void* workspace, cudaStream_t stream) TRT_NOEXCEPT override;
|
||||
|
||||
virtual size_t getSerializationSize() const TRT_NOEXCEPT override;
|
||||
|
||||
virtual void serialize(void* buffer) const TRT_NOEXCEPT override;
|
||||
|
||||
bool supportsFormatCombination(int pos, const PluginTensorDesc* inOut, int nbInputs, int nbOutputs) const TRT_NOEXCEPT override {
|
||||
return inOut[pos].format == TensorFormat::kLINEAR && inOut[pos].type == DataType::kFLOAT;
|
||||
}
|
||||
|
||||
const char* getPluginType() const TRT_NOEXCEPT override;
|
||||
|
||||
const char* getPluginVersion() const TRT_NOEXCEPT override;
|
||||
|
||||
void destroy() TRT_NOEXCEPT override;
|
||||
|
||||
IPluginV2IOExt* clone() const TRT_NOEXCEPT override;
|
||||
|
||||
void setPluginNamespace(const char* pluginNamespace) TRT_NOEXCEPT override;
|
||||
|
||||
const char* getPluginNamespace() const TRT_NOEXCEPT override;
|
||||
|
||||
DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const TRT_NOEXCEPT override;
|
||||
|
||||
bool isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const TRT_NOEXCEPT override;
|
||||
|
||||
bool canBroadcastInputAcrossBatch(int inputIndex) const TRT_NOEXCEPT override;
|
||||
|
||||
void attachToContext(
|
||||
cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator) TRT_NOEXCEPT override;
|
||||
|
||||
void configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput) TRT_NOEXCEPT override;
|
||||
|
||||
void detachFromContext() TRT_NOEXCEPT override;
|
||||
|
||||
private:
|
||||
void forwardGpu(const float* const* inputs, float *output, cudaStream_t stream, int batchSize = 1);
|
||||
int mThreadCount = 256;
|
||||
const char* mPluginNamespace;
|
||||
int mKernelCount;
|
||||
int mClassCount;
|
||||
int mYoloV5NetWidth;
|
||||
int mYoloV5NetHeight;
|
||||
int mMaxOutObject;
|
||||
std::vector<Yolo::YoloKernel> mYoloKernel;
|
||||
void** mAnchor;
|
||||
};
|
||||
|
||||
class API YoloPluginCreator : public IPluginCreator
|
||||
{
|
||||
public:
|
||||
YoloPluginCreator();
|
||||
|
||||
~YoloPluginCreator() override = default;
|
||||
|
||||
const char* getPluginName() const TRT_NOEXCEPT override;
|
||||
|
||||
const char* getPluginVersion() const TRT_NOEXCEPT override;
|
||||
|
||||
const PluginFieldCollection* getFieldNames() TRT_NOEXCEPT override;
|
||||
|
||||
IPluginV2IOExt* createPlugin(const char* name, const PluginFieldCollection* fc) TRT_NOEXCEPT override;
|
||||
|
||||
IPluginV2IOExt* deserializePlugin(const char* name, const void* serialData, size_t serialLength) TRT_NOEXCEPT override;
|
||||
|
||||
void setPluginNamespace(const char* libNamespace) TRT_NOEXCEPT override
|
||||
{
|
||||
mNamespace = libNamespace;
|
||||
}
|
||||
|
||||
const char* getPluginNamespace() const TRT_NOEXCEPT override
|
||||
{
|
||||
return mNamespace.c_str();
|
||||
}
|
||||
|
||||
private:
|
||||
std::string mNamespace;
|
||||
static PluginFieldCollection mFC;
|
||||
static std::vector<PluginField> mPluginAttributes;
|
||||
};
|
||||
REGISTER_TENSORRT_PLUGIN(YoloPluginCreator);
|
||||
};
|
||||
|
||||
#endif // _YOLO_LAYER_H
|
||||
Reference in New Issue
Block a user