当前位置:网站首页>Recommended model recurrence (I): familiar with torch rechub framework and use
Recommended model recurrence (I): familiar with torch rechub framework and use
2022-06-29 12:39:00 【GoAI】
This series is the first chapter of the recommended model , It mainly uses PyTorch Reproduce the recommended model , be familiar with Torch-RecHub Framework and use .
1 Torch-RecHub frame
Torch-RecHub It's a lightweight pytorch Recommended model framework
1.1 The framework outlined
- Core positioning : Easy to use and expand 、 It can reproduce the practical recommendation model in the industry 、 Focusing on the research of model recurrence of Pan ecology
- Engineering : be based on PyTorch Native classes and functions , Model training and model definition are decoupled , nothing basemodel, On the basis of conforming to the ideas of the thesis , Let the students get started quickly
- Learning reference : Reference resources
DeepCTR、FuxiCTRAnd other outstanding open source framework features
1.2 The main features
scikit-learnEasy to use style API(fit、predict), Open the box- Model training and model definition are decoupled , Easy to expand , Different training mechanisms can be set for different types of models
- Support
pandasOfDataFrame、DictAnd other data types , Reduce the cost of getting started - Highly modular , Support common
Layer, Easy to call and assemble to form a new model- LR、MLP、FM、FFM、CIN
- target-attention、self-attention、transformer
- Support common sorting models
- WideDeep、DeepFM、DIN、DCN、xDeepFM etc.
- Support common recall models
- DSSM、YoutubeDNN、YoutubeDSSM、FacebookEBR、MIND etc.
- Rich multi task learning support
- SharedBottom、ESMM、MMOE、PLE、AITM Wait for the model
- GradNorm、UWL、MetaBanlance Equal dynamic loss Weighting mechanism
- Focus on more ecological recommended scenarios
- Support rich training mechanisms
1.3 Torch-RecHub Architecture design
Torch-RecHub It mainly consists of data processing module 、 Model layer module and trainer module :
- Data processing module
- Feature handling :DenseFeature( Used to construct numeric features )、SparseFeature( Used to deal with category type features )、SequenceFeature( Used to process sequence features )
- Data structure :DataGenerator( Data generator , Used to create three data sets )
- Model layer module
- Sort model :WideDeep、DeepFM、DCN、xDeepFM、DIN、DIEN、SIM
- Recall model :DSSM、YoutubeDNN、YoutubeSBC、FaceBookDSSM、Gru4Rec、MIND、SASRec、ComiRec
- Multitask model :SharedBottom、ESMM、MMOE、PLE、AITM
- Trainer module
- CTRTrainer: For training and evaluation of fine rehearsal model
- MTLTrainer: It is used for training and evaluation of multi task sorting model
- MatchTrainer: For recall model training and evaluation
2 Torch-RecHub Use
A small sample of criteo Data sets , have only 115 Data . The dataset is Criteo Labs Published online advertising data sets . It contains millions of click feedback records of display ads , This data can be used as the click through rate (CTR) The basis of the forecast . The dataset has 40 Features , The first 1 Columns are labels , The value of 1 Indicates that the advertisement has been clicked , And value 0 Indicates that the advertisement has not been clicked . Other features include 13 individual dense The characteristics and 26 individual sparse features .
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from sklearn.preprocessing import MinMaxScaler, LabelEncoder
from torch_rechub.basic.features import DenseFeature, SparseFeature
from torch_rechub.utils.data import DataGenerator
from torch_rechub.trainers import CTRTrainer
from torch_rechub.models.ranking import WideDeepdata_path = './data/criteo/criteo_sample.csv'
# Import dataset
data = pd.read_csv(data_path)
data.head()| label | I1 | I2 | I3 | I4 | I5 | I6 | I7 | I8 | I9 | ... | C17 | C18 | C19 | C20 | C21 | C22 | C23 | C24 | C25 | C26 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 0 | 0.0 | 0 | 104.0 | 27.0 | 1990.0 | 142.0 | 4.0 | 32.0 | 37.0 | ... | e5ba7672 | 25c88e42 | 21ddcdc9 | b1252a9d | 0e8585d2 | NaN | 32c7478e | 0d4a6d1a | 001f3601 | 92c878de |
| 1 | 0 | 0.0 | -1 | 63.0 | 40.0 | 1470.0 | 61.0 | 4.0 | 37.0 | 46.0 | ... | e5ba7672 | d3303ea5 | 21ddcdc9 | b1252a9d | 7633c7c8 | NaN | 32c7478e | 17f458f7 | 001f3601 | 71236095 |
| 2 | 0 | 0.0 | 370 | 4.0 | 1.0 | 1787.0 | 65.0 | 14.0 | 25.0 | 489.0 | ... | 3486227d | 642f2610 | 55dd3565 | b1252a9d | 5c8dc711 | NaN | 423fab69 | 45ab94c8 | 2bf691b1 | c84c4aec |
| 3 | 1 | 19.0 | 10 | 30.0 | 10.0 | 1.0 | 3.0 | 33.0 | 47.0 | 126.0 | ... | e5ba7672 | a78bd508 | 21ddcdc9 | 5840adea | c2a93b37 | NaN | 32c7478e | 1793a828 | e8b83407 | 2fede552 |
| 4 | 0 | 0.0 | 0 | 36.0 | 22.0 | 4684.0 | 217.0 | 9.0 | 35.0 | 135.0 | ... | e5ba7672 | 7ce63c71 | NaN | NaN | af5dc647 | NaN | dbb486d7 | 1793a828 | NaN | NaN |
5 rows × 40 columns
dense_features = [f for f in data.columns.tolist() if f[0] == "I"]
sparse_features = [f for f in data.columns.tolist() if f[0] == "C"]
# data NaN Value padding , Yes sparse The characteristics of the NaN The data fill string is -996, Yes dense The characteristics of the NaN Data filling 0
data[sparse_features] = data[sparse_features].fillna('-996',)
data[dense_features] = data[dense_features].fillna(0,)def convert_numeric_feature(val):
v = int(val)
if v > 2:
return int(np.log(v)**2)
else:
return v - 2# Normalize
for feat in dense_features:
sparse_features.append(feat + "_cat")
data[feat + "_cat"] = data[feat].apply(lambda x: convert_numeric_feature(x))
sca = MinMaxScaler() #scaler dense feature
data[dense_features] = sca.fit_transform(data[dense_features])# Handle sparse Characteristic data
for feat in sparse_features:
lbe = LabelEncoder()
data[feat] = lbe.fit_transform(data[feat])# Get the final data
dense_feas = [DenseFeature(feature_name) for feature_name in dense_features]
sparse_feas = [SparseFeature(feature_name, vocab_size=data[feature_name].nunique(), embed_dim=16) for feature_name in sparse_features]
y = data["label"]
del data["label"]
x = data# Build a data generator
data_generator = DataGenerator(x, y)batch_size = 2048
# Separate data sets into training sets 70%、 Verification set 10% And test set 20%
train_dataloader, val_dataloader, test_dataloader = data_generator.generate_dataloader(split_ratio=[0.7, 0.1], batch_size=batch_size)the samples of train : val : test are 80 : 11 : 24# Configure parameters of multilayer perceptron module
mlp_params={
"dims": [256, 128],
"dropout": 0.2,
"activation": "relu"}
# structure WideDeep Model
model = WideDeep(wide_features=dense_feas, deep_features=sparse_feas, mlp_params=mlp_params)learning_rate = 1e-3 weight_decay = 1e-3 device = 'cuda:0' save_dir = './models/' epoch = 2 optimizer_params={ "lr": learning_rate, "weight_decay": weight_decay} # Build trainers ctr_trainer = CTRTrainer(model, optimizer_params=optimizer_params, n_epoch=epoch, earlystop_patience=10, device=device, model_path=save_dir)# model training ctr_trainer.fit(train_dataloader, val_dataloader)epoch: 0 train: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:12<00:00, 12.33s/it] validation: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:10<00:00, 10.66s/it] epoch: 0 validation: auc: 0.35 epoch: 1 train: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:10<00:00, 10.71s/it] validation: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:10<00:00, 10.69s/it] epoch: 1 validation: auc: 0.35
Model to evaluate
auc = ctr_trainer.evaluate(ctr_trainer.model, test_dataloader)
print(f'test auc: {auc}')
validation: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:10<00:00, 10.60s/it]
test auc: 0.62037037037037033 summary
This task , It mainly introduces Torch-RecHub Framework and basic usage :
- Torch-RecHub The framework is mainly based on PyTorch and sklearn, Easy to use and expand 、 It can reproduce the practical recommendation model in the industry , Highly modular , Support common Layer, Support common sorting models 、 Recall model 、 Multi task learning ;
- Usage method : Use DataGenerator Build a data loader , By building lightweight models , And model training is carried out based on the unified trainer , Finally, complete the model evaluation .
边栏推荐
- JVM之方法区
- MySQL主从同步之 异步复制 半同步复制 全同步复制
- MySQL master-slave synchronous asynchronous replication semi synchronous replication full synchronous replication
- Codeforces Round #803 (Div. 2)
- Dragon Book tiger Book whale Book gnawing? Try the monkey book with Douban score of 9.5
- 推荐模型复现(四):多任务模型ESMM、MMOE
- 《高难度谈话》突破谈话瓶颈,实现完美沟通
- Uncover the practice of Baidu intelligent test in the field of automatic test execution
- MySQL 主从复制原理以及流程
- Is it safe for Orient Fortune Securities to open an account? Handling of securities account opening
猜你喜欢

