当前位置:网站首页>6. Logistic model
6. Logistic model
2022-07-05 05:41:00 【A big pigeon】
Logistic model To solve the problem of classification .
If the first 5 The problem with section Whether to pass , It is a binary classification problem .
The output is the probability of passing the exam P. probability P stay 0 To 1 Between .
The output range of the original linear model is R ( The set of real Numbers ), The original output can be mapped to [0,1] Within the scope of .
It only needs two changes to change the linear model into the logistic model , Model plus sigmod() And the loss function is changed to BCELoss
Complete code :
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
# 1. Prepare the data , Note that they are all in matrix form
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[0], [0], [1]])
# 2. The design model ( class ) Inherit nn.Module In order to use its method
class LogisticRegressionModel(torch.nn.Module):
# initialization
def __init__(self):
super(LogisticRegressionModel, self).__init__()
self.linear = torch.nn.Linear(1, 1) # Linear Is a linear unit
# Feedforward method
def forward(self, x):
y_pred = F.sigmoid(self.linear(x)) # In fact, the calling object linear Of __call__() Method ,linear Of __call__() Method execution forward feedforward
return y_pred
model = LogisticRegressionModel()
# 3 loss and optimizer( Optimizer )
criterion = torch.nn.BCELoss(size_average=False) # There is no need to find the mean
# Optimizer . model.parameters() Get the parameters that need to be optimized in the model ,lr(learning rate, Learning rate )
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# 4 Training process
for epoch in range(1000):
# feedforward
y_pred = model(x_data)
# Calculate the loss
loss = criterion(y_pred, y_data)
print("epoch={},loss={}".format(epoch, loss))
optimizer.zero_grad() # Zeroing
# Back propagation
loss.backward()
# to update 、 Optimization parameters
optimizer.step()
# Test, Check the model parameters and test the training effect
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred = {} for x = {}'.format(y_test.data, x_test.data))
# mapping
x = np.linspace(0,10,200)
x_t = torch.Tensor(x).view((200, 1))
y_t = model(x_t)
y = y_t.data.numpy()
plt.plot(x, y)
plt.plot([0,10], [.5,.5], c='r')
plt.xlabel('Hours')
plt.ylabel('Probability of Pass')
plt.grid()
plt.show()
边栏推荐
- Educational Codeforces Round 116 (Rated for Div. 2) E. Arena
- 每日一题-无重复字符的最长子串
- Personal developed penetration testing tool Satania v1.2 update
- PC register
- “磐云杯”中职网络安全技能大赛A模块新题
- 游戏商城毕业设计
- MySQL数据库(一)
- Educational codeforces round 109 (rated for Div. 2) C. robot collisions D. armchairs
- A problem and solution of recording QT memory leakage
- 【Jailhouse 文章】Look Mum, no VM Exits
猜你喜欢
Gbase database helps the development of digital finance in the Bay Area
Using HashMap to realize simple cache
Personal developed penetration testing tool Satania v1.2 update
A new micro ORM open source framework
2017 USP Try-outs C. Coprimes
Analysis of backdoor vulnerability in remote code execution penetration test / / phpstudy of national game title of national secondary vocational network security B module
shared_ Repeated release heap object of PTR hidden danger
【Jailhouse 文章】Jailhouse Hypervisor
剑指 Offer 05. 替换空格
sync.Mutex源码解读
随机推荐
Codeforces Round #715 (Div. 2) D. Binary Literature
AtCoder Grand Contest 013 E - Placing Squares
Sword finger offer 06 Print linked list from beginning to end
中职网络安全技能竞赛——广西区赛中间件渗透测试教程文章
二十六、文件系统API(设备在应用间的共享;目录和文件API)
[cloud native] record of feign custom configuration of microservices
Developing desktop applications with electron
Haut OJ 2021 freshmen week II reflection summary
Introduction to tools in TF-A
In this indifferent world, light crying
How many checks does kubedm series-01-preflight have
Time complexity and space complexity
Remote upgrade afraid of cutting beard? Explain FOTA safety upgrade in detail
Solution to game 10 of the personal field
【Jailhouse 文章】Look Mum, no VM Exits
Convolution neural network -- convolution layer
Maximum number of "balloons"
剑指 Offer 09. 用两个栈实现队列
Mysql database (I)
shared_ Repeated release heap object of PTR hidden danger