当前位置:网站首页>14.绘制网络模型结构
14.绘制网络模型结构
2022-07-07 23:11:00 【booze-J】
绘制网络结构流程
运行代码之前需要需要安装pydot
和graphviz
安装pydot:pip install pydot
安装graphviz就比较麻烦了,大家自行百度一下。
代码运行平台为jupyter-notebook,文章中的代码块,也是按照jupyter-notebook中的划分顺序进行书写的,运行文章代码,直接分单元粘入到jupyter-notebook即可。
1.导入第三方库
import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense,Dropout,Convolution2D,MaxPooling2D,Flatten
from tensorflow.keras.optimizers import Adam
from keras.utils.vis_utils import plot_model
import matplotlib.pyplot as plt
# install pydot and graphviz
2.数据预处理
# 载入数据
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000, 28, 28)
print("x_shape:\n",x_train.shape)
# (60000,) 还未进行one-hot编码 需要后面自己操作
print("y_shape:\n",y_train.shape)
# (60000, 28, 28) -> (60000,28,28,1)=(图片数目,图片高度,图片宽度,图片的通道数) reshape()中参数填入-1的话可以自动计算出参数结果 除以255.0是为了归一化
# 归一化很关键哈,可以大大减少计算量
x_train = x_train.reshape(-1,28,28,1)/255.0
x_test = x_test.reshape(-1,28,28,1)/255.0
# 换one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)
3.搭建网络模型
# 定义顺序模型
model = Sequential()
# 第一个卷积层 注意第一层要写输入图片的大小 后面的层可以忽略
# input_shape 输入平面
# filters 卷积核/滤波器个数
# kernel_size 卷积窗口大小
# strides 步长
# padding padding方式 same/valid
# activation 激活函数
model.add(Convolution2D(
input_shape=(28,28,1),
filters=32,
kernel_size=5,
strides=1,
padding="same",
activation="relu"
))
# 第一个池化层
model.add(MaxPooling2D(
pool_size=2,
strides=2,
padding="same"
))
# 第二个池化层
model.add(Convolution2D(filters=64,kernel_size=5,strides=1,padding="same",activation="relu"))
# 第二个池化层
model.add(MaxPooling2D(pool_size=2,strides=2,padding="same"))
# 把第二个池化层的输出扁平化为1维
model.add(Flatten())
# 第一个全连接层
model.add(Dense(units=1024,activation="relu"))
# Dropout 随机选用50%神经元进行训练
model.add(Dropout(0.5))
# 第二个全连接层
model.add(Dense(units=10,activation="softmax"))
# # 定义优化器 设置学习率为1e-4
# adam = Adam(lr=1e-4)
# # 定义优化器,loss function,训练过程中计算准确率
# model.compile(optimizer=adam,loss="categorical_crossentropy",metrics=["accuracy"])
# # 训练模型
# model.fit(x_train,y_train,batch_size=64,epochs=10)
# # 评估模型
# loss,accuracy=model.evaluate(x_test,y_test)
# print("test loss:",loss)
# print("test accuracy:",accuracy)
4.绘制网络模型结构
# rankdir="TB" 最后这个就是决定方向的 T代表TOP B代表BOTTOM TB就是从上到下 如果要从左往右的话,修改rankdir="LR"即可
plot_model(model,to_file="model.png",show_shapes=True,show_layer_names="False",rankdir="TB")
plt.figure(figsize=(10,10))
img = plt.imread("model.png")
plt.imshow(img)
plt.axis("off")
plt.show()
运行结果:plot_model(model,to_file="model.png",show_shapes=True,show_layer_names="False",rankdir="TB")
中的rankdir="TB"
最后这个就是决定方向的 T
代表TOP ,B
代表BOTTOM,TB
就是从上到下,如果要从左往右的话,修改rankdir="LR"
即可。
边栏推荐
- Interface test advanced interface script use - apipost (pre / post execution script)
- 取消select的默认样式的向下箭头和设置select默认字样
- Stock account opening is free of charge. Is it safe to open an account on your mobile phone
- Cause analysis and solution of too laggy page of [test interview questions]
- 1293_ Implementation analysis of xtask resumeall() interface in FreeRTOS
- From starfish OS' continued deflationary consumption of SFO, the value of SFO in the long run
- [go record] start go language from scratch -- make an oscilloscope with go language (I) go language foundation
- Jemter distributed
- CVE-2022-28346:Django SQL注入漏洞
- What if the testing process is not perfect and the development is not active?
猜你喜欢
ThinkPHP kernel work order system source code commercial open source version multi user + multi customer service + SMS + email notification
51 communicates with the Bluetooth module, and 51 drives the Bluetooth app to light up
Password recovery vulnerability of foreign public testing
C # generics and performance comparison
Codeforces Round #804 (Div. 2)(A~D)
12.RNN应用于手写数字识别
大数据开源项目,一站式全自动化全生命周期运维管家ChengYing(承影)走向何方?
图像数据预处理
fabulous! How does idea open multiple projects in a single window?
8.优化器
随机推荐
Jemter distributed
FOFA-攻防挑战记录
Is it safe to open an account on the official website of Huatai Securities?
大数据开源项目,一站式全自动化全生命周期运维管家ChengYing(承影)走向何方?
Basic principle and usage of dynamic library, -fpic option context
Service Mesh的基本模式
Application practice | the efficiency of the data warehouse system has been comprehensively improved! Data warehouse construction based on Apache Doris in Tongcheng digital Department
STL -- common function replication of string class
韦东山第三期课程内容概要
丸子官网小程序配置教程来了(附详细步骤)
RPA cloud computer, let RPA out of the box with unlimited computing power?
Kubernetes static pod (static POD)
How to insert highlighted code blocks in WPS and word
A brief history of information by James Gleick
赞!idea 如何单窗口打开多个项目?
取消select的默认样式的向下箭头和设置select默认字样
Summary of the third course of weidongshan
Service Mesh介绍,Istio概述
SDNU_ACM_ICPC_2022_Summer_Practice(1~2)
基于卷积神经网络的恶意软件检测方法