当前位置:网站首页>6. Logistic model
6. Logistic model
2022-07-05 05:41:00 【A big pigeon】
Logistic model To solve the problem of classification .
If the first 5 The problem with section Whether to pass , It is a binary classification problem .
The output is the probability of passing the exam P. probability P stay 0 To 1 Between .
The output range of the original linear model is R ( The set of real Numbers ), The original output can be mapped to [0,1] Within the scope of .
It only needs two changes to change the linear model into the logistic model , Model plus sigmod() And the loss function is changed to BCELoss
Complete code :
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
# 1. Prepare the data , Note that they are all in matrix form
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[0], [0], [1]])
# 2. The design model ( class ) Inherit nn.Module In order to use its method
class LogisticRegressionModel(torch.nn.Module):
# initialization
def __init__(self):
super(LogisticRegressionModel, self).__init__()
self.linear = torch.nn.Linear(1, 1) # Linear Is a linear unit
# Feedforward method
def forward(self, x):
y_pred = F.sigmoid(self.linear(x)) # In fact, the calling object linear Of __call__() Method ,linear Of __call__() Method execution forward feedforward
return y_pred
model = LogisticRegressionModel()
# 3 loss and optimizer( Optimizer )
criterion = torch.nn.BCELoss(size_average=False) # There is no need to find the mean
# Optimizer . model.parameters() Get the parameters that need to be optimized in the model ,lr(learning rate, Learning rate )
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# 4 Training process
for epoch in range(1000):
# feedforward
y_pred = model(x_data)
# Calculate the loss
loss = criterion(y_pred, y_data)
print("epoch={},loss={}".format(epoch, loss))
optimizer.zero_grad() # Zeroing
# Back propagation
loss.backward()
# to update 、 Optimization parameters
optimizer.step()
# Test, Check the model parameters and test the training effect
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred = {} for x = {}'.format(y_test.data, x_test.data))
# mapping
x = np.linspace(0,10,200)
x_t = torch.Tensor(x).view((200, 1))
y_t = model(x_t)
y = y_t.data.numpy()
plt.plot(x, y)
plt.plot([0,10], [.5,.5], c='r')
plt.xlabel('Hours')
plt.ylabel('Probability of Pass')
plt.grid()
plt.show()
边栏推荐
- 个人开发的渗透测试工具Satania v1.2更新
- Codeforces Round #715 (Div. 2) D. Binary Literature
- Personal developed penetration testing tool Satania v1.2 update
- Over fitting and regularization
- 【实战技能】非技术背景经理的技术管理
- sync. Interpretation of mutex source code
- Control Unit 控制部件
- On-off and on-off of quality system construction
- Warning using room database: schema export directory is not provided to the annotation processor so we cannot export
- Improvement of pointnet++
猜你喜欢
Sword finger offer 05 Replace spaces
Time of process
Corridor and bridge distribution (csp-s-2021-t1) popular problem solution
Introduction and experience of wazuh open source host security solution
F - Two Exam(AtCoder Beginner Contest 238)
In this indifferent world, light crying
从Dijkstra的图灵奖演讲论科技创业者特点
On-off and on-off of quality system construction
Graduation project of game mall
Educational Codeforces Round 116 (Rated for Div. 2) E. Arena
随机推荐
Add level control and logger level control of Solon logging plug-in
Scope of inline symbol
Bit mask of bit operation
High precision subtraction
动漫评分数据分析与可视化 与 IT行业招聘数据分析与可视化
sync. Interpretation of mutex source code
26、 File system API (device sharing between applications; directory and file API)
Palindrome (csp-s-2021-palin) solution
Sword finger offer 05 Replace spaces
Typical use cases for knapsacks, queues, and stacks
Reader writer model
Convolution neural network -- convolution layer
Brief introduction to tcp/ip protocol stack
Summary of Haut OJ 2021 freshman week
Control Unit 控制部件
从Dijkstra的图灵奖演讲论科技创业者特点
Individual game 12
Solution to game 10 of the personal field
A problem and solution of recording QT memory leakage
Educational Codeforces Round 116 (Rated for Div. 2) E. Arena