当前位置:网站首页>6. Logistic model
6. Logistic model
2022-07-05 05:41:00 【A big pigeon】
Logistic model To solve the problem of classification .
If the first 5 The problem with section Whether to pass , It is a binary classification problem .

The output is the probability of passing the exam P. probability P stay 0 To 1 Between .
The output range of the original linear model is R ( The set of real Numbers ), The original output can be mapped to [0,1] Within the scope of .

It only needs two changes to change the linear model into the logistic model , Model plus sigmod() And the loss function is changed to BCELoss


Complete code :
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
# 1. Prepare the data , Note that they are all in matrix form
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[0], [0], [1]])
# 2. The design model ( class ) Inherit nn.Module In order to use its method
class LogisticRegressionModel(torch.nn.Module):
# initialization
def __init__(self):
super(LogisticRegressionModel, self).__init__()
self.linear = torch.nn.Linear(1, 1) # Linear Is a linear unit
# Feedforward method
def forward(self, x):
y_pred = F.sigmoid(self.linear(x)) # In fact, the calling object linear Of __call__() Method ,linear Of __call__() Method execution forward feedforward
return y_pred
model = LogisticRegressionModel()
# 3 loss and optimizer( Optimizer )
criterion = torch.nn.BCELoss(size_average=False) # There is no need to find the mean
# Optimizer . model.parameters() Get the parameters that need to be optimized in the model ,lr(learning rate, Learning rate )
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# 4 Training process
for epoch in range(1000):
# feedforward
y_pred = model(x_data)
# Calculate the loss
loss = criterion(y_pred, y_data)
print("epoch={},loss={}".format(epoch, loss))
optimizer.zero_grad() # Zeroing
# Back propagation
loss.backward()
# to update 、 Optimization parameters
optimizer.step()
# Test, Check the model parameters and test the training effect
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred = {} for x = {}'.format(y_test.data, x_test.data))
# mapping
x = np.linspace(0,10,200)
x_t = torch.Tensor(x).view((200, 1))
y_t = model(x_t)
y = y_t.data.numpy()
plt.plot(x, y)
plt.plot([0,10], [.5,.5], c='r')
plt.xlabel('Hours')
plt.ylabel('Probability of Pass')
plt.grid()
plt.show()
边栏推荐
- Demonstration of using Solon auth authentication framework (simpler authentication framework)
- YOLOv5-Shufflenetv2
- Introduction to convolutional neural network
- Improvement of pointnet++
- [jailhouse article] performance measurements for hypervisors on embedded ARM processors
- Chapter 6 data flow modeling - after class exercises
- How can the Solon framework easily obtain the response time of each request?
- Kubedm series-00-overview
- Web APIs DOM node
- 数仓项目的集群脚本
猜你喜欢
随机推荐
【实战技能】非技术背景经理的技术管理
Wazuh开源主机安全解决方案的简介与使用体验
剑指 Offer 09. 用两个栈实现队列
A problem and solution of recording QT memory leakage
Sword finger offer 06 Print linked list from beginning to end
shared_ Repeated release heap object of PTR hidden danger
【Jailhouse 文章】Look Mum, no VM Exits
CCPC Weihai 2021m eight hundred and ten thousand nine hundred and seventy-five
Individual game 12
游戏商城毕业设计
Educational Codeforces Round 116 (Rated for Div. 2) E. Arena
A misunderstanding about the console window
A preliminary study of sdei - see the essence through transactions
记录QT内存泄漏的一种问题和解决方案
“磐云杯”中职网络安全技能大赛A模块新题
Add level control and logger level control of Solon logging plug-in
Educational codeforces round 109 (rated for Div. 2) C. robot collisions D. armchairs
Fried chicken nuggets and fifa22
Bit mask of bit operation
全国中职网络安全B模块之国赛题远程代码执行渗透测试 //PHPstudy的后门漏洞分析









![[jailhouse article] jailhouse hypervisor](/img/f4/4809b236067d3007fa5835bbfe5f48.png)