当前位置:网站首页>[CV] Wu Enda machine learning course notes | Chapter 9
[CV] Wu Enda machine learning course notes | Chapter 9
2022-07-04 08:13:00 【Fannnnf】
If there is no special explanation in this series of articles , The text explains the picture above the text
machine learning | Coursera
Wu Enda machine learning series _bilibili
Catalog
9 neural network :Learning
9-1 Cost function applied to neural network
- use L L L Represents the total number of layers of the neural network (Layers)
- use s l s_l sl It means the first one l l l Layer unit ( Neuron ) The number of ( The bias unit is not included )
- h Θ ( x ) ∈ R K h_\Theta(x)\in\mathbb{R}^K hΘ(x)∈RK( h Θ ( x ) h_\Theta(x) hΘ(x) by K K K Dimension vector , Common to the output layer of neural network K K K Neurons , That is to say K K K Outputs )
- ( h Θ ( x ) ) i = i t h o u t p u t (h_\Theta(x))_i=i^{th} output (hΘ(x))i=ithoutput( ( h Θ ( x ) ) i (h_\Theta(x))_i (hΘ(x))i It means the first one i i i Outputs )
The cost function applied to neural networks is :
J ( Θ ) = − 1 m [ ∑ i = 1 m ∑ k = 1 K y ( i ) l o g ( h Θ ( x ( i ) ) ) k + ( 1 − y k ( i ) ) l o g ( 1 − ( h Θ ( x ( i ) ) ) k ) ] + λ 2 m ∑ l = 1 L − 1 ∑ i = 1 s l ∑ j = 1 s l + 1 ( Θ j i ( l ) ) 2 J(\Theta)=-\frac{1}{m}\left[\sum_{i=1}^m\sum_{k=1}^Ky^{(i)}log(h_\Theta(x^{(i)}))_k+(1-y_k^{(i)})log(1-(h_\Theta(x^{(i)}))_k)\right] +\frac{λ}{2m}\sum_{l=1}^{L-1}\sum_{i=1}^{s_l}\sum_{j=1}^{s_{l+1}}(\Theta_{ji}^{(l)})^2 J(Θ)=−m1[i=1∑mk=1∑Ky(i)log(hΘ(x(i)))k+(1−yk(i))log(1−(hΘ(x(i)))k)]+2mλl=1∑L−1i=1∑slj=1∑sl+1(Θji(l))2
- In the second item ∑ i = 1 s l ∑ j = 1 s l + 1 \sum_{i=1}^{s_l}\sum_{j=1}^{s_{l+1}} ∑i=1sl∑j=1sl+1 It means to be s l + 1 s_{l+1} sl+1 That's ok s l s_l sl Columns of the matrix Θ j i ( l ) \Theta_{ji}^{(l)} Θji(l) Add up each element in
- In the second item ∑ l = 1 L − 1 \sum_{l=1}^{L-1} ∑l=1L−1 It refers to summing the matrices of the input layer and the hidden layer
9-2 Back propagation algorithm
- δ j ( l ) \delta_j^{(l)} δj(l) It is defined as the first l l l Layer j j j Deviation of neurons (“error”)
Take the four layer neural network in the above figure as an example - δ j ( 4 ) = a j ( 4 ) − y j \delta_j^{(4)}=a_j^{(4)}-y_j δj(4)=aj(4)−yj( y j y_j yj Refers to the first j j j Output values in the data set , a j ( 4 ) a_j^{(4)} aj(4) It refers to the second of neural network j j j Outputs , a j ( 4 ) a_j^{(4)} aj(4) It can also be expressed as ( h Θ ( x ) ) j (h_\Theta(x))_j (hΘ(x))j)
- The above formula can be expressed as δ ( 4 ) = a ( 4 ) − y \delta^{(4)}=a^{(4)}-y δ(4)=a(4)−y, It can also be expressed as δ ( 4 ) = h Θ ( x ) − y \delta^{(4)}=h_\Theta(x)-y δ(4)=hΘ(x)−y
- δ ( 3 ) = ( Θ ( 3 ) ) T δ ( 4 ) ⋅ g ′ ( z ( 3 ) ) \delta^{(3)}=(\Theta^{(3)})^T\delta^{(4)}\cdot g^{\prime}(z^{(3)}) δ(3)=(Θ(3))Tδ(4)⋅g′(z(3))
among g ′ ( z ( 3 ) ) = a ( 3 ) ⋅ ( 1 − a ( 3 ) ) g^{\prime}(z^{(3)})=a^{(3)}\cdot (1-a^{(3)}) g′(z(3))=a(3)⋅(1−a(3)) - δ ( 2 ) = ( Θ ( 2 ) ) T δ ( 3 ) ⋅ g ′ ( z ( 2 ) ) \delta^{(2)}=(\Theta^{(2)})^T\delta^{(3)}\cdot g^{\prime}(z^{(2)}) δ(2)=(Θ(2))Tδ(3)⋅g′(z(2))
among g ′ ( z ( 2 ) ) = a ( 2 ) ⋅ ( 1 − a ( 2 ) ) g^{\prime}(z^{(2)})=a^{(2)}\cdot (1-a^{(2)}) g′(z(2))=a(2)⋅(1−a(2))
The result of dot multiplication is a number , The cross product is a vector
- ∂ ∂ Θ i j ( l ) J ( Θ ) = a j ( l ) δ i ( l + 1 ) \frac{\partial}{\partial \Theta_{ij}^{(l)}}J(\Theta)=a_j^{(l)}\delta_i^{(l+1)} ∂Θij(l)∂J(Θ)=aj(l)δi(l+1)
The regularization term is ignored here , The idea that λ = 0 \lambda=0 λ=0 - The above figure is the flow of the back propagation algorithm , Finally, we can get ∂ ∂ Θ i j ( l ) J ( Θ ) = D i j ( l ) \frac{\partial}{\partial \Theta_{ij}^{(l)}}J(\Theta)=D^{(l)}_{ij} ∂Θij(l)∂J(Θ)=Dij(l), Then carry out gradient descent algorithm
9-3 Understand back propagation
Take the neural network in the figure above as an example
- δ 2 ( 2 ) = Θ 12 ( 2 ) δ 1 ( 3 ) + Θ 22 ( 2 ) δ 2 ( 3 ) \delta_2^{(2)}=\Theta_{12}^{(2)}\delta_1^{(3)}+\Theta_{22}^{(2)}\delta_2^{(3)} δ2(2)=Θ12(2)δ1(3)+Θ22(2)δ2(3)
- δ 2 ( 3 ) = Θ 12 ( 3 ) δ 1 ( 4 ) \delta_2^{(3)}=\Theta_{12}^{(3)}\delta_1^{(4)} δ2(3)=Θ12(3)δ1(4)
9-4 Expand parameters
9-5 Gradient detection
To estimate the cost function J ( Θ ) J(\Theta) J(Θ) Upper point ( θ , J ( Θ ) ) (\theta,J(\Theta)) (θ,J(Θ)) Derivative at , Can use d d θ J ( θ ) ≈ J ( θ + ε ) − J ( θ − ε ) 2 ε ( ε = 1 0 − 4 by should ) \frac{\mathrm{d} }{\mathrm{d} \theta}J(\theta)\approx\frac{J(\theta+\varepsilon)-J(\theta-\varepsilon)}{2\varepsilon}(\varepsilon=10^{-4} It is advisable to ) dθdJ(θ)≈2εJ(θ+ε)−J(θ−ε)(ε=10−4 by should ) Obtain derivative
Expand into vectors , Pictured above
- θ \theta θ It's a n n n Dimension vector , It's a matrix Θ ( 1 ) , Θ ( 2 ) , Θ ( 3 ) , . . . \Theta^{(1)},\Theta^{(2)},\Theta^{(3)},... Θ(1),Θ(2),Θ(3),... Expansion of
- It can be estimated that ∂ ∂ θ n J ( θ ) \frac{\partial}{\partial \theta_{n}}J(\theta) ∂θn∂J(θ) Value
Compare the estimated partial derivative value with the partial derivative value obtained by back propagation , If the two values are very close , You can verify that the calculation is correct
Once it is determined that the value calculated by the back propagation algorithm is correct , You should turn off the gradient test algorithm
9-6 Random initialization
If at the beginning of the program Θ \Theta Θ All elements in are 0, It will cause multiple neurons to calculate the same characteristics , Leading to redundancy , This becomes a symmetric weight problem
So when initializing, make Θ i j ( l ) \Theta^{(l)}_{ij} Θij(l) be equal to [ − ϵ , ϵ ] [-\epsilon,\epsilon] [−ϵ,ϵ] A random value in
9-7 Review summary
Train a neural network :
1. Random an initial weight
2. Execute forward propagation algorithm , Get to all x ( i ) x^{(i)} x(i) Of h Θ ( x ( i ) ) h_\Theta(x^{(i)}) hΘ(x(i))
3. Computational cost function J ( Θ ) J(\Theta) J(Θ)
4. Execute back propagation algorithm , Calculation ∂ ∂ Θ j k ( l ) J ( Θ ) \frac{\partial}{\partial\Theta_{jk}^{(l)}}J(\Theta) ∂Θjk(l)∂J(Θ)
(get a ( l ) a^{(l)} a(l) and δ ( l ) \delta^{(l)} δ(l) for l = 2 , . . . , L l=2,...,L l=2,...,L)
5. Estimated by gradient test algorithm J ( Θ ) J(\Theta) J(Θ) Partial derivative of , Compare the estimated partial derivative value with the partial derivative value obtained by back propagation , If the two values are very close , It can be verified that the calculation result of the back propagation algorithm is correct ; After verification , Turn off the run Inspection Algorithm (disable gradient checking code)
6. Use gradient descent algorithm or other more advanced optimization methods , Combined with the back propagation calculation results , Get to make J ( Θ ) J(\Theta) J(Θ) The smallest parameter Θ \Theta Θ Value
边栏推荐
- 【性能測試】一文讀懂Jmeter
- 力扣今日题-1200. 最小绝对差
- Common components of flask
- PCIE知识点-010:PCIE 热插拔资料从哪获取
- C # implements a queue in which everything can be sorted
- 小程序容器技术与物联网 IoT 可以碰撞出什么样的火花
- Flask 常用组件
- Need help resetting PHP counters - PHP
- L1-021 important words three times (5 points)
- Snipaste convenient screenshot software, which can be copied on the screen
猜你喜欢
Sqli labs download, installation and reset of SQL injection test tool one of the solutions to the database error (# 0{main}throw in d:\software\phpstudy_pro\www\sqli labs-...)
This monitoring system can monitor the turnover intention and fishing all, and the product page has 404 after the dispute appears
运动【跑步 01】一个程序员的半马挑战:跑前准备+跑中调整+跑后恢复(经验分享)
[performance test] read JMeter
Redis 哨兵机制
Preliminary study on temporal database incluxdb 2.2
A method for detecting outliers of data
ZABBIX monitoring system custom monitoring content
Redis sentinel mechanism
BUUCTF(4)
随机推荐
Linear algebra 1.1
The text box displays the word (prompt text) by default, and the text disappears after clicking.
How to get bytes containing null terminators from a string- c#
L1-025 positive integer a+b (15 points)
Application of isnull in database query
Comparison between applet framework and platform compilation
Using the rate package for data mining
DM8 command line installation and database creation
【Go基础】2 - Go基本语句
Need help resetting PHP counters - PHP
Unity-写入Word
OKR vs. KPI 一次搞清楚这两大概念!
Set and modify the page address bar icon favicon ico
Leetcode 23. Merge K ascending linked lists
Google's official response: we have not given up tensorflow and will develop side by side with Jax in the future
Unity-Text上标平方表示形式+text判断文本是否为空
Conversion of yolov5 XML dataset to VOC dataset
JVM -- class loading process and runtime data area
If the array values match each other, shuffle again - PHP
Parallel shift does not provide any acceleration - C #