当前位置:网站首页>Pytoch --- use pytoch to predict birds
Pytoch --- use pytoch to predict birds
2022-07-02 04:11:00 【Brother Shui is very water】
One 、 The datasets in the code can be obtained through the following link
Baidu online disk extraction code :lala
Two 、 Code running environment
Pytorch-gpu==1.7.1
Python==3.8
3、 ... and 、 Data set processing codes are as follows
import glob
import numpy as np
from torch.utils.data.dataset import T_co
from torchvision import transforms
from torch.utils import data
from PIL import Image
import matplotlib.pyplot as plt
import torch
class FeatureDataset(data.Dataset):
def __getitem__(self, index):
return self.feature_list[index], torch.as_tensor(data=self.label_list[index], dtype=torch.long)
def __init__(self, feature_list, label_list):
self.feature_list = feature_list
self.label_list = label_list
def __len__(self):
return len(self.feature_list)
class BirdsDataset(data.Dataset):
def __init__(self, img_path, labels, trans):
self.imgs = img_path
self.labels = labels
self.trans = trans
def __getitem__(self, index) -> T_co:
imgg = self.imgs[index]
label = self.labels[index]
pil_img = Image.open(imgg)
pil_img = pil_img.convert('RGB')
img_tensor = self.trans(pil_img)
return img_tensor, label
def __len__(self):
return len(self.imgs)
def loader_data():
file_names = glob.glob(r'birds\*\*.jpg')
classes = np.unique([name.split('\\')[1].split('.')[1] for name in file_names])
index_classes = dict((index, cla) for index, cla in enumerate(classes))
classes_index = dict((v, k) for k, v in index_classes.items())
all_labels = []
for name in file_names:
for cla in classes:
if cla in name:
all_labels.append(classes_index.get(cla))
np.random.seed(200)
random_index = np.random.permutation(len(file_names))
file_names = np.array(file_names)[random_index]
all_labels = np.array(all_labels)[random_index]
i = int(len(file_names) * 0.8)
train_path = file_names[:i]
train_label = all_labels[:i]
test_path = file_names[i:]
test_label = all_labels[i:]
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
train_ds = BirdsDataset(img_path=train_path, labels=train_label, trans=transform)
test_ds = BirdsDataset(img_path=test_path, labels=test_label, trans=transform)
train_dl = data.DataLoader(dataset=train_ds, batch_size=32, shuffle=True)
test_dl = data.DataLoader(dataset=test_ds, batch_size=32)
return train_dl, test_dl, index_classes
if __name__ == '__main__':
a, b, c = loader_data()
img_batch, label_batch = next(iter(b))
plt.figure(figsize=(12, 8))
for ii, (img, lab) in enumerate(zip(img_batch[:6], label_batch[:6])):
img = img.permute(1, 2, 0).numpy()
plt.subplot(2, 3, ii + 1)
plt.axis('off')
plt.title(c.get(lab.item()))
plt.imshow(img)
plt.show()
Four 、 The construction code of the model is as follows
import torchvision
from torch import nn
class FCModel(nn.Module):
def __init__(self, in_size, out_size):
super().__init__()
self.lin1 = nn.Linear(in_features=in_size, out_features=2048)
self.lin2 = nn.Linear(in_features=2048, out_features=1024)
self.lin3 = nn.Linear(in_features=1024, out_features=out_size)
def forward(self, x):
return self.lin3(self.lin2(self.lin1(x)))
def load_model():
model = torchvision.models.densenet121(pretrained=True).features
for parameter in model.parameters():
parameter.requires_grad = False
return model
if __name__ == '__main__':
mol = load_model()
print(mol)
5、 ... and 、 The training code of the model is as follows
import torch
import tqdm
from data_loader import loader_data, FeatureDataset
from model_loader import load_model, FCModel
from torch.utils import data
from torch import nn, optim
from sklearn.metrics import accuracy_score
import numpy as np
import os
# Environment configuration
devices = 'cuda' if torch.cuda.is_available() else 'cpu'
# Load data
train_dl, test_dl, index_classes = loader_data()
# Load model
model = load_model()
model = model.to(devices)
# Start feature extraction
train_features = []
train_labels = []
tqdm_train_dl = tqdm.tqdm(iterable=train_dl, total=len(train_dl))
for img, lab in tqdm_train_dl:
out = model(img.to(devices))
out = out.view(out.size(0), -1)
train_features.extend(out.cpu().data)
train_labels.extend(lab)
tqdm_train_dl.close()
test_features = []
test_labels = []
tqdm_test_dl = tqdm.tqdm(iterable=test_dl, total=len(test_dl))
for img, lab in tqdm_test_dl:
out = model(img.to(devices))
out = out.view(out.size(0), -1)
test_features.extend(out.cpu().data)
test_labels.extend(lab)
tqdm_test_dl.close()
# Building feature data sets
train_feat_ds = FeatureDataset(feature_list=train_features, label_list=train_labels)
test_feat_ds = FeatureDataset(feature_list=test_features, label_list=test_labels)
train_feat_dl = data.DataLoader(dataset=train_feat_ds, batch_size=32, shuffle=True)
test_feat_dl = data.DataLoader(dataset=test_feat_ds, batch_size=32)
# Build feature classifier
in_size = train_features[0].shape[0]
net = FCModel(in_size=in_size, out_size=200)
net = net.to(devices)
# Configuration used for training
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(params=net.parameters(), lr=0.00001)
# Start training
for epoch in range(100):
net.train()
train_accuracy_sum = []
train_loss_sum = []
train_tqdm = tqdm.tqdm(iterable=train_feat_dl, total=len(train_feat_dl))
train_tqdm.set_description_str('Train epoch {:2d}'.format(epoch))
for img_feature, img_label in train_tqdm:
img_feature, img_label = img_feature.to(devices), img_label.to(devices)
pred = net(img_feature)
loss = loss_fn(pred, img_label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Display indicators
train_loss_sum.append(loss.item())
pred = torch.argmax(input=pred, dim=-1)
train_accuracy_sum.append(accuracy_score(y_true=img_label.cpu().numpy(), y_pred=pred.cpu().numpy()))
train_tqdm.set_postfix_str(
'loss is {:14f}, accuracy is {:14f}'.format(np.mean(train_loss_sum), np.mean(train_accuracy_sum)))
train_tqdm.close()
with torch.no_grad():
net.eval()
test_accuracy_sum = []
test_loss_sum = []
test_tqdm = tqdm.tqdm(iterable=test_feat_dl, total=len(test_feat_dl))
test_tqdm.set_description_str('Test epoch {:2d}'.format(epoch))
for img_feature, img_label in test_tqdm:
img_feature, img_label = img_feature.to(devices), img_label.to(devices)
pred = net(img_feature)
loss = loss_fn(pred, img_label)
# Display indicators
test_loss_sum.append(loss.item())
pred = torch.argmax(input=pred, dim=-1)
test_accuracy_sum.append(accuracy_score(y_true=img_label.cpu().numpy(), y_pred=pred.cpu().numpy()))
test_tqdm.set_postfix_str(
'loss is {:14f}, accuracy is {:14f}'.format(np.mean(test_loss_sum), np.mean(test_accuracy_sum)))
test_tqdm.close()
# Save model
if not os.path.exists(os.path.join('model_data')):
os.mkdir(os.path.join('model_data'))
torch.save(net.state_dict(), os.path.join('model_data', 'net.pth'))
6、 ... and 、 The prediction code of the model is as follows
from data_loader import loader_data
from model_loader import load_model, FCModel
import torch
import os
import matplotlib.pyplot as plt
# Configuration of environment variables
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Data loading
train_dl, test_dl, index_classes = loader_data()
image, label = next(iter(test_dl))
image, label = image.to(device), label.to(device)
# Model loading
model = load_model()
model.eval()
net = FCModel(in_size=50176, out_size=200)
model_state_dict = torch.load(os.path.join('model_data', 'net.pth'))
net.load_state_dict(model_state_dict)
net.eval()
model = model.to(device=device)
net = net.to(device=device)
# Predict the model
index = 5
with torch.no_grad():
feature = model(image)
feature = feature.view(feature.size(0), -1)
pre = net(feature)
pre = torch.argmax(input=pre, dim=-1)
plt.axis('off')
plt.title('predict result: ' + index_classes.get(pre[index].cpu().item()) + '\nlabel result: ' + index_classes.get(
label[index].cpu().item()))
plt.imshow(image[index].cpu().permute(1, 2, 0))
plt.savefig('result.png')
plt.show()
7、 ... and 、 The running result of the code is as follows

边栏推荐
- 10 minutes to understand CMS garbage collector in JVM
- pip 安装第三方库
- 2022-07-01:某公司年会上,大家要玩一食发奖金游戏,一共有n个员工, 每个员工都有建设积分和捣乱积分, 他们需要排成一队,在队伍最前面的一定是老板,老板也有建设积分和捣乱积分, 排好队后,所有
- Wechat applet map annotation
- What is 5g industrial wireless gateway? What functions can 5g industrial wireless gateway achieve?
- 【leetcode】74. Search 2D matrix
- Welcome the winter vacation multi school league game 2 partial solution (B, C, D, F, G, H)
- 5g era is coming in an all-round way, talking about the past and present life of mobile communication
- Sorted out an ECS summer money saving secret, this time @ old users come and take it away
- 【人员密度检测】基于形态学处理和GRNN网络的人员密度检测matlab仿真
猜你喜欢

66.qt quick-qml自定义日历组件(支持竖屏和横屏)

Demonstration description of integrated base scheme

Play with concurrency: draw a thread state transition diagram

Playing with concurrency: what are the ways of communication between threads?

10 minutes to understand CMS garbage collector in JVM

Common sense of cloud server security settings

Microsoft Research Institute's new book "Fundamentals of data science", 479 Pages pdf

Go语言介绍

Go language introduction

Actual combat | use composite material 3 in application
随机推荐
Lost a few hairs, and finally learned - graph traversal -dfs and BFS
Actual combat | use composite material 3 in application
go 包的使用
How much can a job hopping increase? Today, I saw the ceiling of job hopping.
Déchirure à la main - tri
A summary of common interview questions in 2022, including 25 technology stacks, has helped me successfully get an offer from Tencent
Raspberry pie GPIO pin controls traffic light and buzzer
Document declaration and character encoding
Hand tear - sort
[personal notes] PHP common functions - custom functions
Homework in Chapter 3 of slam course of dark blue vision -- derivative application of T6 common functions
Welcome the winter vacation multi school league game 2 partial solution (B, C, D, F, G, H)
Recently, the weather has been extremely hot, so collect the weather data of Beijing, Shanghai, Guangzhou and Shenzhen last year, and make a visual map
【c语言】动态规划---入门到起立
Pandora IOT development board learning (RT thread) - Experiment 1 LED flashing experiment (learning notes)
Jetpack之LiveData扩展MediatorLiveData
Feature Engineering: summary of common feature transformation methods
Pytorch---使用Pytorch实现U-Net进行语义分割
u本位合约爆仓清算解决方案建议
Wechat applet JWT login issue token