当前位置:网站首页>【AI实战】应用xgboost.XGBRegressor搭建空气质量预测模型(一)
【AI实战】应用xgboost.XGBRegressor搭建空气质量预测模型(一)
2022-07-03 03:13:00 【szZack】
1、xgboost.XGBRegressor 详解
xgboost.XGBRegressor 的详细参数可查看 https://xgboost.readthedocs.io/en/latest/python/python_api.html?highlight=XGBRegressor#xgboost.XGBRegressor
XGBRegressor 类:
class xgboost.XGBRegressor(*, objective='reg:squarederror', **kwargs)
核心参数包括:
n_estimators (int) – Number of gradient boosted trees. Equivalent to number of boosting rounds.
max_depth (Optional[int]) – Maximum tree depth for base learners.
learning_rate (Optional[float]) – Boosting learning rate (xgb’s “eta”)
verbosity (Optional[int]) – The degree of verbosity. Valid values are 0 (silent) - 3 (debug).
tree_method (Optional[str]) – Specify which tree method to use. Default to auto. If this parameter is set to default, XGBoost will choose the most conservative option available. It’s recommended to study this option from the parameters document tree method
n_jobs (Optional[int]) – Number of parallel threads used to run xgboost. When used with other Scikit-Learn algorithms like grid search, you may choose which algorithm to parallelize and balance the threads. Creating thread contention will significantly slow down both algorithms.
gamma (Optional[float]) – (min_split_loss) Minimum loss reduction required to make a further partition on a leaf node of the tree.
min_child_weight (Optional[float]) – Minimum sum of instance weight(hessian) needed in a child.
subsample (Optional[float]) – Subsample ratio of the training instance.
colsample_bytree (Optional[float]) – Subsample ratio of columns when constructing each tree.
scale_pos_weight (Optional[float]) – Balancing of positive and negative weights.
2、应用xgboost.XGBRegressor搭建空气质量预测模型
2.1 依赖的库
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold #k折交叉验证
from sklearn.model_selection import GridSearchCV #网格搜索
from sklearn.metrics import make_scorer
import os
import sys
import time
import math
from sklearn.metrics import r2_score
from sklearn.ensemble import GradientBoostingRegressor
import numpy as np
import warnings
warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn", lineno=193)
from sklearn.multioutput import MultiOutputRegressor
import xgboost as xgb
import joblib
from sklearn.preprocessing import MinMaxScaler
2.2 搭建空气质量预测模型
模型
使用 xgboost.XGBRegressor 作为基础模型,使用 MultiOutputRegressor 包装 XGBRegressor 从而实现多维时间输出(多目标回归 Multi target regression)模型核心代码如下:
def fit_model(self, x, y, learning_rate=0.05,
n_estimators=500,
max_depth=7,
min_child_weight=1,
gamma=0.0,
subsample=0.8,
colsample_bytree=0.8,
scale_pos_weight=0.8):
model = xgb.XGBRegressor(learning_rate=learning_rate,
n_estimators=n_estimators,
max_depth=max_depth,
min_child_weight=min_child_weight,
gamma=gamma,
subsample=subsample,
colsample_bytree=colsample_bytree,
scale_pos_weight=scale_pos_weight,
seed=42,
tree_method='gpu_hist',
gpu_id=2)
multioutput = MultiOutputRegressor(model).fit(x, y)
return multioutput
- 输入 x
shape 为 (N, W, 24)
其中 N 为数据的天数, W 为特征的维度, 24 为输入数据的小时数 - 输出 y
shape 为 (N, 24)
其中 N 为数据的天数,24 为输出数据的小时数
2.3 核心代码
# 基于 XGBRegressor 的空气质量模型
class AQXGB():
def __init__(self, factor, n_input, n_output, version):
self.n_input = n_input
self.n_output = n_output
self.version = version
self.factor = factor#空气因子
if not os.path.exists('./ml_data/'):#保存机器学习的训练数据
os.mkdir('./ml_data/')
def train(self, train_data_path, test_data_path):
x,y = self.load_data(self.version, 'train', train_data_path, self.n_input, self.n_output)
train_x,test_x,train_y,test_y = train_test_split(x,y,test_size=0.2,random_state=2022)
model = self.fit_model(train_x, train_y)
pre_y = model.predict(test_x)
#计算决策系数r方
r2 = self.performance_metric(test_y, pre_y)
print('test_r2 = ', r2)
x,y = self.load_data(self.version, 'test', test_data_path, self.n_input, self.n_output)
pre_y = model.predict(x)
r2 = self.performance_metric(y, pre_y)
print('val_r2 = ', r2)
#保存模型
joblib.dump(model, './ml_data/xgb_%s_%d_%d_%s.model' %(self.factor, self.n_input, self.n_output, self.version))
def performance_metric(self, y_true, y_predict):
# 根据需要选择评估函数
# r2
score = r2_score(y_true,y_predict)
# MSE
MSE=np.mean(( y_predict- y_true)**2)
print('RMSE: ',MSE**0.5)
#MAE
MAE=np.mean(np.abs( y_predict- y_true))
print('MAE: ',MAE)
#SMAPE
SMAPE=self.smape(y_true, y_predict)
print('SMAPE: ',SMAPE)
return score
def smape(self, A, F):
A = A.reshape(-1)
F = F.reshape(-1)
return 1.0/len(A) * np.sum(2 * np.abs(F - A) / (np.abs(A) + np.abs(F)))
def fit_model(self, x, y, learning_rate=0.05,
n_estimators=500,
max_depth=7,
min_child_weight=1,
gamma=0.0,
subsample=0.8,
colsample_bytree=0.8,
scale_pos_weight=0.8):
model = xgb.XGBRegressor(learning_rate=learning_rate,
n_estimators=n_estimators,
max_depth=max_depth,
min_child_weight=min_child_weight,
gamma=gamma,
subsample=subsample,
colsample_bytree=colsample_bytree,
scale_pos_weight=scale_pos_weight,
seed=42,
tree_method='gpu_hist',
gpu_id=2)
multioutput = MultiOutputRegressor(model).fit(x, y)
return multioutput
2.4 模型训练
训练代码
if __name__ == "__main__": if len(sys.argv) == 7: # 训练模型 # python3 src/train_xgb_model.py data/train_data.csv data/test_data.csv O3 24 24 v2 aq_model = AQXGB(sys.argv[3], int(sys.argv[4]), int(sys.argv[5]), sys.argv[6]) aq_model.train(sys.argv[1], sys.argv[2])
训练脚本
输入过去24小时是特征数据,输出未来24小时的O3的预测结果python3 src/train_xgb_model.py data/train_data.csv data/test_data.csv O3 24 24 v2
2.5 数据格式
- 数据格式
csv文件 - 示例
air_pressure,CO,humidity,AQI,monitoring_time,NO2,O3,PM10,PM25,SO2,station_number,air_temperature,wind_direction,wind_speed,longitude,latitude,station_type_name
1013.0,0.3,59.0,69.0,2019-02-01 00:00:00,15.0,80.0,88.0,26.0,8.0,xxx监测站,-0.4,205.8,1.1,116.97810856433719,36.61655020673796,shik
1013.0,0.3,58.0,68.0,2019-02-01 01:00:00,15.0,80.0,86.0,26.0,8.0,xxx监测站,-0.5,179.4,1.0,116.97810856433719,36.61655020673796,shik
1012.0,0.3,62.0,72.0,2019-02-01 02:00:00,15.0,80.0,94.0,26.0,8.0,xxx监测站,-0.9,175.7,0.8,116.97810856433719,36.61655020673796,shik
1011.0,0.3,64.0,76.0,2019-02-01 03:00:00,15.0,80.0,102.0,26.0,8.0,xxx监测站,-1.0,166.9,0.9,116.97810856433719,36.61655020673796,shik
1011.0,0.3,65.0,80.0,2019-02-01 04:00:00,15.0,80.0,110.0,26.0,8.0,xxx监测站,-0.8,191.1,0.9,116.97810856433719,36.61655020673796,shik
1011.0,0.3,66.0,84.0,2019-02-01 05:00:00,15.0,80.0,117.0,26.0,8.0,xxx监测站,-1.1,211.4,1.0,116.97810856433719,36.61655020673796,shik
1011.0,0.3,68.0,85.0,2019-02-01 06:00:00,15.0,80.0,119.0,26.0,8.0,xxx监测站,-1.4,137.3,1.3,116.97810856433719,36.61655020673796,shik
1011.0,0.3,68.0,65.75,2019-02-01 07:00:00,15.0,80.0,130.6,26.0,8.0,xxx监测站,-1.3,147.0,1.5,116.97810856433719,36.61655020673796,shik
1011.0,0.3,58.0,46.5,2019-02-01 08:00:00,15.0,80.0,142.2,26.0,8.0,xxx监测站,0.7,157.0,1.4,116.97810856433719,36.61655020673796,shik
3、其他参考
【AI实战】XGBRegressor模型加速训练,使用GPU秒级训练XGBRegressor
边栏推荐
- 将时间戳转为指定格式的时间
- Use of check boxes: select all, deselect all, and select some
- 从C到Capable-----利用指针作为函数参数求字符串是否为回文字符
- Spark on yarn资源优化思路笔记
- [pyg] understand the messagepassing process, GCN demo details
- How to return ordered keys after counter counts the quantity
- Sqlserver row to column pivot
- 解决高並發下System.currentTimeMillis卡頓
- How to use asp Net MVC identity 2 change password authentication- How To Change Password Validation in ASP. Net MVC Identity 2?
- Hi3536C V100R001C02SPC040 交叉编译器安装
猜你喜欢
[pyg] understand the messagepassing process, GCN demo details
【PyG】理解MessagePassing过程,GCN demo详解
MySQL practice 45 lecture [transaction isolation]
Check log4j problems using stain analysis
Deep reinforcement learning for intelligent transportation systems: a survey paper reading notes
C语言初阶-指针详解-庖丁解牛篇
别再用 System.currentTimeMillis() 统计耗时了,太 Low,StopWatch 好用到爆!
Gavin teacher's perception of transformer live class - rasa project's actual banking financial BOT Intelligent Business Dialogue robot architecture, process and phenomenon decryption through rasa inte
Creation and destruction of function stack frame
Super easy to use logzero
随机推荐
Sqlserver row to column pivot
PAT乙级“1104 天长地久”DFS优化思路
Force deduction ----- the minimum path cost in the grid
What happens between entering the URL and displaying the page?
Segmentation fault occurs during VFORK execution
Pat class B common function Usage Summary
[combinatorics] number of solutions of indefinite equations (number of combinations of multiple sets R | number of non negative integer solutions of indefinite equations | number of integer solutions
Introduction to cron expression
Creation and destruction of function stack frame
Converts a timestamp to a time in the specified format
node 开启服务器
Deep learning: multi-layer perceptron and XOR problem (pytoch Implementation)
PAT乙级常用函数用法总结
I2C subsystem (II): I3C spec
I2C subsystem (I): I2C spec
VS 2019安装及配置opencv
Add automatic model generation function to hade
How to return ordered keys after counter counts the quantity
I2C 子系统(四):I2C debug
Check log4j problems using stain analysis