当前位置:网站首页>6. Logistic model
6. Logistic model
2022-07-05 05:41:00 【A big pigeon】
Logistic model To solve the problem of classification .
If the first 5 The problem with section Whether to pass , It is a binary classification problem .
The output is the probability of passing the exam P. probability P stay 0 To 1 Between .
The output range of the original linear model is R ( The set of real Numbers ), The original output can be mapped to [0,1] Within the scope of .
It only needs two changes to change the linear model into the logistic model , Model plus sigmod() And the loss function is changed to BCELoss
Complete code :
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
# 1. Prepare the data , Note that they are all in matrix form
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[0], [0], [1]])
# 2. The design model ( class ) Inherit nn.Module In order to use its method
class LogisticRegressionModel(torch.nn.Module):
# initialization
def __init__(self):
super(LogisticRegressionModel, self).__init__()
self.linear = torch.nn.Linear(1, 1) # Linear Is a linear unit
# Feedforward method
def forward(self, x):
y_pred = F.sigmoid(self.linear(x)) # In fact, the calling object linear Of __call__() Method ,linear Of __call__() Method execution forward feedforward
return y_pred
model = LogisticRegressionModel()
# 3 loss and optimizer( Optimizer )
criterion = torch.nn.BCELoss(size_average=False) # There is no need to find the mean
# Optimizer . model.parameters() Get the parameters that need to be optimized in the model ,lr(learning rate, Learning rate )
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# 4 Training process
for epoch in range(1000):
# feedforward
y_pred = model(x_data)
# Calculate the loss
loss = criterion(y_pred, y_data)
print("epoch={},loss={}".format(epoch, loss))
optimizer.zero_grad() # Zeroing
# Back propagation
loss.backward()
# to update 、 Optimization parameters
optimizer.step()
# Test, Check the model parameters and test the training effect
x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred = {} for x = {}'.format(y_test.data, x_test.data))
# mapping
x = np.linspace(0,10,200)
x_t = torch.Tensor(x).view((200, 1))
y_t = model(x_t)
y = y_t.data.numpy()
plt.plot(x, y)
plt.plot([0,10], [.5,.5], c='r')
plt.xlabel('Hours')
plt.ylabel('Probability of Pass')
plt.grid()
plt.show()
边栏推荐
- Reflection summary of Haut OJ freshmen on Wednesday
- Codeforces Round #715 (Div. 2) D. Binary Literature
- 2017 USP Try-outs C. Coprimes
- [cloud native] record of feign custom configuration of microservices
- Codeforces Round #732 (Div. 2) D. AquaMoon and Chess
- CF1634E Fair Share
- Support multi-mode polymorphic gbase 8C database continuous innovation and heavy upgrade
- 游戏商城毕业设计
- Cluster script of data warehouse project
- 浅谈JVM(面试常考)
猜你喜欢
随机推荐
ssh免密登录设置及使用脚本进行ssh登录并执行指令
Service fusing hystrix
Zzulioj 1673: b: clever characters???
2022 极术通讯-Arm 虚拟硬件加速物联网软件开发
Fried chicken nuggets and fifa22
Transform optimization problems into decision-making problems
SAP method of modifying system table data
Sword finger offer 53 - I. find the number I in the sorted array
记录QT内存泄漏的一种问题和解决方案
【Jailhouse 文章】Jailhouse Hypervisor
浅谈JVM(面试常考)
object serialization
Gbase database helps the development of digital finance in the Bay Area
[cloud native] record of feign custom configuration of microservices
[es practice] use the native realm security mode on es
Time of process
Maximum number of "balloons"
全国中职网络安全B模块之国赛题远程代码执行渗透测试 //PHPstudy的后门漏洞分析
剑指 Offer 06.从头到尾打印链表
On-off and on-off of quality system construction