当前位置:网站首页>Recommended model recurrence (I): familiar with torch rechub framework and use
Recommended model recurrence (I): familiar with torch rechub framework and use
2022-06-29 12:39:00 【GoAI】
This series is the first chapter of the recommended model , It mainly uses PyTorch Reproduce the recommended model , be familiar with Torch-RecHub Framework and use .
1 Torch-RecHub frame
Torch-RecHub It's a lightweight pytorch Recommended model framework
1.1 The framework outlined
- Core positioning : Easy to use and expand 、 It can reproduce the practical recommendation model in the industry 、 Focusing on the research of model recurrence of Pan ecology
- Engineering : be based on PyTorch Native classes and functions , Model training and model definition are decoupled , nothing basemodel, On the basis of conforming to the ideas of the thesis , Let the students get started quickly
- Learning reference : Reference resources
DeepCTR、FuxiCTRAnd other outstanding open source framework features
1.2 The main features
scikit-learnEasy to use style API(fit、predict), Open the box- Model training and model definition are decoupled , Easy to expand , Different training mechanisms can be set for different types of models
- Support
pandasOfDataFrame、DictAnd other data types , Reduce the cost of getting started - Highly modular , Support common
Layer, Easy to call and assemble to form a new model- LR、MLP、FM、FFM、CIN
- target-attention、self-attention、transformer
- Support common sorting models
- WideDeep、DeepFM、DIN、DCN、xDeepFM etc.
- Support common recall models
- DSSM、YoutubeDNN、YoutubeDSSM、FacebookEBR、MIND etc.
- Rich multi task learning support
- SharedBottom、ESMM、MMOE、PLE、AITM Wait for the model
- GradNorm、UWL、MetaBanlance Equal dynamic loss Weighting mechanism
- Focus on more ecological recommended scenarios
- Support rich training mechanisms
1.3 Torch-RecHub Architecture design
Torch-RecHub It mainly consists of data processing module 、 Model layer module and trainer module :
- Data processing module
- Feature handling :DenseFeature( Used to construct numeric features )、SparseFeature( Used to deal with category type features )、SequenceFeature( Used to process sequence features )
- Data structure :DataGenerator( Data generator , Used to create three data sets )
- Model layer module
- Sort model :WideDeep、DeepFM、DCN、xDeepFM、DIN、DIEN、SIM
- Recall model :DSSM、YoutubeDNN、YoutubeSBC、FaceBookDSSM、Gru4Rec、MIND、SASRec、ComiRec
- Multitask model :SharedBottom、ESMM、MMOE、PLE、AITM
- Trainer module
- CTRTrainer: For training and evaluation of fine rehearsal model
- MTLTrainer: It is used for training and evaluation of multi task sorting model
- MatchTrainer: For recall model training and evaluation
2 Torch-RecHub Use
A small sample of criteo Data sets , have only 115 Data . The dataset is Criteo Labs Published online advertising data sets . It contains millions of click feedback records of display ads , This data can be used as the click through rate (CTR) The basis of the forecast . The dataset has 40 Features , The first 1 Columns are labels , The value of 1 Indicates that the advertisement has been clicked , And value 0 Indicates that the advertisement has not been clicked . Other features include 13 individual dense The characteristics and 26 individual sparse features .
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from sklearn.preprocessing import MinMaxScaler, LabelEncoder
from torch_rechub.basic.features import DenseFeature, SparseFeature
from torch_rechub.utils.data import DataGenerator
from torch_rechub.trainers import CTRTrainer
from torch_rechub.models.ranking import WideDeepdata_path = './data/criteo/criteo_sample.csv'
# Import dataset
data = pd.read_csv(data_path)
data.head()| label | I1 | I2 | I3 | I4 | I5 | I6 | I7 | I8 | I9 | ... | C17 | C18 | C19 | C20 | C21 | C22 | C23 | C24 | C25 | C26 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 0 | 0.0 | 0 | 104.0 | 27.0 | 1990.0 | 142.0 | 4.0 | 32.0 | 37.0 | ... | e5ba7672 | 25c88e42 | 21ddcdc9 | b1252a9d | 0e8585d2 | NaN | 32c7478e | 0d4a6d1a | 001f3601 | 92c878de |
| 1 | 0 | 0.0 | -1 | 63.0 | 40.0 | 1470.0 | 61.0 | 4.0 | 37.0 | 46.0 | ... | e5ba7672 | d3303ea5 | 21ddcdc9 | b1252a9d | 7633c7c8 | NaN | 32c7478e | 17f458f7 | 001f3601 | 71236095 |
| 2 | 0 | 0.0 | 370 | 4.0 | 1.0 | 1787.0 | 65.0 | 14.0 | 25.0 | 489.0 | ... | 3486227d | 642f2610 | 55dd3565 | b1252a9d | 5c8dc711 | NaN | 423fab69 | 45ab94c8 | 2bf691b1 | c84c4aec |
| 3 | 1 | 19.0 | 10 | 30.0 | 10.0 | 1.0 | 3.0 | 33.0 | 47.0 | 126.0 | ... | e5ba7672 | a78bd508 | 21ddcdc9 | 5840adea | c2a93b37 | NaN | 32c7478e | 1793a828 | e8b83407 | 2fede552 |
| 4 | 0 | 0.0 | 0 | 36.0 | 22.0 | 4684.0 | 217.0 | 9.0 | 35.0 | 135.0 | ... | e5ba7672 | 7ce63c71 | NaN | NaN | af5dc647 | NaN | dbb486d7 | 1793a828 | NaN | NaN |
5 rows × 40 columns
dense_features = [f for f in data.columns.tolist() if f[0] == "I"]
sparse_features = [f for f in data.columns.tolist() if f[0] == "C"]
# data NaN Value padding , Yes sparse The characteristics of the NaN The data fill string is -996, Yes dense The characteristics of the NaN Data filling 0
data[sparse_features] = data[sparse_features].fillna('-996',)
data[dense_features] = data[dense_features].fillna(0,)def convert_numeric_feature(val):
v = int(val)
if v > 2:
return int(np.log(v)**2)
else:
return v - 2# Normalize
for feat in dense_features:
sparse_features.append(feat + "_cat")
data[feat + "_cat"] = data[feat].apply(lambda x: convert_numeric_feature(x))
sca = MinMaxScaler() #scaler dense feature
data[dense_features] = sca.fit_transform(data[dense_features])# Handle sparse Characteristic data
for feat in sparse_features:
lbe = LabelEncoder()
data[feat] = lbe.fit_transform(data[feat])# Get the final data
dense_feas = [DenseFeature(feature_name) for feature_name in dense_features]
sparse_feas = [SparseFeature(feature_name, vocab_size=data[feature_name].nunique(), embed_dim=16) for feature_name in sparse_features]
y = data["label"]
del data["label"]
x = data# Build a data generator
data_generator = DataGenerator(x, y)batch_size = 2048
# Separate data sets into training sets 70%、 Verification set 10% And test set 20%
train_dataloader, val_dataloader, test_dataloader = data_generator.generate_dataloader(split_ratio=[0.7, 0.1], batch_size=batch_size)the samples of train : val : test are 80 : 11 : 24# Configure parameters of multilayer perceptron module
mlp_params={
"dims": [256, 128],
"dropout": 0.2,
"activation": "relu"}
# structure WideDeep Model
model = WideDeep(wide_features=dense_feas, deep_features=sparse_feas, mlp_params=mlp_params)learning_rate = 1e-3 weight_decay = 1e-3 device = 'cuda:0' save_dir = './models/' epoch = 2 optimizer_params={ "lr": learning_rate, "weight_decay": weight_decay} # Build trainers ctr_trainer = CTRTrainer(model, optimizer_params=optimizer_params, n_epoch=epoch, earlystop_patience=10, device=device, model_path=save_dir)# model training ctr_trainer.fit(train_dataloader, val_dataloader)epoch: 0 train: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:12<00:00, 12.33s/it] validation: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:10<00:00, 10.66s/it] epoch: 0 validation: auc: 0.35 epoch: 1 train: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:10<00:00, 10.71s/it] validation: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:10<00:00, 10.69s/it] epoch: 1 validation: auc: 0.35
Model to evaluate
auc = ctr_trainer.evaluate(ctr_trainer.model, test_dataloader)
print(f'test auc: {auc}')
validation: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:10<00:00, 10.60s/it]
test auc: 0.62037037037037033 summary
This task , It mainly introduces Torch-RecHub Framework and basic usage :
- Torch-RecHub The framework is mainly based on PyTorch and sklearn, Easy to use and expand 、 It can reproduce the practical recommendation model in the industry , Highly modular , Support common Layer, Support common sorting models 、 Recall model 、 Multi task learning ;
- Usage method : Use DataGenerator Build a data loader , By building lightweight models , And model training is carried out based on the unified trainer , Finally, complete the model evaluation .
边栏推荐
- Gbase8s database sorts standard or raw result tables
- Gbase8s database for update clause
- Cocos star meetings at Hangzhou station in 2022
- [leetcode] 14. Longest public prefix
- [comprehensive case] credit card virtual transaction identification
- Cereal mall project
- Proteus软件初学笔记
- 地球观测卫星数据
- GBase8s数据库对 STANDARD 或 RAW 结果表排序
- Gbase8s database select has order by Clause 4
猜你喜欢

