当前位置:网站首页>13.模型的保存和载入
13.模型的保存和载入
2022-07-07 23:11:00 【booze-J】
我们以保存3.MNIST数据集分类中训练的模型为例,来演示模型的保存与载入。
第一种模型保存和载入方式
1.保存方式
保存模型只需要在模型训练完之后添加上
# 保存模型 可以同时保存模型的结构和参数
model.save("model.h5") # HDF5文件,pip install h5py
这种保存方式可以同时保存模型的结构和参数。
2.载入方式
载入模型之前需要先导入load_model方法
from keras.models import load_model
然后载入的代码就是简单一句:
# 载入模型
model = load_model("../model.h5")
这种载入方法可以同时载入模型的结构和参数。
第二种模型保存和载入方式
1.保存方式
模型参数和模型结构分开来保存:
# 保存参数
model.save_weights("my_model_weights.h5")
# 保存网络结构
json_string = model.to_json()
2.载入方式
在载入模型结构之前,需要先导入model_from_json()方法
from keras.models import model_from_json
分别载入网络参数和网络结构:
# 载入参数
model.load_weights("my_model_weights.h5")
# 载入模型结构
model = model_from_json(json_string)
模型再训练
代码运行平台为jupyter-notebook,文章中的代码块,也是按照jupyter-notebook中的划分顺序进行书写的,运行文章代码,直接分单元粘入到jupyter-notebook即可。
其实模型载入之后是可以进行再训练的。
1.导入第三方库
import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from keras.models import load_model
2.加载数据及数据预处理
# 载入数据
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000, 28, 28)
print("x_shape:\n",x_train.shape)
# (60000,) 还未进行one-hot编码 需要后面自己操作
print("y_shape:\n",y_train.shape)
# (60000, 28, 28) -> (60000,784) reshape()中参数填入-1的话可以自动计算出参数结果 除以255.0是为了归一化
x_train = x_train.reshape(x_train.shape[0],-1)/255.0
x_test = x_test.reshape(x_test.shape[0],-1)/255.0
# 换one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)
3.模型再训练
# 载入模型
model = load_model("../model.h5")
# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)
print("\ntest loss",loss)
print("accuracy:",accuracy)
运行结果:
对比首次保存的模型:
可以发现再训练模型在测试集上的准确率有所提高。
边栏推荐
- 手写一个模拟的ReentrantLock
- How to learn a new technology (programming language)
- letcode43:字符串相乘
- Analysis of 8 classic C language pointer written test questions
- Handwriting a simulated reentrantlock
- 新库上线 | 中国记者信息数据
- 1293_FreeRTOS中xTaskResumeAll()接口的实现分析
- 哪个券商公司开户佣金低又安全,又靠谱
- NVIDIA Jetson test installation yolox process record
- New library online | information data of Chinese journalists
猜你喜欢

My best game based on wechat applet development

【愚公系列】2022年7月 Go教学课程 006-自动推导类型和输入输出

letcode43:字符串相乘

8道经典C语言指针笔试题解析

DNS 系列(一):为什么更新了 DNS 记录不生效?

大数据开源项目,一站式全自动化全生命周期运维管家ChengYing(承影)走向何方?

New library online | cnopendata China Star Hotel data
![[necessary for R & D personnel] how to make your own dataset and display it.](/img/50/3d826186b563069fd8d433e8feefc4.png)
[necessary for R & D personnel] how to make your own dataset and display it.

FOFA-攻防挑战记录

玩轉Sonar
随机推荐
Installation and configuration of sublime Text3
Malware detection method based on convolutional neural network
SDNU_ACM_ICPC_2022_Summer_Practice(1~2)
取消select的默认样式的向下箭头和设置select默认字样
German prime minister says Ukraine will not receive "NATO style" security guarantee
Kubernetes static pod (static POD)
[Yugong series] go teaching course 006 in July 2022 - automatic derivation of types and input and output
爬虫实战(八):爬表情包
ReentrantLock 公平锁源码 第0篇
[note] common combined filter circuit
[necessary for R & D personnel] how to make your own dataset and display it.
22年秋招心得
Password recovery vulnerability of foreign public testing
新库上线 | CnOpenData中华老字号企业名录
Deep dive kotlin collaboration (the end of 23): sharedflow and stateflow
Reptile practice (VIII): reptile expression pack
Handwriting a simulated reentrantlock
Thinkphp内核工单系统源码商业开源版 多用户+多客服+短信+邮件通知
Course of causality, taught by Jonas Peters, University of Copenhagen
攻防世界Web进阶区unserialize3题解