当前位置:网站首页>Pytorch---使用Pytorch进行鸟类的预测
Pytorch---使用Pytorch进行鸟类的预测
2022-07-02 04:04:00 【水哥很水】
一、代码中的数据集可以通过以下链接获取
二、代码运行环境
Pytorch-gpu==1.7.1
Python==3.8
三、数据集处理代码如下所示
import glob
import numpy as np
from torch.utils.data.dataset import T_co
from torchvision import transforms
from torch.utils import data
from PIL import Image
import matplotlib.pyplot as plt
import torch
class FeatureDataset(data.Dataset):
def __getitem__(self, index):
return self.feature_list[index], torch.as_tensor(data=self.label_list[index], dtype=torch.long)
def __init__(self, feature_list, label_list):
self.feature_list = feature_list
self.label_list = label_list
def __len__(self):
return len(self.feature_list)
class BirdsDataset(data.Dataset):
def __init__(self, img_path, labels, trans):
self.imgs = img_path
self.labels = labels
self.trans = trans
def __getitem__(self, index) -> T_co:
imgg = self.imgs[index]
label = self.labels[index]
pil_img = Image.open(imgg)
pil_img = pil_img.convert('RGB')
img_tensor = self.trans(pil_img)
return img_tensor, label
def __len__(self):
return len(self.imgs)
def loader_data():
file_names = glob.glob(r'birds\*\*.jpg')
classes = np.unique([name.split('\\')[1].split('.')[1] for name in file_names])
index_classes = dict((index, cla) for index, cla in enumerate(classes))
classes_index = dict((v, k) for k, v in index_classes.items())
all_labels = []
for name in file_names:
for cla in classes:
if cla in name:
all_labels.append(classes_index.get(cla))
np.random.seed(200)
random_index = np.random.permutation(len(file_names))
file_names = np.array(file_names)[random_index]
all_labels = np.array(all_labels)[random_index]
i = int(len(file_names) * 0.8)
train_path = file_names[:i]
train_label = all_labels[:i]
test_path = file_names[i:]
test_label = all_labels[i:]
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
train_ds = BirdsDataset(img_path=train_path, labels=train_label, trans=transform)
test_ds = BirdsDataset(img_path=test_path, labels=test_label, trans=transform)
train_dl = data.DataLoader(dataset=train_ds, batch_size=32, shuffle=True)
test_dl = data.DataLoader(dataset=test_ds, batch_size=32)
return train_dl, test_dl, index_classes
if __name__ == '__main__':
a, b, c = loader_data()
img_batch, label_batch = next(iter(b))
plt.figure(figsize=(12, 8))
for ii, (img, lab) in enumerate(zip(img_batch[:6], label_batch[:6])):
img = img.permute(1, 2, 0).numpy()
plt.subplot(2, 3, ii + 1)
plt.axis('off')
plt.title(c.get(lab.item()))
plt.imshow(img)
plt.show()
四、模型的构建代码如下所示
import torchvision
from torch import nn
class FCModel(nn.Module):
def __init__(self, in_size, out_size):
super().__init__()
self.lin1 = nn.Linear(in_features=in_size, out_features=2048)
self.lin2 = nn.Linear(in_features=2048, out_features=1024)
self.lin3 = nn.Linear(in_features=1024, out_features=out_size)
def forward(self, x):
return self.lin3(self.lin2(self.lin1(x)))
def load_model():
model = torchvision.models.densenet121(pretrained=True).features
for parameter in model.parameters():
parameter.requires_grad = False
return model
if __name__ == '__main__':
mol = load_model()
print(mol)
五、模型的训练代码如下所示
import torch
import tqdm
from data_loader import loader_data, FeatureDataset
from model_loader import load_model, FCModel
from torch.utils import data
from torch import nn, optim
from sklearn.metrics import accuracy_score
import numpy as np
import os
# 环境配置
devices = 'cuda' if torch.cuda.is_available() else 'cpu'
# 加载数据
train_dl, test_dl, index_classes = loader_data()
# 加载模型
model = load_model()
model = model.to(devices)
# 开始进行特征提取
train_features = []
train_labels = []
tqdm_train_dl = tqdm.tqdm(iterable=train_dl, total=len(train_dl))
for img, lab in tqdm_train_dl:
out = model(img.to(devices))
out = out.view(out.size(0), -1)
train_features.extend(out.cpu().data)
train_labels.extend(lab)
tqdm_train_dl.close()
test_features = []
test_labels = []
tqdm_test_dl = tqdm.tqdm(iterable=test_dl, total=len(test_dl))
for img, lab in tqdm_test_dl:
out = model(img.to(devices))
out = out.view(out.size(0), -1)
test_features.extend(out.cpu().data)
test_labels.extend(lab)
tqdm_test_dl.close()
# 构建特征的数据集
train_feat_ds = FeatureDataset(feature_list=train_features, label_list=train_labels)
test_feat_ds = FeatureDataset(feature_list=test_features, label_list=test_labels)
train_feat_dl = data.DataLoader(dataset=train_feat_ds, batch_size=32, shuffle=True)
test_feat_dl = data.DataLoader(dataset=test_feat_ds, batch_size=32)
# 构建特征分类器
in_size = train_features[0].shape[0]
net = FCModel(in_size=in_size, out_size=200)
net = net.to(devices)
# 训练使用的配置
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(params=net.parameters(), lr=0.00001)
# 开始进行训练
for epoch in range(100):
net.train()
train_accuracy_sum = []
train_loss_sum = []
train_tqdm = tqdm.tqdm(iterable=train_feat_dl, total=len(train_feat_dl))
train_tqdm.set_description_str('Train epoch {:2d}'.format(epoch))
for img_feature, img_label in train_tqdm:
img_feature, img_label = img_feature.to(devices), img_label.to(devices)
pred = net(img_feature)
loss = loss_fn(pred, img_label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 进行指标的展示
train_loss_sum.append(loss.item())
pred = torch.argmax(input=pred, dim=-1)
train_accuracy_sum.append(accuracy_score(y_true=img_label.cpu().numpy(), y_pred=pred.cpu().numpy()))
train_tqdm.set_postfix_str(
'loss is {:14f}, accuracy is {:14f}'.format(np.mean(train_loss_sum), np.mean(train_accuracy_sum)))
train_tqdm.close()
with torch.no_grad():
net.eval()
test_accuracy_sum = []
test_loss_sum = []
test_tqdm = tqdm.tqdm(iterable=test_feat_dl, total=len(test_feat_dl))
test_tqdm.set_description_str('Test epoch {:2d}'.format(epoch))
for img_feature, img_label in test_tqdm:
img_feature, img_label = img_feature.to(devices), img_label.to(devices)
pred = net(img_feature)
loss = loss_fn(pred, img_label)
# 进行指标的展示
test_loss_sum.append(loss.item())
pred = torch.argmax(input=pred, dim=-1)
test_accuracy_sum.append(accuracy_score(y_true=img_label.cpu().numpy(), y_pred=pred.cpu().numpy()))
test_tqdm.set_postfix_str(
'loss is {:14f}, accuracy is {:14f}'.format(np.mean(test_loss_sum), np.mean(test_accuracy_sum)))
test_tqdm.close()
# 模型的保存
if not os.path.exists(os.path.join('model_data')):
os.mkdir(os.path.join('model_data'))
torch.save(net.state_dict(), os.path.join('model_data', 'net.pth'))
六、模型的预测代码如下所示
from data_loader import loader_data
from model_loader import load_model, FCModel
import torch
import os
import matplotlib.pyplot as plt
# 环境变量的配置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 数据的加载
train_dl, test_dl, index_classes = loader_data()
image, label = next(iter(test_dl))
image, label = image.to(device), label.to(device)
# 模型的加载
model = load_model()
model.eval()
net = FCModel(in_size=50176, out_size=200)
model_state_dict = torch.load(os.path.join('model_data', 'net.pth'))
net.load_state_dict(model_state_dict)
net.eval()
model = model.to(device=device)
net = net.to(device=device)
# 进行模型的预测
index = 5
with torch.no_grad():
feature = model(image)
feature = feature.view(feature.size(0), -1)
pre = net(feature)
pre = torch.argmax(input=pre, dim=-1)
plt.axis('off')
plt.title('predict result: ' + index_classes.get(pre[index].cpu().item()) + '\nlabel result: ' + index_classes.get(
label[index].cpu().item()))
plt.imshow(image[index].cpu().permute(1, 2, 0))
plt.savefig('result.png')
plt.show()
七、代码的运行结果如下所示
边栏推荐
- Qt插件之Qt Designer插件实现
- Document declaration and character encoding
- 【leetcode】74. Search 2D matrix
- 手撕——排序
- PR zero foundation introductory guide note 2
- Blue Bridge Cup SCM digital tube skills
- The first game of the 11th provincial single chip microcomputer competition of the Blue Bridge Cup
- Fingertips life Chapter 4 modules and packages
- JVM知识点
- 2022-07-01:某公司年会上,大家要玩一食发奖金游戏,一共有n个员工, 每个员工都有建设积分和捣乱积分, 他们需要排成一队,在队伍最前面的一定是老板,老板也有建设积分和捣乱积分, 排好队后,所有
猜你喜欢
BGP experiment the next day
毕设-基于SSM电影院购票系统
Go语言介绍
Sword finger offer II 006 Sort the sum of two numbers in the array
Sorted out an ECS summer money saving secret, this time @ old users come and take it away
Force buckle 540 A single element in an ordered array
Lost a few hairs, and finally learned - graph traversal -dfs and BFS
整理了一份ECS夏日省钱秘籍,这次@老用户快来领走
手撕——排序
Set vscode. When double clicking, the selected string includes the $symbol - convenient for PHP operation
随机推荐
【leetcode】74. Search 2D matrix
JVM knowledge points
[untitled]
【c语言】基础篇学习笔记
Introduction to vmware workstation and vSphere
[Li Kou brush questions] 15 Sum of three numbers (double pointer); 17. Letter combination of phone number (recursive backtracking)
Handling of inconsistency between cursor and hinttext position in shutter textfield
Three ways for programmers to learn PHP easily and put chaos out of order
SQL:常用的 SQL 命令
Lei Jun wrote a blog when he was a programmer. It's awesome
Document declaration and character encoding
Target free or target specific: a simple and effective zero sample position detection comparative learning method
Where can I buy cancer insurance? Which product is better?
BGP experiment the next day
Go语言介绍
LxC limits the number of CPUs
【直播回顾】战码先锋首期8节直播完美落幕,下期敬请期待!
Set vscode. When double clicking, the selected string includes the $symbol - convenient for PHP operation
Wechat applet calculates the distance between the two places
2022-07-01:某公司年会上,大家要玩一食发奖金游戏,一共有n个员工, 每个员工都有建设积分和捣乱积分, 他们需要排成一队,在队伍最前面的一定是老板,老板也有建设积分和捣乱积分, 排好队后,所有