IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> [pytorch] Kaggle图片分类比赛 ArcFace + bounding box 代码学习 -> 正文阅读

[人工智能][pytorch] Kaggle图片分类比赛 ArcFace + bounding box 代码学习

比赛中的数据包含来自 28 个不同研究机构的 30 个不同物种(鲸鱼和海豚)的 15,000 多只独特个体海洋哺乳动物的图像。比赛要求是对测试集个体id的分类。
kaggle 比赛数据详情及数据集下载:Happywhale - Whale and Dolphin Identification

代码链接

arcface是本次比赛中表现最好的方法之一。代码:

  1. Pytorch Metric Learning [effnet + arcface])
  2. Pytorch Train Notebook(ArcFace + GeM Pooling)

基础知识

[理论] 度量学习 Metric Learning

度量学习(Metric Learning)是机器学习过程中经常用到的一种方法,它可以借助一系列观测,构造出对应的度量函数,从而学习数据间的距离或差异,有效地描述样本之间的相似度。这个度量函数对于相似度高的观测值,会返回一个小的距离值;对于差异巨大的观测值,则会返回一个大的距离值。当样本量不大时,度量学习在处理分类任务的准确率和高效率上,展现出了显著优势。

然而,如果要处理的分类任务十分复杂,具有多类别、小样本等特征时,结合深度学习和度量学习的深度度量学习((Deep Metric Learning,简称 DML)),才是真正的王者。深度度量学习又被称为距离度量学习(Distance Metric Learning)。相较于度量学习,深度度量学习可以对输入特征做非线性映射。

通过训练一个基于 CNN 的非线性特征提取模块或编码器,深度度量学习可以将提取的图像特征(Embedding)嵌入到近邻位置,同时借助欧氏距离、cosine 等距离度量方法,将不同的图像特征区分开来。

深度度量学习在 CV 领域的一些极端分类任务(类别众多、样本量不足)中表现优异,应用遍及人脸识别、行人重识别、图像检索、目标跟踪、特征匹配等场景。

参考链接:

  1. 度量学习和pytorch-metric-learning的使用
  2. PyTorch 深度度量学习无敌 Buff:九大模块、随意调用
  3. 度量学习/对比学习入门: 论文阅读笔记-Deep Metric Learning: A Survey

[理论] bounding box 目标检测

在图像分类任务中,我们假设图像中只有一个主要物体对象,我们只关注如何识别其类别。 然而,很多时候图像里有多个我们感兴趣的目标,我们不仅想知道它们的类别,还想得到它们在图像中的具体位置。 在计算机视觉里,我们将这类任务称为目标检测(object detection)或目标识别(object recognition)。

在目标检测中,我们通常使用边界框(bounding box)来描述对象的空间位置。 边界框是矩形的,由矩形左上角的以及右下角的 x 和 y 坐标决定。 另一种常用的边界框表示方法是边界框中心的 (x,y) 轴坐标以及框的宽度和高度。
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

参考链接:
1.CNN: bounding box prediction 01 problem
2.CNN: bounding box prediction - specify bounding box
3.CNN: bounding box prediction - YOLO algo
4.CNN: 3.9 YOLO 算法 part1
5.CNN: 3.9 YOLO 算法 part2

[python] logging模块

那么在 Python 中,怎样才能算作一个比较标准的日志记录过程呢?或许很多人会使用 print 语句输出一些运行信息,然后再在控制台观察,运行的时候再将输出重定向到文件输出流保存到文件中,这样其实是非常不规范的,在 Python 中有一个标准的 logging 模块,我们可以使用它来进行标注的日志记录,利用它我们可以更方便地进行日志记录,同时还可以做更方便的级别区分以及一些额外日志信息的记录,如时间、运行模块信息等。
接下来我们先了解一下日志记录流程的整体框架。
在这里插入图片描述
在这里插入图片描述
参考链接:

  1. 是时候抛弃print了,开始体验下logging的强大吧!
  2. Python之日志处理(logging模块)

数据预处理

首先,根据我们数据统计的结果:[pytorch] Kaggle大型图像数据集 数据分析+可视化
数据图片的大小差异非常大,其中我们要检测的鲸鱼或海豚的位置也是乱七八糟
Things to know before starting image preprocessing
这里你可以看到一些极端案例

所以,图片处理的第一步就是确定海豚/鲸鱼的在图片中的位置,为此,我们使用了bounding box[YOLOv5].
Happywhale: BoundingBox [YOLOv5]
在这个代码中,我们将使用 YOLOv5 生成边界框。这么做的目的是为之后图像的crop提供方向,从而对大小各异的数据集图片进行裁剪,最终可以达到更好的分类结果.

我们使用 Whale Flute 数据集(另一个Kaggle竞赛数据,鲸鱼尾鳍定位)来训练和测试BoundingBox模型,我们总共有 1200 个带有边界框的样本。之后,我们将使用 Whale Flute 模型对我们的 Whale 和 Dolphin 数据集进行预测。

