当前位置:网站首页>Torch learning notes (6) -- logistic regression model (self training)
Torch learning notes (6) -- logistic regression model (self training)
2022-07-03 18:22:00 【ZRX_ GIS】
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
# Logical regression (Logistic Regression)
# LR It is a linear binary classification model
# y = f(wx+b) f(x) = 1/(1+e**-x) f(x) Also known as Sigmoid Function or Logistic function
# LR effect : Map input data to [0,1]
# Process quantity classification adopts “ rounding , Realize overall two classification
# Linear regression is the analysis of independent variables x And scalar y The method of the relationship between
# LR Is the analysis independent variable x And probability y The method of the relationship between
# LR It can also be generated by variant “ Logarithmic probability regression model ” ln(y/1-y) = wx+b
# Machine learning steps :data( collection 、 cleaning 、 Divide 、 Preprocessing )、model、Loss、optim
# LR Case study
# data
sample_num = torch.tensor([100.], requires_grad=True)
mean_value = 1.7
bias = 1
n_data = torch.normal(sample_num, 2)
x0 = torch.normal(mean_value * n_data, 1) + bias # Normal distribution generates
y0 = torch.zeros(100)
x1 = torch.normal(-mean_value * n_data, 1) + bias
y1 = torch.ones(100)
train_x = torch.cat((x0, x1), 0) # Splicing variables
train_y = torch.cat((y0, y1), 0)
# model
class LR(nn.Module):
def __init__(self):
super(LR, self).__init__()
self.features = nn.Linear(2, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.features(x)
x = self.sigmoid(x)
# Instantiation LR
LR = LR()
# Loss
loss_fn = nn.BCELoss()
# optim
lr = 0.01 # Learning rate
optimizer = torch.optim.SGD(LR.parameters(), lr=lr, momentum=0.9)
# train
for iteration in range(1000):
y_pred = LR(train_x)
loss = loss_fn(y_pred.squeeze(), train_y)
loss.backward()
optimizer.step()
# mapping
if iteration % 20 == 0:
mask = y_pred.ge(0.5).float().squeeze() # 0.5 Classify thresholds
correct = (mask == train_y).sum() # Number of correct classifications
acc = correct.item() / train_y.size(0) # precision
plt.scatter(x0.data.numpy()[:, 0], x0.data.numpy()[:, 1], c='r', label='class 0')
plt.scatter(x1.data.numpy()[:, 0], x1.data.numpy()[:, 1], c='b', label='class 1')
w0, w1 = LR.features.weight[0]
w0, w1 = float(w0.item()), float(w1.item())
plot_b = float(LR.features.bias[0].item())
plot_x = np.arange(-6, 6, 0.1)
plot_y = (-w0 * plot_x - plot_b) / w1
plt.xlim(-5, 7)
plt.xlim(-5, 7)
plt.plot(plot_x, plot_y)
plt.legend()
plt.show()
plt.pause(0.5)
if acc >= 0.99: break
边栏推荐
- English语法_名词 - 分类
- Keepalived setting does not preempt resources
- Lesson 13 of the Blue Bridge Cup -- tree array and line segment tree [exercise]
- What London Silver Trading software supports multiple languages
- Computer graduation design PHP campus address book telephone number inquiry system
- Prototype inheritance..
- English語法_名詞 - 分類
- Enterprise custom form engine solution (12) -- form rule engine 2
- [combinatorics] generating function (summation property)
- A. Berland Poker &1000【简单数学思维】
猜你喜欢
English語法_名詞 - 分類
2022-2028 global sepsis treatment drug industry research and trend analysis report
2022-2028 global solid phase extraction column industry research and trend analysis report
English语法_名词 - 分类
4. Load balancing and dynamic static separation
How does GCN use large convolution instead of small convolution? (the explanation of the paper includes super detailed notes + Chinese English comparison + pictures)
How to track the real-time trend of Bank of London
Computer graduation design PHP makeup sales Beauty shopping mall
win32:堆破壞的dump文件分析
What London Silver Trading software supports multiple languages
随机推荐
Sepconv (separable revolution) code recurrence
统计图像中各像素值的数量
Mysql45 lecture learning notes (II)
Image 24 bit depth to 8 bit depth
图像24位深度转8位深度
Distributed task distribution framework gearman
Closure and closure function
PHP MySQL Update
2022-2028 global copper foil (thickness 12 μ M) industry research and trend analysis report
Win 11 major updates, new features love love.
Design limitations of structure type (struct)
How do microservices aggregate API documents? This wave of operation is too good
[enumeration] annoying frogs always step on my rice fields: (who is the most hateful? (POJ hundred practice 2812)
圖像24比特深度轉8比特深度
[combinatorics] exponential generating function (example of exponential generating function solving multiple set arrangement)
G1 garbage collector of garbage collector
Codeforces Round #803 (Div. 2) C. 3SUM Closure
Kotlin's collaboration: Context
After the festival, a large number of people change careers. Is it still time to be 30? Listen to the experience of the past people
[combinatorics] exponential generating function (proving that the exponential generating function solves the arrangement of multiple sets)