当前位置:网站首页>推荐模型复现(三):召回模型YoutubeDNN、DSSM
推荐模型复现(三):召回模型YoutubeDNN、DSSM
2022-06-29 11:48:00 【GoAI】
本章为推荐模型复现第三章,使用torch_rechub框架进行模型搭建,主要介绍推荐系统召回模型YoutubeDNN、DSSM,包括结构讲解与代码实战,参考其他文章。
推荐方向资料推荐:
1. DSSM
1.1 DSSM模型原理
DSSM(Deep Structured Semantic Model),由微软研究院提出,利用深度神经网络将文本表示为低维度的向量,应用于文本相似度匹配场景下的一个算法。不仅局限于文本,在其他可以计算相似性计算的场景,例如推荐系统中。根据用户搜索行为中query(文本搜索)和doc(要匹配的文本)的日志数据,使用深度学习网络将query和doc映射到相同维度的语义空间中,即query侧特征的embedding和doc侧特征的embedding,从而得到语句的低维语义向量表达sentence embedding,用于预测两句话的语义相似度。
1.2 DSSM结构

模型结构:user侧塔和item侧塔分别经过各自的DNN得到embedding,再计算两者之间的相似度
特点:
- user和item两侧最终得到的embedding维度需要保持一致
- 对物料库中所有item计算相似度时,使用负采样进行近似计算
- 在海量的候选数据进行召回的场景下,速度很快
**缺点:**双塔的结构无法考虑两侧特征之间的交互信息,在一定程度上牺牲掉模型的部分精准性。
1.3 正负样本构建
正样本:以内容推荐为例,选“用户点击”的item为正样本。最多考虑一下用户停留时长,将“用户误点击”排除在外
负样本:user与item不匹配的样本,为负样本。
- 全局随机采样: 从全局候选item里面随机抽取一定数量作为召回模型的负样本,但可能会导致长尾现象。
- 全局随机采样+热门打压:对一些热门item进行适当的采样,减少热门对搜索的影响,提高模型对相似item的区分能力。
- Hard Negative增强样本:选取一部分匹配度适中的item,增加模型在训练时的难度
- Batch内随机选择:利用其他样本的正样本在batch内随机采样作为自己的负样本
1.4 DSSM的代码
class DSSM(torch.nn.Module):
def __init__(self, user_features, item_features, user_params, item_params, temperature=1.0):
super().__init__()
self.user_features = user_features
self.item_features = item_features
self.temperature = temperature
self.user_dims = sum([fea.embed_dim for fea in user_features])
self.item_dims = sum([fea.embed_dim for fea in item_features])
self.embedding = EmbeddingLayer(user_features + item_features)
self.user_mlp = MLP(self.user_dims, output_layer=False, **user_params)
self.item_mlp = MLP(self.item_dims, output_layer=False, **item_params)
self.mode = None
def forward(self, x):
user_embedding = self.user_tower(x)
item_embedding = self.item_tower(x)
if self.mode == "user":
return user_embedding
if self.mode == "item":
return item_embedding
# 计算余弦相似度
y = torch.mul(user_embedding, item_embedding).sum(dim=1)
return torch.sigmoid(y)
def user_tower(self, x):
if self.mode == "item":
return None
input_user = self.embedding(x, self.user_features, squeeze_dim=True)
# user DNN
user_embedding = self.user_mlp(input_user)
user_embedding = F.normalize(user_embedding, p=2, dim=1)
return user_embedding
def item_tower(self, x):
if self.mode == "user":
return None
input_item = self.embedding(x, self.item_features, squeeze_dim=True)
# item DNN
item_embedding = self.item_mlp(input_item)
item_embedding = F.normalize(item_embedding, p=2, dim=1)
return item_embedding
2. YoutubeDNN
2.1 YoutubeDNN模型原理
YoutubeDNN是Youtube用于做视频推荐的落地模型,可谓推荐系统中的经典,其大体思路为召回阶段使用多个简单模型筛除大量相关度较低的样本,排序阶段使用较为复杂的模型获取精准的推荐结果。
2.2 YoutubeDNN结构

召回部分: 主要的输入是用户的点击历史数据,输出是与该用户相关的一个候选视频集合;
精排部分: 主要方法是特征工程, 模型设计和训练方法;
线下评估:采用一些常用的评估指标,通过A/B实验观察用户真实行为;
2.2.1 YoutubeDNN召回模型

