当前位置:网站首页>Calculation of accuracy, recall rate, F1 value and accuracy rate of pytorch prediction results (simple implementation)

Calculation of accuracy, recall rate, F1 value and accuracy rate of pytorch prediction results (simple implementation)

2022-06-13 01:46:00 No change of name

1. Import the necessary library functions

import os
import json
import torch
from PIL import Image
from torchvision import transforms
#  Import your own model 
from model_v3 import mobilenet_v3_small

2. File storage format

imgs Under the file is the image to be predicted

label.txt Label corresponding to each picture

### It doesn't matter , It can be read normally

 3. Tag reading auxiliary function

targets_path = "test/label.txt"
    with open(targets_path, 'r') as file:
        targets = file.readlines()

 4.  precision 、 Recall rate 、F1 value 、 Accuracy calculation function

#  Get tag 
target = int(targets[i])
#  Output results 
output = torch.squeeze(model(img.to(device))).cpu()
#  Get the index of the maximum value 
predict = torch.softmax(output, dim=0)
#  Get the value 
predict_cla = torch.argmax(predict).item()
# 0 Is a positive class , 1 Is a negative class 
if predict_cla == 0 and target == 0:
    TP += 1
if predict_cla == 1 and target == 1:
    TN += 1
if predict_cla == 0 and target == 1:
    FP += 1
    print(str(imgs_path) + " " + str(file) + " is predicted wrong")
if predict_cla == 1 and target == 0:
    FN += 1
    print(str(imgs_path) + " " + str(file) + " is predicted wrong")

#  According to the formula , If you have any questions, you are welcome to support , Learning together 
P = TP / (TP + FP + esp)
R = TP / (TP + FN + esp)
F1 = 2 * P * R / (P + R + esp)
acc = (TP + TN) / (TP + TN + FP + FN + esp)

5. Complete implementation , The modification path can be used for your own model

import os
import json
import torch
from PIL import Image
from torchvision import transforms
from model_v3 import mobilenet_v3_small


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    data_transform = transforms.Compose(
        [transforms.Resize((224, 224)),
         transforms.ToTensor(),
         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])


    imgs_path = "I:/ZTC950V763_211118/CV/imgs/"
    targets_path = "I:/ZTC950V763_211118/CV/y_CV.txt"
    with open(targets_path, 'r') as file:
        targets = file.readlines()
    TP, TN, FP, FN = 0, 0, 0, 0
    esp = 1e-6
    i = 0
    for _, __, files in os.walk(imgs_path):
        for file in files:
            img = Image.open(imgs_path + str(file))
            target = int(targets[i])
            # plt.imshow(img)
            # [N, C, H, W]
            img = data_transform(img)
            # expand batch dimension
            img = torch.unsqueeze(img, dim=0)
            # read class_indict
            json_path = 'class_indices.json'
            assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

            json_file = open(json_path, "r")
            class_indict = json.load(json_file)

            # create model
            model = mobilenet_v3_small(num_classes=2).to(device)
            # load model weights
            model_weight_path = "C:/Users/00769111/PycharmProjects/mobilenet_juanyang/weights/No_freeze_MobileNetV3.pth"
            model.load_state_dict(torch.load(model_weight_path, map_location=device))
            model.eval()
            with torch.no_grad():
                # predict class
                output = torch.squeeze(model(img.to(device))).cpu()
                predict = torch.softmax(output, dim=0)
                predict_cla = torch.argmax(predict).item()
                if predict_cla == 0 and target == 0:
                    TP += 1
                if predict_cla == 1 and target == 1:
                    TN += 1
                if predict_cla == 0 and target == 1:
                    FP += 1
                    print(str(imgs_path) + " " + str(file) + " is predicted wrong")
                if predict_cla == 1 and target == 0:
                    FN += 1
                    print(str(imgs_path) + " " + str(file) + " is predicted wrong")
            i += 1
            if i % 200 == 0:
                P = TP / (TP + FP + esp)
                R = TP / (TP + FN + esp)
                F1 = 2 * P * R / (P + R + esp)
                acc = (TP + TN) / (TP + TN + FP + FN + esp)
                print(f" Accuracy of : {P}\n")
                print(f" The recall rate is : {R}\n")
                print(f"F1 The value is : {F1}\n")
                print(f" Accuracy rate is : {acc}")
        P = TP / (TP + FP + esp)
        R = TP / (TP + FN + esp)
        F1 = 2 * P * R / (P + R + esp)
        acc = (TP + TN) / (TP + TN + FP + FN + esp)
        print(" The results summary \n")
        print(f" Accuracy of : {P}\n")
        print(f" The recall rate is : {R}\n")
        print(f"F1 The value is : {F1}\n")
        print(f" Accuracy rate is : {acc}")


if __name__ == '__main__':
    main()

原网站

版权声明
本文为[No change of name]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/02/202202280550191646.html