当前位置:网站首页>VAE: understanding and implementation of variational self encoder
VAE: understanding and implementation of variational self encoder
2022-07-28 19:26:00 【I'm Mr. rhubarb】
VAE Understanding and Realization
List of articles
1. understand VAE
VAE It's a kind of generative model , Its assumption is in low dimensional space ( dimension k,k<d) There is a about input X( dimension d) The true probability distribution of : p g t ( x ) p_{gt}(x) pgt(x). Now consider one with X Related potential variables z( dimension d),z The distribution of is p ( z ) p(z) p(z), It is often called a priori distribution ( Usually normal distribution ), Therefore, the probability distribution can be rewritten p g t ( x ) p_{gt}(x) pgt(x):

A priori distribution is what we set in advance , Known
From the above formula, we can understand , As long as we can find p g t ( x ∣ z ) p_{gt}(x|z) pgt(x∣z), Then we can use it to generate new samples ( First, sample from a priori distribution z, According to p g t ( x ∣ z ) p_{gt}(x|z) pgt(x∣z) Sampling generates new X):

Thus, it is expected to use maximum likelihood estimation to solve P θ = { p θ ( x ∣ z ) ∣ θ } P_{\theta}=\left\{ {p_\theta}(x|z) | \theta \right\} Pθ={ pθ(x∣z)∣θ}:

But the above formula is difficult to solve mathematically , therefore VAE A new probability distribution is constructed q ( z ∣ x ) q(z|x) q(z∣x), Also known as code distribution (encoder distribution), Now rewrite the maximum likelihood estimation :
Be similar to EM Decomposition of Algorithm

Analyze by line :
The first 1 OK, there's nothing to say , Is the maximum likelihood estimation of discrete values ;
The first 2 Walk in l o g P ( x ) logP(x) logP(x) Introduce code distribution q ( z ∣ x ) q(z|x) q(z∣x), Its integral is equal to the left formula ;
The first 3 Make some changes to the row , introduce P ( z , x ) P(z,x) P(z,x) And P ( z ∣ x ) P(z|x) P(z∣x);
The first 4 Line splitting log function , Multiply and add , The latter one is q ( z ∣ x ) q(z|x) q(z∣x) And P ( z ∣ x ) P(z|x) P(z∣x) Between KL The divergence ( Its value is always greater than or equal to 0);
The first 5 Line is the first 4 Line previous , because KL The divergence ≥0, therefore l o g P ( x ) logP(x) logP(x) There is a lower bound , be called lower bound Lb( Also known as ELBO (Evidence Lower BOund) ).
Here I borrowed teacher Li Hongyi's PPT, among P(x) It means the real distribution mentioned before p g t ( x ) p_{gt}(x) pgt(x)
Now we know that l o g P ( x ) logP(x) logP(x) It can be divided into ELBO And KL Divergence of these two , And then we expect that on the whole data set l o g P ( x ) logP(x) logP(x) Are the biggest , How do you optimize it ? Since the maximum likelihood estimation is not easy to do , Then let's optimize ELBO, Try to make him bigger , And let l o g P ( x ) logP(x) logP(x) And then it gets bigger .

because P(z) Is a known prior distribution , Here the P(x|z) And q(z|x) Joint optimization makes Lb Maximum , Why use two joint optimizations ?
Suppose we only use P(x|z) To maximize Lb, But we don't know KL(q(z|x)||P(z|x)) How to change , Therefore, it is not known whether the final likelihood increases .
If we fix it P(x|z), Use only q(z|x) Optimize Lb, We know P(x) And q(z|x) irrelevant , Therefore, the final likelihood size is fixed , With Lb Bigger , therefore KL(q(z|x)||P(z|x)) smaller , As shown in the figure below :

So here we use P(x|z) And q(z|x) Joint optimization makes Lb Maximum , The final ideal result KL Will be close to 0,Lb Bigger , that l o g P ( x ) logP(x) logP(x) And it will increase , Again q(z|x) Also solve with P(z|x).
Now? , We have made it clear that the goal is to optimize Lb, Now rewrite Lb expression :

