当前位置:网站首页>Pytorch---使用Pytorch进行图像定位
Pytorch---使用Pytorch进行图像定位
2022-07-02 04:04:00 【水哥很水】
一、代码中的数据集可以通过以下链接获取
二、代码运行环境
Pytorch-gpu==1.10.1
Python==3.8
三、数据集处理代码如下所示
import os
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader
import xml.etree.ElementTree as ET
import numpy as np
from torchvision import transforms
from PIL import Image
from torchvision.utils import draw_bounding_boxes
class PetDataset(Dataset):
def __init__(self, images_path, labels, transform):
super(PetDataset, self).__init__()
self.images_path = images_path
self.labels = labels
self.transform = transform
def __getitem__(self, index):
img = self.images_path[index]
pil_img = Image.open(img).convert('RGB')
img_tensor = self.transform(pil_img)
label1, label2, label3, label4 = self.labels[index]
return img_tensor, label1, label2, label3, label4
def __len__(self):
return len(self.images_path)
def to_labels(path):
xml_file = open(path, encoding='utf-8')
tree = ET.parse(xml_file)
root = tree.getroot()
width = float(root.find('size').find('width').text)
height = float(root.find('size').find('height').text)
xmin = float(root.find('object').find('bndbox').find('xmin').text) / width
ymin = float(root.find('object').find('bndbox').find('ymin').text) / height
xmax = float(root.find('object').find('bndbox').find('xmax').text) / width
ymax = float(root.find('object').find('bndbox').find('ymax').text) / height
return [xmin, ymin, xmax, ymax]
def load_data():
DATASET_PATH = r'/Users/leeakita/Desktop/dataset'
BATCH_SIZE = 32
XML_PATH = os.path.join(DATASET_PATH, 'xmls')
IMAGE_PATH = os.path.join(DATASET_PATH, 'images')
xml_names = os.listdir(XML_PATH)
file_names = [name.split('.')[0] for name in xml_names]
image_paths = [os.path.join(IMAGE_PATH, file_name + '.jpg') for file_name in file_names]
xml_paths = [os.path.join(XML_PATH, file_name + '.xml') for file_name in file_names]
labels = [to_labels(xml_path) for xml_path in xml_paths]
np.random.seed(2022)
index = np.random.permutation(len(image_paths))
image_paths = np.array(image_paths)[index]
labels = np.array(labels)[index]
labels = labels.astype(np.float32)
train_split = int(len(image_paths) * 0.8)
train_images = image_paths[:train_split]
train_labels = labels[:train_split]
test_images = image_paths[train_split:]
test_labels = labels[train_split:]
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
train_ds = PetDataset(images_path=train_images, labels=train_labels, transform=transform)
test_ds = PetDataset(images_path=test_images, labels=test_labels, transform=transform)
train_dl = DataLoader(dataset=train_ds, batch_size=BATCH_SIZE, shuffle=True)
test_dl = DataLoader(dataset=test_ds, batch_size=BATCH_SIZE)
return train_dl, test_dl
if __name__ == '__main__':
trainn, testt = load_data()
image, xminn, yminn, xmaxn, ymaxn = next(iter(testt))
index = 6
image, xminn, yminn, xmaxn, ymaxn = image[index], xminn[index], yminn[index], xmaxn[index], ymaxn[index]
boxes = [xminn.item() * 224, yminn.item() * 224, xmaxn.item() * 224, ymaxn.item() * 224]
boxes = torch.FloatTensor(boxes)
boxes = boxes.unsqueeze(0)
result = draw_bounding_boxes(image=torch.as_tensor(data=image * 255, dtype=torch.uint8), boxes=boxes, colors='red')
plt.imshow(result.permute(1, 2, 0).numpy())
plt.show()
四、模型的构建代码如下所示
import torch
import torchvision
from torch import nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
resnet = torchvision.models.resnet101(pretrained=True)
self.conv_base = nn.Sequential(*list(resnet.children())[:-1])
self.fc1 = nn.Linear(in_features=resnet.fc.in_features, out_features=1)
self.fc2 = nn.Linear(in_features=resnet.fc.in_features, out_features=1)
self.fc3 = nn.Linear(in_features=resnet.fc.in_features, out_features=1)
self.fc4 = nn.Linear(in_features=resnet.fc.in_features, out_features=1)
def forward(self, x):
x = self.conv_base(x)
x = torch.squeeze(x)
x1 = self.fc1(x)
x2 = self.fc2(x)
x3 = self.fc3(x)
x4 = self.fc4(x)
return x1, x2, x3, x4
if __name__ == '__main__':
model = Net()
五、模型的训练代码如下所示
import numpy as np
import torch
from data_loader import load_data
from model_loader import Net
import tqdm
import os
# 环境变量的配置
devices = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 数据的加载
train_dl, test_dl = load_data()
# 模型的加载
model = Net()
model.to(device=devices)
# 训练的相关配置
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.0001)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=7, gamma=0.7)
# 开始进行训练
for epoch in range(50):
model.train()
train_tqdm = tqdm.tqdm(iterable=train_dl, total=len(train_dl))
train_tqdm.set_description_str('Train epoch: {:2d}'.format(epoch))
train_loss_sum = []
for image, xmin, ymin, xmax, ymax in train_tqdm:
image, xmin, ymin, xmax, ymax = image.to(devices), xmin.to(devices), ymin.to(devices), xmax.to(
devices), ymax.to(devices)
pred_xmin, pred_ymin, pred_xmax, pred_ymax = model(image)
pred_xmin, pred_ymin, pred_xmax, pred_ymax = pred_xmin.squeeze(), pred_ymin.squeeze(), pred_xmax.squeeze(), pred_ymax.squeeze()
loss_xmin = loss_fn(pred_xmin, xmin)
loss_ymin = loss_fn(pred_ymin, ymin)
loss_xmax = loss_fn(pred_xmax, xmax)
loss_ymax = loss_fn(pred_ymax, ymax)
loss = loss_xmin + loss_ymin + loss_xmax + loss_ymax
optimizer.zero_grad()
loss.backward()
optimizer.step()
with torch.no_grad():
train_loss_sum.append(loss.item())
train_tqdm.set_postfix_str('loss is :{:14f}'.format(np.array(train_loss_sum).mean()))
train_tqdm.close()
with torch.no_grad():
model.eval()
test_tqdm = tqdm.tqdm(iterable=test_dl, total=len(test_dl))
test_tqdm.set_description_str('Test epoch: {:2d}'.format(epoch))
test_loss_sum = []
for image, xmin, ymin, xmax, ymax in test_tqdm:
image, xmin, ymin, xmax, ymax = image.to(devices), xmin.to(devices), ymin.to(devices), xmax.to(
devices), ymax.to(devices)
pred_xmin, pred_ymin, pred_xmax, pred_ymax = model(image)
pred_xmin, pred_ymin, pred_xmax, pred_ymax = pred_xmin.squeeze(), pred_ymin.squeeze(), pred_xmax.squeeze(), pred_ymax.squeeze()
loss_xmin = loss_fn(pred_xmin, xmin)
loss_ymin = loss_fn(pred_ymin, ymin)
loss_xmax = loss_fn(pred_xmax, xmax)
loss_ymax = loss_fn(pred_ymax, ymax)
loss = loss_xmin + loss_ymin + loss_xmax + loss_ymax
test_loss_sum.append(loss.item())
test_tqdm.set_postfix_str('loss is :{:14f}'.format(np.array(test_loss_sum).mean()))
test_tqdm.close()
# 进行模型的保存
if not os.path.exists(os.path.join('model_data')):
os.mkdir(os.path.join('model_data'))
torch.save(model.state_dict(), os.path.join('model_data', 'model.pth'))
六、模型的预测代码如下所示
import torch
from data_loader import load_data
from model_loader import Net
import os
from torchvision.utils import draw_bounding_boxes
import matplotlib.pyplot as plt
# 数据的加载
train_dl, test_dl = load_data()
image, xmin, ymin, xmax, ymax = next(iter(test_dl))
# 模型的加载
model = Net()
model_state_dict = torch.load(os.path.join('model_data', 'model.pth'), map_location='cpu')
model.load_state_dict(model_state_dict)
model.eval()
# 开始进行预测
index = 0
with torch.no_grad():
pred_xmin, pred_ymin, pred_xmax, pred_ymax = model(image)
pred_xmin, pred_ymin, pred_xmax, pred_ymax = pred_xmin[index], pred_ymin[index], pred_xmax[index], pred_ymax[index]
pre_boxes = [pred_xmin.item() * 224, pred_ymin.item() * 224, pred_xmax.item() * 224, pred_ymax.item() * 224]
pre_boxes = torch.FloatTensor(pre_boxes)
pre_boxes = torch.unsqueeze(input=pre_boxes, dim=0)
label_boxes = [xmin[index].item() * 224, ymin[index].item() * 224, xmax[index].item() * 224,
ymax[index].item() * 224]
label_boxes = torch.FloatTensor(label_boxes)
label_boxes = torch.unsqueeze(input=label_boxes, dim=0)
img = image[index]
img = torch.as_tensor(data=img * 255, dtype=torch.uint8)
result = draw_bounding_boxes(image=img, boxes=pre_boxes, colors='red')
result = draw_bounding_boxes(image=result, boxes=label_boxes, colors='blue')
plt.figure(figsize=(8, 8), dpi=500)
plt.axis('off')
plt.imshow(result.permute(1, 2, 0))
plt.savefig('result.png')
plt.show()
七、代码的运行结果如下所示
边栏推荐
- Use of go package
- Sorted out an ECS summer money saving secret, this time @ old users come and take it away
- Dare to go out for an interview without learning some distributed technology?
- [live broadcast review] the first 8 live broadcasts of battle code Pioneer have come to a perfect end. Please look forward to the next one!
- regular expression
- pip 安装第三方库
- Raspberry pie GPIO pin controls traffic light and buzzer
- 【无线图传】基于FPGA的简易无线图像传输系统verilog开发,matlab辅助验证
- 云服务器的安全设置常识
- How much is the tuition fee of SCM training class? How long is the study time?
猜你喜欢
JVM knowledge points
Force buckle 540 A single element in an ordered array
【人员密度检测】基于形态学处理和GRNN网络的人员密度检测matlab仿真
Homework in Chapter 3 of slam course of dark blue vision -- derivative application of T6 common functions
Use a mask to restrict the input of the qlineedit control
Lost a few hairs, and finally learned - graph traversal -dfs and BFS
Suggestions on settlement solution of u standard contract position explosion
Analysis of the overall design principle of Nacos configuration center (persistence, clustering, information synchronization)
The 7th Blue Bridge Cup single chip microcomputer provincial competition
How much is the tuition fee of SCM training class? How long is the study time?
随机推荐
Go branch and loop
Go language introduction
Www2022 | know your way back: self training method of graph neural network under distribution and migration
树莓派GPIO引脚控制红绿灯与轰鸣器
云服务器的安全设置常识
【直播回顾】战码先锋首期8节直播完美落幕,下期敬请期待!
文档声明与字符编码
[wireless image transmission] FPGA based simple wireless image transmission system Verilog development, matlab assisted verification
Dare to go out for an interview without learning some distributed technology?
Hand tear - sort
【小技巧】使用matlab GUI以对话框模式读取文件
【无线图传】基于FPGA的简易无线图像传输系统verilog开发,matlab辅助验证
BiShe cinema ticket purchasing system based on SSM
Which insurance company has a better product of anti-cancer insurance?
Wechat applet pull-down loading more waterfall flow loading
Uni app - realize the countdown of 60 seconds to obtain the mobile verification code (mobile number + verification code login function)
[untitled]
Déchirure à la main - tri
Recently, the weather has been extremely hot, so collect the weather data of Beijing, Shanghai, Guangzhou and Shenzhen last year, and make a visual map
《动手学深度学习》(二)-- 多层感知机