当前位置:网站首页>一、常见损失函数的用法
一、常见损失函数的用法
2022-07-29 05:22:00 【MY头发乱了】
前言
定义损失函数的常用方法,其中包括多分类交叉熵、均方差、二分类交叉熵的用法。其作用包括:1.衡量模型输出值和标签值的差异;2.评估模型的预测值与真实值不一致程度;3.神经网络中优化的目标函数,损失函数越小,预测值越接近真实值,模型健壮性也越好。
一、L1-loss(MAE)、L2- loss(MSE)、smooth L1- loss、交叉熵损失函数是什么?
二、使用步骤
1.损失函数方法
代码如下(示例):
#定义损失函数,更新梯度----
loss_fn = torch.nn.CrossEntropyLoss()#多分类交叉熵不需要用激活幂函数输出
# loss_fn=torch.nn.MSELoss()#均方差
# loss_fn=torch.nn.BCELoss()#二分类交叉熵
# loss_fn=torch.nn.BCEWithLogitsLoss()#自动引入激活函数
2.代码操作
代码如下(示例):
import torch
from torchvision import datasets,transforms
from torch.utils.data import DataLoader#类时加载数据的核心,返回可迭代的数据
import os
import matplotlib.pyplot as plt
class Net(torch.nn.Module):
def __init__(self):
super().__init__()#继承
self.fc1 = torch.nn.Sequential(
torch.nn.Linear(784,256),
torch.nn.BatchNorm1d(256),
torch.nn.ReLU())
#nn.Sequential 将网络层和激活函数结合起来,输出激活后的网络节点。
#nn.Linear(in_features,out_features,bias = True )对传入数据应用线性变换
#784,每个输入样本的大小----即为28*28,图像的像素值w*h
# 256 每个输出样本的大小-----即784通过Linear函数
#BatchNorm1d(256),#自适应标准化-正态分布--输入值落在非线性函数敏感的区域,避免梯度消失问题产生
#nn.ReLU() 激活函数 relu
self.fc2 = torch.nn.Sequential(
torch.nn.Linear(256,128),
torch.nn.BatchNorm1d(128),
torch.nn.ReLU())
self.fc3 = torch.nn.Linear(128,10)
def forward(self,x):#forward函数里面实现在前向传播运算
# print(x.shape)
#N,C,H,W(batchsize,channels,x,y)-->N,V
#x.size(0)==batchsize,转换后有几行
#最后通过x.view(x.size(0), -1)将tensor的结构转换为了(batchsize, channels*x*y),
# 即将(channels,x,y)拉直,然后就可以和fc层连接了
#-1指在不告诉函数有多少列的情况下,根据原tensor数据和batchsize自动分配列数。
x = torch.reshape(x,[x.size(0),-1])#变换形状,换成2维,reshape=view
# print(x.shape)
y=self.fc1(x)#N,256
y=self.fc2(y)#N,128 #y=w*sqrt(x2+bias)
# y=self.fc3(y)#N,10
self.y=self.fc3(y)
y=torch.softmax(self.y,1)
return y
if __name__ == '__main__':
save_params = r"./save_params/parmas.pth"#保存参数
save_net = r"./save_params/net.pth"#保存网络
transf = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(mean=[0.5,],std=[0.5,])])
#transforms.Compose 将transforms列表里面的transform操作进行遍历。
#transforms.ToTensor() 灰度范围从0-255变换到0-1之间
#transforms.Normalize把0-1变换到(-1,1),(image-mean)/std
train_data = datasets.MNIST("./data",train=True,transform=transf,download=True)#读取训练数据
test_data = datasets.MNIST("./data",train=False,transform=transf,download=False)#读取测试数据
# 100涨图片,True 是否打乱,随机,给出不同的特征才能学习
trin_loader = DataLoader(train_data,100,True)#加载数据
test_loader = DataLoader(test_data,100,True)
# DataLoader()
# 利用多进程来加速batchdata的处理
# 直观的网络输入数据结构,便于使用和扩展
print(train_data.data.shape)
print(train_data.targets.shape)
print(test_data.data.shape)
print(test_data.targets.shape)
print(test_data.classes)
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
net=Net().to(device)#网络加载器开始读取数据时的tensor变量copy一份到device所指定的cuda上去
if os.path.exists(save_params):
net.load_state_dict(torch.load(save_params))#只加载参数
print("参数加载成功")
else:
print("No params!")
# net = torch.load(save_net).to(device)#加载参数和网络
# loss_fn = torch.nn.CrossEntropyLoss()#多分类交叉熵,定义损失函数,更新梯度----
# loss_fn=torch.nn.MSELoss()#均方差
# loss_fn=torch.nn.BCELoss()#二分类交叉熵
loss_fn=torch.nn.BCEWithLogitsLoss()#二分类交叉熵,对输入值自动做sigmoid
# optim = torch.optim.SGD(net.parameters(),lr=1e-3)#比较稳定
optim = torch.optim.Adam(net.parameters(),lr=1e-3)#创建优化器
#torch.optim.Adam 优化器
# (net.parameters(), 待优化参数的iterable或者是定义了参数组的dict
# lr=1e-3) 学习率或步长因子
#测试,实时画图分析
a = []
b = []
# plt.ion()
net.train()
for epoch in range(1):
for i ,(x,y) in enumerate(trin_loader):
x = x.to(device)
y = y.to(device)
y_ = torch.zeros(len(y), max(y) + 1).to(device)
y_[torch.arange(len(y)), y] = 1
out = net(x)#前向输出
# loss = loss_fn(out,y)#求损失
loss = loss_fn(net.y,y_)
optim.zero_grad()#清空当前梯度
loss.backward()#计算当前梯度
optim.step()#沿着当前梯度更新一步
# a.append(i)
# b.append(loss.item())
# plt.clf()
# plt.plot(a,b)
# plt.pause(0.1)
if i%50==0:
print("loss",loss.item())
plt.ioff()
plt.show()
# 测试
eval_loss=0
eval_acc=0
net.eval()
for i,(x,y) in enumerate(test_loader):
x = x.to(device)#x传给网络
y = y.to(device)
y_ = torch.zeros(len(y), max(y) + 1).to(device)
y_[torch.arange(len(y)), y] = 1
out = net(x)
# loss = loss_fn(out, y)
loss = loss_fn(out,y_)
eval_loss += loss.item() * y.size(0)
eval_acc += (y == torch.argmax(out, 1)).cpu().sum().item()
avg_loss = eval_loss / len(test_data)
avg_acc = eval_acc / len(test_data)
print(avg_loss)
print(avg_acc)
if not os.path.exists("./save_params"):
os.mkdir("./save_params")
torch.save(net.state_dict(),"./save_params/parmas.pth")#只保存参数
torch.save(net,"./save_params/net.pth")
总结
提示:这里对文章进行总结:
loss_fn = torch.nn.CrossEntropyLoss()#多分类交叉熵,输出不需要加激活函数。
loss_fn=torch.nn.MSELoss()#均方差、输出需要加激活函数。
loss_fn=torch.nn.BCELoss()#二分类交叉熵、输出需要加激活函数。
loss_fn=torch.nn.BCEWithLogitsLoss()#二分类交叉熵,对输入值自动做sigmoid
边栏推荐
- 在uni-app项目中,如何实现微信小程序openid的获取
- Spring, summer, autumn and winter with Miss Zhang (3)
- ANR优化:导致 OOM 崩溃及相对应的解决方案
- 迁移学习——Transitive Transfer Learning
- Windos下安装pyspider报错:Please specify --curl-dir=/path/to/built/libcurl解决办法
- anaconda中移除旧环境、增加新环境、查看环境、安装库、清理缓存等操作命令
- Reporting Services- Web Service
- [overview] image classification network
- torch.nn.Embedding()详解
- 虚假新闻检测论文阅读(三):Semi-supervised Content-based Detection of Misinformation via Tensor Embeddings
猜你喜欢
【Attention】Visual Attention Network
Android studio login registration - source code (connect to MySQL database)
[target detection] KL loss: bounding box progression with uncertainty for accurate object detection
Research on the implementation principle of reentrantlock in concurrent programming learning notes
Detailed explanation of MySQL statistical function count
The differences and reasons between MySQL with and without quotation marks when querying string types
Operation commands in anaconda, such as removing old environment, adding new environment, viewing environment, installing library, cleaning cache, etc
第三周周报 ResNet+ResNext
Interesting talk about performance optimization thread pool: is the more threads open, the better?
简单聊聊 PendingIntent 与 Intent 的区别
随机推荐
【Transformer】ATS: Adaptive Token Sampling For Efficient Vision Transformers
Ribbon学习笔记二
ROS常用指令
Analysis on the principle of flow
Nifi changed UTC time to CST time
Flink connector Oracle CDC synchronizes data to MySQL in real time (oracle19c)
【go】defer的使用
nacos外置数据库的配置与使用
【语义分割】语义分割综述
Configuration and use of Nacos external database
Lock lock of concurrent programming learning notes and its implementation basic usage of reentrantlock, reentrantreadwritelock and stampedlock
主流实时流处理计算框架Flink初体验。
【Transformer】TransMix: Attend to Mix for Vision Transformers
[CV] what are the specific numbers of convolution kernels (filters) 3*3, 5*5, 7*7 and 11*11?
【Transformer】TransMix: Attend to Mix for Vision Transformers
Spring, summer, autumn and winter with Miss Zhang (3)
Windos下安装pyspider报错:Please specify --curl-dir=/path/to/built/libcurl解决办法
虚假新闻检测论文阅读(二):Semi-Supervised Learning and Graph Neural Networks for Fake News Detection
【Attention】Visual Attention Network
研究生新生培训第二周:卷积神经网络基础