当前位置:网站首页>13.模型的保存和載入
13.模型的保存和載入
2022-07-08 00:54:00 【booze-J】
我們以保存3.MNIST數據集分類中訓練的模型為例,來演示模型的保存與載入。
第一種模型保存和載入方式
1.保存方式
保存模型只需要在模型訓練完之後添加上
# 保存模型 可以同時保存模型的結構和參數
model.save("model.h5") # HDF5文件,pip install h5py
這種保存方式可以同時保存模型的結構和參數。
2.載入方式
載入模型之前需要先導入load_model方法
from keras.models import load_model
然後載入的代碼就是簡單一句:
# 載入模型
model = load_model("../model.h5")
這種載入方法可以同時載入模型的結構和參數。
第二種模型保存和載入方式
1.保存方式
模型參數和模型結構分開來保存:
# 保存參數
model.save_weights("my_model_weights.h5")
# 保存網絡結構
json_string = model.to_json()
2.載入方式
在載入模型結構之前,需要先導入model_from_json()方法
from keras.models import model_from_json
分別載入網絡參數和網絡結構:
# 載入參數
model.load_weights("my_model_weights.h5")
# 載入模型結構
model = model_from_json(json_string)
模型再訓練
代碼運行平臺為jupyter-notebook,文章中的代碼塊,也是按照jupyter-notebook中的劃分順序進行書寫的,運行文章代碼,直接分單元粘入到jupyter-notebook即可。
其實模型載入之後是可以進行再訓練的。
1.導入第三方庫
import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from keras.models import load_model
2.加載數據及數據預處理
# 載入數據
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000, 28, 28)
print("x_shape:\n",x_train.shape)
# (60000,) 還未進行one-hot編碼 需要後面自己操作
print("y_shape:\n",y_train.shape)
# (60000, 28, 28) -> (60000,784) reshape()中參數填入-1的話可以自動計算出參數結果 除以255.0是為了歸一化
x_train = x_train.reshape(x_train.shape[0],-1)/255.0
x_test = x_test.reshape(x_test.shape[0],-1)/255.0
# 換one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)
3.模型再訓練
# 載入模型
model = load_model("../model.h5")
# 評估模型
loss,accuracy = model.evaluate(x_test,y_test)
print("\ntest loss",loss)
print("accuracy:",accuracy)
運行結果:
對比首次保存的模型:
可以發現再訓練模型在測試集上的准確率有所提高。
边栏推荐
- An error is reported during the process of setting up ADG. Rman-03009 ora-03113
- [necessary for R & D personnel] how to make your own dataset and display it.
- Cve-2022-28346: Django SQL injection vulnerability
- Reptile practice (VIII): reptile expression pack
- 串口接收一包数据
- 接口测试进阶接口脚本使用—apipost(预/后执行脚本)
- [C language] objective questions - knowledge points
- ReentrantLock 公平锁源码 第0篇
- 51与蓝牙模块通讯,51驱动蓝牙APP点灯
- Cancel the down arrow of the default style of select and set the default word of select
猜你喜欢

How does the markdown editor of CSDN input mathematical formulas--- Latex syntax summary

Where is the big data open source project, one-stop fully automated full life cycle operation and maintenance steward Chengying (background)?
![[note] common combined filter circuit](/img/2f/a8c2ef0d76dd7a45b50a64a928a9c8.png)
[note] common combined filter circuit

A network composed of three convolution layers completes the image classification task of cifar10 data set

New library launched | cnopendata China Time-honored enterprise directory

基于人脸识别实现课堂抬头率检测

1293_FreeRTOS中xTaskResumeAll()接口的实现分析
![[necessary for R & D personnel] how to make your own dataset and display it.](/img/50/3d826186b563069fd8d433e8feefc4.png)
[necessary for R & D personnel] how to make your own dataset and display it.

Kubernetes static pod (static POD)

v-for遍历元素样式失效
随机推荐
玩轉Sonar
Marubeni official website applet configuration tutorial is coming (with detailed steps)
股票开户免费办理佣金最低的券商,手机上开户安全吗
基于卷积神经网络的恶意软件检测方法
5g NR system messages
Service mesh introduction, istio overview
How to learn a new technology (programming language)
牛客基础语法必刷100题之基本类型
Reentrantlock fair lock source code Chapter 0
C # generics and performance comparison
From starfish OS' continued deflationary consumption of SFO, the value of SFO in the long run
[Yugong series] go teaching course 006 in July 2022 - automatic derivation of types and input and output
语义分割模型库segmentation_models_pytorch的详细使用介绍
Malware detection method based on convolutional neural network
德总理称乌不会获得“北约式”安全保障
新库上线 | 中国记者信息数据
[OBS] the official configuration is use_ GPU_ Priority effect is true
Cascade-LSTM: A Tree-Structured Neural Classifier for Detecting Misinformation Cascades(KDD20)
Class head up rate detection based on face recognition
【笔记】常见组合滤波电路