当前位置:网站首页>AutoInt网络详解及pytorch复现
AutoInt网络详解及pytorch复现
2022-08-03 05:29:00 【WGS.】
网络详解
网络代码
# coding:utf-8
# @Email: [email protected]
# @Time: 2022/7/25 6:02 下午
# @File: AutoInt.py
import numpy as np
import pandas as pd
import datetime
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from itertools import combinations
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
import torch.nn.functional as F
# import torch.utils.data as Data
from torchkeras import summary
from sklearn.preprocessing import LabelEncoder
from tools import *
import BaseModel
class MultiheadAttention(nn.Module):
def __init__(self, emb_dim, head_num, scaling=True, use_residual=True):
super(MultiheadAttention, self).__init__()
self.emb_dim = emb_dim
self.head_num = head_num
self.scaling = scaling
self.use_residual = use_residual
self.att_emb_size = emb_dim // head_num
assert emb_dim % head_num == 0, "emb_dim must be divisible head_num"
self.W_Q = nn.Parameter(torch.Tensor(emb_dim, emb_dim))
self.W_K = nn.Parameter(torch.Tensor(emb_dim, emb_dim))
self.W_V = nn.Parameter(torch.Tensor(emb_dim, emb_dim))
if self.use_residual:
self.W_R = nn.Parameter(torch.Tensor(emb_dim, emb_dim))
# 初始化, 避免计算得到nan
for weight in self.parameters():
nn.init.xavier_uniform_(weight)
# nn.init.normal_(weight, mean=0.0, std=0.05)
def forward(self, inputs):
'''1. 线性变换生成Q、K、V'''
# dim: [batch_size, fields, emb_size]
querys = torch.tensordot(inputs, self.W_Q, dims=([-1], [0]))
keys = torch.tensordot(inputs, self.W_K, dims=([-1], [0]))
values = torch.tensordot(inputs, self.W_V, dims=([-1], [0]))
# # 等价于 matmul
# querys = torch.matmul(inputs, self.W_Q)
# keys = torch.matmul(inputs, self.W_K)
# values = torch.matmul(inputs, self.W_V)
'''2. 分头'''
# dim: [head_num, batch_size, fields, emb_size // head_num]
querys = torch.stack(torch.split(querys, self.att_emb_size, dim=2))
keys = torch.stack(torch.split(keys, self.att_emb_size, dim=2))
values = torch.stack(torch.split(values, self.att_emb_size, dim=2))
'''3. 缩放点积注意力'''
# dim: [head_num, batch_size, fields, emb_size // head_num]
inner_product = torch.matmul(querys, keys.transpose(-2, -1))
# # 等价于
# inner_product = torch.einsum('bnik,bnjk->bnij', querys, keys)
if self.scaling:
inner_product /= self.att_emb_size ** 0.5
# Softmax归一化权重
attn_w = F.softmax(inner_product, dim=-1)
# 加权求和, attention结果与V相乘,得到多头注意力结果
results = torch.matmul(attn_w, values)
'''4. 拼接多头空间'''
# dim: [batch_size, fields, emb_size]
results = torch.cat(torch.split(results, 1, ), dim=-1)
results = torch.squeeze(results, dim=0)
# 跳跃连接
if self.use_residual:
results = results + torch.tensordot(inputs, self.W_R, dims=([-1], [0]))
results = F.relu(results)
# results = F.tanh(results)
return results
class AutoIntNet(nn.Module):
def __init__(self, sparse_fields, dense_fields, emb_dim=10, head_num=1, attn_layers=1, scaling=True, use_residual=True,
dnn_hidden_units=(64, 32), dropout=0.2, use_bn=True):
super(AutoIntNet, self).__init__()
self.dense_fields_num = len(dense_fields)
self.sparse_field_num = len(sparse_fields)
self.sparse_offsets = np.array((0, *np.cumsum(sparse_fields)[:-1]), dtype=np.longlong)
self.dense_offsets = np.array((0, *np.cumsum(dense_fields)[:-1]), dtype=np.longlong)
# Embedding layer
self.sparse_embedding = nn.Embedding(sum(sparse_fields) + 1, embedding_dim=emb_dim)
self.dense_embedding = nn.Embedding(sum(dense_fields) + 1, embedding_dim=emb_dim)
torch.nn.init.xavier_uniform_(self.sparse_embedding.weight.data)
torch.nn.init.xavier_uniform_(self.dense_embedding.weight.data)
# DNN layer
self.dnn_hidden_units = dnn_hidden_units
dnn_layers = []
self.dnn_input_dim = self.sparse_field_num * emb_dim + self.dense_fields_num * emb_dim
input_dim = self.dnn_input_dim
for hidden in dnn_hidden_units:
dnn_layers.append(nn.Linear(input_dim, hidden))
if use_bn: dnn_layers.append(nn.BatchNorm1d(hidden))
dnn_layers.append(nn.ReLU())
dnn_layers.append(nn.Dropout(p=dropout))
input_dim = hidden
dnn_layers.append(nn.Linear(input_dim, 1))
self.DNN = nn.Sequential(*dnn_layers)
# Interaction Layer
self.att_output_dim = self.sparse_field_num * emb_dim + self.dense_fields_num * emb_dim
multi_attn_layers = []
for i in range(attn_layers):
multi_attn_layers.append(MultiheadAttention(emb_dim=emb_dim, head_num=head_num, scaling=scaling, use_residual=use_residual))
self.multi_attn = nn.Sequential(*multi_attn_layers)
self.attn_fc = torch.nn.Linear(self.att_output_dim, 1)
def forward(self, inputs):
# 编码后的数值特征、数值特征、离散特征
dense_enc_inputs, dense_inputs, sparse_inputs = inputs[:, :13], inputs[:, 13:26], inputs[:, 26:]
dense_enc_inputs = dense_enc_inputs.long()
sparse_inputs = sparse_inputs.long()
# dense embedding
dense_enc_inputs = dense_enc_inputs + dense_enc_inputs.new_tensor(self.dense_offsets).unsqueeze(0)
dense_emb = self.dense_embedding(dense_enc_inputs)
dense_inputs = torch.unsqueeze(dense_inputs, dim=-1)
dense_emb = dense_inputs * dense_emb # [1960, 13, 10]
# sparse embedding
sparse_inputs = sparse_inputs + sparse_inputs.new_tensor(self.sparse_offsets).unsqueeze(0)
spare_emb = self.sparse_embedding(sparse_inputs) # [1960, 26, 10]
x = torch.cat([spare_emb, dense_emb], dim=1) # [1960, 39, 10]
dnn_out = self.DNN(x.view(-1, self.dnn_input_dim)) # [1960, 1]
attn_out = self.multi_attn(x) # [1960, 39, 10]
attn_out = self.attn_fc(attn_out.view(-1, self.att_output_dim)) # [1960, 1]
outs = dnn_out + attn_out
return torch.sigmoid(outs.squeeze(1))
训练代码
# coding:utf-8
# @Email: [email protected]
# @Time: 2022/7/22 3:16 下午
# @File: BaseModel.py
import pandas as pd, numpy as np
from sklearn import metrics
from sklearn.metrics import roc_auc_score, accuracy_score
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import time
def printlog(info):
# nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
# print("%s " % nowtime + "----------"*11 + '---')
print(str(info))
class BaseModel():
def __init__(self, net):
super(BaseModel, self).__init__()
self.net = net
def fit(self, train_loader, val_loader, epochs, loss_function, optimizer, metric_name):
start_time = time.time()
print("\n" + "********** start training **********")
columns = ["epoch", "loss", *metric_name, "val_loss"] + ['val_' + mn for mn in metric_name]
dfhistory = pd.DataFrame(columns=columns)
''' 训练 '''
for epoch in range(1, epochs + 1):
printlog("Epoch {0} / {1}".format(epoch, epochs))
step_start = time.time()
step_num = 0
train_loss = []
train_pred_probs, train_y, train_pre = [], [], []
self.net.train()
for batch, (x, y) in enumerate(train_loader):
step_num += 1
# 梯度清零
optimizer.zero_grad()
# 正向传播求损失
pred_probs = self.net(x)
loss = loss_function(pred_probs, y.float().detach())
# loss = loss_function(pred, y)
# 反向传播求梯度
loss.backward()
optimizer.step()
train_loss.append(loss.item())
train_pred_probs.extend(pred_probs.tolist())
train_y.extend(y.tolist())
train_pre.extend(torch.where(pred_probs > 0.5, torch.ones_like(pred_probs), torch.zeros_like(pred_probs)))
''' 验证 '''
val_loss = []
val_pred_probs, val_y, val_pre = [], [], []
self.net.eval()
# 不参与梯度计算
with torch.no_grad():
for batch, (x, y) in enumerate(val_loader):
pred_probs = self.net(x)
loss = loss_function(pred_probs, y.float().detach())
val_loss.append(loss.item())
val_pred_probs.extend(pred_probs.tolist())
val_y.extend(y.tolist())
val_pre.extend(
torch.where(pred_probs > 0.5, torch.ones_like(pred_probs), torch.zeros_like(pred_probs)))
''' 一次epoch结束 记录日志 '''
epoch_loss, epoch_val_loss = np.mean(train_loss), np.mean(val_loss)
train_auc = roc_auc_score(y_true=train_y, y_score=train_pred_probs)
train_acc = accuracy_score(y_true=train_y, y_pred=train_pre)
val_auc = roc_auc_score(y_true=val_y, y_score=val_pred_probs)
val_acc = accuracy_score(y_true=val_y, y_pred=val_pre)
dfhistory.loc[epoch - 1] = (epoch, epoch_loss, train_acc, train_auc, epoch_val_loss, val_acc, val_auc)
step_end = time.time()
print("step_num: %s - %.1fs - loss: %.5f accuracy: %.5f auc: %.5f - val_loss: %.5f val_accuracy: %.5f val_auc: %.5f"
% (step_num, (step_end - step_start) % 60, epoch_loss, train_acc, train_auc, epoch_val_loss, val_acc, val_auc))
end_time = time.time()
print('********** end of training run time: {:.0f}分 {:.0f}秒 **********'.format((end_time - start_time) // 60,
(end_time - start_time) % 60))
print()
return dfhistory
def evaluate(self, val_X, val_y):
val_X = torch.tensor(val_X).float()
pred_probs = self.net(val_X).data
pred = torch.where(pred_probs > 0.5, torch.ones_like(pred_probs), torch.zeros_like(pred_probs))
precision = np.around(metrics.precision_score(val_y, pred), 4)
recall = np.around(metrics.recall_score(val_y, pred), 4)
accuracy = np.around(metrics.accuracy_score(val_y, pred), 4)
f1 = np.around(metrics.f1_score(val_y, pred), 4)
auc = np.around(metrics.roc_auc_score(val_y, pred_probs), 4)
loss = np.around(metrics.log_loss(val_y, pred), 4)
acc_condition, precision_condition, recall_condition = self.accDealWith2(val_y, pred)
return precision, recall, accuracy, f1, auc, loss, acc_condition, precision_condition, recall_condition
def predict(self, x):
pred_probs = self.net(torch.tensor(x).float()).data
print(pred_probs)
pred = torch.where(pred_probs > 0.5, torch.ones_like(pred_probs), torch.zeros_like(pred_probs))
print(pred)
def plot_metric(self, dfhistory, metric):
train_metrics = dfhistory[metric]
val_metrics = dfhistory['val_' + metric]
epochs = range(1, len(train_metrics) + 1)
plt.plot(epochs, train_metrics, 'bo--')
plt.plot(epochs, val_metrics, 'ro-')
plt.title('Training and validation ' + metric)
plt.xlabel("Epochs")
plt.ylabel(metric)
plt.legend(["train_" + metric, 'val_' + metric])
plt.show()
def accDealWith2(self, y_test, y_pre):
lenall = len(y_test)
if type(y_test) != list:
y_test = y_test.flatten()
pos = 0
pre = 0
rec = 0
precisoinlen = 0
recallLen = 0
for i in range(lenall):
# 准确率
if y_test[i] == y_pre[i]:
pos += 1
# 精确率
if y_pre[i] == 1:
pre += 1
if y_test[i] == 1:
precisoinlen += 1
# 召回率
if y_test[i] == 1:
rec += 1
if y_pre[i] == 1:
recallLen += 1
acc_condition = '预测对的:{},总样本:{}'.format(pos, lenall)
if pre != 0:
precision_condition = '预测为正的样本数:{},其中实际为正的样本数:{},精确率:{}'.format(pre, precisoinlen,
np.around(precisoinlen / pre, 4))
else:
precision_condition = '预测为正的样本数:{},其中实际为正的样本数:{},精确率:{}'.format(pre, precisoinlen, 0.0)
if rec != 0:
recall_condition = '正例样本:{},正例中预测正确的数量:{},召回率:{}'.format(rec, recallLen, np.around(recallLen / rec, 4))
else:
recall_condition = '正例样本:{},正例中预测正确的数量:{},召回率:{}'.format(rec, recallLen, 0.0)
return acc_condition, precision_condition, recall_condition
main
if __name__ == '__main__':
data = pd.read_csv('./data/criteo_sampled_data_test.csv')
# I1-I13:总共 13 列数值型特征
# C1-C26:共有 26 列类别型特征
dense_cols = ['I' + str(i) for i in range(1, 14)]
sparse_cols = ['C' + str(i) for i in range(1, 27)]
stat_pnrate_pd(data=data, labname='label', message='criteo_sampled_data_test')
# 数值特征编码用以embedding
data, dense_cols_enc, dense_enc_dict = label_enc_sk(data=data, cols=dense_cols)
data_X = data[dense_cols_enc + dense_cols + sparse_cols]
data_y = data['label']
'''定义 sparse_fields'''
sparse_fields = data_X[sparse_cols].max().values + 1
sparse_fields = sparse_fields.astype(np.int32)
print(sparse_fields)
'''定义 dense_fields'''
dense_fields = data_X[dense_cols_enc].max().values + 1
dense_fields = dense_fields.astype(np.int32)
print(dense_fields)
tmp_X, test_X, tmp_y, test_y = train_test_split(data_X, data_y, test_size=0.01, random_state=42, stratify=data_y)
train_X, val_X, train_y, val_y = train_test_split(tmp_X, tmp_y, test_size=0.01, random_state=42, stratify=tmp_y)
print(train_X.shape)
print(val_X.shape)
train_set = TensorDataset(torch.tensor(train_X.values).float(), torch.tensor(train_y.values).float())
val_set = TensorDataset(torch.tensor(val_X.values).float(), torch.tensor(val_y.values).float())
train_loader = DataLoader(dataset=train_set, batch_size=2048, shuffle=True)
val_loader = DataLoader(dataset=val_set, batch_size=2048, shuffle=False)
net = AutoIntNet(sparse_fields=sparse_fields, dense_fields=dense_fields, emb_dim=10, head_num=5, attn_layers=6, scaling=True, use_residual=True,
dnn_hidden_units=(128, 64, 32), dropout=0.2, use_bn=True)
loss_function = nn.BCELoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)
base = BaseModel.BaseModel(net=net)
dfhistory = base.fit(train_loader=train_loader, val_loader=val_loader,
epochs=1, loss_function=loss_function, optimizer=optimizer, metric_name=['accuracy', 'auc'])
# base.plot_metric(dfhistory, metric='loss')
# print(base.evaluate(val_X=val_X.values, val_y=val_y.values))
summary(model=net, input_data=torch.tensor(val_X.values).float())
边栏推荐
猜你喜欢
随机推荐
contos安装php-ffmpeg和tp5.1使用插件
VB.net如何使用List类型
Content type ‘applicationx-www-form-urlencoded;charset=UTF-8‘ not supported“【已解决】
一文读懂PCB品质体系认证
VS项目配置管理器
Postman知识汇总
2021-06-15
【项目案例】配置小型网络WLAN基本业务示例
Redis哨兵模式+过期策略、淘汰策略、读写策略
WinServer2012r2破解多用户同时远程登录,并取消用户控制
PCB 多层板为什么都是偶数层?
Mysql去除重复数据
2021新版idea过滤无用文件.idea .iml
使用Contab调用Shell脚本执行expdp自动备份Oracle
国内首款PCB资料分析软件,华秋DFM使用介绍
数据库OracleRAC节点宕机处理流程
mysql 数据去重的三种方式[实战]
【EA Price strategy OC1】以实时价格为依据的EA,首月翻仓!】
【设计指南】避免PCB板翘,合格的工程师都会这样设计!
Migration of BOA Server