当前位置:网站首页>Recurrence of recommended models (IV): multi task models esmm and MMOE
Recurrence of recommended models (IV): multi task models esmm and MMOE
2022-06-29 12:39:00 【GoAI】
Multitask model :ESMM、MMOE
This chapter is the fourth chapter of recommendatory model reproduction , Use torch_rechub Framework for model building , This paper mainly introduces the multi task model of recommendation system ESMM、MMOE, Including structure explanation and code practice , Refer to other articles .
Recommended directions and materials :
1 ESMM
1.1 ESMM The background
- Sample selection bias : The distribution sampling of the constructed training sample set is not accurate
- Sparse data : Click samples account for a small proportion of exposure samples
1.2 ESMM principle
- Solutions : Based on multi task learning , introduce CTR、CTCVR Eliminate sample selection bias and sparse data
- Three forecasting tasks :
- pCTR: Click through rate prediction model
- pCVR: Conversion prediction model
- pCTCVR: Click and conversion rate prediction model
\underbrace{p(y=1, z=1 | x)}_{pCTCVR}=\underbrace{p(y=1 | x)}_{pCTR} \times \underbrace{p(z=1 | y=1, x)}_{pCVR}pCTCVRp(y=1,z=1∣x)=pCTRp(y=1∣x)×pCVRp(z=1∣y=1,x)
among xx It means exposure ,yy It means to click ,zz It means transformation

The primary and secondary tasks share features , And make use of CTCVR and CTR Of
\begin{aligned} L(\theta_{c v r}, \theta_{c t r}) &= \sum_{i=1}^{N} l(y_{i}, f(\boldsymbol{x}_{i} ; \theta_{c t r})) \\ &+ \sum_{i=1}^{N} l(y_{i} \& z_{i}, f(\boldsymbol{x}_{i} ; \theta_{c t r}) \times f(\boldsymbol{x}_{i} ; \theta_{c v r})) \end{aligned}L(θcvr,θctr)=i=1∑Nl(yi,f(xi;θctr))+i=1∑Nl(yi&zi,f(xi;θctr)×f(xi;θcvr))labelConstructing loss function :Resolve sample selection bias : In the process of training , The model only needs to predict pCTCVR and pCTR, Parameters can be updated , because pCTCVR and pCTR The data is extracted based on the complete sample space , So according to the formula , Can solve pCVR Sample selection deviation
Solve the problem of data sparsity : Use shared embedding layer , bring CVR Subtasks can also learn from samples that show only clicks , It can alleviate the problem of sparse training data
1.3 ESSM Model optimization 1.3 ESSM Model optimization 1.3 ESSM Model optimization
- In the paper , Subtasks are independent Tower The network is pure MLP Model , You can set different models according to your own characteristics , For example, using DeepFM、DIN etc.
- Introduce dynamic weighted learning mechanism , Optimize loss
- Longer sequence dependency models can be built , For example, meituan AITM Credit card business , The user conversion process is exposure -> Click on -> apply -> Nuclear card -> Activate
1.4 ESSM Model code implementation 1.4 ESSM Model code implementation 1.4 ESSM Model code implementation
import torch
import torch.nn.functional as F
from torch_rechub.basic.layers import MLP, EmbeddingLayer
from tqdm import tqdm
class ESMM(torch.nn.Module):
def __init__(self, user_features, item_features, cvr_params, ctr_params):
super().__init__()
self.user_features = user_features
self.item_features = item_features
self.embedding = EmbeddingLayer(user_features + item_features)
self.tower_dims = user_features[0].embed_dim + item_features[0].embed_dim
# structure CVR and CTR The twin towers
self.tower_cvr = MLP(self.tower_dims, **cvr_params)
self.tower_ctr = MLP(self.tower_dims, **ctr_params)
def forward(self, x):
embed_user_features = self.embedding(x, self.user_features,
squeeze_dim=False).sum(dim=1)
embed_item_features = self.embedding(x, self.item_features,
squeeze_dim=False).sum(dim=1)
input_tower = torch.cat((embed_user_features, embed_item_features), dim=1)
cvr_logit = self.tower_cvr(input_tower)
ctr_logit = self.tower_ctr(input_tower)
cvr_pred = torch.sigmoid(cvr_logit)
ctr_pred = torch.sigmoid(ctr_logit)
# Calculation pCTCVR = pCTR * pCVR
ctcvr_pred = torch.mul(cvr_pred, cvr_pred)
ys = [cvr_pred, ctr_pred, ctcvr_pred]
return torch.cat(ys, dim=1)2 MMOE
2.1 MMOE The background
- Multitask model : Learn commonalities and differences between different tasks , It can improve the quality and efficiency of modeling .
- Multi task model design pattern :
- Hard Parameter Sharing Method : The bottom layer is the shared hidden layer , Learn the common pattern of each task , The upper layer uses some specific full connection layers to learn specific task patterns
- Soft Parameter Sharing Method : Underlying does not use shared shared bottom, It's more than one tower, To different tower Assign different weights
- Task sequence dependency modeling : This is suitable for different tasks with a certain sequence dependency
2.2 MOE Models and MMOE Model principle

