当前位置:网站首页>TensorRTx-YOLOv5工程解读(一)
TensorRTx-YOLOv5工程解读(一)
2022-08-04 05:24:00 【单胖】
TensorRTx-YOLOv5工程解读(一)
权重生成:gen_wts.py
作者先是使用了gen_wts.py这个脚本去生成wts文件。顾名思义,这个.wts文件里面存放的就是.pt文件的权重。脚本内容如下:
import sys
import argparse
import os
import struct
import torch
from utils.torch_utils import select_device
def parse_args():
parser = argparse.ArgumentParser(description='Convert .pt file to .wts')
parser.add_argument('-w', '--weights', required=True, help='Input weights (.pt) file path (required)')
parser.add_argument('-o', '--output', help='Output (.wts) file path (optional)')
args = parser.parse_args()
if not os.path.isfile(args.weights):
raise SystemExit('Invalid input file')
if not args.output:
args.output = os.path.splitext(args.weights)[0] + '.wts'
elif os.path.isdir(args.output):
args.output = os.path.join(
args.output,
os.path.splitext(os.path.basename(args.weights))[0] + '.wts')
return args.weights, args.output
pt_file, wts_file = parse_args()
# Initialize
device = select_device('cpu')
# Load model
model = torch.load(pt_file, map_location=device)['model'].float() # load to FP32
model.to(device).eval()
with open(wts_file, 'w') as f:
f.write('{}\n'.format(len(model.state_dict().keys())))
for k, v in model.state_dict().items():
vr = v.reshape(-1).cpu().numpy()
f.write('{} {} '.format(k, len(vr)))
for vv in vr:
f.write(' ')
f.write(struct.pack('>f' ,float(vv)).hex())
f.write('\n')
第一个函数parse_args()
就是正常处理输入的命令行参数,不多做赘述。
主函数内,先是设置设备为CPU,再load进pt文件获得model并转成FP32格式。并设置模型的device和eval模式。
设置完毕后,作者保存权重文件,其中权重文件的内容是作者自定义的。第一行存入的是model的keys的个数,再分别遍历pt文件内的每一个权重,保存为该层名称
该层参数量
16进制权重
。
权重读取:common.cpp
首先顺着之前的思路,看看作者是如何load权重的。
// TensorRT weight files have a simple space delimited format:
// [type] [size] <data x size in hex>
std::map<std::string, Weights> loadWeights(const std::string file) {
std::cout << "Loading weights: " << file << std::endl;
std::map<std::string, Weights> weightMap;
// Open weights file
std::ifstream input(file);
assert(input.is_open() && "Unable to load weight file. please check if the .wts file path is right!!!!!!");
// Read number of weight blobs
int32_t count;
input >> count;
assert(count > 0 && "Invalid weight map file.");
while (count--)
{
Weights wt{
DataType::kFLOAT, nullptr, 0 };
uint32_t size;
// Read name and type of blob
std::string name;
input >> name >> std::dec >> size;
wt.type = DataType::kFLOAT;
// Load blob
uint32_t* val = reinterpret_cast<uint32_t*>(malloc(sizeof(val) * size));
for (uint32_t x = 0, y = size; x < y; ++x)
{
input >> std::hex >> val[x];
}
wt.values = val;
wt.count = size;
weightMap[name] = wt;
}
return weightMap;
}
此为loadWeight()
函数。作者此处使用了std::map
容器。map容器在OpenCV和OpenVINO中本身就是大量使用的,所以除了vector
之外,也需要掌握map
的使用。后面需要往这个<std::string, Weights>
型的map中添加权重信息。
同时应该注意,此处的Weights
类型在TensorRT的NvInferRuntime.h头文件中有定义:
class Weights
{
public:
DataType type; //!< The type of the weights.
const void* values; //!< The weight values, in a contiguous array.
int64_t count; //!< The number of weights in the array.
};
作者使用了std::ifstream
进行输入流变量的定义,并设置了一些变量。代码中的input >> count
就是将.wts文件中的第一行的算子数传递给count这个变量,从而构建while循环。
在While循环中,作者先定义了Weights
型的wt
变量,其类型为DataType::kFLOAT
,values直接初始化为nullptr
,count初始化一个0在上面即可。
这一句input >> name >> std::dec >> size
是将input中的第一部分:权重的名称,赋值给name
变量,再将紧跟着name后的size推入给size
变量。具体的形式可以参考之前分析gen_wts.py
脚本中的权重生成的部分。作者之所以要存入这一算子的权重的size,就是为了方便分配空间大小。声明指针val指向一个大小为sizeof(val) * size
的uint32_t
的数组,并且将input中这一行的权重全部推入给val这个数组即可。
这一步完成后,设置Weights的values
成员为val
,count
成员为size
,并将name
作为weightMap
的keys
,wt
作为其values
即可。
至此,模型权重加载完毕。
边栏推荐
猜你喜欢
随机推荐
C Expert Programming Chapter 5 Thinking about Chaining 5.6 Take it easy --- see who's talking: take the Turning quiz
擎朗智能全国研发创新中心落地光谷:去年曾获2亿美元融资
FPGA学习笔记——知识点总结
C Expert Programming Chapter 4 The Shocking Fact: Arrays and pointers are not the same 4.1 Arrays are not pointers
【JS】js给对象动态添加、设置、删除属性名和属性值
一个对象引用的思考
想低成本保障软件安全?5大安全任务值得考虑
在被面试官说了无数次后,终于潜下心来整理了一下JVM的类加载器
Write golang simple C2 remote control based on gRPC
npm报错Beginning October 4, 2021, all connections to the npm registry - including for package installa
Typora 使用保姆级教程 | 看这一篇就够了 | 历史版本已被禁用
Get the selected content of the radio box
应届生软件测试薪资大概多少?
MySQL日期函数
JS basics - forced type conversion (error-prone, self-use)
Performance testing with Loadrunner
力扣:96.不同的二叉搜索树
4.1 声明式事务之JdbcTemplate
8、自定义映射resultMap
The idea setting recognizes the .sql file type and other file types