当前位置:网站首页>The less successful implementation and lessons of RESNET
The less successful implementation and lessons of RESNET
2022-07-03 09:09:00 【weixin_ thirty-seven million six hundred and eighty-two thousan】
import torch
from torch import nn
from torch.nn import functional as F
class ResBlk(nn.Module):
""" resnet block """
def __init__(self, ch_in, ch_out):
""" :param ch_in: :param ch_out: """
super(ResBlk, self).__init__()
self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(ch_out)
self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(ch_out)
self.extra = nn.Sequential()
if ch_out != ch_in:
# [b, ch_in, h, w] => [b, ch_out, h, w]
self.extra = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=1),
nn.BatchNorm2d(ch_out)
)
def forward(self, x):
""" :param x: [b, ch, h, w] :return: """
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
# short cut.
# extra module: [b, ch_in, h, w] => [b, ch_out, h, w]
# element-wise add:
out = self.extra(x) + out
return out
class ResNet18(nn.Module):
def __init__(self):
super(ResNet18, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, stride=3, padding=0),
nn.BatchNorm2d(16)
)
# followed 4 blocks
# [b, 64, h, w] => [b, 128, h ,w]
self.blk1 = ResBlk(16, 32)
# [b, 128, h, w] => [b, 256, h, w]
self.blk2 = ResBlk(32, 64)
# # [b, 256, h, w] => [b, 512, h, w]
self.blk3 = ResBlk(64, 128)
# # [b, 512, h, w] => [b, 1024, h, w]
self.blk4 = ResBlk(128, 256)
self.outlayer = nn.Linear(256*10*10, 10)
def forward(self, x):
""" :param x: :return: """
x = F.relu(self.conv1(x))
# [b, 64, h, w] => [b, 1024, h, w]
x = self.blk1(x)
x = self.blk2(x)
x = self.blk3(x)
x = self.blk4(x)
# print(x.shape)
x = x.view(x.size(0), -1)
x = self.outlayer(x)
return x
def main():
blk = ResBlk(64, 128)
tmp = torch.randn(2, 64,32, 32)
out = blk(tmp)
print('blkk', out.shape)
model = ResNet18()
tmp = torch.randn(2, 3, 32, 32)
out = model(tmp)
print('resnet:', out.shape)
if __name__ == '__main__':
main()
This is the code written by the teacher , Don't post your own code if it's a little messy
The biggest feeling is CNN There is stride and padding after ,[b, chn, h,w], It's too messy It's easy to come out x And the next step to accept x Not right .
It mainly affects h,w
And then there was cnn Pick up fc Remember to level the back chn* H * W, namely
hold cnn Of [b, chn ,h ,w] Make it even [b, chnhw]
There is a key short cut
mistake
out = self.extra(x) + out
It's written in
out = self.extra(out) +x
It took two hours to find this problem , Still not familiar with the dimension transformation in the middle of the network ,
Yes short cut I don't understand well
from resnet_teacher import ResNet18
def main():
batchsz = 32
cifar_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()
]), download=True)
cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)
cifar_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()
]), download=True)
cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)
x, label = iter(cifar_train).next()
print('x:', x.shape, 'label:', label.shape)
device = torch.device('cuda')
# model = Lenet5().to(device)
model = ResNet18()
criteon = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
print(model)
for epoch in range(1000):
model.train()
for batchidx, (x, label) in enumerate(cifar_train):
# [b, 3, 32, 32]
# [b]
#x, label = x.to(device), label.to(device)
logits = model(x)
# logits: [b, 10]
# label: [b]
# loss: tensor scalar
loss = criteon(logits, label)
# backprop
optimizer.zero_grad()
loss.backward()
optimizer.step()
#
print(epoch, 'loss:', loss.item())
model.eval()
with torch.no_grad():
# test
total_correct = 0
total_num = 0
for x, label in cifar_test:
# [b, 3, 32, 32]
# [b]
#x, label = x.to(device), label.to(device)
# [b, 10]
logits = model(x)
# [b]
pred = logits.argmax(dim=1)
# [b] vs [b] => scalar tensor
correct = torch.eq(pred, label).float().sum().item()
total_correct += correct
total_num += x.size(0)
# print(correct)
acc = total_correct / total_num
print(epoch, 'acc:', acc)
if __name__ == '__main__':
main()
边栏推荐
- How to place the parameters of the controller in the view after encountering the input textarea tag in the TP framework
- Introduction to the basic application and skills of QT
- LeetCode 513. Find the value in the lower left corner of the tree
- LeetCode 30. Concatenate substrings of all words
- Too many open files solution
- 我们有个共同的名字,XX工
- Sword finger offer II 091 Paint the house
- Apache startup failed phpstudy Apache startup failed
- AcWing 787. Merge sort (template)
- LeetCode 532. 数组中的 k-diff 数对
猜你喜欢

Vscode connect to remote server

我們有個共同的名字,XX工

AcWing 788. 逆序对的数量

LeetCode 515. 在每个树行中找最大值

The "booster" of traditional office mode, Building OA office system, was so simple!

Digital statistics DP acwing 338 Counting problem

【点云处理之论文狂读经典版10】—— PointCNN: Convolution On X-Transformed Points

Education informatization has stepped into 2.0. How can jnpf help teachers reduce their burden and improve efficiency?

AcWing 785. 快速排序(模板)

数字化转型中,企业设备管理会出现什么问题?JNPF或将是“最优解”
随机推荐
Shell script kills the process according to the port number
Solution of 300ms delay of mobile phone
LeetCode 715. Range module
我們有個共同的名字,XX工
【点云处理之论文狂读经典版9】—— Pointwise Convolutional Neural Networks
Summary of methods for counting the number of file lines in shell scripts
干货!零售业智能化管理会遇到哪些问题?看懂这篇文章就够了
Instant messaging IM is the countercurrent of the progress of the times? See what jnpf says
cres
State compression DP acwing 91 Shortest Hamilton path
AcWing 786. Number k
LeetCode 508. 出现次数最多的子树元素和
LeetCode 57. 插入区间
Complex character + number pyramid
AcWing 787. 归并排序(模板)
Use of sort command in shell
We have a common name, XX Gong
Methods of using arrays as function parameters in shell
On the difference and connection between find and select in TP5 framework
Vscode connect to remote server