当前位置:网站首页>FiBiNet torch复现
FiBiNet torch复现
2022-08-03 05:29:00 【WGS.】
网络详解
torch复现
''' Squeeze-and-Excitation Attention '''
class SENetLayer(nn.Module):
""" 输入: [batch_size, field_num, emb_dim] 输出: [batch_size, field_num, emb_dim] """
def __init__(self, field_num, reduction_ratio=3, pooling='mean'):
super().__init__()
self.pooling = pooling
reduced_size = max(1, field_num // reduction_ratio)
self.excitation = nn.Sequential(
nn.Linear(field_num, reduced_size, bias=False),
nn.ReLU(),
nn.Linear(reduced_size, field_num, bias=False),
nn.ReLU()
)
def forward(self, inputs):
if self.pooling == 'mean':
Z = torch.mean(inputs, dim=-1) # Sequeeze: [None, field_num]
elif self.pooling == 'max':
Z = inputs.max(dim=-1)
else:
raise Exception('pooling type unknown')
A = self.excitation(Z) # Excitation: [None, field_num]
V = inputs * A.unsqueeze(-1) # Re-Weight: [None, field_num, emb_dim]
return V
''' BilinearInteraction '''
class BilinearInteraction(nn.Module):
""" 输入: [batch_size, field_num, emb_dim] 输出: [batch_size, field_num * (field_num - 1) / 2, emb_dim] 组合函数combinations从 field_num 中任取2种作为组合,共有 field_num*(field_num-1)中组合方式。所以输出的Field数量变成了 field_num*(field_num-1)/2。 """
def __init__(self, field_num, emb_dim, bilinear_type="field_interaction"):
super().__init__()
self.bilinear_type = bilinear_type
self.bilinear_layer = nn.ModuleList()
# 所有特征共用一个W
if self.bilinear_type == "field_all":
self.bilinear_layer = nn.Linear(emb_dim, emb_dim, bias=False)
# 每个特征field共用一个W
elif self.bilinear_type == "field_each":
for _ in range(field_num):
self.bilinear_layer.append(nn.Linear(emb_dim, emb_dim, bias=False))
# 每对交互特征feature共用一个W
elif self.bilinear_type == "field_interaction":
for _, _ in combinations(range(field_num), 2):
self.bilinear_layer.append(nn.Linear(emb_dim, emb_dim, bias=False))
else:
raise Exception('bilinear type unknown')
def forward(self, inputs):
# 按特征划分(将每个field拿出来),划分后元祖大小为 field_num,元祖中的每个tensor维度为 None * (1, emb_dim)
inputs = torch.split(inputs, 1, dim=1)
# p = v_i \cdot W \odot v_j
if self.bilinear_type == "field_all":
p = [torch.mul(self.bilinear_layer(v_i), v_j)
for v_i, v_j in combinations(inputs, 2)]
# p = v_i \cdot W_i \odot v_j
elif self.bilinear_type == "field_each":
p = [torch.mul(self.bilinear_layer[i](inputs[i]), inputs[j])
for i, j in combinations(range(len(inputs)), 2)]
# p = v_i \cdot W_ij \odot v_j
elif self.bilinear_type == "field_interaction":
p = [torch.mul(self.bilinear_layer[i](v[0]), v[1])
for i, v in enumerate(combinations(inputs, 2))]
return torch.cat(p, dim=1)
class FiBiNet(nn.Module):
def __init__(self, sparse_fields, dense_fields_num, emb_dim=10, reduction_ratio=3, pooling='mean', bilinear_type='field_interaction',
dnn_hidden_units=(64, 32), dropout=0.5, use_bn=True):
super(FiBiNet, self).__init__()
self.dense_fields_num = dense_fields_num
self.sparse_field_num = len(sparse_fields)
self.offsets = np.array((0, *np.cumsum(sparse_fields)[:-1]), dtype=np.longlong)
# self.offsets = nn.Parameter(torch.tensor([0] + feature_fields[:-1]).cumsum(0), requires_grad=False)
# Embedding layer
self.embedding = nn.Embedding(sum(sparse_fields) + 1, embedding_dim=emb_dim)
torch.nn.init.xavier_uniform_(self.embedding.weight.data)
# SENet layer
self.SE = SENetLayer(field_num=self.sparse_field_num, reduction_ratio=reduction_ratio, pooling=pooling)
# Bilinear layer
self.Bilinear = BilinearInteraction(field_num=self.sparse_field_num, emb_dim=emb_dim, bilinear_type=bilinear_type)
# DNN layer
self.dnn_hidden_units = dnn_hidden_units
dnn_layers = []
input_dim = self.sparse_field_num * (self.sparse_field_num - 1) * emb_dim + self.dense_fields_num
for hidden in dnn_hidden_units:
dnn_layers.append(nn.Linear(input_dim, hidden))
if use_bn: dnn_layers.append(nn.BatchNorm1d(hidden))
dnn_layers.append(nn.ReLU())
dnn_layers.append(nn.Dropout(p=dropout))
input_dim = hidden
dnn_layers.append(nn.Linear(input_dim, 1))
self.DNN = nn.Sequential(*dnn_layers)
def forward(self, inputs):
''' 输入:(None, field_num) '''
dense_inputs, sparse_inputs = inputs[:, :13], inputs[:, 13:]
sparse_inputs = sparse_inputs.long() # 转成long类型才能作为nn.embedding的输入
sparse_inputs = sparse_inputs + sparse_inputs.new_tensor(self.offsets).unsqueeze(0)
spare_emb = self.embedding(sparse_inputs) # (None, field_num, emb_dim)
se_out = self.SE(spare_emb) # (None, field_num, emb_dim)
se_bilinear_out = self.Bilinear(se_out).flatten(start_dim=1)
bilinear_out = self.Bilinear(spare_emb).flatten(start_dim=1)
dnn_inp = torch.cat([bilinear_out, se_bilinear_out, dense_inputs], dim=1)
dnn_out = self.DNN(dnn_inp) # (None, 1)
# dnn_out = dnn_out.squeeze(1)
dnn_out = dnn_out.squeeze(-1)
return torch.sigmoid(dnn_out)
全部代码
emb:https://wangguisen.blog.csdn.net/article/details/122697991
网络结构
上面就是
封装训练
def printlog(info):
# nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
# print("%s " % nowtime + "----------"*11 + '---')
print(str(info))
def fit(net, train_loader, val_loader, epochs, loss_function, optimizer, metric_name, ):
start_time = time.time()
print("\n" + "********** start training **********")
columns = ["epoch", "loss", *metric_name, "val_loss"] + ['val_' + mn for mn in metric_name]
dfhistory = pd.DataFrame(columns=columns)
''' 训练 '''
for epoch in range(1, epochs+1):
printlog("Epoch {0} / {1}".format(epoch, epochs))
step_start = time.time()
step_num = 0
train_loss = []
train_pred_probs, train_y, train_pre = [], [], []
net.train()
for batch, (x, y) in enumerate(train_loader):
step_num += 1
optimizer.zero_grad()
pred_probs = net(x)
loss = loss_function(pred_probs, y.float().detach())
loss.backward()
optimizer.step()
train_loss.append(loss.item())
train_pred_probs.extend(pred_probs.tolist())
train_y.extend(y.tolist())
train_pre.extend(torch.where(pred_probs > 0.5, torch.ones_like(pred_probs), torch.zeros_like(pred_probs)))
''' 验证 '''
val_loss = []
val_pred_probs, val_y, val_pre = [], [], []
net.eval()
with torch.no_grad():
for batch, (x, y) in enumerate(val_loader):
pred_probs = net(x)
loss = loss_function(pred_probs, y.float().detach())
val_loss.append(loss.item())
val_pred_probs.extend(pred_probs.tolist())
val_y.extend(y.tolist())
val_pre.extend(torch.where(pred_probs > 0.5, torch.ones_like(pred_probs), torch.zeros_like(pred_probs)))
''' 一次epoch结束 记录日志 '''
epoch_loss, epoch_val_loss = np.mean(train_loss), np.mean(val_loss)
train_auc = roc_auc_score(y_true=train_y, y_score=train_pred_probs)
train_acc = accuracy_score(y_true=train_y, y_pred=train_pre)
val_auc = roc_auc_score(y_true=val_y, y_score=val_pred_probs)
val_acc = accuracy_score(y_true=val_y, y_pred=val_pre)
dfhistory.loc[epoch - 1] = (epoch, epoch_loss, train_acc, train_auc, epoch_val_loss, val_acc, val_auc)
step_end = time.time()
print("step_num: %s - %.1fs - loss: %.5f accuracy: %.5f auc: %.5f - val_loss: %.5f val_accuracy: %.5f val_auc: %.5f"
% (step_num, (step_end - step_start) % 60,
epoch_loss, train_acc, train_auc, epoch_val_loss, val_acc, val_auc))
end_time = time.time()
print('********** end of training run time: {:.0f}分 {:.0f}秒 **********'.format((end_time - start_time) // 60, (end_time - start_time) % 60))
print()
return dfhistory
def plot_metric(dfhistory, metric):
train_metrics = dfhistory[metric]
val_metrics = dfhistory['val_'+metric]
epochs = range(1, len(train_metrics) + 1)
plt.plot(epochs, train_metrics, 'bo--')
plt.plot(epochs, val_metrics, 'ro-')
plt.title('Training and validation '+ metric)
plt.xlabel("Epochs")
plt.ylabel(metric)
plt.legend(["train_"+metric, 'val_'+metric])
plt.show()
main
import numpy as np
import pandas as pd
import datetime
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from itertools import combinations
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
# import torch.utils.data as Data
from torchkeras import summary
from tools import *
if __name__ == '__main__':
pass
# model_test()
data = pd.read_csv('./data/criteo_sampled_data_test.csv')
# I1-I13:总共 13 列数值型特征
# C1-C26:共有 26 列类别型特征
dense_cols = ['I' + str(i) for i in range(1, 14)]
sparse_cols = ['C' + str(i) for i in range(1, 27)]
stat_pnrate_pd(data=data, labname='label', message='criteo_sampled_data_test')
data_X = data[dense_cols + sparse_cols]
data_y = data['label']
sparse_fields = data_X[sparse_cols].max().values + 1
sparse_fields = sparse_fields.astype(np.int32)
print(sparse_fields)
dense_fields_num = 13
tmp_X, test_X, tmp_y, test_y = train_test_split(data_X, data_y, test_size=0.01, random_state=42, stratify=data_y)
train_X, val_X, train_y, val_y = train_test_split(tmp_X, tmp_y, test_size=0.01, random_state=42, stratify=tmp_y)
print(train_X.shape)
print(val_X.shape)
train_set = TensorDataset(torch.tensor(train_X.values).float(), torch.tensor(train_y.values).float())
val_set = TensorDataset(torch.tensor(val_X.values).float(), torch.tensor(val_y.values).float())
train_loader = DataLoader(dataset=train_set, batch_size=2048, shuffle=True)
val_loader = DataLoader(dataset=val_set, batch_size=2048, shuffle=False)
net = FiBiNet(sparse_fields=sparse_fields, dense_fields_num=dense_fields_num, emb_dim=10, reduction_ratio=3, pooling='mean', bilinear_type='field_interaction',
dnn_hidden_units=(128, 128, 32), dropout=0.3, use_bn=True)
loss_function = nn.BCELoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)
dfhistory = fit(net=net, train_loader=train_loader, val_loader=val_loader,
epochs=1, loss_function=loss_function, optimizer=optimizer, metric_name=['accuracy', 'auc'])
# # 评估图
# plot_metric(dfhistory, metric='loss')
# plot_metric(dfhistory, metric='auc')
# plot_metric(dfhistory, metric='accuracy')
# # 预测
# pred_probs = net(torch.tensor(val_X.values).float()).data
# print(pred_probs)
# pred = torch.where(pred_probs > 0.5, torch.ones_like(pred_probs), torch.zeros_like(pred_probs))
# print(pred)
边栏推荐
猜你喜欢
MySQL的Replace用法详解
Scala 基础 (三):运算符和流程控制
在Zabbix5.4上使用ODBC监控Oracle数据库
Oracle数据文件收缩_最佳实践_超简单方法
MySQL的10种常用数据类型
TFS (Azure conversation) prohibit people checked out at the same time
MySQL 日期时间类型精确到毫秒
mysql事务与多版本并发控制
【dllogger bug】AttributeError: module ‘dllogger‘ has no attribute ‘StdOutBackend‘
2021年PHP-Laravel面试题问卷题 答案记录
随机推荐
高密度 PCB 线路板设计中的过孔知识
el-table获取读取数据表中某一行的数据属性
超全!9种PCB表面处理工艺大对比
ESXI主机给虚拟机添加USB加密狗设备
配置MSTP功能示例
Command errored out with exit status 1类似问题解决方案
Oracle 数据库集群常用巡检命令
Use of Alibaba Cloud SMS Service (create, test notes)
【EA Price strategy OC1】以实时价格为依据的EA,首月翻仓!】
MySQL的on duplicate key update 的使用
流式低代码编程,拖拽节点画流程图并运行
Shell脚本--信号发送与捕捉
2021新版idea过滤无用文件.idea .iml
C#使用Oracle.ManagedDataAccess连接C#数据库
mysql 数据去重的三种方式[实战]
Servlet详解含实例
2021-06-15
【干货分享】PCB 板变形原因!不看不知道
cookie和session区别
Chrome 配置samesite=none方式