当前位置:网站首页>推荐系统——An Embedding Learning Framework for Numerical Features in CTR Prediction
推荐系统——An Embedding Learning Framework for Numerical Features in CTR Prediction
2022-07-25 23:08:00 【只会git clone的程序员】
前言
论文:https://arxiv.org/pdf/2012.08986.pdf
收录:KDD‘21
机构:华为
论文解读
这个老哥写的真的很好了,点这里:知乎
代码
论文的代码链接挂了并且源码我看了是华为mindxxx啥深度学习框架写的,而且看了看跟论文的公式也没对齐…干脆我就对着论文用pytorch写了一份
import torch
import torch.nn as nn
class AudisEconder(nn.Module):
r"""Args: in_dim: the dimension of input tensor out_dim: the dimension of output tensor H_j: the number of Meta_embeddings alpha: the factor of skip-connection t: Temperature Coefficient """
def __init__(self, in_dim, out_dim, H_j=20, alpha=0.1, t=1e-5):
super(AudisEconder, self).__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.w_j = nn.Linear(in_dim, H_j)
self.leak_relu = nn.LeakyReLU()
self.W_j = nn.Linear(H_j, H_j)
self.alpha = alpha
self.t = t
self.softmax = nn.Softmax(dim=-1)
self.ME = nn.Parameter(torch.randn(H_j, out_dim))
def forward(self, x):
h_j = self.leak_relu(self.w_j(x))
x_hat_j = self.W_j(h_j) + self.alpha * h_j
x_hat_j_h = self.softmax(x_hat_j / self.t)
e_j = x_hat_j_h @ self.ME
return e_j
if __name__ == '__main__':
input = torch.rand(16, 10)
model = AudisEconder(in_dim=10, out_dim=128)
out = model(input)
print(out.shape)
边栏推荐
猜你喜欢

Network Security Learning (XIV) IP protocol

QT operation to solve large amount of duplicate data

How to obtain the cash flow data of advertising services to help analyze the advertising effect?

AI chief architect 12 AICA industrial landing analysis under the industrial production process optimization scenario

Custom MVC principle

About using NPM command under the terminal, the installation error problem is solved (my own experience)

PHP JSON variable array problem

Notification(状态栏通知)详解

Mongodb features, differences with MySQL, and application scenarios

Details of notification (status bar notification)
随机推荐
@Autowired注解 required属性
Network Security Learning (XV) ARP
Stack and stack class
Week 2: convolutional neural network
2021-09-30
js正则表达式匹配ip地址(ip地址正则表达式验证)
AI首席架构师12-AICA-工业生产过程优化场景下产业落地解析
Understanding of forward proxy and reverse proxy
向下扎根,向上生长,探寻华为云AI的“根”力量
内存分页与调优,内核与用户空间
anaconda安装教程环境变量(如何配置环境变量)
Summary of common PHP functions
Zero crossing position search of discrete data (array)
Enabling partners, how can Amazon cloud technology "get on the horse and get a ride"?
Deep recursion, deep search DFS, backtracking, paper cutting learning.
Review of static routing
PHP binary array is sorted by a field in it
QT operation to solve large amount of duplicate data
IPFs of Internet Protocol
Mongodb的特点、与MySQL的差别、以及应用场景