当前位置:网站首页>13.模型的保存和載入
13.模型的保存和載入
2022-07-08 00:54:00 【booze-J】
我們以保存3.MNIST數據集分類中訓練的模型為例,來演示模型的保存與載入。
第一種模型保存和載入方式
1.保存方式
保存模型只需要在模型訓練完之後添加上
# 保存模型 可以同時保存模型的結構和參數
model.save("model.h5") # HDF5文件,pip install h5py
這種保存方式可以同時保存模型的結構和參數。
2.載入方式
載入模型之前需要先導入load_model
方法
from keras.models import load_model
然後載入的代碼就是簡單一句:
# 載入模型
model = load_model("../model.h5")
這種載入方法可以同時載入模型的結構和參數。
第二種模型保存和載入方式
1.保存方式
模型參數和模型結構分開來保存:
# 保存參數
model.save_weights("my_model_weights.h5")
# 保存網絡結構
json_string = model.to_json()
2.載入方式
在載入模型結構之前,需要先導入model_from_json()
方法
from keras.models import model_from_json
分別載入網絡參數和網絡結構:
# 載入參數
model.load_weights("my_model_weights.h5")
# 載入模型結構
model = model_from_json(json_string)
模型再訓練
代碼運行平臺為jupyter-notebook,文章中的代碼塊,也是按照jupyter-notebook中的劃分順序進行書寫的,運行文章代碼,直接分單元粘入到jupyter-notebook即可。
其實模型載入之後是可以進行再訓練的。
1.導入第三方庫
import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from keras.models import load_model
2.加載數據及數據預處理
# 載入數據
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000, 28, 28)
print("x_shape:\n",x_train.shape)
# (60000,) 還未進行one-hot編碼 需要後面自己操作
print("y_shape:\n",y_train.shape)
# (60000, 28, 28) -> (60000,784) reshape()中參數填入-1的話可以自動計算出參數結果 除以255.0是為了歸一化
x_train = x_train.reshape(x_train.shape[0],-1)/255.0
x_test = x_test.reshape(x_test.shape[0],-1)/255.0
# 換one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)
3.模型再訓練
# 載入模型
model = load_model("../model.h5")
# 評估模型
loss,accuracy = model.evaluate(x_test,y_test)
print("\ntest loss",loss)
print("accuracy:",accuracy)
運行結果:
對比首次保存的模型:
可以發現再訓練模型在測試集上的准確率有所提高。
边栏推荐
- STL--String类的常用功能复写
- Deep dive kotlin synergy (XXII): flow treatment
- [go record] start go language from scratch -- make an oscilloscope with go language (I) go language foundation
- From starfish OS' continued deflationary consumption of SFO, the value of SFO in the long run
- jemter分布式
- Cascade-LSTM: A Tree-Structured Neural Classifier for Detecting Misinformation Cascades(KDD20)
- ABAP ALV LVC template
- 德总理称乌不会获得“北约式”安全保障
- 5g NR system messages
- 基于人脸识别实现课堂抬头率检测
猜你喜欢
Interface test advanced interface script use - apipost (pre / post execution script)
[note] common combined filter circuit
Installation and configuration of sublime Text3
How to learn a new technology (programming language)
Password recovery vulnerability of foreign public testing
他们齐聚 2022 ECUG Con,只为「中国技术力量」
Application practice | the efficiency of the data warehouse system has been comprehensively improved! Data warehouse construction based on Apache Doris in Tongcheng digital Department
【笔记】常见组合滤波电路
基于微信小程序开发的我最在行的小游戏
深潜Kotlin协程(二十二):Flow的处理
随机推荐
Summary of weidongshan phase II course content
Binder core API
Codeforces Round #804 (Div. 2)(A~D)
【GO记录】从零开始GO语言——用GO语言做一个示波器(一)GO语言基础
After going to ByteDance, I learned that there are so many test engineers with an annual salary of 40W?
【obs】Impossible to find entrance point CreateDirect3D11DeviceFromDXGIDevice
Qt添加资源文件,为QAction添加图标,建立信号槽函数并实现
Kubernetes Static Pod (静态Pod)
QT establish signal slots between different classes and transfer parameters
FOFA-攻防挑战记录
Service mesh introduction, istio overview
Solution to the problem of unserialize3 in the advanced web area of the attack and defense world
ReentrantLock 公平锁源码 第0篇
接口测试进阶接口脚本使用—apipost(预/后执行脚本)
Implementation of adjacency table of SQLite database storage directory structure 2-construction of directory tree
他们齐聚 2022 ECUG Con,只为「中国技术力量」
5g NR system messages
韦东山第三期课程内容概要
Deep dive kotlin synergy (XXII): flow treatment
ABAP ALV LVC模板