当前位置:网站首页>Pytoch --- use pytoch for image positioning

Pytoch --- use pytoch for image positioning

2022-07-02 04:11:00 Brother Shui is very water

One 、 The datasets in the code can be obtained through the following link

Baidu online disk extraction code :vc56

Two 、 Code running environment


3、 ... and 、 Data set processing codes are as follows

import os
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader
import xml.etree.ElementTree as ET
import numpy as np
from torchvision import transforms
from PIL import Image
from torchvision.utils import draw_bounding_boxes

class PetDataset(Dataset):
    def __init__(self, images_path, labels, transform):
        super(PetDataset, self).__init__()
        self.images_path = images_path
        self.labels = labels
        self.transform = transform

    def __getitem__(self, index):
        img = self.images_path[index]
        pil_img = Image.open(img).convert('RGB')
        img_tensor = self.transform(pil_img)

        label1, label2, label3, label4 = self.labels[index]

        return img_tensor, label1, label2, label3, label4

    def __len__(self):
        return len(self.images_path)

def to_labels(path):
    xml_file = open(path, encoding='utf-8')
    tree = ET.parse(xml_file)
    root = tree.getroot()

    width = float(root.find('size').find('width').text)
    height = float(root.find('size').find('height').text)

    xmin = float(root.find('object').find('bndbox').find('xmin').text) / width
    ymin = float(root.find('object').find('bndbox').find('ymin').text) / height
    xmax = float(root.find('object').find('bndbox').find('xmax').text) / width
    ymax = float(root.find('object').find('bndbox').find('ymax').text) / height

    return [xmin, ymin, xmax, ymax]

def load_data():
    DATASET_PATH = r'/Users/leeakita/Desktop/dataset'
    BATCH_SIZE = 32

    XML_PATH = os.path.join(DATASET_PATH, 'xmls')
    IMAGE_PATH = os.path.join(DATASET_PATH, 'images')

    xml_names = os.listdir(XML_PATH)
    file_names = [name.split('.')[0] for name in xml_names]

    image_paths = [os.path.join(IMAGE_PATH, file_name + '.jpg') for file_name in file_names]
    xml_paths = [os.path.join(XML_PATH, file_name + '.xml') for file_name in file_names]
    labels = [to_labels(xml_path) for xml_path in xml_paths]

    index = np.random.permutation(len(image_paths))

    image_paths = np.array(image_paths)[index]
    labels = np.array(labels)[index]
    labels = labels.astype(np.float32)

    train_split = int(len(image_paths) * 0.8)

    train_images = image_paths[:train_split]
    train_labels = labels[:train_split]

    test_images = image_paths[train_split:]
    test_labels = labels[train_split:]

    transform = transforms.Compose([
        transforms.Resize((224, 224)),

    train_ds = PetDataset(images_path=train_images, labels=train_labels, transform=transform)
    test_ds = PetDataset(images_path=test_images, labels=test_labels, transform=transform)

    train_dl = DataLoader(dataset=train_ds, batch_size=BATCH_SIZE, shuffle=True)
    test_dl = DataLoader(dataset=test_ds, batch_size=BATCH_SIZE)

    return train_dl, test_dl

if __name__ == '__main__':
    trainn, testt = load_data()
    image, xminn, yminn, xmaxn, ymaxn = next(iter(testt))
    index = 6
    image, xminn, yminn, xmaxn, ymaxn = image[index], xminn[index], yminn[index], xmaxn[index], ymaxn[index]
    boxes = [xminn.item() * 224, yminn.item() * 224, xmaxn.item() * 224, ymaxn.item() * 224]
    boxes = torch.FloatTensor(boxes)
    boxes = boxes.unsqueeze(0)
    result = draw_bounding_boxes(image=torch.as_tensor(data=image * 255, dtype=torch.uint8), boxes=boxes, colors='red')
    plt.imshow(result.permute(1, 2, 0).numpy())

Four 、 The construction code of the model is as follows

import torch
import torchvision
from torch import nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        resnet = torchvision.models.resnet101(pretrained=True)
        self.conv_base = nn.Sequential(*list(resnet.children())[:-1])
        self.fc1 = nn.Linear(in_features=resnet.fc.in_features, out_features=1)
        self.fc2 = nn.Linear(in_features=resnet.fc.in_features, out_features=1)
        self.fc3 = nn.Linear(in_features=resnet.fc.in_features, out_features=1)
        self.fc4 = nn.Linear(in_features=resnet.fc.in_features, out_features=1)

    def forward(self, x):
        x = self.conv_base(x)
        x = torch.squeeze(x)
        x1 = self.fc1(x)
        x2 = self.fc2(x)
        x3 = self.fc3(x)
        x4 = self.fc4(x)

        return x1, x2, x3, x4

if __name__ == '__main__':
    model = Net()

5、 ... and 、 The training code of the model is as follows

import numpy as np
import torch
from data_loader import load_data
from model_loader import Net
import tqdm
import os

#  Configuration of environment variables 
devices = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#  Data loading 
train_dl, test_dl = load_data()

#  Model loading 
model = Net()

#  Training related configurations 
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.0001)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=7, gamma=0.7)

