当前位置:网站首页>Recommended model recurrence (I): familiar with torch rechub framework and use
Recommended model recurrence (I): familiar with torch rechub framework and use
2022-06-29 12:39:00 【GoAI】
This series is the first chapter of the recommended model , It mainly uses PyTorch Reproduce the recommended model , be familiar with Torch-RecHub Framework and use .
1 Torch-RecHub frame
Torch-RecHub It's a lightweight pytorch Recommended model framework
1.1 The framework outlined
- Core positioning : Easy to use and expand 、 It can reproduce the practical recommendation model in the industry 、 Focusing on the research of model recurrence of Pan ecology
- Engineering : be based on PyTorch Native classes and functions , Model training and model definition are decoupled , nothing basemodel, On the basis of conforming to the ideas of the thesis , Let the students get started quickly
- Learning reference : Reference resources
DeepCTR、FuxiCTRAnd other outstanding open source framework features
1.2 The main features
scikit-learnEasy to use style API(fit、predict), Open the box- Model training and model definition are decoupled , Easy to expand , Different training mechanisms can be set for different types of models
- Support
pandasOfDataFrame、DictAnd other data types , Reduce the cost of getting started - Highly modular , Support common
Layer, Easy to call and assemble to form a new model- LR、MLP、FM、FFM、CIN
- target-attention、self-attention、transformer
- Support common sorting models
- WideDeep、DeepFM、DIN、DCN、xDeepFM etc.
- Support common recall models
- DSSM、YoutubeDNN、YoutubeDSSM、FacebookEBR、MIND etc.
- Rich multi task learning support
- SharedBottom、ESMM、MMOE、PLE、AITM Wait for the model
- GradNorm、UWL、MetaBanlance Equal dynamic loss Weighting mechanism
- Focus on more ecological recommended scenarios
- Support rich training mechanisms
1.3 Torch-RecHub Architecture design
Torch-RecHub It mainly consists of data processing module 、 Model layer module and trainer module :
- Data processing module
- Feature handling :DenseFeature( Used to construct numeric features )、SparseFeature( Used to deal with category type features )、SequenceFeature( Used to process sequence features )
- Data structure :DataGenerator( Data generator , Used to create three data sets )
- Model layer module
- Sort model :WideDeep、DeepFM、DCN、xDeepFM、DIN、DIEN、SIM
- Recall model :DSSM、YoutubeDNN、YoutubeSBC、FaceBookDSSM、Gru4Rec、MIND、SASRec、ComiRec
- Multitask model :SharedBottom、ESMM、MMOE、PLE、AITM
- Trainer module
- CTRTrainer: For training and evaluation of fine rehearsal model
- MTLTrainer: It is used for training and evaluation of multi task sorting model
- MatchTrainer: For recall model training and evaluation
2 Torch-RecHub Use
A small sample of criteo Data sets , have only 115 Data . The dataset is Criteo Labs Published online advertising data sets . It contains millions of click feedback records of display ads , This data can be used as the click through rate (CTR) The basis of the forecast . The dataset has 40 Features , The first 1 Columns are labels , The value of 1 Indicates that the advertisement has been clicked , And value 0 Indicates that the advertisement has not been clicked . Other features include 13 individual dense The characteristics and 26 individual sparse features .
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from sklearn.preprocessing import MinMaxScaler, LabelEncoder
from torch_rechub.basic.features import DenseFeature, SparseFeature
from torch_rechub.utils.data import DataGenerator
from torch_rechub.trainers import CTRTrainer
from torch_rechub.models.ranking import WideDeepdata_path = './data/criteo/criteo_sample.csv'
# Import dataset
data = pd.read_csv(data_path)
data.head()| label | I1 | I2 | I3 | I4 | I5 | I6 | I7 | I8 | I9 | ... | C17 | C18 | C19 | C20 | C21 | C22 | C23 | C24 | C25 | C26 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 0 | 0.0 | 0 | 104.0 | 27.0 | 1990.0 | 142.0 | 4.0 | 32.0 | 37.0 | ... | e5ba7672 | 25c88e42 | 21ddcdc9 | b1252a9d | 0e8585d2 | NaN | 32c7478e | 0d4a6d1a | 001f3601 | 92c878de |
| 1 | 0 | 0.0 | -1 | 63.0 | 40.0 | 1470.0 | 61.0 | 4.0 | 37.0 | 46.0 | ... | e5ba7672 | d3303ea5 | 21ddcdc9 | b1252a9d | 7633c7c8 | NaN | 32c7478e | 17f458f7 | 001f3601 | 71236095 |
| 2 | 0 | 0.0 | 370 | 4.0 | 1.0 | 1787.0 | 65.0 | 14.0 | 25.0 | 489.0 | ... | 3486227d | 642f2610 | 55dd3565 | b1252a9d | 5c8dc711 | NaN | 423fab69 | 45ab94c8 | 2bf691b1 | c84c4aec |
| 3 | 1 | 19.0 | 10 | 30.0 | 10.0 | 1.0 | 3.0 | 33.0 | 47.0 | 126.0 | ... | e5ba7672 | a78bd508 | 21ddcdc9 | 5840adea | c2a93b37 | NaN | 32c7478e | 1793a828 | e8b83407 | 2fede552 |
| 4 | 0 | 0.0 | 0 | 36.0 | 22.0 | 4684.0 | 217.0 | 9.0 | 35.0 | 135.0 | ... | e5ba7672 | 7ce63c71 | NaN | NaN | af5dc647 | NaN | dbb486d7 | 1793a828 | NaN | NaN |
5 rows × 40 columns
dense_features = [f for f in data.columns.tolist() if f[0] == "I"]
sparse_features = [f for f in data.columns.tolist() if f[0] == "C"]
# data NaN Value padding , Yes sparse The characteristics of the NaN The data fill string is -996, Yes dense The characteristics of the NaN Data filling 0
data[sparse_features] = data[sparse_features].fillna('-996',)
data[dense_features] = data[dense_features].fillna(0,)def convert_numeric_feature(val):
v = int(val)
if v > 2:
return int(np.log(v)**2)
else:
return v - 2# Normalize
for feat in dense_features:
sparse_features.append(feat + "_cat")
data[feat + "_cat"] = data[feat].apply(lambda x: convert_numeric_feature(x))
sca = MinMaxScaler() #scaler dense feature
data[dense_features] = sca.fit_transform(data[dense_features])# Handle sparse Characteristic data
for feat in sparse_features:
lbe = LabelEncoder()
data[feat] = lbe.fit_transform(data[feat])# Get the final data
dense_feas = [DenseFeature(feature_name) for feature_name in dense_features]
sparse_feas = [SparseFeature(feature_name, vocab_size=data[feature_name].nunique(), embed_dim=16) for feature_name in sparse_features]
y = data["label"]
del data["label"]
x = data# Build a data generator
data_generator = DataGenerator(x, y)batch_size = 2048
# Separate data sets into training sets 70%、 Verification set 10% And test set 20%
train_dataloader, val_dataloader, test_dataloader = data_generator.generate_dataloader(split_ratio=[0.7, 0.1], batch_size=batch_size)the samples of train : val : test are 80 : 11 : 24# Configure parameters of multilayer perceptron module
mlp_params={
"dims": [256, 128],
"dropout": 0.2,
"activation": "relu"}
# structure WideDeep Model
model = WideDeep(wide_features=dense_feas, deep_features=sparse_feas, mlp_params=mlp_params)learning_rate = 1e-3 weight_decay = 1e-3 device = 'cuda:0' save_dir = './models/' epoch = 2 optimizer_params={ "lr": learning_rate, "weight_decay": weight_decay} # Build trainers ctr_trainer = CTRTrainer(model, optimizer_params=optimizer_params, n_epoch=epoch, earlystop_patience=10, device=device, model_path=save_dir)# model training ctr_trainer.fit(train_dataloader, val_dataloader)epoch: 0 train: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:12<00:00, 12.33s/it] validation: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:10<00:00, 10.66s/it] epoch: 0 validation: auc: 0.35 epoch: 1 train: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:10<00:00, 10.71s/it] validation: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:10<00:00, 10.69s/it] epoch: 1 validation: auc: 0.35
Model to evaluate
auc = ctr_trainer.evaluate(ctr_trainer.model, test_dataloader)
print(f'test auc: {auc}')
validation: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:10<00:00, 10.60s/it]
test auc: 0.62037037037037033 summary
This task , It mainly introduces Torch-RecHub Framework and basic usage :
- Torch-RecHub The framework is mainly based on PyTorch and sklearn, Easy to use and expand 、 It can reproduce the practical recommendation model in the industry , Highly modular , Support common Layer, Support common sorting models 、 Recall model 、 Multi task learning ;
- Usage method : Use DataGenerator Build a data loader , By building lightweight models , And model training is carried out based on the unified trainer , Finally, complete the model evaluation .
边栏推荐
- Interpolated scatter data
- MIT linear algebra Chinese Notes
- GBase8s数据库select有ORDER BY 子句1
- Principle and process of MySQL master-slave replication
- Gbase8s database into external clause
- MySQL主从同步之 异步复制 半同步复制 全同步复制
- Inferiority complex and transcendence the meaning of life to you
- ERP编制物料清单 金蝶
- GBase8s数据库与 FOR UPDATE 子句不兼容的语法
- 速看|期待已久的2022年广州助理检测工程师真题解析终于出炉
猜你喜欢

