当前位置:网站首页>TensorRTx-YOLOv5工程解读(一)
TensorRTx-YOLOv5工程解读(一)
2022-08-04 05:24:00 【单胖】
TensorRTx-YOLOv5工程解读(一)
权重生成:gen_wts.py
作者先是使用了gen_wts.py这个脚本去生成wts文件。顾名思义,这个.wts文件里面存放的就是.pt文件的权重。脚本内容如下:
import sys
import argparse
import os
import struct
import torch
from utils.torch_utils import select_device
def parse_args():
parser = argparse.ArgumentParser(description='Convert .pt file to .wts')
parser.add_argument('-w', '--weights', required=True, help='Input weights (.pt) file path (required)')
parser.add_argument('-o', '--output', help='Output (.wts) file path (optional)')
args = parser.parse_args()
if not os.path.isfile(args.weights):
raise SystemExit('Invalid input file')
if not args.output:
args.output = os.path.splitext(args.weights)[0] + '.wts'
elif os.path.isdir(args.output):
args.output = os.path.join(
args.output,
os.path.splitext(os.path.basename(args.weights))[0] + '.wts')
return args.weights, args.output
pt_file, wts_file = parse_args()
# Initialize
device = select_device('cpu')
# Load model
model = torch.load(pt_file, map_location=device)['model'].float() # load to FP32
model.to(device).eval()
with open(wts_file, 'w') as f:
f.write('{}\n'.format(len(model.state_dict().keys())))
for k, v in model.state_dict().items():
vr = v.reshape(-1).cpu().numpy()
f.write('{} {} '.format(k, len(vr)))
for vv in vr:
f.write(' ')
f.write(struct.pack('>f' ,float(vv)).hex())
f.write('\n')
第一个函数parse_args()
就是正常处理输入的命令行参数,不多做赘述。
主函数内,先是设置设备为CPU,再load进pt文件获得model并转成FP32格式。并设置模型的device和eval模式。
设置完毕后,作者保存权重文件,其中权重文件的内容是作者自定义的。第一行存入的是model的keys的个数,再分别遍历pt文件内的每一个权重,保存为该层名称
该层参数量
16进制权重
。
权重读取:common.cpp
首先顺着之前的思路,看看作者是如何load权重的。
// TensorRT weight files have a simple space delimited format:
// [type] [size] <data x size in hex>
std::map<std::string, Weights> loadWeights(const std::string file) {
std::cout << "Loading weights: " << file << std::endl;
std::map<std::string, Weights> weightMap;
// Open weights file
std::ifstream input(file);
assert(input.is_open() && "Unable to load weight file. please check if the .wts file path is right!!!!!!");
// Read number of weight blobs
int32_t count;
input >> count;
assert(count > 0 && "Invalid weight map file.");
while (count--)
{
Weights wt{
DataType::kFLOAT, nullptr, 0 };
uint32_t size;
// Read name and type of blob
std::string name;
input >> name >> std::dec >> size;
wt.type = DataType::kFLOAT;
// Load blob
uint32_t* val = reinterpret_cast<uint32_t*>(malloc(sizeof(val) * size));
for (uint32_t x = 0, y = size; x < y; ++x)
{
input >> std::hex >> val[x];
}
wt.values = val;
wt.count = size;
weightMap[name] = wt;
}
return weightMap;
}
此为loadWeight()
函数。作者此处使用了std::map
容器。map容器在OpenCV和OpenVINO中本身就是大量使用的,所以除了vector
之外,也需要掌握map
的使用。后面需要往这个<std::string, Weights>
型的map中添加权重信息。
同时应该注意,此处的Weights
类型在TensorRT的NvInferRuntime.h头文件中有定义:
class Weights
{
public:
DataType type; //!< The type of the weights.
const void* values; //!< The weight values, in a contiguous array.
int64_t count; //!< The number of weights in the array.
};
作者使用了std::ifstream
进行输入流变量的定义,并设置了一些变量。代码中的input >> count
就是将.wts文件中的第一行的算子数传递给count这个变量,从而构建while循环。
在While循环中,作者先定义了Weights
型的wt
变量,其类型为DataType::kFLOAT
,values直接初始化为nullptr
,count初始化一个0在上面即可。
这一句input >> name >> std::dec >> size
是将input中的第一部分:权重的名称,赋值给name
变量,再将紧跟着name后的size推入给size
变量。具体的形式可以参考之前分析gen_wts.py
脚本中的权重生成的部分。作者之所以要存入这一算子的权重的size,就是为了方便分配空间大小。声明指针val指向一个大小为sizeof(val) * size
的uint32_t
的数组,并且将input中这一行的权重全部推入给val这个数组即可。
这一步完成后,设置Weights的values
成员为val
,count
成员为size
,并将name
作为weightMap
的keys
,wt
作为其values
即可。
至此,模型权重加载完毕。
边栏推荐
- 读者让我总结一波 redis 面试题,现在肝出来了
- 高性能高可靠性高扩展性分布式防火墙架构
- 败给“MySQL”的第60天,我重振旗鼓,四面拿下蚂蚁金服offer
- What are the steps for how to develop a mall system APP?
- C专家编程 第4章 令人震惊的事实:数组和指针并不相同 4.2 我的代码为什么无法运行
- 3面头条,花7天整理了面试题和学习笔记,已正式入职半个月
- 8、自定义映射resultMap
- TSF微服务治理实战系列(一)——治理蓝图
- [Cloud Native--Kubernetes] Pod Resource Management and Probe Detection
- 谷粒商城-基础篇(项目简介&项目搭建)
猜你喜欢
el-Select 选择器 底部固定
心余力绌:企业面临的软件供应链安全困境
【评价类模型】Topsis法(优劣解距离法)
npm报错Beginning October 4, 2021, all connections to the npm registry - including for package installa
高性能高可靠性高扩展性分布式防火墙架构
少年成就黑客,需要这些技能
8. Custom mapping resultMap
应届生软件测试薪资大概多少?
Can‘t connect to MySQL server on ‘localhost3306‘ (10061) 简洁明了的解决方法
[Cocos 3.5.2]开启模型合批
随机推荐
力扣:63. 不同路径 II
震惊,99.9% 的同学没有真正理解字符串的不可变性
MySQL log articles, binlog log of MySQL log, detailed explanation of binlog log
力扣:746. 使用最小花费爬楼梯
Turn: Management is the love of possibility, and managers must have the courage to break into the unknown
代码重构:面向单元测试
嵌入式系统驱动初级【4】——字符设备驱动基础下_并发控制
一个对象引用的思考
C专家编程 第4章 令人震惊的事实:数组和指针并不相同 4.5 数组和指针的其他区别
动态规划总括
Will the 2023 PMP exam use the new version of the textbook?Reply is here!
The Road to Ad Monetization for Uni-app Mini Program Apps: Full Screen Video Ads
FPGA学习笔记——知识点总结
DataTable使用Linq进行分组汇总,将Linq结果集转化为DataTable
[Cocos] cc.sys.browserType可能的属性
少年成就黑客,需要这些技能
Can‘t connect to MySQL server on ‘localhost3306‘ (10061) 简洁明了的解决方法
力扣:343. 整数拆分
C Expert Programming Chapter 4 The Shocking Fact: Arrays and Pointers Are Not the Same 4.5 Other Differences Between Arrays and Pointers
商城系统APP如何开发 都有哪些步骤