当前位置:网站首页>模型搭建过程1==MIIDock
模型搭建过程1==MIIDock
2022-06-13 11:02:00 【火萤石】
引入模块
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
定义一个网络模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 3, 3, padding=(0, 0), bias=False)
def forward(self, x):
x = self.conv1(x)
return x
net = Net()
定义卷积权重
#sobel全边缘检测算子
conv_rgb_core_sobel = [
[[-1,-1,-1],[-1,8,-1], [-1, -1, -1],
[0,0,0],[0,0,0], [0,0,0],
[0,0,0],[0,0,0], [0,0,0]
],
[[0,0,0],[0,0,0], [0,0,0],
[-1,-1,-1],[-1,8,-1], [-1, -1, -1],
[0,0,0],[0,0,0], [0,0,0]
],
[[0,0,0],[0,0,0], [0,0,0],
[0,0,0],[0,0,0], [0,0,0],
[-1,-1,-1],[-1,8,-1], [-1, -1, -1],
]]
#sobel垂直边缘检测算子
conv_rgb_core_sobel_vertical = [
[[-1,0,1],[-2,0,2], [-1, 0, 1],
[0,0,0],[0,0,0], [0,0,0],
[0,0,0],[0,0,0], [0,0,0]
],
[[0,0,0],[0,0,0], [0,0,0],
[-1,0,1],[-2,0,2], [-1, 0, 1],
[0,0,0],[0,0,0], [0,0,0]
],
[[0,0,0],[0,0,0], [0,0,0],
[0,0,0],[0,0,0], [0,0,0],
[-1,0,1],[-2,0,2], [-1, 0, 1],
]]
#sobel水平边缘检测算子
conv_rgb_core_sobel_horizontal = [
[[1,2,1],[0,0,0], [-1, -2, -1],
[0,0,0],[0,0,0], [0,0,0],
[0,0,0],[0,0,0], [0,0,0]
],
[[0,0,0],[0,0,0], [0,0,0],
[1,2,1],[0,0,0], [-1, -2, -1],
[0,0,0],[0,0,0], [0,0,0]
],
[[0,0,0],[0,0,0], [0,0,0],
[0,0,0],[0,0,0], [0,0,0],
[1,2,1],[0,0,0], [-1, -2, -1],
]]
网络载入权重函数
def sobel(net, kernel):
sobel_kernel = np.array(kernel, dtype='float32')
sobel_kernel = sobel_kernel.reshape((3, 3, 3, 3))
net.conv1.weight.data = torch.from_numpy(sobel_kernel)
params = list(net.parameters())
打开一张图片
pil_img = Image.open("../yarn_2.jpg")
#display(pil_img)
input_img = np.array(pil_img)
print(input_img.shape)
图片归一化处理
input_tensor = (input_img.astype(np.float32) - 127.5) / 128 # to [-1, 1]
print(input_tensor.shape)
input_tensor = torch.Tensor(input_tensor).permute((2, 0, 1))
input_tensor = input_tensor.unsqueeze(0)
print("input shape:", input_tensor.shape)
输入图片转换成PyTorch张量个格式
input_tensor = (input_img.astype(np.float32) - 127.5) / 128 # to [-1, 1]
input_tensor = torch.Tensor(input_tensor).permute((2, 0, 1))
print(input_tensor.shape)
input_tensor = input_tensor.unsqueeze(0)
print("input shape:", input_tensor.shape)
模型推理
global sobel_img_t
global sobel_vertical_img_t
global sobel_horizontal_img_t
sobel_img_t = None
sobel_vertical_img_t = None
sobel_horizontal_img_t = None
#载入网络权重
sobel(net, conv_rgb_core_sobel_vertical)
#在推理模式下运行网络
with torch.no_grad():
out = net(input_tensor)
sobel_vertical_img_t = out.numpy()[0].transpose([1, 2, 0])
#载入网络权重
sobel(net, conv_rgb_core_sobel_horizontal)
在推理模式下运行网络
with torch.no_grad():
out = net(input_tensor)
sobel_horizontal_img_t = out.numpy()[0].transpose([1, 2, 0])
#载入网络权重
sobel(net, conv_rgb_core_sobel)
#在推理模式下运行网络
with torch.no_grad():
out = net(input_tensor)
sobel_img_t = out.numpy()[0].transpose([1, 2, 0])
#打印输出结果
plt.figure()
plt.subplot(1, 5, 1)
plt.imshow(input_img)
plt.subplot(1, 5, 2)
plt.imshow(sobel_img_t)
plt.subplot(1, 5, 3)
plt.imshow(sobel_vertical_img_t)
plt.subplot(1, 5, 4)
plt.imshow(sobel_horizontal_img_t)
plt.subplot(1, 5, 5)
out = np.sqrt(np.square(sobel_vertical_img_t) + np.square(sobel_horizontal_img_t))
plt.imshow(out)
plt.show()
导出onnx网络
with torch.no_grad():
torch.onnx.export(net,input_tensor,"./model.onnx",
export_params=True,input_names=["input0"],
output_names=["output0"])
print("导出网络完成")
使用ncnn工具将onnx网络抓换成ncnn网络
#以下代码中会调用用户环境中的ncnn工具,请确保已经安装好并加入环境变量
#ncnn编译完后,在build/tools/onnx里会生成个可执行文件onnx2ncnn
#在终端执行命令 ./onnx2ncnn mobilenetv2.onnx mobilenetv2.param mobilenetv2.bin
def onnx_to_ncnn(input_shape,
onnx="out/model.onnx",
ncnn_param="out/conv0.param",
ncnn_bin="out/conv0.bin"):
import os
cmd=f"onnx2ncnn {
onnx} {
ncnn_param} {
ncnn_bin}"#可以更换工具目录
os.system(cmd)
with open(ncnn_param) as f:
content=f.read().split("\n")
if len(input_shape)==1:
content[2]+="0={}".format(input_shape[0])
else:
content[2]+="0={} 1={} 2={}".format(input_shape[2],input_shape[1],input_shape[0])
content="\n".join(content)
with open(ncnn_param,"w") as f:
f.write(content)
onnx_to_ncnn(input_shape=(3,224,224),
onnx="./model.onnx",
ncnn_param="./conv0.param",
ncnn_bin="./conv0.bin")
print("net success!")
去maxihub量化模型(int8)
边缘检测模型部署
#v831 运行边缘检测的代码
from maix import nn, camera, display, image
import numpy as np
import time
model = {
"param": "./sobel_int8.param",
"bin": "./sobel_int8.bin"
}
input_size = (224, 224, 3)
output_size = (222, 222, 3)
options = {
"model_type": "awnn",
"inputs": {
"input0": input_size
},
"outputs": {
"output0": output_size
},
"mean": [127.5, 127.5, 127.5],
"norm": [0.0078125, 0.0078125, 0.0078125],
}
print("-- load model:", model)
m = nn.load(model, opt=options)
print("-- load ok")
while True:
img = camera.capture().resize(224,224)
out = m.forward(img, quantize=True, layout="hwc")
out, = out.astype(np.float32).reshape(output_size)
out = (np.ndarray.__abs__(out) * 255 / out.max()).astype(np.uint8)
data = out.tobytes()
img2 = img.load(data,(222, 222), mode="RGB")
display.show(img2)
边栏推荐
- 宝塔添加一个网站:PHP项目
- 求组合数四种方法
- 高斯消元求n元方程组
- 数位DP例题
- Private computing fat core concepts and stand-alone deployment
- Questions and answers of the labor worker general basic (labor worker) work license in 2022
- Do you agree that the salary of hardware engineers is falsely high?
- ue5 小知识点 random point in Bounding Boxf From Stream
- vivo大规模 Kubernetes 集群自动化运维实践
- Environ. Sci. Technol.(IF=9.028) | 城市绿化对大气环境的影响
猜你喜欢

