当前位置:网站首页>Recommended model recurrence (I): familiar with torch rechub framework and use
Recommended model recurrence (I): familiar with torch rechub framework and use
2022-06-29 12:39:00 【GoAI】
This series is the first chapter of the recommended model , It mainly uses PyTorch Reproduce the recommended model , be familiar with Torch-RecHub Framework and use .
1 Torch-RecHub frame
Torch-RecHub It's a lightweight pytorch Recommended model framework
1.1 The framework outlined
- Core positioning : Easy to use and expand 、 It can reproduce the practical recommendation model in the industry 、 Focusing on the research of model recurrence of Pan ecology
- Engineering : be based on PyTorch Native classes and functions , Model training and model definition are decoupled , nothing basemodel, On the basis of conforming to the ideas of the thesis , Let the students get started quickly
- Learning reference : Reference resources
DeepCTR、FuxiCTRAnd other outstanding open source framework features
1.2 The main features
scikit-learnEasy to use style API(fit、predict), Open the box- Model training and model definition are decoupled , Easy to expand , Different training mechanisms can be set for different types of models
- Support
pandasOfDataFrame、DictAnd other data types , Reduce the cost of getting started - Highly modular , Support common
Layer, Easy to call and assemble to form a new model- LR、MLP、FM、FFM、CIN
- target-attention、self-attention、transformer
- Support common sorting models
- WideDeep、DeepFM、DIN、DCN、xDeepFM etc.
- Support common recall models
- DSSM、YoutubeDNN、YoutubeDSSM、FacebookEBR、MIND etc.
- Rich multi task learning support
- SharedBottom、ESMM、MMOE、PLE、AITM Wait for the model
- GradNorm、UWL、MetaBanlance Equal dynamic loss Weighting mechanism
- Focus on more ecological recommended scenarios
- Support rich training mechanisms
1.3 Torch-RecHub Architecture design
Torch-RecHub It mainly consists of data processing module 、 Model layer module and trainer module :
- Data processing module
- Feature handling :DenseFeature( Used to construct numeric features )、SparseFeature( Used to deal with category type features )、SequenceFeature( Used to process sequence features )
- Data structure :DataGenerator( Data generator , Used to create three data sets )
- Model layer module
- Sort model :WideDeep、DeepFM、DCN、xDeepFM、DIN、DIEN、SIM
- Recall model :DSSM、YoutubeDNN、YoutubeSBC、FaceBookDSSM、Gru4Rec、MIND、SASRec、ComiRec
- Multitask model :SharedBottom、ESMM、MMOE、PLE、AITM
- Trainer module
- CTRTrainer: For training and evaluation of fine rehearsal model
- MTLTrainer: It is used for training and evaluation of multi task sorting model
- MatchTrainer: For recall model training and evaluation
2 Torch-RecHub Use
A small sample of criteo Data sets , have only 115 Data . The dataset is Criteo Labs Published online advertising data sets . It contains millions of click feedback records of display ads , This data can be used as the click through rate (CTR) The basis of the forecast . The dataset has 40 Features , The first 1 Columns are labels , The value of 1 Indicates that the advertisement has been clicked , And value 0 Indicates that the advertisement has not been clicked . Other features include 13 individual dense The characteristics and 26 individual sparse features .
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from sklearn.preprocessing import MinMaxScaler, LabelEncoder
from torch_rechub.basic.features import DenseFeature, SparseFeature
from torch_rechub.utils.data import DataGenerator
from torch_rechub.trainers import CTRTrainer
from torch_rechub.models.ranking import WideDeepdata_path = './data/criteo/criteo_sample.csv'
# Import dataset
data = pd.read_csv(data_path)
data.head()| label | I1 | I2 | I3 | I4 | I5 | I6 | I7 | I8 | I9 | ... | C17 | C18 | C19 | C20 | C21 | C22 | C23 | C24 | C25 | C26 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 0 | 0.0 | 0 | 104.0 | 27.0 | 1990.0 | 142.0 | 4.0 | 32.0 | 37.0 | ... | e5ba7672 | 25c88e42 | 21ddcdc9 | b1252a9d | 0e8585d2 | NaN | 32c7478e | 0d4a6d1a | 001f3601 | 92c878de |
| 1 | 0 | 0.0 | -1 | 63.0 | 40.0 | 1470.0 | 61.0 | 4.0 | 37.0 | 46.0 | ... | e5ba7672 | d3303ea5 | 21ddcdc9 | b1252a9d | 7633c7c8 | NaN | 32c7478e | 17f458f7 | 001f3601 | 71236095 |
| 2 | 0 | 0.0 | 370 | 4.0 | 1.0 | 1787.0 | 65.0 | 14.0 | 25.0 | 489.0 | ... | 3486227d | 642f2610 | 55dd3565 | b1252a9d | 5c8dc711 | NaN | 423fab69 | 45ab94c8 | 2bf691b1 | c84c4aec |
| 3 | 1 | 19.0 | 10 | 30.0 | 10.0 | 1.0 | 3.0 | 33.0 | 47.0 | 126.0 | ... | e5ba7672 | a78bd508 | 21ddcdc9 | 5840adea | c2a93b37 | NaN | 32c7478e | 1793a828 | e8b83407 | 2fede552 |
| 4 | 0 | 0.0 | 0 | 36.0 | 22.0 | 4684.0 | 217.0 | 9.0 | 35.0 | 135.0 | ... | e5ba7672 | 7ce63c71 | NaN | NaN | af5dc647 | NaN | dbb486d7 | 1793a828 | NaN | NaN |
5 rows × 40 columns
dense_features = [f for f in data.columns.tolist() if f[0] == "I"]
sparse_features = [f for f in data.columns.tolist() if f[0] == "C"]
# data NaN Value padding , Yes sparse The characteristics of the NaN The data fill string is -996, Yes dense The characteristics of the NaN Data filling 0
data[sparse_features] = data[sparse_features].fillna('-996',)
data[dense_features] = data[dense_features].fillna(0,)def convert_numeric_feature(val):
v = int(val)
if v > 2:
return int(np.log(v)**2)
else:
return v - 2# Normalize
for feat in dense_features:
sparse_features.append(feat + "_cat")
data[feat + "_cat"] = data[feat].apply(lambda x: convert_numeric_feature(x))
sca = MinMaxScaler() #scaler dense feature
data[dense_features] = sca.fit_transform(data[dense_features])# Handle sparse Characteristic data
for feat in sparse_features:
lbe = LabelEncoder()
data[feat] = lbe.fit_transform(data[feat])# Get the final data
dense_feas = [DenseFeature(feature_name) for feature_name in dense_features]
sparse_feas = [SparseFeature(feature_name, vocab_size=data[feature_name].nunique(), embed_dim=16) for feature_name in sparse_features]
y = data["label"]
del data["label"]
x = data# Build a data generator
data_generator = DataGenerator(x, y)batch_size = 2048
# Separate data sets into training sets 70%、 Verification set 10% And test set 20%
train_dataloader, val_dataloader, test_dataloader = data_generator.generate_dataloader(split_ratio=[0.7, 0.1], batch_size=batch_size)the samples of train : val : test are 80 : 11 : 24# Configure parameters of multilayer perceptron module
mlp_params={
"dims": [256, 128],
"dropout": 0.2,
"activation": "relu"}
# structure WideDeep Model
model = WideDeep(wide_features=dense_feas, deep_features=sparse_feas, mlp_params=mlp_params)learning_rate = 1e-3 weight_decay = 1e-3 device = 'cuda:0' save_dir = './models/' epoch = 2 optimizer_params={ "lr": learning_rate, "weight_decay": weight_decay} # Build trainers ctr_trainer = CTRTrainer(model, optimizer_params=optimizer_params, n_epoch=epoch, earlystop_patience=10, device=device, model_path=save_dir)# model training ctr_trainer.fit(train_dataloader, val_dataloader)epoch: 0 train: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:12<00:00, 12.33s/it] validation: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:10<00:00, 10.66s/it] epoch: 0 validation: auc: 0.35 epoch: 1 train: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:10<00:00, 10.71s/it] validation: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:10<00:00, 10.69s/it] epoch: 1 validation: auc: 0.35
Model to evaluate
auc = ctr_trainer.evaluate(ctr_trainer.model, test_dataloader)
print(f'test auc: {auc}')
validation: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:10<00:00, 10.60s/it]
test auc: 0.62037037037037033 summary
This task , It mainly introduces Torch-RecHub Framework and basic usage :
- Torch-RecHub The framework is mainly based on PyTorch and sklearn, Easy to use and expand 、 It can reproduce the practical recommendation model in the industry , Highly modular , Support common Layer, Support common sorting models 、 Recall model 、 Multi task learning ;
- Usage method : Use DataGenerator Build a data loader , By building lightweight models , And model training is carried out based on the unified trainer , Finally, complete the model evaluation .
边栏推荐
猜你喜欢

