当前位置:网站首页>13.模型的保存和载入
13.模型的保存和载入
2022-07-07 23:11:00 【booze-J】
我们以保存3.MNIST数据集分类中训练的模型为例,来演示模型的保存与载入。
第一种模型保存和载入方式
1.保存方式
保存模型只需要在模型训练完之后添加上
# 保存模型 可以同时保存模型的结构和参数
model.save("model.h5") # HDF5文件,pip install h5py
这种保存方式可以同时保存模型的结构和参数。
2.载入方式
载入模型之前需要先导入load_model方法
from keras.models import load_model
然后载入的代码就是简单一句:
# 载入模型
model = load_model("../model.h5")
这种载入方法可以同时载入模型的结构和参数。
第二种模型保存和载入方式
1.保存方式
模型参数和模型结构分开来保存:
# 保存参数
model.save_weights("my_model_weights.h5")
# 保存网络结构
json_string = model.to_json()
2.载入方式
在载入模型结构之前,需要先导入model_from_json()方法
from keras.models import model_from_json
分别载入网络参数和网络结构:
# 载入参数
model.load_weights("my_model_weights.h5")
# 载入模型结构
model = model_from_json(json_string)
模型再训练
代码运行平台为jupyter-notebook,文章中的代码块,也是按照jupyter-notebook中的划分顺序进行书写的,运行文章代码,直接分单元粘入到jupyter-notebook即可。
其实模型载入之后是可以进行再训练的。
1.导入第三方库
import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from keras.models import load_model
2.加载数据及数据预处理
# 载入数据
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000, 28, 28)
print("x_shape:\n",x_train.shape)
# (60000,) 还未进行one-hot编码 需要后面自己操作
print("y_shape:\n",y_train.shape)
# (60000, 28, 28) -> (60000,784) reshape()中参数填入-1的话可以自动计算出参数结果 除以255.0是为了归一化
x_train = x_train.reshape(x_train.shape[0],-1)/255.0
x_test = x_test.reshape(x_test.shape[0],-1)/255.0
# 换one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)
3.模型再训练
# 载入模型
model = load_model("../model.h5")
# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)
print("\ntest loss",loss)
print("accuracy:",accuracy)
运行结果:
对比首次保存的模型:
可以发现再训练模型在测试集上的准确率有所提高。
边栏推荐
- QT establish signal slots between different classes and transfer parameters
- Malware detection method based on convolutional neural network
- Reptile practice (VIII): reptile expression pack
- Letcode43: string multiplication
- 语义分割模型库segmentation_models_pytorch的详细使用介绍
- 什么是负载均衡?DNS如何实现负载均衡?
- C# 泛型及性能比较
- 接口测试进阶接口脚本使用—apipost(预/后执行脚本)
- Summary of the third course of weidongshan
- How to add automatic sorting titles in typora software?
猜你喜欢

取消select的默认样式的向下箭头和设置select默认字样

1293_ Implementation analysis of xtask resumeall() interface in FreeRTOS

Cancel the down arrow of the default style of select and set the default word of select

【愚公系列】2022年7月 Go教学课程 006-自动推导类型和输入输出

NVIDIA Jetson测试安装yolox过程记录

Deep dive kotlin synergy (XXII): flow treatment

QT establish signal slots between different classes and transfer parameters

Huawei switch s5735s-l24t4s-qa2 cannot be remotely accessed by telnet

华为交换机S5735S-L24T4S-QA2无法telnet远程访问

玩转Sonar
随机推荐
My best game based on wechat applet development
DNS 系列(一):为什么更新了 DNS 记录不生效?
Introduction to paddle - using lenet to realize image classification method II in MNIST
Interface test advanced interface script use - apipost (pre / post execution script)
接口测试进阶接口脚本使用—apipost(预/后执行脚本)
【obs】Impossible to find entrance point CreateDirect3D11DeviceFromDXGIDevice
8道经典C语言指针笔试题解析
Semantic segmentation model base segmentation_ models_ Detailed introduction to pytorch
股票开户免费办理佣金最低的券商,手机上开户安全吗
牛客基础语法必刷100题之基本类型
Password recovery vulnerability of foreign public testing
Play sonar
Qt添加资源文件,为QAction添加图标,建立信号槽函数并实现
搭建ADG过程中复制报错 RMAN-03009 ORA-03113
jemter分布式
第一讲:链表中环的入口结点
他们齐聚 2022 ECUG Con,只为「中国技术力量」
Cascade-LSTM: A Tree-Structured Neural Classifier for Detecting Misinformation Cascades(KDD20)
How is it most convenient to open an account for stock speculation? Is it safe to open an account on your mobile phone
v-for遍历元素样式失效