当前位置:网站首页>Resnet18 actual battle Baoke dream spirit
Resnet18 actual battle Baoke dream spirit
2022-07-05 12:26:00 【Dongcheng West que】
File path
pokemon.py( Custom dataset load file )
import torch
import os,glob
import random,csv
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms
from PIL import Image
datapath="pokemon"
class Pokemon(Dataset):
def __init__(self,root,resize,mode):
super(Pokemon,self).__init__()
self.root=root
self.resize=resize
self.name2label={}
for name in sorted(os.listdir(os.path.join(root))):
if not os.path.isdir(os.path.join(root,name)):
continue
self.name2label[name]=len(self.name2label.keys())
# print(self.name2label)
self.images,self.labels=self.load_csv("images.csv")
if mode=="train": #60%
self.images=self.images[:int(0.6*len(self.images))]
self.labels=self.labels[:int(0.6*len(self.labels))]
elif mode=="val": #20% =60%->80%
self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]
else: #20% =80%->100%
self.images = self.images[int(0.8 * len(self.images)):]
self.labels = self.labels[int(0.8 * len(self.labels)):]
def load_csv(self,filename):
if os.path.exists(os.path.join(self.root,filename))==0:
images=[]
for name in self.name2label.keys():
images+=glob.glob(os.path.join(self.root,name,"*.png"))
images+=glob.glob(os.path.join(self.root,name,"*.jpg"))
images+=glob.glob(os.path.join(self.root,name,"*.jpeg"))
images+=glob.glob(os.path.join(self.root,name,"*.gif"))
# print(len(images),images)
# {bulbasaur:0,charmander:1,mewtwo:2 }
random.shuffle(images)
with open(os.path.join(self.root,filename),mode="w",newline="") as f:
writer=csv.writer(f)
for img in images: #E:\\datasets\\pokemon\\bulbasaur\\00000000.png
name=img.split(os.sep)[-2]
label=self.name2label[name]
#E:\\datasets\\pokemon\\bulbasaur\\00000000.png ,0
writer.writerow([img,label])
print("writen into csv file:",filename)
# read from csv file
images,labels=[],[]
with open(os.path.join(self.root,filename))as f:
reader=csv.reader(f)
for row in reader:
img,label=row
label=int(label)
images.append(img)
labels.append(label)
assert len(images)==len(labels)
return images,labels
def __len__(self):
return len(self.images)
def denormalize(self,x_hat):
mean=[0.485,0.456,0.406]
std=[0.229,0.224,0.225]
# x_hat=(x-mean)/std
# x=x_hat*std=mean
# x:[c,h,w]
# mean:[3]=>[3,1,1]
mean=torch.tensor(mean).unsqueeze(1).unsqueeze(1)
std=torch.tensor(std).unsqueeze(1).unsqueeze(1)
# print("x_hat",x_hat.shape,"std",std.shape,"mean",mean.shape)
x=x_hat*std+mean
return x
def __getitem__(self, idx):
#idx [0-len(images)]
#self.images,self.labels
#img:"pokemon\\bulbasaur\\0000000.png" label :0
img,label=self.images[idx],self.labels[idx]
tf=transforms.Compose([
lambda x:Image.open(x).convert("RGB"), #string path=>image data
transforms.Resize((int(self.resize*1.25), int(self.resize*1.25))),
transforms.RandomRotation(15),
transforms.CenterCrop(self.resize),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225]) #mean,std Is a statistical constant , Normalize the image
])
img=tf(img)
label=torch.tensor(label)
return img,label
def main():
from visdom import Visdom
import time
import torchvision
viz=Visdom()
# Load data set , Method 2
"""
tf = transforms.Compose([
transforms.Resize((64,64)),
transforms.ToTensor(),
])
db=torchvision.datasets.ImageFolder(root="pokemon",transform=tf)
loader=DataLoader(db,batch_size=32,shuffle=True)
print("make-code",db.class_to_idx)
for x, y in loader:
viz.images(x, nrow=8, win="batch", opts=dict(title="batch"))
viz.text(str(y.numpy()), win="lablel", opts=dict(title="batch-y"))
time.sleep(10)
"""
db=Pokemon(datapath,128,"train")
x,y=next(iter(db))
print("sample",x.shape,y.shape,y)
viz.image(db.denormalize(x),win="sample_x",opts=dict(title="sample_x"))
loader=DataLoader(db,batch_size=32,shuffle=True,num_workers=8)
for x,y in loader:
viz.images(db.denormalize(x),nrow=8,win="batch",opts=dict(title="batch"))
viz.text(str(y.numpy()),win="lablel",opts=dict(title="batch-y"))
time.sleep(10)
if __name__=="__main__":
main()
resnet.py(resnet Network model definition )
import torch
from torch import nn
from torch.nn import functional as F
class ResBlk(nn.Module):
"""
resnet block
"""
def __init__(self, ch_in, ch_out, stride=1):
"""
:param ch_in:
:param ch_out:
"""
super(ResBlk, self).__init__()
self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1)
self.bn1 = nn.BatchNorm2d(ch_out)
self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(ch_out)
self.extra = nn.Sequential()
if ch_out != ch_in:
# [b, ch_in, h, w] => [b, ch_out, h, w]
self.extra = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride),
nn.BatchNorm2d(ch_out)
)
def forward(self, x):
"""
:param x: [b, ch, h, w]
:return:
"""
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
# short cut.
# extra module: [b, ch_in, h, w] => [b, ch_out, h, w]
# element-wise add:
out = self.extra(x) + out
out = F.relu(out)
return out
class ResNet18(nn.Module):
def __init__(self, num_class):
super(ResNet18, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, stride=3, padding=0),
nn.BatchNorm2d(16)
)
# followed 4 blocks
# [b, 16, h, w] => [b, 32, h ,w]
self.blk1 = ResBlk(16, 32, stride=3)
# [b, 32, h, w] => [b, 64, h, w]
self.blk2 = ResBlk(32, 64, stride=3)
# # [b, 64, h, w] => [b, 128, h, w]
self.blk3 = ResBlk(64, 128, stride=2)
# # [b, 128, h, w] => [b, 256, h, w]
self.blk4 = ResBlk(128, 256, stride=2)
# [b, 256, 7, 7]
self.outlayer = nn.Linear(256*3*3, num_class)
def forward(self, x):
"""
:param x:
:return:
"""
x = F.relu(self.conv1(x))
# [b, 64, h, w] => [b, 1024, h, w]
x = self.blk1(x)
x = self.blk2(x)
x = self.blk3(x)
x = self.blk4(x)
# print(x.shape)
x = x.view(x.size(0), -1)
x = self.outlayer(x)
return x
def main():
blk = ResBlk(64, 128)
tmp = torch.randn(2, 64, 224, 224)
out = blk(tmp)
print('block:', out.shape)
model = ResNet18(5)
tmp = torch.randn(2, 3, 224, 224)
out = model(tmp)
print('resnet:', out.shape)
p = sum(map(lambda p:p.numel(), model.parameters()))
print('parameters size:', p)
if __name__ == '__main__':
main()
train.py( Training documents )
import torch
from torch import optim,nn
import visdom
import torchvision
from torch.utils.data import DataLoader
from pokemon import Pokemon
from resnet import ResNet18
batchsz=32
lr=1e-3
epochs=20
device=torch.device("cuda")
torch.manual_seed(1234)
train_db=Pokemon("pokemon",224,mode="train")
val_db=Pokemon("pokemon",224,mode="val")
test_db=Pokemon("pokemon",224,mode="test")
train_loader=DataLoader(train_db,batch_size=batchsz,shuffle=True,
num_workers=4)
val_loader=DataLoader(val_db,batch_size=batchsz, num_workers=2)
test_loader=DataLoader(test_db,batch_size=batchsz, num_workers=2)
viz=visdom.Visdom()
def evalute(model,loader):
correct=0
total=len(loader.dataset)
for x,y in loader:
x,y=x.to(device),y.to(device)
with torch.no_grad():
logits=model(x)
pred=logits.argmax(dim=1)
correct+=torch.eq(pred,y).sum().float().item()
return correct/total
def main():
model=ResNet18(5).to(device)
optimizer=optim.Adam(model.parameters(),lr=lr)
criteon=nn.CrossEntropyLoss()
best_acc,best_epoch=0,0
global_step=0
viz.line([0],[-1],win="loss",opts=dict(title="loss"))
viz.line([0],[-1],win="val_acc",opts=dict(title="val_acc"))
for epoch in range(epochs):
for step,(x,y) in enumerate(train_loader):
x,y=x.to(device),y.to(device)
logits=model(x)
# print("y", y.shape,y)
# print("logits",logits.shape,logits)
loss=criteon(logits,y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step%10==0:
print("epoch:",epoch,"step:",step,"loss:",loss.item())
viz.line([loss.item()], [global_step], win="loss", update="append")
global_step+=1
if epoch%1==0:
val_acc=evalute(model,val_loader)
viz.line([val_acc], [global_step], win="val_acc", update="append")
print("epoch:",epoch,"val_acc:",val_acc)
if val_acc>best_acc:
best_epoch=epoch
best_acc=val_acc
torch.save(model.state_dict(),"best.mdl")
print("best acc:",best_acc,"best epoch:",best_epoch)
model.load_state_dict(torch.load("best.mdl"))
print("loaded from ckpt!")
test_acc=evalute(model,test_loader)
print("test acc:",test_acc)
if __name__ == '__main__':
main()
utils.py
from matplotlib import pyplot as plt
import torch
from torch import nn
class Flatten(nn.Module):
def __init__(self):
super(Flatten, self).__init__()
def forward(self, x):
shape = torch.prod(torch.tensor(x.shape[1:])).item()
return x.view(-1, shape)
def plot_image(img, label, name):
fig = plt.figure()
for i in range(6):
plt.subplot(2, 3, i + 1)
plt.tight_layout()
plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
plt.title("{}: {}".format(name, label[i].item()))
plt.xticks([])
plt.yticks([])
plt.show()
train_transfer.py Transfer learning to achieve
import torch
from torch import optim,nn
import visdom
import torchvision
from torch.utils.data import DataLoader
from pokemon import Pokemon
# from resnet import ResNet18
from torchvision.models import resnet18
from utils import Flatten
batchsz=32
lr=1e-3
epochs=20
device=torch.device("cuda")
torch.manual_seed(1234)
train_db=Pokemon("pokemon",224,mode="train")
val_db=Pokemon("pokemon",224,mode="val")
test_db=Pokemon("pokemon",224,mode="test")
train_loader=DataLoader(train_db,batch_size=batchsz,shuffle=True,
num_workers=4)
val_loader=DataLoader(val_db,batch_size=batchsz, num_workers=2)
test_loader=DataLoader(test_db,batch_size=batchsz, num_workers=2)
viz=visdom.Visdom()
def evalute(model,loader):
correct=0
total=len(loader.dataset)
for x,y in loader:
x,y=x.to(device),y.to(device)
with torch.no_grad():
logits=model(x)
pred=logits.argmax(dim=1)
correct+=torch.eq(pred,y).sum().float().item()
return correct/total
def main():
# model=ResNet18(5).to(device)
trained_model=resnet18(pretrained=True)
model=nn.Sequential(*list(trained_model.children())[:-1], #[b,512,1,1]
Flatten(), #[b,512,1,1]=>[b,512]
nn.Linear(512,5)
).to(device)
# x=torch.randn(2,3,224,224)
# print(model(x).shape)
optimizer=optim.Adam(model.parameters(),lr=lr)
criteon=nn.CrossEntropyLoss()
best_acc,best_epoch=0,0
global_step=0
viz.line([0],[-1],win="loss",opts=dict(title="loss"))
viz.line([0],[-1],win="val_acc",opts=dict(title="val_acc"))
for epoch in range(epochs):
for step,(x,y) in enumerate(train_loader):
x,y=x.to(device),y.to(device)
logits=model(x)
loss=criteon(logits,y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step%10==0:
print("epoch:",epoch,"step:",step,"loss:",loss.item())
viz.line([loss.item()], [global_step], win="loss", update="append")
global_step+=1
if epoch%1==0:
val_acc=evalute(model,val_loader)
viz.line([val_acc], [global_step], win="val_acc", update="append")
print("epoch:",epoch,"val_acc:",val_acc)
if val_acc>best_acc:
best_epoch=epoch
best_acc=val_acc
torch.save(model.state_dict(),"best.mdl")
print("best acc:",best_acc,"best epoch:",best_epoch)
model.load_state_dict(torch.load("best.mdl"))
print("loaded from ckpt!")
test_acc=evalute(model,test_loader)
print("test acc:",test_acc)
if __name__ == '__main__':
main()
边栏推荐
- MySQL installation, Windows version
- [HDU 2096] 小明A+B
- Tabbar configuration at the bottom of wechat applet
- Knowledge representation (KR)
- Get the variable address of structure member in C language
- One article tells the latest and complete learning materials of flutter
- Four operations and derivative operations of MATLAB polynomials
- byte2String、string2Byte
- 16 channel water lamp experiment based on Proteus (assembly language)
- Hiengine: comparable to the local cloud native memory database engine
猜你喜欢
Error modulenotfounderror: no module named 'cv2 aruco‘
报错ModuleNotFoundError: No module named ‘cv2.aruco‘
Solve the problem of cache and database double write data consistency
abap查表程序
Detailed steps for upgrading window mysql5.5 to 5.7.36
Third party payment interface design
MySQL index - extended data
强化学习-学习笔记3 | 策略学习
Wireless WiFi learning 8-channel transmitting remote control module
Two minutes will take you to quickly master the project structure, resources, dependencies and localization of flutter
随机推荐
Video networkstate property
Learning JVM garbage collection 06 - memory set and card table (hotspot)
语义分割实验:Unet网络/MSRC2数据集
ZABBIX 5.0 - LNMP environment compilation and installation
Acid transaction theory
Handwriting blocking queue: condition + lock
A new WiFi option for smart home -- the application of simplewifi in wireless smart home
Hexadecimal conversion summary
MySQL basic operation -dql
ZABBIX agent2 installation
Uniapp + unicloud + Unipay realize wechat applet payment function
How to recover the information server and how to recover the server data [easy to understand]
自动化测试生命周期
Differences between IPv6 and IPv4 three departments including the office of network information technology promote IPv6 scale deployment
Flutter2 heavy release supports web and desktop applications
互联网公司实习岗位选择与简易版职业发展规划
Two minutes will take you to quickly master the project structure, resources, dependencies and localization of flutter
Leetcode-1. Sum of two numbers (Application of hash table)
Want to ask, how to choose a securities firm? Is it safe to open an account online?
Average lookup length when hash table lookup fails