当前位置:网站首页>机器学习:随机梯度下降(SGD)与梯度下降(GD)的区别与代码实现。
机器学习:随机梯度下降(SGD)与梯度下降(GD)的区别与代码实现。
2022-07-06 17:34:00 【HanZee】
机器学习:随机梯度下降(SGD)与梯度下降(GD)的区别与代码实现。
如果想细致的了解:-》 梯度下降法
梯度下降法(GD)
假设函数fx, 代价函数cost,有如下表达式:
f ( x ) = w 1 x 1 + w 2 x 2 + b c o s t ( w ) = 1 n ∑ i = 1 n ( f ( x i ) − y i ) w 1 = w 1 o l d − α ∂ c o s t ( w ) ∂ w 1 c o s t ( w ) w 2 = w 2 o l d − α ∂ c o s t ( w ) ∂ w 2 c o s t ( w ) \begin{aligned}f\left( x\right) =w_{1}x_{1}+w_{2}x_{2}+b\\ cost\left( w\right) =\dfrac{1}{n}\sum ^{n}_{i=1}\left( f(x_{i}\right) -y_{i}) \\ w_{1}=w_{1old}-\alpha \dfrac{\partial cos t\left( w\right) }{\partial w_{1}}cos t\left( w\right) \\ w _{2}=w_{2old}-\alpha \dfrac{\partial cos t\left( w\right) }{\partial w_{2}}cos t\left( w\right) \end{aligned} f(x)=w1x1+w2x2+bcost(w)=n1i=1∑n(f(xi)−yi)w1=w1old−α∂w1∂cost(w)cost(w)w2=w2old−α∂w2∂cost(w)cost(w)
从上面公式,我们得出如下结论:
1.参数w,b每更新一次,就需要计算一次全体数据对相应参数的偏导数,这个计算量是很大的,函数的收敛速度会在数据量很大的时候会很慢。
2.与SGD不同,每一次参数的改变,都能保证cost是朝着全局最小方向移动的。
3.如果cost非凸函数,函数可能会陷入局部最优。
随即梯度下降(SGD)
公式如下:
f ( x ) = w 1 x 1 + w 2 x 2 + b f\left( x\right) =w_{1}x_{1}+w_{2}x_{2}+b f(x)=w1x1+w2x2+b
f o r ( i = 0 , i < = n , i + + ) c o s t ( w ) = ( f ( x i ) − y i ) w 1 = w 1 o l d − α ∂ c o s t ( w ) ∂ w 1 c o s t ( w ) w 2 = w 2 o l d − α ∂ c o s t ( w ) ∂ w 2 c o s t ( w ) for (i=0,i<=n,i++)\\ cost\left( w\right) =(f(x_i)-y_i)\\ w_{1}=w_{1old}-\alpha \dfrac{\partial cos t\left( w\right) }{\partial w_{1}}cos t\left( w\right) \\ w _{2}=w_{2old}-\alpha \dfrac{\partial cos t\left( w\right) }{\partial w_{2}}cos t\left( w\right) for(i=0,i<=n,i++)cost(w)=(f(xi)−yi)w1=w1old−α∂w1∂cost(w)cost(w)w2=w2old−α∂w2∂cost(w)cost(w)
从上面公式,得出如下结论:
- SGD中每更新一次参数,只计算了1个batch的梯度(上面公式假设batch=1),大大加快了函数的收敛速度。
2.SGD每一次更新参数只考虑了一个数据,可能不会每一次都是朝着全局最优的方向移动,最终可能无法收敛到最小,但是会解决陷入局部最优的问题。
代码实现
以波士顿房价预测为案例
导入数据
import numpy as np
path = 'Desktop/波士顿房价/trian.csv'
data = np.loadtxt(path, delimiter = ",", skiprows=1)
data.shape
分割数据
train = data[:int(data.shape[0]*0.8)]
test = data[int(data.shape[0]*0.8):]
print(train.shape, test.shape)
train_x = train[:,:-1]
train_y = train[:,13:]
test_x = test[:,:-1]
test_y = test[:,13:]
print(train_x.shape, train_y.shape)
class Network:
def __init__(self, num_weights):
self.num_weights = num_weights
self.w = np.random.rand(num_weights, 1)
self.b = 0
def forward(self, x):
z = np.dot(x, self.w) + self.b
return z
def loss(self, z, y):
cost = (z-y)*(z-y)
cost = np.mean(cost)
return cost
def gradient(self, z, y):
w = (z-y)*train_x
w = np.mean(w, axis=0)
w = np.array(w).reshape([13, 1])
b = z-y
b = np.mean(b)
return w, b
def update(self, gradient_w, gradient_b, eta):
self.w = self.w - eta*gradient_w
self.b = self.b - eta*gradient_b
#梯度下降
def train_GD(self, items, eta):
for i in range(items):
z = self.forward(train_x)
loss = self.loss(z, train_y)
gradient_w, gradient_b = self.gradient(z, train_y)
self.update(gradient_w, gradient_b, eta)
# if i % 100 == 0:
test_loss = self.test()
print('item:', i, 'loss:', loss, 'test_loss:', test_loss)
#随即梯度下降
def train_SGD(self, num_epochs, batchsize, eta):
for epoch_id in range(num_epochs):
np.random.shuffle(train)
losses = []
for i in range(0, len(train), batchsize):
# print(i, batchsize+i)
mini_batchs = train[i:i + batchsize]
for iter_id, mini_batch in enumerate(mini_batchs):
# print(mini_batch)
x = mini_batch[:-1]
y = mini_batch[-1]
z = self.forward(x)
loss = self.loss(z, y)
gradient_w, gradient_b = self.gradient(z, y)
self.update(gradient_w, gradient_b, eta)
losses.append(loss)
sum = 0
for i in losses:
sum += i
loss_mean = sum/len(losses)
print('Epoch{}, loss{}, loss_mean{}'.format(epoch_id, loss, loss_mean))
def test(self):
z = self.forward(test_x)
loss = self.loss(z, test_y)
return loss
net = Network(13)
net.train_GD(100, eta=1e-9)
net.train_SGD(100, 5, 1e-9)
边栏推荐
- 云呐|工单管理软件,工单管理软件APP
- Batch obtain the latitude coordinates of all administrative regions in China (to the county level)
- Spark TPCDS Data Gen
- Rainstorm effect in levels - ue5
- 实现mysql与ES的增量数据同步
- Openjudge noi 1.7 08: character substitution
- taro3.*中使用 dva 入门级别的哦
- Neon Optimization: an optimization case of log10 function
- Windows installation mysql8 (5 minutes)
- Atomic in golang, and cas Operations
猜你喜欢
The MySQL database in Alibaba cloud was attacked, and finally the data was found
Go zero micro service practical series (IX. ultimate optimization of seckill performance)
MySQL script batch queries all tables containing specified field types in the database
pytorch之数据类型tensor
力扣1037. 有效的回旋镖
UI control telerik UI for WinForms new theme - vs2022 heuristic theme
云呐|工单管理办法,如何开展工单管理
Telerik UI 2022 R2 SP1 Retail-Not Crack
身体质量指数程序,入门写死的小程序项目
[牛客] [NOIP2015]跳石头
随机推荐
Fastdfs data migration operation record
THREE.AxesHelper is not a constructor
BFS realizes breadth first traversal of adjacency matrix (with examples)
docker 方法安装mysql
JTAG debugging experience of arm bare board debugging
Building a dream in the digital era, the Xi'an station of the city chain science and Technology Strategy Summit ended smoothly
【JVM调优实战100例】04——方法区调优实战(上)
线段树(SegmentTree)
UI control telerik UI for WinForms new theme - vs2022 heuristic theme
Body mass index program, entry to write dead applet project
Spark TPCDS Data Gen
Dell Notebook Periodic Flash Screen Fault
Informatics Orsay Ibn YBT 1172: find the factorial of n within 10000 | 1.6 14: find the factorial of n within 10000
界面控件DevExpress WinForms皮肤编辑器的这个补丁,你了解了吗?
NEON优化:性能优化经验总结
Come on, don't spread it out. Fashion cloud secretly takes you to collect "cloud" wool, and then secretly builds a personal website to be the king of scrolls, hehe
golang中的WaitGroup实现原理
第三方跳转网站 出现 405 Method Not Allowed
如何管理分布式团队?
分享一个通用的so动态库的编译方法