当前位置:网站首页>13.模型的保存和载入
13.模型的保存和载入
2022-07-07 23:11:00 【booze-J】
我们以保存3.MNIST数据集分类中训练的模型为例,来演示模型的保存与载入。
第一种模型保存和载入方式
1.保存方式
保存模型只需要在模型训练完之后添加上
# 保存模型 可以同时保存模型的结构和参数
model.save("model.h5") # HDF5文件,pip install h5py
这种保存方式可以同时保存模型的结构和参数。
2.载入方式
载入模型之前需要先导入load_model方法
from keras.models import load_model
然后载入的代码就是简单一句:
# 载入模型
model = load_model("../model.h5")
这种载入方法可以同时载入模型的结构和参数。
第二种模型保存和载入方式
1.保存方式
模型参数和模型结构分开来保存:
# 保存参数
model.save_weights("my_model_weights.h5")
# 保存网络结构
json_string = model.to_json()
2.载入方式
在载入模型结构之前,需要先导入model_from_json()方法
from keras.models import model_from_json
分别载入网络参数和网络结构:
# 载入参数
model.load_weights("my_model_weights.h5")
# 载入模型结构
model = model_from_json(json_string)
模型再训练
代码运行平台为jupyter-notebook,文章中的代码块,也是按照jupyter-notebook中的划分顺序进行书写的,运行文章代码,直接分单元粘入到jupyter-notebook即可。
其实模型载入之后是可以进行再训练的。
1.导入第三方库
import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from keras.models import load_model
2.加载数据及数据预处理
# 载入数据
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000, 28, 28)
print("x_shape:\n",x_train.shape)
# (60000,) 还未进行one-hot编码 需要后面自己操作
print("y_shape:\n",y_train.shape)
# (60000, 28, 28) -> (60000,784) reshape()中参数填入-1的话可以自动计算出参数结果 除以255.0是为了归一化
x_train = x_train.reshape(x_train.shape[0],-1)/255.0
x_test = x_test.reshape(x_test.shape[0],-1)/255.0
# 换one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)
3.模型再训练
# 载入模型
model = load_model("../model.h5")
# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)
print("\ntest loss",loss)
print("accuracy:",accuracy)
运行结果:
对比首次保存的模型:
可以发现再训练模型在测试集上的准确率有所提高。
边栏推荐
猜你喜欢
![Cause analysis and solution of too laggy page of [test interview questions]](/img/8d/3ca92ce5f9cdc85d52dbcd826e477d.jpg)
Cause analysis and solution of too laggy page of [test interview questions]

基于卷积神经网络的恶意软件检测方法

Invalid V-for traversal element style

浪潮云溪分布式数据库 Tracing(二)—— 源码解析

fabulous! How does idea open multiple projects in a single window?

DNS series (I): why does the updated DNS record not take effect?

Kubernetes Static Pod (静态Pod)

A network composed of three convolution layers completes the image classification task of cifar10 data set

ReentrantLock 公平锁源码 第0篇

QT establish signal slots between different classes and transfer parameters
随机推荐
Application practice | the efficiency of the data warehouse system has been comprehensively improved! Data warehouse construction based on Apache Doris in Tongcheng digital Department
FOFA-攻防挑战记录
Leetcode brush questions
Where is the big data open source project, one-stop fully automated full life cycle operation and maintenance steward Chengying (background)?
深潜Kotlin协程(二十三 完结篇):SharedFlow 和 StateFlow
华泰证券官方网站开户安全吗?
攻防世界Web进阶区unserialize3题解
Codeforces Round #804 (Div. 2)(A~D)
CVE-2022-28346:Django SQL注入漏洞
ABAP ALV LVC模板
Stock account opening is free of charge. Is it safe to open an account on your mobile phone
Summary of the third course of weidongshan
Invalid V-for traversal element style
After going to ByteDance, I learned that there are so many test engineers with an annual salary of 40W?
Jemter distributed
Langchao Yunxi distributed database tracing (II) -- source code analysis
从服务器到云托管,到底经历了什么?
3 years of experience, can't you get 20K for the interview and test post? Such a hole?
STL--String类的常用功能复写
Cancel the down arrow of the default style of select and set the default word of select