当前位置:网站首页>[Linear Neural Network] softmax regression
[Linear Neural Network] softmax regression
2022-07-31 04:29:00 【PBemmm】
one - hot encoding
Generally used for classification problems, where the features are discrete
It is very simple, using n states to represent n features, only one state takes the value 1, and the others are all 0
Cross entropy
Use the difference between the true probability and the predicted probability as the loss
Loss function
L2 Loss

The green curve is the likelihood function, the yellow is the gradient
When it is far from the origin, the update range of the parameters may be larger, which leads to L1 Loss
Absolute value loss L1 Loss
Huber's Rubust Loss
The loss function combining L1 Loss and L2 Loss

Softmax is realized from 0
Read data
import torchfrom IPython import displayfrom d2l import torch as d2lbatch_size = 256train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)The dataset is Fashion-MNIST, 10 types of images, 6000 images for each type, the training set is 60000, the test set is 10000, and the batch size is 256
num_inputs = 784num_outputs = 10W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)b = torch.zeros(num_outputs, requires_grad=True)Each image is 28*28, expanded to 784-dimensional vector, and the output layer is 10 types
Here weight W is a matrix of (784, 10), input data x is a vector of length 784, for O(1 -> 10), each Oj corresponds to
Oj =< W[j], X>+ bj, so the number of columns of W is num_outputs, which is the number (type) of O
Obviously, b also corresponds to O
softmax
As the name suggests, softmax corresponds to hardmax, and hardmax is the routine value of the sequence. In classification, one hot coding is used, and confidence is introduced. According to the index of e introduced by softmax, we only care whether it can make The predicted value and confidence of the correct class are large enough not to care about the incorrect class.The model can distance the real class from other classes.
def softmax(X):X_exp = torch.exp(X)partition = X_exp.sum(1, keepdim=True)return X_exp / partition # The broadcast mechanism is applied hereModel
def net(X):return softmax(torch.matmul(X.reshape((-1, W.shape[0])), W) + b)Cross entropy
y = torch.tensor([0, 2])y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])y_hat[[0, 1], y]y records the index of the real category, y_hat[ [0,1] , y ] returns the element 0.1 of the real category index of the first group [0.1,0.3,0.6] and 0.5 of the second group
![]()
def cross_entropy(y_hat, y):return - torch.log(y_hat[range(len(y_hat)), y])cross_entropy(y_hat, y)![]()
Simple implementation of softmax
Import
import torchfrom torch import nnfrom d2l import torch as d2lbatch_size = 256train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)Initialize model parameters
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, std=0.01)net.apply(init_weights);nn.Flatten() is used to adjust the shape of the network input
trainer = torch.optim.SGD(net.parameters(), lr=0.1)Loss function
loss = nn.CrossEntropyLoss(reduction='none')Optimization algorithm
trainer = torch.optim.SGD(net.parameters(), lr=0.1)Training
num_epochs = 10d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)
边栏推荐
- 【SemiDrive源码分析】【MailBox核间通信】44 - 基于Mailbox IPCC RPC 实现核间通信(RTOS侧 IPCC_RPC Server 消息接收及回复 原理分析篇)
- The use of beforeDestroy and destroyed
- Pytest e-commerce project combat (on)
- 已解决:不小心卸载pip后(手动安装pip的两种方式)
- mysql基础知识(二)
- 微信小程序使用云函数更新和添加云数据库嵌套数组元素
- MySQL fuzzy query can use INSTR instead of LIKE
- Win10 CUDA CUDNN installation configuration (torch paddlepaddle)
- type_traits元编程库学习
- ERROR 2003 (HY000) Can‘t connect to MySQL server on ‘localhost3306‘ (10061)解决办法
猜你喜欢

(8) Math class, Arrays class, System class, Biglnteger and BigDecimal classes, date class

ERROR 2003 (HY000) Can‘t connect to MySQL server on ‘localhost3306‘ (10061)解决办法

ENSP,划分VLAN、静态路由,三层交换机综合配置

重磅 | 基金会为白金、黄金、白银捐赠人授牌

数字经济时代的开源数据库创新 | 2022开放原子全球开源峰会数据库分论坛圆满召开

STM32HAL库修改Hal_Delay为us级延时

binom二项分布,

XSS shooting range (3) prompt to win

C language from entry to such as soil, the data store

BUG消灭者!!实用调试技巧超全整理
随机推荐
log level and print log note
两个地址池r2负责管地址池r1负责管dhcp中继
专访 | 阿里巴巴首席技术官程立:云+开源共同形成数字世界的可信基础
Pytest e-commerce project combat (on)
(6) Enumeration and annotation
(五)final、抽象类、接口、内部类
MySQL fuzzy query can use INSTR instead of LIKE
Port inspection steps - 7680 port analysis - Dosvc service
HCIP第十天_BGP路由汇总实验
npm、nrm两种方式查看源和切换镜像
Redis uses LIST to cache the latest comments
ENSP,划分VLAN、静态路由,三层交换机综合配置
Redis uses sorted set to cache latest comments
简易网络文件拷贝的C实现
ENSP, VLAN division, static routing, comprehensive configuration of Layer 3 switches
Safety 20220712
C语言表白代码?
Can‘t load /home/Iot/.rnd into RNG
MATLAB/Simulink & & STM32CubeMX tool chain completes model-based design development (MBD) (three)
(树) 最近公共祖先(LCA)