当前位置:网站首页>FiBiNet torch复现
FiBiNet torch复现
2022-08-03 05:29:00 【WGS.】
网络详解
torch复现
''' Squeeze-and-Excitation Attention '''
class SENetLayer(nn.Module):
""" 输入: [batch_size, field_num, emb_dim] 输出: [batch_size, field_num, emb_dim] """
def __init__(self, field_num, reduction_ratio=3, pooling='mean'):
super().__init__()
self.pooling = pooling
reduced_size = max(1, field_num // reduction_ratio)
self.excitation = nn.Sequential(
nn.Linear(field_num, reduced_size, bias=False),
nn.ReLU(),
nn.Linear(reduced_size, field_num, bias=False),
nn.ReLU()
)
def forward(self, inputs):
if self.pooling == 'mean':
Z = torch.mean(inputs, dim=-1) # Sequeeze: [None, field_num]
elif self.pooling == 'max':
Z = inputs.max(dim=-1)
else:
raise Exception('pooling type unknown')
A = self.excitation(Z) # Excitation: [None, field_num]
V = inputs * A.unsqueeze(-1) # Re-Weight: [None, field_num, emb_dim]
return V
''' BilinearInteraction '''
class BilinearInteraction(nn.Module):
""" 输入: [batch_size, field_num, emb_dim] 输出: [batch_size, field_num * (field_num - 1) / 2, emb_dim] 组合函数combinations从 field_num 中任取2种作为组合,共有 field_num*(field_num-1)中组合方式。所以输出的Field数量变成了 field_num*(field_num-1)/2。 """
def __init__(self, field_num, emb_dim, bilinear_type="field_interaction"):
super().__init__()
self.bilinear_type = bilinear_type
self.bilinear_layer = nn.ModuleList()
# 所有特征共用一个W
if self.bilinear_type == "field_all":
self.bilinear_layer = nn.Linear(emb_dim, emb_dim, bias=False)
# 每个特征field共用一个W
elif self.bilinear_type == "field_each":
for _ in range(field_num):
self.bilinear_layer.append(nn.Linear(emb_dim, emb_dim, bias=False))
# 每对交互特征feature共用一个W
elif self.bilinear_type == "field_interaction":
for _, _ in combinations(range(field_num), 2):
self.bilinear_layer.append(nn.Linear(emb_dim, emb_dim, bias=False))
else:
raise Exception('bilinear type unknown')
def forward(self, inputs):
# 按特征划分(将每个field拿出来),划分后元祖大小为 field_num,元祖中的每个tensor维度为 None * (1, emb_dim)
inputs = torch.split(inputs, 1, dim=1)
# p = v_i \cdot W \odot v_j
if self.bilinear_type == "field_all":
p = [torch.mul(self.bilinear_layer(v_i), v_j)
for v_i, v_j in combinations(inputs, 2)]
# p = v_i \cdot W_i \odot v_j
elif self.bilinear_type == "field_each":
p = [torch.mul(self.bilinear_layer[i](inputs[i]), inputs[j])
for i, j in combinations(range(len(inputs)), 2)]
# p = v_i \cdot W_ij \odot v_j
elif self.bilinear_type == "field_interaction":
p = [torch.mul(self.bilinear_layer[i](v[0]), v[1])
for i, v in enumerate(combinations(inputs, 2))]
return torch.cat(p, dim=1)
class FiBiNet(nn.Module):
def __init__(self, sparse_fields, dense_fields_num, emb_dim=10, reduction_ratio=3, pooling='mean', bilinear_type='field_interaction',
dnn_hidden_units=(64, 32), dropout=0.5, use_bn=True):
super(FiBiNet, self).__init__()
self.dense_fields_num = dense_fields_num
self.sparse_field_num = len(sparse_fields)
self.offsets = np.array((0, *np.cumsum(sparse_fields)[:-1]), dtype=np.longlong)
# self.offsets = nn.Parameter(torch.tensor([0] + feature_fields[:-1]).cumsum(0), requires_grad=False)
# Embedding layer
self.embedding = nn.Embedding(sum(sparse_fields) + 1, embedding_dim=emb_dim)
torch.nn.init.xavier_uniform_(self.embedding.weight.data)
# SENet layer
self.SE = SENetLayer(field_num=self.sparse_field_num, reduction_ratio=reduction_ratio, pooling=pooling)
# Bilinear layer
self.Bilinear = BilinearInteraction(field_num=self.sparse_field_num, emb_dim=emb_dim, bilinear_type=bilinear_type)
# DNN layer
self.dnn_hidden_units = dnn_hidden_units
dnn_layers = []
input_dim = self.sparse_field_num * (self.sparse_field_num - 1) * emb_dim + self.dense_fields_num
for hidden in dnn_hidden_units:
dnn_layers.append(nn.Linear(input_dim, hidden))
if use_bn: dnn_layers.append(nn.BatchNorm1d(hidden))
dnn_layers.append(nn.ReLU())
dnn_layers.append(nn.Dropout(p=dropout))
input_dim = hidden
dnn_layers.append(nn.Linear(input_dim, 1))
self.DNN = nn.Sequential(*dnn_layers)
def forward(self, inputs):
''' 输入:(None, field_num) '''
dense_inputs, sparse_inputs = inputs[:, :13], inputs[:, 13:]
sparse_inputs = sparse_inputs.long() # 转成long类型才能作为nn.embedding的输入
sparse_inputs = sparse_inputs + sparse_inputs.new_tensor(self.offsets).unsqueeze(0)
spare_emb = self.embedding(sparse_inputs) # (None, field_num, emb_dim)
se_out = self.SE(spare_emb) # (None, field_num, emb_dim)
se_bilinear_out = self.Bilinear(se_out).flatten(start_dim=1)
bilinear_out = self.Bilinear(spare_emb).flatten(start_dim=1)
dnn_inp = torch.cat([bilinear_out, se_bilinear_out, dense_inputs], dim=1)
dnn_out = self.DNN(dnn_inp) # (None, 1)
# dnn_out = dnn_out.squeeze(1)
dnn_out = dnn_out.squeeze(-1)
return torch.sigmoid(dnn_out)
全部代码
emb:https://wangguisen.blog.csdn.net/article/details/122697991
网络结构
上面就是
封装训练
def printlog(info):
# nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
# print("%s " % nowtime + "----------"*11 + '---')
print(str(info))
def fit(net, train_loader, val_loader, epochs, loss_function, optimizer, metric_name, ):
start_time = time.time()
print("\n" + "********** start training **********")
columns = ["epoch", "loss", *metric_name, "val_loss"] + ['val_' + mn for mn in metric_name]
dfhistory = pd.DataFrame(columns=columns)
''' 训练 '''
for epoch in range(1, epochs+1):
printlog("Epoch {0} / {1}".format(epoch, epochs))
step_start = time.time()
step_num = 0
train_loss = []
train_pred_probs, train_y, train_pre = [], [], []
net.train()
for batch, (x, y) in enumerate(train_loader):
step_num += 1
optimizer.zero_grad()
pred_probs = net(x)
loss = loss_function(pred_probs, y.float().detach())
loss.backward()
optimizer.step()
train_loss.append(loss.item())
train_pred_probs.extend(pred_probs.tolist())
train_y.extend(y.tolist())
train_pre.extend(torch.where(pred_probs > 0.5, torch.ones_like(pred_probs), torch.zeros_like(pred_probs)))
''' 验证 '''
val_loss = []
val_pred_probs, val_y, val_pre = [], [], []
net.eval()
with torch.no_grad():
for batch, (x, y) in enumerate(val_loader):
pred_probs = net(x)
loss = loss_function(pred_probs, y.float().detach())
val_loss.append(loss.item())
val_pred_probs.extend(pred_probs.tolist())
val_y.extend(y.tolist())
val_pre.extend(torch.where(pred_probs > 0.5, torch.ones_like(pred_probs), torch.zeros_like(pred_probs)))
''' 一次epoch结束 记录日志 '''
epoch_loss, epoch_val_loss = np.mean(train_loss), np.mean(val_loss)
train_auc = roc_auc_score(y_true=train_y, y_score=train_pred_probs)
train_acc = accuracy_score(y_true=train_y, y_pred=train_pre)
val_auc = roc_auc_score(y_true=val_y, y_score=val_pred_probs)
val_acc = accuracy_score(y_true=val_y, y_pred=val_pre)
dfhistory.loc[epoch - 1] = (epoch, epoch_loss, train_acc, train_auc, epoch_val_loss, val_acc, val_auc)
step_end = time.time()
print("step_num: %s - %.1fs - loss: %.5f accuracy: %.5f auc: %.5f - val_loss: %.5f val_accuracy: %.5f val_auc: %.5f"
% (step_num, (step_end - step_start) % 60,
epoch_loss, train_acc, train_auc, epoch_val_loss, val_acc, val_auc))
end_time = time.time()
print('********** end of training run time: {:.0f}分 {:.0f}秒 **********'.format((end_time - start_time) // 60, (end_time - start_time) % 60))
print()
return dfhistory
def plot_metric(dfhistory, metric):
train_metrics = dfhistory[metric]
val_metrics = dfhistory['val_'+metric]
epochs = range(1, len(train_metrics) + 1)
plt.plot(epochs, train_metrics, 'bo--')
plt.plot(epochs, val_metrics, 'ro-')
plt.title('Training and validation '+ metric)
plt.xlabel("Epochs")
plt.ylabel(metric)
plt.legend(["train_"+metric, 'val_'+metric])
plt.show()
main
import numpy as np
import pandas as pd
import datetime
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from itertools import combinations
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
# import torch.utils.data as Data
from torchkeras import summary
from tools import *
if __name__ == '__main__':
pass
# model_test()
data = pd.read_csv('./data/criteo_sampled_data_test.csv')
# I1-I13:总共 13 列数值型特征
# C1-C26:共有 26 列类别型特征
dense_cols = ['I' + str(i) for i in range(1, 14)]
sparse_cols = ['C' + str(i) for i in range(1, 27)]
stat_pnrate_pd(data=data, labname='label', message='criteo_sampled_data_test')
data_X = data[dense_cols + sparse_cols]
data_y = data['label']
sparse_fields = data_X[sparse_cols].max().values + 1
sparse_fields = sparse_fields.astype(np.int32)
print(sparse_fields)
dense_fields_num = 13
tmp_X, test_X, tmp_y, test_y = train_test_split(data_X, data_y, test_size=0.01, random_state=42, stratify=data_y)
train_X, val_X, train_y, val_y = train_test_split(tmp_X, tmp_y, test_size=0.01, random_state=42, stratify=tmp_y)
print(train_X.shape)
print(val_X.shape)
train_set = TensorDataset(torch.tensor(train_X.values).float(), torch.tensor(train_y.values).float())
val_set = TensorDataset(torch.tensor(val_X.values).float(), torch.tensor(val_y.values).float())
train_loader = DataLoader(dataset=train_set, batch_size=2048, shuffle=True)
val_loader = DataLoader(dataset=val_set, batch_size=2048, shuffle=False)
net = FiBiNet(sparse_fields=sparse_fields, dense_fields_num=dense_fields_num, emb_dim=10, reduction_ratio=3, pooling='mean', bilinear_type='field_interaction',
dnn_hidden_units=(128, 128, 32), dropout=0.3, use_bn=True)
loss_function = nn.BCELoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)
dfhistory = fit(net=net, train_loader=train_loader, val_loader=val_loader,
epochs=1, loss_function=loss_function, optimizer=optimizer, metric_name=['accuracy', 'auc'])
# # 评估图
# plot_metric(dfhistory, metric='loss')
# plot_metric(dfhistory, metric='auc')
# plot_metric(dfhistory, metric='accuracy')
# # 预测
# pred_probs = net(torch.tensor(val_X.values).float()).data
# print(pred_probs)
# pred = torch.where(pred_probs > 0.5, torch.ones_like(pred_probs), torch.zeros_like(pred_probs))
# print(pred)
边栏推荐
- 记一次postgresql中使用正则表达式
- C#程序默认以管理员权限打开
- 【Personal summary】Key points of MES system development/management
- ESXI主机给虚拟机添加USB加密狗设备
- 5 个开源的 Rust Web 开发框架,你选择哪个?
- linux安装mysql
- 【地平线 开发板】实现模型转换并在地平线开发板上部署的全过程操作记录(魔改开发包)
- C # program with administrator rights to open by default
- C#切换输入法
- el-tree设置利用setCheckedNodessetCheckedKeys默认勾选节点,以及通过setChecked新增勾选指定节点
猜你喜欢
随机推荐
【项目案例】配置小型网络WLAN基本业务示例
ES 中时间日期类型 “yyyy-MM-dd HHmmss” 的完全避坑指南
SQLServer2019安装(Windows)
ClickHouse 数据插入、更新与删除操作 SQL
【个人总结】MES系统开发/管理要点
UniApp 自定义条件编译详细使用流程
使用Contab调用Shell脚本执行expdp自动备份Oracle
PCB 多层板为什么都是偶数层?
【云原生 · Kubernetes】搭建Harbor仓库
C#通过WebBrowser对网页截图
UniApp 获取当前页面标题(navigationBarTitleText)
ClickHouse删除数据之delete问题详解
mysql事务与多版本并发控制
【IoU loss】IoU损失函数理解
VS Project Configuration Manager
el-table获取读取数据表中某一行的数据属性
npx 有什么作用跟意义?为什么要有 npx?什么场景使用?
界面仅允许扫码枪录入禁止手工键盘输入
2021新版idea过滤无用文件.idea .iml
MySQL的DATE_FORMAT()函数将Date转为字符串