当前位置:网站首页>【线性神经网络】softmax回归
【线性神经网络】softmax回归
2022-07-31 04:23:00 【PBemmm】
one - hot 编码
一般用于分类问题,其特征是离散的
很简单,用n个状态表示表示n个特征,其中只有一个状态取值为1,其他全为0
交叉熵
使用真实概率与预测概率的区别来作为损失
损失函数
均方损失 L2 Loss

绿色曲线是似然函数,黄色为梯度
当离原点较远的时候,对参数的更新幅度可能较大,这里引出L1 Loss
绝对值损失 L1 Loss

Huber's Rubust Loss
结合L1 Loss 和 L2 Loss的损失函数

Softmax从0实现
读数据
import torch
from IPython import display
from d2l import torch as d2l
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)数据集是Fashion-MNIST,10个种类的图像,每个种类6000张,训练集为60000,测试集为10000,批量大小256
num_inputs = 784
num_outputs = 10
W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)
b = torch.zeros(num_outputs, requires_grad=True)每张图片为28 * 28,展开为784维向量,输出层为10个种类
这里权重W为(784,10)的矩阵,输入数据x为长度784的向量,对于O(1 -> 10),每个Oj对应
Oj = < W[j],X > + bj,所以W的列数为num_outputs,也就是O的个数(种类)
显然,b也要和O对应
softmax
顾名思义,softmax与hardmax对应,常规对序列求最值就是hardmax,而在分类中,采用one hot编码,以及引入置信度,根据softmax引入的e的指数,我们只关心是否能让正确类别的预测值以及置信度足够大,而不关心非正确类别。模型能够把真正的类别和其他的类别拉开一个距离。
def softmax(X):
X_exp = torch.exp(X)
partition = X_exp.sum(1, keepdim=True)
return X_exp / partition # 这里应用了广播机制模型
def net(X):
return softmax(torch.matmul(X.reshape((-1, W.shape[0])), W) + b)交叉熵
y = torch.tensor([0, 2])
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
y_hat[[0, 1], y]y记录真实类别的索引,y_hat[ [0,1] , y ]返回的是第一组[0.1,0.3,0.6]的真实类别索引的元素0.1和第二组的0.5
![]()
def cross_entropy(y_hat, y):
return - torch.log(y_hat[range(len(y_hat)), y])
cross_entropy(y_hat, y)![]()
softmax简洁实现
导入
import torch
from torch import nn
from d2l import torch as d2l
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)初始化模型参数
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))
def init_weights(m):
if type(m) == nn.Linear:
nn.init.normal_(m.weight, std=0.01)
net.apply(init_weights);nn.Flatten()用来调整网络输入的形状
trainer = torch.optim.SGD(net.parameters(), lr=0.1)损失函数
loss = nn.CrossEntropyLoss(reduction='none')优化算法
trainer = torch.optim.SGD(net.parameters(), lr=0.1)训练
num_epochs = 10
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)
边栏推荐
- 「 每日一练,快乐水题 」1331. 数组序号转换
- [CV project debugging] CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT problem
- (八)Math 类、Arrays 类、System类、Biglnteger 和 BigDecimal 类、日期类
- Two address pools r2 are responsible for managing the address pool r1 is responsible for managing dhcp relays
- Safety 20220715
- (4) Recursion, variable parameters, access modifiers, understanding main method, code block
- (六)枚举、注解
- 问题7:列表的拼接
- MySQL数据库增删改查(基础操作命令详解)
- Error EPERM operation not permitted, mkdir ‘Dsoftwarenodejsnode_cache_cacach两种解决办法
猜你喜欢

手把手实现图片预览插件(三)

$attrs/$listeners

LocalDate加减操作及比较大小

A brief introduction to the showDatePicker method of the basic components of Flutter

SOLVED: After accidentally uninstalling pip (two ways to manually install pip)

论治理与创新 | 2022开放原子全球开源峰会OpenAnolis分论坛圆满召开

XSS靶场(三)prompt to win

(4) Recursion, variable parameters, access modifiers, understanding main method, code block
![[C language] General method of expression evaluation](/img/59/cf43b7dd16c203b4f31c1591615955.jpg)
[C language] General method of expression evaluation

pom文件成橘红色未加载的解决方案
随机推荐
SQL Interview Questions (Key Points)
两个地址池r2负责管地址池r1负责管dhcp中继
BUG消灭者!!实用调试技巧超全整理
[AUTOSAR-RTE]-5-Explicit (explicit) and Implicit (implicit) Sender-Receiver communication
MySQL to revise the root password
$attrs/$listeners
The BP neural network
Port inspection steps - 7680 port analysis - Dosvc service
C language from entry to such as soil, the data store
RESTful api interface design specification
Understanding of the presence of a large number of close_wait states
Daily practice of LeetCode - palindrome structure of OR36 linked list
ClickHouse:设置远程连接
Safety 20220712
type_traits metaprogramming library learning
$parent/$children and ref
prompt.ml/15中<svg>标签使用解释
(8) Math class, Arrays class, System class, Biglnteger and BigDecimal classes, date class
Can‘t load /home/Iot/.rnd into RNG
MySQL修改root账号密码