2.2.1 MOE Model ( Hybrid expert model )
- Model principle : Based on multiple
ExpertAggregate output , Through the gating network mechanism ( Attention networks ) Get eachExpertThe weight of - characteristic : Model integration 、 Attention mechanism 、multi-head Mechanism
2.2.2 MMOE Model
- be based on OMOE Model , Every
ExpertTasks have a gated network - characteristic :
- Avoid task conflicts , Adjust according to different door controls , Select those that are helpful for the current task
ExpertCombine - Build relationships between tasks
- Flexible parameter sharing
- The model can converge quickly during training
- Avoid task conflicts , Adjust according to different door controls , Select those that are helpful for the current task
import torch
import torch.nn as nn
from torch_rechub.basic.layers import MLP, EmbeddingLayer, PredictionLayer
class MMOE(torch.nn.Module):
def __init__(self, features, task_types, n_expert, expert_params, tower_params_list):
super().__init__()
self.features = features
self.task_types = task_types
# Number of tasks
self.n_task = len(task_types)
self.n_expert = n_expert
self.embedding = EmbeddingLayer(features)
self.input_dims = sum([fea.embed_dim for fea in features])
# Every Expert Corresponding to a door control
self.experts = nn.ModuleList(
MLP(self.input_dims, output_layer=False, **expert_params) for i in range(self.n_expert))
self.gates = nn.ModuleList(
MLP(self.input_dims, output_layer=False, **{
"dims": [self.n_expert],
"activation": "softmax"
}) for i in range(self.n_task))
# Two towers
self.towers = nn.ModuleList(MLP(expert_params["dims"][-1], **tower_params_list[i]) for i in range(self.n_task))
self.predict_layers = nn.ModuleList(PredictionLayer(task_type) for task_type in task_types)
def forward(self, x):
embed_x = self.embedding(x, self.features, squeeze_dim=True)
expert_outs = [expert(embed_x).unsqueeze(1) for expert in self.experts]
expert_outs = torch.cat(expert_outs, dim=1)
gate_outs = [gate(embed_x).unsqueeze(-1) for gate in self.gates]
ys = []
for gate_out, tower, predict_layer in zip(gate_outs, self.towers, self.predict_layers):
expert_weight = torch.mul(gate_out, expert_outs)
expert_pooling = torch.sum(expert_weight, dim=1)
# Calculation of Twin Towers
tower_out = tower(expert_pooling)
# logit -> proba
y = predict_layer(tower_out)
ys.append(y)
return torch.cat(ys, dim=1)3 summary
This task , It mainly introduces ESSM and MMOE Multi task learning model principle and code practice :
- ESSM Model : Mainly introduce CTR and CTCVR The auxiliary task of , Solve the problem of sample selection bias and sparse data , Based on the twin tower model , Different models of the two towers can be set according to their own characteristics , The subnetwork supports any replacement
- MMOE Model : Based mainly on OMOE Model , Each of them
ExpertTasks have a gated network , The lower level is MOE Basic model , The upper layer is a twin tower model , Meet each task inExpertDecoupling in combinatorial selection , Flexible parameter sharing 、 Fast convergence of training .
In this paper, the reference :
边栏推荐
- 智能指标驱动的管理和决策平台 Kyligence Zen 全新上线,限量内测中
- [JUC series] ThreadLocal of synchronization tool class
- 面试突击61:说一下MySQL事务隔离级别?
- QQ集体被盗号,猝不及防的大型社死名场面
- Principle and process of MySQL master-slave replication
- Wonderful! Miaoying technology fully implements Zadig to help container construction, and fully embraces kubernetes and Yunyuan
- Unexpected ‘debugger‘ statement no-debugger
- Weekly recommended short video: How did Einstein think?
- GBase8s数据库select有ORDER BY 子句2
- ShanDong Multi-University Training #3
猜你喜欢

How to create new user for ORACLE 19c (CDB & PDB)

面试突击61:说一下MySQL事务隔离级别?

Matlab GUI realizes the function of clicking the button, opening the file dialog box and importing pictures

《高难度谈话》突破谈话瓶颈,实现完美沟通

【JUC系列】同步工具类之ThreadLocal

Artbench: the first class balanced, high-quality, clean annotated and standardized artwork generation data set
![[pbootcms template] composition website / document download website source code](/img/6e/51bbb4ce961defa4abd098ff3af21f.jpg)
[pbootcms template] composition website / document download website source code

面试突击61:说一下MySQL事务隔离级别?

【综合案例】信用卡虚拟交易识别

智能指标驱动的管理和决策平台 Kyligence Zen 全新上线,限量内测中
随机推荐
oracle 19c : change the user sys/system username pasword under Linux
bison使用error死循环的记录
【智能QbD风险评估工具】上海道宁为您带来LeanQbD介绍、试用、教程
Paper reproduction - ac-fpn:attention-guided context feature pyramid network for object detection
GBase8s数据库select有ORDER BY 子句2
Blurred pictures become clear, one button two-color pictures, quickly organize local pictures These 8 online picture tools apply to join your favorites!
Factorization of large numbers ← C language
测试--自动化测试:关于unittest框架
How to create new user for ORACLE 19c (CDB & PDB)
Gbase8s database into table clause
GBase8s数据库INTO table 子句
内插散点数据
Weekly recommended short video: How did Einstein think?
智能指标驱动的管理和决策平台 Kyligence Zen 全新上线,限量内测中
Wonderful! Miaoying technology fully implements Zadig to help container construction, and fully embraces kubernetes and Yunyuan
How can colleges and universities build future oriented smart campus based on cloud native? Full stack cloud native architecture vs traditional IT architecture
GBase8s数据库select有HAVING 子句
cmake 报错
ERP编制物料清单 华夏
Gbase8s database select has order by Clause 3