当前位置:网站首页>Pytorch---使用Pytorch进行鸟类的预测
Pytorch---使用Pytorch进行鸟类的预测
2022-07-02 04:04:00 【水哥很水】
一、代码中的数据集可以通过以下链接获取
二、代码运行环境
Pytorch-gpu==1.7.1
Python==3.8
三、数据集处理代码如下所示
import glob
import numpy as np
from torch.utils.data.dataset import T_co
from torchvision import transforms
from torch.utils import data
from PIL import Image
import matplotlib.pyplot as plt
import torch
class FeatureDataset(data.Dataset):
def __getitem__(self, index):
return self.feature_list[index], torch.as_tensor(data=self.label_list[index], dtype=torch.long)
def __init__(self, feature_list, label_list):
self.feature_list = feature_list
self.label_list = label_list
def __len__(self):
return len(self.feature_list)
class BirdsDataset(data.Dataset):
def __init__(self, img_path, labels, trans):
self.imgs = img_path
self.labels = labels
self.trans = trans
def __getitem__(self, index) -> T_co:
imgg = self.imgs[index]
label = self.labels[index]
pil_img = Image.open(imgg)
pil_img = pil_img.convert('RGB')
img_tensor = self.trans(pil_img)
return img_tensor, label
def __len__(self):
return len(self.imgs)
def loader_data():
file_names = glob.glob(r'birds\*\*.jpg')
classes = np.unique([name.split('\\')[1].split('.')[1] for name in file_names])
index_classes = dict((index, cla) for index, cla in enumerate(classes))
classes_index = dict((v, k) for k, v in index_classes.items())
all_labels = []
for name in file_names:
for cla in classes:
if cla in name:
all_labels.append(classes_index.get(cla))
np.random.seed(200)
random_index = np.random.permutation(len(file_names))
file_names = np.array(file_names)[random_index]
all_labels = np.array(all_labels)[random_index]
i = int(len(file_names) * 0.8)
train_path = file_names[:i]
train_label = all_labels[:i]
test_path = file_names[i:]
test_label = all_labels[i:]
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
train_ds = BirdsDataset(img_path=train_path, labels=train_label, trans=transform)
test_ds = BirdsDataset(img_path=test_path, labels=test_label, trans=transform)
train_dl = data.DataLoader(dataset=train_ds, batch_size=32, shuffle=True)
test_dl = data.DataLoader(dataset=test_ds, batch_size=32)
return train_dl, test_dl, index_classes
if __name__ == '__main__':
a, b, c = loader_data()
img_batch, label_batch = next(iter(b))
plt.figure(figsize=(12, 8))
for ii, (img, lab) in enumerate(zip(img_batch[:6], label_batch[:6])):
img = img.permute(1, 2, 0).numpy()
plt.subplot(2, 3, ii + 1)
plt.axis('off')
plt.title(c.get(lab.item()))
plt.imshow(img)
plt.show()
四、模型的构建代码如下所示
import torchvision
from torch import nn
class FCModel(nn.Module):
def __init__(self, in_size, out_size):
super().__init__()
self.lin1 = nn.Linear(in_features=in_size, out_features=2048)
self.lin2 = nn.Linear(in_features=2048, out_features=1024)
self.lin3 = nn.Linear(in_features=1024, out_features=out_size)
def forward(self, x):
return self.lin3(self.lin2(self.lin1(x)))
def load_model():
model = torchvision.models.densenet121(pretrained=True).features
for parameter in model.parameters():
parameter.requires_grad = False
return model
if __name__ == '__main__':
mol = load_model()
print(mol)
五、模型的训练代码如下所示
import torch
import tqdm
from data_loader import loader_data, FeatureDataset
from model_loader import load_model, FCModel
from torch.utils import data
from torch import nn, optim
from sklearn.metrics import accuracy_score
import numpy as np
import os
# 环境配置
devices = 'cuda' if torch.cuda.is_available() else 'cpu'
# 加载数据
train_dl, test_dl, index_classes = loader_data()
# 加载模型
model = load_model()
model = model.to(devices)
# 开始进行特征提取
train_features = []
train_labels = []
tqdm_train_dl = tqdm.tqdm(iterable=train_dl, total=len(train_dl))
for img, lab in tqdm_train_dl:
out = model(img.to(devices))
out = out.view(out.size(0), -1)
train_features.extend(out.cpu().data)
train_labels.extend(lab)
tqdm_train_dl.close()
test_features = []
test_labels = []
tqdm_test_dl = tqdm.tqdm(iterable=test_dl, total=len(test_dl))
for img, lab in tqdm_test_dl:
out = model(img.to(devices))
out = out.view(out.size(0), -1)
test_features.extend(out.cpu().data)
test_labels.extend(lab)
tqdm_test_dl.close()
# 构建特征的数据集
train_feat_ds = FeatureDataset(feature_list=train_features, label_list=train_labels)
test_feat_ds = FeatureDataset(feature_list=test_features, label_list=test_labels)
train_feat_dl = data.DataLoader(dataset=train_feat_ds, batch_size=32, shuffle=True)
test_feat_dl = data.DataLoader(dataset=test_feat_ds, batch_size=32)
# 构建特征分类器
in_size = train_features[0].shape[0]
net = FCModel(in_size=in_size, out_size=200)
net = net.to(devices)
# 训练使用的配置
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(params=net.parameters(), lr=0.00001)
# 开始进行训练
for epoch in range(100):
net.train()
train_accuracy_sum = []
train_loss_sum = []
train_tqdm = tqdm.tqdm(iterable=train_feat_dl, total=len(train_feat_dl))
train_tqdm.set_description_str('Train epoch {:2d}'.format(epoch))
for img_feature, img_label in train_tqdm:
img_feature, img_label = img_feature.to(devices), img_label.to(devices)
pred = net(img_feature)
loss = loss_fn(pred, img_label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 进行指标的展示
train_loss_sum.append(loss.item())
pred = torch.argmax(input=pred, dim=-1)
train_accuracy_sum.append(accuracy_score(y_true=img_label.cpu().numpy(), y_pred=pred.cpu().numpy()))
train_tqdm.set_postfix_str(
'loss is {:14f}, accuracy is {:14f}'.format(np.mean(train_loss_sum), np.mean(train_accuracy_sum)))
train_tqdm.close()
with torch.no_grad():
net.eval()
test_accuracy_sum = []
test_loss_sum = []
test_tqdm = tqdm.tqdm(iterable=test_feat_dl, total=len(test_feat_dl))
test_tqdm.set_description_str('Test epoch {:2d}'.format(epoch))
for img_feature, img_label in test_tqdm:
img_feature, img_label = img_feature.to(devices), img_label.to(devices)
pred = net(img_feature)
loss = loss_fn(pred, img_label)
# 进行指标的展示
test_loss_sum.append(loss.item())
pred = torch.argmax(input=pred, dim=-1)
test_accuracy_sum.append(accuracy_score(y_true=img_label.cpu().numpy(), y_pred=pred.cpu().numpy()))
test_tqdm.set_postfix_str(
'loss is {:14f}, accuracy is {:14f}'.format(np.mean(test_loss_sum), np.mean(test_accuracy_sum)))
test_tqdm.close()
# 模型的保存
if not os.path.exists(os.path.join('model_data')):
os.mkdir(os.path.join('model_data'))
torch.save(net.state_dict(), os.path.join('model_data', 'net.pth'))
六、模型的预测代码如下所示
from data_loader import loader_data
from model_loader import load_model, FCModel
import torch
import os
import matplotlib.pyplot as plt
# 环境变量的配置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 数据的加载
train_dl, test_dl, index_classes = loader_data()
image, label = next(iter(test_dl))
image, label = image.to(device), label.to(device)
# 模型的加载
model = load_model()
model.eval()
net = FCModel(in_size=50176, out_size=200)
model_state_dict = torch.load(os.path.join('model_data', 'net.pth'))
net.load_state_dict(model_state_dict)
net.eval()
model = model.to(device=device)
net = net.to(device=device)
# 进行模型的预测
index = 5
with torch.no_grad():
feature = model(image)
feature = feature.view(feature.size(0), -1)
pre = net(feature)
pre = torch.argmax(input=pre, dim=-1)
plt.axis('off')
plt.title('predict result: ' + index_classes.get(pre[index].cpu().item()) + '\nlabel result: ' + index_classes.get(
label[index].cpu().item()))
plt.imshow(image[index].cpu().permute(1, 2, 0))
plt.savefig('result.png')
plt.show()
七、代码的运行结果如下所示

边栏推荐
- 集成底座方案演示说明
- okcc为什么云呼叫中心比传统呼叫中心更好?
- Go variables and constants
- 10 minutes to understand CMS garbage collector in JVM
- FAQ | FAQ for building applications for large screen devices
- The first game of the 11th provincial single chip microcomputer competition of the Blue Bridge Cup
- Pandora IOT development board learning (HAL Library) - Experiment 2 buzzer experiment (learning notes)
- JVM知识点
- Uni app - realize the countdown of 60 seconds to obtain the mobile verification code (mobile number + verification code login function)
- 微信小程序 - 实现获取手机验证码倒计时 60 秒(手机号+验证码登录功能)
猜你喜欢

Document declaration and character encoding

Learn more about materialapp and common attribute parsing in fluent

MySQL advanced SQL statement 2

《动手学深度学习》(二)-- 多层感知机

Basic operations of MySQL database (based on tables)

云服务器的安全设置常识

"No war on the Western Front" we just began to love life, but we had to shoot at everything
![[source code analysis] NVIDIA hugectr, GPU version parameter server - (1)](/img/e1/620443dbc6ea8b326e1242f25d6d74.jpg)
[source code analysis] NVIDIA hugectr, GPU version parameter server - (1)

The 8th Blue Bridge Cup single chip microcomputer provincial competition

Pandora IOT development board learning (RT thread) - Experiment 1 LED flashing experiment (learning notes)
随机推荐
Installation et utilisation du lac bleu
LCM of Spreadtrum platform rotates 180 °
Monkey测试
Li Kou interview question 02.08 Loop detection
【无线图传】基于FPGA的简易无线图像传输系统verilog开发,matlab辅助验证
Blue Bridge Cup single chip microcomputer sixth temperature recorder
L'avènement de l'ère 5G, une brève discussion sur la vie passée et présente des communications mobiles
手撕——排序
【c语言】动态规划---入门到起立
Go variables and constants
[source code analysis] NVIDIA hugectr, GPU version parameter server - (1)
Document declaration and character encoding
Introduction to vmware workstation and vSphere
Wechat applet calculates the distance between the two places
The second game of the 12th provincial single chip microcomputer competition of the Blue Bridge Cup
A thorough understanding of the development of scorecards - the determination of Y (Vintage analysis, rolling rate analysis, etc.)
【直播回顾】战码先锋首期8节直播完美落幕,下期敬请期待!
go 函数
How should the team choose the feature branch development mode or trunk development mode?
The second game of the 11th provincial single chip microcomputer competition of the Blue Bridge Cup