当前位置:网站首页>[Linear Neural Network] softmax regression
[Linear Neural Network] softmax regression
2022-07-31 04:29:00 【PBemmm】
one - hot encoding
Generally used for classification problems, where the features are discrete
It is very simple, using n states to represent n features, only one state takes the value 1, and the others are all 0
Cross entropy
Use the difference between the true probability and the predicted probability as the loss
Loss function
L2 Loss
The green curve is the likelihood function, the yellow is the gradient
When it is far from the origin, the update range of the parameters may be larger, which leads to L1 Loss
Absolute value loss L1 Loss
Huber's Rubust Loss
The loss function combining L1 Loss and L2 Loss
Softmax is realized from 0
Read data
import torchfrom IPython import displayfrom d2l import torch as d2lbatch_size = 256train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
The dataset is Fashion-MNIST, 10 types of images, 6000 images for each type, the training set is 60000, the test set is 10000, and the batch size is 256
num_inputs = 784num_outputs = 10W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)b = torch.zeros(num_outputs, requires_grad=True)
Each image is 28*28, expanded to 784-dimensional vector, and the output layer is 10 types
Here weight W is a matrix of (784, 10), input data x is a vector of length 784, for O(1 -> 10), each Oj corresponds to
Oj =< W[j], X>+ bj, so the number of columns of W is num_outputs, which is the number (type) of O
Obviously, b also corresponds to O
softmax
As the name suggests, softmax corresponds to hardmax, and hardmax is the routine value of the sequence. In classification, one hot coding is used, and confidence is introduced. According to the index of e introduced by softmax, we only care whether it can make The predicted value and confidence of the correct class are large enough not to care about the incorrect class.The model can distance the real class from other classes.
def softmax(X):X_exp = torch.exp(X)partition = X_exp.sum(1, keepdim=True)return X_exp / partition # The broadcast mechanism is applied here
Model
def net(X):return softmax(torch.matmul(X.reshape((-1, W.shape[0])), W) + b)
Cross entropy
y = torch.tensor([0, 2])y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])y_hat[[0, 1], y]
y records the index of the real category, y_hat[ [0,1] , y ] returns the element 0.1 of the real category index of the first group [0.1,0.3,0.6] and 0.5 of the second group
def cross_entropy(y_hat, y):return - torch.log(y_hat[range(len(y_hat)), y])cross_entropy(y_hat, y)
Simple implementation of softmax
Import
import torchfrom torch import nnfrom d2l import torch as d2lbatch_size = 256train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
Initialize model parameters
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, std=0.01)net.apply(init_weights);
nn.Flatten() is used to adjust the shape of the network input
trainer = torch.optim.SGD(net.parameters(), lr=0.1)
Loss function
loss = nn.CrossEntropyLoss(reduction='none')
Optimization algorithm
trainer = torch.optim.SGD(net.parameters(), lr=0.1)
Training
num_epochs = 10d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)
边栏推荐
- MATLAB/Simulink&&STM32CubeMX工具链完成基于模型的设计开发(MBD)(三)
- MySQL database must add, delete, search and modify operations (CRUD)
- el-image标签绑定点击事件后没有有用
- type_traits元编程库学习
- 三子棋的代码实现
- ClickHouse: Setting up remote connections
- C language from entry to such as soil, the data store
- 从零开始,一镜到底,纯净系统搭建除草机(Grasscutter)
- Win10 CUDA CUDNN installation configuration (torch paddlepaddle)
- XSS靶场(三)prompt to win
猜你喜欢
(4) Recursion, variable parameters, access modifiers, understanding main method, code block
(八)Math 类、Arrays 类、System类、Biglnteger 和 BigDecimal 类、日期类
【C语言进阶】文件操作(一)
数字经济时代的开源数据库创新 | 2022开放原子全球开源峰会数据库分论坛圆满召开
Component pass value provide/inject
Based on the local, linking the world | Schneider Electric "Industrial SI Alliance" joins hands with partners to go to the future industry
"A daily practice, happy water problem" 1331. Array serial number conversion
MATLAB/Simulink & & STM32CubeMX tool chain completes model-based design development (MBD) (three)
Error EPERM operation not permitted, mkdir ‘Dsoftwarenodejsnode_cache_cacach两种解决办法
Knowledge Distillation 7: Detailed Explanation of Knowledge Distillation Code
随机推荐
Port inspection steps - 7680 port analysis - Dosvc service
Safety 20220722
RESTful api interface design specification
errno error code and meaning (Chinese)
已解决(最新版selenium框架元素定位报错)NameError: name ‘By‘ is not defined
开源汇智创未来 | 2022开放原子全球开源峰会OpenAtom openEuler分论坛圆满召开
Can't load /home/Iot/.rnd into RNG
The idea project obviously has dependencies, but the file is not displayed, Cannot resolve symbol 'XXX'
WPF WPF 】 【 the depth resolution of the template
log level and print log note
【论文阅读】Mastering the game of Go with deep neural networks and tree search
(线段树) 基础线段树常见问题总结
STM32HAL库修改Hal_Delay为us级延时
问题7:列表的拼接
BUG消灭者!!实用调试技巧超全整理
qlib架构
exsl文件预览,word文件预览网页方法
[Paper reading] Mastering the game of Go with deep neural networks and tree search
聚变云原生,赋能新里程 | 2022开放原子全球开源峰会云原生分论坛圆满召开
递归实现汉诺塔问题