当前位置:网站首页>Deep understanding of softmax
Deep understanding of softmax
2022-07-08 02:17:00 【Strawberry sauce toast】
Preface
This code is based on Pytorch Realization .
One 、softmax Definition and code implementation of
1.1 Definition
s o f t m a x ( x i ) = e x p ( x i ) ∑ j n e x p ( x j ) softmax(x_i) = \frac{exp(x_i)}{\sum_j^nexp(x_j)} softmax(xi)=∑jnexp(xj)exp(xi)
1.2 Code implementation
def softmax(X):
''' Realization softmax Input X The shape of is [ Number of samples , Output vector dimension ] '''
return torch.exp(X) / torch.sum(torch.exp(X), dim=1).reshape(-1, 1)
>>> X = torch.randn(5, 5)
>>> y = softmax(X)
>>> torch.sum(y, dim=1)
tensor([1.0000, 1.0000, 1.0000, 1.0000])
Two 、softmax The role of
softmax The output of the linear layer can be normalized and calibrated : Ensure that the output is nonnegative and the sum is 1.
Because if you directly regard the non normalized output as probability , Will exist 2 Some questions :
- The output of the linear layer does not limit the sum of the output numbers of each neuron to 1;
- Depending on the input , The output of the linear layer may be negative .
this 2 Point violates the basic axiom of probability .
3、 ... and 、softmax Overflow on (overflow) Overflow with bottom (underflow)
3.1 It spills over
When x i x_i xi When the value of is too large , The value of index operation is too large , If it is beyond the accuracy range , Then overflow .
>>> torch.exp(torch.tensor([1000]))
tensor([inf])
3.2 underflow
When the vector x \boldsymbol x x Each element of x i x_i xi When the values of are negative numbers with large absolute values , be e x p ( x i ) exp(x_i) exp(xi) The value of is very small, and it is taken down beyond the accuracy range 0, The denominator ∑ j e x p ( j ) \sum_jexp(j) ∑jexp(j) The values for 0.
>>> X = torch.ones(1, 3) * (-1000)
>>> softmax(X)
tensor([[nan, nan, nan]])
3.3 Avoid spillovers
Reference resources 1 The technique in :
- Find vector x \boldsymbol x x Maximum of :
c = m a x ( x ) c=max(\boldsymbol x) c=max(x) - s o f t m a x softmax softmax The molecules of 、 Divide the denominator by c c c
s o f t m a x ( x i − c ) = e x p ( x i − c ) ∑ j n e x p ( x j − c ) = e x p ( x i ) e x p ( − c ) ∑ j n e x p ( x i ) e x p ( − c ) = s o f t m a x ( x i ) softmax(x_i - c) = \frac{exp(x_i-c)}{\sum_j^nexp(x_j-c)}=\frac{exp(x_i)exp(-c)}{\sum_j^nexp(x_i)exp(-c)}=softmax(x_i) softmax(xi−c)=∑jnexp(xj−c)exp(xi−c)=∑jnexp(xi)exp(−c)exp(xi)exp(−c)=softmax(xi)
After the above transformation , The maximum value of the molecule becomes e x p ( 0 ) = 1 exp(0)=1 exp(0)=1, Avoid upper overflow ;
At least + 1 +1 +1, Avoid denominator 0 Cause lower overflow .
∑ j n e x p ( x j − c ) = e x p ( x i − c ) + e x p ( x 2 − c ) + . . . + e x p ( x m a x − c ) = e x p ( x 1 − c ) + e x p ( x 2 − c ) + . . . + 1 \sum_j^nexp(x_j-c) =exp(x_i-c)+exp(x_2-c)+...+exp(x_{max}-c)\\ =exp(x_1-c) + exp(x_2-c)+...+1 j∑nexp(xj−c)=exp(xi−c)+exp(x2−c)+...+exp(xmax−c)=exp(x1−c)+exp(x2−c)+...+1
def softmax_trick(X):
c, _ = torch.max(X, dim=1, keepdim=True)
return torch.exp(X - c) / torch.sum(torch.exp(X - c), dim=1).reshape(-1, 1)
>>> X = torch.tensor([[-1000, 1000, -1000]])
>>> softmax_trick(X)
tensor([0., 1., 0.])
>>> softmax(X)
tensor([[0., nan, 0.]])
pytorch The implementation of has been done to prevent overflow , therefore , Its operation results are similar to softmax_trick
Agreement .
import pytorch.nn.functional as F
>>> X = torch.tensor([[-1000., 1000., -1000.]])
>>> F.softmax(X, dim=1)
tensor([[0., 1., 0.]])
3.4 Log-Sum_Exp Trick2( take log operation )
1. Avoid spillage
Logarithmic operation can change multiplication into addition , namely : l o g ( x 1 x 2 ) = l o g ( x 1 ) + l o g ( x 2 ) log(x_1x_2) = log(x_1) + log(x_2) log(x1x2)=log(x1)+log(x2). When two very small numbers x 1 、 x 2 x_1、x_2 x1、x2 Multiplying time , The product becomes smaller , If the accuracy is exceeded, it will overflow ; The logarithmic operation turns the product into addition , Reduce the risk of lower overflow .
2. Avoid overflow
l o g − s o f t m a x log-softmax log−softmax The definition of :
l o g − s o f t m a x = l o g [ s o f t m a x ( x i ) ] = l o g ( e x p ( x i ) ∑ j n e x p ( x j ) ) = x i − l o g [ ∑ j n e x p ( x j ) ] \begin{aligned} log-softmax &=log[softmax(x_i)] \\ &= log(\frac{exp(x_i)}{\sum_j^nexp(x_j)}) \\ &=x_i - log[\sum_j^nexp(x_j)] \end{aligned} log−softmax=log[softmax(xi)]=log(∑jnexp(xj)exp(xi))=xi−log[j∑nexp(xj)]
Make y = l o g ∑ j n e x p ( x j ) y=log\sum_j^nexp(x_j) y=log∑jnexp(xj), When x j x_j xj When the value of is too large , y y y There is a risk of spillover , therefore , with 3.3 The same in Trick:
y = l o g ∑ j n e x p ( x j ) = l o g ∑ j n e x p ( x j − c ) e x p ( c ) = c + l o g ∑ j n e x p ( x j − c ) \begin{aligned} y &= log\sum_j^nexp(x_j) \\ & = log\sum_j^nexp(x_j-c)exp(c) \\ & = c +log\sum_j^nexp(x_j-c) \end{aligned} y=logj∑nexp(xj)=logj∑nexp(xj−c)exp(c)=c+logj∑nexp(xj−c)
When c = m a x ( x ) c=max(\boldsymbol x) c=max(x) when , Avoid overflow .
here , l o g − s o f t m a x log-softmax log−softmax The calculation formula of becomes :( In fact, it is equivalent to directly 3.3 Chaste Trick Take the logarithm )
l o g − s o f t m a x = ( x i − c ) − l o g ∑ j n e x p ( x j − c ) log-softmax = (x_i-c)-log\sum_j^nexp(x_j-c) log−softmax=(xi−c)−logj∑nexp(xj−c)
Code implementation :
def log_softmax(X):
c, _ = torch.max(X, dim=1, keepdim=True)
return X - c - torch.log(torch.sum(torch.exp(X-c), dim=1, keepdim=True))
>>> X = torch.tensor([[-1000., 1000., -1000.]])
>>> torch.exp(log_softmax(X))
tensor([[0., 1., 0.]])
# pytorch API Realization
>>> torch.exp(F.log_softmax(X, dim=1))
tensor([[0., 1., 0.]])
3.5 log-softmax And softmax The difference between 3
combination 3.3 Chaste Trick And my own understanding :
- stay pytorch In the implementation of ,softmax The result of the operation is equivalent to log_softmax The result of is exponentially calculated
>>> X = torch.tensor([[-1000., 1000., -1000.]])
>>> torch.exp(F.log_softmax(X, dim=1)) == F.softmax(X)
tensor([[True, True, True]])
- Use l o g log log It is more convenient to derive after operation , It can speed up the speed of back propagation 4
∂ ∂ x i l o g s o f t m a x = ∂ ∂ x i [ x i − l o g ∑ j n e x p ( x j ) ] = 1 − s o f t m a x ( x i ) \begin{aligned} \frac{\partial}{\partial x_i}logsoftmax&=\frac{\partial}{\partial x_i} [{x_i - log\sum_j^nexp(x_j)]} \\ &= 1 - softmax(x_i) \end{aligned} ∂xi∂logsoftmax=∂xi∂[xi−logj∑nexp(xj)]=1−softmax(xi)
边栏推荐
- #797div3 A---C
- In depth analysis of ArrayList source code, from the most basic capacity expansion principle, to the magic iterator and fast fail mechanism, you have everything you want!!!
- Semantic segmentation | learning record (1) semantic segmentation Preface
- Redisson distributed lock unlocking exception
- Popular science | what is soul binding token SBT? What is the value?
- Nmap tool introduction and common commands
- Industrial Development and technological realization of vr/ar
- 微信小程序uniapp页面无法跳转:“navigateTo:fail can not navigateTo a tabbar page“
- Flutter 3.0框架下的小程序运行
- List of top ten domestic industrial 3D visual guidance enterprises in 2022
猜你喜欢
Introduction to grpc for cloud native application development
Completion report of communication software development and Application
[knowledge map paper] r2d2: knowledge map reasoning based on debate dynamics
Introduction à l'outil nmap et aux commandes communes
线程死锁——死锁产生的条件
Disk rust -- add a log to the program
Coreldraw2022 download and install computer system requirements technical specifications
分布式定时任务之XXL-JOB
leetcode 869. Reordered Power of 2 | 869. Reorder to a power of 2 (state compression)
Strive to ensure that domestic events should be held as much as possible, and the State General Administration of sports has made it clear that offline sports events should be resumed safely and order
随机推荐
《ClickHouse原理解析与应用实践》读书笔记(7)
Emqx 5.0 release: open source Internet of things message server with single cluster supporting 100million mqtt connections
Introduction to Microsoft ad super Foundation
See how names are added to namespace STD from cmath file
Ml self realization / linear regression / multivariable
Beaucoup d'enfants ne savent pas grand - chose sur le principe sous - jacent du cadre orm, non, ice River vous emmène 10 minutes à la main "un cadre orm minimaliste" (collectionnez - le maintenant)
adb工具介绍
Leetcode question brushing record | 485_ Maximum number of consecutive ones
Vim 字符串替换
数据链路层及网络层协议要点
burpsuite
Reading notes of Clickhouse principle analysis and Application Practice (7)
Clickhouse principle analysis and application practice "reading notes (8)
For friends who are not fat at all, nature tells you the reason: it is a genetic mutation
Wechat applet uniapp page cannot jump: "navigateto:fail can not navigateto a tabbar page“
#797div3 A---C
Introduction to ADB tools
VIM string substitution
OpenGL/WebGL着色器开发入门指南
Can you write the software test questions?