当前位置:网站首页>torch. nn. Linear() function

torch. nn. Linear() function

2022-06-12 20:57:00 Human high quality Algorithm Engineer

torch.nn.Linear(in_features, out_features, bias=True) Function is a linear transformation function :
 Insert picture description here

among ,in_features Enter the sample size for ,out_features Is the size of the output sample ,bias The default is true. If you set bias = false Then this layer will not learn an additive bias .

Linear() Function is usually used to set the full connection layer in the network .

import torch

x = torch.randn(8, 3)  #  The input samples 
fc = torch.nn.Linear(3, 5)  # 20 Enter the sample size for ,30 Is the output sample size 
output = fc(x)
print('fc.weight.shape:\n ', fc.weight.shape, fc.weight)
print('fc.bias.shape:\n', fc.bias.shape)
print('output.shape:\n', output.shape)

ans = torch.mm(x, torch.t(fc.weight)) + fc.bias  #  The results are in agreement with fc(x) identical 
print('ans.shape:\n', ans.shape)

print(torch.equal(ans, output))

The output is :

fc.weight.shape:
  torch.Size([5, 3]) Parameter containing:
tensor([[-0.1878, -0.2082,  0.4506],
        [ 0.3230,  0.3543,  0.3187],
        [-0.0993, -0.0028, -0.1001],
        [-0.0479,  0.3248, -0.4867],
        [ 0.0574,  0.0451,  0.1525]], requires_grad=True)
fc.bias.shape:
 torch.Size([5])
output.shape:
 torch.Size([8, 5])
ans.shape:
 torch.Size([8, 5])
True

Process finished with exit code 0

First ,nn.linear(3,5) Its weighted shape by (5,3), therefore x When multiplied by , use torch.t Please nn.linear The transpose , such (83)(35) Get the output dimension after the full connection layer (85), The results are also consistent with the results fc(x) Verification is consistent , torch.mm Just two matrices in mathematics Multiply .

reference :
https://blog.csdn.net/daodaipsrensheng/article/details/117259324

原网站

版权声明
本文为[Human high quality Algorithm Engineer]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/02/202202281434270716.html