当前位置:网站首页>[CV] Wu Enda machine learning course notes | Chapter 9
[CV] Wu Enda machine learning course notes | Chapter 9
2022-07-04 08:13:00 【Fannnnf】
If there is no special explanation in this series of articles , The text explains the picture above the text
machine learning | Coursera
Wu Enda machine learning series _bilibili
Catalog
9 neural network :Learning
9-1 Cost function applied to neural network
- use L L L Represents the total number of layers of the neural network (Layers)
- use s l s_l sl It means the first one l l l Layer unit ( Neuron ) The number of ( The bias unit is not included )
- h Θ ( x ) ∈ R K h_\Theta(x)\in\mathbb{R}^K hΘ(x)∈RK( h Θ ( x ) h_\Theta(x) hΘ(x) by K K K Dimension vector , Common to the output layer of neural network K K K Neurons , That is to say K K K Outputs )
- ( h Θ ( x ) ) i = i t h o u t p u t (h_\Theta(x))_i=i^{th} output (hΘ(x))i=ithoutput( ( h Θ ( x ) ) i (h_\Theta(x))_i (hΘ(x))i It means the first one i i i Outputs )
The cost function applied to neural networks is :
J ( Θ ) = − 1 m [ ∑ i = 1 m ∑ k = 1 K y ( i ) l o g ( h Θ ( x ( i ) ) ) k + ( 1 − y k ( i ) ) l o g ( 1 − ( h Θ ( x ( i ) ) ) k ) ] + λ 2 m ∑ l = 1 L − 1 ∑ i = 1 s l ∑ j = 1 s l + 1 ( Θ j i ( l ) ) 2 J(\Theta)=-\frac{1}{m}\left[\sum_{i=1}^m\sum_{k=1}^Ky^{(i)}log(h_\Theta(x^{(i)}))_k+(1-y_k^{(i)})log(1-(h_\Theta(x^{(i)}))_k)\right] +\frac{λ}{2m}\sum_{l=1}^{L-1}\sum_{i=1}^{s_l}\sum_{j=1}^{s_{l+1}}(\Theta_{ji}^{(l)})^2 J(Θ)=−m1[i=1∑mk=1∑Ky(i)log(hΘ(x(i)))k+(1−yk(i))log(1−(hΘ(x(i)))k)]+2mλl=1∑L−1i=1∑slj=1∑sl+1(Θji(l))2
- In the second item ∑ i = 1 s l ∑ j = 1 s l + 1 \sum_{i=1}^{s_l}\sum_{j=1}^{s_{l+1}} ∑i=1sl∑j=1sl+1 It means to be s l + 1 s_{l+1} sl+1 That's ok s l s_l sl Columns of the matrix Θ j i ( l ) \Theta_{ji}^{(l)} Θji(l) Add up each element in
- In the second item ∑ l = 1 L − 1 \sum_{l=1}^{L-1} ∑l=1L−1 It refers to summing the matrices of the input layer and the hidden layer
9-2 Back propagation algorithm
- δ j ( l ) \delta_j^{(l)} δj(l) It is defined as the first l l l Layer j j j Deviation of neurons (“error”)
Take the four layer neural network in the above figure as an example - δ j ( 4 ) = a j ( 4 ) − y j \delta_j^{(4)}=a_j^{(4)}-y_j δj(4)=aj(4)−yj( y j y_j yj Refers to the first j j j Output values in the data set , a j ( 4 ) a_j^{(4)} aj(4) It refers to the second of neural network j j j Outputs , a j ( 4 ) a_j^{(4)} aj(4) It can also be expressed as ( h Θ ( x ) ) j (h_\Theta(x))_j (hΘ(x))j)
- The above formula can be expressed as δ ( 4 ) = a ( 4 ) − y \delta^{(4)}=a^{(4)}-y δ(4)=a(4)−y, It can also be expressed as δ ( 4 ) = h Θ ( x ) − y \delta^{(4)}=h_\Theta(x)-y δ(4)=hΘ(x)−y
- δ ( 3 ) = ( Θ ( 3 ) ) T δ ( 4 ) ⋅ g ′ ( z ( 3 ) ) \delta^{(3)}=(\Theta^{(3)})^T\delta^{(4)}\cdot g^{\prime}(z^{(3)}) δ(3)=(Θ(3))Tδ(4)⋅g′(z(3))
among g ′ ( z ( 3 ) ) = a ( 3 ) ⋅ ( 1 − a ( 3 ) ) g^{\prime}(z^{(3)})=a^{(3)}\cdot (1-a^{(3)}) g′(z(3))=a(3)⋅(1−a(3)) - δ ( 2 ) = ( Θ ( 2 ) ) T δ ( 3 ) ⋅ g ′ ( z ( 2 ) ) \delta^{(2)}=(\Theta^{(2)})^T\delta^{(3)}\cdot g^{\prime}(z^{(2)}) δ(2)=(Θ(2))Tδ(3)⋅g′(z(2))
among g ′ ( z ( 2 ) ) = a ( 2 ) ⋅ ( 1 − a ( 2 ) ) g^{\prime}(z^{(2)})=a^{(2)}\cdot (1-a^{(2)}) g′(z(2))=a(2)⋅(1−a(2))
The result of dot multiplication is a number , The cross product is a vector
- ∂ ∂ Θ i j ( l ) J ( Θ ) = a j ( l ) δ i ( l + 1 ) \frac{\partial}{\partial \Theta_{ij}^{(l)}}J(\Theta)=a_j^{(l)}\delta_i^{(l+1)} ∂Θij(l)∂J(Θ)=aj(l)δi(l+1)
The regularization term is ignored here , The idea that λ = 0 \lambda=0 λ=0 - The above figure is the flow of the back propagation algorithm , Finally, we can get ∂ ∂ Θ i j ( l ) J ( Θ ) = D i j ( l ) \frac{\partial}{\partial \Theta_{ij}^{(l)}}J(\Theta)=D^{(l)}_{ij} ∂Θij(l)∂J(Θ)=Dij(l), Then carry out gradient descent algorithm
9-3 Understand back propagation
Take the neural network in the figure above as an example
- δ 2 ( 2 ) = Θ 12 ( 2 ) δ 1 ( 3 ) + Θ 22 ( 2 ) δ 2 ( 3 ) \delta_2^{(2)}=\Theta_{12}^{(2)}\delta_1^{(3)}+\Theta_{22}^{(2)}\delta_2^{(3)} δ2(2)=Θ12(2)δ1(3)+Θ22(2)δ2(3)
- δ 2 ( 3 ) = Θ 12 ( 3 ) δ 1 ( 4 ) \delta_2^{(3)}=\Theta_{12}^{(3)}\delta_1^{(4)} δ2(3)=Θ12(3)δ1(4)
9-4 Expand parameters
9-5 Gradient detection
To estimate the cost function J ( Θ ) J(\Theta) J(Θ) Upper point ( θ , J ( Θ ) ) (\theta,J(\Theta)) (θ,J(Θ)) Derivative at , Can use d d θ J ( θ ) ≈ J ( θ + ε ) − J ( θ − ε ) 2 ε ( ε = 1 0 − 4 by should ) \frac{\mathrm{d} }{\mathrm{d} \theta}J(\theta)\approx\frac{J(\theta+\varepsilon)-J(\theta-\varepsilon)}{2\varepsilon}(\varepsilon=10^{-4} It is advisable to ) dθdJ(θ)≈2εJ(θ+ε)−J(θ−ε)(ε=10−4 by should ) Obtain derivative
Expand into vectors , Pictured above
- θ \theta θ It's a n n n Dimension vector , It's a matrix Θ ( 1 ) , Θ ( 2 ) , Θ ( 3 ) , . . . \Theta^{(1)},\Theta^{(2)},\Theta^{(3)},... Θ(1),Θ(2),Θ(3),... Expansion of
- It can be estimated that ∂ ∂ θ n J ( θ ) \frac{\partial}{\partial \theta_{n}}J(\theta) ∂θn∂J(θ) Value
Compare the estimated partial derivative value with the partial derivative value obtained by back propagation , If the two values are very close , You can verify that the calculation is correct
Once it is determined that the value calculated by the back propagation algorithm is correct , You should turn off the gradient test algorithm
9-6 Random initialization
If at the beginning of the program Θ \Theta Θ All elements in are 0, It will cause multiple neurons to calculate the same characteristics , Leading to redundancy , This becomes a symmetric weight problem
So when initializing, make Θ i j ( l ) \Theta^{(l)}_{ij} Θij(l) be equal to [ − ϵ , ϵ ] [-\epsilon,\epsilon] [−ϵ,ϵ] A random value in
9-7 Review summary
Train a neural network :
1. Random an initial weight
2. Execute forward propagation algorithm , Get to all x ( i ) x^{(i)} x(i) Of h Θ ( x ( i ) ) h_\Theta(x^{(i)}) hΘ(x(i))
3. Computational cost function J ( Θ ) J(\Theta) J(Θ)
4. Execute back propagation algorithm , Calculation ∂ ∂ Θ j k ( l ) J ( Θ ) \frac{\partial}{\partial\Theta_{jk}^{(l)}}J(\Theta) ∂Θjk(l)∂J(Θ)
(get a ( l ) a^{(l)} a(l) and δ ( l ) \delta^{(l)} δ(l) for l = 2 , . . . , L l=2,...,L l=2,...,L)
5. Estimated by gradient test algorithm J ( Θ ) J(\Theta) J(Θ) Partial derivative of , Compare the estimated partial derivative value with the partial derivative value obtained by back propagation , If the two values are very close , It can be verified that the calculation result of the back propagation algorithm is correct ; After verification , Turn off the run Inspection Algorithm (disable gradient checking code)
6. Use gradient descent algorithm or other more advanced optimization methods , Combined with the back propagation calculation results , Get to make J ( Θ ) J(\Theta) J(Θ) The smallest parameter Θ \Theta Θ Value
边栏推荐
- Common components of flask
- 1. Getting started with QT
- Comprendre la méthode de détection des valeurs aberrantes des données
- Preliminary study on temporal database incluxdb 2.2
- WordPress get_ Users() returns all users with comparison queries - PHP
- Wechat has new functions, and the test is started again
- L1-025 positive integer a+b (15 points)
- DM database password policy and login restriction settings
- Chrome is set to pure black
- This article is enough for learning advanced mysql
猜你喜欢
BUUCTF(4)
zabbix监控系统邮件报警配置
snipaste 方便的截图软件,可以复制在屏幕上
Azure ad domain service (II) configure azure file share disk sharing for machines in the domain service
AcWing 244. Enigmatic cow (tree array + binary search)
时序数据库 InfluxDB 2.2 初探
一文了解数据异常值检测方法
1. Getting started with QT
Redis sentinel mechanism
Linear algebra 1.1
随机推荐
Google's official response: we have not given up tensorflow and will develop side by side with Jax in the future
Example analysis of C # read / write lock
I was pressed for the draft, so let's talk about how long links can be as efficient as short links in the development of mobile terminals
DM8 database recovery based on point in time
时序数据库 InfluxDB 2.2 初探
如何用MOS管来实现电源防反接电路
Application of isnull in database query
L1-022 odd even split (10 points)
Activiti常见操作数据表关系
[Gurobi] 简单模型的建立
How to improve your system architecture?
Ecole bio rushes to the scientific innovation board: the annual revenue is 330million. Honghui fund and Temasek are shareholders
Group programming ladder race - exercise set l2-002 linked list de duplication
L1-026 I love gplt (5 points)
团体程序设计天梯赛-练习集 L2-002 链表去重
小程序容器技术与物联网 IoT 可以碰撞出什么样的火花
Collections in Scala
es6总结
Snipaste convenient screenshot software, which can be copied on the screen
Unity text superscript square representation +text judge whether the text is empty