Or do some conversion to the formula , So it can be converted into two , The former one is -KL(q(z|x)||P(z)), This one is forever ≤0. Therefore, the optimization goal is further minimized KL(q(z|x)||P(z)) And maximize the latter :

In the previous item q ( z ∣ x ) q(z|x) q(z∣x) Code distribution , And assume that it is multivariate Gaussian distribution (GMM angle ), A neural network is used to base the input x, Prediction and z The mean and variance of the relevant Gaussian distribution , and P For our prior distribution , Usually normal distribution , Now let's deduce .
because VAE Consider the multivariate positive distribution with independent components , Therefore, it is only necessary to deduce the case of univariate normal distribution , The picture and derivation come from https://zhuanlan.zhihu.com/p/34998569

The whole result is divided into three integrals , The first is actually − l o g σ 2 −logσ^2 −logσ2 Integral multiplied by probability density ( That is to say 1), So the result is − l o g σ 2 −logσ^2 −logσ2; The second term is actually the second moment of the normal distribution , Friends who are familiar with normal distribution should know that the second moment of normal distribution is μ 2 + σ 2 μ^2+σ^2 μ2+σ2; And by definition , The third item is actually “- Variance divided by variance =-1”. So the total result is :

therefore , The loss function used in the network is also the following formula ( Also known as KLD loss), Here we need to pay special attention to δ i \delta_{i} δi Is variance log, And in the above formula δ 2 \delta^2 δ2 Is variance , Bloggers are lazy and don't write their own derivation , So while watching , Let's switch by ourselves :

This piece can correspond to code digestion , Why should model output be log Variance ? This is because the variance must be greater than 0, The output of the model may be less than 0, add exp After the operation , It is always greater than 0, And take this as the final variance .
secondly , Maximizing the latter is actually based on q(z|x) Maximize l o g ( P ( x ∣ z ) ) log(P(x|z)) log(P(x∣z)), It means according to q(z|x) We can sample z, Then according to this z It can reproduce x, This is the ordinary Auto-Encoder What we're doing , Calculate the... Between the generated sample and the real sample MSE Loss:

Li Hongyi , from GMM Point of view VAE :https://www.youtube.com/watch?v=8zomhgKrsmQ&t=2780s, Strongly recommend
Since then , Whole VAE The structure of the is very clear , For data generation , We introduce an additional prior distribution p(z), And introduce a new distribution q(z|x) And let it approach p(z), So as to reduce KL The divergence , This also constitutes VAE The first part of the loss :KLD-Loss. The other part is to maximize logP(x|z), That is, the distribution q(z|x) By sampling z It can reconstruct the input very well , Constitute the second part of the loss :MSE-Loss.
Because the encoder , Decoders are all Neural Networks , It's not useful here EM Algorithmic solution , Instead, the gradient descent method is directly used for optimization , Optimize both objectives at the same time , To find the optimal solution . After training VAE, Encoder output distribution q(x|z) Very close to p(z), The input samples can also be well reconstructed from the values sampled from this distribution , and p(z) Is the standard normal distribution . So directly from the normal distribution p(z) In the sample , Then send it to the decoder to generate new samples , And the generated samples are also very similar to the input samples .
2. Model implementation
According to the previous derivation , Each sample x First feed encoder in , To calculate the q(z|x) Corresponding to the mean value 、 variance , This determines its distribution ( Gaussian mixture model ), And we expect this distribution to approximate a priori distribution P(z)[ Usually normal distribution ], This part calculates KLD loss. After getting the distribution , We from q(z|x) Distribute the sampling variables to get z, Deliver to decoder Generate samples in , Then calculate according to the generated sample and the real sample MSE loss.
The picture is from https://zhuanlan.zhihu.com/p/34998569

