当前位置:网站首页>Resnet18 actual battle Baoke dream spirit
Resnet18 actual battle Baoke dream spirit
2022-07-05 12:26:00 【Dongcheng West que】
File path
pokemon.py( Custom dataset load file )
import torch
import os,glob
import random,csv
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms
from PIL import Image
datapath="pokemon"
class Pokemon(Dataset):
def __init__(self,root,resize,mode):
super(Pokemon,self).__init__()
self.root=root
self.resize=resize
self.name2label={}
for name in sorted(os.listdir(os.path.join(root))):
if not os.path.isdir(os.path.join(root,name)):
continue
self.name2label[name]=len(self.name2label.keys())
# print(self.name2label)
self.images,self.labels=self.load_csv("images.csv")
if mode=="train": #60%
self.images=self.images[:int(0.6*len(self.images))]
self.labels=self.labels[:int(0.6*len(self.labels))]
elif mode=="val": #20% =60%->80%
self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]
else: #20% =80%->100%
self.images = self.images[int(0.8 * len(self.images)):]
self.labels = self.labels[int(0.8 * len(self.labels)):]
def load_csv(self,filename):
if os.path.exists(os.path.join(self.root,filename))==0:
images=[]
for name in self.name2label.keys():
images+=glob.glob(os.path.join(self.root,name,"*.png"))
images+=glob.glob(os.path.join(self.root,name,"*.jpg"))
images+=glob.glob(os.path.join(self.root,name,"*.jpeg"))
images+=glob.glob(os.path.join(self.root,name,"*.gif"))
# print(len(images),images)
# {bulbasaur:0,charmander:1,mewtwo:2 }
random.shuffle(images)
with open(os.path.join(self.root,filename),mode="w",newline="") as f:
writer=csv.writer(f)
for img in images: #E:\\datasets\\pokemon\\bulbasaur\\00000000.png
name=img.split(os.sep)[-2]
label=self.name2label[name]
#E:\\datasets\\pokemon\\bulbasaur\\00000000.png ,0
writer.writerow([img,label])
print("writen into csv file:",filename)
# read from csv file
images,labels=[],[]
with open(os.path.join(self.root,filename))as f:
reader=csv.reader(f)
for row in reader:
img,label=row
label=int(label)
images.append(img)
labels.append(label)
assert len(images)==len(labels)
return images,labels
def __len__(self):
return len(self.images)
def denormalize(self,x_hat):
mean=[0.485,0.456,0.406]
std=[0.229,0.224,0.225]
# x_hat=(x-mean)/std
# x=x_hat*std=mean
# x:[c,h,w]
# mean:[3]=>[3,1,1]
mean=torch.tensor(mean).unsqueeze(1).unsqueeze(1)
std=torch.tensor(std).unsqueeze(1).unsqueeze(1)
# print("x_hat",x_hat.shape,"std",std.shape,"mean",mean.shape)
x=x_hat*std+mean
return x
def __getitem__(self, idx):
#idx [0-len(images)]
#self.images,self.labels
#img:"pokemon\\bulbasaur\\0000000.png" label :0
img,label=self.images[idx],self.labels[idx]
tf=transforms.Compose([
lambda x:Image.open(x).convert("RGB"), #string path=>image data
transforms.Resize((int(self.resize*1.25), int(self.resize*1.25))),
transforms.RandomRotation(15),
transforms.CenterCrop(self.resize),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225]) #mean,std Is a statistical constant , Normalize the image
])
img=tf(img)
label=torch.tensor(label)
return img,label
def main():
from visdom import Visdom
import time
import torchvision
viz=Visdom()
# Load data set , Method 2
"""
tf = transforms.Compose([
transforms.Resize((64,64)),
transforms.ToTensor(),
])
db=torchvision.datasets.ImageFolder(root="pokemon",transform=tf)
loader=DataLoader(db,batch_size=32,shuffle=True)
print("make-code",db.class_to_idx)
for x, y in loader:
viz.images(x, nrow=8, win="batch", opts=dict(title="batch"))
viz.text(str(y.numpy()), win="lablel", opts=dict(title="batch-y"))
time.sleep(10)
"""
db=Pokemon(datapath,128,"train")
x,y=next(iter(db))
print("sample",x.shape,y.shape,y)
viz.image(db.denormalize(x),win="sample_x",opts=dict(title="sample_x"))
loader=DataLoader(db,batch_size=32,shuffle=True,num_workers=8)
for x,y in loader:
viz.images(db.denormalize(x),nrow=8,win="batch",opts=dict(title="batch"))
viz.text(str(y.numpy()),win="lablel",opts=dict(title="batch-y"))
time.sleep(10)
if __name__=="__main__":
main()
resnet.py(resnet Network model definition )
import torch
from torch import nn
from torch.nn import functional as F
class ResBlk(nn.Module):
"""
resnet block
"""
def __init__(self, ch_in, ch_out, stride=1):
"""
:param ch_in:
:param ch_out:
"""
super(ResBlk, self).__init__()
self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1)
self.bn1 = nn.BatchNorm2d(ch_out)
self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(ch_out)
self.extra = nn.Sequential()
if ch_out != ch_in:
# [b, ch_in, h, w] => [b, ch_out, h, w]
self.extra = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride),
nn.BatchNorm2d(ch_out)
)
def forward(self, x):
"""
:param x: [b, ch, h, w]
:return:
"""
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
# short cut.
# extra module: [b, ch_in, h, w] => [b, ch_out, h, w]
# element-wise add:
out = self.extra(x) + out
out = F.relu(out)
return out
class ResNet18(nn.Module):
def __init__(self, num_class):
super(ResNet18, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, stride=3, padding=0),
nn.BatchNorm2d(16)
)
# followed 4 blocks
# [b, 16, h, w] => [b, 32, h ,w]
self.blk1 = ResBlk(16, 32, stride=3)
# [b, 32, h, w] => [b, 64, h, w]
self.blk2 = ResBlk(32, 64, stride=3)
# # [b, 64, h, w] => [b, 128, h, w]
self.blk3 = ResBlk(64, 128, stride=2)
# # [b, 128, h, w] => [b, 256, h, w]
self.blk4 = ResBlk(128, 256, stride=2)
# [b, 256, 7, 7]
self.outlayer = nn.Linear(256*3*3, num_class)
def forward(self, x):
"""
:param x:
:return:
"""
x = F.relu(self.conv1(x))
# [b, 64, h, w] => [b, 1024, h, w]
x = self.blk1(x)
x = self.blk2(x)
x = self.blk3(x)
x = self.blk4(x)
# print(x.shape)
x = x.view(x.size(0), -1)
x = self.outlayer(x)
return x
def main():
blk = ResBlk(64, 128)
tmp = torch.randn(2, 64, 224, 224)
out = blk(tmp)
print('block:', out.shape)
model = ResNet18(5)
tmp = torch.randn(2, 3, 224, 224)
out = model(tmp)
print('resnet:', out.shape)
p = sum(map(lambda p:p.numel(), model.parameters()))
print('parameters size:', p)
if __name__ == '__main__':
main()
train.py( Training documents )
import torch
from torch import optim,nn
import visdom
import torchvision
from torch.utils.data import DataLoader
from pokemon import Pokemon
from resnet import ResNet18
batchsz=32
lr=1e-3
epochs=20
device=torch.device("cuda")
torch.manual_seed(1234)
train_db=Pokemon("pokemon",224,mode="train")
val_db=Pokemon("pokemon",224,mode="val")
test_db=Pokemon("pokemon",224,mode="test")
train_loader=DataLoader(train_db,batch_size=batchsz,shuffle=True,
num_workers=4)
val_loader=DataLoader(val_db,batch_size=batchsz, num_workers=2)
test_loader=DataLoader(test_db,batch_size=batchsz, num_workers=2)
viz=visdom.Visdom()
def evalute(model,loader):
correct=0
total=len(loader.dataset)
for x,y in loader:
x,y=x.to(device),y.to(device)
with torch.no_grad():
logits=model(x)
pred=logits.argmax(dim=1)
correct+=torch.eq(pred,y).sum().float().item()
return correct/total
def main():
model=ResNet18(5).to(device)
optimizer=optim.Adam(model.parameters(),lr=lr)
criteon=nn.CrossEntropyLoss()
best_acc,best_epoch=0,0
global_step=0
viz.line([0],[-1],win="loss",opts=dict(title="loss"))
viz.line([0],[-1],win="val_acc",opts=dict(title="val_acc"))
for epoch in range(epochs):
for step,(x,y) in enumerate(train_loader):
x,y=x.to(device),y.to(device)
logits=model(x)
# print("y", y.shape,y)
# print("logits",logits.shape,logits)
loss=criteon(logits,y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step%10==0:
print("epoch:",epoch,"step:",step,"loss:",loss.item())
viz.line([loss.item()], [global_step], win="loss", update="append")
global_step+=1
if epoch%1==0:
val_acc=evalute(model,val_loader)
viz.line([val_acc], [global_step], win="val_acc", update="append")
print("epoch:",epoch,"val_acc:",val_acc)
if val_acc>best_acc:
best_epoch=epoch
best_acc=val_acc
torch.save(model.state_dict(),"best.mdl")
print("best acc:",best_acc,"best epoch:",best_epoch)
model.load_state_dict(torch.load("best.mdl"))
print("loaded from ckpt!")
test_acc=evalute(model,test_loader)
print("test acc:",test_acc)
if __name__ == '__main__':
main()
utils.py
from matplotlib import pyplot as plt
import torch
from torch import nn
class Flatten(nn.Module):
def __init__(self):
super(Flatten, self).__init__()
def forward(self, x):
shape = torch.prod(torch.tensor(x.shape[1:])).item()
return x.view(-1, shape)
def plot_image(img, label, name):
fig = plt.figure()
for i in range(6):
plt.subplot(2, 3, i + 1)
plt.tight_layout()
plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
plt.title("{}: {}".format(name, label[i].item()))
plt.xticks([])
plt.yticks([])
plt.show()
train_transfer.py Transfer learning to achieve
import torch
from torch import optim,nn
import visdom
import torchvision
from torch.utils.data import DataLoader
from pokemon import Pokemon
# from resnet import ResNet18
from torchvision.models import resnet18
from utils import Flatten
batchsz=32
lr=1e-3
epochs=20
device=torch.device("cuda")
torch.manual_seed(1234)
train_db=Pokemon("pokemon",224,mode="train")
val_db=Pokemon("pokemon",224,mode="val")
test_db=Pokemon("pokemon",224,mode="test")
train_loader=DataLoader(train_db,batch_size=batchsz,shuffle=True,
num_workers=4)
val_loader=DataLoader(val_db,batch_size=batchsz, num_workers=2)
test_loader=DataLoader(test_db,batch_size=batchsz, num_workers=2)
viz=visdom.Visdom()
def evalute(model,loader):
correct=0
total=len(loader.dataset)
for x,y in loader:
x,y=x.to(device),y.to(device)
with torch.no_grad():
logits=model(x)
pred=logits.argmax(dim=1)
correct+=torch.eq(pred,y).sum().float().item()
return correct/total
def main():
# model=ResNet18(5).to(device)
trained_model=resnet18(pretrained=True)
model=nn.Sequential(*list(trained_model.children())[:-1], #[b,512,1,1]
Flatten(), #[b,512,1,1]=>[b,512]
nn.Linear(512,5)
).to(device)
# x=torch.randn(2,3,224,224)
# print(model(x).shape)
optimizer=optim.Adam(model.parameters(),lr=lr)
criteon=nn.CrossEntropyLoss()
best_acc,best_epoch=0,0
global_step=0
viz.line([0],[-1],win="loss",opts=dict(title="loss"))
viz.line([0],[-1],win="val_acc",opts=dict(title="val_acc"))
for epoch in range(epochs):
for step,(x,y) in enumerate(train_loader):
x,y=x.to(device),y.to(device)
logits=model(x)
loss=criteon(logits,y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step%10==0:
print("epoch:",epoch,"step:",step,"loss:",loss.item())
viz.line([loss.item()], [global_step], win="loss", update="append")
global_step+=1
if epoch%1==0:
val_acc=evalute(model,val_loader)
viz.line([val_acc], [global_step], win="val_acc", update="append")
print("epoch:",epoch,"val_acc:",val_acc)
if val_acc>best_acc:
best_epoch=epoch
best_acc=val_acc
torch.save(model.state_dict(),"best.mdl")
print("best acc:",best_acc,"best epoch:",best_epoch)
model.load_state_dict(torch.load("best.mdl"))
print("loaded from ckpt!")
test_acc=evalute(model,test_loader)
print("test acc:",test_acc)
if __name__ == '__main__':
main()
边栏推荐
- ZABBIX agent2 monitors mongodb nodes, clusters and templates (official blog)
- Matlab label2idx function (convert the label matrix into a cell array with linear index)
- Principle and performance analysis of lepton lossless compression
- Application of a class of identities (vandermond convolution and hypergeometric functions)
- Instance + source code = see through 128 traps
- Four operations and derivative operations of MATLAB polynomials
- ZABBIX 5.0 - LNMP environment compilation and installation
- July Huaqing learning-1
- Video networkState 属性
- 图像超分实验:SRCNN/FSRCNN
猜你喜欢
Linux Installation and deployment lamp (apache+mysql+php)
Understand kotlin from the perspective of an architect
Learn memory management of JVM 01 - first memory
MySQL storage engine
ZABBIX ODBC database monitoring
Get data from the database when using JMeter for database assertion
Mmclassification training custom data
Embedded software architecture design - message interaction
Master the new features of fluent 2.10
Interviewer: is acid fully guaranteed for redis transactions?
随机推荐
Implementing Yang Hui triangle with cyclic queue C language
Two minutes will take you to quickly master the project structure, resources, dependencies and localization of flutter
PIP command reports an error pip is configured with locations that requires tls/ssl problems
Semantic segmentation experiment: UNET network /msrc2 dataset
Why learn harmonyos and how to get started quickly?
Solve the problem of cache and database double write data consistency
One article tells the latest and complete learning materials of flutter
Redis highly available slice cluster
【ijkplayer】when i compile file “compile-ffmpeg.sh“ ,it show error “No such file or directory“.
Complete activity switching according to sliding
Learn garbage collection 01 of JVM -- garbage collection for the first time and life and death judgment
Mmclassification training custom data
Redis highly available sentinel mechanism
Read and understand the rendering mechanism and principle of flutter's three trees
Clear neo4j database data
Halcon template matching actual code (I)
Acid transaction theory
Instance + source code = see through 128 traps
Learning JVM garbage collection 06 - memory set and card table (hotspot)
Summary of C language learning problems (VS)