当前位置:网站首页>7.正则化应用
7.正则化应用
2022-07-07 23:12:00 【booze-J】
一、正则化的应用
在6.Dropout应用中的未使用Dropout的代码的网络模型构建中添加正则化。
将6.Dropout应用中的
# 创建模型 输入784个神经元,输出10个神经元
model = Sequential([
# 定义输出是200 输入是784,设置偏置为1,添加softmax激活函数 第一个隐藏层有200个神经元
Dense(units=200,input_dim=784,bias_initializer='one',activation="tanh"),
# 第二个隐藏层有 100个神经元
Dense(units=100,bias_initializer='one',activation="tanh"),
Dense(units=10,bias_initializer='one',activation="softmax")
])
修改为
# 创建模型 输入784个神经元,输出10个神经元
model = Sequential([
# 定义输出是200 输入是784,设置偏置为1,添加softmax激活函数 第一个隐藏层有200个神经元
Dense(units=200,input_dim=784,bias_initializer='one',activation="tanh",kernel_regularizer=l2(0.0003)),
# 第二个隐藏层有 100个神经元
Dense(units=100,bias_initializer='one',activation="tanh",kernel_regularizer=l2(0.0003)),
Dense(units=10,bias_initializer='one',activation="softmax",kernel_regularizer=l2(0.0003))
])
使用l2正则化之前需要先导入from keras.regularizers import l2
。
运行结果:
从运行结果可以看出来明显克服了一些过拟合的情况,模型对于数据集不是很复杂,加上正则化的话,它的效果可能就不是很好。
完整代码
代码运行平台为jupyter-notebook,文章中的代码块,也是按照jupyter-notebook中的划分顺序进行书写的,运行文章代码,直接分单元粘入到jupyter-notebook即可。
1.导入第三方库
import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from tensorflow.keras.optimizers import SGD
from keras.regularizers import l2
2.加载数据及数据预处理
# 载入数据
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000, 28, 28)
print("x_shape:\n",x_train.shape)
# (60000,) 还未进行one-hot编码 需要后面自己操作
print("y_shape:\n",y_train.shape)
# (60000, 28, 28) -> (60000,784) reshape()中参数填入-1的话可以自动计算出参数结果 除以255.0是为了归一化
x_train = x_train.reshape(x_train.shape[0],-1)/255.0
x_test = x_test.reshape(x_test.shape[0],-1)/255.0
# 换one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)
3.训练模型
# 创建模型 输入784个神经元,输出10个神经元
model = Sequential([
# 定义输出是200 输入是784,设置偏置为1,添加softmax激活函数 第一个隐藏层有200个神经元
Dense(units=200,input_dim=784,bias_initializer='one',activation="tanh",kernel_regularizer=l2(0.0003)),
# 第二个隐藏层有 100个神经元
Dense(units=100,bias_initializer='one',activation="tanh",kernel_regularizer=l2(0.0003)),
Dense(units=10,bias_initializer='one',activation="softmax",kernel_regularizer=l2(0.0003))
])
# 定义优化器
sgd = SGD(lr=0.2)
# 定义优化器,loss_function,训练过程中计算准确率
model.compile(
optimizer=sgd,
loss="categorical_crossentropy",
metrics=['accuracy']
)
# 训练模型
model.fit(x_train,y_train,batch_size=32,epochs=10)
# 评估模型
# 测试集的loss和准确率
loss,accuracy = model.evaluate(x_test,y_test)
print("\ntest loss",loss)
print("test_accuracy:",accuracy)
# 训练集的loss和准确率
loss,accuracy = model.evaluate(x_train,y_train)
print("\ntrain loss",loss)
print("train_accuracy:",accuracy)
边栏推荐
- After going to ByteDance, I learned that there are so many test engineers with an annual salary of 40W?
- 【obs】官方是配置USE_GPU_PRIORITY 效果为TRUE的
- 深潜Kotlin协程(二十二):Flow的处理
- Service Mesh的基本模式
- 13. Enregistrement et chargement des modèles
- Is it safe to speculate in stocks on mobile phones?
- [go record] start go language from scratch -- make an oscilloscope with go language (I) go language foundation
- 大二级分类产品页权重低,不收录怎么办?
- Kubernetes Static Pod (静态Pod)
- 基于人脸识别实现课堂抬头率检测
猜你喜欢
Langchao Yunxi distributed database tracing (II) -- source code analysis
3 years of experience, can't you get 20K for the interview and test post? Such a hole?
9.卷积神经网络介绍
[necessary for R & D personnel] how to make your own dataset and display it.
Kubernetes Static Pod (静态Pod)
英雄联盟胜负预测--简易肯德基上校
Kubernetes static pod (static POD)
How to learn a new technology (programming language)
8道经典C语言指针笔试题解析
华为交换机S5735S-L24T4S-QA2无法telnet远程访问
随机推荐
Summary of weidongshan phase II course content
5g NR system messages
Cancel the down arrow of the default style of select and set the default word of select
Hotel
Codeforces Round #804 (Div. 2)(A~D)
手写一个模拟的ReentrantLock
[note] common combined filter circuit
Deep dive kotlin collaboration (the end of 23): sharedflow and stateflow
Hotel
【GO记录】从零开始GO语言——用GO语言做一个示波器(一)GO语言基础
[go record] start go language from scratch -- make an oscilloscope with go language (I) go language foundation
Qt不同类之间建立信号槽,并传递参数
Is it safe to speculate in stocks on mobile phones?
New library launched | cnopendata China Time-honored enterprise directory
What is load balancing? How does DNS achieve load balancing?
基于卷积神经网络的恶意软件检测方法
Fofa attack and defense challenge record
攻防演练中沙盘推演的4个阶段
tourist的NTT模板
Malware detection method based on convolutional neural network