当前位置:网站首页>[tensorrt] convert pytorch into deployable tensorrt
[tensorrt] convert pytorch into deployable tensorrt
2022-07-29 06:03:00 【Dull cat】
List of articles

In the process of deep learning model landing , You will face the problem of deploying the model to edge devices , Model training uses different frameworks , Then we need to use the same framework when reasoning , But different types of platforms , Tuning and implementation are very difficult , Because each platform has different functions and features . If you need to run multiple frameworks on this platform , Will increase complexity , therefore ONNX It's useful . You can convert models trained in different frameworks into general ONNX Model , And then convert it into the formats supported by various platforms , You can simplify deployment .
One 、 What is? ONNX
ONNX yes Open Neural Network Exchange For short , Also called open neural network switching , Is a standard for representing deep learning models , The model can be directly converted in different frameworks .
ONNX It's the first step towards an open ecosystem , So that developers are not limited to a specific development tool , Open source format is provided for the model .
ONNX Currently supported frameworks are :Caffe2、PyTorch、TensorFlow、MXNet、TensorRT、CNTK etc.
ONNX Generally speaking, it is an intermediary , It's a means , Transforming the model into ONNX after , And then convert it into a deployable form , Such as TensorRT.
Typical structural transformation route :
- Pytorch → ONNX → TensorRT
- Pytorch → ONNX → TVM
- TF → ONNX → NCNN
Two 、PyTorch turn ONNX
import onnxruntime
import torch
torch.onnx.export(
model,
(img_list, ),
'tmp.onnx',
input_names=['input.1'],
output_names=['output'],
export_params=True,
keep_initializers_as_inputs=False,
verbose=show,
opset_version=opset_version,
dynamic_axes=dynamic_axes))
# onnx Model simplification :
python3 -m onnxsim tmp.onnx tmp_simplify.onnx
3、 ... and 、 What is? TensorRT
TensorRT It's a high-performance deep learning reasoning (inference) Optimizer , Yes, you can. NVIDIA Various GPU A running under the hardware platform C++ The frame of reasoning , Low latency can be improved for deep learning 、 High throughput deployment reasoning , It can be used for embedded platforms 、 The reasoning of the autopilot platform accelerates .
take TensorRT and NVIDIA Of GPU Combine , It can carry out rapid and efficient deployment reasoning in almost all frameworks . So as to improve this model in NVIDIA GPU The speed of running on the computer . The proportion of speed increase is considerable .
We know that the model includes two stages of training and reasoning , Training includes forward propagation and back propagation , Reasoning only involves forward propagation , So the speed of prediction is more important .
During training , Generally, it will be used more GPU Distributed training , When deploying reasoning , Often use a single GPU Even embedded platforms . The sampling framework during model training will be different , The performance of different machines will vary , Cause reasoning speed to slow down , Cannot meet high real-time . and TensorRT Is the inference optimizer , hold ONNX The model is converted to TensorRT after , You can deploy on the relevant side .
TensorRT Optimization method :
TensorRT There are many optimization methods , The first two are the most important :
Interlaminar fusion or tensor fusion :
TensorRT Through horizontal or vertical merging between layers ( The combined structure is called CBR, Signification convolution, bias, and ReLU layers are fused to form a single layer), The number of layers is greatly reduced . Horizontal merging can make convolution 、 The offset and active layers are combined into one CBR structure , Only one CUDA The core . Vertical merging can make the structure the same , But layers with different weights are merged into a wider layer , Only one CUDA The core . The merged calculation chart ) There are fewer levels of , The amount of CUDA The number of cores is also less , Therefore, the whole model structure will be smaller , faster , More efficient .
Data accuracy calibration :
When most deep learning frameworks train neural networks, the tensor in the network is 32 The precision of floating point numbers (Full 32-bit precision,FP32), Once the network training is completed , In the process of deploying reasoning, there is no need for back propagation , It can reduce the data accuracy properly , For example, it is reduced to FP16 or INT8 The accuracy of the . Lower data accuracy will result in lower memory usage and latency , The model is smaller .
Kernel Auto-Tuning:
When the network model is reasoning and calculating , Is to call GPU Of CUDA computationally ,TensorRT Can really different algorithms 、 Different model structures 、 Different GPU Platform, etc , Conduct CUDA adjustment , To ensure that the current model calculates with optimal performance on a specific platform .
Suppose that 3090 and T4 Deploy separately , It needs to be carried out on these two platforms TensorRT Transformation , Then use on the corresponding platform , It cannot be converted on the same platform , Use on different platforms .
Dynamic Tensor Memory:
At every tensor Use period ,TensorRT It will be assigned a memory , Avoid duplicate applications for video storage , Reduce memory usage and improve reuse efficiency
Four 、ONNX turn TensorRT
def convert_tensorrt_engine(onnx_fn, trt_fn, max_batch_size, fp16=True, int8_calibrator=None, workspace=2_000_000_000):
network_creation_flag = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
with trt.Builder(TRT_LOGGER) as builder,
builder.create_network(network_creation_flag) as network,
trt.OnnxParser(network, TRT_LOGGER) as parser:
builder.max_workspace_size = workspace
builder.max_batch_size = max_batch_size
builder.fp16_mode = fp16
if int8_calibrator:
builder.int8_mode = True
builder.int8_calibrator = int8_calibrator
with open(onnx_fn, "rb") as f:
if not parser.parse(f.read()):
print("got {} errors: ".format(parser.num_errors))
for i in range(parser.num_errors):
e = parser.get_error(i)
print(e.code(), e.desc(), e.node())
return
else:
print("parse successful")
print("inputs: ", network.num_inputs)
# inputs = [network.get_input(i) for i in range(network.num_inputs)]
# opt_profiles = create_optimization_profiles(builder, inputs)
# add_profiles(config, inputs, opt_profiles)
for i in range(network.num_inputs):
print(i, network.get_input(i).name, network.get_input(i).shape)
print("outputs: ", network.num_outputs)
for i in range(network.num_outputs):
output = network.get_output(i)
print(i, output.name, output.shape)
engine = builder.build_cuda_engine(network)
with open(trt_fn, "wb") as f:
f.write(engine.serialize())
print("done")
边栏推荐
- isAccessible()方法:使用反射技巧让你的性能提升数倍
- 第一周任务 深度学习和pytorch基础
- datax安装
- Spring, summer, autumn and winter with Miss Zhang (4)
- Basic use of array -- traverse the circular array to find the maximum value, minimum value, maximum subscript and minimum subscript of the array
- 【比赛网站】收集机器学习/深度学习比赛网站(持续更新)
- Android Studio 实现登录注册-源代码 (连接MySql数据库)
- Research and implementation of flash loan DAPP
- Super simple integration of HMS ml kit to realize parent control
- File文件上传的使用(2)--上传到阿里云Oss文件服务器
猜你喜欢

