当前位置:网站首页>Pytorch---使用Pytorch进行鸟类的预测
Pytorch---使用Pytorch进行鸟类的预测
2022-07-02 04:04:00 【水哥很水】
一、代码中的数据集可以通过以下链接获取
二、代码运行环境
Pytorch-gpu==1.7.1
Python==3.8
三、数据集处理代码如下所示
import glob
import numpy as np
from torch.utils.data.dataset import T_co
from torchvision import transforms
from torch.utils import data
from PIL import Image
import matplotlib.pyplot as plt
import torch
class FeatureDataset(data.Dataset):
def __getitem__(self, index):
return self.feature_list[index], torch.as_tensor(data=self.label_list[index], dtype=torch.long)
def __init__(self, feature_list, label_list):
self.feature_list = feature_list
self.label_list = label_list
def __len__(self):
return len(self.feature_list)
class BirdsDataset(data.Dataset):
def __init__(self, img_path, labels, trans):
self.imgs = img_path
self.labels = labels
self.trans = trans
def __getitem__(self, index) -> T_co:
imgg = self.imgs[index]
label = self.labels[index]
pil_img = Image.open(imgg)
pil_img = pil_img.convert('RGB')
img_tensor = self.trans(pil_img)
return img_tensor, label
def __len__(self):
return len(self.imgs)
def loader_data():
file_names = glob.glob(r'birds\*\*.jpg')
classes = np.unique([name.split('\\')[1].split('.')[1] for name in file_names])
index_classes = dict((index, cla) for index, cla in enumerate(classes))
classes_index = dict((v, k) for k, v in index_classes.items())
all_labels = []
for name in file_names:
for cla in classes:
if cla in name:
all_labels.append(classes_index.get(cla))
np.random.seed(200)
random_index = np.random.permutation(len(file_names))
file_names = np.array(file_names)[random_index]
all_labels = np.array(all_labels)[random_index]
i = int(len(file_names) * 0.8)
train_path = file_names[:i]
train_label = all_labels[:i]
test_path = file_names[i:]
test_label = all_labels[i:]
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
train_ds = BirdsDataset(img_path=train_path, labels=train_label, trans=transform)
test_ds = BirdsDataset(img_path=test_path, labels=test_label, trans=transform)
train_dl = data.DataLoader(dataset=train_ds, batch_size=32, shuffle=True)
test_dl = data.DataLoader(dataset=test_ds, batch_size=32)
return train_dl, test_dl, index_classes
if __name__ == '__main__':
a, b, c = loader_data()
img_batch, label_batch = next(iter(b))
plt.figure(figsize=(12, 8))
for ii, (img, lab) in enumerate(zip(img_batch[:6], label_batch[:6])):
img = img.permute(1, 2, 0).numpy()
plt.subplot(2, 3, ii + 1)
plt.axis('off')
plt.title(c.get(lab.item()))
plt.imshow(img)
plt.show()
四、模型的构建代码如下所示
import torchvision
from torch import nn
class FCModel(nn.Module):
def __init__(self, in_size, out_size):
super().__init__()
self.lin1 = nn.Linear(in_features=in_size, out_features=2048)
self.lin2 = nn.Linear(in_features=2048, out_features=1024)
self.lin3 = nn.Linear(in_features=1024, out_features=out_size)
def forward(self, x):
return self.lin3(self.lin2(self.lin1(x)))
def load_model():
model = torchvision.models.densenet121(pretrained=True).features
for parameter in model.parameters():
parameter.requires_grad = False
return model
if __name__ == '__main__':
mol = load_model()
print(mol)
五、模型的训练代码如下所示
import torch
import tqdm
from data_loader import loader_data, FeatureDataset
from model_loader import load_model, FCModel
from torch.utils import data
from torch import nn, optim
from sklearn.metrics import accuracy_score
import numpy as np
import os
# 环境配置
devices = 'cuda' if torch.cuda.is_available() else 'cpu'
# 加载数据
train_dl, test_dl, index_classes = loader_data()
# 加载模型
model = load_model()
model = model.to(devices)
# 开始进行特征提取
train_features = []
train_labels = []
tqdm_train_dl = tqdm.tqdm(iterable=train_dl, total=len(train_dl))
for img, lab in tqdm_train_dl:
out = model(img.to(devices))
out = out.view(out.size(0), -1)
train_features.extend(out.cpu().data)
train_labels.extend(lab)
tqdm_train_dl.close()
test_features = []
test_labels = []
tqdm_test_dl = tqdm.tqdm(iterable=test_dl, total=len(test_dl))
for img, lab in tqdm_test_dl:
out = model(img.to(devices))
out = out.view(out.size(0), -1)
test_features.extend(out.cpu().data)
test_labels.extend(lab)
tqdm_test_dl.close()
# 构建特征的数据集
train_feat_ds = FeatureDataset(feature_list=train_features, label_list=train_labels)
test_feat_ds = FeatureDataset(feature_list=test_features, label_list=test_labels)
train_feat_dl = data.DataLoader(dataset=train_feat_ds, batch_size=32, shuffle=True)
test_feat_dl = data.DataLoader(dataset=test_feat_ds, batch_size=32)
# 构建特征分类器
in_size = train_features[0].shape[0]
net = FCModel(in_size=in_size, out_size=200)
net = net.to(devices)
# 训练使用的配置
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(params=net.parameters(), lr=0.00001)
# 开始进行训练
for epoch in range(100):
net.train()
train_accuracy_sum = []
train_loss_sum = []
train_tqdm = tqdm.tqdm(iterable=train_feat_dl, total=len(train_feat_dl))
train_tqdm.set_description_str('Train epoch {:2d}'.format(epoch))
for img_feature, img_label in train_tqdm:
img_feature, img_label = img_feature.to(devices), img_label.to(devices)
pred = net(img_feature)
loss = loss_fn(pred, img_label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 进行指标的展示
train_loss_sum.append(loss.item())
pred = torch.argmax(input=pred, dim=-1)
train_accuracy_sum.append(accuracy_score(y_true=img_label.cpu().numpy(), y_pred=pred.cpu().numpy()))
train_tqdm.set_postfix_str(
'loss is {:14f}, accuracy is {:14f}'.format(np.mean(train_loss_sum), np.mean(train_accuracy_sum)))
train_tqdm.close()
with torch.no_grad():
net.eval()
test_accuracy_sum = []
test_loss_sum = []
test_tqdm = tqdm.tqdm(iterable=test_feat_dl, total=len(test_feat_dl))
test_tqdm.set_description_str('Test epoch {:2d}'.format(epoch))
for img_feature, img_label in test_tqdm:
img_feature, img_label = img_feature.to(devices), img_label.to(devices)
pred = net(img_feature)
loss = loss_fn(pred, img_label)
# 进行指标的展示
test_loss_sum.append(loss.item())
pred = torch.argmax(input=pred, dim=-1)
test_accuracy_sum.append(accuracy_score(y_true=img_label.cpu().numpy(), y_pred=pred.cpu().numpy()))
test_tqdm.set_postfix_str(
'loss is {:14f}, accuracy is {:14f}'.format(np.mean(test_loss_sum), np.mean(test_accuracy_sum)))
test_tqdm.close()
# 模型的保存
if not os.path.exists(os.path.join('model_data')):
os.mkdir(os.path.join('model_data'))
torch.save(net.state_dict(), os.path.join('model_data', 'net.pth'))
六、模型的预测代码如下所示
from data_loader import loader_data
from model_loader import load_model, FCModel
import torch
import os
import matplotlib.pyplot as plt
# 环境变量的配置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 数据的加载
train_dl, test_dl, index_classes = loader_data()
image, label = next(iter(test_dl))
image, label = image.to(device), label.to(device)
# 模型的加载
model = load_model()
model.eval()
net = FCModel(in_size=50176, out_size=200)
model_state_dict = torch.load(os.path.join('model_data', 'net.pth'))
net.load_state_dict(model_state_dict)
net.eval()
model = model.to(device=device)
net = net.to(device=device)
# 进行模型的预测
index = 5
with torch.no_grad():
feature = model(image)
feature = feature.view(feature.size(0), -1)
pre = net(feature)
pre = torch.argmax(input=pre, dim=-1)
plt.axis('off')
plt.title('predict result: ' + index_classes.get(pre[index].cpu().item()) + '\nlabel result: ' + index_classes.get(
label[index].cpu().item()))
plt.imshow(image[index].cpu().permute(1, 2, 0))
plt.savefig('result.png')
plt.show()
七、代码的运行结果如下所示

边栏推荐
- The 7th Blue Bridge Cup single chip microcomputer provincial competition
- 云服务器的安全设置常识
- The second game of the 11th provincial single chip microcomputer competition of the Blue Bridge Cup
- 2022-07-01:某公司年会上,大家要玩一食发奖金游戏,一共有n个员工, 每个员工都有建设积分和捣乱积分, 他们需要排成一队,在队伍最前面的一定是老板,老板也有建设积分和捣乱积分, 排好队后,所有
- Feature Engineering: summary of common feature transformation methods
- 初识P4语言
- Vite: scaffold assembly
- The 6th Blue Bridge Cup single chip microcomputer provincial competition
- cookie、session、tooken
- Déchirure à la main - tri
猜你喜欢
![[wireless image transmission] FPGA based simple wireless image transmission system Verilog development, matlab assisted verification](/img/77/4df7a1439ff1a53f94d409a19a47d6.png)
[wireless image transmission] FPGA based simple wireless image transmission system Verilog development, matlab assisted verification

Go语言介绍

Www 2022 | rethinking the knowledge map completion of graph convolution network

Common sense of cloud server security settings

云服务器的安全设置常识

软件测试人的第一个实战项目:web端(视频教程+文档+用例库)
![[tips] use Matlab GUI to read files in dialog mode](/img/51/6d6051836bfc9caa957d0275245bd3.png)
[tips] use Matlab GUI to read files in dialog mode

The first game of the 12th Blue Bridge Cup single chip microcomputer provincial competition

66.qt quick-qml自定义日历组件(支持竖屏和横屏)

First acquaintance with string+ simple usage (II)
随机推荐
Introduction to vmware workstation and vSphere
【人员密度检测】基于形态学处理和GRNN网络的人员密度检测matlab仿真
云服务器的安全设置常识
What is 5g industrial wireless gateway? What functions can 5g industrial wireless gateway achieve?
[live broadcast review] the first 8 live broadcasts of battle code Pioneer have come to a perfect end. Please look forward to the next one!
【直播回顾】战码先锋首期8节直播完美落幕,下期敬请期待!
【小技巧】使用matlab GUI以对话框模式读取文件
【c语言】动态规划---入门到起立
JVM知识点
2022-07-01: at the annual meeting of a company, everyone is going to play a game of giving bonuses. There are a total of N employees. Each employee has construction points and trouble points. They nee
"No war on the Western Front" we just began to love life, but we had to shoot at everything
Suggestions on settlement solution of u standard contract position explosion
Force buckle 540 A single element in an ordered array
MySQL advanced SQL statement 2
Today's plan: February 15, 2022
Raspberry pie GPIO pin controls traffic light and buzzer
Pandora IOT development board learning (HAL Library) - Experiment 2 buzzer experiment (learning notes)
手撕——排序
微信小程序 - 实现获取手机验证码倒计时 60 秒(手机号+验证码登录功能)
Déchirure à la main - tri