当前位置:网站首页>Case practice --- Resnet classic convolutional neural network (Mindspore)
Case practice --- Resnet classic convolutional neural network (Mindspore)
2022-08-01 07:27:00 【swl. Crow】
目录
调用Model高阶APITo train and save the model file
前言
Practical reference for this caseAI Gallery-开发者-华为云
Because the neural network training steps are similar,Just summarize the relevant code and Resnet18网络结构.
Similar training steps can be referred tohttp://t.csdn.cn/SSmos
调用Model高阶APITo train and save the model file
Follow this sectionAI Gallery-开发者-华为云Notes from case practice,The original code is very clear,If you don't understand, you can find it on the official websiteModel API的详解——mindspore — MindSpore master documentation
import os,time
from mindspore import Model
from mindspore import load_checkpoint, load_param_into_net
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
model = Model(network, loss_fn = net_loss, optimizer = net_opt, metrics = {'acc'}) #完成Model初始化
#训练参数
batch_num = mnist_ds_train.get_dataset_size() #The size of the training dataset
max_epochs = 1 #训练轮数
model_path = "./model/ckpt" #Save the path of the trained model
os.system('rm -f {0}*.ckpt {0}*.meta {0}*.pb'.format(model_path)) #rm -f --->强制删除文件或者目录
#定义回调函数
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num, keep_checkpoint_max=35) #对ckpt的配置 保存步骤、保存最多ckpt文件数
ckpoint_cb = ModelCheckpoint(prefix="train_resnet_mnist", directory=model_path, config=config_ck) #Save the model and parameters after training
loss_cb = LossMonitor(batch_num) # for output loss
start_time = time.time()
model.train(max_epochs, mnist_ds_train, callbacks=[ckpoint_cb, loss_cb]) # 训练
res = model.eval(mnist_ds_test) # 验证测试集
print("result: ", res)
cost_time = time.time() - start_time
print("训练总耗时: %.1f s" % cost_time)Resnet18网络结构
图解
【参考这位博主的文章http://t.csdn.cn/83wbR】

代码详解----基于Jupternotebook
第一步,构建一个残差单元
according to the structure,Each residual unit is different输入、输出通道数和步长,So take these three variables as initialization parameters .
import mindspore.nn as nn
#构建一个残差单元
class basic_res(nn.Cell):
"""
需要设置的参数:
input_channels, output_channels, stride
"""
def __init__(self, input_channels, output_channels, stride = 1):
super(basic_res, self).__init__()
self.conv1 = nn.Conv2d(in_channels = input_channels, out_channels = output_channels, kernel_size = 3, stride = stride, pad_mode="same")
self.bn = nn.BatchNorm2d(output_channels)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(in_channels = output_channels, out_channels = output_channels, kernel_size = 3, stride = 1, pad_mode="same") #The stride of the second convolutional layer is 1,不需要人为设置
self.downsample = nn.Conv2d(in_channels = input_channels, out_channels = output_channels, kernel_size = 1, stride = stride, pad_mode="same") #Guaranteed residual inputshapewith the residual outputshape相同
def construct(self, x):
out = self.conv1(x)
out = self.bn(out)
out = self.relu(out)
out = self.conv2(out)
identity = self.downsample(x)
out = out + identity
out = self.relu(out)
return out 第二步,构建残差结构
在Resnet18结构中,Each residual structure consists of two residual units,Observe the second diagram of the structure,可知,Except for the first residual structure,The first residual unit step size in the latter three residual structures has changed2,Therefore only the step size of the first residual unit needs to be set.


#Stacked residual units are built into a residual structure
def build_res(input_channels, output_channels,blocks, stride = 1):
res_build = nn.SequentialCell()
res_build.append(basic_res(input_channels, output_channels, stride = stride)) #The first residual unit step size will change,为2,With downsampling function
for _ in range(1, blocks):
res_build.append(basic_res(output_channels, output_channels, stride = 1))#in a residual structure,Except for the first residual unit,The following steps are all1
return res_build第三步,构建残差网络
Based on a well-defined residual structure,Now just follow the schema structure,构建输入层、隐藏层、输出层,设置相应的参数,can be builtResnet18网络结构.
#构建残差网络
from mindspore import nn
class Resnet(nn.Cell):
def __init__(self, layer_dims, num_classes):
super(Resnet, self).__init__()
#输入层--Preprocessing such as convolution pooling is performed on the original input
self.stem = nn.SequentialCell([nn.Conv2d(3, 64, 7, 2, pad_mode='same'),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(3, 2, pad_mode='same')])
#隐藏层---残差结构、卷积
self.layer1 = build_res(64, 64, layer_dims[0])
self.layer2 = build_res(64, 128, layer_dims[1], 2)
self.layer3 = build_res(128, 256, layer_dims[2], 2)
self.layer4 = build_res(256, 512, layer_dims[3], 2)
#平均池化
self.avgpool = nn.AvgPool2d(7, 1)
#展开
self.flatten = nn.Flatten()
#全连接
self.fc = nn.Dense(512, num_classes)
def construct(self, x):
#输入层
out = self.stem(x)
#隐藏层
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
#输出层
out = self.avgpool(out)
out = self.flatten(out)
out = self.fc(out)
return out 第四步,实例化测试
#Define the number of residual units in each residual structure
layer_dims = [2,2,2,2]#建立测试数据
from mindspore import Tensor
from mindspore import numpy as np
x = Tensor(np.ones([1,3,224,224]))#实例化网络
resnet18 = Resnet(layer_dims, 10)
#输入数据
x_resnet18 = resnet18(x)输入数据x:

输出结果:
JupternotebookIt is really convenient for code debugging,The built structure can be disassembled,Test the input layer step by step、隐藏层和输出层,对于新手来说,Very useful for understanding neural network structure,As well as checking for changes in data throughout processing.
案例补充
Batch prediction is done last in this case,Because it is simpler,But there are a few things to understand,So a little summary
import numpy as np
from PIL import Image
import mindspore
import mindspore.ops as ops
from mindspore import Tensor
dic_ds_test = mnist_ds_test.create_dict_iterator(output_numpy = True) #Create iterative data,返回字典类型,数据类型是数组
ds_test = next(dic_ds_test) #Take the created iteration data
images_test = ds_test["image"]
labels_test = ds_test["label"]
output = model.predict(Tensor(images_test)) #开始预测,Returns a predicted score for each class
pred_labels = ops.Argmax(output_type=mindspore.int32)(output) #Returns the index of the largest value in the predicted score,即预测值
print("预测值 -- > ", pred_labels) # 打印预测值
print("真实值 -- > ", labels_test) # Print the real value
batch_img = np.squeeze(images_test[0])
for i in range(1, len(labels_test)):
batch_img = np.hstack((batch_img, np.squeeze(images_test[i]))) # Stitch a batch of images horizontally,It is convenient to display in the next step
Image.fromarray((batch_img*255).astype('uint8'), mode= "L") # Display the true value
- next ()------My understanding is if not adoptednext(),也可以遍历数据,when empty element is read,won't check,可能会有问题,但使用next()时,Empty element encountered,will automatically terminate with an error.详细原理可以参考:http://t.csdn.cn/OTTm2
预测结果显示:

边栏推荐
- 我说过无数遍了:从来没有一种技术是为灵活组合这个目标而设计的
- Dbeaver connect the MySQL database and error Connection refusedconnect processing
- 日志导致线程Block的这些坑,你不得不防
- 05-SDRAM: Arbitration
- Golang:go连接和使用mysql
- 图像基本操作的其他内容
- Srping中bean的生命周期
- R语言使用tidyquant包的tq_transmute函数计算持有某只股票的天、月、周收益率、ggplot2使用条形图可视化股票月收益率数据、使用百分比显示Y轴坐标数据、使用不同的色彩表征正负收益率
- 国内外最顶级的8大plm项目管理系统
- return;代表含义
猜你喜欢

"By sharing" northwestern university life service | | bytes a second interview on three sides by HR

The log causes these pits in the thread block, you have to prevent

小程序全面屏手势配置案例

对于升级go1.18的goland问题

【HDLBits 刷题】Circuits(1)Combinational Logic

VSCode 快捷键及通用插件推荐

小程序通过云函数操作数据库【使用get取数据库】

special day to remember
![Explosive 30,000 words, the hardest core丨Mysql knowledge system, complete collection of commands [recommended collection]](/img/7f/08b323ffc5b5f8e3354bee6775b994.png)
Explosive 30,000 words, the hardest core丨Mysql knowledge system, complete collection of commands [recommended collection]

案例实践 --- Resnet经典卷积神经网络(Mindspore)
随机推荐
好的plm软件有哪些?plm软件排行榜
Golang: go open web service
MySQL row locks and gap locks
2022杭电多校第二场1011 DOS Card(线段树)
史上超强最常用SQL语句大全
华为深度学习课程第六、七章
插入排序—直接插入排序和希尔排序
牛客刷SQL---2
return;代表含义
pytest接口自动化测试框架 | 使用函数返回值的形式传入参数值
小程序更多的手势事件(左右滑动、放大缩小、双击、长按)
centos 安装php7.4,搭建hyperf,转发RDS
最小生成树
Datagrip error "The specified database userpassword combination is rejected..."Solutions
Image lossless compression software which works: try completely free JPG - C image batch finishing compression reduces weight tools | latest JPG batch dressing tools download
Fist game copyright-free music download, League of Legends copyright-free music, can be used for video creation, live broadcast
聊一聊ICMP协议以及ping的过程
VoLTE基础学习系列 | 企业语音网简述
The log causes these pits in the thread block, you have to prevent
MVVM project development (commodity management system 1)