当前位置:网站首页>【深度学习】自编码器
【深度学习】自编码器
2022-07-25 09:22:00 【繁星¹⁸⁹⁵】
自编码器
自编码器的结构

- 为避免自编码器学出:原始数据*1=重建数据,这种无用的结构,要点是中间层要比输入层低维,强制其在编码时产生信息损失。
- 还有一种做法是:去噪自编码器,也就是对输入数据加入人为随机噪声,然后让解码器重建出没有噪声的图像,那这样就可以避免自编码学出乘1这种结构,因为乘1不能恢复没有噪声的图像。

自编码器的原理理解
- 自编码器输入是高维数据,中间隐藏层是低维特征,输出是和原始输入数据相同维度的重建数据。
- 输入的数据看起来是随机的,但其实存在着某种结构。某种规律
- 自编码器本质是在做投影,将高维数据投影成低维特征。编码器部分的网络存储了高维数据到低维特征的投影函数(这个函数可能不是常规的可解析的)。解码器部分的网络存储了低维特征到高维数据的反投影函数。
举个例子:
- 对于一系列3维的输入数据,其在空间中的分布如下图表示。
- 当然,对于呈现这样分布的每个点我们可以用3维点(x,y,z)来表示,但是我们可以明显地发现这些数据分布其实是被限制在1维螺旋中的,也就是说,只要我们记住这个螺旋函数f,那么每个点我们用1维特征(θ)就能表示。也即,θ=f(x,y,z)。
- 这里,3维的(x,y,z)就是输入自编码器的高维数据,1维θ就是自编码器学习得到(可以说是投影产生)的低维特征,投影函数f就是自编码器的编码器本身,通过网络结构和参数表示。
- 解码器则将低维特征反投影回高维数据,本质上在学这么一个函数(x,y,z)=g(θ)。
- f 和 g 就是数学中的函数和反函数

- 这里的例子中,点(3,0,0.5)就可以用2Π来表示,(x,y,z)与θ的关系则存储与自编码器网络中,
- 这里(x,y,z)与θ的关系有着显式的解析函数,但实际应用中,高维数据和低维特征的关系可能及其复杂,其函数关系无法用解析式表达,此时就可以通过神经网络来表达,神经网络的高参数量和非线性使其具备很强的拟合能力,只有参数量够,理论上神经网络可以拟合任意的函数。

- 解码器: 从低维到高维