速看|期待已久的2022年广州助理检测工程师真题解析终于出炉

ERP Kingdee for preparing BOM

How to install oracle19c in Centos8

When you are young, you should be awake to fight, and when you are young, you should have the courage to try

【综合案例】信用卡虚拟交易识别

Do you think people who learn machinery are terrible?

面试突击61:说一下MySQL事务隔离级别?

如何计算win/tai/loss in paired t-test

After class assignment of module 5 of the construction practice camp

Interpolated scatter data
随机推荐
After class assignment of module 5 of the construction practice camp
Kyligence Zen, an intelligent indicator driven management and decision-making platform, is newly launched and is in limited internal testing
Imile uses Zadig's multi cloud environment to deploy thousands of times a week to continuously deliver global business across clouds and regions
How to install oracle19c in Centos8
Gbase8s database select has a having clause
Li Kou daily question - day 31 -13 Maximum perimeter of triangle
The blackened honeysnow ice city wants to grasp the hearts of consumers by marketing?
GBase8s数据库select有ORDER BY 子句6
Proteus软件初学笔记
oracle 19c : change the user sys/system username pasword under Linux
Matlab GUI realizes the function of clicking the button, opening the file dialog box and importing pictures
ERP preparation of bill of materials Huaxia
Titanium dynamic technology: our Zadig landing Road
MySQL主从同步之 异步复制 半同步复制 全同步复制
GBase8s数据库select有ORDER BY 子句5
揭秘百度智能测试在测试自动执行领域实践
GBase8s数据库INTO table 子句
推荐模型复现(四):多任务模型ESMM、MMOE
Interview shock 61: tell me about MySQL transaction isolation level?
参加2022年杭州站Cocos Star Meetings