当前位置:网站首页>【Pytorch】nn.Linear,nn.Conv
【Pytorch】nn.Linear,nn.Conv
2022-08-11 06:28:00 【二进制人工智能】
nn.Linear

nn.Conv1d

当nn.Conv1d的kernel_size=1时,效果与nn.Linear相同,不过输入数据格式不同:
https://blog.csdn.net/l1076604169/article/details/107170146
import torch
def count_parameters(model):
"""Count the number of parameters in a model."""
return sum([p.numel() for p in model.parameters()])
conv = torch.nn.Conv1d(3, 32, kernel_size=1)
print(count_parameters(conv))
# 128
linear = torch.nn.Linear(3, 32)
print(count_parameters(linear))
# 128
print(conv.weight.shape)
# torch.Size([32, 3, 1])
print(linear.weight.shape)
# torch.Size([32, 3])
# use same initialization
linear.weight = torch.nn.Parameter(conv.weight.squeeze(2))
linear.bias = torch.nn.Parameter(conv.bias)
tensor = torch.randn(128, 256, 3) # [batch, feature_num,feature_size]
permuted_tensor = tensor.permute(0, 2, 1).clone().contiguous() # [batch, feature_size,feature_num]
out_linear = linear(tensor)
print(out_linear.mean())
# tensor(0.0344, grad_fn=<MeanBackward0>)
print(out_linear.shape)
# torch.Size([128, 256, 32])
out_conv = conv(permuted_tensor)
print(out_conv.mean())
# tensor(0.0344, grad_fn=<MeanBackward0>)
print(out_conv.shape)
# torch.Size([128, 32, 256])
nn.Conv2d

nn.Conv3d

边栏推荐
- PIXHAWK飞控使用RTK
- js判断图片是否存在
- 软件测试基本流程有哪些?北京专业第三方软件检测机构安利
- 【latex异常和错误】Missing $ inserted.<inserted text>You can‘t use \spacefactor in math mode.输出文本要注意特殊字符的转义
- daily sql - user retention rate for two days
- 矩阵分析——微分、积分、极限
- Amazon Get AMAZON Product Details API Return Value Description
- Taobao API common interface and acquisition method
- js根据当天获取前几天的日期
- 软件测试主要做什么工作,难不难?
猜你喜欢
随机推荐
《猪猪1984》NFT 作品集将上线 The Sandbox 市场平台
拼多多API接口(附上我的可用API)
Unity3D 学习路线?
Unity底层是如何处理C#的
How Unity programmers can improve their abilities
buu—Re(5)
下一代 无线局域网--强健性
SQL sliding window
unable to extend table xxx by 1024 in tablespace xxxx
每日sql -用户两天留存率
淘宝API接口参考
exness:黄金1800关口遇阻,静待美国CPI出炉
maxwell 概念
sql--Users who have purchased more than 3 times (inclusive) within 7 days (including the current day), and the purchase amount in the past 7 days exceeds 1,000
每日sql -查询至少有5名下属的经理和选举
Daily sql - judgment + aggregation
jar服务导致cpu飙升问题-带解决方法
抖音关键词搜索商品-API工具
空间金字塔池化 -Spatial Pyramid Pooling(含源码)
博途PLC 1200/1500PLC ModbusTcp通信梯形图优化汇总(多服务器多从站轮询)








