当前位置:网站首页>Resnet18 actual battle Baoke dream spirit
Resnet18 actual battle Baoke dream spirit
2022-07-05 12:26:00 【Dongcheng West que】
File path
pokemon.py( Custom dataset load file )
import torch
import os,glob
import random,csv
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms
from PIL import Image
datapath="pokemon"
class Pokemon(Dataset):
def __init__(self,root,resize,mode):
super(Pokemon,self).__init__()
self.root=root
self.resize=resize
self.name2label={}
for name in sorted(os.listdir(os.path.join(root))):
if not os.path.isdir(os.path.join(root,name)):
continue
self.name2label[name]=len(self.name2label.keys())
# print(self.name2label)
self.images,self.labels=self.load_csv("images.csv")
if mode=="train": #60%
self.images=self.images[:int(0.6*len(self.images))]
self.labels=self.labels[:int(0.6*len(self.labels))]
elif mode=="val": #20% =60%->80%
self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]
else: #20% =80%->100%
self.images = self.images[int(0.8 * len(self.images)):]
self.labels = self.labels[int(0.8 * len(self.labels)):]
def load_csv(self,filename):
if os.path.exists(os.path.join(self.root,filename))==0:
images=[]
for name in self.name2label.keys():
images+=glob.glob(os.path.join(self.root,name,"*.png"))
images+=glob.glob(os.path.join(self.root,name,"*.jpg"))
images+=glob.glob(os.path.join(self.root,name,"*.jpeg"))
images+=glob.glob(os.path.join(self.root,name,"*.gif"))
# print(len(images),images)
# {bulbasaur:0,charmander:1,mewtwo:2 }
random.shuffle(images)
with open(os.path.join(self.root,filename),mode="w",newline="") as f:
writer=csv.writer(f)
for img in images: #E:\\datasets\\pokemon\\bulbasaur\\00000000.png
name=img.split(os.sep)[-2]
label=self.name2label[name]
#E:\\datasets\\pokemon\\bulbasaur\\00000000.png ,0
writer.writerow([img,label])
print("writen into csv file:",filename)
# read from csv file
images,labels=[],[]
with open(os.path.join(self.root,filename))as f:
reader=csv.reader(f)
for row in reader:
img,label=row
label=int(label)
images.append(img)
labels.append(label)
assert len(images)==len(labels)
return images,labels
def __len__(self):
return len(self.images)
def denormalize(self,x_hat):
mean=[0.485,0.456,0.406]
std=[0.229,0.224,0.225]
# x_hat=(x-mean)/std
# x=x_hat*std=mean
# x:[c,h,w]
# mean:[3]=>[3,1,1]
mean=torch.tensor(mean).unsqueeze(1).unsqueeze(1)
std=torch.tensor(std).unsqueeze(1).unsqueeze(1)
# print("x_hat",x_hat.shape,"std",std.shape,"mean",mean.shape)
x=x_hat*std+mean
return x
def __getitem__(self, idx):
#idx [0-len(images)]
#self.images,self.labels
#img:"pokemon\\bulbasaur\\0000000.png" label :0
img,label=self.images[idx],self.labels[idx]
tf=transforms.Compose([
lambda x:Image.open(x).convert("RGB"), #string path=>image data
transforms.Resize((int(self.resize*1.25), int(self.resize*1.25))),
transforms.RandomRotation(15),
transforms.CenterCrop(self.resize),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225]) #mean,std Is a statistical constant , Normalize the image
])
img=tf(img)
label=torch.tensor(label)
return img,label
def main():
from visdom import Visdom
import time
import torchvision
viz=Visdom()
# Load data set , Method 2
"""
tf = transforms.Compose([
transforms.Resize((64,64)),
transforms.ToTensor(),
])
db=torchvision.datasets.ImageFolder(root="pokemon",transform=tf)
loader=DataLoader(db,batch_size=32,shuffle=True)
print("make-code",db.class_to_idx)
for x, y in loader:
viz.images(x, nrow=8, win="batch", opts=dict(title="batch"))
viz.text(str(y.numpy()), win="lablel", opts=dict(title="batch-y"))
time.sleep(10)
"""
db=Pokemon(datapath,128,"train")
x,y=next(iter(db))
print("sample",x.shape,y.shape,y)
viz.image(db.denormalize(x),win="sample_x",opts=dict(title="sample_x"))
loader=DataLoader(db,batch_size=32,shuffle=True,num_workers=8)
for x,y in loader:
viz.images(db.denormalize(x),nrow=8,win="batch",opts=dict(title="batch"))
viz.text(str(y.numpy()),win="lablel",opts=dict(title="batch-y"))
time.sleep(10)
if __name__=="__main__":
main()
resnet.py(resnet Network model definition )
import torch
from torch import nn
from torch.nn import functional as F
class ResBlk(nn.Module):
"""
resnet block
"""
def __init__(self, ch_in, ch_out, stride=1):
"""
:param ch_in:
:param ch_out:
"""
super(ResBlk, self).__init__()
self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1)
self.bn1 = nn.BatchNorm2d(ch_out)
self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(ch_out)
self.extra = nn.Sequential()
if ch_out != ch_in:
# [b, ch_in, h, w] => [b, ch_out, h, w]
self.extra = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride),
nn.BatchNorm2d(ch_out)
)
def forward(self, x):
"""
:param x: [b, ch, h, w]
:return:
"""
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
# short cut.
# extra module: [b, ch_in, h, w] => [b, ch_out, h, w]
# element-wise add:
out = self.extra(x) + out
out = F.relu(out)
return out
class ResNet18(nn.Module):
def __init__(self, num_class):
super(ResNet18, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, stride=3, padding=0),
nn.BatchNorm2d(16)
)
# followed 4 blocks
# [b, 16, h, w] => [b, 32, h ,w]
self.blk1 = ResBlk(16, 32, stride=3)
# [b, 32, h, w] => [b, 64, h, w]
self.blk2 = ResBlk(32, 64, stride=3)
# # [b, 64, h, w] => [b, 128, h, w]
self.blk3 = ResBlk(64, 128, stride=2)
# # [b, 128, h, w] => [b, 256, h, w]
self.blk4 = ResBlk(128, 256, stride=2)
# [b, 256, 7, 7]
self.outlayer = nn.Linear(256*3*3, num_class)
def forward(self, x):
"""
:param x:
:return:
"""
x = F.relu(self.conv1(x))
# [b, 64, h, w] => [b, 1024, h, w]
x = self.blk1(x)
x = self.blk2(x)
x = self.blk3(x)
x = self.blk4(x)
# print(x.shape)
x = x.view(x.size(0), -1)
x = self.outlayer(x)
return x
def main():
blk = ResBlk(64, 128)
tmp = torch.randn(2, 64, 224, 224)
out = blk(tmp)
print('block:', out.shape)
model = ResNet18(5)
tmp = torch.randn(2, 3, 224, 224)
out = model(tmp)
print('resnet:', out.shape)
p = sum(map(lambda p:p.numel(), model.parameters()))
print('parameters size:', p)
if __name__ == '__main__':
main()
train.py( Training documents )
import torch
from torch import optim,nn
import visdom
import torchvision
from torch.utils.data import DataLoader
from pokemon import Pokemon
from resnet import ResNet18
batchsz=32
lr=1e-3
epochs=20
device=torch.device("cuda")
torch.manual_seed(1234)
train_db=Pokemon("pokemon",224,mode="train")
val_db=Pokemon("pokemon",224,mode="val")
test_db=Pokemon("pokemon",224,mode="test")
train_loader=DataLoader(train_db,batch_size=batchsz,shuffle=True,
num_workers=4)
val_loader=DataLoader(val_db,batch_size=batchsz, num_workers=2)
test_loader=DataLoader(test_db,batch_size=batchsz, num_workers=2)
viz=visdom.Visdom()
def evalute(model,loader):
correct=0
total=len(loader.dataset)
for x,y in loader:
x,y=x.to(device),y.to(device)
with torch.no_grad():
logits=model(x)
pred=logits.argmax(dim=1)
correct+=torch.eq(pred,y).sum().float().item()
return correct/total
def main():
model=ResNet18(5).to(device)
optimizer=optim.Adam(model.parameters(),lr=lr)
criteon=nn.CrossEntropyLoss()
best_acc,best_epoch=0,0
global_step=0
viz.line([0],[-1],win="loss",opts=dict(title="loss"))
viz.line([0],[-1],win="val_acc",opts=dict(title="val_acc"))
for epoch in range(epochs):
for step,(x,y) in enumerate(train_loader):
x,y=x.to(device),y.to(device)
logits=model(x)
# print("y", y.shape,y)
# print("logits",logits.shape,logits)
loss=criteon(logits,y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step%10==0:
print("epoch:",epoch,"step:",step,"loss:",loss.item())
viz.line([loss.item()], [global_step], win="loss", update="append")
global_step+=1
if epoch%1==0:
val_acc=evalute(model,val_loader)
viz.line([val_acc], [global_step], win="val_acc", update="append")
print("epoch:",epoch,"val_acc:",val_acc)
if val_acc>best_acc:
best_epoch=epoch
best_acc=val_acc
torch.save(model.state_dict(),"best.mdl")
print("best acc:",best_acc,"best epoch:",best_epoch)
model.load_state_dict(torch.load("best.mdl"))
print("loaded from ckpt!")
test_acc=evalute(model,test_loader)
print("test acc:",test_acc)
if __name__ == '__main__':
main()
utils.py
from matplotlib import pyplot as plt
import torch
from torch import nn
class Flatten(nn.Module):
def __init__(self):
super(Flatten, self).__init__()
def forward(self, x):
shape = torch.prod(torch.tensor(x.shape[1:])).item()
return x.view(-1, shape)
def plot_image(img, label, name):
fig = plt.figure()
for i in range(6):
plt.subplot(2, 3, i + 1)
plt.tight_layout()
plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
plt.title("{}: {}".format(name, label[i].item()))
plt.xticks([])
plt.yticks([])
plt.show()
train_transfer.py Transfer learning to achieve
import torch
from torch import optim,nn
import visdom
import torchvision
from torch.utils.data import DataLoader
from pokemon import Pokemon
# from resnet import ResNet18
from torchvision.models import resnet18
from utils import Flatten
batchsz=32
lr=1e-3
epochs=20
device=torch.device("cuda")
torch.manual_seed(1234)
train_db=Pokemon("pokemon",224,mode="train")
val_db=Pokemon("pokemon",224,mode="val")
test_db=Pokemon("pokemon",224,mode="test")
train_loader=DataLoader(train_db,batch_size=batchsz,shuffle=True,
num_workers=4)
val_loader=DataLoader(val_db,batch_size=batchsz, num_workers=2)
test_loader=DataLoader(test_db,batch_size=batchsz, num_workers=2)
viz=visdom.Visdom()
def evalute(model,loader):
correct=0
total=len(loader.dataset)
for x,y in loader:
x,y=x.to(device),y.to(device)
with torch.no_grad():
logits=model(x)
pred=logits.argmax(dim=1)
correct+=torch.eq(pred,y).sum().float().item()
return correct/total
def main():
# model=ResNet18(5).to(device)
trained_model=resnet18(pretrained=True)
model=nn.Sequential(*list(trained_model.children())[:-1], #[b,512,1,1]
Flatten(), #[b,512,1,1]=>[b,512]
nn.Linear(512,5)
).to(device)
# x=torch.randn(2,3,224,224)
# print(model(x).shape)
optimizer=optim.Adam(model.parameters(),lr=lr)
criteon=nn.CrossEntropyLoss()
best_acc,best_epoch=0,0
global_step=0
viz.line([0],[-1],win="loss",opts=dict(title="loss"))
viz.line([0],[-1],win="val_acc",opts=dict(title="val_acc"))
for epoch in range(epochs):
for step,(x,y) in enumerate(train_loader):
x,y=x.to(device),y.to(device)
logits=model(x)
loss=criteon(logits,y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step%10==0:
print("epoch:",epoch,"step:",step,"loss:",loss.item())
viz.line([loss.item()], [global_step], win="loss", update="append")
global_step+=1
if epoch%1==0:
val_acc=evalute(model,val_loader)
viz.line([val_acc], [global_step], win="val_acc", update="append")
print("epoch:",epoch,"val_acc:",val_acc)
if val_acc>best_acc:
best_epoch=epoch
best_acc=val_acc
torch.save(model.state_dict(),"best.mdl")
print("best acc:",best_acc,"best epoch:",best_epoch)
model.load_state_dict(torch.load("best.mdl"))
print("loaded from ckpt!")
test_acc=evalute(model,test_loader)
print("test acc:",test_acc)
if __name__ == '__main__':
main()
边栏推荐
- Codeworks 5 questions per day (1700 average) - day 5
- Which domestic cloud management platform manufacturer is good in 2022? Why?
- Tabbar configuration at the bottom of wechat applet
- Principle and performance analysis of lepton lossless compression
- A guide to threaded and asynchronous UI development in the "quick start fluent Development Series tutorials"
- Knowledge representation (KR)
- Wireless WiFi learning 8-channel transmitting remote control module
- Reinforcement learning - learning notes 3 | strategic learning
- Migrate data from Mysql to neo4j database
- 自动化测试生命周期
猜你喜欢
MySQL splits strings for conditional queries
Simple production of wechat applet cloud development authorization login
Why do you always fail in automated tests?
Get all stock data of big a
Select drop-down box realizes three-level linkage of provinces and cities in China
The evolution of mobile cross platform technology
Linux Installation and deployment lamp (apache+mysql+php)
Tabbar configuration at the bottom of wechat applet
Codeworks 5 questions per day (1700 average) - day 5
Linux安装部署LAMP(Apache+MySQL+PHP)
随机推荐
MVVM framework part I lifecycle
Acid transaction theory
JS for loop number exception
Which domestic cloud management platform manufacturer is good in 2022? Why?
Constructing expression binary tree with prefix expression
Take you hand in hand to develop a service monitoring component
A new WiFi option for smart home -- the application of simplewifi in wireless smart home
MySQL index (1)
Array cyclic shift problem
Migrate data from Mysql to neo4j database
MySQL stored procedure
Swift - enables textview to be highly adaptive
[untitled]
Codeforces Round #804 (Div. 2)
What is digital existence? Digital transformation starts with digital existence
POJ-2499 Binary Tree
Video networkState 属性
Matlab superpixels function (2D super pixel over segmentation of image)
Course design of compilation principle --- formula calculator (a simple calculator with interface developed based on QT)
MySQL data table operation DDL & data type