当前位置:网站首页>模型训练出现NAN
模型训练出现NAN
2022-08-11 08:52:00 【小乐快乐】
【功能模块】完整代码在附件,数据集需要的话也可以提供
class EmbeddingImagenet(nn.Cell):
def __init__(self,emb_size,cifar_flag=False):
super(EmbeddingImagenet, self).__init__()
# set size
self.hidden = 64
self.last_hidden = self.hidden * 25 if not cifar_flag else self.hidden * 4
self.emb_size = emb_size
self.out_dim = emb_size
# set layers
self.conv_1 = nn.SequentialCell(nn.Conv2d(in_channels=3,
out_channels=self.hidden,
kernel_size=3,
padding=1,
pad_mode='pad',
has_bias=False),
nn.BatchNorm2d(num_features=self.hidden),
nn.MaxPool2d(kernel_size=2,stride=2),
nn.LeakyReLU(alpha=0.2))
self.conv_2 = nn.SequentialCell(nn.Conv2d(in_channels=self.hidden,
out_channels=int(self.hidden*1.5),
kernel_size=3,
padding=1,
pad_mode='pad',
has_bias=False),
nn.BatchNorm2d(num_features=int(self.hidden*1.5)),
nn.MaxPool2d(kernel_size=2,stride=2),
nn.LeakyReLU(alpha=0.2))
self.conv_3 = nn.SequentialCell(nn.Conv2d(in_channels=int(self.hidden*1.5),
out_channels=self.hidden*2,
kernel_size=3,
padding=1,
pad_mode='pad',
has_bias=False),
nn.BatchNorm2d(num_features=self.hidden * 2),
nn.MaxPool2d(kernel_size=2,stride=2),
nn.LeakyReLU(alpha=0.2),
nn.Dropout(0.6))
self.conv_4 = nn.SequentialCell(nn.Conv2d(in_channels=self.hidden*2,
out_channels=self.hidden*4,
kernel_size=3,
padding=1,
pad_mode='pad',
has_bias=False),
nn.BatchNorm2d(num_features=self.hidden * 4), # 16 * 64 * (5 * 5)
nn.MaxPool2d(kernel_size=2,stride=2),
nn.LeakyReLU(alpha=0.2),
nn.Dropout(0.5))
# self.layer_last = nn.SequentialCell(nn.Dense(in_channels=self.last_hidden * 4,
# out_channels=self.emb_size, has_bias=True),
# nn.BatchNorm1d(self.emb_size))
self.layer_last = nn.Dense(in_channels=self.last_hidden * 4,out_channels=self.emb_size, has_bias=True)
#self.bn = nn.BatchNorm1d(self.emb_size)
def construct(self, input_data):
#print("img:",input_data[0])
x = self.conv_1(input_data)
x = self.conv_2(x)
x = self.conv_3(x)
x = self.conv_4(x)
#x = ops.Reshape()(x,(x.shape[0],-1))
print("feat:", input_data[0])
#x = self.layer_last(x)
x = self.layer_last(x.view(x.shape[0],-1))
print("last--------------------------------:",x[0])
return xclass NodeUpdateNetwork(nn.Cell):
def __init__(self,
in_features,
num_features,
ratio=[2, 1],
dropout=0.0):
super(NodeUpdateNetwork, self).__init__()
# set size
self.in_features = in_features
self.num_features_list = [num_features * r for r in ratio]
self.dropout = dropout
self.eye = ops.Eye()
self.bmm = ops.BatchMatMul()
self.cat = ops.Concat(-1)
self.split = ops.Split(1,2)
self.repeat = ops.Tile()
self.unsqueeze = ops.ExpandDims()
self.squeeze = ops.Squeeze()
self.transpose = ops.Transpose()
# layers
layer_list = OrderedDict()
for l in range(len(self.num_features_list)):
layer_list['conv{}'.format(l)] = nn.Conv2d(
in_channels=self.num_features_list[l - 1] if l > 0 else self.in_features * 3,
out_channels=self.num_features_list[l],
kernel_size=1,
has_bias=False)
layer_list['norm{}'.format(l)] = nn.BatchNorm2d(num_features=self.num_features_list[l],)
layer_list['relu{}'.format(l)] = nn.LeakyReLU(alpha=1e-2)
if self.dropout > 0 and l == (len(self.num_features_list) - 1):
layer_list['drop{}'.format(l)] = nn.Dropout(keep_prob=1-self.dropout)
self.network = nn.SequentialCell(layer_list)
def construct(self, node_feat, edge_feat):
# get size
num_tasks = node_feat.shape[0]
num_data = node_feat.shape[1]
# get eye matrix (batch_size x 2 x node_size x node_size)
diag_mask = 1.0 - self.repeat(self.unsqueeze(self.unsqueeze(self.eye(num_data,num_data,ms.float32),0),0),(num_tasks,2,1,1))
# set diagonal as zero and normalize 原论文是l1归一化
# edge_feat = edge_feat * diag_mask
# edge_feat = edge_feat / ops.clip_by_value(ops.ReduceSum(keep_dims=True)(ops.Abs()(edge_feat), -1),Tensor(0,ms.float32),Tensor(num_data,ms.float32))
edge_feat = ops.L2Normalize(-1)(edge_feat * diag_mask)
# compute attention and aggregate
aggr_feat = self.bmm(self.squeeze(ops.Concat(2)(self.split(edge_feat))),node_feat)
node_feat = self.cat([node_feat,self.cat(ops.Split(1, 2)(aggr_feat))]).swapaxes(1,2)
#node_feat = self.transpose(self.cat([node_feat,self.cat(ops.Split(1, 2)(aggr_feat))]),(0,2,1))
node_feat = self.network(self.unsqueeze(node_feat,(-1))).swapaxes(1,2).squeeze()
#node_feat = self.squeeze(self.transpose(self.network(self.unsqueeze(node_feat,(-1))),(0,2,1,3)))
return node_feat
class EdgeUpdateNetwork(nn.Cell):
def __init__(self,
in_features,
num_features,
ratio=[2, 2, 1, 1],
separate_dissimilarity=False,
dropout=0.0):
super(EdgeUpdateNetwork, self).__init__()
# set size
self.in_features = in_features
self.num_features_list = [num_features * r for r in ratio]
self.separate_dissimilarity = separate_dissimilarity
self.dropout = dropout
self.eye = ops.Eye()
self.repeat = ops.Tile()
self.unsqueeze = ops.ExpandDims()
# layers
layer_list = OrderedDict()
for l in range(len(self.num_features_list)):
# set layer
layer_list['conv{}'.format(l)] = nn.Conv2d(in_channels=self.num_features_list[l-1] if l > 0 else self.in_features,
out_channels=self.num_features_list[l],
kernel_size=1,
has_bias=False)
layer_list['norm{}'.format(l)] = nn.BatchNorm2d(num_features=self.num_features_list[l],
)
layer_list['relu{}'.format(l)] = nn.LeakyReLU(alpha=1e-2)
if self.dropout > 0:
layer_list['drop{}'.format(l)] = nn.Dropout(keep_prob=1-self.dropout)
layer_list['conv_out'] = nn.Conv2d(in_channels=self.num_features_list[-1],
out_channels=1,
kernel_size=1)
self.sim_network = nn.SequentialCell(layer_list)
def construct(self, node_feat, edge_feat):
# compute abs(x_i, x_j)
x_i = ops.ExpandDims()(node_feat,2)
x_j = x_i.swapaxes(1,2)
#x_j = ops.Transpose()(x_i,(0,2,1,3))
#x_ij = (x_i-x_j)**2
x_ij = ops.Abs()(x_i-x_j)
#print("x_ij:",x_ij[0,0,:,:])
x_ij = ops.Transpose()(x_ij,(0,3,2,1))
sim_val = self.sim_network(x_ij)
sim_val = ops.Sigmoid()(sim_val)
#print("sim_val", sim_val[0, 0, :, :])
dsim_val = 1.0 - sim_val
diag_mask = 1.0 - self.repeat(self.unsqueeze(self.unsqueeze(self.eye(node_feat.shape[1],node_feat.shape[1],ms.float32),0),0),(node_feat.shape[0],2,1,1))
edge_feat = edge_feat * diag_mask
merge_sum = ops.ReduceSum(keep_dims=True)(edge_feat,-1)
# set diagonal as zero and normalize
# edge_feat = ops.Concat(1)([sim_val,dsim_val])*edge_feat
# edge_feat = edge_feat / ops.clip_by_value((ops.ReduceSum(keep_dims=True)(ops.Abs()(edge_feat), -1)),Tensor(0,ms.float32),Tensor(num_data,ms.float32))
# edge_feat = edge_feat*merge_sum
edge_feat = ops.L2Normalize(-1)(ops.Concat(1)([sim_val,dsim_val])*edge_feat)*merge_sum
force_edge_feat = self.repeat(self.unsqueeze(ops.Concat(0)([self.unsqueeze(self.eye(node_feat.shape[1],node_feat.shape[1],ms.float32),0),self.unsqueeze(ops.Zeros()((node_feat.shape[1],node_feat.shape[1]),ms.float32),0)]),0),(node_feat.shape[0],1,1,1))
edge_feat = edge_feat + force_edge_feat
edge_feat = edge_feat + 1e-6
#print("sum_edge",self.repeat(self.unsqueeze(ops.ReduceSum()(edge_feat,1),1),(1,2,1,1))[0,0])
edge_feat = edge_feat / self.repeat(self.unsqueeze(ops.ReduceSum()(edge_feat,1),1),(1,2,1,1))
return edge_feat
class GraphNetwork(nn.Cell):
def __init__(self,
in_features,
node_features,
edge_features,
num_layers,
dropout=0.0
):
super(GraphNetwork, self).__init__()
# set size
self.in_features = in_features
self.node_features = node_features
self.edge_features = edge_features
self.num_layers = num_layers
self.dropout = dropout
self.layers = nn.CellList()
# for each layer
for l in range(self.num_layers):
# set edge to node
edge2node_net = NodeUpdateNetwork(in_features=self.in_features if l == 0 else self.node_features,
num_features=self.node_features,
dropout=self.dropout if l < self.num_layers-1 else 0.0)
# set node to edge
node2edge_net = EdgeUpdateNetwork(in_features=self.node_features,
num_features=self.edge_features,
separate_dissimilarity=False,
dropout=self.dropout if l < self.num_layers-1 else 0.0)
self.layers.append(nn.CellList([edge2node_net,node2edge_net]))
# forward
def construct(self, node_feat, edge_feat):
# for each layer
edge_feat_list = []
#print("node_feat---------------------------------------------------------- -1", node_feat[0, 0, :])
for l in range(self.num_layers):
# (1) edge to node
node_feat = self.layers[l][0](node_feat, edge_feat)
# (2) node to edge
edge_feat = self.layers[l][1](node_feat, edge_feat)
# save edge feature
edge_feat_list.append(edge_feat)
return edge_feat_list
【操作步骤&问题现象】
我们代码主要功能是用4层卷积加一层全连接层提取图片特征,之后将图片的特征当成图网络每个节点,用GNN。(代码在附件上)
1、在训练了很多个batch之后,提取出来的特征(经过了4层卷积层和全连接层)出现了很大很大的值,之后几个batch后出现NAN,而在没有经过全连接层的时候,特征数字还是正常的
2、
【截图信息】
这是代码输出的特征
last--------------------------------: [ 1.918492 -0.8280923 2.0575197 0.3089749 -1.0514854 0.5368729
0.14135109 1.5270222 -1.4794292 -1.4336827 1.0335447 -0.7093582
-0.41919574 -0.5667086 -0.3535831 1.5567536 0.5002996 -1.4093596
0.9674009 -0.18156137 0.14888959 0.6358457 1.406878 -0.03820777
-0.24577822 -0.25783274 0.5756687 -1.4558431 -1.1002262 0.68062806
-1.6467474 0.88712454 0.3551372 -1.3449378 -1.7011788 -0.8629771
-0.92482185 0.9867192 -1.5548937 1.340383 -2.299356 -0.3421743
1.3239275 -1.3792732 -0.31955895 -0.58364254 -3.7381008 -1.2121737
-0.75104207 -0.7562581 0.04980466 0.45131734 -1.2448095 -0.33418307
0.86268485 -1.3601649 1.2753168 2.469506 -1.7358601 -2.9104383
-0.07392117 -0.73263663 0.11657254 -0.05724781 0.34374043 -0.31884825
0.13456154 2.3561432 -0.18908082 0.5410311 1.7249999 0.9508886
-0.30631644 1.6836481 1.1513023 -0.33672807 -0.889638 -0.76715356
-0.7316199 1.597606 -1.6586273 0.4502733 0.5224928 -3.5851111
-2.906651 -1.5284328 0.83426046 1.354644 -1.4453334 2.0504599
-1.3200179 -0.50427496 0.97681373 0.30048305 0.17170379 0.8179815
-0.92994857 1.333491 -1.2931286 -0.3569969 2.7953048 -3.352736
1.878619 2.018083 -1.1191074 -1.1341975 1.4532931 -0.66957355
2.3269157 -0.4198427 0.7148121 0.5458231 -1.3050007 -0.34666243
2.519589 0.804219 0.91191477 1.3088121 0.6767241 2.1667008
0.24471135 1.2600335 -1.8683847 2.5641935 -0.9636249 -1.0340385
-0.32570755 -1.7694132 ]
------------------------------------------
------------------------------------------------------------------------------- 1 0.7806913
---------------------------------------------
feat: [[[-1.6726604 -1.6897851 -1.7069099 ... 0.43368444 0.46793392
0.41655967]
[-1.7069099 -1.7069099 -1.7069099 ... 0.5364329 0.5193082
0.4850587 ]
[-1.7240347 -1.7240347 -1.7069099 ... 0.60493195 0.5535577
0.4850587 ]
...
[-0.6622999 -0.8335474 -0.8677969 ... -0.02868402 0.00556549
-0.02868402]
[-0.6622999 -0.69654936 -0.69654936 ... -0.11430778 -0.11430778
-0.14855729]
[-0.95342064 -0.8335474 -0.78217316 ... -0.26843056 -0.30268008
-0.31980482]]
[[-1.7556022 -1.7731092 -1.7906162 ... -0.617647 -0.582633
-0.635154 ]
[-1.7906162 -1.7906162 -1.7906162 ... -0.512605 -0.512605
-0.565126 ]
[-1.8081232 -1.8081232 -1.7906162 ... -0.460084 -0.495098
-0.565126 ]
...
[-0.28501397 -0.37254897 -0.40756297 ... -1.0028011 -0.9677871
-1.0203081 ]
[-0.26750696 -0.33753496 -0.32002798 ... -1.12535 -1.1428571
-1.160364 ]
[-0.53011197 -0.53011197 -0.44257697 ... -1.2829131 -1.317927
-1.317927 ]]
[[-1.68244 -1.6998693 -1.7172985 ... -1.490719 -1.4558606
-1.490719 ]
[-1.7172985 -1.7172985 -1.7172985 ... -1.4732897 -1.4384314
-1.4732897 ]
[-1.7347276 -1.7347276 -1.7172985 ... -1.4558606 -1.4732897
-1.5255773 ]
...
[-1.3338562 -1.4210021 -1.4210021 ... -1.6127234 -1.5430065
-1.5604358 ]
[-1.2815686 -1.3512855 -1.3338562 ... -1.6127234 -1.5952941
-1.6127234 ]
[-1.5081482 -1.4732897 -1.4210021 ... -1.5778649 -1.6127234
-1.6301525 ]]]
last--------------------------------: [-9.7715964e+37 -1.3229437e+37 -1.5262715e+38 -2.5811514e+38
3.2964988e+38 -7.1266450e+37 -7.2963347e+37 -3.0699307e+38
-1.6108344e+38 5.8011444e+37 -3.9925391e+37 -9.5891957e+37
-1.7783365e+38 2.2280316e+38 -4.4186918e+37 3.4825655e+37
5.8457292e+37 7.2160006e+37 1.4259578e+38 9.4037617e+37
7.4650717e+37 1.8146209e+37 -2.5143476e+38 2.4387442e+38
-7.5397363e+37 1.4157064e+38 -1.1084308e+38 1.9522180e+38
2.5864164e+37 -8.5381704e+37 3.3140050e+36 -1.2379668e+38
-3.3449897e+37 1.6203643e+38 1.4627435e+38 6.6909600e+37
6.0661751e+37 -1.2335753e+38 1.3377397e+38 -3.7530971e+37
3.5314601e+37 -1.4393099e+37 -inf -6.0411279e+37
-7.0721061e+37 1.5951782e+38 9.0163464e+37 1.3680580e+37
-1.2254094e+37 1.0919689e+38 -1.5229139e+37 -3.4862508e+36
-8.9739065e+37 2.8713203e+38 9.4768839e+37 7.8658815e+37
-2.6619306e+38 -7.8224467e+37 6.8780734e+37 inf
-9.8889302e+37 -1.9009123e+38 -1.4562352e+38 -4.5324568e+37
-2.6728082e+38 1.0300855e+38 -5.7767852e+37 1.3662499e+37
-4.0048543e+37 -3.1911765e+37 -1.9702732e+38 -6.5395945e+37
1.0223747e+38 -2.8775531e+38 -1.1156091e+38 -1.8772822e+38
1.2472896e+38 1.2465860e+38 -6.7286062e+37 -8.9167649e+37
-2.8327554e+37 -2.7379526e+37 -1.5994879e+37 1.1577176e+38
1.1864721e+38 1.7089999e+38 -1.5323652e+37 -1.5374746e+38
1.2187025e+38 -8.9546139e+37 1.7550813e+38 -5.7048014e+37
-8.5996788e+37 -5.2310546e+36 -1.4450948e+37 -1.9950120e+37
4.2429252e+37 -1.4849557e+38 1.0697206e+38 -7.6313524e+37
-inf 1.7437526e+38 -1.0569269e+38 -1.5577321e+38
-7.8117285e+37 6.4801082e+37 -3.3032475e+37 -6.4655517e+36
-2.3770844e+38 1.0880277e+38 3.6430118e+37 -6.9370110e+37
8.5146681e+37 1.1550550e+38 -2.5614073e+38 -2.1489826e+38
-8.3233807e+37 2.7233982e+37 -1.3777926e+38 -9.6201629e+37
-2.1125345e+38 -1.4252791e+36 3.6633845e+37 2.6106833e+37
9.6643025e+37 -1.4538810e+37 -1.3660478e+38 1.9220696e+38]
1 采用warmup调整一下学习率,最大学习率设置为0.01;
2 采用梯度剪裁方法进行保护;
3 检查最后是否进行归一处理,估计可能取值范围不在0-1之间。
边栏推荐
- Unity3D - modification of the Inspector panel of the custom class
- MATLAB实战Sobel边缘检测(Edge Detection)
- 如何通过 IDEA 数据库管理工具连接 TDengine?
- C Primer Plus(6) 中文版 第1章 初识C语言 1.7 使用C语言的7个步骤
- 新一代开源免费的轻量级 SSH 终端,非常炫酷好用!
- 表达式必须具有与对应表达式相同的数据类型
- js将table生成excel文件并去除表格中的多余tr(js去除表格中空的tr标签)
- shell之sed
- Kotlin算法入门求自由落体
- [wxGlade learning] wxGlade environment configuration
猜你喜欢

