当前位置:网站首页>机器学习:随机梯度下降(SGD)与梯度下降(GD)的区别与代码实现。
机器学习:随机梯度下降(SGD)与梯度下降(GD)的区别与代码实现。
2022-07-06 17:34:00 【HanZee】
机器学习:随机梯度下降(SGD)与梯度下降(GD)的区别与代码实现。
如果想细致的了解:-》 梯度下降法
梯度下降法(GD)
假设函数fx, 代价函数cost,有如下表达式:
f ( x ) = w 1 x 1 + w 2 x 2 + b c o s t ( w ) = 1 n ∑ i = 1 n ( f ( x i ) − y i ) w 1 = w 1 o l d − α ∂ c o s t ( w ) ∂ w 1 c o s t ( w ) w 2 = w 2 o l d − α ∂ c o s t ( w ) ∂ w 2 c o s t ( w ) \begin{aligned}f\left( x\right) =w_{1}x_{1}+w_{2}x_{2}+b\\ cost\left( w\right) =\dfrac{1}{n}\sum ^{n}_{i=1}\left( f(x_{i}\right) -y_{i}) \\ w_{1}=w_{1old}-\alpha \dfrac{\partial cos t\left( w\right) }{\partial w_{1}}cos t\left( w\right) \\ w _{2}=w_{2old}-\alpha \dfrac{\partial cos t\left( w\right) }{\partial w_{2}}cos t\left( w\right) \end{aligned} f(x)=w1x1+w2x2+bcost(w)=n1i=1∑n(f(xi)−yi)w1=w1old−α∂w1∂cost(w)cost(w)w2=w2old−α∂w2∂cost(w)cost(w)
从上面公式,我们得出如下结论:
1.参数w,b每更新一次,就需要计算一次全体数据对相应参数的偏导数,这个计算量是很大的,函数的收敛速度会在数据量很大的时候会很慢。
2.与SGD不同,每一次参数的改变,都能保证cost是朝着全局最小方向移动的。
3.如果cost非凸函数,函数可能会陷入局部最优。
随即梯度下降(SGD)
公式如下:
f ( x ) = w 1 x 1 + w 2 x 2 + b f\left( x\right) =w_{1}x_{1}+w_{2}x_{2}+b f(x)=w1x1+w2x2+b
f o r ( i = 0 , i < = n , i + + ) c o s t ( w ) = ( f ( x i ) − y i ) w 1 = w 1 o l d − α ∂ c o s t ( w ) ∂ w 1 c o s t ( w ) w 2 = w 2 o l d − α ∂ c o s t ( w ) ∂ w 2 c o s t ( w ) for (i=0,i<=n,i++)\\ cost\left( w\right) =(f(x_i)-y_i)\\ w_{1}=w_{1old}-\alpha \dfrac{\partial cos t\left( w\right) }{\partial w_{1}}cos t\left( w\right) \\ w _{2}=w_{2old}-\alpha \dfrac{\partial cos t\left( w\right) }{\partial w_{2}}cos t\left( w\right) for(i=0,i<=n,i++)cost(w)=(f(xi)−yi)w1=w1old−α∂w1∂cost(w)cost(w)w2=w2old−α∂w2∂cost(w)cost(w)
从上面公式,得出如下结论:
- SGD中每更新一次参数,只计算了1个batch的梯度(上面公式假设batch=1),大大加快了函数的收敛速度。
2.SGD每一次更新参数只考虑了一个数据,可能不会每一次都是朝着全局最优的方向移动,最终可能无法收敛到最小,但是会解决陷入局部最优的问题。
代码实现
以波士顿房价预测为案例
导入数据
import numpy as np
path = 'Desktop/波士顿房价/trian.csv'
data = np.loadtxt(path, delimiter = ",", skiprows=1)
data.shape
分割数据
train = data[:int(data.shape[0]*0.8)]
test = data[int(data.shape[0]*0.8):]
print(train.shape, test.shape)
train_x = train[:,:-1]
train_y = train[:,13:]
test_x = test[:,:-1]
test_y = test[:,13:]
print(train_x.shape, train_y.shape)
class Network:
def __init__(self, num_weights):
self.num_weights = num_weights
self.w = np.random.rand(num_weights, 1)
self.b = 0
def forward(self, x):
z = np.dot(x, self.w) + self.b
return z
def loss(self, z, y):
cost = (z-y)*(z-y)
cost = np.mean(cost)
return cost
def gradient(self, z, y):
w = (z-y)*train_x
w = np.mean(w, axis=0)
w = np.array(w).reshape([13, 1])
b = z-y
b = np.mean(b)
return w, b
def update(self, gradient_w, gradient_b, eta):
self.w = self.w - eta*gradient_w
self.b = self.b - eta*gradient_b
#梯度下降
def train_GD(self, items, eta):
for i in range(items):
z = self.forward(train_x)
loss = self.loss(z, train_y)
gradient_w, gradient_b = self.gradient(z, train_y)
self.update(gradient_w, gradient_b, eta)
# if i % 100 == 0:
test_loss = self.test()
print('item:', i, 'loss:', loss, 'test_loss:', test_loss)
#随即梯度下降
def train_SGD(self, num_epochs, batchsize, eta):
for epoch_id in range(num_epochs):
np.random.shuffle(train)
losses = []
for i in range(0, len(train), batchsize):
# print(i, batchsize+i)
mini_batchs = train[i:i + batchsize]
for iter_id, mini_batch in enumerate(mini_batchs):
# print(mini_batch)
x = mini_batch[:-1]
y = mini_batch[-1]
z = self.forward(x)
loss = self.loss(z, y)
gradient_w, gradient_b = self.gradient(z, y)
self.update(gradient_w, gradient_b, eta)
losses.append(loss)
sum = 0
for i in losses:
sum += i
loss_mean = sum/len(losses)
print('Epoch{}, loss{}, loss_mean{}'.format(epoch_id, loss, loss_mean))
def test(self):
z = self.forward(test_x)
loss = self.loss(z, test_y)
return loss
net = Network(13)
net.train_GD(100, eta=1e-9)
net.train_SGD(100, 5, 1e-9)
边栏推荐
- Neon Optimization: performance optimization FAQ QA
- Deep learning framework TF installation
- Js逆向——捅了【马蜂窝】的ob混淆与加速乐
- tensorflow 1.14指定gpu运行设置
- NEON优化:性能优化常见问题QA
- [hfctf2020]babyupload session parsing engine
- Receive user input, height BMI, BMI detection small business entry case
- 迈动互联中标北京人寿保险,助推客户提升品牌价值
- 【案例分享】网络环路检测基本功能配置
- 负载均衡性能参数如何测评?
猜你喜欢
Send template message via wechat official account
资产安全问题或制约加密行业发展 风控+合规成为平台破局关键
Gazebo的安装&与ROS的连接
HMM notes
Anfulai embedded weekly report no. 272: 2022.06.27--2022.07.03
Dell Notebook Periodic Flash Screen Fault
Building a dream in the digital era, the Xi'an station of the city chain science and Technology Strategy Summit ended smoothly
Wood extraction in Halcon
BFS realizes breadth first traversal of adjacency matrix (with examples)
云呐|工单管理软件,工单管理软件APP
随机推荐
MySQL中回表的代价
Installation and testing of pyflink
SuperSocket 1.6 创建一个简易的报文长度在头部的Socket服务器
taro3.*中使用 dva 入门级别的哦
[case sharing] basic function configuration of network loop detection
[hfctf2020]babyupload session parsing engine
Dell笔记本周期性闪屏故障
从零开始匹配vim(0)——vimscript 简介
Dell Notebook Periodic Flash Screen Fault
Grc: personal information protection law, personal privacy, corporate risk compliance governance
Install Firefox browser on raspberry pie /arm device
Neon Optimization: About Cross access and reverse cross access
《安富莱嵌入式周报》第272期:2022.06.27--2022.07.03
Niuke cold training camp 6B (Freund has no green name level)
tensorflow 1.14指定gpu运行设置
第三方跳转网站 出现 405 Method Not Allowed
docker 方法安装mysql
负载均衡性能参数如何测评?
Your cache folder contains root-owned files, due to a bug in npm ERR! previous versions of npm which
接收用户输入,身高BMI体重指数检测小业务入门案例