当前位置:网站首页>Pytoch temperature prediction
Pytoch temperature prediction
2022-07-06 12:00:00 【Want to be a kite】
pytorch- Temperature prediction
dir = r'E:\PyTorch\02\02.2020 Deep learning -PyTorch actual combat \ Code + Information \ Neural network combat classification and regression task \temps.csv'
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
Data = pd.read_csv(dir)
print(Data)
print(Data.head())
del Data['friend']
print(Data)
print(' Data dimension :',Data.shape)
# Processing time data
import datetime
years = Data['year']
months = Data['month']
days = Data['day']
#datetime Format
dates = [str(int(year)) + '-' + str(int(month)) + '-' + str(int(day)) for year,month,day in zip(years,months,days)]
dates = [datetime.datetime.strptime(date,'%Y-%m-%d') for date in dates]
print(dates[:5])
# mapping
# Specify the default style
# plt.style.use('fivethirtyeight')
# plt.figure(dpi=400)
# #plt.subplot(2,2,1)
# plt.plot(dates,Data['actual'])
# plt.xlabel('')
# plt.ylabel('Temperature')
# plt.title('Max Temp')
# plt.show()
#
# plt.figure(dpi=400)
# #plt.subplot(2,2,2)
# plt.plot(dates,Data['temp_1'])
# plt.xlabel('')
# plt.ylabel('Temperature')
# plt.title('Previous Max Temp')
# plt.show()
#
# plt.figure(dpi=400)
# #plt.subplot(2,2,3)
# plt.plot(dates,Data['temp_2'])
# plt.xlabel('')
# plt.ylabel('Temperature')
# plt.title('Two Days Prior Max Temp')
#
# plt.show()
# week Column A special yes String type
# Hot coding alone
Data = pd.get_dummies(Data)
print(Data.head())
# label
labels = np.array(Data['actual'])
Data = Data.drop('actual',axis=1)
Data_columns_name = list(Data.columns)
Data = np.array(Data)
print(Data.shape)
from sklearn import preprocessing
input_Data = preprocessing.StandardScaler().fit_transform(Data)
# Build the model
x = torch.tensor(input_Data,dtype=float)
y = torch.tensor(labels,dtype=float)
""" # Customize the calculation method to build the model # Weight parameter initialization weights = torch.randn((13,128),dtype=float,requires_grad=True) biases = torch.randn(128,dtype=float,requires_grad=True) weights1 = torch.randn((128,1),dtype=float,requires_grad=True) biases1 = torch.randn(1,dtype=float,requires_grad=True) lr = 0.001 losses = [] for i in range(1000): hidden = x.mm(weights) + biases hidden = torch.relu(hidden) predictions = hidden.mm(weights1) + biases1 loss = torch.mean((predictions-y)**2) losses.append(loss.data.numpy()) # Print loss value if i % 100 == 0: print('loss:',loss) loss.backward() # Update parameters weights.data.add_(-lr*weights.grad.data) biases.data.add_(-lr*biases.grad.data) weights1.data.add_(-lr*weights1.grad.data) biases1.data.add_(-lr*biases1.grad.data) # Every iteration should clear the gradient weights.grad.data.zero_() biases.grad.data.zero_() weights1.grad.data.zero_() biases1.grad.data.zero_() """
# Concise network model
input_size = input_Data.shape[1]
hidden_size = 128
output_size = 1
batch_size = 16
my_nn = torch.nn.Sequential(
torch.nn.Linear(input_size,hidden_size),
torch.nn.Sigmoid(),
torch.nn.Linear(hidden_size,output_size)
)
cost = torch.nn.MSELoss(reduction='mean')
optimizer= torch.optim.Adam(my_nn.parameters(),lr=0.001)
#training
losses = []
for i in range(1000):
batch_loss = []
for start in range(0,len(input_Data),batch_size):
end = start + batch_size if start + batch_size < len(input_Data) else len(input_Data)
xx = torch.tensor(input_Data[start:end],dtype=torch.float,requires_grad=True)
yy = torch.tensor(labels[start:end],dtype=torch.float,requires_grad=True)
prediction = my_nn(xx)
loss = cost(prediction,yy)
optimizer.zero_grad()
loss.backward()
optimizer.step()
batch_loss.append(loss.data.numpy())
if i % 100 ==0:
losses.append(np.mean(batch_loss))
print(i,np.mean(batch_loss))
# Predict training results
x = torch.tensor(input_Data,dtype=torch.float)
predict = my_nn(x).data.numpy()
# Convert date format
dates = [str(int(year)) + '-' +str(int(month)) + '-' + str(int(day)) for year,month,day in zip(years,months,days)]
dates = [datetime.datetime.strptime(date,'%Y-%m-%d') for date in dates]
# Create a table to store the date and the corresponding tag value
true_data = pd.DataFrame(data={
'date':dates,'actual':labels})
# Then create an incoming date and the predicted value of the corresponding model
months =Data[:,Data_columns_name.index('month')]
days = Data[:,Data_columns_name.index('day')]
years = Data[:,Data_columns_name.index('year')]
test_dates = [str(int(year)) + '-' +str(int(month)) + '-' + str(int(day)) for year,month,day in zip(years,months,days)]
test_dates = [datetime.datetime.strptime(date,'%Y-%m-%d') for date in test_dates]
# Create a table to store the date and the corresponding tag value
predict_data = pd.DataFrame(data={
'date':test_dates,'prediction':predict.reshape(-1)})
plt.figure(dpi=400)
#True
plt.plot(true_data['date'],true_data['actual'],'b-',label='actual')
#predict
plt.plot(predict_data['date'],predict_data['prediction'],'ro',label='prediction')
plt.xticks()
plt.legend()
plt.xlabel('Date')
plt.ylabel('Maximum Temperature (F)')
plt.title('Actual and Predicted Values')
plt.show()
边栏推荐
- Apprentissage automatique - - régression linéaire (sklearn)
- ESP8266通过arduino IED连接巴法云(TCP创客云)
- 机器学习--线性回归(sklearn)
- Redis interview questions
- [NPUCTF2020]ReadlezPHP
- 【CDH】CDH/CDP 环境修改 cloudera manager默认端口7180
- Connexion sans mot de passe du noeud distribué
- Mall project -- day09 -- order module
- Linux yum安装MySQL
- Kaggle竞赛-Two Sigma Connect: Rental Listing Inquiries
猜你喜欢
[yarn] yarn container log cleaning
锂电池基础知识
arduino UNO R3的寄存器写法(1)-----引脚电平状态变化
Linux yum安装MySQL
PyTorch四种常用优化器测试
【yarn】CDP集群 Yarn配置capacity调度器批量分配
分布式節點免密登錄
Kaggle竞赛-Two Sigma Connect: Rental Listing Inquiries(XGBoost)
[yarn] CDP cluster yarn configuration capacity scheduler batch allocation
Principle and implementation of MySQL master-slave replication
随机推荐
Stage 4 MySQL database
FTP file upload file implementation, regularly scan folders to upload files in the specified format to the server, C language to realize FTP file upload details and code case implementation
4. Install and deploy spark (spark on Yan mode)
[yarn] yarn container log cleaning
Internet protocol details
nodejs连接Mysql
ESP8266使用arduino连接阿里云物联网
【Flink】CDH/CDP Flink on Yarn 日志配置
冒泡排序【C语言】
B tree and b+ tree of MySQL index implementation
高通&MTK&麒麟 手機平臺USB3.0方案對比
Mysql的索引实现之B树和B+树
Linux Yum install MySQL
Connexion sans mot de passe du noeud distribué
Redis interview questions
Common regular expression collation
C语言,log打印文件名、函数名、行号、日期时间
MySQL START SLAVE Syntax
Détails du Protocole Internet
Mysql database interview questions