当前位置:网站首页>[Linear Neural Network] softmax regression
[Linear Neural Network] softmax regression
2022-07-31 04:29:00 【PBemmm】
one - hot encoding
Generally used for classification problems, where the features are discrete
It is very simple, using n states to represent n features, only one state takes the value 1, and the others are all 0
Cross entropy
Use the difference between the true probability and the predicted probability as the loss
Loss function
L2 Loss

The green curve is the likelihood function, the yellow is the gradient
When it is far from the origin, the update range of the parameters may be larger, which leads to L1 Loss
Absolute value loss L1 Loss
Huber's Rubust Loss
The loss function combining L1 Loss and L2 Loss

Softmax is realized from 0
Read data
import torchfrom IPython import displayfrom d2l import torch as d2lbatch_size = 256train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)The dataset is Fashion-MNIST, 10 types of images, 6000 images for each type, the training set is 60000, the test set is 10000, and the batch size is 256
num_inputs = 784num_outputs = 10W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)b = torch.zeros(num_outputs, requires_grad=True)Each image is 28*28, expanded to 784-dimensional vector, and the output layer is 10 types
Here weight W is a matrix of (784, 10), input data x is a vector of length 784, for O(1 -> 10), each Oj corresponds to
Oj =< W[j], X>+ bj, so the number of columns of W is num_outputs, which is the number (type) of O
Obviously, b also corresponds to O
softmax
As the name suggests, softmax corresponds to hardmax, and hardmax is the routine value of the sequence. In classification, one hot coding is used, and confidence is introduced. According to the index of e introduced by softmax, we only care whether it can make The predicted value and confidence of the correct class are large enough not to care about the incorrect class.The model can distance the real class from other classes.
def softmax(X):X_exp = torch.exp(X)partition = X_exp.sum(1, keepdim=True)return X_exp / partition # The broadcast mechanism is applied hereModel
def net(X):return softmax(torch.matmul(X.reshape((-1, W.shape[0])), W) + b)Cross entropy
y = torch.tensor([0, 2])y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])y_hat[[0, 1], y]y records the index of the real category, y_hat[ [0,1] , y ] returns the element 0.1 of the real category index of the first group [0.1,0.3,0.6] and 0.5 of the second group
![]()
def cross_entropy(y_hat, y):return - torch.log(y_hat[range(len(y_hat)), y])cross_entropy(y_hat, y)![]()
Simple implementation of softmax
Import
import torchfrom torch import nnfrom d2l import torch as d2lbatch_size = 256train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)Initialize model parameters
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, std=0.01)net.apply(init_weights);nn.Flatten() is used to adjust the shape of the network input
trainer = torch.optim.SGD(net.parameters(), lr=0.1)Loss function
loss = nn.CrossEntropyLoss(reduction='none')Optimization algorithm
trainer = torch.optim.SGD(net.parameters(), lr=0.1)Training
num_epochs = 10d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)
边栏推荐
- $parent/$children and ref
- MATLAB/Simulink & & STM32CubeMX tool chain completes model-based design development (MBD) (three)
- unity2d小游戏
- 扫雷小游戏——C语言
- npm、nrm两种方式查看源和切换镜像
- Musk talks to the "virtual version" of Musk, how far is the brain-computer interaction technology from us
- (树) 最近公共祖先(LCA)
- idea工程明明有依赖但是文件就是显示没有,Cannot resolve symbol ‘XXX‘
- errno错误码及含义(中文)
- Reinforcement learning: from entry to pit to shit
猜你喜欢

(五)final、抽象类、接口、内部类

Based on the local, linking the world | Schneider Electric "Industrial SI Alliance" joins hands with partners to go to the future industry

The third is the code to achieve

【线性神经网络】softmax回归

Fusion Cloud Native, Empowering New Milestones | 2022 Open Atom Global Open Source Summit Cloud Native Sub-Forum Successfully Held

exsl文件预览,word文件预览网页方法

Musk talks to the "virtual version" of Musk, how far is the brain-computer interaction technology from us

LocalDate addition and subtraction operations and comparison size

Understanding and Using Unity2D Custom Scriptable Tiles (4) - Start to build a custom tile based on the Tile class (below)

IDEA common shortcut keys and plug-ins
随机推荐
Safety 20220712
errno error code and meaning (Chinese)
开源汇智创未来 | 2022开放原子全球开源峰会OpenAtom openEuler分论坛圆满召开
三子棋的代码实现
【SemiDrive源码分析】【MailBox核间通信】44 - 基于Mailbox IPCC RPC 实现核间通信(RTOS侧 IPCC_RPC Server 消息接收及回复 原理分析篇)
MATLAB/Simulink&&STM32CubeMX工具链完成基于模型的设计开发(MBD)(三)
input输入框展示两位小数之precision
Unity2D 自定义Scriptable Tiles的理解与使用(四)——开始着手构建一个基于Tile类的自定义tile(下)
聚变云原生,赋能新里程 | 2022开放原子全球开源峰会云原生分论坛圆满召开
Safety 20220718
扫雷小游戏——C语言
Reinforcement learning: from entry to pit to shit
【线性神经网络】softmax回归
MySQL数据库备份
Gaussian distribution and its maximum likelihood estimation
Notes on the establishment of the company's official website (6): The public security record of the domain name is carried out and the record number is displayed at the bottom of the web page
产学研用 共建开源人才生态 | 2022开放原子全球开源峰会教育分论坛圆满召开
el-image标签绑定点击事件后没有有用
PWN ROP
RESTful api接口设计规范