当前位置:网站首页>Pytoch --- use pytoch for image positioning
Pytoch --- use pytoch for image positioning
2022-07-02 04:11:00 【Brother Shui is very water】
One 、 The datasets in the code can be obtained through the following link
Baidu online disk extraction code :vc56
Two 、 Code running environment
Pytorch-gpu==1.10.1
Python==3.8
3、 ... and 、 Data set processing codes are as follows
import os
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader
import xml.etree.ElementTree as ET
import numpy as np
from torchvision import transforms
from PIL import Image
from torchvision.utils import draw_bounding_boxes
class PetDataset(Dataset):
def __init__(self, images_path, labels, transform):
super(PetDataset, self).__init__()
self.images_path = images_path
self.labels = labels
self.transform = transform
def __getitem__(self, index):
img = self.images_path[index]
pil_img = Image.open(img).convert('RGB')
img_tensor = self.transform(pil_img)
label1, label2, label3, label4 = self.labels[index]
return img_tensor, label1, label2, label3, label4
def __len__(self):
return len(self.images_path)
def to_labels(path):
xml_file = open(path, encoding='utf-8')
tree = ET.parse(xml_file)
root = tree.getroot()
width = float(root.find('size').find('width').text)
height = float(root.find('size').find('height').text)
xmin = float(root.find('object').find('bndbox').find('xmin').text) / width
ymin = float(root.find('object').find('bndbox').find('ymin').text) / height
xmax = float(root.find('object').find('bndbox').find('xmax').text) / width
ymax = float(root.find('object').find('bndbox').find('ymax').text) / height
return [xmin, ymin, xmax, ymax]
def load_data():
DATASET_PATH = r'/Users/leeakita/Desktop/dataset'
BATCH_SIZE = 32
XML_PATH = os.path.join(DATASET_PATH, 'xmls')
IMAGE_PATH = os.path.join(DATASET_PATH, 'images')
xml_names = os.listdir(XML_PATH)
file_names = [name.split('.')[0] for name in xml_names]
image_paths = [os.path.join(IMAGE_PATH, file_name + '.jpg') for file_name in file_names]
xml_paths = [os.path.join(XML_PATH, file_name + '.xml') for file_name in file_names]
labels = [to_labels(xml_path) for xml_path in xml_paths]
np.random.seed(2022)
index = np.random.permutation(len(image_paths))
image_paths = np.array(image_paths)[index]
labels = np.array(labels)[index]
labels = labels.astype(np.float32)
train_split = int(len(image_paths) * 0.8)
train_images = image_paths[:train_split]
train_labels = labels[:train_split]
test_images = image_paths[train_split:]
test_labels = labels[train_split:]
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
train_ds = PetDataset(images_path=train_images, labels=train_labels, transform=transform)
test_ds = PetDataset(images_path=test_images, labels=test_labels, transform=transform)
train_dl = DataLoader(dataset=train_ds, batch_size=BATCH_SIZE, shuffle=True)
test_dl = DataLoader(dataset=test_ds, batch_size=BATCH_SIZE)
return train_dl, test_dl
if __name__ == '__main__':
trainn, testt = load_data()
image, xminn, yminn, xmaxn, ymaxn = next(iter(testt))
index = 6
image, xminn, yminn, xmaxn, ymaxn = image[index], xminn[index], yminn[index], xmaxn[index], ymaxn[index]
boxes = [xminn.item() * 224, yminn.item() * 224, xmaxn.item() * 224, ymaxn.item() * 224]
boxes = torch.FloatTensor(boxes)
boxes = boxes.unsqueeze(0)
result = draw_bounding_boxes(image=torch.as_tensor(data=image * 255, dtype=torch.uint8), boxes=boxes, colors='red')
plt.imshow(result.permute(1, 2, 0).numpy())
plt.show()
Four 、 The construction code of the model is as follows
import torch
import torchvision
from torch import nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
resnet = torchvision.models.resnet101(pretrained=True)
self.conv_base = nn.Sequential(*list(resnet.children())[:-1])
self.fc1 = nn.Linear(in_features=resnet.fc.in_features, out_features=1)
self.fc2 = nn.Linear(in_features=resnet.fc.in_features, out_features=1)
self.fc3 = nn.Linear(in_features=resnet.fc.in_features, out_features=1)
self.fc4 = nn.Linear(in_features=resnet.fc.in_features, out_features=1)
def forward(self, x):
x = self.conv_base(x)
x = torch.squeeze(x)
x1 = self.fc1(x)
x2 = self.fc2(x)
x3 = self.fc3(x)
x4 = self.fc4(x)
return x1, x2, x3, x4
if __name__ == '__main__':
model = Net()
5、 ... and 、 The training code of the model is as follows
import numpy as np
import torch
from data_loader import load_data
from model_loader import Net
import tqdm
import os
# Configuration of environment variables
devices = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Data loading
train_dl, test_dl = load_data()
# Model loading
model = Net()
model.to(device=devices)
# Training related configurations
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.0001)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=7, gamma=0.7)
# Start training
for epoch in range(50):
model.train()
train_tqdm = tqdm.tqdm(iterable=train_dl, total=len(train_dl))
train_tqdm.set_description_str('Train epoch: {:2d}'.format(epoch))
train_loss_sum = []
for image, xmin, ymin, xmax, ymax in train_tqdm:
image, xmin, ymin, xmax, ymax = image.to(devices), xmin.to(devices), ymin.to(devices), xmax.to(
devices), ymax.to(devices)
pred_xmin, pred_ymin, pred_xmax, pred_ymax = model(image)
pred_xmin, pred_ymin, pred_xmax, pred_ymax = pred_xmin.squeeze(), pred_ymin.squeeze(), pred_xmax.squeeze(), pred_ymax.squeeze()
loss_xmin = loss_fn(pred_xmin, xmin)
loss_ymin = loss_fn(pred_ymin, ymin)
loss_xmax = loss_fn(pred_xmax, xmax)
loss_ymax = loss_fn(pred_ymax, ymax)
loss = loss_xmin + loss_ymin + loss_xmax + loss_ymax
optimizer.zero_grad()
loss.backward()
optimizer.step()
with torch.no_grad():
train_loss_sum.append(loss.item())
train_tqdm.set_postfix_str('loss is :{:14f}'.format(np.array(train_loss_sum).mean()))
train_tqdm.close()
with torch.no_grad():
model.eval()
test_tqdm = tqdm.tqdm(iterable=test_dl, total=len(test_dl))
test_tqdm.set_description_str('Test epoch: {:2d}'.format(epoch))
test_loss_sum = []
for image, xmin, ymin, xmax, ymax in test_tqdm:
image, xmin, ymin, xmax, ymax = image.to(devices), xmin.to(devices), ymin.to(devices), xmax.to(
devices), ymax.to(devices)
pred_xmin, pred_ymin, pred_xmax, pred_ymax = model(image)
pred_xmin, pred_ymin, pred_xmax, pred_ymax = pred_xmin.squeeze(), pred_ymin.squeeze(), pred_xmax.squeeze(), pred_ymax.squeeze()
loss_xmin = loss_fn(pred_xmin, xmin)
loss_ymin = loss_fn(pred_ymin, ymin)
loss_xmax = loss_fn(pred_xmax, xmax)
loss_ymax = loss_fn(pred_ymax, ymax)
loss = loss_xmin + loss_ymin + loss_xmax + loss_ymax
test_loss_sum.append(loss.item())
test_tqdm.set_postfix_str('loss is :{:14f}'.format(np.array(test_loss_sum).mean()))
test_tqdm.close()
# Save the model
if not os.path.exists(os.path.join('model_data')):
os.mkdir(os.path.join('model_data'))
torch.save(model.state_dict(), os.path.join('model_data', 'model.pth'))
6、 ... and 、 The prediction code of the model is as follows
import torch
from data_loader import load_data
from model_loader import Net
import os
from torchvision.utils import draw_bounding_boxes
import matplotlib.pyplot as plt
# Data loading
train_dl, test_dl = load_data()
image, xmin, ymin, xmax, ymax = next(iter(test_dl))
# Model loading
model = Net()
model_state_dict = torch.load(os.path.join('model_data', 'model.pth'), map_location='cpu')
model.load_state_dict(model_state_dict)
model.eval()
# Start Forecasting
index = 0
with torch.no_grad():
pred_xmin, pred_ymin, pred_xmax, pred_ymax = model(image)
pred_xmin, pred_ymin, pred_xmax, pred_ymax = pred_xmin[index], pred_ymin[index], pred_xmax[index], pred_ymax[index]
pre_boxes = [pred_xmin.item() * 224, pred_ymin.item() * 224, pred_xmax.item() * 224, pred_ymax.item() * 224]
pre_boxes = torch.FloatTensor(pre_boxes)
pre_boxes = torch.unsqueeze(input=pre_boxes, dim=0)
label_boxes = [xmin[index].item() * 224, ymin[index].item() * 224, xmax[index].item() * 224,
ymax[index].item() * 224]
label_boxes = torch.FloatTensor(label_boxes)
label_boxes = torch.unsqueeze(input=label_boxes, dim=0)
img = image[index]
img = torch.as_tensor(data=img * 255, dtype=torch.uint8)
result = draw_bounding_boxes(image=img, boxes=pre_boxes, colors='red')
result = draw_bounding_boxes(image=result, boxes=label_boxes, colors='blue')
plt.figure(figsize=(8, 8), dpi=500)
plt.axis('off')
plt.imshow(result.permute(1, 2, 0))
plt.savefig('result.png')
plt.show()
7、 ... and 、 The running result of the code is as follows
边栏推荐
- Spring moves are coming. Watch the gods fight
- How much can a job hopping increase? Today, I saw the ceiling of job hopping.
- Play with concurrency: draw a thread state transition diagram
- 【力扣刷题】15.三数之和(双指针);17.电话号码的字母组合(递归回溯)
- Wechat applet - realize the countdown of 60 seconds to obtain the mobile verification code (mobile number + verification code login function)
- 【c语言】动态规划---入门到起立
- Target free or target specific: a simple and effective zero sample position detection comparative learning method
- SQL:常用的 SQL 命令
- A thorough understanding of the development of scorecards - the determination of Y (Vintage analysis, rolling rate analysis, etc.)
- Lost a few hairs, and finally learned - graph traversal -dfs and BFS
猜你喜欢
云服务器的安全设置常识
Pandora IOT development board learning (RT thread) - Experiment 1 LED flashing experiment (learning notes)
10 minutes to understand CMS garbage collector in JVM
Recently, the weather has been extremely hot, so collect the weather data of Beijing, Shanghai, Guangzhou and Shenzhen last year, and make a visual map
Common sense of cloud server security settings
PR zero foundation introductory guide note 2
Sorted out an ECS summer money saving secret, this time @ old users come and take it away
The confusion I encountered when learning stm32
手撕——排序
A thorough understanding of the development of scorecards - the determination of Y (Vintage analysis, rolling rate analysis, etc.)
随机推荐
[Li Kou brush questions] 15 Sum of three numbers (double pointer); 17. Letter combination of phone number (recursive backtracking)
Finally got byte offer. The 25-year-old inexperienced perception of software testing is written to you who are still confused
Introduction to vmware workstation and vSphere
向数据库中存入数组数据,代码出错怎么解决
Www2022 | know your way back: self training method of graph neural network under distribution and migration
2022-07-01:某公司年会上,大家要玩一食发奖金游戏,一共有n个员工, 每个员工都有建设积分和捣乱积分, 他们需要排成一队,在队伍最前面的一定是老板,老板也有建设积分和捣乱积分, 排好队后,所有
Homework in Chapter 3 of slam course of dark blue vision -- derivative application of T6 common functions
【leetcode】81. Search rotation sort array II
《西线无战事》我们才刚开始热爱生活,却不得不对一切开炮
微信小程序 - 实现获取手机验证码倒计时 60 秒(手机号+验证码登录功能)
Pytorch---使用Pytorch实现U-Net进行语义分割
office_ Delete the last page of word (the seemingly blank page)
[JS -- map string]
C语言猜数字游戏
5G時代全面到來,淺談移動通信的前世今生
First acquaintance with P4 language
Realizing deep learning framework from zero -- Introduction to neural network
[ibdfe] matlab simulation of frequency domain equalization based on ibdfe
Spring moves are coming. Watch the gods fight
C language: examples of logical operation and judgment selection structure