当前位置:网站首页>Embedding的两种实现方式torch代码
Embedding的两种实现方式torch代码
2022-08-03 05:29:00 【WGS.】
详细请看:
以DNN为例全部代码:
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from itertools import combinations
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
# import torch.utils.data as Data
from torchkeras import summary
from tools import *
''' 用于 spare field embedding '''
def sparseFeature(feat, vocabulary_size, embed_dim):
return {
'spare': feat, 'vocabulary_size': vocabulary_size, 'embed_dim': embed_dim}
''' 用于 dense field embedding '''
def denseFeature(feat):
return {
'dense': feat}
class NN_emb_offsets(nn.Module):
def __init__(self, spare_feature_columns, dense_feature_columns, feature_fields,
emb_dim=10, dnn_hidden_units=(64, 32), dropout=0.5):
super(NN_emb_offsets, self).__init__()
self.spare_feature_columns = spare_feature_columns
self.dense_feature_columns = dense_feature_columns
self.spare_field_num = len(spare_feature_columns)
self.dense_field_num = len(dense_feature_columns)
self.offsets = np.array((0, *np.cumsum(feature_fields)[:-1]), dtype=np.long)
# self.offsets = nn.Parameter(torch.tensor([0] + feature_fields[:-1]).cumsum(0), requires_grad=False)
# Embedding layer
self.embedding = nn.Embedding(sum(feature_fields) + 1, embedding_dim=emb_dim)
torch.nn.init.xavier_uniform_(self.embedding.weight.data)
# DNN layer
self.dnn_hidden_units = dnn_hidden_units
dnn_layers = []
input_dim = self.dense_field_num + self.spare_field_num * emb_dim
for hidden in dnn_hidden_units:
dnn_layers.append(nn.Linear(input_dim, hidden))
dnn_layers.append(nn.BatchNorm1d(hidden))
dnn_layers.append(nn.ReLU())
dnn_layers.append(nn.Dropout(p=dropout))
input_dim = hidden
dnn_layers.append(nn.Linear(input_dim, 1))
self.DNN = nn.Sequential(*dnn_layers)
def forward(self, inputs):
dense_inputs, sparse_inputs = inputs[:, :13], inputs[:, 13:]
sparse_inputs = sparse_inputs.long()
sparse_inputs = sparse_inputs + sparse_inputs.new_tensor(self.offsets).unsqueeze(0)
sparse_embed = self.embedding(sparse_inputs) # (None, field_num, emb_dim)
sparse_embed = sparse_embed.view(-1, 26 * 8)
x = torch.cat([dense_inputs, sparse_embed], dim=1) # (batchsize, 26*embed_dim + 13)
dnn_out = self.DNN(x) # (None, 1)
return torch.sigmoid(dnn_out).squeeze(dim=-1) # (batchsize,)
class NN_emb(nn.Module):
def __init__(self, spare_feature_columns, dense_feature_columns, hidden_units=(128, 64, 32), emb_dim=10, droup_out=0.):
super(NN_emb, self).__init__()
self.spare_feature_columns = spare_feature_columns
self.dense_feature_columns = dense_feature_columns
self.spare_field_num = len(spare_feature_columns)
self.dense_field_num = len(dense_feature_columns)
# Embedding
self.embedding_layer = nn.ModuleDict({
'embed_layer{}'.format(i): nn.Embedding(feat['vocabulary_size'], feat['embed_dim'])
for i, feat in enumerate(self.spare_feature_columns)})
for i in range(self.spare_field_num):
torch.nn.init.xavier_uniform_(self.embedding_layer['embed_layer{}'.format(i)].weight.data)
fc_layers = []
input_dim = self.dense_field_num + self.spare_field_num * emb_dim
for fc_dim in hidden_units:
fc_layers.append(nn.Linear(input_dim, fc_dim))
fc_layers.append(nn.BatchNorm1d(fc_dim))
fc_layers.append(nn.ReLU())
fc_layers.append(nn.Dropout(p=droup_out))
input_dim = fc_dim
fc_layers.append(nn.Linear(input_dim, 1))
self.DNN = nn.Sequential(*fc_layers)
def forward(self, inputs):
# dense_inputs: 数值特征,13维
# sparse_inputs: 类别特征,26维 28?
dense_inputs, sparse_inputs = inputs[:, :13], inputs[:, 13:]
# embedding
sparse_inputs = sparse_inputs.long() # 转成long类型才能作为nn.embedding的输入
sparse_embed = torch.cat([self.embedding_layer['embed_layer{}'.format(i)](sparse_inputs[:, i])
for i in range(sparse_inputs.shape[1])], dim=1) # (batchsize, 26*embed_dim)
x = torch.cat([dense_inputs, sparse_embed], dim=1) # (batchsize, 26*embed_dim + 13)
# DNN Layer
dnn_out = self.DNN(x) # (batchsize, 1)
return torch.sigmoid(dnn_out).squeeze(dim=-1) # (batchsize,)
if __name__ == '__main__':
data = pd.read_csv('./data/criteo_sampled_data_test.csv')
stat_pnrate_pd(data=data, labname='label', message='criteo_sampled_data_test')
# I1-I13:总共 13 列数值型特征
# C1-C26:共有 26 列类别型特征
dense_cols = ['I' + str(i) for i in range(1, 14)]
sparse_cols = ['C' + str(i) for i in range(1, 27)]
data_X = data[dense_cols + sparse_cols]
data_y = data['label']
tmp_X, test_X, tmp_y, test_y = train_test_split(data_X, data_y, test_size=0.01, random_state=42, stratify=data_y)
train_X, val_X, train_y, val_y = train_test_split(tmp_X, tmp_y, test_size=0.01, random_state=42, stratify=tmp_y)
train_set = TensorDataset(torch.tensor(train_X.values).float(), torch.tensor(train_y.values).float())
val_set = TensorDataset(torch.tensor(val_X.values).float(), torch.tensor(val_y.values).float())
train_loader = DataLoader(dataset=train_set, batch_size=2048, shuffle=True)
val_loader = DataLoader(dataset=val_set, batch_size=2048, shuffle=False)
emb_dim = 8
dense_feature_columns = [denseFeature(feat) for feat in dense_cols]
# spare_feature_columns = [sparseFeature(feat, data_X[feat].nunique(), emb_dim) for feat in sparse_cols]
spare_feature_columns = [sparseFeature(feat, data_X[feat].max() + 1, emb_dim) for feat in sparse_cols]
print(len(dense_feature_columns), dense_feature_columns)
print(len(spare_feature_columns), spare_feature_columns)
feature_fields = data_X[sparse_cols].max().values + 1
feature_fields = feature_fields.astype(np.int)
print(len(feature_fields), feature_fields)
# net = NN_emb(spare_feature_columns=spare_feature_columns, dense_feature_columns=dense_feature_columns, emb_dim=emb_dim)
net = NN_emb_offsets(spare_feature_columns=spare_feature_columns, dense_feature_columns=dense_feature_columns, feature_fields=feature_fields, emb_dim=emb_dim)
loss_function = nn.BCELoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)
fit(net, train_loader=train_loader, val_loader=val_loader,
epochs=2, loss_function=loss_function, optimizer=optimizer, metric_name=['accuracy', 'auc'])
tf、torch实现embedding
https://wangguisen.blog.csdn.net/article/details/122697991
边栏推荐
猜你喜欢
随机推荐
PCB 多层板为什么都是偶数层?
prometheus 监控mysql数据库
PCB板上的字母代表哪些元器件?一文看全!
Prometheus monitors container, pod, email alerts
MySQL的on duplicate key update 的使用
MySQL的10种常用数据类型
【干货分享】PCB 板变形原因!不看不知道
TFS (Azure conversation) prohibit people checked out at the same time
在Zabbix5.4上使用ODBC监控Oracle数据库
JDBC从手写连接到引用DBCP和C3P0
2021-06-14
mysql事务与多版本并发控制
Cesium加载离线地图和离线地形
JumpServer如何传输文件以及复制剪切板
Oracle Common Commands - Basic Commands
【dllogger bug】AttributeError: module ‘dllogger‘ has no attribute ‘StdOutBackend‘
ES 中时间日期类型 “yyyy-MM-dd HHmmss” 的完全避坑指南
Shell脚本--信号发送与捕捉
界面仅允许扫码枪录入禁止手工键盘输入
Charles抓包显示<unknown>解决方案