当前位置:网站首页>机器学习:随机梯度下降(SGD)与梯度下降(GD)的区别与代码实现。
机器学习:随机梯度下降(SGD)与梯度下降(GD)的区别与代码实现。
2022-07-06 17:34:00 【HanZee】
机器学习:随机梯度下降(SGD)与梯度下降(GD)的区别与代码实现。
如果想细致的了解:-》 梯度下降法
梯度下降法(GD)
假设函数fx, 代价函数cost,有如下表达式:
f ( x ) = w 1 x 1 + w 2 x 2 + b c o s t ( w ) = 1 n ∑ i = 1 n ( f ( x i ) − y i ) w 1 = w 1 o l d − α ∂ c o s t ( w ) ∂ w 1 c o s t ( w ) w 2 = w 2 o l d − α ∂ c o s t ( w ) ∂ w 2 c o s t ( w ) \begin{aligned}f\left( x\right) =w_{1}x_{1}+w_{2}x_{2}+b\\ cost\left( w\right) =\dfrac{1}{n}\sum ^{n}_{i=1}\left( f(x_{i}\right) -y_{i}) \\ w_{1}=w_{1old}-\alpha \dfrac{\partial cos t\left( w\right) }{\partial w_{1}}cos t\left( w\right) \\ w _{2}=w_{2old}-\alpha \dfrac{\partial cos t\left( w\right) }{\partial w_{2}}cos t\left( w\right) \end{aligned} f(x)=w1x1+w2x2+bcost(w)=n1i=1∑n(f(xi)−yi)w1=w1old−α∂w1∂cost(w)cost(w)w2=w2old−α∂w2∂cost(w)cost(w)
从上面公式,我们得出如下结论:
1.参数w,b每更新一次,就需要计算一次全体数据对相应参数的偏导数,这个计算量是很大的,函数的收敛速度会在数据量很大的时候会很慢。
2.与SGD不同,每一次参数的改变,都能保证cost是朝着全局最小方向移动的。
3.如果cost非凸函数,函数可能会陷入局部最优。
随即梯度下降(SGD)
公式如下:
f ( x ) = w 1 x 1 + w 2 x 2 + b f\left( x\right) =w_{1}x_{1}+w_{2}x_{2}+b f(x)=w1x1+w2x2+b
f o r ( i = 0 , i < = n , i + + ) c o s t ( w ) = ( f ( x i ) − y i ) w 1 = w 1 o l d − α ∂ c o s t ( w ) ∂ w 1 c o s t ( w ) w 2 = w 2 o l d − α ∂ c o s t ( w ) ∂ w 2 c o s t ( w ) for (i=0,i<=n,i++)\\ cost\left( w\right) =(f(x_i)-y_i)\\ w_{1}=w_{1old}-\alpha \dfrac{\partial cos t\left( w\right) }{\partial w_{1}}cos t\left( w\right) \\ w _{2}=w_{2old}-\alpha \dfrac{\partial cos t\left( w\right) }{\partial w_{2}}cos t\left( w\right) for(i=0,i<=n,i++)cost(w)=(f(xi)−yi)w1=w1old−α∂w1∂cost(w)cost(w)w2=w2old−α∂w2∂cost(w)cost(w)
从上面公式,得出如下结论:
- SGD中每更新一次参数,只计算了1个batch的梯度(上面公式假设batch=1),大大加快了函数的收敛速度。
2.SGD每一次更新参数只考虑了一个数据,可能不会每一次都是朝着全局最优的方向移动,最终可能无法收敛到最小,但是会解决陷入局部最优的问题。
代码实现
以波士顿房价预测为案例
导入数据
import numpy as np
path = 'Desktop/波士顿房价/trian.csv'
data = np.loadtxt(path, delimiter = ",", skiprows=1)
data.shape
分割数据
train = data[:int(data.shape[0]*0.8)]
test = data[int(data.shape[0]*0.8):]
print(train.shape, test.shape)
train_x = train[:,:-1]
train_y = train[:,13:]
test_x = test[:,:-1]
test_y = test[:,13:]
print(train_x.shape, train_y.shape)
class Network:
def __init__(self, num_weights):
self.num_weights = num_weights
self.w = np.random.rand(num_weights, 1)
self.b = 0
def forward(self, x):
z = np.dot(x, self.w) + self.b
return z
def loss(self, z, y):
cost = (z-y)*(z-y)
cost = np.mean(cost)
return cost
def gradient(self, z, y):
w = (z-y)*train_x
w = np.mean(w, axis=0)
w = np.array(w).reshape([13, 1])
b = z-y
b = np.mean(b)
return w, b
def update(self, gradient_w, gradient_b, eta):
self.w = self.w - eta*gradient_w
self.b = self.b - eta*gradient_b
#梯度下降
def train_GD(self, items, eta):
for i in range(items):
z = self.forward(train_x)
loss = self.loss(z, train_y)
gradient_w, gradient_b = self.gradient(z, train_y)
self.update(gradient_w, gradient_b, eta)
# if i % 100 == 0:
test_loss = self.test()
print('item:', i, 'loss:', loss, 'test_loss:', test_loss)
#随即梯度下降
def train_SGD(self, num_epochs, batchsize, eta):
for epoch_id in range(num_epochs):
np.random.shuffle(train)
losses = []
for i in range(0, len(train), batchsize):
# print(i, batchsize+i)
mini_batchs = train[i:i + batchsize]
for iter_id, mini_batch in enumerate(mini_batchs):
# print(mini_batch)
x = mini_batch[:-1]
y = mini_batch[-1]
z = self.forward(x)
loss = self.loss(z, y)
gradient_w, gradient_b = self.gradient(z, y)
self.update(gradient_w, gradient_b, eta)
losses.append(loss)
sum = 0
for i in losses:
sum += i
loss_mean = sum/len(losses)
print('Epoch{}, loss{}, loss_mean{}'.format(epoch_id, loss, loss_mean))
def test(self):
z = self.forward(test_x)
loss = self.loss(z, test_y)
return loss
net = Network(13)
net.train_GD(100, eta=1e-9)
net.train_SGD(100, 5, 1e-9)
边栏推荐
- Deep learning framework TF installation
- What are the differences between Oracle Linux and CentOS?
- 「笔记」折半搜索(Meet in the Middle)
- 免费白嫖的图床对比
- Force buckle 1037 Effective boomerang
- mysql: error while loading shared libraries: libtinfo.so.5: cannot open shared object file: No such
- pytorch之数据类型tensor
- table表格设置圆角
- 【JVM调优实战100例】05——方法区调优实战(下)
- [Niuke] [noip2015] jumping stone
猜你喜欢
Return to blowing marshland -- travel notes of zhailidong, founder of duanzhitang
2022 Google CTF segfault Labyrinth WP
c语言—数组
boot - prometheus-push gateway 使用
golang中的Mutex原理解析
Go zero micro service practical series (IX. ultimate optimization of seckill performance)
Force buckle 1037 Effective boomerang
ARM裸板调试之JTAG原理
Js逆向——捅了【马蜂窝】的ob混淆与加速乐
Analysis of mutex principle in golang
随机推荐
[case sharing] basic function configuration of network loop detection
2022 Google CTF SEGFAULT LABYRINTH wp
Do you understand this patch of the interface control devaxpress WinForms skin editor?
字节P7专业级讲解:接口测试常用工具及测试方法,福利文
Batch obtain the latitude coordinates of all administrative regions in China (to the county level)
Meet in the middle
Atomic in golang, and cas Operations
Cause of handler memory leak
Dell Notebook Periodic Flash Screen Fault
Oracle: Practice of CDB restricting PDB resources
力扣1037. 有效的回旋镖
Implementation principle of waitgroup in golang
What are the differences between Oracle Linux and CentOS?
tensorflow 1.14指定gpu运行设置
分享一个通用的so动态库的编译方法
BFS realizes breadth first traversal of adjacency matrix (with examples)
Openjudge noi 1.7 08: character substitution
7.6模拟赛总结
Grc: personal information protection law, personal privacy, corporate risk compliance governance
Windows installation mysql8 (5 minutes)