当前位置:网站首页>Deep learning medical image model reproduction
Deep learning medical image model reproduction
2022-07-28 05:37:00 【A little knowledge, a hundred Xiaosheng】
A. Sevastopolsky, Optic disc and cup segmentation methods for glaucoma detection with modification of u-net convolutional neural network, Pattern Recognition and Image Analysis 27 (2017) 618–624
The original author's code is based on keras Of , I use pytorch To reproduce
from multiprocessing import pool
from turtle import forward
from sklearn.preprocessing import scale
from torch import nn
import torch
import torch.nn.functional as F
import segmentation_models_pytorch as smp
class encoder(nn.Module):
def __init__(self,in_channels,out_channels) -> None:
super(encoder,self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=3,stride=1,padding=1),
nn.ReLU(True),
nn.Dropout(0.3),
nn.Conv2d(in_channels=out_channels,out_channels=out_channels,kernel_size=3,stride=1,padding=1),
nn.ReLU(True),
nn.Dropout(0.3),
)
def forward(self,x):
return self.encoder(x)
class decoder(nn.Module):
def __init__(self,in_channels,out_channels) -> None:
super(decoder,self).__init__()
self.decoder = nn.Sequential(
nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=3,stride=1,padding=1),
nn.ReLU(True),
nn.Dropout(0.3),
nn.Conv2d(in_channels=out_channels,out_channels=out_channels,kernel_size=3,stride=1,padding=1),
nn.ReLU(True),
)
def forward(self,x):
return self.decoder(x)
class modified_UNet(nn.Module):
def __init__(self) -> None:
super(modified_UNet,self).__init__()
self.encoder1 = encoder(in_channels=3,out_channels=32)
self.maxpool1 = nn.MaxPool2d(kernel_size=2,stride=2)
self.encoder2 = encoder(in_channels=32,out_channels=64)
self.maxpool2 = nn.MaxPool2d(kernel_size=2,stride=2)
self.encoder3 = encoder(in_channels=64,out_channels=64)
self.maxpool3 = nn.MaxPool2d(kernel_size=2,stride=2)
self.encoder4 = encoder(in_channels=64,out_channels=64)
self.maxpool4 = nn.MaxPool2d(kernel_size=2,stride=2)
self.encoder5 = encoder(in_channels=64,out_channels=64)
self.decoder6 = decoder(in_channels=128,out_channels=64)
self.decoder7 = decoder(in_channels=128,out_channels=64)
self.decoder8 = decoder(in_channels=128,out_channels=64)
self.decoder9 = decoder(in_channels=96,out_channels=32)
self.conv10 = nn.Conv2d(in_channels=32,out_channels=1,kernel_size=1,stride=1)
def forward(self,x):
conv1 = self.encoder1(x)
pool1 = self.maxpool1(conv1)
conv2 = self.encoder2(pool1)
pool2 = self.maxpool2(conv2)
conv3 = self.encoder3(pool2)
pool3 = self.maxpool3(conv3)
conv4 = self.encoder4(pool3)
pool4 = self.maxpool4(conv4)
conv5 = self.encoder5(pool4)
up6 = torch.cat((F.interpolate(conv5,scale_factor=(2,2),mode='bilinear'),conv4),dim=1)
conv6 = self.decoder6(up6)
up7 = torch.cat((F.interpolate(conv6,scale_factor=(2,2),mode='bilinear'),conv3),dim=1)
conv7 = self.decoder7(up7)
up8 = torch.cat((F.interpolate(conv7,scale_factor=(2,2),mode='bilinear'),conv2),dim=1)
conv8 = self.decoder8(up8)
up9 = torch.cat((F.interpolate(conv8,scale_factor=(2,2),mode='bilinear'),conv1),dim=1)
conv9 = self.decoder9(up9)
out = self.conv10(conv9)
out = F.sigmoid(out)
return out
H. Fu, J. Cheng, Y. Xu, D. W. K. Wong, J. Liu, X. Cao, Joint optic discand cup segmentation based on multi-label deep network and polar transformation, IEEE transactions on medical imaging 37 (2018)
1597–1605.
The original author published only part of it keras Code for , I based on keras The version uses pytorch It was reproduced
from torch import nn, relu
import torch
from torch.nn import functional as F
class encoder(nn.Module):
def __init__(self,in_channels,out_channels) -> None:
super(encoder,self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=3,stride=1,padding=1),
nn.ReLU(True),
nn.Conv2d(in_channels=out_channels,out_channels=out_channels,kernel_size=3,stride=1,padding=1),
nn.ReLU(True),
)
def forward(self,x):
return self.encoder(x)
class M_Conv(nn.Module):
def __init__(self, input_channels, output_channels):
super(M_Conv, self).__init__()
self.encode = nn.Sequential(
nn.Conv2d(input_channels, output_channels,kernel_size=3, padding=1, stride=1),
nn.ReLU(inplace=True),
)
def forward(self, x):
conv = self.encode(x)
return conv
class ConvRelu2d(nn.Module):
def __init__(self, in_channels, out_channels):
super(ConvRelu2d, self).__init__()
self.decoder = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, stride=1),
nn.ReLU(inplace=True),
)
def forward(self,x):
return self.decoder(x)
class M_Decoder(nn.Module):
def __init__(self, input_channels, output_channels):
super(M_Decoder, self).__init__()
self.up = nn.ConvTranspose2d(input_channels, output_channels, kernel_size=2, stride=2)
self.deconv = nn.Sequential(
ConvRelu2d(input_channels, output_channels),
ConvRelu2d(output_channels, output_channels))
def forward(self, x0,x1):
up = self.up(x0)
out = torch.cat((up,x1),dim=1)
out = self.deconv(out)
return out
class MNet(nn.Module):
def __init__(self,in_channels=3,n_classes=1) -> None:
super(MNet,self).__init__()
# mutli-scale simple convolution
self.conv2 = M_Conv(3, 64)
self.conv3 = M_Conv(3, 128)
self.conv4 = M_Conv(3, 256)
# the down convolution contain concat operation
self.encoder1 = encoder(3,32)
self.maxpool1 = nn.MaxPool2d(kernel_size=2,stride=2)
self.encoder2 = encoder(64+32,64)
self.maxpool2 = nn.MaxPool2d(kernel_size=2,stride=2)
self.encoder3 = encoder(128+64,128)
self.maxpool3 = nn.MaxPool2d(kernel_size=2,stride=2)
self.encoder4 = encoder(256+128,256)
self.maxpool4 = nn.MaxPool2d(kernel_size=2,stride=2)
# the center
self.center = encoder(256, 512)
# the up convolution contain concat operation
self.up5 = M_Decoder(512, 256)
self.up6 = M_Decoder(256, 128)
self.up7 = M_Decoder(128, 64)
self.up8 = M_Decoder(64, 32)
# the sideoutput
self.side_6 = nn.Conv2d(256, n_classes, kernel_size=1, padding=0, stride=1, bias=True)
self.side_7 = nn.Conv2d(128, n_classes, kernel_size=1, padding=0, stride=1, bias=True)
self.side_8 = nn.Conv2d(64, n_classes, kernel_size=1, padding=0, stride=1, bias=True)
self.side_9 = nn.Conv2d(32, n_classes, kernel_size=1, padding=0, stride=1, bias=True)
def forward(self, x):
_, _, img_shape, _ = x.size()
x_2 = F.interpolate(x, scale_factor=0.5)
x_3 = F.interpolate(x, scale_factor=0.25)
x_4 = F.interpolate(x, scale_factor=0.125)
conv1 = self.encoder1(x)
pool1 = self.maxpool1(conv1)
input2 = torch.cat([self.conv2(x_2), pool1], dim=1)
conv2 = self.encoder2(input2)
pool2 = self.maxpool2(conv2)
input3 = torch.cat([self.conv3(x_3), pool2], dim=1)
conv3 = self.encoder3(input3)
pool3 = self.maxpool3(conv3)
input4 = torch.cat([self.conv4(x_4), pool3], dim=1)
conv4 = self.encoder4(input4)
pool4 = self.maxpool4(conv4)
conv5 = self.center(pool4)
up6 = self.up5(conv5,conv4)
up7 = self.up6(up6,conv3)
up8 = self.up7(up7,conv2)
up9 = self.up8(up8,conv1)
side_6 = F.upsample(up6, size=(img_shape, img_shape), mode='bilinear')
side_7 = F.upsample(up7, size=(img_shape, img_shape), mode='bilinear')
side_8 = F.upsample(up8, size=(img_shape, img_shape), mode='bilinear')
side_6 = self.side_6(side_6)
side_7 = self.side_7(side_7)
side_8 = self.side_8(side_8)
side_9 = self.side_9(up9)
ave_out = (side_6+side_7+side_8+side_9)/4
return [ave_out, side_6, side_7, side_8,side_9]
边栏推荐
- Scanf function of input and output function in C language
- 多线程进阶:锁的策略
- 2022 summer practice (first week)
- 【MySQL】MySQL时区问题、数据库时间相差8小时问题解决
- Mybats foreach multi select query, index loop, and cancel the and/or tag
- ECCV22 最新54篇论文主图整理
- 冶金物理化学复习 --- 液 - 液相反应动力学
- Multi module packaging: package: XXX does not exist
- 框架一步一步方便使用的流程
- ResNet结构对比
猜你喜欢

