当前位置:网站首页>14.绘制网络模型结构
14.绘制网络模型结构
2022-07-07 23:11:00 【booze-J】
绘制网络结构流程
运行代码之前需要需要安装pydot
和graphviz
安装pydot:pip install pydot
安装graphviz就比较麻烦了,大家自行百度一下。
代码运行平台为jupyter-notebook,文章中的代码块,也是按照jupyter-notebook中的划分顺序进行书写的,运行文章代码,直接分单元粘入到jupyter-notebook即可。
1.导入第三方库
import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense,Dropout,Convolution2D,MaxPooling2D,Flatten
from tensorflow.keras.optimizers import Adam
from keras.utils.vis_utils import plot_model
import matplotlib.pyplot as plt
# install pydot and graphviz
2.数据预处理
# 载入数据
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000, 28, 28)
print("x_shape:\n",x_train.shape)
# (60000,) 还未进行one-hot编码 需要后面自己操作
print("y_shape:\n",y_train.shape)
# (60000, 28, 28) -> (60000,28,28,1)=(图片数目,图片高度,图片宽度,图片的通道数) reshape()中参数填入-1的话可以自动计算出参数结果 除以255.0是为了归一化
# 归一化很关键哈,可以大大减少计算量
x_train = x_train.reshape(-1,28,28,1)/255.0
x_test = x_test.reshape(-1,28,28,1)/255.0
# 换one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)
3.搭建网络模型
# 定义顺序模型
model = Sequential()
# 第一个卷积层 注意第一层要写输入图片的大小 后面的层可以忽略
# input_shape 输入平面
# filters 卷积核/滤波器个数
# kernel_size 卷积窗口大小
# strides 步长
# padding padding方式 same/valid
# activation 激活函数
model.add(Convolution2D(
input_shape=(28,28,1),
filters=32,
kernel_size=5,
strides=1,
padding="same",
activation="relu"
))
# 第一个池化层
model.add(MaxPooling2D(
pool_size=2,
strides=2,
padding="same"
))
# 第二个池化层
model.add(Convolution2D(filters=64,kernel_size=5,strides=1,padding="same",activation="relu"))
# 第二个池化层
model.add(MaxPooling2D(pool_size=2,strides=2,padding="same"))
# 把第二个池化层的输出扁平化为1维
model.add(Flatten())
# 第一个全连接层
model.add(Dense(units=1024,activation="relu"))
# Dropout 随机选用50%神经元进行训练
model.add(Dropout(0.5))
# 第二个全连接层
model.add(Dense(units=10,activation="softmax"))
# # 定义优化器 设置学习率为1e-4
# adam = Adam(lr=1e-4)
# # 定义优化器,loss function,训练过程中计算准确率
# model.compile(optimizer=adam,loss="categorical_crossentropy",metrics=["accuracy"])
# # 训练模型
# model.fit(x_train,y_train,batch_size=64,epochs=10)
# # 评估模型
# loss,accuracy=model.evaluate(x_test,y_test)
# print("test loss:",loss)
# print("test accuracy:",accuracy)
4.绘制网络模型结构
# rankdir="TB" 最后这个就是决定方向的 T代表TOP B代表BOTTOM TB就是从上到下 如果要从左往右的话,修改rankdir="LR"即可
plot_model(model,to_file="model.png",show_shapes=True,show_layer_names="False",rankdir="TB")
plt.figure(figsize=(10,10))
img = plt.imread("model.png")
plt.imshow(img)
plt.axis("off")
plt.show()
运行结果:plot_model(model,to_file="model.png",show_shapes=True,show_layer_names="False",rankdir="TB")
中的rankdir="TB"
最后这个就是决定方向的 T
代表TOP ,B
代表BOTTOM,TB
就是从上到下,如果要从左往右的话,修改rankdir="LR"
即可。
边栏推荐
- Cancel the down arrow of the default style of select and set the default word of select
- Summary of the third course of weidongshan
- 华为交换机S5735S-L24T4S-QA2无法telnet远程访问
- 13.模型的保存和载入
- AI遮天传 ML-初识决策树
- Basic principle and usage of dynamic library, -fpic option context
- 9.卷积神经网络介绍
- How to insert highlighted code blocks in WPS and word
- 取消select的默认样式的向下箭头和设置select默认字样
- FOFA-攻防挑战记录
猜你喜欢
Qt不同类之间建立信号槽,并传递参数
Service mesh introduction, istio overview
接口测试进阶接口脚本使用—apipost(预/后执行脚本)
Tapdata 的 2.0 版 ,开源的 Live Data Platform 现已发布
【笔记】常见组合滤波电路
取消select的默认样式的向下箭头和设置select默认字样
浪潮云溪分布式数据库 Tracing(二)—— 源码解析
接口测试要测试什么?
[necessary for R & D personnel] how to make your own dataset and display it.
新库上线 | CnOpenData中国星级酒店数据
随机推荐
什么是负载均衡?DNS如何实现负载均衡?
Play sonar
Cascade-LSTM: A Tree-Structured Neural Classifier for Detecting Misinformation Cascades(KDD20)
How to learn a new technology (programming language)
8道经典C语言指针笔试题解析
[OBS] the official configuration is use_ GPU_ Priority effect is true
Su embedded training - Day3
ABAP ALV LVC模板
51与蓝牙模块通讯,51驱动蓝牙APP点灯
Basic types of 100 questions for basic grammar of Niuke
Application practice | the efficiency of the data warehouse system has been comprehensively improved! Data warehouse construction based on Apache Doris in Tongcheng digital Department
ThinkPHP kernel work order system source code commercial open source version multi user + multi customer service + SMS + email notification
【笔记】常见组合滤波电路
12.RNN应用于手写数字识别
Reentrantlock fair lock source code Chapter 0
炒股开户怎么最方便,手机上开户安全吗
Binder core API
RPA cloud computer, let RPA out of the box with unlimited computing power?
[go record] start go language from scratch -- make an oscilloscope with go language (I) go language foundation
串口接收一包数据