当前位置:网站首页>13.模型的保存和载入
13.模型的保存和载入
2022-07-07 23:11:00 【booze-J】
我们以保存3.MNIST数据集分类中训练的模型为例,来演示模型的保存与载入。
第一种模型保存和载入方式
1.保存方式
保存模型只需要在模型训练完之后添加上
# 保存模型 可以同时保存模型的结构和参数
model.save("model.h5") # HDF5文件,pip install h5py
这种保存方式可以同时保存模型的结构和参数。
2.载入方式
载入模型之前需要先导入load_model
方法
from keras.models import load_model
然后载入的代码就是简单一句:
# 载入模型
model = load_model("../model.h5")
这种载入方法可以同时载入模型的结构和参数。
第二种模型保存和载入方式
1.保存方式
模型参数和模型结构分开来保存:
# 保存参数
model.save_weights("my_model_weights.h5")
# 保存网络结构
json_string = model.to_json()
2.载入方式
在载入模型结构之前,需要先导入model_from_json()
方法
from keras.models import model_from_json
分别载入网络参数和网络结构:
# 载入参数
model.load_weights("my_model_weights.h5")
# 载入模型结构
model = model_from_json(json_string)
模型再训练
代码运行平台为jupyter-notebook,文章中的代码块,也是按照jupyter-notebook中的划分顺序进行书写的,运行文章代码,直接分单元粘入到jupyter-notebook即可。
其实模型载入之后是可以进行再训练的。
1.导入第三方库
import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from keras.models import load_model
2.加载数据及数据预处理
# 载入数据
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000, 28, 28)
print("x_shape:\n",x_train.shape)
# (60000,) 还未进行one-hot编码 需要后面自己操作
print("y_shape:\n",y_train.shape)
# (60000, 28, 28) -> (60000,784) reshape()中参数填入-1的话可以自动计算出参数结果 除以255.0是为了归一化
x_train = x_train.reshape(x_train.shape[0],-1)/255.0
x_test = x_test.reshape(x_test.shape[0],-1)/255.0
# 换one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)
3.模型再训练
# 载入模型
model = load_model("../model.h5")
# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)
print("\ntest loss",loss)
print("accuracy:",accuracy)
运行结果:
对比首次保存的模型:
可以发现再训练模型在测试集上的准确率有所提高。
边栏推荐
- 5g NR system messages
- Introduction to paddle - using lenet to realize image classification method II in MNIST
- 【GO记录】从零开始GO语言——用GO语言做一个示波器(一)GO语言基础
- New library online | information data of Chinese journalists
- ThinkPHP kernel work order system source code commercial open source version multi user + multi customer service + SMS + email notification
- NVIDIA Jetson test installation yolox process record
- German prime minister says Ukraine will not receive "NATO style" security guarantee
- 51与蓝牙模块通讯,51驱动蓝牙APP点灯
- Codeforces Round #804 (Div. 2)(A~D)
- 基于人脸识别实现课堂抬头率检测
猜你喜欢
They gathered at the 2022 ecug con just for "China's technological power"
Lecture 1: the entry node of the link in the linked list
ThinkPHP kernel work order system source code commercial open source version multi user + multi customer service + SMS + email notification
C language 001: download, install, create the first C project and execute the first C language program of CodeBlocks
搭建ADG过程中复制报错 RMAN-03009 ORA-03113
NVIDIA Jetson test installation yolox process record
新库上线 | CnOpenData中国星级酒店数据
Jouer sonar
A network composed of three convolution layers completes the image classification task of cifar10 data set
[go record] start go language from scratch -- make an oscilloscope with go language (I) go language foundation
随机推荐
How to learn a new technology (programming language)
51 communicates with the Bluetooth module, and 51 drives the Bluetooth app to light up
基于卷积神经网络的恶意软件检测方法
詹姆斯·格雷克《信息简史》读后感记录
A network composed of three convolution layers completes the image classification task of cifar10 data set
The weight of the product page of the second level classification is low. What if it is not included?
Cascade-LSTM: A Tree-Structured Neural Classifier for Detecting Misinformation Cascades(KDD20)
Introduction to paddle - using lenet to realize image classification method II in MNIST
Service Mesh的基本模式
SDNU_ACM_ICPC_2022_Summer_Practice(1~2)
5G NR 系统消息
接口测试进阶接口脚本使用—apipost(预/后执行脚本)
"An excellent programmer is worth five ordinary programmers", and the gap lies in these seven key points
Lecture 1: the entry node of the link in the linked list
深潜Kotlin协程(二十三 完结篇):SharedFlow 和 StateFlow
Codeforces Round #804 (Div. 2)(A~D)
NVIDIA Jetson test installation yolox process record
浪潮云溪分布式数据库 Tracing(二)—— 源码解析
攻防世界Web进阶区unserialize3题解
服务器防御DDOS的方法,杭州高防IP段103.219.39.x