当前位置:网站首页>深度学习医学图像模型复现
深度学习医学图像模型复现
2022-07-28 05:17:00 【一知半解百晓生】
A. Sevastopolsky, Optic disc and cup segmentation methods for glaucoma detection with modification of u-net convolutional neural network, Pattern Recognition and Image Analysis 27 (2017) 618–624
原作者的代码是基于keras的,我使用pytorch进行复现
from multiprocessing import pool
from turtle import forward
from sklearn.preprocessing import scale
from torch import nn
import torch
import torch.nn.functional as F
import segmentation_models_pytorch as smp
class encoder(nn.Module):
def __init__(self,in_channels,out_channels) -> None:
super(encoder,self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=3,stride=1,padding=1),
nn.ReLU(True),
nn.Dropout(0.3),
nn.Conv2d(in_channels=out_channels,out_channels=out_channels,kernel_size=3,stride=1,padding=1),
nn.ReLU(True),
nn.Dropout(0.3),
)
def forward(self,x):
return self.encoder(x)
class decoder(nn.Module):
def __init__(self,in_channels,out_channels) -> None:
super(decoder,self).__init__()
self.decoder = nn.Sequential(
nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=3,stride=1,padding=1),
nn.ReLU(True),
nn.Dropout(0.3),
nn.Conv2d(in_channels=out_channels,out_channels=out_channels,kernel_size=3,stride=1,padding=1),
nn.ReLU(True),
)
def forward(self,x):
return self.decoder(x)
class modified_UNet(nn.Module):
def __init__(self) -> None:
super(modified_UNet,self).__init__()
self.encoder1 = encoder(in_channels=3,out_channels=32)
self.maxpool1 = nn.MaxPool2d(kernel_size=2,stride=2)
self.encoder2 = encoder(in_channels=32,out_channels=64)
self.maxpool2 = nn.MaxPool2d(kernel_size=2,stride=2)
self.encoder3 = encoder(in_channels=64,out_channels=64)
self.maxpool3 = nn.MaxPool2d(kernel_size=2,stride=2)
self.encoder4 = encoder(in_channels=64,out_channels=64)
self.maxpool4 = nn.MaxPool2d(kernel_size=2,stride=2)
self.encoder5 = encoder(in_channels=64,out_channels=64)
self.decoder6 = decoder(in_channels=128,out_channels=64)
self.decoder7 = decoder(in_channels=128,out_channels=64)
self.decoder8 = decoder(in_channels=128,out_channels=64)
self.decoder9 = decoder(in_channels=96,out_channels=32)
self.conv10 = nn.Conv2d(in_channels=32,out_channels=1,kernel_size=1,stride=1)
def forward(self,x):
conv1 = self.encoder1(x)
pool1 = self.maxpool1(conv1)
conv2 = self.encoder2(pool1)
pool2 = self.maxpool2(conv2)
conv3 = self.encoder3(pool2)
pool3 = self.maxpool3(conv3)
conv4 = self.encoder4(pool3)
pool4 = self.maxpool4(conv4)
conv5 = self.encoder5(pool4)
up6 = torch.cat((F.interpolate(conv5,scale_factor=(2,2),mode='bilinear'),conv4),dim=1)
conv6 = self.decoder6(up6)
up7 = torch.cat((F.interpolate(conv6,scale_factor=(2,2),mode='bilinear'),conv3),dim=1)
conv7 = self.decoder7(up7)
up8 = torch.cat((F.interpolate(conv7,scale_factor=(2,2),mode='bilinear'),conv2),dim=1)
conv8 = self.decoder8(up8)
up9 = torch.cat((F.interpolate(conv8,scale_factor=(2,2),mode='bilinear'),conv1),dim=1)
conv9 = self.decoder9(up9)
out = self.conv10(conv9)
out = F.sigmoid(out)
return out
H. Fu, J. Cheng, Y. Xu, D. W. K. Wong, J. Liu, X. Cao, Joint optic discand cup segmentation based on multi-label deep network and polar transformation, IEEE transactions on medical imaging 37 (2018)
1597–1605.
原作者只公布了部分keras的代码,我基于keras的版本使用pytorch进行了复现
from torch import nn, relu
import torch
from torch.nn import functional as F
class encoder(nn.Module):
def __init__(self,in_channels,out_channels) -> None:
super(encoder,self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=3,stride=1,padding=1),
nn.ReLU(True),
nn.Conv2d(in_channels=out_channels,out_channels=out_channels,kernel_size=3,stride=1,padding=1),
nn.ReLU(True),
)
def forward(self,x):
return self.encoder(x)
class M_Conv(nn.Module):
def __init__(self, input_channels, output_channels):
super(M_Conv, self).__init__()
self.encode = nn.Sequential(
nn.Conv2d(input_channels, output_channels,kernel_size=3, padding=1, stride=1),
nn.ReLU(inplace=True),
)
def forward(self, x):
conv = self.encode(x)
return conv
class ConvRelu2d(nn.Module):
def __init__(self, in_channels, out_channels):
super(ConvRelu2d, self).__init__()
self.decoder = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, stride=1),
nn.ReLU(inplace=True),
)
def forward(self,x):
return self.decoder(x)
class M_Decoder(nn.Module):
def __init__(self, input_channels, output_channels):
super(M_Decoder, self).__init__()
self.up = nn.ConvTranspose2d(input_channels, output_channels, kernel_size=2, stride=2)
self.deconv = nn.Sequential(
ConvRelu2d(input_channels, output_channels),
ConvRelu2d(output_channels, output_channels))
def forward(self, x0,x1):
up = self.up(x0)
out = torch.cat((up,x1),dim=1)
out = self.deconv(out)
return out
class MNet(nn.Module):
def __init__(self,in_channels=3,n_classes=1) -> None:
super(MNet,self).__init__()
# mutli-scale simple convolution
self.conv2 = M_Conv(3, 64)
self.conv3 = M_Conv(3, 128)
self.conv4 = M_Conv(3, 256)
# the down convolution contain concat operation
self.encoder1 = encoder(3,32)
self.maxpool1 = nn.MaxPool2d(kernel_size=2,stride=2)
self.encoder2 = encoder(64+32,64)
self.maxpool2 = nn.MaxPool2d(kernel_size=2,stride=2)
self.encoder3 = encoder(128+64,128)
self.maxpool3 = nn.MaxPool2d(kernel_size=2,stride=2)
self.encoder4 = encoder(256+128,256)
self.maxpool4 = nn.MaxPool2d(kernel_size=2,stride=2)
# the center
self.center = encoder(256, 512)
# the up convolution contain concat operation
self.up5 = M_Decoder(512, 256)
self.up6 = M_Decoder(256, 128)
self.up7 = M_Decoder(128, 64)
self.up8 = M_Decoder(64, 32)
# the sideoutput
self.side_6 = nn.Conv2d(256, n_classes, kernel_size=1, padding=0, stride=1, bias=True)
self.side_7 = nn.Conv2d(128, n_classes, kernel_size=1, padding=0, stride=1, bias=True)
self.side_8 = nn.Conv2d(64, n_classes, kernel_size=1, padding=0, stride=1, bias=True)
self.side_9 = nn.Conv2d(32, n_classes, kernel_size=1, padding=0, stride=1, bias=True)
def forward(self, x):
_, _, img_shape, _ = x.size()
x_2 = F.interpolate(x, scale_factor=0.5)
x_3 = F.interpolate(x, scale_factor=0.25)
x_4 = F.interpolate(x, scale_factor=0.125)
conv1 = self.encoder1(x)
pool1 = self.maxpool1(conv1)
input2 = torch.cat([self.conv2(x_2), pool1], dim=1)
conv2 = self.encoder2(input2)
pool2 = self.maxpool2(conv2)
input3 = torch.cat([self.conv3(x_3), pool2], dim=1)
conv3 = self.encoder3(input3)
pool3 = self.maxpool3(conv3)
input4 = torch.cat([self.conv4(x_4), pool3], dim=1)
conv4 = self.encoder4(input4)
pool4 = self.maxpool4(conv4)
conv5 = self.center(pool4)
up6 = self.up5(conv5,conv4)
up7 = self.up6(up6,conv3)
up8 = self.up7(up7,conv2)
up9 = self.up8(up8,conv1)
side_6 = F.upsample(up6, size=(img_shape, img_shape), mode='bilinear')
side_7 = F.upsample(up7, size=(img_shape, img_shape), mode='bilinear')
side_8 = F.upsample(up8, size=(img_shape, img_shape), mode='bilinear')
side_6 = self.side_6(side_6)
side_7 = self.side_7(side_7)
side_8 = self.side_8(side_8)
side_9 = self.side_9(up9)
ave_out = (side_6+side_7+side_8+side_9)/4
return [ave_out, side_6, side_7, side_8,side_9]
边栏推荐
- Oracle创建表、删除表、修改表(添加字段、修改字段、删除字段)语句总结
- Response < t > class
- Redis 之布隆过滤器
- 注册中心服务eureka 切换到 nocas遇到的问题
- 2022 summer practice (PowerDesigner tutorial learning record) (first week)
- 7. < tag string and API trade-offs> supplement: Sword finger offer 05. replace spaces
- Flask Development & get/post request
- Invalid bound statement (not found): com.exam.mapper.UserMapper.findbyid
- List < long >, list < integer > convert each other
- 论文写作用词
猜你喜欢

