当前位置:网站首页>7.正则化应用
7.正则化应用
2022-07-07 23:12:00 【booze-J】
一、正则化的应用
在6.Dropout应用中的未使用Dropout的代码的网络模型构建中添加正则化。
将6.Dropout应用中的
# 创建模型 输入784个神经元,输出10个神经元
model = Sequential([
# 定义输出是200 输入是784,设置偏置为1,添加softmax激活函数 第一个隐藏层有200个神经元
Dense(units=200,input_dim=784,bias_initializer='one',activation="tanh"),
# 第二个隐藏层有 100个神经元
Dense(units=100,bias_initializer='one',activation="tanh"),
Dense(units=10,bias_initializer='one',activation="softmax")
])
修改为
# 创建模型 输入784个神经元,输出10个神经元
model = Sequential([
# 定义输出是200 输入是784,设置偏置为1,添加softmax激活函数 第一个隐藏层有200个神经元
Dense(units=200,input_dim=784,bias_initializer='one',activation="tanh",kernel_regularizer=l2(0.0003)),
# 第二个隐藏层有 100个神经元
Dense(units=100,bias_initializer='one',activation="tanh",kernel_regularizer=l2(0.0003)),
Dense(units=10,bias_initializer='one',activation="softmax",kernel_regularizer=l2(0.0003))
])
使用l2正则化之前需要先导入from keras.regularizers import l2。
运行结果:
从运行结果可以看出来明显克服了一些过拟合的情况,模型对于数据集不是很复杂,加上正则化的话,它的效果可能就不是很好。
完整代码
代码运行平台为jupyter-notebook,文章中的代码块,也是按照jupyter-notebook中的划分顺序进行书写的,运行文章代码,直接分单元粘入到jupyter-notebook即可。
1.导入第三方库
import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from tensorflow.keras.optimizers import SGD
from keras.regularizers import l2
2.加载数据及数据预处理
# 载入数据
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000, 28, 28)
print("x_shape:\n",x_train.shape)
# (60000,) 还未进行one-hot编码 需要后面自己操作
print("y_shape:\n",y_train.shape)
# (60000, 28, 28) -> (60000,784) reshape()中参数填入-1的话可以自动计算出参数结果 除以255.0是为了归一化
x_train = x_train.reshape(x_train.shape[0],-1)/255.0
x_test = x_test.reshape(x_test.shape[0],-1)/255.0
# 换one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)
3.训练模型
# 创建模型 输入784个神经元,输出10个神经元
model = Sequential([
# 定义输出是200 输入是784,设置偏置为1,添加softmax激活函数 第一个隐藏层有200个神经元
Dense(units=200,input_dim=784,bias_initializer='one',activation="tanh",kernel_regularizer=l2(0.0003)),
# 第二个隐藏层有 100个神经元
Dense(units=100,bias_initializer='one',activation="tanh",kernel_regularizer=l2(0.0003)),
Dense(units=10,bias_initializer='one',activation="softmax",kernel_regularizer=l2(0.0003))
])
# 定义优化器
sgd = SGD(lr=0.2)
# 定义优化器,loss_function,训练过程中计算准确率
model.compile(
optimizer=sgd,
loss="categorical_crossentropy",
metrics=['accuracy']
)
# 训练模型
model.fit(x_train,y_train,batch_size=32,epochs=10)
# 评估模型
# 测试集的loss和准确率
loss,accuracy = model.evaluate(x_test,y_test)
print("\ntest loss",loss)
print("test_accuracy:",accuracy)
# 训练集的loss和准确率
loss,accuracy = model.evaluate(x_train,y_train)
print("\ntrain loss",loss)
print("train_accuracy:",accuracy)
边栏推荐
- Hotel
- Fofa attack and defense challenge record
- 新库上线 | 中国记者信息数据
- [necessary for R & D personnel] how to make your own dataset and display it.
- New library online | cnopendata China Star Hotel data
- 服务器防御DDOS的方法,杭州高防IP段103.219.39.x
- Is it safe to open an account on the official website of Huatai Securities?
- tourist的NTT模板
- 12.RNN应用于手写数字识别
- letcode43:字符串相乘
猜你喜欢

FOFA-攻防挑战记录

"An excellent programmer is worth five ordinary programmers", and the gap lies in these seven key points

图像数据预处理

Fofa attack and defense challenge record
![[note] common combined filter circuit](/img/2f/a8c2ef0d76dd7a45b50a64a928a9c8.png)
[note] common combined filter circuit

Course of causality, taught by Jonas Peters, University of Copenhagen

After going to ByteDance, I learned that there are so many test engineers with an annual salary of 40W?

QT establish signal slots between different classes and transfer parameters

【愚公系列】2022年7月 Go教学课程 006-自动推导类型和输入输出

Huawei switch s5735s-l24t4s-qa2 cannot be remotely accessed by telnet
随机推荐
SDNU_ ACM_ ICPC_ 2022_ Summer_ Practice(1~2)
5G NR 系统消息
华泰证券官方网站开户安全吗?
Cascade-LSTM: A Tree-Structured Neural Classifier for Detecting Misinformation Cascades(KDD20)
v-for遍历元素样式失效
Marubeni official website applet configuration tutorial is coming (with detailed steps)
The standby database has been delayed. Check that the MRP is wait_ for_ Log, apply after restarting MRP_ Log but wait again later_ for_ log
STL -- common function replication of string class
LeetCode刷题
Binder core API
新库上线 | CnOpenData中华老字号企业名录
基于微信小程序开发的我最在行的小游戏
Introduction to paddle - using lenet to realize image classification method II in MNIST
哪个券商公司开户佣金低又安全,又靠谱
玩轉Sonar
AI遮天传 ML-回归分析入门
Password recovery vulnerability of foreign public testing
德总理称乌不会获得“北约式”安全保障
炒股开户怎么最方便,手机上开户安全吗
手机上炒股安全么?