当前位置:网站首页>回归预测 | MATLAB实现TPA-LSTM(时间注意力注意力机制长短期记忆神经网络)多输入单输出
回归预测 | MATLAB实现TPA-LSTM(时间注意力注意力机制长短期记忆神经网络)多输入单输出
2022-08-01 10:11:00 【机器学习之心】
回归预测 | MATLAB实现TPA-LSTM(时间注意力注意力机制长短期记忆神经网络)多输入单输出
预测效果





基本介绍
注意力机制模仿人脑,更加注重重要信息,而忽略相对无用的信息,已被广泛应用于自然语言处理、图像及语音识别中,近年来也被广泛应用于各类预测问题。传统注意力机制注重不同时间点的权重分布,在每个时间步只含有一个变量时有较好的效果。但对于区域内的多风电机组功率预测,每个时间步都含有多个变量,各个变量之间可能存在复杂的非线性内在联系,且每个变量序列都有自己的特征和周期,难以单独选取某个时间步作为注意重点。而TPA则由多个一维CNN滤波器从BiLSTM隐藏状态行向量抽取特征,使得模型能够从不同时间步学习多变量之间的互相依赖关系。
环境介绍
运行环境,Matlab2020b。
程序设计
- 完整程序下载:TPA-LSTM
% 数据集 列为特征,行为样本数目
%% 数据导入及处理
load('./Train.mat')
Train.weekend = dummyvar(Train.weekend);
Train.month = dummyvar(Train.month);
Train = movevars(Train,{
'weekend','month'},'After','demandLag');
Train.ts = [];
% Train.hour = dummyvar(Train.hour);
%自己主动观察右侧工作区变量格式,对前面数据进行更改替换
Train(1,:) =[];
y = Train.demand;
x = Train{
:,2:5};
[xnorm,xopt] = mapminmax(x',0,1);
[ynorm,yopt] = mapminmax(y',0,1);
%
% xnorm = [xnorm;Train.weekend';Train.month'];
%%
% x = x';
xnorm = xnorm(:,1:1000);
ynorm = ynorm(1:1000);
k = 24; % 滞后长度
% 转换成2-D image
for i = 1:length(ynorm)-k
Train_xNorm(:,i,:) = xnorm(:,i:i+k-1);
Train_yNorm(i) = ynorm(i+k-1);
Train_y(i) = y(i+k-1);
end
Train_yNorm= Train_yNorm';
ytest = Train.demand(1001:1170);
xtest = Train{
1001:1170,2:5};
[xtestnorm] = mapminmax('apply', xtest',xopt);
[ytestnorm] = mapminmax('apply',ytest',yopt);
% xtestnorm = [xtestnorm; Train.weekend(1001:1170,:)'; Train.month(1001:1170,:)'];
xtest = xtest';
for i = 1:length(ytestnorm)-k
Test_xNorm(:,i,:) = xtestnorm(:,i:i+k-1);
Test_yNorm(i) = ytestnorm(i+k-1);
Test_y(i) = ytest(i+k-1);
end
Test_yNorm = Test_yNorm';
clear k i x y
%
Train_xNorm = dlarray(Train_xNorm,'CBT');
Train_yNorm = dlarray(Train_yNorm,'BC');
Test_xNorm = dlarray(Test_xNorm,'CBT');
Test_yNorm = dlarray(Test_yNorm,'BC');
%% 训练集和验证集划分
TrainSampleLength = length(Train_yNorm);
validatasize = floor(TrainSampleLength * 0.1);
Validata_xNorm = Train_xNorm(:,end - validatasize:end,:);
Validata_yNorm = Train_yNorm(:,TrainSampleLength-validatasize:end);
Validata_y = Train_y(TrainSampleLength-validatasize:end);
%参数设置
inputSize = size(Train_xNorm,1); %数据输入x的特征维度
outputSize = 1; %数据输出y的维度
numhidden_units1=50;
[params,~] = paramsInit(numhidden_units1,inputSize,outputSize); % 导入初始化参数
[~,validatastate] = paramsInit(numhidden_units1,inputSize,outputSize); % 导入初始化参数
[~,TestState] = paramsInit(numhidden_units1,inputSize,outputSize); % 导入初始化参数
% 训练相关参数
TrainOptions;
numIterationsPerEpoch = floor((TrainSampleLength-validatasize)/minibatchsize);
LearnRate = 0.01;
%% Loop over epochs.
figure
start = tic;
lineLossTrain = animatedline('color','r');
validationLoss = animatedline('color',[0 0 0]./255,'Marker','o','MarkerFaceColor',[150 150 150]./255);
xlabel("Iteration")
ylabel("Loss")
% epoch 更新
iteration = 0;
for epoch = 1 : numEpochs
[~,state] = paramsInit(numhidden_units1,inputSize,outputSize); % 每轮epoch,state初始化
disp(['Epoch: ', int2str(epoch)])
% 作图(训练过程损失图)--------------------------********————————————————————————————————————————————————
D = duration(0,0,toc(start),'Format','hh:mm:ss');
addpoints(lineLossTrain,iteration,double(gather(extractdata(loss))))
if iteration == 1 || mod(iteration,validationFrequency) == 0
addpoints(validationLoss,iteration,double(gather(extractdata(lossValidation))))
end
title("Epoch: " + epoch + ", Elapsed: " + string(D))
legend('训练集','验证集')
drawnow
end
% 每轮epoch 更新学习率
if mod(epoch,5) == 0
LearnRate = LearnRate * LearnRateDropFactor;
end
end
%% 训练集
Predict_yNorm = TPAModelPredict(gpuArray(Train_xNorm),params,TestState);
Predict_yNorm = extractdata(Predict_yNorm);
Predict_y = mapminmax('reverse',Predict_yNorm,yopt);
%
figure
plot(Predict_y,'-.','Color',[50 100 180]./255,'linewidth',1.5,'Markersize',3,'MarkerFaceColor',[50 100 180]./255);
hold on
plot(Train_y,'-.','Color',[150 150 150]./255,'linewidth',1.5,'Markersize',3,'MarkerFaceColor',[150 150 150]./255)
legend('训练集预测值','训练集实际值')
%% 验证集
Predict_yNorm = TPAModelPredict(gpuArray(Validata_xNorm),params,TestState);
Predict_yNorm = extractdata(Predict_yNorm);
Predict_y = mapminmax('reverse',Predict_yNorm,yopt);
%
figure
plot(Predict_y,'-.','Color',[255 0 0]./255,'linewidth',1.5,'Markersize',3,'MarkerFaceColor',[255 0 0]./255);
hold on
plot(Validata_y,'--','Color',[150 150 150]./255,'linewidth',1.5,'Markersize',3,'MarkerFaceColor',[0 0 0]./255)
legend('验证集预测值','验证集实际值')
%% predict(传统递归测试集)
% clear Predict_yNorm
% % Test_xNorm = extractdata(Test_xNorm);
% for i = 1:size(Test_xNorm,2)
% if i ==1
% [a,Teststate] = TPAModelPredict(dlarray(Test_xNorm(:,i,:),'CBT'),params,TestState);
%% predict(直接测试集)
Predict_yNorm = TPAModelPredict(dlarray(Test_xNorm,'CBT'),params,TestState);
Predict_YNorm = extractdata(Predict_yNorm);
Predict_y = mapminmax('reverse',Predict_YNorm,yopt);
figure
plot(Predict_y,'-.','Color',[0 0 255]./255,'linewidth',1.5,'Markersize',3,'MarkerFaceColor',[0 0 255]./255);
hold on
plot(Test_y,'--','Color',[150 150 150]./255,'linewidth',1.5,'Markersize',3,'MarkerFaceColor',[0 0 0]./255)
legend('测试集预测值','测试集实际值')
参考资料
[1] https://blog.csdn.net/kjm13182345320/article/details/125644313?spm=1001.2014.3001.5501
[2] https://blog.csdn.net/kjm13182345320/article/details/125637228?spm=1001.2014.3001.5501
[3] https://download.csdn.net/download/kjm13182345320/85661169?spm=1001.2014.3001.5501
边栏推荐
- 什么是步进电机?40张图带你了解!
- Drawing arrows of WPF screenshot control (5) "Imitation WeChat"
- Naive Bayes--Study Notes--Basic Principles and Code Implementation
- rpm and yum
- AC与瘦AP的WLAN组网实验
- AI篮球裁判火了,走步算得特别准,就问哈登慌不慌
- 数仓分层简介(实时数仓架构)
- Mysql索引相关的知识复盘一
- DBPack SQL Tracing 功能及数据加密功能详解
- 已解决(pip安装库报错)Consider using the-- user option or check the permissions.
猜你喜欢

gc的意义和触发条件

Batch大小不一定是2的n次幂!ML资深学者最新结论

世界第4疯狂的科学家,在103岁生日那天去世了

【钛晨报】国家统计局:7月制造业PMI为49%;玖富旗下理财产品涉嫌欺诈,涉及390亿元;国内航线机票燃油附加费8月5日0时起下调

Node's traditional and advanced practices for formatting time (moment)

Mysql index related knowledge review one

使用ESP32驱动QMA7981读取三轴加速度(带例程)

什么是步进电机?40张图带你了解!

C language game - minesweeper

C#/VB.NET convert PPT or PPTX to image
随机推荐
How I secured 70,000 ETH and won a 6 million bug bounty
WLAN networking experiment of AC and thin AP
Guangyu Mingdao was selected into the list of pilot demonstration projects for the development of digital economy industry in Chongqing in 2022
mysql在cmd的登录及数据库与表的基本操作
[Software Architecture Mode] The difference between MVVM mode and MVC mode
小程序毕设作品之微信美食菜谱小程序毕业设计成品(2)小程序功能
如何解决 chrome 浏览器标签过多无法查看到标题的情况
JWT
阿里腾讯面试一二
数仓分层简介(实时数仓架构)
Shell:条件测试操作
slice、splice、split傻傻分不清
STM32 Personal Notes - Embedded C Language Optimization
InputStream转成String
Quantify daily work metrics
EasyRecovery热门免费数据检测修复软件
MFC实现交通图导航系统
Mysql索引相关的知识复盘一
Node's traditional and advanced practices for formatting time (moment)
SkiaSharp's WPF self-painted five-ring bouncing ball (case version)