当前位置:网站首页>Pytorch每日一练——预测泰坦尼克号船上的生存乘客
Pytorch每日一练——预测泰坦尼克号船上的生存乘客
2022-07-31 05:16:00 【qq_50749521】
训练数据:
Survived是输出标签,其他年龄、性别、名字等等都当做输入。当然会有数据缺失的情况,需要提前进行清洗。
测试的目的就是输入样本特征,输出是否能生存下来(0或1)
import torch
import pandas as pd
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
class DiabetesDataset(Dataset):
def __init__(self, filepath):
xy = pd.read_csv(filepath)
self.len = xy.shape[0]
features = ["Pclass", "Sex", "SibSp", "Parch", "Fare"]
self.x_data = torch.from_numpy(np.array(pd.get_dummies(xy[features])))
self.y_data = torch.from_numpy(np.array(xy['Survived']))
def __getitem__(self, index):
return self.x_data[index], self.y_data[index]
def __len__(self):
return self.len
dataset = DiabetesDataset('Dataset\\titanic\\train.csv')
train_loader = DataLoader(dataset = dataset,
batch_size = 32,
shuffle = True,
num_workers = 0)
batch_size = 32
batch = np.round(dataset.__len__() / batch_size)
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear1 = torch.nn.Linear(6, 4)
self.linear2 = torch.nn.Linear(4, 2)
self.linear3 = torch.nn.Linear(2, 1)
self.relu = torch.nn.ReLU()
self.sigmoid = torch.nn.Sigmoid()
def forward(self, x):
x = self.relu(self.linear1(x))
x = self.relu(self.linear2(x))
x = self.sigmoid(self.linear3(x))#注意最后一步不能使用relu,避免无法计算梯度
return x
mymodel = Model()
criterion = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.SGD(mymodel.parameters(), lr = 0.01)
epoch_list = []
loss_list = []
sum_loss = 0
if __name__ == '__main__':
for epoch in range(500):
for index, data in enumerate(train_loader, 0): #train_loader存的是分割组合后的小批量训练样本和对应的标签
inputs, labels = data #inputs labels都是张量
inputs = inputs.float()
labels = labels.float()
y_pred = mymodel(inputs)
y_pred = y_pred.squeeze(-1)
loss = criterion(y_pred, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
sum_loss += loss.item()
print('epoch = ', epoch + 1,'index = ', index+1, 'loss = ', loss.item())
epoch_list.append(epoch)
loss_list.append(sum_loss/batch)
print(sum_loss/batch)
sum_loss = 0
test_x = pd.read_csv('Dataset\\titanic\\test.csv')
features = ["Pclass", "Sex", "SibSp", "Parch", "Fare"]
test_x_data = torch.from_numpy(np.array(pd.get_dummies(test_x[features])))
test_x_data = test_x_data.float()
y_test_pred = mymodel(test_x_data)
len_y = y_test_pred.shape[0]
y = []
for i in range(len_y):
if(y_test_pred[i].item()<0.5):
y.append(0)
else:
y.append(1)
for i in range(len(y)):
print(y[i])
最后把输出的y保存到gender_submission.csv中,提交kaggle即可。
刚开始练习基础,后面再慢慢改进…
边栏推荐
- Principle analysis of famous website msdn.itellyou.cn
- Using IIS10 to build an asp website in win11
- 活体检测FaceBagNet阅读笔记
- Pure shell implementation of text replacement
- js中的全局作用域与函数作用域
- cocos2d-x-3.2 create project method
- 为什么bash中的read要配合while才能读取/dev/stdin的内容
- qt:cannot open C:\Users\某某某\AppData\Local\Temp\main.obj.15576.16.jom for write
- VS connects to MYSQL through ODBC (2)
- TransactionTemplate transaction programmatic way
猜你喜欢

一文速学-玩转MySQL获取时间、格式转换各类操作方法详解
![[Cloud native] Simple introduction and use of microservice Nacos](/img/06/b0594208d5b0cbf3ae8edd80ec12c4.png)
[Cloud native] Simple introduction and use of microservice Nacos

The latest MySql installation teaching, very detailed

np.fliplr与np.flipud

WeChat applet source code acquisition and decompilation method

VS2017 connects to MYSQL

After unicloud is released, the applet prompts that the connection to the local debugging service failed. Please check whether the client and the host are under the same local area network.

动态规划(一)| 斐波那契数列和归递

Multi-Modal Face Anti-Spoofing Based on Central Difference Networks学习笔记

unicloud 云开发记录
随机推荐
微信小程序源码获取与反编译方式
cocos2d-x-3.2创建项目方法
人脸识别AdaFace学习笔记
2021年京东数据分析工程师秋招笔试编程题
DC-CDN学习笔记
微信小程序启动优化
quick-3.5 无法使用模拟器修改
MySQL面试题大全(陆续更新)
VS connects to MYSQL through ODBC (2)
The latest MySql installation teaching, very detailed
360 加固 file path not exists.
The server time zone value ‘й‘ is unrecognized or represents more than one time zone
powershell统计文件夹大小
After unicloud is released, the applet prompts that the connection to the local debugging service failed. Please check whether the client and the host are under the same local area network.
VS2017连接MYSQL
SSH自动重连脚本
使用ps | egrep时过滤排除掉egrep自身
VS connects to MYSQL through ODBC (1)
Sqlite column A data is copied to column B
VTK:Could not locate vtkTextRenderer object.