当前位置:网站首页>Pytorch Daily Practice - Predicting Surviving Passengers on the Titanic
Pytorch Daily Practice - Predicting Surviving Passengers on the Titanic
2022-07-31 06:32:00 【qq_50749521】
训练数据:
Survived是输出标签,other age、性别、Names, etc. are treated as input.Of course there will be missing data,It needs to be cleaned in advance.
The purpose of the test is to input the sample features,Whether the output can survive(0或1)
import torch
import pandas as pd
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
class DiabetesDataset(Dataset):
def __init__(self, filepath):
xy = pd.read_csv(filepath)
self.len = xy.shape[0]
features = ["Pclass", "Sex", "SibSp", "Parch", "Fare"]
self.x_data = torch.from_numpy(np.array(pd.get_dummies(xy[features])))
self.y_data = torch.from_numpy(np.array(xy['Survived']))
def __getitem__(self, index):
return self.x_data[index], self.y_data[index]
def __len__(self):
return self.len
dataset = DiabetesDataset('Dataset\\titanic\\train.csv')
train_loader = DataLoader(dataset = dataset,
batch_size = 32,
shuffle = True,
num_workers = 0)
batch_size = 32
batch = np.round(dataset.__len__() / batch_size)
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear1 = torch.nn.Linear(6, 4)
self.linear2 = torch.nn.Linear(4, 2)
self.linear3 = torch.nn.Linear(2, 1)
self.relu = torch.nn.ReLU()
self.sigmoid = torch.nn.Sigmoid()
def forward(self, x):
x = self.relu(self.linear1(x))
x = self.relu(self.linear2(x))
x = self.sigmoid(self.linear3(x))#注意最后一步不能使用relu,避免无法计算梯度
return x
mymodel = Model()
criterion = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.SGD(mymodel.parameters(), lr = 0.01)
epoch_list = []
loss_list = []
sum_loss = 0
if __name__ == '__main__':
for epoch in range(500):
for index, data in enumerate(train_loader, 0): #train_loaderWhat is stored is the split and combined mini-batch training samples and the corresponding labels
inputs, labels = data #inputs labels都是张量
inputs = inputs.float()
labels = labels.float()
y_pred = mymodel(inputs)
y_pred = y_pred.squeeze(-1)
loss = criterion(y_pred, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
sum_loss += loss.item()
print('epoch = ', epoch + 1,'index = ', index+1, 'loss = ', loss.item())
epoch_list.append(epoch)
loss_list.append(sum_loss/batch)
print(sum_loss/batch)
sum_loss = 0
test_x = pd.read_csv('Dataset\\titanic\\test.csv')
features = ["Pclass", "Sex", "SibSp", "Parch", "Fare"]
test_x_data = torch.from_numpy(np.array(pd.get_dummies(test_x[features])))
test_x_data = test_x_data.float()
y_test_pred = mymodel(test_x_data)
len_y = y_test_pred.shape[0]
y = []
for i in range(len_y):
if(y_test_pred[i].item()<0.5):
y.append(0)
else:
y.append(1)
for i in range(len(y)):
print(y[i])
Finally put the outputy保存到gender_submission.csv中,提交kaggle即可.
Just started practicing the basics,Improve slowly later…
边栏推荐
- random.randint函数用法
- Phospholipids-Polyethylene Glycol-Active Esters for Scientific Research DSPE-PEG-NHS CAS: 1445723-73-8
- RuntimeError: CUDA error: no kernel image is available for execution on the device问题记录
- CAS:1403744-37-5 DSPE-PEG-FA 科研实验用磷脂-聚乙二醇-叶酸
- DSPE-PEG-Thiol DSPE-PEG-SH phospholipid-polyethylene glycol-thiol liposome for later use
- 2022 SQL big factory high-frequency practical interview questions (detailed analysis)
- Cholesterol-PEG-Amine CLS-PEG-NH2 Cholesterol-Polyethylene Glycol-Amino Research Use
- mPEG-DSPE 178744-28-0 Methoxy-polyethylene glycol-phosphatidylethanolamine linear PEG phospholipids
- WeChat applet source code acquisition and decompilation method
- IDEA控制台不能输入信息的解决方法
猜你喜欢

The content of the wangeditor editor is transferred to the background server for storage

jenkins +miniprogram-ci upload WeChat applet with one click

CAS:1403744-37-5 DSPE-PEG-FA 科研实验用磷脂-聚乙二醇-叶酸

钉钉企业内部-H5微应用开发

Numpy常用函数

CLS-PEG-FITC Fluorescein-PEG-CLS 胆固醇-聚乙二醇-荧光素简介

MW: 3400 4-Arm PEG-DSPE four-arm-polyethylene glycol-phospholipid a saturated 18-carbon phospholipid

pytorch学习笔记10——卷积神经网络详解及mnist数据集多分类任务应用

Cholesterol-PEG-Acid CLS-PEG-COOH 胆固醇-聚乙二醇-羧基修饰肽类化合物

Learn how to get a database connection with JDBC
随机推荐
机器学习和深度学习概述
这些数组技巧,我爱了
浏览器中的画中画(Picture-in-Picture)API
mysql 事务原理详解
Tensorflow边用边踩坑
我的训练函数模板(动态修改学习率、参数初始化、优化器选择)
CAS:474922-22-0 Maleimide-PEG-DSPE Phospholipid-Polyethylene Glycol-Maleimide Brief Description
如何修改数据库密码
cocos2d-x implements cross-platform directory traversal
The solution to the IDEA console not being able to enter information
RuntimeError: CUDA error: no kernel image is available for execution on the device问题记录
DSPE-PEG-Thiol DSPE-PEG-SH phospholipid-polyethylene glycol-thiol liposome for later use
Podspec verification dependency error problem pod lib lint , need to specify the source
MySQL 免安装版的下载与配置教程
When solving background-size:cover, the picture is covered but not displayed completely?
Talking about the understanding of CAP in distributed mode
VTK环境配置
jenkins +miniprogram-ci upload WeChat applet with one click
The content of the wangeditor editor is transferred to the background server for storage
Wangeditor rich text editor to upload pictures and solve cross-domain problems