当前位置:网站首页>13.模型的保存和載入
13.模型的保存和載入
2022-07-08 00:54:00 【booze-J】
我們以保存3.MNIST數據集分類中訓練的模型為例,來演示模型的保存與載入。
第一種模型保存和載入方式
1.保存方式
保存模型只需要在模型訓練完之後添加上
# 保存模型 可以同時保存模型的結構和參數
model.save("model.h5") # HDF5文件,pip install h5py
這種保存方式可以同時保存模型的結構和參數。
2.載入方式
載入模型之前需要先導入load_model方法
from keras.models import load_model
然後載入的代碼就是簡單一句:
# 載入模型
model = load_model("../model.h5")
這種載入方法可以同時載入模型的結構和參數。
第二種模型保存和載入方式
1.保存方式
模型參數和模型結構分開來保存:
# 保存參數
model.save_weights("my_model_weights.h5")
# 保存網絡結構
json_string = model.to_json()
2.載入方式
在載入模型結構之前,需要先導入model_from_json()方法
from keras.models import model_from_json
分別載入網絡參數和網絡結構:
# 載入參數
model.load_weights("my_model_weights.h5")
# 載入模型結構
model = model_from_json(json_string)
模型再訓練
代碼運行平臺為jupyter-notebook,文章中的代碼塊,也是按照jupyter-notebook中的劃分順序進行書寫的,運行文章代碼,直接分單元粘入到jupyter-notebook即可。
其實模型載入之後是可以進行再訓練的。
1.導入第三方庫
import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from keras.models import load_model
2.加載數據及數據預處理
# 載入數據
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000, 28, 28)
print("x_shape:\n",x_train.shape)
# (60000,) 還未進行one-hot編碼 需要後面自己操作
print("y_shape:\n",y_train.shape)
# (60000, 28, 28) -> (60000,784) reshape()中參數填入-1的話可以自動計算出參數結果 除以255.0是為了歸一化
x_train = x_train.reshape(x_train.shape[0],-1)/255.0
x_test = x_test.reshape(x_test.shape[0],-1)/255.0
# 換one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)
3.模型再訓練
# 載入模型
model = load_model("../model.h5")
# 評估模型
loss,accuracy = model.evaluate(x_test,y_test)
print("\ntest loss",loss)
print("accuracy:",accuracy)
運行結果:
對比首次保存的模型:
可以發現再訓練模型在測試集上的准確率有所提高。
边栏推荐
- 詹姆斯·格雷克《信息简史》读后感记录
- The weight of the product page of the second level classification is low. What if it is not included?
- [OBS] the official configuration is use_ GPU_ Priority effect is true
- 牛客基础语法必刷100题之基本类型
- jemter分布式
- 浪潮云溪分布式数据库 Tracing(二)—— 源码解析
- SDNU_ACM_ICPC_2022_Summer_Practice(1~2)
- Codeforces Round #804 (Div. 2)(A~D)
- 新库上线 | 中国记者信息数据
- Cascade-LSTM: A Tree-Structured Neural Classifier for Detecting Misinformation Cascades(KDD20)
猜你喜欢

Reptile practice (VIII): reptile expression pack
![[Yugong series] go teaching course 006 in July 2022 - automatic derivation of types and input and output](/img/79/f5cffe62d5d1e4a69b6143aef561d9.png)
[Yugong series] go teaching course 006 in July 2022 - automatic derivation of types and input and output
![[note] common combined filter circuit](/img/2f/a8c2ef0d76dd7a45b50a64a928a9c8.png)
[note] common combined filter circuit

Class head up rate detection based on face recognition

Interface test advanced interface script use - apipost (pre / post execution script)

接口测试要测试什么?

Redis, do you understand the list

C # generics and performance comparison

5g NR system messages

从服务器到云托管,到底经历了什么?
随机推荐
Service mesh introduction, istio overview
手机上炒股安全么?
After going to ByteDance, I learned that there are so many test engineers with an annual salary of 40W?
基于微信小程序开发的我最在行的小游戏
Letcode43: string multiplication
1293_FreeRTOS中xTaskResumeAll()接口的实现分析
Play sonar
2022-07-07: the original array is a monotonic array with numbers greater than 0 and less than or equal to K. there may be equal numbers in it, and the overall trend is increasing. However, the number
【obs】官方是配置USE_GPU_PRIORITY 效果为TRUE的
LeetCode刷题
华为交换机S5735S-L24T4S-QA2无法telnet远程访问
RPA cloud computer, let RPA out of the box with unlimited computing power?
Langchao Yunxi distributed database tracing (II) -- source code analysis
Analysis of 8 classic C language pointer written test questions
ABAP ALV LVC template
Interface test advanced interface script use - apipost (pre / post execution script)
Four stages of sand table deduction in attack and defense drill
Su embedded training - day4
[OBS] the official configuration is use_ GPU_ Priority effect is true
New library launched | cnopendata China Time-honored enterprise directory