IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> pytorch预测结果的精度、召回率、F1值、准确率计算(简单实现) -> 正文阅读

[人工智能]pytorch预测结果的精度、召回率、F1值、准确率计算(简单实现)

1. 导入必要的库函数

import os
import json
import torch
from PIL import Image
from torchvision import transforms
# 导入自己的模型
from model_v3 import mobilenet_v3_small

2. 文件存储格式

imgs文件下下是需要预测的图像

label.txt为每幅图对应的标签

### 这个无所谓,能正常读取就行

?3. 标签读取辅助函数

targets_path = "test/label.txt"
    with open(targets_path, 'r') as file:
        targets = file.readlines()

?4.?精度、召回率、F1值、准确率计算函数

# 获取标签
target = int(targets[i])
# 输出结果
output = torch.squeeze(model(img.to(device))).cpu()
# 取出最大值的索引
predict = torch.softmax(output, dim=0)
# 得到数值
predict_cla = torch.argmax(predict).item()
# 0为正类, 1为负类
if predict_cla == 0 and target == 0:
    TP += 1
if predict_cla == 1 and target == 1:
    TN += 1
if predict_cla == 0 and target == 1:
    FP += 1
    print(str(imgs_path) + " " + str(file) + " is predicted wrong")
if predict_cla == 1 and target == 0:
    FN += 1
    print(str(imgs_path) + " " + str(file) + " is predicted wrong")

# 根据公式写的,如果有问题欢迎大家支持,一起学习
P = TP / (TP + FP + esp)
R = TP / (TP + FN + esp)
F1 = 2 * P * R / (P + R + esp)
acc = (TP + TN) / (TP + TN + FP + FN + esp)

5. 完整实现,修改路径可以用于自己的模型

import os
import json
import torch
from PIL import Image
from torchvision import transforms
from model_v3 import mobilenet_v3_small


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    data_transform = transforms.Compose(
        [transforms.Resize((224, 224)),
         transforms.ToTensor(),
         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])


    imgs_path = "I:/ZTC950V763_211118/CV/imgs/"
    targets_path = "I:/ZTC950V763_211118/CV/y_CV.txt"
    with open(targets_path, 'r') as file:
        targets = file.readlines()
    TP, TN, FP, FN = 0, 0, 0, 0
    esp = 1e-6
    i = 0
    for _, __, files in os.walk(imgs_path):
        for file in files:
            img = Image.open(imgs_path + str(file))
            target = int(targets[i])
            # plt.imshow(img)
            # [N, C, H, W]
            img = data_transform(img)
            # expand batch dimension
            img = torch.unsqueeze(img, dim=0)
            # read class_indict
            json_path = 'class_indices.json'
            assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

            json_file = open(json_path, "r")
            class_indict = json.load(json_file)

            # create model
            model = mobilenet_v3_small(num_classes=2).to(device)
            # load model weights
            model_weight_path = "C:/Users/00769111/PycharmProjects/mobilenet_juanyang/weights/No_freeze_MobileNetV3.pth"
            model.load_state_dict(torch.load(model_weight_path, map_location=device))
            model.eval()
            with torch.no_grad():
                # predict class
                output = torch.squeeze(model(img.to(device))).cpu()
                predict = torch.softmax(output, dim=0)
                predict_cla = torch.argmax(predict).item()
                if predict_cla == 0 and target == 0:
                    TP += 1
                if predict_cla == 1 and target == 1:
                    TN += 1
                if predict_cla == 0 and target == 1:
                    FP += 1
                    print(str(imgs_path) + " " + str(file) + " is predicted wrong")
                if predict_cla == 1 and target == 0:
                    FN += 1
                    print(str(imgs_path) + " " + str(file) + " is predicted wrong")
            i += 1
            if i % 200 == 0:
                P = TP / (TP + FP + esp)
                R = TP / (TP + FN + esp)
                F1 = 2 * P * R / (P + R + esp)
                acc = (TP + TN) / (TP + TN + FP + FN + esp)
                print(f"精度为: {P}\n")
                print(f"召回率为: {R}\n")
                print(f"F1值为: {F1}\n")
                print(f"准确率为: {acc}")
        P = TP / (TP + FP + esp)
        R = TP / (TP + FN + esp)
        F1 = 2 * P * R / (P + R + esp)
        acc = (TP + TN) / (TP + TN + FP + FN + esp)
        print("结果汇总\n")
        print(f"精度为: {P}\n")
        print(f"召回率为: {R}\n")
        print(f"F1值为: {F1}\n")
        print(f"准确率为: {acc}")


if __name__ == '__main__':
    main()

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-11-23 12:20:39  更:2021-11-23 12:24:00 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/27 4:17:24-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码