JUC notes

FusionGAN代码学习(一)

21 day SQL punch in summary

ByteBuffer.position 抛出异常 IllegalArgumentException

2022 summer practice (PowerDesigner tutorial learning record) (first week)

Redis 之布隆过滤器

CentOS7安装MySQL5.7

科研论文写作方法:在方法部分添加分析和讨论说明自己的贡献和不同

Tomato timing dimming table lamp touch chip-dlt8t10s-jericho

Eccv2022 | 29 papers of Tencent Youtu were selected, including face security, image segmentation, target detection and other research directions
随机推荐
Edge calculation kubeedge+edgemash
Fusiongan code learning (I)
Redis 之布隆过滤器
FusionGAN代码学习(一)
About localdatetime in swagger
Scanf function of input and output function in C language
PyTorch 使用 MaxPool 实现图像的膨胀和腐蚀
(黑马)MYSQL初级-高级笔记(博主懒狗)
How to compare long and integer and why to report errors
ResNet结构对比
Eccv2022 | 29 papers of Tencent Youtu were selected, including face security, image segmentation, target detection and other research directions
pytorch使用hook获得特征图
MySQL practice 45 lectures
LocalDateTime去掉T,JSONField失效
repackag failed: Unable to find main class
JVM篇 笔记4:内存模型
Non functional test
Confused, I'm going to start running in the direction of [test]
多线程进阶:volatile的作用以及实现原理
BeanUtils.copyProperties无法复制不同List集合问题解决 Lists.transform函数