But let's note , We should train this model , The method of gradient descent is usually used . We can't find the derivative in the process of sampling according to the distribution , This leads to the inability to train , Therefore, a method called Reparameterization The technique of . Its core is : From distribution to N ( μ , δ 2 ) N(\mu,\delta^2) N(μ,δ2) Sample to get a value Z, Equivalent to our standard normal distribution N ( 0 , 1 ) N(0,1) N(0,1) Sample a value in ε \varepsilon ε, Z = μ + ε × δ Z=\mu+\varepsilon\times\delta Z=μ+ε×δ, such , The sampling process will not participate in the back propagation process , So that the whole model can be trained normally .

3. Code
For the full code, see ( welcome issue And star):https://github.com/Classmate-Huang/CV_GenerateModel/tree/master/VAE
The code implements two kinds VAE, And in MNIST Experiment on , Here we take the construction of convolutional self encoder as an example , Introduce the code :
def ConvBnRelu(channel_in, channel_out):
# Conv + BatchNorm + ReLU modular
conv_bn_relu = nn.Sequential(
nn.Conv2d(channel_in, channel_out, 3, stride=2, padding=1),
nn.BatchNorm2d(channel_out),
nn.LeakyReLU(0.2, inplace=True)
)
return conv_bn_relu
def DConvBnRelu(channel_in, channel_out):
# Conv + BatchNorm + ReLU modular
d_conv_bn_relu = nn.Sequential(
nn.ConvTranspose2d(channel_in, channel_out, 3, stride=2, padding=1, output_padding=1),
nn.BatchNorm2d(channel_out),
nn.LeakyReLU(0.2, inplace=True)
)
return d_conv_bn_relu
class VariationAutoEncoder(nn.Module):
''' Conv VAE '''
def __init__(self, in_channel=3, img_size=512, latent_dim=256):
super().__init__()
# Encoder
self.encoder = nn.Sequential(
ConvBnRelu(in_channel, 96),
ConvBnRelu(96,128),
ConvBnRelu(128, 256),
ConvBnRelu(256, 256),
)
# decoder
self.decoder = nn.Sequential(
DConvBnRelu(256, 256),
DConvBnRelu(256, 128),
DConvBnRelu(128,96),
# nn.ConvTranspose2d(96, 3, 3, stride=2, padding=1, output_padding=1),
DConvBnRelu(96, 96),
nn.Conv2d(96,in_channel, kernel_size=3, padding=1),
nn.Tanh()
)
# latent code Dimensions
self.latent_dim = latent_dim
self.img_size = img_size
original_dim = 256*(img_size//16)**2
# The original code The dimension of maps to the specified dimension
self.fc_mu = nn.Linear(original_dim, latent_dim)
self.fc_var = nn.Linear(original_dim, latent_dim)
# Used to restore resolution
self.fc_recover = nn.Linear(latent_dim, original_dim)
def reparameterize(self, mu, logvar):
''' Reparameter skill '''
std = torch.exp(0.5 * logvar) # std
eps = torch.randn_like(std) # To sample from a normal distribution
return eps * std + mu # obtain
def forward(self, x):
# encode Encoding phase
fea = self.encoder(x)
fea = torch.flatten(fea, start_dim=1)
# split into mu an var components of the latent Gaussian distribution
# Convert the coding result into a normal distribution mu And log_var
mu = self.fc_mu(fea)
log_var = self.fc_var(fea)
# get latent code
# Use the heavy parameter technique to sample and get code
z = self.reparameterize(mu, log_var)
# decode
# take code Send it to the decoder for decoding , Before decoding, you need to restore to the original resolution
fea = self.fc_recover(z).view(-1, 256, self.img_size//16, self.img_size//16)
out = self.decoder(fea)
return mu, log_var, out
def sample(self, num_sample, device):
''' sampling '''
# Sample from the standard normal distribution ( After training ,q(z|x) Close to standard normal distribution )
z = torch.randn(num_sample, self.latent_dim).to(device)
# Decode and generate samples
fea = self.fc_recover(z).view(-1, 256, self.img_size//16, self.img_size//16)
out = self.decoder(fea)
return out
experimental result
① Refactoring effect :


② Generate effect :

4. summary
VAE The essence of is to find a prior distribution p(z) And the true probability distribution of the sample space p(x) The connection between , Thus, according to p(x|z) Generate new samples . At first glance, it's really difficult , therefore VAE Introduce additional distribution q(z|x), And this distribution is based on Neural Network ( Encoder ) Got ( Generate mean variance , So as to determine a distribution ). Give Way q(z|x) To approach p(z), Then let q(z|x) The sampled values in this distribution can reconstruct the input samples x, bring E q ( z ∣ x ) l o g P ( x ∣ z ) E_{q(z|x)}logP(x|z) Eq(z∣x)logP(x∣z) Maximum , Thus making ELBO rising , The overall likelihood increases .
This also reflects a feature of the generation model , Modeling the real distribution of data through training data , Then generate more data according to this model and distribution .
Understand from a certain angle ,VAE Compared with AE In the process of refactoring latent code Gaussian noise is introduced into , And let Latent Code Satisfy some kind of distribution , Let's use distributed sampling to generate code, Generate new samples .
See teacher Li Hongyi VAE Explain :https://www.youtube.com/watch?v=8zomhgKrsmQ&t=2780s
For starters ,VAE It's really hard to understand , But as we learn more , You will find its design and thought (ELBO、Reparameter trick) It's really clever , People have to sigh the charm of Mathematics .VAE At present, there are also a large number of variant models , It is widely used in various fields .
边栏推荐
- [physical application] atmospheric absorption loss with matlab code
- Application of time series database in Hydropower Station
- pytest 自定义HOOK函数
- ACM warm-up exercise 3 in 2022 summer vacation (detailed)
- Nips18(AD) - 利用几何增广的无监督异常检测《Deep Anomaly Detection Using Geometric Transformations》
- 搜索问题与技术
- 力扣 1331. 数组序号转换
- R语言与数据分析实战11-数据的删除
- 机器学习 --- 模型评估、选择与验证
- 6-20漏洞利用-proftpd测试
猜你喜欢

三类6种地图可视化软件测评,最好用的工具居然是它

关于ASM冗余问题

Libgdx learning path 01: libgdx introduction and running tutorial

BM11 链表相加(二)
![[data analysis] realize SVDD decision boundary visualization based on MATLAB](/img/3b/72bb459a357f836b52aa79171d75fc.png)
[data analysis] realize SVDD decision boundary visualization based on MATLAB

SaltStack系统初始化

Accumulation and development -- the way of commercialization of open source companies

一家芯片公司倒在了B轮

6-20 vulnerability exploitation proftpd test

Application of time series database in bridge monitoring field
随机推荐
Random finite set RFs self-study notes (6): an example of calculation with the formula of prediction step and update step
New this prototype precompiled exercise
uwb模块实现人员精确定位,超宽带脉冲技术方案,实时厘米级定位应用
BM11 链表相加(二)
关于白盒测试,这些技巧你得游刃有余~
Libgdx learning path 01: libgdx introduction and running tutorial
Self-adaptive multi-objective evolutionary algorithm for flexible job shop scheduling with fuzzy pro
C string to short[] method
2022年中总结
Application of time series database in bridge monitoring field
Sudo rosdep init error: cannot download default
机器学习 --- 模型评估、选择与验证
Applet applet jump to official account page
【滤波跟踪】基于EKF、时差和频差定位实现目标跟踪附matlab代码
pytest 自定义HOOK函数
Efficiency comparison of JS array splicing push() concat() methods
Pytorch:交叉熵损失(CrossEntropyLoss)以及标签平滑(LabelSmoothing)的实现
JDBC simple encapsulation
Wechat official account custom sharing and updateappmessagesharedata are all successful. Why is it a link that is shared?
Fundamentals of software testing and development | practical development of several tools in testing and development