面试突击61:说一下MySQL事务隔离级别?

推荐模型复现(一):熟悉Torch-RecHub框架与使用

内插散点数据

When a technician becomes a CEO, what "bugs" should be modified?

【智能QbD风险评估工具】上海道宁为您带来LeanQbD介绍、试用、教程

Go Senior Engineer required course | I sincerely suggest you listen to it. Don't miss it~
![[comprehensive case] credit card virtual transaction identification](/img/85/9915ab9a6100a022dfa1c0050d554f.png)
[comprehensive case] credit card virtual transaction identification

MIT线性代数中文笔记

缓存一致性,删除缓存,写入缓存,缓存击穿,缓存穿透,缓存雪崩

1. opencv realizes simple color recognition
随机推荐
ERP preparation of bill of materials Huaxia
Li Kou daily question - day 31 -1779 Find the nearest point with the same X or Y coordinate
MySQL数据库主从同步,一致性解决方案
Go learning - build a development environment vscode development environment golang
Artbench: the first class balanced, high-quality, clean annotated and standardized artwork generation data set
谷粒商城项目
oracle 19c : change the user sys/system username pasword under Linux
Gbase8s database into table clause
An interpretable geometric depth learning model for structure based protein binding site prediction
Uncover the practice of Baidu intelligent test in the field of automatic test execution
Li Kou daily question - day 31 -13 Roman array to integer
Difficult conversation breaks through the bottleneck of conversation and achieves perfect communication
《自卑与超越》生活对你应有的意义
测试--自动化测试:关于unittest框架
nacos启动报错
Interpolated scatter data
GBase8s数据库select有ORDER BY 子句
Set operator of gbase8s database in combined query
东方财富证券开户安全吗 证券开户办理
Li Kou daily question - day 31 -13 Maximum perimeter of triangle