当前位置:网站首页>IFM网络详解及torch复现
IFM网络详解及torch复现
2022-08-03 05:29:00 【WGS.】
IFM网络详解
网络结构代码
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from itertools import combinations
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
import torch.nn.functional as F
# import torch.utils.data as Data
from torchkeras import summary
from sklearn.preprocessing import LabelEncoder
from tools import *
import BaseModel
class FM(nn.Module):
''' without linear term and bias '''
def __init__(self):
super(FM, self).__init__()
def forward(self, inputs):
# (batch_size, field_size, embedding_size)
fm_input = inputs
square_of_sum = torch.pow(torch.sum(fm_input, dim=1, keepdim=True), 2)
sum_of_square = torch.sum(fm_input * fm_input, dim=1, keepdim=True)
cross_term = square_of_sum - sum_of_square
cross_term = 0.5 * torch.sum(cross_term, dim=2, keepdim=False)
# (batch_size, 1)
return cross_term
class Linear_W(nn.Module):
def __init__(self, dense_nums):
super(Linear_W, self).__init__()
self.dense_nums = dense_nums
if dense_nums is not None or dense_nums != 0:
self.weight = nn.Parameter(torch.Tensor(dense_nums, 1))
torch.nn.init.normal_(self.weight, mean=0, std=0.01)
def forward(self, sparse_inputs, dense_inputs=None, sparse_feat_refine_weight=None):
linear_logit = torch.zeros([sparse_inputs.shape[0], 1])
sparse_logit = sparse_inputs * sparse_feat_refine_weight.unsqueeze(-1)
sparse_logit = torch.sum(sparse_logit, dim=-1, keepdim=False)
sparse_logit = torch.unsqueeze(torch.sum(sparse_logit, dim=-1, keepdim=False), dim=-1)
linear_logit += sparse_logit
if dense_inputs is not None:
dense_logit = torch.matmul(dense_inputs, self.weight)
linear_logit += dense_logit
return linear_logit
class IFM(nn.Module):
def __init__(self, sparse_fields, dense_nums, emb_dim=8, dnn_hidden_units=(256, 128), use_bn=True, dropout=0.2, l2_reg_dnn=0):
super(IFM, self).__init__()
self.sparse_field_num = len(sparse_fields)
self.offsets = np.array((0, *np.cumsum(sparse_fields)[:-1]), dtype=np.longlong)
self.embedding = nn.Embedding(sum(sparse_fields) + 1, embedding_dim=emb_dim)
torch.nn.init.xavier_uniform_(self.embedding.weight.data)
# FEN
self.dnn_hidden_units = dnn_hidden_units
dnn_layers = []
self.dnn_input_dim = self.sparse_field_num * emb_dim
input_dim = self.sparse_field_num * emb_dim
for hidden in dnn_hidden_units:
dnn_layers.append(nn.Linear(input_dim, hidden))
if use_bn: dnn_layers.append(nn.BatchNorm1d(hidden))
dnn_layers.append(nn.ReLU())
dnn_layers.append(nn.Dropout(p=dropout))
input_dim = hidden
self.factor_estimating_net = nn.Sequential(*dnn_layers)
for name, tensor in self.factor_estimating_net.named_parameters():
if 'weight' in name:
nn.init.normal_(tensor, mean=0, std=0.01)
# P
self.transform_weight_matrix_P = nn.Linear(dnn_hidden_units[-1], self.sparse_field_num, bias=False)
self.linear_model = Linear_W(dense_nums=dense_nums)
self.fm = FM()
def forward(self, inputs):
dense_inputs, sparse_inputs = inputs[:, :13], inputs[:, 13:]
sparse_inputs = sparse_inputs.long()
sparse_inputs = sparse_inputs + sparse_inputs.new_tensor(self.offsets).unsqueeze(0)
spare_emb = self.embedding(sparse_inputs) # (None, field_num, emb_dim)
# FEN
dnn_output = self.factor_estimating_net(spare_emb.view(-1, self.dnn_input_dim))
dnn_output = self.transform_weight_matrix_P(dnn_output) # m'_{x} = U_x * P
input_aware_factor = self.sparse_field_num * dnn_output.softmax(1) # input_aware_factor m_{x,i}
# Reweighting
logit = self.linear_model(spare_emb, dense_inputs, sparse_feat_refine_weight=input_aware_factor) # w_{x,i} = m_{x,i} \times w_i
refined_fm_input = spare_emb * input_aware_factor.unsqueeze(-1) # v_{x,i} = m_{x,i} \times v_i
logit += self.fm(refined_fm_input)
return torch.sigmoid(logit.squeeze(-1))
训练代码
class BaseModel():
def __init__(self, net):
super(BaseModel, self).__init__()
self.net = net
def fit(self, train_loader, val_loader, epochs, loss_function, optimizer, metric_name):
start_time = time.time()
print("\n" + "********** start training **********")
columns = ["epoch", "loss", *metric_name, "val_loss"] + ['val_' + mn for mn in metric_name]
dfhistory = pd.DataFrame(columns=columns)
''' 训练 '''
for epoch in range(1, epochs + 1):
printlog("Epoch {0} / {1}".format(epoch, epochs))
step_start = time.time()
step_num = 0
train_loss = []
train_pred_probs, train_y, train_pre = [], [], []
self.net.train()
for batch, (x, y) in enumerate(train_loader):
step_num += 1
# 梯度清零
optimizer.zero_grad()
# 正向传播求损失
pred_probs = self.net(x)
loss = loss_function(pred_probs, y.float().detach())
# loss = loss_function(pred, y)
# 反向传播求梯度
loss.backward()
optimizer.step()
train_loss.append(loss.item())
train_pred_probs.extend(pred_probs.tolist())
train_y.extend(y.tolist())
train_pre.extend(torch.where(pred_probs > 0.5, torch.ones_like(pred_probs), torch.zeros_like(pred_probs)))
''' 验证 '''
val_loss = []
val_pred_probs, val_y, val_pre = [], [], []
self.net.eval()
# 不参与梯度计算
with torch.no_grad():
for batch, (x, y) in enumerate(val_loader):
pred_probs = self.net(x)
loss = loss_function(pred_probs, y.float().detach())
val_loss.append(loss.item())
val_pred_probs.extend(pred_probs.tolist())
val_y.extend(y.tolist())
val_pre.extend(
torch.where(pred_probs > 0.5, torch.ones_like(pred_probs), torch.zeros_like(pred_probs)))
''' 一次epoch结束 记录日志 '''
epoch_loss, epoch_val_loss = np.mean(train_loss), np.mean(val_loss)
train_auc = roc_auc_score(y_true=train_y, y_score=train_pred_probs)
train_acc = accuracy_score(y_true=train_y, y_pred=train_pre)
val_auc = roc_auc_score(y_true=val_y, y_score=val_pred_probs)
val_acc = accuracy_score(y_true=val_y, y_pred=val_pre)
dfhistory.loc[epoch - 1] = (epoch, epoch_loss, train_acc, train_auc, epoch_val_loss, val_acc, val_auc)
step_end = time.time()
print("step_num: %s - %.1fs - loss: %.5f accuracy: %.5f auc: %.5f - val_loss: %.5f val_accuracy: %.5f val_auc: %.5f"
% (step_num, (step_end - step_start) % 60, epoch_loss, train_acc, train_auc, epoch_val_loss, val_acc, val_auc))
end_time = time.time()
print('********** end of training run time: {:.0f}分 {:.0f}秒 **********'.format((end_time - start_time) // 60,
(end_time - start_time) % 60))
print()
return dfhistory
def evaluate(self, val_X, val_y):
val_X = torch.tensor(val_X).float()
pred_probs = self.net(val_X).data
pred = torch.where(pred_probs > 0.5, torch.ones_like(pred_probs), torch.zeros_like(pred_probs))
precision = np.around(metrics.precision_score(val_y, pred), 4)
recall = np.around(metrics.recall_score(val_y, pred), 4)
accuracy = np.around(metrics.accuracy_score(val_y, pred), 4)
f1 = np.around(metrics.f1_score(val_y, pred), 4)
auc = np.around(metrics.roc_auc_score(val_y, pred_probs), 4)
loss = np.around(metrics.log_loss(val_y, pred), 4)
acc_condition, precision_condition, recall_condition = self.accDealWith2(val_y, pred)
return precision, recall, accuracy, f1, auc, loss, acc_condition, precision_condition, recall_condition
def predict(self, x):
pred_probs = self.net(torch.tensor(x).float()).data
print(pred_probs)
pred = torch.where(pred_probs > 0.5, torch.ones_like(pred_probs), torch.zeros_like(pred_probs))
print(pred)
def plot_metric(self, dfhistory, metric):
train_metrics = dfhistory[metric]
val_metrics = dfhistory['val_' + metric]
epochs = range(1, len(train_metrics) + 1)
plt.plot(epochs, train_metrics, 'bo--')
plt.plot(epochs, val_metrics, 'ro-')
plt.title('Training and validation ' + metric)
plt.xlabel("Epochs")
plt.ylabel(metric)
plt.legend(["train_" + metric, 'val_' + metric])
plt.show()
def accDealWith2(self, y_test, y_pre):
lenall = len(y_test)
if type(y_test) != list:
y_test = y_test.flatten()
pos = 0
pre = 0
rec = 0
precisoinlen = 0
recallLen = 0
for i in range(lenall):
# 准确率
if y_test[i] == y_pre[i]:
pos += 1
# 精确率
if y_pre[i] == 1:
pre += 1
if y_test[i] == 1:
precisoinlen += 1
# 召回率
if y_test[i] == 1:
rec += 1
if y_pre[i] == 1:
recallLen += 1
acc_condition = '预测对的:{},总样本:{}'.format(pos, lenall)
if pre != 0:
precision_condition = '预测为正的样本数:{},其中实际为正的样本数:{},精确率:{}'.format(pre, precisoinlen,
np.around(precisoinlen / pre, 4))
else:
precision_condition = '预测为正的样本数:{},其中实际为正的样本数:{},精确率:{}'.format(pre, precisoinlen, 0.0)
if rec != 0:
recall_condition = '正例样本:{},正例中预测正确的数量:{},召回率:{}'.format(rec, recallLen, np.around(recallLen / rec, 4))
else:
recall_condition = '正例样本:{},正例中预测正确的数量:{},召回率:{}'.format(rec, recallLen, 0.0)
return acc_condition, precision_condition, recall_condition
main
if __name__ == '__main__':
data = pd.read_csv('./data/criteo_sampled_data_test.csv')
# I1-I13:总共 13 列数值型特征
# C1-C26:共有 26 列类别型特征
dense_cols = ['I' + str(i) for i in range(1, 14)]
sparse_cols = ['C' + str(i) for i in range(1, 27)]
stat_pnrate_pd(data=data, labname='label', message='criteo_sampled_data_test')
data_X = data[dense_cols + sparse_cols]
data_y = data['label']
sparse_fields = data_X[sparse_cols].max().values + 1
sparse_fields = sparse_fields.astype(np.int32)
print(sparse_fields)
dense_fields_num = 13
tmp_X, test_X, tmp_y, test_y = train_test_split(data_X, data_y, test_size=0.01, random_state=42, stratify=data_y)
train_X, val_X, train_y, val_y = train_test_split(tmp_X, tmp_y, test_size=0.01, random_state=42, stratify=tmp_y)
print(train_X.shape)
print(val_X.shape)
train_set = TensorDataset(torch.tensor(train_X.values).float(), torch.tensor(train_y.values).float())
val_set = TensorDataset(torch.tensor(val_X.values).float(), torch.tensor(val_y.values).float())
train_loader = DataLoader(dataset=train_set, batch_size=2048, shuffle=True)
val_loader = DataLoader(dataset=val_set, batch_size=2048, shuffle=False)
net = IFM(sparse_fields=sparse_fields, dense_nums=dense_fields_num, emb_dim=10, dnn_hidden_units=(256, 128, 64), use_bn=True, dropout=0.2, l2_reg_dnn=0.1)
loss_function = nn.BCELoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)
base = BaseModel.BaseModel(net=net)
dfhistory = base.fit(train_loader=train_loader, val_loader=val_loader,
epochs=1, loss_function=loss_function, optimizer=optimizer, metric_name=['accuracy', 'auc'])
summary(model=net, input_data=torch.tensor(val_X.values).float())
边栏推荐
猜你喜欢
随机推荐
PCB制造常用的13种测试方法,你了解几种?
【nohup】nohup命令的简单使用
MySQL的触发器
【英语单词】常见深度学习中编程用到的英语词汇
MySQL中的行锁
MySql data format is converted to Redis key-value pair format
Mysql去除重复数据
process.env环境变量配置方式(配置环境变量区分开发环境和生产环境)
IDEA连接mysql又报错!Server returns invalid timezone. Go to ‘Advanced‘ tab and set ‘serverTimezone‘ prope
Use of Alibaba Cloud SMS Service (create, test notes)
超全!9种PCB表面处理工艺大对比
Docker安装Mysql
高密度 PCB 线路板设计中的过孔知识
MySQL 日期时间类型精确到毫秒
C # program with administrator rights to open by default
在OracleLinux8.6的Zabbix6.0中监控Oracle11gR2
【入职第一篇知识总结- Prometheus】
sql中 exists的用法
ES6中 async 函数、await表达式 的基本用法
Prometheus monitors container, pod, email alerts