ERP编制物料清单 基础

参加2022年杭州站Cocos Star Meetings

Interview shock 61: tell me about MySQL transaction isolation level?

How to install oracle19c in Centos8

MIT线性代数中文笔记

Imile uses Zadig's multi cloud environment to deploy thousands of times a week to continuously deliver global business across clouds and regions

1. opencv realizes simple color recognition

缓存一致性,删除缓存,写入缓存,缓存击穿,缓存穿透,缓存雪崩

推荐模型复现(三):召回模型YoutubeDNN、DSSM

推荐模型复现(一):熟悉Torch-RecHub框架与使用
随机推荐
Gbase8s database select has order by Clause 5
oracle 19c : change the user sys/system username pasword under Linux
GBase8s数据库FOR UPDATE 子句
GBase8s数据库select有ORDER BY 子句2
Understanding of P value
Wang Yingqi, founder of ones, talks to fortune (Chinese version): is there any excellent software in China?
The blackened honeysnow ice city wants to grasp the hearts of consumers by marketing?
牛顿不等式
go 学习-搭建开发环境vscode开发环境golang
Kyligence Zen, an intelligent indicator driven management and decision-making platform, is newly launched and is in limited internal testing
ShanDong Multi-University Training #3
Gbase8s database for read only clause
高校如何基于云原生构建面向未来的智慧校园?全栈云原生架构VS传统IT架构
Go高级工程师必修课 | 真心建议你来听听,别错过~
GBase8s数据库在组合查询中的集合运算符
谷粒商城项目
1. opencv realizes simple color recognition
《自卑与超越》生活对你应有的意义
535. TinyURL 的加密与解密 : 设计一个 URL 简化系统
After class assignment of module 5 of the construction practice camp