当前位置:网站首页>Deep learning medical image model reproduction
Deep learning medical image model reproduction
2022-07-28 05:37:00 【A little knowledge, a hundred Xiaosheng】
A. Sevastopolsky, Optic disc and cup segmentation methods for glaucoma detection with modification of u-net convolutional neural network, Pattern Recognition and Image Analysis 27 (2017) 618–624
The original author's code is based on keras Of , I use pytorch To reproduce
from multiprocessing import pool
from turtle import forward
from sklearn.preprocessing import scale
from torch import nn
import torch
import torch.nn.functional as F
import segmentation_models_pytorch as smp
class encoder(nn.Module):
def __init__(self,in_channels,out_channels) -> None:
super(encoder,self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=3,stride=1,padding=1),
nn.ReLU(True),
nn.Dropout(0.3),
nn.Conv2d(in_channels=out_channels,out_channels=out_channels,kernel_size=3,stride=1,padding=1),
nn.ReLU(True),
nn.Dropout(0.3),
)
def forward(self,x):
return self.encoder(x)
class decoder(nn.Module):
def __init__(self,in_channels,out_channels) -> None:
super(decoder,self).__init__()
self.decoder = nn.Sequential(
nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=3,stride=1,padding=1),
nn.ReLU(True),
nn.Dropout(0.3),
nn.Conv2d(in_channels=out_channels,out_channels=out_channels,kernel_size=3,stride=1,padding=1),
nn.ReLU(True),
)
def forward(self,x):
return self.decoder(x)
class modified_UNet(nn.Module):
def __init__(self) -> None:
super(modified_UNet,self).__init__()
self.encoder1 = encoder(in_channels=3,out_channels=32)
self.maxpool1 = nn.MaxPool2d(kernel_size=2,stride=2)
self.encoder2 = encoder(in_channels=32,out_channels=64)
self.maxpool2 = nn.MaxPool2d(kernel_size=2,stride=2)
self.encoder3 = encoder(in_channels=64,out_channels=64)
self.maxpool3 = nn.MaxPool2d(kernel_size=2,stride=2)
self.encoder4 = encoder(in_channels=64,out_channels=64)
self.maxpool4 = nn.MaxPool2d(kernel_size=2,stride=2)
self.encoder5 = encoder(in_channels=64,out_channels=64)
self.decoder6 = decoder(in_channels=128,out_channels=64)
self.decoder7 = decoder(in_channels=128,out_channels=64)
self.decoder8 = decoder(in_channels=128,out_channels=64)
self.decoder9 = decoder(in_channels=96,out_channels=32)
self.conv10 = nn.Conv2d(in_channels=32,out_channels=1,kernel_size=1,stride=1)
def forward(self,x):
conv1 = self.encoder1(x)
pool1 = self.maxpool1(conv1)
conv2 = self.encoder2(pool1)
pool2 = self.maxpool2(conv2)
conv3 = self.encoder3(pool2)
pool3 = self.maxpool3(conv3)
conv4 = self.encoder4(pool3)
pool4 = self.maxpool4(conv4)
conv5 = self.encoder5(pool4)
up6 = torch.cat((F.interpolate(conv5,scale_factor=(2,2),mode='bilinear'),conv4),dim=1)
conv6 = self.decoder6(up6)
up7 = torch.cat((F.interpolate(conv6,scale_factor=(2,2),mode='bilinear'),conv3),dim=1)
conv7 = self.decoder7(up7)
up8 = torch.cat((F.interpolate(conv7,scale_factor=(2,2),mode='bilinear'),conv2),dim=1)
conv8 = self.decoder8(up8)
up9 = torch.cat((F.interpolate(conv8,scale_factor=(2,2),mode='bilinear'),conv1),dim=1)
conv9 = self.decoder9(up9)
out = self.conv10(conv9)
out = F.sigmoid(out)
return out
H. Fu, J. Cheng, Y. Xu, D. W. K. Wong, J. Liu, X. Cao, Joint optic discand cup segmentation based on multi-label deep network and polar transformation, IEEE transactions on medical imaging 37 (2018)
1597–1605.
The original author published only part of it keras Code for , I based on keras The version uses pytorch It was reproduced
from torch import nn, relu
import torch
from torch.nn import functional as F
class encoder(nn.Module):
def __init__(self,in_channels,out_channels) -> None:
super(encoder,self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=3,stride=1,padding=1),
nn.ReLU(True),
nn.Conv2d(in_channels=out_channels,out_channels=out_channels,kernel_size=3,stride=1,padding=1),
nn.ReLU(True),
)
def forward(self,x):
return self.encoder(x)
class M_Conv(nn.Module):
def __init__(self, input_channels, output_channels):
super(M_Conv, self).__init__()
self.encode = nn.Sequential(
nn.Conv2d(input_channels, output_channels,kernel_size=3, padding=1, stride=1),
nn.ReLU(inplace=True),
)
def forward(self, x):
conv = self.encode(x)
return conv
class ConvRelu2d(nn.Module):
def __init__(self, in_channels, out_channels):
super(ConvRelu2d, self).__init__()
self.decoder = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, stride=1),
nn.ReLU(inplace=True),
)
def forward(self,x):
return self.decoder(x)
class M_Decoder(nn.Module):
def __init__(self, input_channels, output_channels):
super(M_Decoder, self).__init__()
self.up = nn.ConvTranspose2d(input_channels, output_channels, kernel_size=2, stride=2)
self.deconv = nn.Sequential(
ConvRelu2d(input_channels, output_channels),
ConvRelu2d(output_channels, output_channels))
def forward(self, x0,x1):
up = self.up(x0)
out = torch.cat((up,x1),dim=1)
out = self.deconv(out)
return out
class MNet(nn.Module):
def __init__(self,in_channels=3,n_classes=1) -> None:
super(MNet,self).__init__()
# mutli-scale simple convolution
self.conv2 = M_Conv(3, 64)
self.conv3 = M_Conv(3, 128)
self.conv4 = M_Conv(3, 256)
# the down convolution contain concat operation
self.encoder1 = encoder(3,32)
self.maxpool1 = nn.MaxPool2d(kernel_size=2,stride=2)
self.encoder2 = encoder(64+32,64)
self.maxpool2 = nn.MaxPool2d(kernel_size=2,stride=2)
self.encoder3 = encoder(128+64,128)
self.maxpool3 = nn.MaxPool2d(kernel_size=2,stride=2)
self.encoder4 = encoder(256+128,256)
self.maxpool4 = nn.MaxPool2d(kernel_size=2,stride=2)
# the center
self.center = encoder(256, 512)
# the up convolution contain concat operation
self.up5 = M_Decoder(512, 256)
self.up6 = M_Decoder(256, 128)
self.up7 = M_Decoder(128, 64)
self.up8 = M_Decoder(64, 32)
# the sideoutput
self.side_6 = nn.Conv2d(256, n_classes, kernel_size=1, padding=0, stride=1, bias=True)
self.side_7 = nn.Conv2d(128, n_classes, kernel_size=1, padding=0, stride=1, bias=True)
self.side_8 = nn.Conv2d(64, n_classes, kernel_size=1, padding=0, stride=1, bias=True)
self.side_9 = nn.Conv2d(32, n_classes, kernel_size=1, padding=0, stride=1, bias=True)
def forward(self, x):
_, _, img_shape, _ = x.size()
x_2 = F.interpolate(x, scale_factor=0.5)
x_3 = F.interpolate(x, scale_factor=0.25)
x_4 = F.interpolate(x, scale_factor=0.125)
conv1 = self.encoder1(x)
pool1 = self.maxpool1(conv1)
input2 = torch.cat([self.conv2(x_2), pool1], dim=1)
conv2 = self.encoder2(input2)
pool2 = self.maxpool2(conv2)
input3 = torch.cat([self.conv3(x_3), pool2], dim=1)
conv3 = self.encoder3(input3)
pool3 = self.maxpool3(conv3)
input4 = torch.cat([self.conv4(x_4), pool3], dim=1)
conv4 = self.encoder4(input4)
pool4 = self.maxpool4(conv4)
conv5 = self.center(pool4)
up6 = self.up5(conv5,conv4)
up7 = self.up6(up6,conv3)
up8 = self.up7(up7,conv2)
up9 = self.up8(up8,conv1)
side_6 = F.upsample(up6, size=(img_shape, img_shape), mode='bilinear')
side_7 = F.upsample(up7, size=(img_shape, img_shape), mode='bilinear')
side_8 = F.upsample(up8, size=(img_shape, img_shape), mode='bilinear')
side_6 = self.side_6(side_6)
side_7 = self.side_7(side_7)
side_8 = self.side_8(side_8)
side_9 = self.side_9(up9)
ave_out = (side_6+side_7+side_8+side_9)/4
return [ave_out, side_6, side_7, side_8,side_9]
边栏推荐
- 蒸馏模型图
- JMeter related knowledge sorting
- Example of main diagram of paper model
- Scanf function of input and output function in C language
- SSLError
- 深度学习热力图可视化的方式
- ByteBuffer. Position throws exception illegalargumentexception
- (dark horse) MySQL beginner advanced notes (blogger lazy dog)
- 图像增强——MSRCR
- 2021csdn blog star selection, mutual investment
猜你喜欢
随机推荐
Test Development - UI testing in automated testing
latex和word之间相互转换
Invalid bound statement (not found): com.exam.mapper.UserMapper.findbyid
(dark horse) MySQL beginner advanced notes (blogger lazy dog)
MySQL practice 45 lectures
低照度图像数据集
注册中心服务eureka 切换到 nocas遇到的问题
SSLError
Openjudge: count the number of numeric characters
论文写作用词
Openjudge: upper and lower case letters are interchanged
BigDecimal 进行四舍五入 四舍六入和保留两位小数
Thinking on multi system architecture design
冶金物理化学复习 --- 气-液相反应动力学
Response<T>类
JVM note 4: Memory Model
ResNet结构对比
You must configure either the server or JDBC driver (via the ‘serverTimezone)
2021csdn blog star selection, mutual investment
Example of main diagram of paper model