Nuget can't find the package problem

SDUT 2877: angry_birds_again_and_again

盘点四个入门级SSL证书

基于C#通过PLCSIM ADV仿真软件实现与西门子1500PLC的S7通信方法演示

基于 VIVADO 的 AM 调制解调(1)方案设计

Audio and video + AI, Zhongguancun Kejin helps a bank explore a new development path | Case study

对比学习系列(三)-----SimCLR

flex布局回顾

tensorflow 基础操作1(tensor 基本属性 , 维度变换,数学运算)

中国电子学会五级考点详解(一)-string类型字符串
随机推荐
UNITY gameobject代码中setacvtive(false)与面板中直接去掉勾 效果不一样
Audio and video + AI, Zhongguancun Kejin helps a bank explore a new development path | Case study
IPQ4019/IPQ4029 support WiFi6 MiniPCIe Module 2T2R 2×2.4GHz 2x5GHz MT7915 MT7975
golang string manipulation
如何通过开源数据库管理工具 DBeaver 连接 TDengine
Kotlin算法入门求完全数
nodejs微服务中跨域,请求,接口,参数拦截等功能
Nuget can't find the package problem
VoLTE基础自学系列 | 3GPP规范解读之Rx接口(上集)
shell之sed
Kotlin算法入门计算素数以及优化
Song of the Cactus - Massive Rapid Expansion (1)
gRPC系列(四) 框架如何赋能分布式系统
IDEA的初步使用
观察表情和面部,会发现他有焦虑和失眠的痕迹
Analysis of the Status Quo of Enterprise Server Host Reinforcement
Jupyter Notebook 插件 contrib nbextension 安装使用
深度学习100例 —— 卷积神经网络(CNN)识别验证码
兼容并蓄广纳百川,Go lang1.18入门精炼教程,由白丁入鸿儒,go lang复合容器类型的声明和使用EP04
游戏服务器中集群网关的设计