当前位置:网站首页>模型推理模板
模型推理模板
2022-07-29 19:58:00 【洪流之源】
infer.h
#pragma once
#include <string>
#include <future>
#include <memory>
// 封装接口类
class Infer
{
public:
virtual std::shared_future<std::string> commit(const std::string& input) = 0;
};
std::shared_ptr<Infer> create_infer(const std::string& file);infer.cpp
#include "infer.h"
#include <thread>
#include <vector>
#include <condition_variable>
#include <mutex>
#include <string>
#include <future>
#include <queue>
#include <functional>
// 封装接口类
struct Job
{
std::shared_ptr<std::promise<std::string>> pro;
std::string input;
};
class InferImpl : public Infer
{
public:
virtual ~InferImpl()
{
stop();
}
void stop()
{
if (running_)
{
running_ = false;
cv_.notify_one();
}
if (worker_thread_.joinable())
worker_thread_.join();
}
bool startup(const std::string& file)
{
file_ = file;
running_ = true; // 启动后,运行状态设置为true
// 线程传递promise的目的,是获得线程是否初始化成功的状态
// 而在线程内做初始化,好处是,初始化跟释放在同一个线程内
// 代码可读性好,资源管理方便
std::promise<bool> pro;
worker_thread_ = std::thread(&InferImpl::worker, this, std::ref(pro));
/*
注意:这里thread 一构建好后,worker函数就开始执行了
第一个参数是该线程要执行的worker函数,第二个参数是this指的是class InferImpl,第三个参数指的是传引用,因为我们在worker函数里要修改pro。
*/
return pro.get_future().get();
}
virtual std::shared_future<std::string> commit(const std::string& input) override
{
Job job;
job.input = input;
job.pro.reset(new std::promise<std::string>());
std::shared_future<std::string> fut = job.pro->get_future();
{
std::lock_guard<std::mutex> l(lock_);
jobs_.emplace(std::move(job));
}
cv_.notify_one();
return fut;
}
void worker(std::promise<bool>& pro)
{
// load model
if (file_ != "trtfile")
{
// failed
pro.set_value(false);
printf("Load model failed: %s\n", file_.c_str());
return;
}
// load success
pro.set_value(true); // 这里的promise用来负责确认infer初始化成功了
std::vector<Job> fetched_jobs;
while (running_)
{
{
std::unique_lock<std::mutex> l(lock_);
// 一直等着,cv_.wait(lock, predicate)
// 如果 running不在运行状态 或者说 jobs_有东西 而且接收到了notify one的信号
cv_.wait(l, [&]() {return not running_ || not jobs_.empty(); });
// 如果停止运行,则直接结束循环
if (not running_) break;
int batch_size = 5;
for (int i = 0; i < batch_size && not jobs_.empty(); ++i)
{ // jobs_不为空的时候
fetched_jobs.emplace_back(std::move(jobs_.front())); // 就往里面fetched_jobs里塞东西
jobs_.pop(); // fetched_jobs塞进来一个,jobs_那边就要pop掉一个。(因为move)
}
}
// 一次加载一批,并进行批处理
// forward(fetched_jobs)
for (auto& job : fetched_jobs)
{
job.pro->set_value(job.input + "---processed");
}
fetched_jobs.clear();
}
printf("Infer worker done.\n");
}
private:
std::atomic<bool> running_{ false };
std::string file_;
std::thread worker_thread_;
std::queue<Job> jobs_;
std::mutex lock_;
std::condition_variable cv_;
};
std::shared_ptr<Infer> create_infer(const std::string& file)
{
// 实例化一个推理器的实现类(inferImpl),以指针形式返回
std::shared_ptr<InferImpl> instance = std::make_shared<InferImpl>();
// 推理器实现类实例(instance)启动。这里的file是engine file
if (not instance->startup(file))
{
instance.reset(); // 如果启动不成功就reset
}
return instance;
}
main.cpp
#include "infer.h"
int main()
{
auto infer = create_infer("trtfile"); // 创建及初始化
if (infer == nullptr)
{
printf("Infer is nullptr.\n");
return -1;
}
// 将任务提交给推理器(推理器执行commit),同时推理器(infer)也等着获取(get)结果
printf("commit msg = %s\n", infer->commit("msg").get().c_str());
return 0;
}边栏推荐
猜你喜欢

C language learning books zero-based introductory articles

The ambition of glory: "high-end civilians" in a smart world

RNA修饰质谱检测|dextran-siRNA 葡聚糖化学偶联DNA/RNA|siRNA-PLGA聚乳酸-羟基乙酸共聚物修饰核糖核酸

SAG1-MIC8复合DNA基因疫苗|新型脂质-HAP-DNA复合体|实验要求

Monitoring basic resources through observation cloud monitor, automatic alarm
![[mathematical foundation] probability and mathematical statistics related concept learning](/img/bc/d3a246240ff7aca2b84c3766383758.png)
[mathematical foundation] probability and mathematical statistics related concept learning

分布式限流 redission RRateLimiter 的使用及原理

数据可视化----网页显示温湿度

通过观测云监控器监控基础资源,自动报警

4D Summary: 38 Knowledge Points of Distributed Systems
随机推荐
webUI测试框架设计思路详解
全渠道电商 | 国内知名的药妆要如何抓住风口实现快速增长?
使用MD5加密后的字符串存密码安全吗?你不得不了解的Hash算法
JMeter usage tutorial (2)
Private domain growth | Private domain members: 15 case collections from 9 major chain industries
Single-core browser and what is the difference between dual-core browser, which to use?
JMeter使用教程(二)
指定宽度截取字符串
scratch programming + elementary math
12437字,带你深入探究RPC通讯原理
R language for airbnb data nlp text mining, geography, word cloud visualization, regression GAM model, cross-validation analysis
C language learning books (improvement)
mos管闩锁效应理解学习
The difference between uri and url is simple to understand (what is the difference between uri and url)
促进二十一世纪创客教育的新发展
Chrome——插件推荐
经验分享|编写简单易用的在线产品手册小妙招
R语言对airbnb数据nlp文本挖掘、地理、词云可视化、回归GAM模型、交叉验证分析
用对象字面量或Map替代Switch/if语句
Verilog的时间格式系统任务----$printtimescale、$timeformat