基于Vue+Nest.js+MySQL的跨平台开源社区运营管理系统

Interval modification multiplication and addition (a good example of understanding lazy tags)

Go 要加个箭头语法,这下更像 PHP 了!

电赛校赛经验-程控风力摆

Go zero microservice Practice Series (III. API definition and table structure design)

ue5 小知识点 random point in Bounding Boxf From Stream

Full stack development practice | integrated development of SSM framework

【TcaplusDB知识库】TcaplusDB运维单据介绍

【TcaplusDB知识库】TcaplusDB单据受理-建表审批介绍

宝塔中navicat连接mysql
随机推荐
Pyepics download and installation
ACP | 东北地理所在气象-空气质量双向耦合模式研究中取得进展
Nim游戏阶梯 Nim游戏和SG函数应用(集合游戏)
C#/VB.NET 在Word转PDF时生成目录书签
Multithreading starts from the lockless queue of UE4 (thread safe)
C file package and download
AcWing第 55 场周赛
21世纪以来的历次“粮食危机”,发生了什么?
判定二分图和二分图最大匹配
Brief description of redo logs and undo logs in MySQL
Brief request process
2022 tailings recurrent training question bank and simulated examination
Do you agree that the salary of hardware engineers is falsely high?
St table learning
Pagoda access changed from IP to domain name
SSM integration preliminary details
Interval modification multiplication and addition (a good example of understanding lazy tags)
Web3 system construction: principles, models and methods of decentralization (Part I)
宝塔添加一个网站:PHP项目
Database learning notes (Chapter 16)