当前位置:网站首页>[Linear Neural Network] softmax regression
[Linear Neural Network] softmax regression
2022-07-31 04:29:00 【PBemmm】
one - hot encoding
Generally used for classification problems, where the features are discrete
It is very simple, using n states to represent n features, only one state takes the value 1, and the others are all 0
Cross entropy
Use the difference between the true probability and the predicted probability as the loss
Loss function
L2 Loss

The green curve is the likelihood function, the yellow is the gradient
When it is far from the origin, the update range of the parameters may be larger, which leads to L1 Loss
Absolute value loss L1 Loss
Huber's Rubust Loss
The loss function combining L1 Loss and L2 Loss

Softmax is realized from 0
Read data
import torchfrom IPython import displayfrom d2l import torch as d2lbatch_size = 256train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)The dataset is Fashion-MNIST, 10 types of images, 6000 images for each type, the training set is 60000, the test set is 10000, and the batch size is 256
num_inputs = 784num_outputs = 10W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)b = torch.zeros(num_outputs, requires_grad=True)Each image is 28*28, expanded to 784-dimensional vector, and the output layer is 10 types
Here weight W is a matrix of (784, 10), input data x is a vector of length 784, for O(1 -> 10), each Oj corresponds to
Oj =< W[j], X>+ bj, so the number of columns of W is num_outputs, which is the number (type) of O
Obviously, b also corresponds to O
softmax
As the name suggests, softmax corresponds to hardmax, and hardmax is the routine value of the sequence. In classification, one hot coding is used, and confidence is introduced. According to the index of e introduced by softmax, we only care whether it can make The predicted value and confidence of the correct class are large enough not to care about the incorrect class.The model can distance the real class from other classes.
def softmax(X):X_exp = torch.exp(X)partition = X_exp.sum(1, keepdim=True)return X_exp / partition # The broadcast mechanism is applied hereModel
def net(X):return softmax(torch.matmul(X.reshape((-1, W.shape[0])), W) + b)Cross entropy
y = torch.tensor([0, 2])y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])y_hat[[0, 1], y]y records the index of the real category, y_hat[ [0,1] , y ] returns the element 0.1 of the real category index of the first group [0.1,0.3,0.6] and 0.5 of the second group
![]()
def cross_entropy(y_hat, y):return - torch.log(y_hat[range(len(y_hat)), y])cross_entropy(y_hat, y)![]()
Simple implementation of softmax
Import
import torchfrom torch import nnfrom d2l import torch as d2lbatch_size = 256train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)Initialize model parameters
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, std=0.01)net.apply(init_weights);nn.Flatten() is used to adjust the shape of the network input
trainer = torch.optim.SGD(net.parameters(), lr=0.1)Loss function
loss = nn.CrossEntropyLoss(reduction='none')Optimization algorithm
trainer = torch.optim.SGD(net.parameters(), lr=0.1)Training
num_epochs = 10d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)
边栏推荐
- 递归实现汉诺塔问题
- Safety 20220709
- Unity打灵狐者
- 关于出现大量close_wait状态的理解
- Pytest e-commerce project combat (on)
- open failed: EACCES (Permission denied)
- ENSP,划分VLAN、静态路由,三层交换机综合配置
- 马斯克对话“虚拟版”马斯克,脑机交互技术离我们有多远
- Regarding the primary key id in the mysql8.0 database, when the id is inserted using replace to be 0, the actual id is automatically incremented after insertion, resulting in the solution to the repea
- MySQL数据库增删改查(基础操作命令详解)
猜你喜欢
![[Paper reading] Mastering the game of Go with deep neural networks and tree search](/img/4f/899da202e13bd561bbfdbaeebe4d2e.jpg)
[Paper reading] Mastering the game of Go with deep neural networks and tree search

Win10 CUDA CUDNN 安装配置(torch paddlepaddle)

(8) Math class, Arrays class, System class, Biglnteger and BigDecimal classes, date class

【线性神经网络】softmax回归

HCIP第十天_BGP路由汇总实验

数字经济时代的开源数据库创新 | 2022开放原子全球开源峰会数据库分论坛圆满召开

C language confession code?

Unity2D 自定义Scriptable Tiles的理解与使用(四)——开始着手构建一个基于Tile类的自定义tile(下)

(4) Recursion, variable parameters, access modifiers, understanding main method, code block

SOLVED: After accidentally uninstalling pip (two ways to manually install pip)
随机推荐
Unity2D 自定义Scriptable Tiles的理解与使用(四)——开始着手构建一个基于Tile类的自定义tile(下)
论治理与创新 | 2022开放原子全球开源峰会OpenAnolis分论坛圆满召开
(tree) Last Common Ancestor (LCA)
On Governance and Innovation | 2022 OpenAtom Global Open Source Summit OpenAnolis sub-forum was successfully held
Safety 20220715
Bubble sort, selection sort, insertion sort, binary search directly
(线段树) 基础线段树常见问题总结
Basic knowledge of mysql (2)
微信小程序使用云函数更新和添加云数据库嵌套数组元素
问题7:列表的拼接
PWN ROP
Knowledge Distillation 7: Detailed Explanation of Knowledge Distillation Code
The third is the code to achieve
Can't load /home/Iot/.rnd into RNG
MySQL database must add, delete, search and modify operations (CRUD)
BP神经网络
C language from entry to such as soil, the data store
unity2d game
(树) 最近公共祖先(LCA)
重磅 | 开放原子校源行活动正式启动