当前位置:网站首页>13.模型的保存和載入
13.模型的保存和載入
2022-07-08 00:54:00 【booze-J】
我們以保存3.MNIST數據集分類中訓練的模型為例,來演示模型的保存與載入。
第一種模型保存和載入方式
1.保存方式
保存模型只需要在模型訓練完之後添加上
# 保存模型 可以同時保存模型的結構和參數
model.save("model.h5") # HDF5文件,pip install h5py
這種保存方式可以同時保存模型的結構和參數。
2.載入方式
載入模型之前需要先導入load_model
方法
from keras.models import load_model
然後載入的代碼就是簡單一句:
# 載入模型
model = load_model("../model.h5")
這種載入方法可以同時載入模型的結構和參數。
第二種模型保存和載入方式
1.保存方式
模型參數和模型結構分開來保存:
# 保存參數
model.save_weights("my_model_weights.h5")
# 保存網絡結構
json_string = model.to_json()
2.載入方式
在載入模型結構之前,需要先導入model_from_json()
方法
from keras.models import model_from_json
分別載入網絡參數和網絡結構:
# 載入參數
model.load_weights("my_model_weights.h5")
# 載入模型結構
model = model_from_json(json_string)
模型再訓練
代碼運行平臺為jupyter-notebook,文章中的代碼塊,也是按照jupyter-notebook中的劃分順序進行書寫的,運行文章代碼,直接分單元粘入到jupyter-notebook即可。
其實模型載入之後是可以進行再訓練的。
1.導入第三方庫
import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from keras.models import load_model
2.加載數據及數據預處理
# 載入數據
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000, 28, 28)
print("x_shape:\n",x_train.shape)
# (60000,) 還未進行one-hot編碼 需要後面自己操作
print("y_shape:\n",y_train.shape)
# (60000, 28, 28) -> (60000,784) reshape()中參數填入-1的話可以自動計算出參數結果 除以255.0是為了歸一化
x_train = x_train.reshape(x_train.shape[0],-1)/255.0
x_test = x_test.reshape(x_test.shape[0],-1)/255.0
# 換one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)
3.模型再訓練
# 載入模型
model = load_model("../model.h5")
# 評估模型
loss,accuracy = model.evaluate(x_test,y_test)
print("\ntest loss",loss)
print("accuracy:",accuracy)
運行結果:
對比首次保存的模型:
可以發現再訓練模型在測試集上的准確率有所提高。
边栏推荐
- Malware detection method based on convolutional neural network
- 大数据开源项目,一站式全自动化全生命周期运维管家ChengYing(承影)走向何方?
- New library online | cnopendata China Star Hotel data
- 【GO记录】从零开始GO语言——用GO语言做一个示波器(一)GO语言基础
- Tapdata 的 2.0 版 ,开源的 Live Data Platform 现已发布
- 新库上线 | CnOpenData中国星级酒店数据
- Play sonar
- 《因果性Causality》教程,哥本哈根大学Jonas Peters讲授
- Handwriting a simulated reentrantlock
- ThinkPHP kernel work order system source code commercial open source version multi user + multi customer service + SMS + email notification
猜你喜欢
Interface test advanced interface script use - apipost (pre / post execution script)
From starfish OS' continued deflationary consumption of SFO, the value of SFO in the long run
NVIDIA Jetson测试安装yolox过程记录
深潜Kotlin协程(二十二):Flow的处理
Codeforces Round #804 (Div. 2)(A~D)
AI遮天传 ML-初识决策树
ThinkPHP kernel work order system source code commercial open source version multi user + multi customer service + SMS + email notification
Cve-2022-28346: Django SQL injection vulnerability
v-for遍历元素样式失效
Fofa attack and defense challenge record
随机推荐
Class head up rate detection based on face recognition
ThinkPHP kernel work order system source code commercial open source version multi user + multi customer service + SMS + email notification
Kubernetes static pod (static POD)
DNS series (I): why does the updated DNS record not take effect?
新库上线 | CnOpenData中国星级酒店数据
股票开户免费办理佣金最低的券商,手机上开户安全吗
ABAP ALV LVC模板
玩转Sonar
5g NR system messages
NVIDIA Jetson测试安装yolox过程记录
Experience of autumn recruitment in 22 years
How is it most convenient to open an account for stock speculation? Is it safe to open an account on your mobile phone
【obs】Impossible to find entrance point CreateDirect3D11DeviceFromDXGIDevice
韦东山第三期课程内容概要
Reptile practice (VIII): reptile expression pack
Basic types of 100 questions for basic grammar of Niuke
Prompt configure: error: required tool not found: libtool solution when configuring and installing crosstool ng tool
Course of causality, taught by Jonas Peters, University of Copenhagen
My best game based on wechat applet development
C# 泛型及性能比较