#  Start training 
for epoch in range(50):
    train_tqdm = tqdm.tqdm(iterable=train_dl, total=len(train_dl))
    train_tqdm.set_description_str('Train epoch: {:2d}'.format(epoch))
    train_loss_sum = []
    for image, xmin, ymin, xmax, ymax in train_tqdm:
        image, xmin, ymin, xmax, ymax = image.to(devices), xmin.to(devices), ymin.to(devices), xmax.to(
            devices), ymax.to(devices)

        pred_xmin, pred_ymin, pred_xmax, pred_ymax = model(image)
        pred_xmin, pred_ymin, pred_xmax, pred_ymax = pred_xmin.squeeze(), pred_ymin.squeeze(), pred_xmax.squeeze(), pred_ymax.squeeze()

        loss_xmin = loss_fn(pred_xmin, xmin)
        loss_ymin = loss_fn(pred_ymin, ymin)
        loss_xmax = loss_fn(pred_xmax, xmax)
        loss_ymax = loss_fn(pred_ymax, ymax)

        loss = loss_xmin + loss_ymin + loss_xmax + loss_ymax


        with torch.no_grad():
        train_tqdm.set_postfix_str('loss is :{:14f}'.format(np.array(train_loss_sum).mean()))

    with torch.no_grad():
        test_tqdm = tqdm.tqdm(iterable=test_dl, total=len(test_dl))
        test_tqdm.set_description_str('Test epoch: {:2d}'.format(epoch))
        test_loss_sum = []
        for image, xmin, ymin, xmax, ymax in test_tqdm:
            image, xmin, ymin, xmax, ymax = image.to(devices), xmin.to(devices), ymin.to(devices), xmax.to(
                devices), ymax.to(devices)

            pred_xmin, pred_ymin, pred_xmax, pred_ymax = model(image)
            pred_xmin, pred_ymin, pred_xmax, pred_ymax = pred_xmin.squeeze(), pred_ymin.squeeze(), pred_xmax.squeeze(), pred_ymax.squeeze()

            loss_xmin = loss_fn(pred_xmin, xmin)
            loss_ymin = loss_fn(pred_ymin, ymin)
            loss_xmax = loss_fn(pred_xmax, xmax)
            loss_ymax = loss_fn(pred_ymax, ymax)

            loss = loss_xmin + loss_ymin + loss_xmax + loss_ymax

            test_tqdm.set_postfix_str('loss is :{:14f}'.format(np.array(test_loss_sum).mean()))

#  Save the model 
if not os.path.exists(os.path.join('model_data')):
torch.save(model.state_dict(), os.path.join('model_data', 'model.pth'))

6、 ... and 、 The prediction code of the model is as follows

import torch
from data_loader import load_data
from model_loader import Net
import os
from torchvision.utils import draw_bounding_boxes
import matplotlib.pyplot as plt

#  Data loading 
train_dl, test_dl = load_data()
image, xmin, ymin, xmax, ymax = next(iter(test_dl))

#  Model loading 
model = Net()
model_state_dict = torch.load(os.path.join('model_data', 'model.pth'), map_location='cpu')

#  Start Forecasting 
index = 0
with torch.no_grad():
    pred_xmin, pred_ymin, pred_xmax, pred_ymax = model(image)
    pred_xmin, pred_ymin, pred_xmax, pred_ymax = pred_xmin[index], pred_ymin[index], pred_xmax[index], pred_ymax[index]
    pre_boxes = [pred_xmin.item() * 224, pred_ymin.item() * 224, pred_xmax.item() * 224, pred_ymax.item() * 224]
    pre_boxes = torch.FloatTensor(pre_boxes)
    pre_boxes = torch.unsqueeze(input=pre_boxes, dim=0)
    label_boxes = [xmin[index].item() * 224, ymin[index].item() * 224, xmax[index].item() * 224,
                   ymax[index].item() * 224]
    label_boxes = torch.FloatTensor(label_boxes)
    label_boxes = torch.unsqueeze(input=label_boxes, dim=0)
    img = image[index]
    img = torch.as_tensor(data=img * 255, dtype=torch.uint8)
    result = draw_bounding_boxes(image=img, boxes=pre_boxes, colors='red')
    result = draw_bounding_boxes(image=result, boxes=label_boxes, colors='blue')
    plt.figure(figsize=(8, 8), dpi=500)
    plt.imshow(result.permute(1, 2, 0))

7、 ... and 、 The running result of the code is as follows

 Insert picture description here


本文为[Brother Shui is very water]所创,转载请带上原文链接,感谢
