当前位置:网站首页>6. Logistic model
6. Logistic model
2022-07-05 05:41:00 【A big pigeon】
Logistic model To solve the problem of classification .
If the first 5 The problem with section Whether to pass , It is a binary classification problem .

The output is the probability of passing the exam P. probability P stay 0 To 1 Between .
The output range of the original linear model is R ( The set of real Numbers ), The original output can be mapped to [0,1] Within the scope of .

It only needs two changes to change the linear model into the logistic model , Model plus sigmod() And the loss function is changed to BCELoss


Complete code :
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
# 1. Prepare the data , Note that they are all in matrix form
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[0], [0], [1]])
# 2. The design model ( class ) Inherit nn.Module In order to use its method
class LogisticRegressionModel(torch.nn.Module):
# initialization
def __init__(self):
super(LogisticRegressionModel, self).__init__()
self.linear = torch.nn.Linear(1, 1) # Linear Is a linear unit
# Feedforward method
def forward(self, x):
y_pred = F.sigmoid(self.linear(x)) # In fact, the calling object linear Of __call__() Method ,linear Of __call__() Method execution forward feedforward
return y_pred
model = LogisticRegressionModel()
# 3 loss and optimizer( Optimizer )
criterion = torch.nn.BCELoss(size_average=False) # There is no need to find the mean
# Optimizer . model.parameters() Get the parameters that need to be optimized in the model ,lr(learning rate, Learning rate )
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# 4 Training process
for epoch in range(1000):
# feedforward
y_pred = model(x_data)
# Calculate the loss
loss = criterion(y_pred, y_data)
print("epoch={},loss={}".format(epoch, loss))
optimizer.zero_grad() # Zeroing
# Back propagation
loss.backward()
# to update 、 Optimization parameters
optimizer.step()
# Test, Check the model parameters and test the training effect
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred = {} for x = {}'.format(y_test.data, x_test.data))
# mapping
x = np.linspace(0,10,200)
x_t = torch.Tensor(x).view((200, 1))
y_t = model(x_t)
y = y_t.data.numpy()
plt.plot(x, y)
plt.plot([0,10], [.5,.5], c='r')
plt.xlabel('Hours')
plt.ylabel('Probability of Pass')
plt.grid()
plt.show()
边栏推荐
- Educational Codeforces Round 107 (Rated for Div. 2) E. Colorings and Dominoes
- Codeforces Round #716 (Div. 2) D. Cut and Stick
- Codeforces round 712 (Div. 2) d. 3-coloring (construction)
- Add level control and logger level control of Solon logging plug-in
- Educational codeforces round 109 (rated for Div. 2) C. robot collisions D. armchairs
- 常见的最优化方法
- 过拟合与正则化
- PC寄存器
- Bit mask of bit operation
- [jailhouse article] performance measurements for hypervisors on embedded ARM processors
猜你喜欢

2017 USP Try-outs C. Coprimes

API related to TCP connection

YOLOv5-Shufflenetv2
![[cloud native] record of feign custom configuration of microservices](/img/39/05cf7673155954c90e75a8a2eecd96.jpg)
[cloud native] record of feign custom configuration of microservices

shared_ Repeated release heap object of PTR hidden danger
![[jailhouse article] jailhouse hypervisor](/img/f4/4809b236067d3007fa5835bbfe5f48.png)
[jailhouse article] jailhouse hypervisor

Graduation project of game mall

Gbase database helps the development of digital finance in the Bay Area

R语言【数据集的导入导出】

sync.Mutex源码解读
随机推荐
[es practice] use the native realm security mode on es
Simple knapsack, queue and stack with deque
EOJ 2021.10 E. XOR tree
【云原生】微服务之Feign自定义配置的记录
Service fusing hystrix
卷积神经网络简介
数仓项目的集群脚本
每日一题-无重复字符的最长子串
Chapter 6 data flow modeling - after class exercises
AtCoder Grand Contest 013 E - Placing Squares
26、 File system API (device sharing between applications; directory and file API)
中职网络安全技能竞赛——广西区赛中间件渗透测试教程文章
API related to TCP connection
Introduction and experience of wazuh open source host security solution
Solution to game 10 of the personal field
Warning using room database: schema export directory is not provided to the annotation processor so we cannot export
The number of enclaves
Drawing dynamic 3D circle with pure C language
R语言【数据集的导入导出】
F - Two Exam(AtCoder Beginner Contest 238)