当前位置:网站首页>Pytorch-温度预测
Pytorch-温度预测
2022-07-06 09:16:00 【想成为风筝】
pytorch-温度预测
dir = r'E:\PyTorch\02\02.2020深度学习-PyTorch实战\代码+资料\神经网络实战分类与回归任务\temps.csv'
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
Data = pd.read_csv(dir)
print(Data)
print(Data.head())
del Data['friend']
print(Data)
print('数据维度:',Data.shape)
#处理时间数据
import datetime
years = Data['year']
months = Data['month']
days = Data['day']
#datetime格式
dates = [str(int(year)) + '-' + str(int(month)) + '-' + str(int(day)) for year,month,day in zip(years,months,days)]
dates = [datetime.datetime.strptime(date,'%Y-%m-%d') for date in dates]
print(dates[:5])
#绘图
#指定默认风格
# plt.style.use('fivethirtyeight')
# plt.figure(dpi=400)
# #plt.subplot(2,2,1)
# plt.plot(dates,Data['actual'])
# plt.xlabel('')
# plt.ylabel('Temperature')
# plt.title('Max Temp')
# plt.show()
#
# plt.figure(dpi=400)
# #plt.subplot(2,2,2)
# plt.plot(dates,Data['temp_1'])
# plt.xlabel('')
# plt.ylabel('Temperature')
# plt.title('Previous Max Temp')
# plt.show()
#
# plt.figure(dpi=400)
# #plt.subplot(2,2,3)
# plt.plot(dates,Data['temp_2'])
# plt.xlabel('')
# plt.ylabel('Temperature')
# plt.title('Two Days Prior Max Temp')
#
# plt.show()
# week 列 比较特殊 是 字符串类型
#独热编码
Data = pd.get_dummies(Data)
print(Data.head())
#标签
labels = np.array(Data['actual'])
Data = Data.drop('actual',axis=1)
Data_columns_name = list(Data.columns)
Data = np.array(Data)
print(Data.shape)
from sklearn import preprocessing
input_Data = preprocessing.StandardScaler().fit_transform(Data)
#构建模型
x = torch.tensor(input_Data,dtype=float)
y = torch.tensor(labels,dtype=float)
""" #自定义计算方式构建模型 #权重参数初始化 weights = torch.randn((13,128),dtype=float,requires_grad=True) biases = torch.randn(128,dtype=float,requires_grad=True) weights1 = torch.randn((128,1),dtype=float,requires_grad=True) biases1 = torch.randn(1,dtype=float,requires_grad=True) lr = 0.001 losses = [] for i in range(1000): hidden = x.mm(weights) + biases hidden = torch.relu(hidden) predictions = hidden.mm(weights1) + biases1 loss = torch.mean((predictions-y)**2) losses.append(loss.data.numpy()) #打印损失值 if i % 100 == 0: print('loss:',loss) loss.backward() #更新参数 weights.data.add_(-lr*weights.grad.data) biases.data.add_(-lr*biases.grad.data) weights1.data.add_(-lr*weights1.grad.data) biases1.data.add_(-lr*biases1.grad.data) #每次迭代要清空梯度 weights.grad.data.zero_() biases.grad.data.zero_() weights1.grad.data.zero_() biases1.grad.data.zero_() """
#简洁化网络模型
input_size = input_Data.shape[1]
hidden_size = 128
output_size = 1
batch_size = 16
my_nn = torch.nn.Sequential(
torch.nn.Linear(input_size,hidden_size),
torch.nn.Sigmoid(),
torch.nn.Linear(hidden_size,output_size)
)
cost = torch.nn.MSELoss(reduction='mean')
optimizer= torch.optim.Adam(my_nn.parameters(),lr=0.001)
#training
losses = []
for i in range(1000):
batch_loss = []
for start in range(0,len(input_Data),batch_size):
end = start + batch_size if start + batch_size < len(input_Data) else len(input_Data)
xx = torch.tensor(input_Data[start:end],dtype=torch.float,requires_grad=True)
yy = torch.tensor(labels[start:end],dtype=torch.float,requires_grad=True)
prediction = my_nn(xx)
loss = cost(prediction,yy)
optimizer.zero_grad()
loss.backward()
optimizer.step()
batch_loss.append(loss.data.numpy())
if i % 100 ==0:
losses.append(np.mean(batch_loss))
print(i,np.mean(batch_loss))
#预测训练结果
x = torch.tensor(input_Data,dtype=torch.float)
predict = my_nn(x).data.numpy()
#转换日期格式
dates = [str(int(year)) + '-' +str(int(month)) + '-' + str(int(day)) for year,month,day in zip(years,months,days)]
dates = [datetime.datetime.strptime(date,'%Y-%m-%d') for date in dates]
#创建一个表格来存日期和对应的标签数值
true_data = pd.DataFrame(data={
'date':dates,'actual':labels})
#再创建一个来寸日期和对应模型的预测值
months =Data[:,Data_columns_name.index('month')]
days = Data[:,Data_columns_name.index('day')]
years = Data[:,Data_columns_name.index('year')]
test_dates = [str(int(year)) + '-' +str(int(month)) + '-' + str(int(day)) for year,month,day in zip(years,months,days)]
test_dates = [datetime.datetime.strptime(date,'%Y-%m-%d') for date in test_dates]
#创建一个表格来存日期和对应的标签数值
predict_data = pd.DataFrame(data={
'date':test_dates,'prediction':predict.reshape(-1)})
plt.figure(dpi=400)
#True
plt.plot(true_data['date'],true_data['actual'],'b-',label='actual')
#predict
plt.plot(predict_data['date'],predict_data['prediction'],'ro',label='prediction')
plt.xticks()
plt.legend()
plt.xlabel('Date')
plt.ylabel('Maximum Temperature (F)')
plt.title('Actual and Predicted Values')
plt.show()
边栏推荐
- MySQL and C language connection (vs2019 version)
- Word排版(小計)
- [yarn] CDP cluster yarn configuration capacity scheduler batch allocation
- Implementation scheme of distributed transaction
- [NPUCTF2020]ReadlezPHP
- 互聯網協議詳解
- [mrctf2020] dolls
- wangeditor富文本组件-复制可用
- 保姆级出题教程
- Kaggle竞赛-Two Sigma Connect: Rental Listing Inquiries(XGBoost)
猜你喜欢
[template] KMP string matching
Case analysis of data inconsistency caused by Pt OSC table change
[yarn] CDP cluster yarn configuration capacity scheduler batch allocation
Connexion sans mot de passe du noeud distribué
Vs2019 desktop app quick start
FTP文件上传文件实现,定时扫描文件夹上传指定格式文件文件到服务器,C语言实现FTP文件上传详解及代码案例实现
Kaggle竞赛-Two Sigma Connect: Rental Listing Inquiries(XGBoost)
Double to int precision loss
Vs2019 use wizard to generate an MFC Application
人脸识别 face_recognition
随机推荐
Funny cartoon: Programmer's logic
Yarn installation and use
When using lambda to pass parameters in a loop, the parameters are always the same value
Codeforces Round #753 (Div. 3)
库函数--(持续更新)
C语言读取BMP文件
double转int精度丢失问题
Case analysis of data inconsistency caused by Pt OSC table change
[NPUCTF2020]ReadlezPHP
[Flink] cdh/cdp Flink on Yan log configuration
【kerberos】深入理解kerberos票据生命周期
Some concepts often asked in database interview
[Blue Bridge Cup 2017 preliminary] buns make up
MongoDB
Contiki source code + principle + function + programming + transplantation + drive + network (turn)
mysql实现读写分离
Are you monitored by the company for sending resumes and logging in to job search websites? Deeply convinced that the product of "behavior awareness system ba" has not been retrieved on the official w
SQL time injection
Hutool中那些常用的工具类和方法
互联网协议详解