当前位置:网站首页>Resnet18 actual battle Baoke dream spirit
Resnet18 actual battle Baoke dream spirit
2022-07-05 12:26:00 【Dongcheng West que】
File path

pokemon.py( Custom dataset load file )
import torch
import os,glob
import random,csv
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms
from PIL import Image
datapath="pokemon"
class Pokemon(Dataset):
def __init__(self,root,resize,mode):
super(Pokemon,self).__init__()
self.root=root
self.resize=resize
self.name2label={}
for name in sorted(os.listdir(os.path.join(root))):
if not os.path.isdir(os.path.join(root,name)):
continue
self.name2label[name]=len(self.name2label.keys())
# print(self.name2label)
self.images,self.labels=self.load_csv("images.csv")
if mode=="train": #60%
self.images=self.images[:int(0.6*len(self.images))]
self.labels=self.labels[:int(0.6*len(self.labels))]
elif mode=="val": #20% =60%->80%
self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]
else: #20% =80%->100%
self.images = self.images[int(0.8 * len(self.images)):]
self.labels = self.labels[int(0.8 * len(self.labels)):]
def load_csv(self,filename):
if os.path.exists(os.path.join(self.root,filename))==0:
images=[]
for name in self.name2label.keys():
images+=glob.glob(os.path.join(self.root,name,"*.png"))
images+=glob.glob(os.path.join(self.root,name,"*.jpg"))
images+=glob.glob(os.path.join(self.root,name,"*.jpeg"))
images+=glob.glob(os.path.join(self.root,name,"*.gif"))
# print(len(images),images)
# {bulbasaur:0,charmander:1,mewtwo:2 }
random.shuffle(images)
with open(os.path.join(self.root,filename),mode="w",newline="") as f:
writer=csv.writer(f)
for img in images: #E:\\datasets\\pokemon\\bulbasaur\\00000000.png
name=img.split(os.sep)[-2]
label=self.name2label[name]
#E:\\datasets\\pokemon\\bulbasaur\\00000000.png ,0
writer.writerow([img,label])
print("writen into csv file:",filename)
# read from csv file
images,labels=[],[]
with open(os.path.join(self.root,filename))as f:
reader=csv.reader(f)
for row in reader:
img,label=row
label=int(label)
images.append(img)
labels.append(label)
assert len(images)==len(labels)
return images,labels
def __len__(self):
return len(self.images)
def denormalize(self,x_hat):
mean=[0.485,0.456,0.406]
std=[0.229,0.224,0.225]
# x_hat=(x-mean)/std
# x=x_hat*std=mean
# x:[c,h,w]
# mean:[3]=>[3,1,1]
mean=torch.tensor(mean).unsqueeze(1).unsqueeze(1)
std=torch.tensor(std).unsqueeze(1).unsqueeze(1)
# print("x_hat",x_hat.shape,"std",std.shape,"mean",mean.shape)
x=x_hat*std+mean
return x
def __getitem__(self, idx):
#idx [0-len(images)]
#self.images,self.labels
#img:"pokemon\\bulbasaur\\0000000.png" label :0
img,label=self.images[idx],self.labels[idx]
tf=transforms.Compose([
lambda x:Image.open(x).convert("RGB"), #string path=>image data
transforms.Resize((int(self.resize*1.25), int(self.resize*1.25))),
transforms.RandomRotation(15),
transforms.CenterCrop(self.resize),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225]) #mean,std Is a statistical constant , Normalize the image
])
img=tf(img)
label=torch.tensor(label)
return img,label
def main():
from visdom import Visdom
import time
import torchvision
viz=Visdom()
# Load data set , Method 2
"""
tf = transforms.Compose([
transforms.Resize((64,64)),
transforms.ToTensor(),
])
db=torchvision.datasets.ImageFolder(root="pokemon",transform=tf)
loader=DataLoader(db,batch_size=32,shuffle=True)
print("make-code",db.class_to_idx)
for x, y in loader:
viz.images(x, nrow=8, win="batch", opts=dict(title="batch"))
viz.text(str(y.numpy()), win="lablel", opts=dict(title="batch-y"))
time.sleep(10)
"""
db=Pokemon(datapath,128,"train")
x,y=next(iter(db))
print("sample",x.shape,y.shape,y)
viz.image(db.denormalize(x),win="sample_x",opts=dict(title="sample_x"))
loader=DataLoader(db,batch_size=32,shuffle=True,num_workers=8)
for x,y in loader:
viz.images(db.denormalize(x),nrow=8,win="batch",opts=dict(title="batch"))
viz.text(str(y.numpy()),win="lablel",opts=dict(title="batch-y"))
time.sleep(10)
if __name__=="__main__":
main()resnet.py(resnet Network model definition )
import torch
from torch import nn
from torch.nn import functional as F
class ResBlk(nn.Module):
"""
resnet block
"""
def __init__(self, ch_in, ch_out, stride=1):
"""
:param ch_in:
:param ch_out:
"""
super(ResBlk, self).__init__()
self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1)
self.bn1 = nn.BatchNorm2d(ch_out)
self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(ch_out)
self.extra = nn.Sequential()
if ch_out != ch_in:
# [b, ch_in, h, w] => [b, ch_out, h, w]
self.extra = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride),
nn.BatchNorm2d(ch_out)
)
def forward(self, x):
"""
:param x: [b, ch, h, w]
:return:
"""
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
# short cut.
# extra module: [b, ch_in, h, w] => [b, ch_out, h, w]
# element-wise add:
out = self.extra(x) + out
out = F.relu(out)
return out
class ResNet18(nn.Module):
def __init__(self, num_class):
super(ResNet18, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, stride=3, padding=0),
nn.BatchNorm2d(16)
)
# followed 4 blocks
# [b, 16, h, w] => [b, 32, h ,w]
self.blk1 = ResBlk(16, 32, stride=3)
# [b, 32, h, w] => [b, 64, h, w]
self.blk2 = ResBlk(32, 64, stride=3)
# # [b, 64, h, w] => [b, 128, h, w]
self.blk3 = ResBlk(64, 128, stride=2)
# # [b, 128, h, w] => [b, 256, h, w]
self.blk4 = ResBlk(128, 256, stride=2)
# [b, 256, 7, 7]
self.outlayer = nn.Linear(256*3*3, num_class)
def forward(self, x):
"""
:param x:
:return:
"""
x = F.relu(self.conv1(x))
# [b, 64, h, w] => [b, 1024, h, w]
x = self.blk1(x)
x = self.blk2(x)
x = self.blk3(x)
x = self.blk4(x)
# print(x.shape)
x = x.view(x.size(0), -1)
x = self.outlayer(x)
return x
def main():
blk = ResBlk(64, 128)
tmp = torch.randn(2, 64, 224, 224)
out = blk(tmp)
print('block:', out.shape)
model = ResNet18(5)
tmp = torch.randn(2, 3, 224, 224)
out = model(tmp)
print('resnet:', out.shape)
p = sum(map(lambda p:p.numel(), model.parameters()))
print('parameters size:', p)
if __name__ == '__main__':
main()train.py( Training documents )
import torch
from torch import optim,nn
import visdom
import torchvision
from torch.utils.data import DataLoader
from pokemon import Pokemon
from resnet import ResNet18
batchsz=32
lr=1e-3
epochs=20
device=torch.device("cuda")
torch.manual_seed(1234)
train_db=Pokemon("pokemon",224,mode="train")
val_db=Pokemon("pokemon",224,mode="val")
test_db=Pokemon("pokemon",224,mode="test")
train_loader=DataLoader(train_db,batch_size=batchsz,shuffle=True,
num_workers=4)
val_loader=DataLoader(val_db,batch_size=batchsz, num_workers=2)
test_loader=DataLoader(test_db,batch_size=batchsz, num_workers=2)
viz=visdom.Visdom()
def evalute(model,loader):
correct=0
total=len(loader.dataset)
for x,y in loader:
x,y=x.to(device),y.to(device)
with torch.no_grad():
logits=model(x)
pred=logits.argmax(dim=1)
correct+=torch.eq(pred,y).sum().float().item()
return correct/total
def main():
model=ResNet18(5).to(device)
optimizer=optim.Adam(model.parameters(),lr=lr)
criteon=nn.CrossEntropyLoss()
best_acc,best_epoch=0,0
global_step=0
viz.line([0],[-1],win="loss",opts=dict(title="loss"))
viz.line([0],[-1],win="val_acc",opts=dict(title="val_acc"))
for epoch in range(epochs):
for step,(x,y) in enumerate(train_loader):
x,y=x.to(device),y.to(device)
logits=model(x)
# print("y", y.shape,y)
# print("logits",logits.shape,logits)
loss=criteon(logits,y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step%10==0:
print("epoch:",epoch,"step:",step,"loss:",loss.item())
viz.line([loss.item()], [global_step], win="loss", update="append")
global_step+=1
if epoch%1==0:
val_acc=evalute(model,val_loader)
viz.line([val_acc], [global_step], win="val_acc", update="append")
print("epoch:",epoch,"val_acc:",val_acc)
if val_acc>best_acc:
best_epoch=epoch
best_acc=val_acc
torch.save(model.state_dict(),"best.mdl")
print("best acc:",best_acc,"best epoch:",best_epoch)
model.load_state_dict(torch.load("best.mdl"))
print("loaded from ckpt!")
test_acc=evalute(model,test_loader)
print("test acc:",test_acc)
if __name__ == '__main__':
main()utils.py
from matplotlib import pyplot as plt
import torch
from torch import nn
class Flatten(nn.Module):
def __init__(self):
super(Flatten, self).__init__()
def forward(self, x):
shape = torch.prod(torch.tensor(x.shape[1:])).item()
return x.view(-1, shape)
def plot_image(img, label, name):
fig = plt.figure()
for i in range(6):
plt.subplot(2, 3, i + 1)
plt.tight_layout()
plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
plt.title("{}: {}".format(name, label[i].item()))
plt.xticks([])
plt.yticks([])
plt.show()
train_transfer.py Transfer learning to achieve
import torch
from torch import optim,nn
import visdom
import torchvision
from torch.utils.data import DataLoader
from pokemon import Pokemon
# from resnet import ResNet18
from torchvision.models import resnet18
from utils import Flatten
batchsz=32
lr=1e-3
epochs=20
device=torch.device("cuda")
torch.manual_seed(1234)
train_db=Pokemon("pokemon",224,mode="train")
val_db=Pokemon("pokemon",224,mode="val")
test_db=Pokemon("pokemon",224,mode="test")
train_loader=DataLoader(train_db,batch_size=batchsz,shuffle=True,
num_workers=4)
val_loader=DataLoader(val_db,batch_size=batchsz, num_workers=2)
test_loader=DataLoader(test_db,batch_size=batchsz, num_workers=2)
viz=visdom.Visdom()
def evalute(model,loader):
correct=0
total=len(loader.dataset)
for x,y in loader:
x,y=x.to(device),y.to(device)
with torch.no_grad():
logits=model(x)
pred=logits.argmax(dim=1)
correct+=torch.eq(pred,y).sum().float().item()
return correct/total
def main():
# model=ResNet18(5).to(device)
trained_model=resnet18(pretrained=True)
model=nn.Sequential(*list(trained_model.children())[:-1], #[b,512,1,1]
Flatten(), #[b,512,1,1]=>[b,512]
nn.Linear(512,5)
).to(device)
# x=torch.randn(2,3,224,224)
# print(model(x).shape)
optimizer=optim.Adam(model.parameters(),lr=lr)
criteon=nn.CrossEntropyLoss()
best_acc,best_epoch=0,0
global_step=0
viz.line([0],[-1],win="loss",opts=dict(title="loss"))
viz.line([0],[-1],win="val_acc",opts=dict(title="val_acc"))
for epoch in range(epochs):
for step,(x,y) in enumerate(train_loader):
x,y=x.to(device),y.to(device)
logits=model(x)
loss=criteon(logits,y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step%10==0:
print("epoch:",epoch,"step:",step,"loss:",loss.item())
viz.line([loss.item()], [global_step], win="loss", update="append")
global_step+=1
if epoch%1==0:
val_acc=evalute(model,val_loader)
viz.line([val_acc], [global_step], win="val_acc", update="append")
print("epoch:",epoch,"val_acc:",val_acc)
if val_acc>best_acc:
best_epoch=epoch
best_acc=val_acc
torch.save(model.state_dict(),"best.mdl")
print("best acc:",best_acc,"best epoch:",best_epoch)
model.load_state_dict(torch.load("best.mdl"))
print("loaded from ckpt!")
test_acc=evalute(model,test_loader)
print("test acc:",test_acc)
if __name__ == '__main__':
main()
边栏推荐
- Summary of C language learning problems (VS)
- The evolution of mobile cross platform technology
- Learn JVM garbage collection 02 - a brief introduction to the reference and recycling method area
- Linux Installation and deployment lamp (apache+mysql+php)
- July Huaqing learning-1
- Constructing expression binary tree with prefix expression
- Want to ask, how to choose a securities firm? Is it safe to open an account online?
- [HDU 2096] 小明A+B
- [superhard core] is the core technology of redis
- What is digital existence? Digital transformation starts with digital existence
猜你喜欢

Solve the problem of cache and database double write data consistency

Simple production of wechat applet cloud development authorization login

Get the variable address of structure member in C language

Take you hand in hand to develop a service monitoring component

Constructing expression binary tree with prefix expression

Redis highly available sentinel cluster
Take you two minutes to quickly master the route and navigation of flutter

Learn the memory management of JVM 02 - memory allocation of JVM

Interviewer: is acid fully guaranteed for redis transactions?

Matlab imoverlay function (burn binary mask into two-dimensional image)
随机推荐
Get the variable address of structure member in C language
Constructing expression binary tree with prefix expression
Matlab superpixels function (2D super pixel over segmentation of image)
Which domestic cloud management platform manufacturer is good in 2022? Why?
Take you two minutes to quickly master the route and navigation of flutter
One article tells the latest and complete learning materials of flutter
The evolution of mobile cross platform technology
MySQL index (1)
Wireless WiFi learning 8-channel transmitting remote control module
Why do you always fail in automated tests?
Select drop-down box realizes three-level linkage of provinces and cities in China
Learn JVM garbage collection 02 - a brief introduction to the reference and recycling method area
Learn memory management of JVM 01 - first memory
What is the difference between canvas and SVG?
Third party payment interface design
[pytorch modifies the pre training model: there is little difference between the measured loading pre training model and the random initialization of the model]
Matlab struct function (structure array)
byte2String、string2Byte
一款新型的智能家居WiFi选择方案——SimpleWiFi在无线智能家居中的应用
Video networkstate property