IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> GHMC_Loss pytorch实现 -> 正文阅读

[人工智能]GHMC_Loss pytorch实现

GHMC_Loss pytorch实现

前言

做图像分割,想试试GHMC_Loss,但没找到一个完整,复杂度较好的代码,所以写了一个自我感觉良好的代码。贴出来方便小白参考和接受大佬指正错误.
学习原理可以看5分钟理解Focal Loss与GHM——解决样本不平衡利器

存在的问题

训练到后面会出现验证loss飙升但数据指标正常的情况,还是小白不清楚是不是哪里错了.

代码

import torch
import numpy as np
from torch import nn

class GHM_Loss(nn.Module):
    def __init__(self, bins, alpha, device, is_split_batch=True):
        super(GHM_Loss, self).__init__()
        self._bins = bins
        self._alpha = alpha
        self._last_bin_count = None
        self._device = device
        self.is_split_batch = is_split_batch
        self.is_evaluation = False

    def set_evaluation(self, is_evaluation):
        """
        评估时可以(也可以不管)将is_evaluation设为True,训练时设为False,这样就类似于直接计算CEL
        :param is_evaluation: bool
        :return:
        """
        self.is_evaluation = is_evaluation

    def _g2bin(self, g, bin):
        return torch.floor(g * (bin - 0.0001)).long()

    def _custom_loss(self, x, target, weight):
        raise NotImplementedError

    def _custom_loss_grad(self, x, target):
        raise NotImplementedError

    def use_alpha(self, bin_count):
        if (self._alpha != 0):
            if (self.is_evaluation):
                if (self._last_bin_count == None):
                    self._last_bin_count = bin_count
                else:
                    bin_count = self._alpha * self._last_bin_count + (1 - self._alpha) * bin_count
                    self._last_bin_count = bin_count
        return bin_count

    def forward(self, x, target):
        """
        :param x: torch.Tensor,[B,C,*]
        :param target: torch.Tensor,[B,*]
        :return: loss
        """
        g = torch.abs(self._custom_loss_grad(x, target)).detach()
        weight = torch.zeros((x.size(0), x.size(2), x.size(3)))
        if self.is_split_batch:
            #是否对每个batch分开统计梯度,我实验时发现分开统计loss会更容易收敛,可能因为模型中用了batch normalization?
            N = x.size(2) * x.size(3)
            bin = (int)(N // self._bins)
            bin_idx = self._g2bin(g, bin)
            bin_idx = torch.clamp(bin_idx, max=bin - 1)
            bin_count = torch.zeros((x.size(0), bin))
            for i in range(x.size(0)):
                bin_count[i] = torch.from_numpy(np.bincount(torch.flatten(bin_idx[i].cpu()), minlength=bin))
                bin_count[i] *= (bin_count[i] > 0).sum()

            bin_count = self.use_alpha(bin_count)
            gd = torch.clamp(bin_count, min=1)
            beta = N * 1.0 / gd
            for i in range(x.size(0)):
                weight[i] = beta[i][bin_idx[i]]
        else:
            N = x.size(0) * x.size(2) * x.size(3)
            bin = (int)(N // self._bins)
            bin_idx = self._g2bin(g, bin)
            bin_idx = torch.clamp(bin_idx, max=bin - 1)
            bin_count = torch.from_numpy(np.bincount(torch.flatten(bin_idx.cpu()), minlength=bin))
            bin_count *= (bin_count > 0).sum()

            bin_count = self.use_alpha(bin_count)
            gd = torch.clamp(bin_count, min=1)
            beta = N * 1.0 / gd
            weight = beta[bin_idx]

        return self._custom_loss(x, target, weight)


class GHMC_Loss(GHM_Loss):
    def __init__(self, bins, alpha, device, num_classes, ignore_classes=None, class_weights=None, is_split_batch=True):
        """
        :param bins: int。不是bin,这里将取数据[B,C,X,Y]的size计算bin=[B*]X*Y,B不一定乘
        :param alpha: float。
        :param device:
        :param num_classes: int。分类数量。
        :param ignore_classes: [int]。不计算的
        :param class_weights: torch.Tensor,每个类型的权重
        :param is_split_batch: bool,是否分离batch统计
        """
        super(GHMC_Loss, self).__init__(bins, alpha, device, is_split_batch)
        self.num_classes = num_classes
        self.ignore_classes = ignore_classes
        self.class_weights = class_weights

    def _custom_loss(self, x, target, weight):
        """
        计算loss
        :param x: torch.Tensor,[B,C,*]
        :param target: torch.Tensor,[B,*]
        :param weight: torch.Tensor,[B,C,*]
        :return: loss
        """
        if (self.is_evaluation):
            return torch.mean(
                (torch.nn.NLLLoss(weight=self.class_weights, reduction='none')(
                    torch.log_softmax(x, 1), target)))
        else:
            return torch.mean(
                (torch.nn.NLLLoss(weight=self.class_weights, reduction='none')(
                    torch.log_softmax(x, 1), target)).mul(weight.to(self._device).detach()))

    def _custom_loss_grad(self, x, target):
        """
        统计梯度
        :param x: torch.Tensor,[B,C,*]
        :param target: torch.Tensor,[B,*]
        :return: 梯度信息
        """
        g = (torch.softmax(x, 1).detach() - make_one_hot(target.unsqueeze(1), self.num_classes).to(self._device)). \
            gather(1, target.unsqueeze(1)).squeeze(1)
        if self.ignore_classes != None:
            a = torch.tensor(0.0, dtype=torch.float32).to(self._device)
            for class_id in self.ignore_classes:
                g = torch.where(target != class_id, g, a)
        return g

def make_one_hot(input, num_classes):
    """Convert class index tensor to one hot encoding tensor.
    Args:
         input: A tensor of shape [N, 1, *]
         num_classes: An int of number of class
    Returns:
        A tensor of shape [N, num_classes, *]
    """
    # input=torch.squeeze(input,dim=-1)
    shape = np.array(input.shape)
    shape[1] = num_classes
    shape = tuple(shape)
    result = torch.zeros(shape)
    result = result.scatter_(1, input.cpu(), 1)

    return result
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-08-21 15:22:26  更:2021-08-21 15:29:38 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/1 12:25:37-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码