自编码器的作用
- 训练好的自编码器,其编码器部分将是一个高效的特征提取器,提取到的特征可以用于很多任务,如分类,压缩。
- 训练好的一组自编码器(包含编码器和解码器),可以用来做异常数据检测,如果数据异常,也就是其分布不是我们之前学习的那种规律,那么重建误差将会很大。反之我们可以根据重建误差很大来推断输入数据是异常的。
自编码器的代码实现
数据集是Minist
import os
import torch
from torch import nn, optim
from torch.autograd import Variable
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torchvision.utils import save_image
# 加载数据集
def get_data():
data_tf = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
train_data = datasets.MNIST(root='./data', train=True, transform=data_tf, download=True)
train_loader = DataLoader(train_data, shuffle=True, batch_size=batch_size, drop_last=True)
return train_loader
class VAE(nn.Module):
def __init__(self):
super(VAE, self).__init__()
self.fc1 = nn.Linear(784, 400)
self.fc21 = nn.Linear(400, 20) # 均值
self.fc22 = nn.Linear(400, 20) # 方差
self.fc3 = nn.Linear(20, 400)
self.fc4 = nn.Linear(400, 784)
def encoder(self, x):
h1 = F.relu(self.fc1(x))
mu = self.fc21(h1)
logvar = self.fc22(h1)
return mu, logvar
def decoder(self, z):
h3 = F.relu(self.fc3(z))
x = F.tanh(self.fc4(h3))
return x
# 重新参数化
def reparametrize(self, mu, logvar):
std = logvar.mul(0.5).exp_() # 计算标准差
if torch.cuda.is_available():
eps = torch.cuda.FloatTensor(std.size()).normal_() # 从标准的正态分布中随机采样一个eps
else:
eps = torch.FloatTensor(std.size()).normal_()
eps = Variable(eps)
return eps.mul(std).add_(mu)
def forward(self, x):
mu, logvar = self.encoder(x)
z = self.reparametrize(mu, logvar)
return self.decoder(z), mu, logvar
def loss_function(recon_x, x, mu, logvar):
MSE = reconstruction_function(recon_x, x)
# loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
KLD = torch.sum(KLD_element).mul_(-0.5)
# KL divergence
return MSE + KLD
def to_img(x):
x = (x + 1.) * 0.5
x = x.clamp(0, 1)
x = x.view(x.size(0), 1, 28, 28)
return x
if __name__ == '__main__':
# 超参数设置
batch_size = 128
lr = 1e-3
epoches = 100
model = VAE()
if torch.cuda.is_available():
model.cuda()
train_data = get_data()
reconstruction_function = nn.MSELoss(reduction='sum')
optimizer = optim.Adam(model.parameters(), lr=lr)
for epoch in range(epoches):
for img, _ in train_data:
img = img.view(img.size(0), -1)
img = Variable(img)
if torch.cuda.is_available():
img = img.cuda()
# forward
output, mu, logvar = model(img)
loss = loss_function(output, img, mu, logvar)/img.size(0)
# backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
print("epoch=", epoch, loss.data.float())
if (epoch+1) % 10 == 0:
print("epoch = {}, loss is {}".format(epoch+1, loss.data))
pic = to_img(output.cpu().data)
if not os.path.exists('./vae_img1'):
os.mkdir('./vae_img1')
save_image(pic, './vae_img1/image_{}.png'.format(epoch + 1))
torch.save(model, './vae.pth')
总结
- 自编码器本质就是将高维数据投影成低维特征,编码器网络拟合着投影函数,解码器网络拟合着反投影函数。
- 神经网络本质就是一个拟合函数。
边栏推荐
- About C and OC
- 关于学生管理系统(注册,登录,学生端)
- The shortest path problem Bellman Ford (single source shortest path) (illustration)
- How many regions can a positive odd polygon be divided into
- 【代码源】 每日一题 素数之欢(bfs)
- Numpy array attribute, shape changing function, basic operation
- Flex layout syntax and use cases
- How to deploy the jar package to the server? Note: whether the startup command has nohup or not has a lot to do with it
- Swagger2显示get接口有问题,加注解就能解决
- [code source] daily one question non decreasing 01 sequence
猜你喜欢

CoreData存储待办事项

微信小程序初步了解及实现底部导航栏

初识Opencv4.X----图像直方图均衡

Operation 7.19 sequence table

Redis installation (Ubuntu)

cf #785(div2) C. Palindrome Basis

The jar package has been launched on Alibaba cloud server and the security group has been opened, but postman still can't run. What should we do

Swagger2 shows that there is a problem with the get interface, which can be solved with annotations

一张图讲解 SQL Join 左连 又连

【代码源】每日一题 算的我头都大啦
随机推荐
解决esp8266无法连接手机和电脑热点的问题
基于人脸识别的树莓派门禁系统
[GKCTF 2021]easynode
如何将Jar包部署到服务器,注:启动命令有无nohup有很大关系
初识Opencv4.X----方框滤波
基于stm32的恒功率无线充电
laravel 调用第三方 发送邮件 (php)
深入解读C语言随机数函数和如何实现随机数
文件--初识
Machine learning -- detailed introduction of standardscaler (), transform (), fit () in sklearn package
[code source] a prime number of fun every day (BFS)
OC--继承和多态and指针
作业7.15 shell脚本
Redis string structure command
cf #785(div2) C. Palindrome Basis
Numpy - Construction of array
How to obtain location information (longitude and latitude) by uni app
Data control language (DCL)
pdf2Image Pdf文件存为jpg NodeJs实现
作业7.19 顺序表