- 输入层是用户观看视频序列的embedding mean pooling、搜索词的embedding mean pooling、地理位置embedding、用户特征;
- 输入层给到三层激活函数位ReLU的全连接层,然后得到用户向量;
- 最后,经过softmax层,得到每个视频的观看概率。
2.2.2 训练数据选取
- 采样方式:负采样(类似于skip-gram的采样)
- 样本来源:来自于全部的YouTube用户观看记录,包含用户从其他渠道观看的视频
注意:
- 训练数据中对于每个用户选取相同的样本数,保证用户样本在损失函数中的权重;
- 避免让模型知道不该知道的信息,即信息泄露
2.2.3 Example Age特征
- what:Example Age为 视频年龄的特征,即视频的发布时间
- 背景:由于用户对新视频的观看特点,导致视频的播放预测值期望不准确。
- 作用:捕捉视频的生命周期,让模型学习到用户对新颖内容的bias,消除热度偏见。
- 操作:在线上预测时,将example age全部设为0或一个小的负值,不依赖于各个视频的上传时间。
- 好处:将example age设置为常数值,在计算用户向量时只需要一次;对不同的视频,对应的example age所在范围一致,只依赖训练数据选取的时间跨度,便于归一化操作。
2.3 YoutubeDNN代码
import torch
import torch.nn.functional as F
from torch_rechub.basic.layers import MLP, EmbeddingLayer
from tqdm import tqdm
class YoutubeDNN(torch.nn.Module):
def __init__(self, user_features, item_features, neg_item_feature, user_params, temperature=1.0):
super().__init__()
self.user_features = user_features
self.item_features = item_features
self.neg_item_feature = neg_item_feature
self.temperature = temperature
self.user_dims = sum([fea.embed_dim for fea in user_features])
self.embedding = EmbeddingLayer(user_features + item_features)
self.user_mlp = MLP(self.user_dims, output_layer=False, **user_params)
self.mode = None
def forward(self, x):
user_embedding = self.user_tower(x)
item_embedding = self.item_tower(x)
if self.mode == "user":
return user_embedding
if self.mode == "item":
return item_embedding
# 计算相似度
y = torch.mul(user_embedding, item_embedding).sum(dim=2)
y = y / self.temperature
return y
def user_tower(self, x):
# 用于inference_embedding阶段
if self.mode == "item":
return None
input_user = self.embedding(x, self.user_features, squeeze_dim=True)
user_embedding = self.user_mlp(input_user).unsqueeze(1)
user_embedding = F.normalize(user_embedding, p=2, dim=2)
if self.mode == "user":
return user_embedding.squeeze(1)
return user_embedding
def item_tower(self, x):
if self.mode == "user":
return None
pos_embedding = self.embedding(x, self.item_features, squeeze_dim=False)
pos_embedding = F.normalize(pos_embedding, p=2, dim=2)
if self.mode == "item":
return pos_embedding.squeeze(1)
neg_embeddings = self.embedding(x, self.neg_item_feature, squeeze_dim=False).squeeze(1)
neg_embeddings = F.normalize(neg_embeddings, p=2, dim=2)
return torch.cat((pos_embedding, neg_embeddings), dim=1)
3. 总结
- DSSM为双塔模型,user与item分别经过的DNN得到embedding,再计算两者之间的相似度。训练样本,正样本为正确的搜索目标,负样本为全局采样+热门打击所得到的负样本。
- YoutubeDNN在双塔模型基础上进行了改进,召回阶段使用多个简单模型筛除大量相关度较低的样本,排序阶段使用较为复杂的模型获取精准的推荐结果。
参考:
边栏推荐
- ShanDong Multi-University Training #3
- What is the main account of Chia Tai futures used for 4 quotation software?
- 解决问题:ModuleNotFoundError: No module named ‘pip‘
- 百度云盘不限速下载大文件(2021-11亲测有效)
- 535. encryption and decryption of tinyurl: design a URL simplification system
- oracle 19c : change the user sys/system username pasword under Linux
- GBase8s数据库select有ORDER BY 子句2
- An interpretable geometric depth learning model for structure based protein binding site prediction
- Gbase8s database into external clause
- GBase8s数据库select有HAVING 子句
猜你喜欢

【综合案例】信用卡虚拟交易识别
![Jerry's about TWS channel configuration [chapter]](/img/2c/58a49dea7a7931c4d1f055548c2493.png)
Jerry's about TWS channel configuration [chapter]

爱可可AI前沿推介(6.29)

【JUC系列】同步工具类之ThreadLocal

Principle and process of MySQL master-slave replication

After class assignment of module 5 of the construction practice camp

An interpretable geometric depth learning model for structure based protein binding site prediction

Pro test! Centos7 deploy PHP + spool

每周推荐短视频:爱因斯坦是怎样思考问题的?

Dragon Book tiger Book whale Book gnawing? Try the monkey book with Douban score of 9.5
随机推荐
How do I open an account now? Is there a faster and safer opening channel
速看|期待已久的2022年广州助理检测工程师真题解析终于出炉
对p值的理解
Inferiority complex and transcendence the meaning of life to you
GBase8s数据库FOR READ ONLY 子句
模糊图片变清晰,一键双色图片,快速整理本地图片...这8个在线图片工具申请加入你的收藏夹!
Jerry's configuration of TWS cross pairing [chapter]
Some printer driver PPD files of Lenovo Lingxiang lenovoimage
求大数的阶乘 ← C语言
Wang Yingqi, founder of ones, talks to fortune (Chinese version): is there any excellent software in China?
Introduction to multi project development - business scenario Association basic introduction test payroll
缓存一致性,删除缓存,写入缓存,缓存击穿,缓存穿透,缓存雪崩
Wonderful! Miaoying technology fully implements Zadig to help container construction, and fully embraces kubernetes and Yunyuan
bison使用error死循环的记录
Unexpected ‘debugger‘ statement no-debugger
Jerry's about TWS channel configuration [chapter]
面试突击61:说一下MySQL事务隔离级别?
GBase8s数据库select有ORDER BY 子句2
Ttchat x Zadig open source co creates helm access scenarios, and environmental governance can be done!
Understanding of P value