当前位置:网站首页>Resnet+attention project complete code learning
Resnet+attention project complete code learning
2022-07-05 12:26:00 【Dongcheng West que】
Project name :CBAM.PyTorch-master
Source papers :CBAM: Convolutional Block Attention Module--CVPR2018

Project path information :

train.py
import os
from collections import OrderedDict
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision
from torchvision import transforms, models, datasets
import matplotlib.pyplot as plt
from data_loader.ImageNet_datasets import ImageNetData
import model.resnet_cbam as resnet_cbam
from model.Medical import CovNet
from trainer.trainer import Trainer
from utils.logger import Logger
from PIL import Image
from torchnet.meter import ClassErrorMeter
from tensorboardX import SummaryWriter
import torch.backends.cudnn as cudnn
import warnings
warnings.filterwarnings("ignore")
resize=224
def load_state_dict(model_dir, is_multi_gpu):
state_dict = torch.load(model_dir, map_location=lambda storage, loc: storage)['state_dict']
if is_multi_gpu:
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
return new_state_dict
else:
return state_dict
def main(args):
if 0 == len(args.resume):
logger = Logger('./logs/'+args.model+'.log')
else:
logger = Logger('./logs/'+args.model+'.log', True)
logger.append(vars(args))
if args.display:
writer = SummaryWriter()
else:
writer = None
gpus = args.gpu.split(',')
data_transforms = {
'train': transforms.Compose([
# transforms.RandomResizedCrop(224),
# transforms.RandomHorizontalFlip(),
transforms.Resize((args.imagesize, args.imagesize)),
# transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.ToTensor(),
# transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize((args.imagesize, args.imagesize)),
transforms.ToTensor(),
# transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
}
train_datasets = datasets.ImageFolder(os.path.join(args.data_root, 't256'), data_transforms['train'])
val_datasets = datasets.ImageFolder(os.path.join(args.data_root, 'v256'), data_transforms['val'])
train_dataloaders = torch.utils.data.DataLoader(train_datasets, batch_size=args.batch_size*len(gpus), shuffle=True, num_workers=4)
val_dataloaders = torch.utils.data.DataLoader(val_datasets, batch_size=16, shuffle=True, num_workers=4)
unloader = transforms.ToPILImage()
if args.debug:
x, y =next(iter(train_dataloaders))
# image = x[0].squeeze(0) # remove the fake batch dimension
# image = unloader(image)
# image.save('example.jpg')
#
plt.text(2, -20, "labels:" + str(y.numpy()), fontsize=15)
grid_img = torchvision.utils.make_grid(x, nrow=8)
plt.imshow(grid_img.permute(1, 2, 0))
plt.show()
print("x.shape",x.shape)
# print(y.shape)
# print("y",y)
# logger.append([x, y])
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
is_use_cuda = torch.cuda.is_available()
cudnn.benchmark = True
if 'resnet50' == args.model:
my_model = models.resnet50(pretrained=False)
my_model.fc = nn.Linear(2048, 5)
elif 'resnet18' == args.model:
my_model = models.resnet18(pretrained=True)
my_model.fc = nn.Linear(512, 5)
elif 'resnet50_cbam' == args.model:
my_model = resnet_cbam.resnet50_cbam(pretrained=True)
my_model.fc = nn.Linear(2048, 5)
elif 'resnet101_cbam' == args.model:
my_model = resnet_cbam.resnet101_cbam(pretrained=True)
my_model.fc = nn.Linear(2048, 2)
# my_model.sfT = nn.Sigmoid()
elif 'resnet101' == args.model:
my_model = models.resnet101(pretrained=True)
my_model.fc = nn.Linear(2048, 2)
# my_model.sfT = nn.Sigmoid()
elif 'resnet152_cbam' == args.model:
my_model = resnet_cbam.resnet152_cbam(pretrained=True)
my_model.fc = nn.Linear(2048, 2)
# my_model.sfT = nn.Sigmoid()
elif 'resnet152' == args.model:
my_model = models.resnet152(pretrained=True)
my_model.fc = nn.Linear(2048, 2)
# my_model.sfT = nn.Sigmoid()
elif 'vgg19' == args.model:
my_model = models.vgg19(pretrained=True)
my_model.fc = nn.Linear(1000, 5)
# my_model.sfT = nn.Sigmoid()
elif 'CovNet' == args.model.split('_')[0]:
my_model=CovNet(5)
else:
raise ModuleNotFoundError
#my_model.apply(fc_init)
if is_use_cuda and 1 == len(gpus):
my_model = my_model.cuda()
elif is_use_cuda and 1 < len(gpus):
my_model = nn.DataParallel(my_model.cuda())
print(my_model)
loss_fn = [nn.CrossEntropyLoss()]
optimizer = optim.SGD(my_model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
lr_schedule = lr_scheduler.MultiStepLR(optimizer, milestones=[20,40, 60], gamma=0.1)
metric = [ClassErrorMeter([1,2], True)]
start_epoch = 0
num_epochs = 50
my_trainer = Trainer(my_model, args.model, loss_fn, optimizer, lr_schedule, 6, is_use_cuda, train_dataloaders, \
val_dataloaders, metric, start_epoch, num_epochs, args.debug, logger, writer)
my_trainer.fit()
# logger.append('Optimize Done!')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='PyTorch Template')
parser.add_argument('-r', '--resume', default='', type=str,
help='path to latest checkpoint (default: None)')
parser.add_argument('--debug', action='store_true', default=True,dest='debug',
help='trainer debug flag')
parser.add_argument('-g', '--gpu', default='0', type=str,
help='GPU ID Select')
parser.add_argument('-d', '--data_root', default='./datasets',
type=str, help='data root')
parser.add_argument('-t', '--train_file', default='./datasets/train.txt',
type=str, help='train file')
parser.add_argument('-v', '--val_file', default='./datasets/val.txt',
type=str, help='validation file')
parser.add_argument('-m', '--model', default='CovNet',
type=str, help='model type')
parser.add_argument('--batch_size', default=32,
type=int, help='model train batch size')
parser.add_argument('--display', action='store_true', dest='display',default=True,
help='Use TensorboardX to Display')
parser.add_argument('--imagesize', default=224,
type=int, help='model train batch size')
args = parser.parse_args()
main(args)
test.py
import os
from collections import OrderedDict
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim import lr_scheduler
from torchvision import transforms, models
from model import *
# import pretrainedmodels
import numpy as np
import model.resnet_cbam as resnet_cbam
#DATA_ROOT = './datasets/xuelang_round1_test_a_20180709'
#DATA_ROOT = './datasets/xuelang_round1_test_b'
DATA_ROOT = './datasets/xuelang_round2_test_a_20180809'
RESULT_FILE = 'result.csv'
import warnings
warnings.filterwarnings("ignore")
def test_and_generate_result(epoch_num, model_name='resnet101', img_size=320, is_multi_gpu=False):
data_transform = transforms.Compose([
transforms.Resize(img_size, Image.ANTIALIAS),
transforms.ToTensor(),
transforms.Normalize([0.53744068, 0.51462684, 0.52646497], [0.06178288, 0.05989952, 0.0618901])
])
os.environ['CUDA_VISIBLE_DEVICES'] = '4'
is_use_cuda = torch.cuda.is_available()
if 'resnet152' == model_name.split('_')[0]:
model_ft = models.resnet152(pretrained=True)
my_model = resnet152.MyResNet152(model_ft)
del model_ft
elif 'resnet50' == model_name.split('_')[0]:
model_ft = models.resnet50(pretrained=True)
my_model = resnet50.MyResNet50(model_ft)
del model_ft
elif 'resnet101' == model_name.split('_')[0]:
model_ft = models.resnet101(pretrained=True)
my_model = resnet101.MyResNet101(model_ft)
del model_ft
elif 'densenet121' == model_name.split('_')[0]:
model_ft = models.densenet121(pretrained=True)
my_model = densenet121.MyDenseNet121(model_ft)
del model_ft
elif 'densenet169' == model_name.split('_')[0]:
model_ft = models.densenet169(pretrained=True)
my_model = densenet169.MyDenseNet169(model_ft)
del model_ft
elif 'densenet201' == model_name.split('_')[0]:
model_ft = models.densenet201(pretrained=True)
my_model = densenet201.MyDenseNet201(model_ft)
del model_ft
elif 'densenet161' == model_name.split('_')[0]:
model_ft = models.densenet161(pretrained=True)
my_model = densenet161.MyDenseNet161(model_ft)
del model_ft
elif 'ranet' == model_name.split('_')[0]:
my_model = ranet.ResidualAttentionModel_92()
elif 'senet154' == model_name.split('_')[0]:
model_ft = pretrainedmodels.models.senet154(num_classes=1000, pretrained='imagenet')
my_model = MySENet154(model_ft)
del model_ft
else:
raise ModuleNotFoundError
state_dict = torch.load('./checkpoint/' + model_name + '/Models_epoch_' + epoch_num + '.ckpt', map_location=lambda storage, loc: storage.cuda())['state_dict']
if is_multi_gpu:
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
my_model.load_state_dict(new_state_dict)
else:
my_model.load_state_dict(state_dict)
if is_use_cuda:
my_model = my_model.cuda()
my_model.eval()
with open(os.path.join('checkpoint', model_name, model_name+'_'+str(img_size)+'_'+RESULT_FILE), 'w', encoding='utf-8') as fd:
fd.write('filename|defect,probability\n')
test_files_list = os.listdir(DATA_ROOT)
for _file in test_files_list:
file_name = _file
if '.jpg' not in file_name:
continue
file_path = os.path.join(DATA_ROOT, file_name)
img_tensor = data_transform(Image.open(file_path).convert('RGB')).unsqueeze(0)
if is_use_cuda:
img_tensor = Variable(img_tensor.cuda(), volatile=True)
output = F.softmax(my_model(img_tensor), dim=1)
defect_prob = round(output.data[0, 1], 6)
if defect_prob == 0.:
defect_prob = 0.000001
elif defect_prob == 1.:
defect_prob = 0.999999
target_str = '%s,%.6f\n' % (file_name, defect_prob)
fd.write(target_str)
def test_and_generate_result_round2(epoch_num, model_name='resnet101', img_size=224, is_multi_gpu=False):
data_transform = transforms.Compose([
transforms.Resize((img_size,img_size),Image.ANTIALIAS),
transforms.ToTensor(),
# transforms.Normalize([0.53744068, 0.51462684, 0.52646497], [0.06178288, 0.05989952, 0.0618901])
])
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
is_use_cuda = torch.cuda.is_available()
print(epoch_num)
print(model_name)
print(img_size)
print(is_multi_gpu)
#
# if 'resnet152' == model_name.split('_')[0]:
# model_ft = models.resnet152(pretrained=True)
# my_model = resnet152.MyResNet152(model_ft)
# del model_ft
# elif 'resnet152-r2' == model_name.split('_')[0]:
# model_ft = models.resnet152(pretrained=True)
# my_model = resnet152.MyResNet152_Round2(model_ft)
# del model_ft
# elif 'resnet152-r2-2o' == model_name.split('_')[0]:
# model_ft = models.resnet152(pretrained=True)
# my_model = resnet152.MyResNet152_Round2_2out(model_ft)
# del model_ft
# elif 'resnet152-r2-2o-gmp' == model_name.split('_')[0]:
# model_ft = models.resnet152(pretrained=True)
# my_model = resnet152.MyResNet152_Round2_2out_GMP(model_ft)
# del model_ft
# elif 'resnet152-r2-hm-r1' == model_name.split('_')[0]:
# model_ft = models.resnet152(pretrained=True)
# my_model = resnet152.MyResNet152_Round2_HM_round1(model_ft)
# del model_ft
# elif 'resnet50' == model_name.split('_')[0]:
# model_ft = models.resnet50(pretrained=True)
# my_model = resnet50.MyResNet50(model_ft)
# del model_ft
# elif 'resnet101' == model_name.split('_')[0]:
# model_ft = models.resnet101(pretrained=True)
# my_model = resnet101.MyResNet101(model_ft)
# del model_ft
# elif 'densenet121' == model_name.split('_')[0]:
# model_ft = models.densenet121(pretrained=True)
# my_model = densenet121.MyDenseNet121(model_ft)
# del model_ft
# elif 'densenet169' == model_name.split('_')[0]:
# model_ft = models.densenet169(pretrained=True)
# my_model = densenet169.MyDenseNet169(model_ft)
# del model_ft
# elif 'densenet201' == model_name.split('_')[0]:
# model_ft = models.densenet201(pretrained=True)
# my_model = densenet201.MyDenseNet201(model_ft)
# del model_ft
# elif 'densenet161' == model_name.split('_')[0]:
# model_ft = models.densenet161(pretrained=True)
# my_model = densenet161.MyDenseNet161(model_ft)
# del model_ft
# elif 'ranet' == model_name.split('_')[0]:
# my_model = ranet.ResidualAttentionModel_92()
# elif 'senet154' == model_name.split('_')[0]:
# model_ft = pretrainedmodels.models.senet154(num_classes=1000, pretrained='imagenet')
# my_model = MySENet154(model_ft)
# del model_ft
# else:
# raise ModuleNotFoundError
if 'resnet50' == model_name.split('_')[0]:
my_model = models.resnet50(pretrained=False)
elif 'resnet50-cbam' == model_name.split('_')[0]:
my_model = resnet_cbam.resnet50_cbam(pretrained=False)
elif 'resnet101' == model_name.split('_')[0]:
my_model = models.resnet101(pretrained=True)
my_model.fc = nn.Linear(2048, 2)
# my_model.sfT = nn.Sigmoid()
else:
raise ModuleNotFoundError
print('./checkpoint/' + model_name + '/Models_epoch_' + epoch_num + '.ckpt')
state_dict = torch.load('./checkpoint/' + model_name + '/Models_epoch_' + epoch_num + '.ckpt', map_location=lambda storage, loc: storage.cuda())['state_dict']
if is_multi_gpu:
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
my_model.load_state_dict(new_state_dict)
else:
my_model.load_state_dict(state_dict)
if is_use_cuda:
my_model = my_model.cuda()
my_model.eval()
with open(os.path.join('checkpoint', model_name, model_name+'_'+str(img_size)+'_'+RESULT_FILE), 'w', encoding='utf-8') as fd:
print("566", is_multi_gpu)
fd.write('filename|defect,probability\n')
test_files_list = os.listdir(DATA_ROOT)
print("566", test_files_list)
ii=0
for _file in test_files_list:
# print("566")
file_name = _file
# if '.jpg' not in file_name:
# continue
file_path = os.path.join(DATA_ROOT, file_name)
print(ii)
ii += 1
# print("5667",file_path)
img_tensor = data_transform(Image.open(file_path).convert('RGB')).unsqueeze(0)
# print("5667",img_tensor)
if is_use_cuda:
img_tensor = Variable(img_tensor.cuda(), volatile=True)
# _, output, _ = my_model(img_tensor)
print( ":", img_tensor.shape)
output = my_model(img_tensor)
print( "2222222222222:", output.data)
output = F.softmax(output, dim=1)
print( "33333333333333:", output.data[0, 0])
for k in range(2):
# print(k,":",output.data)
print("44444444444:", output.data[0, k])
defect_prob =np.round(output.data[0, k].cpu().numpy(), 6)
print("np.round:", defect_prob)
if defect_prob == 0.:
defect_prob = 0.000001
elif defect_prob == 1.:
defect_prob = 0.999999
target_str = '%s,%.6f\n' % (file_name + '|' + ('norm' if 0 == k else 'defect_'+str(k)), defect_prob)
print("target_str:",target_str)
fd.write(target_str)
if __name__ == '__main__':
#test_and_generate_result('10', 'resnet152_2018073100', 416, True)
#test_and_generate_result('2', 'resnet50_2018072500', 416, True)
#test_and_generate_result('7','resnet101_2018072600', 416, True)
#test_and_generate_result_round2('14','resnet152-r2-2o-gmp_2018081600', 600, True)
#test_and_generate_result_round2('14', 'resnet152-r2-2o_2018081300', 600, True)
#test_and_generate_result('12', 'densenet161_new_stra', 352, True)
#test_and_generate_result('25', 'ranet_2018072400', 416, True)
#test_and_generate_result('8', 'senet154_2018072500', 416, True)
# test_and_generate_result_round2('9','resnet152-r2-hm-r1_2018082000', 576, True)
test_and_generate_result_round2('9','resnet101', 224, False)
loger.py
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import traceback
class Logger(object):
'''Save training process to log file with simple plot function.'''
def __init__(self, fpath,resume=False):
self.file = None
self.resume = resume
if os.path.isfile(fpath):
if resume:
self.file = open(fpath, 'a')
else:
self.file = open(fpath, 'w')
else:
self.file = open(fpath, 'w')
def append(self, target_str):
if not isinstance(target_str, str):
try:
target_str = str(target_str)
except:
traceback.print_exc()
else:
# print(self.file)
# print(target_str)
self.file.write(target_str + '\n')
self.file.flush()
else:
self.file.write(target_str + '\n')
self.file.flush()
def close(self):
if self.file is not None:
self.file.close()train.py
import numpy as np
import torch
import torch.nn.functional as F
from torch.autograd import Variable
import time
import sys
import os
import torchvision
import matplotlib.pyplot as plt
class Trainer():
def __init__(self, model, model_type, loss_fn, optimizer, lr_schedule, log_batchs, is_use_cuda, train_data_loader, \
valid_data_loader=None, metric=None, start_epoch=0, num_epochs=25, is_debug=False, logger=None, writer=None):
self.model = model
self.model_type = model_type
self.loss_fn = loss_fn
self.optimizer = optimizer
self.lr_schedule = lr_schedule
self.log_batchs = log_batchs
self.is_use_cuda = is_use_cuda
self.train_data_loader = train_data_loader
self.valid_data_loader = valid_data_loader
self.metric = metric
self.start_epoch = start_epoch
self.num_epochs = num_epochs
self.is_debug = is_debug
self.cur_epoch = start_epoch
self.best_acc = 0.
self.best_loss = sys.float_info.max
self.logger = logger
self.writer = writer
self.global_step=0
def fit(self):
for epoch in range(0, self.start_epoch):
self.lr_schedule.step()
for epoch in range(self.start_epoch, self.num_epochs):
self.logger.append('Epoch {}/{}'.format(epoch, self.num_epochs - 1))
self.logger.append('-' * 60)
self.cur_epoch = epoch
# print(self.optimizer.state_dict()['param_groups'][0]['lr'])
if self.is_debug:
self._dump_infos()
self._train()
self.lr_schedule.step()
self._valid()
self._save_best_model()
# print()
def _dump_infos(self):
self.logger.append('---------------------Current Parameters---------------------')
self.logger.append('is use GPU: ' + ('True' if self.is_use_cuda else 'False'))
self.logger.append('lr: %f' % (self.lr_schedule.get_lr()[0]))
self.logger.append('model_type: %s' % (self.model_type))
self.logger.append('current epoch: %d' % (self.cur_epoch))
self.logger.append('best accuracy: %f' % (self.best_acc))
self.logger.append('best loss: %f' % (self.best_loss))
self.logger.append('------------------------------------------------------------')
def _train(self):
self.model.train() # Set model to training mode
losses = []
if self.metric is not None:
# print("self.metric11",self.metric)
# print("self.metric12",self.metric[0])
self.metric[0].reset()
print("self.train_data_loader.len()",len(self.train_data_loader))
for i, (inputs, labels) in enumerate(self.train_data_loader): # Notice
self.writer.add_image("label:"+str(labels[0]), inputs[0], global_step=i, walltime=None, dataformats='CHW')
if self.is_use_cuda:
inputs, labels = inputs.cuda(), labels.cuda()
labels = labels.squeeze()
else:
labels = labels.squeeze()
self.optimizer.zero_grad()
outputs = self.model(inputs) # Notice
# print("outputs.shape",outputs.shape)
# print("labels.shape",labels.shape)
# print("labels",labels)
# print("outputs :",outputs )
# print("prob :",prob )
# print("pass:",torch.argmax(outputs,1))
# plt.text(2, -20, "labels:" + str(labels.cpu().numpy()), fontsize=15)
# grid_img = torchvision.utils.make_grid(inputs.cpu(), nrow=8)
# plt.imshow(grid_img.permute(1, 2, 0))
# plt.title("TEST")
# plt.show()
loss = self.loss_fn[0](outputs, labels)
if i%10==0:
print("epoch:{},iter:{}, loss:{}".format(self.cur_epoch,i,loss.item()))
if self.metric is not None:
# print("outputsoutputs", outputs)
prob = F.softmax(outputs, dim=1).data.cpu()
# print("probprobprobprob",prob)
# print("probprobprobprob",labels)
self.metric[0].add(prob, labels.data.cpu())
loss.backward()
self.optimizer.step()
losses.append(loss.item()) # Notice
# print("0 == i % self.log_batchs0 == i % self.log_batchs",0 == i % self.log_batchs)
if 0 == i % self.log_batchs or (i == len(self.train_data_loader) - 1):
local_time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
batch_mean_loss = np.mean(losses)
print_str = '[%s]\tTraining Batch[%d/%d]\t Class Loss: %.4f\t' \
% (local_time_str, i, len(self.train_data_loader) - 1, batch_mean_loss)
if i == len(self.train_data_loader) - 1 and self.metric is not None:
top1_acc_score = self.metric[0].value()[0]
top5_acc_score = self.metric[0].value()[1]
print_str += '@Top-1 Score: %.4f\t' % (top1_acc_score)
print_str += '@Top-5 Score: %.4f\t' % (top5_acc_score)
self.logger.append(print_str)
self.writer.add_scalar('loss/loss_c', batch_mean_loss, self.global_step)
self.global_step+=1
def _valid(self):
self.model.eval()
losses = []
acc_rate = 0.
if self.metric is not None:
self.metric[0].reset()
with torch.no_grad(): # Notice
for i, (inputs, labels) in enumerate(self.valid_data_loader):
if self.is_use_cuda:
inputs, labels = inputs.cuda(), labels.cuda()
labels = labels.squeeze()
else:
labels = labels.squeeze()
outputs = self.model(inputs) # Notice
loss = self.loss_fn[0](outputs, labels)
if self.metric is not None:
prob = F.softmax(outputs, dim=1).data.cpu()
# print("abels :", labels)
# print("outputs :",outputs )
# print("prob :",prob )
# print("pass:",torch.argmax(prob,1))
self.metric[0].add(prob, labels.data.cpu())
# print("self.metric[0].value():",self.metric[0].value())
losses.append(loss.item())
local_time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
#self.logger.append(losses)
batch_mean_loss = np.mean(losses)
print_str = '[%s]\tValidation: \t Class Loss: %.4f\t' \
% (local_time_str, batch_mean_loss)
if self.metric is not None:
top1_acc_score = self.metric[0].value()[0]
top5_acc_score = self.metric[0].value()[1]
print_str += '@Top-1 Score: %.4f\t' % (top1_acc_score)
print_str += '@Top-5 Score: %.4f\t' % (top5_acc_score)
self.logger.append(print_str)
print("cur_epoch:",self.cur_epoch,"top1_acc_s:",top1_acc_score,"best_acc:",self.best_acc,"batch_mean_loss:",batch_mean_loss,"best_loss",self.best_loss)
if top1_acc_score >= self.best_acc:
self.best_acc = top1_acc_score
self.best_loss = batch_mean_loss
def _save_best_model(self):
# Save Model
self.logger.append('Saving Model...')
state = {
'state_dict': self.model.state_dict(),
'best_acc': self.best_acc,
'cur_epoch': self.cur_epoch,
'num_epochs': self.num_epochs
}
if not os.path.isdir('./checkpoint/' + self.model_type):
os.makedirs('./checkpoint/' + self.model_type)
torch.save(state, './checkpoint/' + self.model_type + '/Models' + '_epoch_%d' % self.cur_epoch + '.ckpt') # Noticemodel
Medical.py
import torch,cv2
import os,glob
import random,csv
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from torch import nn
import torchvision
from torch.nn import functional as F
try:
from skimage import data_dir
from skimage import io
from skimage import color
from skimage import img_as_float,transform
from skimage.transform import resize
except ImportError:
raise ImportError("This example requires scikit-image")
class Flatten(nn.Module):
def __init__(self):
super(Flatten, self).__init__()
def forward(self, x):
shape = torch.prod(torch.tensor(x.shape[1:])).item()
out=x.view(-1, shape)
return out
class ConBlk(nn.Module):
def __init__(self):
super(ConBlk,self).__init__()
self.conv1 = nn.Conv2d(3, 36, kernel_size=3, stride=2, padding=1)
self.pool1=nn.MaxPool2d(2,2)
self.bn1 = nn.BatchNorm2d(36)
self.conv2 = nn.Conv2d(36, 36, kernel_size=3, stride=2, padding=1)
self.pool2 = nn.MaxPool2d(2,2)
self.bn2 = nn.BatchNorm2d(36)
self.conv3 = nn.Conv2d(36, 36, kernel_size=3, stride=1, padding=1)
self.pool3 = nn.MaxPool2d(2,2)
# self.bn3 = nn.BatchNorm2d(36)
self.fla=Flatten()
def forward(self,x):
x1=F.relu(self.bn1(self.pool1(self.conv1(x))))
# print("x1.shape",x1.shape)
x2=F.relu(self.bn2(self.pool2(self.conv2(x1))))
# print("x2.shape", x2.shape)
x3=F.relu(self.pool3(self.conv3(x2)))
# print("x3.shape", x3.shape)
out=self.fla(x3)
return out
class CovNet(nn.Module):
def __init__(self,num_class=2):
super(CovNet, self).__init__()
self.blk1=ConBlk()
self.outlayer = nn.Sequential(
nn.Linear(1764, 1024),
nn.Dropout(0.5),
nn.Linear(1024, num_class),
)
def forward(self, x):
out=self.blk1(x)
# print("out.shape:::",out.shape)
out=self.outlayer(out)
return out
resnet_cbam.py
import torch
import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo
__all__ = ['ResNet', 'resnet18_cbam', 'resnet34_cbam', 'resnet50_cbam', 'resnet101_cbam',
'resnet152_cbam']
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}
def conv3x3(in_planes, out_planes, stride=1):
"3x3 convolution with padding"
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
class ChannelAttention(nn.Module):
def __init__(self, in_planes, ratio=16):
super(ChannelAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False)
self.relu1 = nn.ReLU()
self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
out = avg_out + max_out
return self.sigmoid(out)
class SpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
super(SpatialAttention, self).__init__()
assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
padding = 3 if kernel_size == 7 else 1
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
x = torch.cat([avg_out, max_out], dim=1)
x = self.conv1(x)
return self.sigmoid(x)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.ca = ChannelAttention(planes)
self.sa = SpatialAttention()
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.ca(out) * out
out = self.sa(out) * out
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.ca = ChannelAttention(planes * 4)
self.sa = SpatialAttention()
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
out = self.ca(out) * out
out = self.sa(out) * out
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000):
self.inplanes = 64
super(ResNet, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AvgPool2d(7, stride=1)
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
def resnet18_cbam(pretrained=False, **kwargs):
"""Constructs a ResNet-18 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
if pretrained:
pretrained_state_dict = model_zoo.load_url(model_urls['resnet18'])
now_state_dict = model.state_dict()
now_state_dict.update(pretrained_state_dict)
model.load_state_dict(now_state_dict)
return model
def resnet34_cbam(pretrained=False, **kwargs):
"""Constructs a ResNet-34 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
if pretrained:
pretrained_state_dict = model_zoo.load_url(model_urls['resnet34'])
now_state_dict = model.state_dict()
now_state_dict.update(pretrained_state_dict)
model.load_state_dict(now_state_dict)
return model
def resnet50_cbam(pretrained=False, **kwargs):
"""Constructs a ResNet-50 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained:
pretrained_state_dict = model_zoo.load_url(model_urls['resnet50'])
now_state_dict = model.state_dict()
now_state_dict.update(pretrained_state_dict)
model.load_state_dict(now_state_dict)
return model
def resnet101_cbam(pretrained=False, **kwargs):
"""Constructs a ResNet-101 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
if pretrained:
pretrained_state_dict = model_zoo.load_url(model_urls['resnet101'])
now_state_dict = model.state_dict()
now_state_dict.update(pretrained_state_dict)
model.load_state_dict(now_state_dict)
return model
def resnet152_cbam(pretrained=False, **kwargs):
"""Constructs a ResNet-152 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
if pretrained:
pretrained_state_dict = model_zoo.load_url(model_urls['resnet152'])
now_state_dict = model.state_dict()
now_state_dict.update(pretrained_state_dict)
model.load_state_dict(now_state_dict)
return model
边栏推荐
- Migrate data from Mysql to neo4j database
- Simple production of wechat applet cloud development authorization login
- Wireless WiFi learning 8-channel transmitting remote control module
- Learn the garbage collector of JVM -- a brief introduction to Shenandoah collector
- Just a coincidence? The mysterious technology of apple ios16 is actually the same as that of Chinese enterprises five years ago!
- 你做自动化测试为什么总是失败?
- Complete activity switching according to sliding
- Swift - add navigation bar
- Flutter2 heavy release supports web and desktop applications
- Two minutes will take you to quickly master the project structure, resources, dependencies and localization of flutter
猜你喜欢

Select drop-down box realizes three-level linkage of provinces and cities in China

Mmclassification training custom data

Flutter2 heavy release supports web and desktop applications

Hexadecimal conversion summary

Redis clean cache

Understand redis persistence mechanism in one article

mysql拆分字符串做条件查询
[email protected] (using password"/>Solve the error 1045 of Navicat creating local connection -access denied for user [email protected] (using password

Linux Installation and deployment lamp (apache+mysql+php)

mmclassification 训练自定义数据
随机推荐
Solve the error 1045 of Navicat creating local connection -access denied for user [email protected] (using password
ZABBIX agent2 monitors mongodb nodes, clusters and templates (official blog)
Automated test lifecycle
Learn the memory management of JVM 02 - memory allocation of JVM
Yum only downloads the RPM package of the software to the specified directory without installing it
Intern position selection and simplified career development planning in Internet companies
A guide to threaded and asynchronous UI development in the "quick start fluent Development Series tutorials"
Read and understand the rendering mechanism and principle of flutter's three trees
[hdu 2096] Xiaoming a+b
Learning items
Understanding the architecture type of mobile CPU
Solution to order timeout unpaid
Error modulenotfounderror: no module named 'cv2 aruco‘
[pytorch pre training model modification, addition and deletion of specific layers]
MySQL storage engine
一款新型的智能家居WiFi选择方案——SimpleWiFi在无线智能家居中的应用
GPS数据格式转换[通俗易懂]
Hiengine: comparable to the local cloud native memory database engine
byte2String、string2Byte
How does MySQL execute an SQL statement?