【DL】关于tensor(张量)的介绍和理解

MySql统计函数COUNT详解

Flink connector Oracle CDC synchronizes data to MySQL in real time (oracle19c)

Markdown语法

这些你一定要知道的进程知识

第2周学习:卷积神经网络基础

【go】defer的使用

【Transformer】AdaViT: Adaptive Tokens for Efficient Vision Transformer

Thinkphp6 output QR code image format to solve the conflict with debug

These process knowledge you must know
随机推荐
简单聊聊 PendingIntent 与 Intent 的区别
File文件上传的使用(2)--上传到阿里云Oss文件服务器
Research on the implementation principle of reentrantlock in concurrent programming learning notes
Spring, summer, autumn and winter with Miss Zhang (3)
C # judge whether the user accesses by mobile phone or computer
数组的基础使用--遍历循环数组求出数组最大值,最小值以及最大值下标,最小值下标
研究生新生培训第三周:ResNet+ResNeXt
Exploration of flutter drawing skills: draw arrows together (skill development)
并发编程学习笔记 之 工具类Semaphore(信号量)
Process management of day02 operation
Thinkphp6 pipeline mode pipeline use
ASM piling: after learning ASM tree API, you don't have to be afraid of hook anymore
ssm整合
mysql插入百万数据(使用函数和存储过程)
Personal learning website
有价值的博客、面经收集(持续更新)
手撕ORM 框架(泛型+注解+反射)
[ml] PMML of machine learning model -- Overview
深入理解MMAP原理,让大厂都爱不释手的技术
Detailed explanation of atomic operation class atomicinteger in learning notes of concurrent programming