当前位置:网站首页>Machine learning: the difference between random gradient descent (SGD) and gradient descent (GD) and code implementation.
Machine learning: the difference between random gradient descent (SGD) and gradient descent (GD) and code implementation.
2022-07-07 01:21:00 【HanZee】
machine learning : Stochastic gradient descent (SGD) And gradient descent (GD) The difference between code implementation .
If you want to understand in detail :-》 Gradient descent method
Gradient descent method (GD)
Hypothesis function fx, Cost function cost, It has the following expression :
f ( x ) = w 1 x 1 + w 2 x 2 + b c o s t ( w ) = 1 n ∑ i = 1 n ( f ( x i ) − y i ) w 1 = w 1 o l d − α ∂ c o s t ( w ) ∂ w 1 c o s t ( w ) w 2 = w 2 o l d − α ∂ c o s t ( w ) ∂ w 2 c o s t ( w ) \begin{aligned}f\left( x\right) =w_{1}x_{1}+w_{2}x_{2}+b\\ cost\left( w\right) =\dfrac{1}{n}\sum ^{n}_{i=1}\left( f(x_{i}\right) -y_{i}) \\ w_{1}=w_{1old}-\alpha \dfrac{\partial cos t\left( w\right) }{\partial w_{1}}cos t\left( w\right) \\ w _{2}=w_{2old}-\alpha \dfrac{\partial cos t\left( w\right) }{\partial w_{2}}cos t\left( w\right) \end{aligned} f(x)=w1x1+w2x2+bcost(w)=n1i=1∑n(f(xi)−yi)w1=w1old−α∂w1∂cost(w)cost(w)w2=w2old−α∂w2∂cost(w)cost(w)
From the above formula , We come to the conclusion that :
1. Parameters w,b Every update , It is necessary to calculate the partial derivative of all data to the corresponding parameters , This calculation is very large , The convergence speed of the function will be very slow when there is a large amount of data .
2. And SGD Different , Every time the parameter changes , Can guarantee cost It moves towards the global minimum .
3. If cost Nonconvex functions , Functions may fall into local optima .
And then the gradient goes down (SGD)
The formula is as follows :
f ( x ) = w 1 x 1 + w 2 x 2 + b f\left( x\right) =w_{1}x_{1}+w_{2}x_{2}+b f(x)=w1x1+w2x2+b
f o r ( i = 0 , i < = n , i + + ) c o s t ( w ) = ( f ( x i ) − y i ) w 1 = w 1 o l d − α ∂ c o s t ( w ) ∂ w 1 c o s t ( w ) w 2 = w 2 o l d − α ∂ c o s t ( w ) ∂ w 2 c o s t ( w ) for (i=0,i<=n,i++)\\ cost\left( w\right) =(f(x_i)-y_i)\\ w_{1}=w_{1old}-\alpha \dfrac{\partial cos t\left( w\right) }{\partial w_{1}}cos t\left( w\right) \\ w _{2}=w_{2old}-\alpha \dfrac{\partial cos t\left( w\right) }{\partial w_{2}}cos t\left( w\right) for(i=0,i<=n,i++)cost(w)=(f(xi)−yi)w1=w1old−α∂w1∂cost(w)cost(w)w2=w2old−α∂w2∂cost(w)cost(w)
From the above formula , Come to the following conclusion :
- SGD Every time the parameter is updated , Only calculated 1 individual batch Gradient of ( The above formula assumes batch=1), It greatly accelerates the convergence speed of the function .
2.SGD Only one data is considered for each parameter update , It may not move in the direction of global optimization every time , Eventually, it may not converge to the minimum , But it will solve the problem of falling into local optimization .
Code implementation
Take Boston house price forecast as an example
Import data
import numpy as np
path = 'Desktop/ Boston prices /trian.csv'
data = np.loadtxt(path, delimiter = ",", skiprows=1)
data.shape
Split data
train = data[:int(data.shape[0]*0.8)]
test = data[int(data.shape[0]*0.8):]
print(train.shape, test.shape)
train_x = train[:,:-1]
train_y = train[:,13:]
test_x = test[:,:-1]
test_y = test[:,13:]
print(train_x.shape, train_y.shape)
class Network:
def __init__(self, num_weights):
self.num_weights = num_weights
self.w = np.random.rand(num_weights, 1)
self.b = 0
def forward(self, x):
z = np.dot(x, self.w) + self.b
return z
def loss(self, z, y):
cost = (z-y)*(z-y)
cost = np.mean(cost)
return cost
def gradient(self, z, y):
w = (z-y)*train_x
w = np.mean(w, axis=0)
w = np.array(w).reshape([13, 1])
b = z-y
b = np.mean(b)
return w, b
def update(self, gradient_w, gradient_b, eta):
self.w = self.w - eta*gradient_w
self.b = self.b - eta*gradient_b
# gradient descent
def train_GD(self, items, eta):
for i in range(items):
z = self.forward(train_x)
loss = self.loss(z, train_y)
gradient_w, gradient_b = self.gradient(z, train_y)
self.update(gradient_w, gradient_b, eta)
# if i % 100 == 0:
test_loss = self.test()
print('item:', i, 'loss:', loss, 'test_loss:', test_loss)
# And then the gradient goes down
def train_SGD(self, num_epochs, batchsize, eta):
for epoch_id in range(num_epochs):
np.random.shuffle(train)
losses = []
for i in range(0, len(train), batchsize):
# print(i, batchsize+i)
mini_batchs = train[i:i + batchsize]
for iter_id, mini_batch in enumerate(mini_batchs):
# print(mini_batch)
x = mini_batch[:-1]
y = mini_batch[-1]
z = self.forward(x)
loss = self.loss(z, y)
gradient_w, gradient_b = self.gradient(z, y)
self.update(gradient_w, gradient_b, eta)
losses.append(loss)
sum = 0
for i in losses:
sum += i
loss_mean = sum/len(losses)
print('Epoch{}, loss{}, loss_mean{}'.format(epoch_id, loss, loss_mean))
def test(self):
z = self.forward(test_x)
loss = self.loss(z, test_y)
return loss
net = Network(13)
net.train_GD(100, eta=1e-9)
net.train_SGD(100, 5, 1e-9)
边栏推荐
- The MySQL database in Alibaba cloud was attacked, and finally the data was found
- ClickHouse字段分组聚合、按照任意时间段粒度查询SQL
- ARM裸板调试之JTAG调试体验
- 2022 Google CTF SEGFAULT LABYRINTH wp
- MySQL中回表的代价
- Atomic in golang, and cas Operations
- Windows installation mysql8 (5 minutes)
- 字节P7专业级讲解:接口测试常用工具及测试方法,福利文
- 交叉验证如何防止过拟合
- C# 计算农历日期方法 2022
猜你喜欢
ESP Arduino (IV) PWM waveform control output
让我们,从头到尾,通透网络I/O模型
boot - prometheus-push gateway 使用
Dynamic planning idea "from getting started to giving up"
Come on, don't spread it out. Fashion cloud secretly takes you to collect "cloud" wool, and then secretly builds a personal website to be the king of scrolls, hehe
HMM notes
资产安全问题或制约加密行业发展 风控+合规成为平台破局关键
[signal and system]
域分析工具BloodHound的使用说明
Analysis of mutex principle in golang
随机推荐
Neon Optimization: an optimization case of log10 function
golang中的atomic,以及CAS操作
Lldp compatible CDP function configuration
接收用户输入,身高BMI体重指数检测小业务入门案例
go-zero微服务实战系列(九、极致优化秒杀性能)
黑马笔记---创建不可变集合与Stream流
Atomic in golang and CAS operations
Docker method to install MySQL
线段树(SegmentTree)
第三方跳转网站 出现 405 Method Not Allowed
NEON优化:关于交叉存取与反向交叉存取
Transformation transformation operator
力扣1037. 有效的回旋镖
Segmenttree
从零开始匹配vim(0)——vimscript 简介
BFS realizes breadth first traversal of adjacency matrix (with examples)
云呐|工单管理办法,如何开展工单管理
UI控件Telerik UI for WinForms新主题——VS2022启发式主题
【案例分享】网络环路检测基本功能配置
免费白嫖的图床对比