当前位置:网站首页>Pytoch realizes logistic regression
Pytoch realizes logistic regression
2022-07-26 08:54:00 【Miracle Fan】
Pytorch Realization Logistic Return to
1. The import related api
import torch
import torch.nn as nn
import numpy as np
from sklearn import datasets
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
2. Prepare the data
bc = datasets.load_breast_cancer()
X, y = bc.data, bc.target
n_samples, n_features = X.shape# Obtain the number of samples and the number of sample characteristics
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1234)# Set test set to 20%
3. Data preprocessing
3.1 Feature scaling
# Must use first fit_transform(trainData), After that transform(testData)
# If direct transform(testData), The program will report an error
# If fit_transfrom(trainData) after , Use fit_transform(testData) Instead of transform(testData), Although it can also normalize , But the two results are not in the same “ standard ” Under the , There are obvious differences .( Be sure to avoid this situation )
sc = StandardScaler()
X_train = sc.fit_transform(X_train)
X_test = sc.transform(X_test)
3.2 Data type conversion
X_train = torch.from_numpy(X_train.astype(np.float32))
X_test = torch.from_numpy(X_test.astype(np.float32))
y_train = torch.from_numpy(y_train.astype(np.float32))
y_test = torch.from_numpy(y_test.astype(np.float32))
3.3 Data format adjustment
tensor.view() Be similar to array.resize(), Here is to adjust the label line to 1 Column
# Be similar to resize()
y_train = y_train.view(y_train.shape[0], 1)
y_test = y_test.view(y_test.shape[0], 1)
4. Build a model
4.1 Define the basic network model
class Model(nn.Module):
def __init__(self, n_input_features):
super(Model, self).__init__()# For initialization nn.module
self.linear = nn.Linear(n_input_features, 1)# Define a linear perceptron , Input is n_input_features, Output as a single value
def forward(self, x):
y_pred = torch.sigmoid(self.linear(x))# Use sigmoid Output function 0~1 A value of
return y_pred
model = Model(n_features)
4.2 Define loss and optimizer
num_epochs = 100
lr = 0.01
criterion = nn.BCELoss()# Because this is the second category , So use Binary Cross Entropy
optimizer = torch.optim.SGD(model.parameters(), lr=lr)#lr For learning rate , Hyperparameters
5. model training
for epoch in range(num_epochs):
# Iterate forward and calculate the error
y_pred = model(X_train)
loss = criterion(y_pred, y_train)
# Reverse iteration 、 Update parameters
loss.backward()
optimizer.step()
# Clear the gradient before updating the parameters once
optimizer.zero_grad()
if (epoch+1) % 10 == 0:
print(f'epoch: {
epoch+1}, loss = {
loss.item():.4f}')
6. Model test
with torch.no_grad():
y_predicted = model(X_test)
y_predicted_cls = y_predicted.round()# Yes 0~1 The results between are rounded , Get classification value
acc = y_predicted_cls.eq(y_test).sum() / float(y_test.shape[0])
print(f'accuracy: {
acc.item():.4f}')
y_predicted_cls.eq(y_test) When the prediction category is equal to the category of test data, it is 1, And count all the times of correct classification , Calculate the accuracy of the test set .
边栏推荐
- KV database based on raft consensus protocol
- Media at home and abroad publicize that we should strictly grasp the content
- keepalived双机热备
- Replication of SQL injection vulnerability in the foreground of Pan micro e-cology8
- Day06 homework -- skill question 2
- 第6天总结&数据库作业
- [abstract base class inheritance, DOM, event - learning summary]
- [search topics] flood coverage of search questions after reading the inevitable meeting
- Cve-2021-26295 Apache OFBiz deserialization Remote Code Execution Vulnerability recurrence
- Arbitrum launched the anytrust chain to meet the diverse needs of ecological projects
猜你喜欢

Hegong sky team vision training Day6 - traditional vision, image processing
![[freeswitch development practice] user defined module creation and use](/img/5f/3034577e3e2bc018d0f272359af502.png)
[freeswitch development practice] user defined module creation and use

Day 6 summary & database operation

Database operation topic 2

sklearn 机器学习基础(线性回归、欠拟合、过拟合、岭回归、模型加载保存)

Neo eco technology monthly | help developers play smart contracts

Okaleido上线聚变Mining模式,OKA通证当下产出的唯一方式

机器学习中的概率模型

【搜索专题】看完必会的搜索问题之洪水覆盖

keepalived双机热备
随机推荐
P1825 [USACO11OPEN]Corn Maze S
Ueditot_ JSP SSRF vulnerability recurrence
【FreeSwitch开发实践】使用SIP客户端Yate连接FreeSwitch进行VoIP通话
数据库操作技能7
Espressif 玩转 编译环境
[recommended collection] MySQL 30000 word essence summary + 100 interview questions (I)
Memory management based on C language - Simulation of dynamic partition allocation
Analysis on the query method and efficiency of Oracle about date type
Oracle 19C OCP 1z0-082 certification examination question bank (13-18)
海内外媒体宣发自媒体发稿要严格把握内容关
Cadence (x) wiring skills and precautions
In the first year of L2, the upgrade of arbitrum nitro brought a more compatible and efficient development experience
node的js文件引入
[recommended collection] MySQL 30000 word essence summary index (II) [easy to understand]
Learning notes of automatic control principle --- linear discrete system
基于C语言的哈夫曼转化软件
合工大苍穹战队视觉组培训Day6——传统视觉,图像处理
03异常处理,状态保持,请求钩子---04大型项目结构与蓝图
Foundry教程:使用多种方式编写可升级的智能合约(上)
Oracle 19C OCP 1z0-082 certification examination question bank (42-50)