当前位置:网站首页>Pytorch每日一练——预测泰坦尼克号船上的生存乘客
Pytorch每日一练——预测泰坦尼克号船上的生存乘客
2022-07-31 05:16:00 【qq_50749521】
训练数据:
Survived是输出标签,其他年龄、性别、名字等等都当做输入。当然会有数据缺失的情况,需要提前进行清洗。
测试的目的就是输入样本特征,输出是否能生存下来(0或1)
import torch
import pandas as pd
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
class DiabetesDataset(Dataset):
def __init__(self, filepath):
xy = pd.read_csv(filepath)
self.len = xy.shape[0]
features = ["Pclass", "Sex", "SibSp", "Parch", "Fare"]
self.x_data = torch.from_numpy(np.array(pd.get_dummies(xy[features])))
self.y_data = torch.from_numpy(np.array(xy['Survived']))
def __getitem__(self, index):
return self.x_data[index], self.y_data[index]
def __len__(self):
return self.len
dataset = DiabetesDataset('Dataset\\titanic\\train.csv')
train_loader = DataLoader(dataset = dataset,
batch_size = 32,
shuffle = True,
num_workers = 0)
batch_size = 32
batch = np.round(dataset.__len__() / batch_size)
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear1 = torch.nn.Linear(6, 4)
self.linear2 = torch.nn.Linear(4, 2)
self.linear3 = torch.nn.Linear(2, 1)
self.relu = torch.nn.ReLU()
self.sigmoid = torch.nn.Sigmoid()
def forward(self, x):
x = self.relu(self.linear1(x))
x = self.relu(self.linear2(x))
x = self.sigmoid(self.linear3(x))#注意最后一步不能使用relu,避免无法计算梯度
return x
mymodel = Model()
criterion = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.SGD(mymodel.parameters(), lr = 0.01)
epoch_list = []
loss_list = []
sum_loss = 0
if __name__ == '__main__':
for epoch in range(500):
for index, data in enumerate(train_loader, 0): #train_loader存的是分割组合后的小批量训练样本和对应的标签
inputs, labels = data #inputs labels都是张量
inputs = inputs.float()
labels = labels.float()
y_pred = mymodel(inputs)
y_pred = y_pred.squeeze(-1)
loss = criterion(y_pred, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
sum_loss += loss.item()
print('epoch = ', epoch + 1,'index = ', index+1, 'loss = ', loss.item())
epoch_list.append(epoch)
loss_list.append(sum_loss/batch)
print(sum_loss/batch)
sum_loss = 0
test_x = pd.read_csv('Dataset\\titanic\\test.csv')
features = ["Pclass", "Sex", "SibSp", "Parch", "Fare"]
test_x_data = torch.from_numpy(np.array(pd.get_dummies(test_x[features])))
test_x_data = test_x_data.float()
y_test_pred = mymodel(test_x_data)
len_y = y_test_pred.shape[0]
y = []
for i in range(len_y):
if(y_test_pred[i].item()<0.5):
y.append(0)
else:
y.append(1)
for i in range(len(y)):
print(y[i])
最后把输出的y保存到gender_submission.csv中,提交kaggle即可。
刚开始练习基础,后面再慢慢改进…
边栏推荐
猜你喜欢
计网 Packet Tracer仿真 | 简单易懂集线器和交换机对比(理论+仿真)
Understanding of objects and functions in js
The server time zone value ‘й‘ is unrecognized or represents more than one time zone
VS2017 connects to MYSQL
Take you to understand the MySQL isolation level, what happens when two transactions operate on the same row of data at the same time?
CNN的一点理解
朴素贝叶斯文本分类(代码实现)
unicloud 云开发记录
自定dialog 布局没有居中解决方案
VTK环境配置
随机推荐
Why does read in bash need to cooperate with while to read the contents of /dev/stdin
Several forms of Attribute Changer
VTK:Could not locate vtkTextRenderer object.
SSH automatic reconnection script
深度学习知识点杂谈
quick-3.5 无法正常显示有混合纹理的csb文件
powershell统计文件夹大小
360 加固 file path not exists.
quick lua加密
自定dialog 布局没有居中解决方案
cocos2d-x-3.x 修改和纪录
Eternal blue bug reappears
Several solutions for mysql startup error The server quit without updating PID file
random.randint函数用法
quick-3.5 lua调用c++
MYSQL transaction and lock problem handling
VS2017连接MYSQL
[Cloud native] Simple introduction and use of microservice Nacos
Podspec automatic upgrade script
UiBot has an open Microsoft Edge browser and cannot perform the installation