当前位置:网站首页>pspnet完整代码实现
pspnet完整代码实现
2022-06-23 06:18:00 【休斯顿凤梨】
# pool with different bin_size # interpolation back to input size # concat
#code preference paddle
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph import to_variable
from paddle.fluid.dygraph import Layer
from paddle.fluid.dygraph import Conv2D
from paddle.fluid.dygraph import BatchNorm
from paddle.fluid.dygraph import Dropout
from resnet_dilated import ResNet50
class PSPModule(Layer):
def __init__(self,num_channels,bin_size_list):
super(PSPModule,self).__init__()
self.bin_size_list = bin_size_list
num_filters = num_channels//len(bin_size_list)
self.features = []
for i in range(len(bin_size_list)):
self.features.append(
fluid.dygraph.Sequential(
Conv2D(num_channels,num_filters,1),
BatchNorm(num_filters,act='relu')
)
)
def forward(self,inputs):
out = [inputs]
for idx,f in enumerate(self.features):
x = fluid.layers.adaptive_pool2d(inputs,self.bin_size_list[idx])
x = f(x)
x = fluid.layers.interpolate(x,inputs.shape[2::],align_corners=True)
#x = fluid.layers.interpolate
out.append(x)
out = fluid.layers.concat(out,axis=1)
return out
class PSPNet(Layer):
def __init__(self,num_classes=59,backbone='resnet50'):
super(PSPNet,self).__init__()
res = ResNet50(pretrained=False)
self.layer0 = fluid.dygraph.Sequential(
res.conv,
res.pool2d_max
)
self.layer1 = res.layer1
self.layer2 = res.layer2
self.layer3 = res.layer3
self.layer4 = res.layer4
num_channels = 2048
self.pspmodule = PSPModule(num_channels,[1,2,3,6])
num_channels *=2
self.classifier = fluid.dygraph.Sequential(
Conv2D(num_channels=num_channels,
num_filters=512,
filter_size=3,
padding=1),
BatchNorm(512,act='relu'),
Dropout(0.1),
Conv2D(num_channels=512,num_filters=num_classes,filter_size=1)
)
def forward(self,inputs):
x = self.layer0(inputs)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.pspmodule(x)
x = self.classifier(x)
x = fluid.layers.interpolate(x,inputs.shape[2::],align_corners=True)
return x
def main():
with fluid.dygraph.guard(fluid.CPUPlace()):
x_data = np.random.rand(2,3,473,473).astype(np.float32)
x = to_variable(x_data)
model = PSPNet(num_classes=59)
model.train()
pred,aux = model(x)
print(pred.shape,aux.shape)
if __name__== "__main__":
main()
边栏推荐
猜你喜欢

Xiaobai must see in investment and wealth management: illustrated fund buying and selling rules

Eureka

【BULL中文文档】用于在 NodeJS 中处理分布式作业和消息的队列包

【项目实训】多段线扩充为平行线

cmder

云原生落地进入深水区,博云容器云产品族释放四大价值

QT设计师无法修改窗口大小,无法通过鼠标拖动窗口改变大小的解决方案

Common setup modes (Abstract Factory & responsibility chain mode & observer mode)

正则表达式图文超详细总结不用死记硬背(上篇)

The illustration shows three handshakes and four waves. Xiaobai can understand them
随机推荐
cmder
Using fuser to view file usage
QT method of compiling projects using multithreading
301. 删除无效的括号
Linux安装mysql8.0.25
[QT] basic learning notes
306. 累加数
Xxl-sso enables SSO single sign on
20220621 Dual Quaternion
别找了诸位 【十二款超级好用的谷歌插件都在这】(确定不来看看?)
ssm + ftp +ueditor
Traversal of binary tree and related knowledge
GIS实战应用案例100篇(七十九)-多规整合底图的制作要点
ssm + ftp +ueditor
Cloud box is deeply convinced to create a smart dual-mode teaching resource sharing platform for Nanjing No. 1 middle school
【STL】关联容器之unordered_map用法总结
MySQL重做日志 redo log
Concepts and differences of DQL, DML, DDL and DCL
宝塔忘记密码
Mongodb record