当前位置:网站首页>Pytorch---使用Pytorch进行图像定位
Pytorch---使用Pytorch进行图像定位
2022-07-02 04:04:00 【水哥很水】
一、代码中的数据集可以通过以下链接获取
二、代码运行环境
Pytorch-gpu==1.10.1
Python==3.8
三、数据集处理代码如下所示
import os
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader
import xml.etree.ElementTree as ET
import numpy as np
from torchvision import transforms
from PIL import Image
from torchvision.utils import draw_bounding_boxes
class PetDataset(Dataset):
def __init__(self, images_path, labels, transform):
super(PetDataset, self).__init__()
self.images_path = images_path
self.labels = labels
self.transform = transform
def __getitem__(self, index):
img = self.images_path[index]
pil_img = Image.open(img).convert('RGB')
img_tensor = self.transform(pil_img)
label1, label2, label3, label4 = self.labels[index]
return img_tensor, label1, label2, label3, label4
def __len__(self):
return len(self.images_path)
def to_labels(path):
xml_file = open(path, encoding='utf-8')
tree = ET.parse(xml_file)
root = tree.getroot()
width = float(root.find('size').find('width').text)
height = float(root.find('size').find('height').text)
xmin = float(root.find('object').find('bndbox').find('xmin').text) / width
ymin = float(root.find('object').find('bndbox').find('ymin').text) / height
xmax = float(root.find('object').find('bndbox').find('xmax').text) / width
ymax = float(root.find('object').find('bndbox').find('ymax').text) / height
return [xmin, ymin, xmax, ymax]
def load_data():
DATASET_PATH = r'/Users/leeakita/Desktop/dataset'
BATCH_SIZE = 32
XML_PATH = os.path.join(DATASET_PATH, 'xmls')
IMAGE_PATH = os.path.join(DATASET_PATH, 'images')
xml_names = os.listdir(XML_PATH)
file_names = [name.split('.')[0] for name in xml_names]
image_paths = [os.path.join(IMAGE_PATH, file_name + '.jpg') for file_name in file_names]
xml_paths = [os.path.join(XML_PATH, file_name + '.xml') for file_name in file_names]
labels = [to_labels(xml_path) for xml_path in xml_paths]
np.random.seed(2022)
index = np.random.permutation(len(image_paths))
image_paths = np.array(image_paths)[index]
labels = np.array(labels)[index]
labels = labels.astype(np.float32)
train_split = int(len(image_paths) * 0.8)
train_images = image_paths[:train_split]
train_labels = labels[:train_split]
test_images = image_paths[train_split:]
test_labels = labels[train_split:]
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
train_ds = PetDataset(images_path=train_images, labels=train_labels, transform=transform)
test_ds = PetDataset(images_path=test_images, labels=test_labels, transform=transform)
train_dl = DataLoader(dataset=train_ds, batch_size=BATCH_SIZE, shuffle=True)
test_dl = DataLoader(dataset=test_ds, batch_size=BATCH_SIZE)
return train_dl, test_dl
if __name__ == '__main__':
trainn, testt = load_data()
image, xminn, yminn, xmaxn, ymaxn = next(iter(testt))
index = 6
image, xminn, yminn, xmaxn, ymaxn = image[index], xminn[index], yminn[index], xmaxn[index], ymaxn[index]
boxes = [xminn.item() * 224, yminn.item() * 224, xmaxn.item() * 224, ymaxn.item() * 224]
boxes = torch.FloatTensor(boxes)
boxes = boxes.unsqueeze(0)
result = draw_bounding_boxes(image=torch.as_tensor(data=image * 255, dtype=torch.uint8), boxes=boxes, colors='red')
plt.imshow(result.permute(1, 2, 0).numpy())
plt.show()
四、模型的构建代码如下所示
import torch
import torchvision
from torch import nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
resnet = torchvision.models.resnet101(pretrained=True)
self.conv_base = nn.Sequential(*list(resnet.children())[:-1])
self.fc1 = nn.Linear(in_features=resnet.fc.in_features, out_features=1)
self.fc2 = nn.Linear(in_features=resnet.fc.in_features, out_features=1)
self.fc3 = nn.Linear(in_features=resnet.fc.in_features, out_features=1)
self.fc4 = nn.Linear(in_features=resnet.fc.in_features, out_features=1)
def forward(self, x):
x = self.conv_base(x)
x = torch.squeeze(x)
x1 = self.fc1(x)
x2 = self.fc2(x)
x3 = self.fc3(x)
x4 = self.fc4(x)
return x1, x2, x3, x4
if __name__ == '__main__':
model = Net()
五、模型的训练代码如下所示
import numpy as np
import torch
from data_loader import load_data
from model_loader import Net
import tqdm
import os
# 环境变量的配置
devices = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 数据的加载
train_dl, test_dl = load_data()
# 模型的加载
model = Net()
model.to(device=devices)
# 训练的相关配置
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.0001)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=7, gamma=0.7)
# 开始进行训练
for epoch in range(50):
model.train()
train_tqdm = tqdm.tqdm(iterable=train_dl, total=len(train_dl))
train_tqdm.set_description_str('Train epoch: {:2d}'.format(epoch))
train_loss_sum = []
for image, xmin, ymin, xmax, ymax in train_tqdm:
image, xmin, ymin, xmax, ymax = image.to(devices), xmin.to(devices), ymin.to(devices), xmax.to(
devices), ymax.to(devices)
pred_xmin, pred_ymin, pred_xmax, pred_ymax = model(image)
pred_xmin, pred_ymin, pred_xmax, pred_ymax = pred_xmin.squeeze(), pred_ymin.squeeze(), pred_xmax.squeeze(), pred_ymax.squeeze()
loss_xmin = loss_fn(pred_xmin, xmin)
loss_ymin = loss_fn(pred_ymin, ymin)
loss_xmax = loss_fn(pred_xmax, xmax)
loss_ymax = loss_fn(pred_ymax, ymax)
loss = loss_xmin + loss_ymin + loss_xmax + loss_ymax
optimizer.zero_grad()
loss.backward()
optimizer.step()
with torch.no_grad():
train_loss_sum.append(loss.item())
train_tqdm.set_postfix_str('loss is :{:14f}'.format(np.array(train_loss_sum).mean()))
train_tqdm.close()
with torch.no_grad():
model.eval()
test_tqdm = tqdm.tqdm(iterable=test_dl, total=len(test_dl))
test_tqdm.set_description_str('Test epoch: {:2d}'.format(epoch))
test_loss_sum = []
for image, xmin, ymin, xmax, ymax in test_tqdm:
image, xmin, ymin, xmax, ymax = image.to(devices), xmin.to(devices), ymin.to(devices), xmax.to(
devices), ymax.to(devices)
pred_xmin, pred_ymin, pred_xmax, pred_ymax = model(image)
pred_xmin, pred_ymin, pred_xmax, pred_ymax = pred_xmin.squeeze(), pred_ymin.squeeze(), pred_xmax.squeeze(), pred_ymax.squeeze()
loss_xmin = loss_fn(pred_xmin, xmin)
loss_ymin = loss_fn(pred_ymin, ymin)
loss_xmax = loss_fn(pred_xmax, xmax)
loss_ymax = loss_fn(pred_ymax, ymax)
loss = loss_xmin + loss_ymin + loss_xmax + loss_ymax
test_loss_sum.append(loss.item())
test_tqdm.set_postfix_str('loss is :{:14f}'.format(np.array(test_loss_sum).mean()))
test_tqdm.close()
# 进行模型的保存
if not os.path.exists(os.path.join('model_data')):
os.mkdir(os.path.join('model_data'))
torch.save(model.state_dict(), os.path.join('model_data', 'model.pth'))
六、模型的预测代码如下所示
import torch
from data_loader import load_data
from model_loader import Net
import os
from torchvision.utils import draw_bounding_boxes
import matplotlib.pyplot as plt
# 数据的加载
train_dl, test_dl = load_data()
image, xmin, ymin, xmax, ymax = next(iter(test_dl))
# 模型的加载
model = Net()
model_state_dict = torch.load(os.path.join('model_data', 'model.pth'), map_location='cpu')
model.load_state_dict(model_state_dict)
model.eval()
# 开始进行预测
index = 0
with torch.no_grad():
pred_xmin, pred_ymin, pred_xmax, pred_ymax = model(image)
pred_xmin, pred_ymin, pred_xmax, pred_ymax = pred_xmin[index], pred_ymin[index], pred_xmax[index], pred_ymax[index]
pre_boxes = [pred_xmin.item() * 224, pred_ymin.item() * 224, pred_xmax.item() * 224, pred_ymax.item() * 224]
pre_boxes = torch.FloatTensor(pre_boxes)
pre_boxes = torch.unsqueeze(input=pre_boxes, dim=0)
label_boxes = [xmin[index].item() * 224, ymin[index].item() * 224, xmax[index].item() * 224,
ymax[index].item() * 224]
label_boxes = torch.FloatTensor(label_boxes)
label_boxes = torch.unsqueeze(input=label_boxes, dim=0)
img = image[index]
img = torch.as_tensor(data=img * 255, dtype=torch.uint8)
result = draw_bounding_boxes(image=img, boxes=pre_boxes, colors='red')
result = draw_bounding_boxes(image=result, boxes=label_boxes, colors='blue')
plt.figure(figsize=(8, 8), dpi=500)
plt.axis('off')
plt.imshow(result.permute(1, 2, 0))
plt.savefig('result.png')
plt.show()
七、代码的运行结果如下所示

