当前位置:网站首页>[Linear Neural Network] softmax regression
[Linear Neural Network] softmax regression
2022-07-31 04:29:00 【PBemmm】
one - hot encoding
Generally used for classification problems, where the features are discrete
It is very simple, using n states to represent n features, only one state takes the value 1, and the others are all 0
Cross entropy
Use the difference between the true probability and the predicted probability as the loss
Loss function
L2 Loss
The green curve is the likelihood function, the yellow is the gradient
When it is far from the origin, the update range of the parameters may be larger, which leads to L1 Loss
Absolute value loss L1 Loss
Huber's Rubust Loss
The loss function combining L1 Loss and L2 Loss
Softmax is realized from 0
Read data
import torchfrom IPython import displayfrom d2l import torch as d2lbatch_size = 256train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
The dataset is Fashion-MNIST, 10 types of images, 6000 images for each type, the training set is 60000, the test set is 10000, and the batch size is 256
num_inputs = 784num_outputs = 10W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)b = torch.zeros(num_outputs, requires_grad=True)
Each image is 28*28, expanded to 784-dimensional vector, and the output layer is 10 types
Here weight W is a matrix of (784, 10), input data x is a vector of length 784, for O(1 -> 10), each Oj corresponds to
Oj =< W[j], X>+ bj, so the number of columns of W is num_outputs, which is the number (type) of O
Obviously, b also corresponds to O
softmax
As the name suggests, softmax corresponds to hardmax, and hardmax is the routine value of the sequence. In classification, one hot coding is used, and confidence is introduced. According to the index of e introduced by softmax, we only care whether it can make The predicted value and confidence of the correct class are large enough not to care about the incorrect class.The model can distance the real class from other classes.
def softmax(X):X_exp = torch.exp(X)partition = X_exp.sum(1, keepdim=True)return X_exp / partition # The broadcast mechanism is applied here
Model
def net(X):return softmax(torch.matmul(X.reshape((-1, W.shape[0])), W) + b)
Cross entropy
y = torch.tensor([0, 2])y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])y_hat[[0, 1], y]
y records the index of the real category, y_hat[ [0,1] , y ] returns the element 0.1 of the real category index of the first group [0.1,0.3,0.6] and 0.5 of the second group
def cross_entropy(y_hat, y):return - torch.log(y_hat[range(len(y_hat)), y])cross_entropy(y_hat, y)
Simple implementation of softmax
Import
import torchfrom torch import nnfrom d2l import torch as d2lbatch_size = 256train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
Initialize model parameters
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, std=0.01)net.apply(init_weights);
nn.Flatten() is used to adjust the shape of the network input
trainer = torch.optim.SGD(net.parameters(), lr=0.1)
Loss function
loss = nn.CrossEntropyLoss(reduction='none')
Optimization algorithm
trainer = torch.optim.SGD(net.parameters(), lr=0.1)
Training
num_epochs = 10d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)
边栏推荐
- 重磅 | 基金会为白金、黄金、白银捐赠人授牌
- Knowledge Distillation 7: Detailed Explanation of Knowledge Distillation Code
- [AUTOSAR-RTE]-5-Explicit (explicit) and Implicit (implicit) Sender-Receiver communication
- 【SemiDrive源码分析】【MailBox核间通信】44 - 基于Mailbox IPCC RPC 实现核间通信(RTOS侧 IPCC_RPC Server 消息接收及回复 原理分析篇)
- MySQL fuzzy query can use INSTR instead of LIKE
- Smartcom Programming Level 4 - Magic Academy Lesson 6
- 【C语言进阶】文件操作(一)
- 论治理与创新 | 2022开放原子全球开源峰会OpenAnolis分论坛圆满召开
- 【小土堆补充】Pytorch学习笔记_Anaconda虚拟环境使用
- $parent/$children and ref
猜你喜欢
ERROR 1819 (HY000) Your password does not satisfy the current policy requirements
三子棋的代码实现
MATLAB/Simulink&&STM32CubeMX工具链完成基于模型的设计开发(MBD)(三)
Daily practice of LeetCode - 138. Copy a linked list with random pointers
微信小程序使用云函数更新和添加云数据库嵌套数组元素
exsl文件预览,word文件预览网页方法
Thinking about data governance after Didi fines
type_traits metaprogramming library learning
两个地址池r2负责管地址池r1负责管dhcp中继
Safety 20220712
随机推荐
Regarding the primary key id in the mysql8.0 database, when the id is inserted using replace to be 0, the actual id is automatically incremented after insertion, resulting in the solution to the repea
MySQL fuzzy query can use INSTR instead of LIKE
从零开始,一镜到底,纯净系统搭建除草机(Grasscutter)
Daily practice of LeetCode - palindrome structure of OR36 linked list
MySQL database must add, delete, search and modify operations (CRUD)
Pytest电商项目实战(上)
How Zotero removes auto-generated tags
进程间通信
C语言从入门到如土——数据的存储
慧通编程第4关 - 魔法学院第6课
MySQL based operations
BP神经网络
简易网络文件拷贝的C实现
【小土堆补充】Pytorch学习笔记_Anaconda虚拟环境使用
mysql数据库安装(详细)
MATLAB/Simulink&&STM32CubeMX工具链完成基于模型的设计开发(MBD)(三)
mysql基础知识(二)
"A daily practice, happy water problem" 1331. Array serial number conversion
MySQL 8.0.30 GA
微软 AI 量化投资平台 Qlib 体验