当前位置:网站首页>Pytoch realizes logistic regression
Pytoch realizes logistic regression
2022-07-26 08:54:00 【Miracle Fan】
Pytorch Realization Logistic Return to
1. The import related api
import torch
import torch.nn as nn
import numpy as np
from sklearn import datasets
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
2. Prepare the data
bc = datasets.load_breast_cancer()
X, y = bc.data, bc.target
n_samples, n_features = X.shape# Obtain the number of samples and the number of sample characteristics
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1234)# Set test set to 20%
3. Data preprocessing
3.1 Feature scaling
# Must use first fit_transform(trainData), After that transform(testData)
# If direct transform(testData), The program will report an error
# If fit_transfrom(trainData) after , Use fit_transform(testData) Instead of transform(testData), Although it can also normalize , But the two results are not in the same “ standard ” Under the , There are obvious differences .( Be sure to avoid this situation )
sc = StandardScaler()
X_train = sc.fit_transform(X_train)
X_test = sc.transform(X_test)
3.2 Data type conversion
X_train = torch.from_numpy(X_train.astype(np.float32))
X_test = torch.from_numpy(X_test.astype(np.float32))
y_train = torch.from_numpy(y_train.astype(np.float32))
y_test = torch.from_numpy(y_test.astype(np.float32))
3.3 Data format adjustment
tensor.view() Be similar to array.resize(), Here is to adjust the label line to 1 Column
# Be similar to resize()
y_train = y_train.view(y_train.shape[0], 1)
y_test = y_test.view(y_test.shape[0], 1)
4. Build a model
4.1 Define the basic network model
class Model(nn.Module):
def __init__(self, n_input_features):
super(Model, self).__init__()# For initialization nn.module
self.linear = nn.Linear(n_input_features, 1)# Define a linear perceptron , Input is n_input_features, Output as a single value
def forward(self, x):
y_pred = torch.sigmoid(self.linear(x))# Use sigmoid Output function 0~1 A value of
return y_pred
model = Model(n_features)
4.2 Define loss and optimizer
num_epochs = 100
lr = 0.01
criterion = nn.BCELoss()# Because this is the second category , So use Binary Cross Entropy
optimizer = torch.optim.SGD(model.parameters(), lr=lr)#lr For learning rate , Hyperparameters
5. model training
for epoch in range(num_epochs):
# Iterate forward and calculate the error
y_pred = model(X_train)
loss = criterion(y_pred, y_train)
# Reverse iteration 、 Update parameters
loss.backward()
optimizer.step()
# Clear the gradient before updating the parameters once
optimizer.zero_grad()
if (epoch+1) % 10 == 0:
print(f'epoch: {
epoch+1}, loss = {
loss.item():.4f}')
6. Model test
with torch.no_grad():
y_predicted = model(X_test)
y_predicted_cls = y_predicted.round()# Yes 0~1 The results between are rounded , Get classification value
acc = y_predicted_cls.eq(y_test).sum() / float(y_test.shape[0])
print(f'accuracy: {
acc.item():.4f}')
y_predicted_cls.eq(y_test) When the prediction category is equal to the category of test data, it is 1, And count all the times of correct classification , Calculate the accuracy of the test set .
边栏推荐
- MySQL 8.0 OCP 1z0-908 certification examination question bank 1
- CSDN TOP1“一个处女座的程序猿“如何通过写作成为百万粉丝博主?
- File management file system based on C #
- Ansible important components (playbook)
- 03异常处理,状态保持,请求钩子---04大型项目结构与蓝图
- Recurrence of SQL injection vulnerability in the foreground of a 60 terminal security management system
- P3743 Kotori's equipment
- Pxe原理和概念
- 第6天总结&数据库作业
- Typescript snowflake primary key generator
猜你喜欢

tcp 解决short write问题

Uni app simple mall production

File management file system based on C #

《Datawhale熊猫书》出版了!

uni-app 简易商城制作
![[encryption weekly] has the encryption market recovered? The cold winter still hasn't thawed out. Take stock of the major events that occurred in the encryption market last week](/img/d8/a367c26b51d9dbaf53bf4fe2a13917.png)
[encryption weekly] has the encryption market recovered? The cold winter still hasn't thawed out. Take stock of the major events that occurred in the encryption market last week

数据库操作 题目一

【搜索专题】看完必会的搜索问题之洪水覆盖

Kotlin properties and fields

Kotlin属性与字段
随机推荐
SSH,NFS,FTP
MySQL 8.0 OCP (1z0-908) has a Chinese exam
[untitled]
Ansible important components (playbook)
Kept dual machine hot standby
Database operation topic 1
基于C语言实现的人机交互软件
Day06 homework -- skill question 1
Analysis on the query method and efficiency of Oracle about date type
P1825 [USACO11OPEN]Corn Maze S
Vision Group Training Day5 - machine learning, image recognition project
P3743 Kotori's equipment
基于C语言的哈夫曼转化软件
Media at home and abroad publicize that we should strictly grasp the content
Logic of data warehouse zipper table
Kotlin属性与字段
Pxe原理和概念
数据库操作技能7
Oracle 19C OCP 1z0-082 certification examination question bank (24-29)
1、 Redis data structure