Whales Fluke 数据集中的边界框很大,而 Whales & Dolphin 数据集既有小边界框也有大边界框。 要调整此问题,您可以尝试更改 hyp.yaml 文件中的 scale 参数。 默认值为 0.5,您可以尝试增加该值。您也可以尝试将 bbox 放大,例如 1.5x 或 1.7x。 这将确保您不会裁剪到鲸鱼或海豚。

在确定好边界框的位置之后,我们继续对图像进行剪裁来得到我们分类所需要的图像
Happywhale: Cropped Dataset [YOLOv5]
在这里插入图片描述
最终,在调整大小之后,我们得到新的数据数据集图像。
数据集:JPEG Happywhale 384x384

代码详解

配置

!pip install timm
!pip install pytorch-metric-learning[with-hooks]

开源的度量学习库pytorch-metric-learning,集成了当前常用的各种度量学习方法,是一个非常好用的工具。

import os
import glob
import pandas as pd
import numpy as np
import logging
import timm
from tqdm.notebook import tqdm #进度条

import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader
from torchvision.io import ImageReadMode, read_image
from torchvision.transforms import Compose, Lambda, Normalize, AutoAugment, AutoAugmentPolicy

import pytorch_metric_learning
import pytorch_metric_learning.utils.logging_presets as LP
from pytorch_metric_learning.utils import common_functions
from pytorch_metric_learning import losses, miners, samplers, testers, trainers
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator
from pytorch_metric_learning.utils.inference import InferenceModel

for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler) #  remove exactly the preexisting handler object

logging.getLogger().setLevel(logging.INFO) # 获取logger实例 指定日志的最低输出级别
logging.info("VERSION %s" % pytorch_metric_learning.__version__) # 打印库版本
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device) # cuda:0
print(torch.cuda.get_device_name(0)) # NVIDIA RTX A6000

参数

MODEL_NAME='tf_efficientnet_b4_ns'
N_CLASSES=15587  #个体数
OUTPUT_SIZE = 1792
EMBEDDING_SIZE = 512
N_EPOCH=15
BATCH_SIZE=16
ACCUMULATION_STEPS = int(256 / BATCH_SIZE)
MODEL_LR = 1e-3
PCT_START=0.3
PATIENCE=5
N_WORKER=2
N_NEIGHBOURS = 750

读取csv数据

df = pd.read_csv('./happy-whale-and-dolphin/train.csv')
df.head()

在这里插入图片描述

df['label'] = df.groupby('individual_id').ngroup()
df['label'].describe()

在这里插入图片描述
实现了根据物种到标签数字的转化
在这里插入图片描述

  1. df.groupby
    groupby的过程就是将原有的DataFrame按照groupby的字段(这里是individual_id),划分为若干个分组DataFrame,被分为多少个组就有多少个分组DataFrame。Pandas教程 | 超好用的Groupby用法详解

  2. GroupBy.ngroup(self, ascending:bool = True) return=每个组的唯一编号。
    在这里插入图片描述

  3. 数据总结df.describe()
    会返回一个有多个行的所有数字列的统计表,每个行是一个统计指标,有总数、平均数、标准差、最大最小值、四分位数等,对我们初步了解数据还是很有作用。 如果是一个时间类型则会按时间相关的如开始结束时间、周期等信息。

划分数据集

训练集和验证集

valid_proportion = 0.1

valid_df = df.sample(frac=valid_proportion, replace=False, random_state=1).copy()
train_df = df[~df['image'].isin(valid_df['image'])].copy()

print(train_df.shape) # (45930, 4)
print(valid_df.shape) # (5103, 4)

Reset index on both since we want to use it for KNN lookups later:?

train_df.reset_index(drop=True, inplace=True)
valid_df.reset_index(drop=True, inplace=True)

读取图片数据

创建用于加载图像的dataset类。

class HappyWhaleDataset(Dataset):
    def __init__(
        self,
        df: pd.DataFrame,
        image_dir: str,
        return_labels=True,
    ):
        self.df = df
        self.images = self.df["image"]
        self.image_dir = image_dir
        self.image_transform = Compose(
            [
                AutoAugment(AutoAugmentPolicy.IMAGENET),
                Lambda(lambda x: x / 255),
                
            ]
        )
        self.return_labels = return_labels

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        
        image_path = os.path.join(self.image_dir, self.images.iloc[idx])
        image = read_image(path=image_path)
        image = self.image_transform(image)
        
        if self.return_labels:
            label = self.df['label'].iloc[idx] # iloc函数:通过行号来取行数据
            return image, label
        else:
            return image
train_dataset = HappyWhaleDataset(df=train_df, image_dir=TRAIN_DIR, return_labels=True)
len(train_dataset)#45930
valid_dataset = HappyWhaleDataset(df=valid_df, image_dir=TRAIN_DIR, return_labels=True)
len(valid_dataset)#5103
dataset_dict = {"train": train_dataset, "val": valid_dataset}

看一下训练集
在这里插入图片描述

未完待续…

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-02-28 15:28:52  更:2022-02-28 15:34:48 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/10 2:36:34-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码