Go Senior Engineer required course | I sincerely suggest you listen to it. Don't miss it~

1. Opencv实现简单颜色识别

Understanding of P value

智能指标驱动的管理和决策平台 Kyligence Zen 全新上线,限量内测中

如何计算win/tai/loss in paired t-test

How can colleges and universities build future oriented smart campus based on cloud native? Full stack cloud native architecture vs traditional IT architecture

1. opencv realizes simple color recognition

一种可解释的几何深度学习模型,用于基于结构的蛋白质结合位点预测

JVM之方法区

ArtBench:第一个类平衡的、高质量的、干净注释的和标准化的艺术品生成数据集
随机推荐
cmake 报错
GBase8s数据库FOR UPDATE 子句
Uncover the practice of Baidu intelligent test in the field of automatic test execution
Baidu cloud disk downloads large files without speed limit (valid for 2021-11 personal test)
墨菲安全入选中关村科学城24个重点项目签约
How can colleges and universities build future oriented smart campus based on cloud native? Full stack cloud native architecture vs traditional IT architecture
Artbench: the first class balanced, high-quality, clean annotated and standardized artwork generation data set
参加2022年杭州站Cocos Star Meetings
推荐模型复现(三):召回模型YoutubeDNN、DSSM
Method area of JVM
Titanium dynamic technology: our Zadig landing Road
架构实战营第五模块课后作业
Unified exception reporting practice based on bytecode
Gbase8s database for read only clause
Gbase8s database select has an order by clause
Do you think people who learn machinery are terrible?
[comprehensive case] credit card virtual transaction identification
How do I open an account now? Is there a faster and safer opening channel
LM07丨细聊期货横截面策略
牛顿不等式