当前位置:网站首页>用XGBoost迭代读取数据集
用XGBoost迭代读取数据集
2022-06-27 06:35:00 【Datawhale】
Datawhale干货
来源:Coggle数据科学
在大规模数据集进行读取进行训练的过程中,迭代读取数据集是一个非常合适的选择,在Pytorch中支持迭代读取的方式。接下来我们将介绍XGBoost的迭代读取的方式。
内存数据读取
class IterLoadForDMatrix(xgb.core.DataIter):
def __init__(self, df=None, features=None, target=None, batch_size=256*1024):
self.features = features
self.target = target
self.df = df
self.batch_size = batch_size
self.batches = int( np.ceil( len(df) / self.batch_size ) )
self.it = 0 # set iterator to 0
super().__init__()
def reset(self):
'''Reset the iterator'''
self.it = 0
def next(self, input_data):
'''Yield next batch of data.'''
if self.it == self.batches:
return 0 # Return 0 when there's no more batch.
a = self.it * self.batch_size
b = min( (self.it + 1) * self.batch_size, len(self.df) )
dt = pd.DataFrame(self.df.iloc[a:b])
input_data(data=dt[self.features], label=dt[self.target]) #, weight=dt['weight'])
self.it += 1
return 1调用方法(此种方式比较适合GPU训练):
Xy_train = IterLoadForDMatrix(train.loc[train_idx], FEATURES, 'target')
dtrain = xgb.DeviceQuantileDMatrix(Xy_train, max_bin=256)参考文档:
https://xgboost.readthedocs.io/en/latest/python/examples/quantile_data_iterator.html
外部数据迭代读取
class Iterator(xgboost.DataIter):
def __init__(self, svm_file_paths: List[str]):
self._file_paths = svm_file_paths
self._it = 0
super().__init__(cache_prefix=os.path.join(".", "cache"))
def next(self, input_data: Callable):
if self._it == len(self._file_paths):
# return 0 to let XGBoost know this is the end of iteration
return 0
X, y = load_svmlight_file(self._file_paths[self._it])
input_data(X, y)
self._it += 1
return 1
def reset(self):
"""Reset the iterator to its beginning"""
self._it = 0调用方法(此种方式比较适合CPU训练):
it = Iterator(["file_0.svm", "file_1.svm", "file_2.svm"])
Xy = xgboost.DMatrix(it)
# Other tree methods including ``hist`` and ``gpu_hist`` also work, but has some caveats
# as noted in following sections.
booster = xgboost.train({"tree_method": "approx"}, Xy)参考文档:
https://xgboost.readthedocs.io/en/stable/tutorials/external_memory.html

整理不易,点赞三连↓
边栏推荐
- 如何优雅的写 Controller 层代码?
- pytorch Default process group is not initialized
- Optimistic and pessimistic affairs
- Mathematical modeling contest for graduate students - optimal application of UAV in rescue and disaster relief
- From 5 seconds to 1 second, the system flies
- Inter thread wait and wake-up mechanism, singleton mode, blocking queue, timer
- 观测电机转速转矩
- OpenCV怎么下载?OpenCV下载后怎么配置?
- 面试官:用分库分表如何做到永不迁移数据和避免热点问题?
- Yolov6's fast and accurate target detection framework is open source
猜你喜欢

2022 CISP-PTE(一)文件包含

AHB2APB桥接器设计(2)——同步桥设计的介绍

Park and unpark in unsafe

Machine learning

Centos7.9安装mysql 5.7,并设置开机启动

On gpu: historical development and structure

Convolution neural network -- Application of CNN model (ore prospecting prediction)

Ahb2apb bridge design (2) -- Introduction to synchronous bridge design

MPC control of aircraft wingtip acceleration and control surface

建模竞赛-光传送网建模与价值评估
随机推荐
How to write controller layer code gracefully?
tracepoint
Caldera安装及简单使用
[getting started] regular expression Basics
Park and unpark in unsafe
Date database date strings are converted to and from each other
云服务器配置ftp、企业官网、数据库等方法
TiDB 基本功能
[cultivation system] common regular expressions
Process termination (have you really learned recursion? Test your recursion Foundation)
第 299 场周赛 第四题 6103. 从树中删除边的最小分数
获取地址url中的query参数指定参数方法
技术人员创业一年心得
2022 CISP-PTE(二)SQL注入
记一次Spark报错:Failed to allocate a page (67108864 bytes), try again.
Tidb basic functions
Instance Tunnel 使用
Unrecognized VM option ‘‘
Win10 remote connection to ECS
Oppo interview sorting, real eight part essay, abusing the interviewer