How to compare long and integer and why to report errors

自定义Json返回数据

PC端-bug记录

How does Alibaba use DDD to split microservices?

FusionGAN代码学习(一)

From the basic concept of micro services to core components - explain and analyze through an example

使用navicat或plsql导出csv格式,超过15位数字后面变成000(E+19)的问题

Edge calculation kubeedge+edgemash

regular expression

MySQL practice 45 lectures
随机推荐
There is no crossover in the time period within 24 hours
I've been in an outsourcing company for two years, and I feel like I'm going to die
VMware Workstation 与 Device/Credential Guard 不兼容。禁用 Device/Credential Guard
Redis 之布隆过滤器
Export excel, generate multiple sheet pages, and name them
21 day SQL punch in summary
How should programmers keep warm when winter is coming
Invalid bound statement (not found): com.exam.mapper.UserMapper.findbyid
First acquaintance with C language (1)
Test Development - UI testing in automated testing
[slam] lvi-sam analysis - Overview
数据库面试
ssm项目快速搭建项目配置文件
BigDecimal 进行四舍五入 四舍六入和保留两位小数
Microservice failure mode and building elastic system
注册中心服务eureka 切换到 nocas遇到的问题
Digital twin solutions inject new momentum into the construction of chemical parks
使用navicat或plsql导出csv格式,超过15位数字后面变成000(E+19)的问题
YUV to uiimage
regular expression