当前位置:网站首页>Pytoch --- use pytoch to predict birds
Pytoch --- use pytoch to predict birds
2022-07-02 04:11:00 【Brother Shui is very water】
One 、 The datasets in the code can be obtained through the following link
Baidu online disk extraction code :lala
Two 、 Code running environment
Pytorch-gpu==1.7.1
Python==3.8
3、 ... and 、 Data set processing codes are as follows
import glob
import numpy as np
from torch.utils.data.dataset import T_co
from torchvision import transforms
from torch.utils import data
from PIL import Image
import matplotlib.pyplot as plt
import torch
class FeatureDataset(data.Dataset):
def __getitem__(self, index):
return self.feature_list[index], torch.as_tensor(data=self.label_list[index], dtype=torch.long)
def __init__(self, feature_list, label_list):
self.feature_list = feature_list
self.label_list = label_list
def __len__(self):
return len(self.feature_list)
class BirdsDataset(data.Dataset):
def __init__(self, img_path, labels, trans):
self.imgs = img_path
self.labels = labels
self.trans = trans
def __getitem__(self, index) -> T_co:
imgg = self.imgs[index]
label = self.labels[index]
pil_img = Image.open(imgg)
pil_img = pil_img.convert('RGB')
img_tensor = self.trans(pil_img)
return img_tensor, label
def __len__(self):
return len(self.imgs)
def loader_data():
file_names = glob.glob(r'birds\*\*.jpg')
classes = np.unique([name.split('\\')[1].split('.')[1] for name in file_names])
index_classes = dict((index, cla) for index, cla in enumerate(classes))
classes_index = dict((v, k) for k, v in index_classes.items())
all_labels = []
for name in file_names:
for cla in classes:
if cla in name:
all_labels.append(classes_index.get(cla))
np.random.seed(200)
random_index = np.random.permutation(len(file_names))
file_names = np.array(file_names)[random_index]
all_labels = np.array(all_labels)[random_index]
i = int(len(file_names) * 0.8)
train_path = file_names[:i]
train_label = all_labels[:i]
test_path = file_names[i:]
test_label = all_labels[i:]
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
train_ds = BirdsDataset(img_path=train_path, labels=train_label, trans=transform)
test_ds = BirdsDataset(img_path=test_path, labels=test_label, trans=transform)
train_dl = data.DataLoader(dataset=train_ds, batch_size=32, shuffle=True)
test_dl = data.DataLoader(dataset=test_ds, batch_size=32)
return train_dl, test_dl, index_classes
if __name__ == '__main__':
a, b, c = loader_data()
img_batch, label_batch = next(iter(b))
plt.figure(figsize=(12, 8))
for ii, (img, lab) in enumerate(zip(img_batch[:6], label_batch[:6])):
img = img.permute(1, 2, 0).numpy()
plt.subplot(2, 3, ii + 1)
plt.axis('off')
plt.title(c.get(lab.item()))
plt.imshow(img)
plt.show()
Four 、 The construction code of the model is as follows
import torchvision
from torch import nn
class FCModel(nn.Module):
def __init__(self, in_size, out_size):
super().__init__()
self.lin1 = nn.Linear(in_features=in_size, out_features=2048)
self.lin2 = nn.Linear(in_features=2048, out_features=1024)
self.lin3 = nn.Linear(in_features=1024, out_features=out_size)
def forward(self, x):
return self.lin3(self.lin2(self.lin1(x)))
def load_model():
model = torchvision.models.densenet121(pretrained=True).features
for parameter in model.parameters():
parameter.requires_grad = False
return model
if __name__ == '__main__':
mol = load_model()
print(mol)
5、 ... and 、 The training code of the model is as follows
import torch
import tqdm
from data_loader import loader_data, FeatureDataset
from model_loader import load_model, FCModel
from torch.utils import data
from torch import nn, optim
from sklearn.metrics import accuracy_score
import numpy as np
import os
# Environment configuration
devices = 'cuda' if torch.cuda.is_available() else 'cpu'
# Load data
train_dl, test_dl, index_classes = loader_data()
# Load model
model = load_model()
model = model.to(devices)
# Start feature extraction
train_features = []
train_labels = []
tqdm_train_dl = tqdm.tqdm(iterable=train_dl, total=len(train_dl))
for img, lab in tqdm_train_dl:
out = model(img.to(devices))
out = out.view(out.size(0), -1)
train_features.extend(out.cpu().data)
train_labels.extend(lab)
tqdm_train_dl.close()
test_features = []
test_labels = []
tqdm_test_dl = tqdm.tqdm(iterable=test_dl, total=len(test_dl))
for img, lab in tqdm_test_dl:
out = model(img.to(devices))
out = out.view(out.size(0), -1)
test_features.extend(out.cpu().data)
test_labels.extend(lab)
tqdm_test_dl.close()
# Building feature data sets
train_feat_ds = FeatureDataset(feature_list=train_features, label_list=train_labels)
test_feat_ds = FeatureDataset(feature_list=test_features, label_list=test_labels)
train_feat_dl = data.DataLoader(dataset=train_feat_ds, batch_size=32, shuffle=True)
test_feat_dl = data.DataLoader(dataset=test_feat_ds, batch_size=32)
# Build feature classifier
in_size = train_features[0].shape[0]
net = FCModel(in_size=in_size, out_size=200)
net = net.to(devices)
# Configuration used for training
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(params=net.parameters(), lr=0.00001)
# Start training
for epoch in range(100):
net.train()
train_accuracy_sum = []
train_loss_sum = []
train_tqdm = tqdm.tqdm(iterable=train_feat_dl, total=len(train_feat_dl))
train_tqdm.set_description_str('Train epoch {:2d}'.format(epoch))
for img_feature, img_label in train_tqdm:
img_feature, img_label = img_feature.to(devices), img_label.to(devices)
pred = net(img_feature)
loss = loss_fn(pred, img_label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Display indicators
train_loss_sum.append(loss.item())
pred = torch.argmax(input=pred, dim=-1)
train_accuracy_sum.append(accuracy_score(y_true=img_label.cpu().numpy(), y_pred=pred.cpu().numpy()))
train_tqdm.set_postfix_str(
'loss is {:14f}, accuracy is {:14f}'.format(np.mean(train_loss_sum), np.mean(train_accuracy_sum)))
train_tqdm.close()
with torch.no_grad():
net.eval()
test_accuracy_sum = []
test_loss_sum = []
test_tqdm = tqdm.tqdm(iterable=test_feat_dl, total=len(test_feat_dl))
test_tqdm.set_description_str('Test epoch {:2d}'.format(epoch))
for img_feature, img_label in test_tqdm:
img_feature, img_label = img_feature.to(devices), img_label.to(devices)
pred = net(img_feature)
loss = loss_fn(pred, img_label)
# Display indicators
test_loss_sum.append(loss.item())
pred = torch.argmax(input=pred, dim=-1)
test_accuracy_sum.append(accuracy_score(y_true=img_label.cpu().numpy(), y_pred=pred.cpu().numpy()))
test_tqdm.set_postfix_str(
'loss is {:14f}, accuracy is {:14f}'.format(np.mean(test_loss_sum), np.mean(test_accuracy_sum)))
test_tqdm.close()
# Save model
if not os.path.exists(os.path.join('model_data')):
os.mkdir(os.path.join('model_data'))
torch.save(net.state_dict(), os.path.join('model_data', 'net.pth'))
6、 ... and 、 The prediction code of the model is as follows
from data_loader import loader_data
from model_loader import load_model, FCModel
import torch
import os
import matplotlib.pyplot as plt
# Configuration of environment variables
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Data loading
train_dl, test_dl, index_classes = loader_data()
image, label = next(iter(test_dl))
image, label = image.to(device), label.to(device)
# Model loading
model = load_model()
model.eval()
net = FCModel(in_size=50176, out_size=200)
model_state_dict = torch.load(os.path.join('model_data', 'net.pth'))
net.load_state_dict(model_state_dict)
net.eval()
model = model.to(device=device)
net = net.to(device=device)
# Predict the model
index = 5
with torch.no_grad():
feature = model(image)
feature = feature.view(feature.size(0), -1)
pre = net(feature)
pre = torch.argmax(input=pre, dim=-1)
plt.axis('off')
plt.title('predict result: ' + index_classes.get(pre[index].cpu().item()) + '\nlabel result: ' + index_classes.get(
label[index].cpu().item()))
plt.imshow(image[index].cpu().permute(1, 2, 0))
plt.savefig('result.png')
plt.show()
7、 ... and 、 The running result of the code is as follows

边栏推荐
- Handling of inconsistency between cursor and hinttext position in shutter textfield
- 初识P4语言
- BiShe cinema ticket purchasing system based on SSM
- SQL: common SQL commands
- 【leetcode】34. Find the first and last positions of elements in a sorted array
- Set vscode. When double clicking, the selected string includes the $symbol - convenient for PHP operation
- [wireless image transmission] FPGA based simple wireless image transmission system Verilog development, matlab assisted verification
- Yyds dry goods inventory kubernetes introduction foundation pod concept and related operations
- Déchirure à la main - tri
- The original author is out! Faker. JS has been controlled by the community..
猜你喜欢

【无线图传】基于FPGA的简易无线图像传输系统verilog开发,matlab辅助验证

Sorted out an ECS summer money saving secret, this time @ old users come and take it away

First acquaintance with P4 language

【leetcode】74. Search 2D matrix

【c语言】基础篇学习笔记

How should the team choose the feature branch development mode or trunk development mode?

Vite: configure IP access

【c语言】动态规划---入门到起立

Go语言介绍

Li Kou interview question 02.08 Loop detection
随机推荐
[tips] use Matlab GUI to read files in dialog mode
A thorough understanding of the development of scorecards - the determination of Y (Vintage analysis, rolling rate analysis, etc.)
Uni app - realize the countdown of 60 seconds to obtain the mobile verification code (mobile number + verification code login function)
C语言:逻辑运算和判断选择结构例题
Sorted out an ECS summer money saving secret, this time @ old users come and take it away
BGP experiment the next day
Spring moves are coming. Watch the gods fight
Playing with concurrency: what are the ways of communication between threads?
The original author is out! Faker. JS has been controlled by the community..
Fingertips life Chapter 4 modules and packages
Homework in Chapter 3 of slam course of dark blue vision -- derivative application of T6 common functions
文档声明与字符编码
[ibdfe] matlab simulation of frequency domain equalization based on ibdfe
Sword finger offer II 006 Sort the sum of two numbers in the array
Go language introduction
How should the team choose the feature branch development mode or trunk development mode?
Www2022 | know your way back: self training method of graph neural network under distribution and migration
Dare to go out for an interview without learning some distributed technology?
如何解决在editor模式下 无法删除物体的问题
Common sense of cloud server security settings