边栏推荐
- WPViewPDF Delphi 和 .NET 的 PDF 查看组件
- The 7th Blue Bridge Cup single chip microcomputer provincial competition
- Opencv learning example code 3.2.4 LUT
- [tips] use Matlab GUI to read files in dialog mode
- JVM知识点
- Welcome the winter vacation multi school league game 2 partial solution (B, C, D, F, G, H)
- The second game of the 12th provincial single chip microcomputer competition of the Blue Bridge Cup
- [ibdfe] matlab simulation of frequency domain equalization based on ibdfe
- Finally got byte offer. The 25-year-old inexperienced perception of software testing is written to you who are still confused
- 【小技巧】使用matlab GUI以对话框模式读取文件
猜你喜欢

Jetpack之LiveData扩展MediatorLiveData

10 minutes to understand CMS garbage collector in JVM

Vite: configure IP access

Target free or target specific: a simple and effective zero sample position detection comparative learning method

Cloud service selection of enterprises: comparative analysis of SaaS, PAAS and IAAs

手撕——排序

Spring recruitment of Internet enterprises: Kwai meituan has expanded the most, and the annual salary of technical posts is up to nearly 400000

Fluent icon demo

Learn more about materialapp and common attribute parsing in fluent

WiFi 5GHz frequency
随机推荐
go 包的使用
Vite: configure IP access
【小技巧】使用matlab GUI以对话框模式读取文件
5G時代全面到來,淺談移動通信的前世今生
藍湖的安裝及使用
2022-07-01:某公司年会上,大家要玩一食发奖金游戏,一共有n个员工, 每个员工都有建设积分和捣乱积分, 他们需要排成一队,在队伍最前面的一定是老板,老板也有建设积分和捣乱积分, 排好队后,所有
Go language introduction
整理了一份ECS夏日省钱秘籍,这次@老用户快来领走
66.qt quick-qml自定义日历组件(支持竖屏和横屏)
Go variables and constants
go 分支与循环
First acquaintance with string+ simple usage (II)
C语言:逻辑运算和判断选择结构例题
First acquaintance with P4 language
[ibdfe] matlab simulation of frequency domain equalization based on ibdfe
Pandora IOT development board learning (HAL Library) - Experiment 2 buzzer experiment (learning notes)
Go branch and loop
Microsoft Research Institute's new book "Fundamentals of data science", 479 Pages pdf
PIP installation of third-party libraries
Flutter中深入了解MaterialApp,常用属性解析