IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 关于DG(域泛化)领域的PCL方法的代码实例 -> 正文阅读

[人工智能]关于DG(域泛化)领域的PCL方法的代码实例

分享一下文章PCL: Proxy-based Contrastive Learning for Domain Generalization,代码已经在GitHub上已经开源,其使用的是在DomainBed框架基础上实现的优化框架SWAD上改进的框架本文主要就放一些精简的代码,对于每个模块只保留了一个算法。
DomainBed库主要是为了DG领域的多种方法的实现,所以框架写的很复杂,封装了很多东西,对初次使用的同学真的很不友好,甚至可能连输入输出都看不懂,如果对DG和DA感兴趣的同学,这里推荐一个大佬实现的DA和DG的库,迁移学习代码库,比较容易看懂!!
话不多说,直接上代码

主文件

main.py

import torch
import algorithm
from torch.autograd import Variable
from torchvision import datasets, transforms
#使用swad调优的话,异步文章最后
#import swa_utils
#import swad as swad_module
import torch.nn as nn

train_transforms= transforms.Compose([
    transforms.Resize(256),
    transforms.RandomRotation((5), expand=True),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),

    transforms.ColorJitter(.3, .3, .3, .3),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
val_dataTrans = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

train_data_dir = '../../data/train'
val_data_dir = '../../data/val'
test_data_dir='../../data/test'

train_dataset = datasets.ImageFolder(train_data_dir, train_transforms)
val_dataset = datasets.ImageFolder(val_data_dir,val_dataTrans)
test_dataset = datasets.ImageFolder(test_data_dir, val_dataTrans)
#根据需要可重新划分数据集
# train_dataset = torch.utils.data.ConcatDataset([train_dataset1, test_dataset])
# val_dataset = datasets.ImageFolder(val_data_dir, _dataTrans)
# val_dataset, val_dataset_ = torch.utils.data.random_split(val_dataset, [5, len(val_dataset) - 5])
# train_dataset = torch.utils.data.ConcatDataset([train_dataset, val_dataset_])

train_dataloder = torch.utils.data.DataLoader(train_dataset,batch_size=4,shuffle=True)
val_dataloder = torch.utils.data.DataLoader(val_dataset,batch_size=4,shuffle=True)
device = "cuda" if torch.cuda.is_available() else"cpu"
# setup hparams
algorithm = algorithm.ERM(input_shape=[3, 244, 244], num_classes=4)
use_swad=False#是否使用swad优化
#优化器的选择在algorithm文件里
if use_swad:
    swad_algorithm = swa_utils.AveragedModel(algorithm)
    swad_cls = getattr(swad_module, 'LossValley')
    swad_kwargs={'n_converge': 3, 'n_tolerance': 6, 'tolerance_ratio': 0.3}
    swad = swad_cls(**swad_kwargs)
algorithm.to(device)
lossfunc=nn.CrossEntropyLoss()
epochs=10
if __name__ == '__main__':

    for epoch in range(epochs):
        running_loss = 0
        running_corrects = 0
        algorithm.train()
        for step ,(inputs, labels) in enumerate(train_dataloder):
            inputs = Variable(inputs.cuda())
            labels = Variable(labels.cuda())
            step_vals = algorithm.update(inputs, labels)
            if use_swad:
                swad_algorithm.update_parameters(algorithm, step=step)
            _, outputs = algorithm.predict(inputs)
            _, preds = torch.max(outputs.data, 1)
            train_loss = lossfunc(outputs, labels)
            # statistics
            running_loss += loss.data
            train_acc=torch.sum(preds == labels.data).cpu().to(torch.float32)
            running_corrects += train_acc


        tr_epoch_loss = running_loss / len(train_dataloder)
        tr_epoch_acc = running_corrects / len(train_dataloder)
        print('{} Loss: {:.4f} Acc: {:.4f}'.format(epoch, tr_epoch_loss, tr_epoch_acc))
        with torch.no_grad():
            algorithm.eval()
            running_loss = 0
            running_corrects = 0
            for step, (inputs, labels) in enumerate(val_dataloder ):
                inputs = Variable(inputs.cuda())
                labels = Variable(labels.cuda())
                _, outputs = algorithm.predict(inputs)
                _, preds = torch.max(outputs.data, 1)
                loss = lossfunc(outputs, labels)
                # statistics
                running_loss += loss.data
                val_acc = torch.sum(preds == labels.data).cpu().to(torch.float32)
                running_corrects += val_acc
            te_epoch_loss = running_loss / len(val_dataloder)
            te_epoch_acc = running_corrects / len(val_dataloder)
        if use_swad:
            swad.update_and_evaluate(swad_algorithm, te_epoch_acc)
            swad_algorithm = swa_utils.AveragedModel(algorithm)  # reset
        filename = r'epoch{}_Loss{:.4f}_Acc{:.4f}_Loss{:.4f}_Acc{:.4f}.pth'.format(
            epoch, tr_epoch_loss, tr_epoch_acc, te_epoch_loss, te_epoch_acc)
        torch.save(algorithm.state_dict(), filename, _use_new_zipfile_serialization=False)

主算法

主算法部分,这里使用的是Empirical Risk Minimization (ERM, Vapnik, 1998),原DomainBed框架提供了很多算法,如IRMGroupDRORSC等,可以根据需要自行取用。
————————————————————
algorithm.py

import math
from model import *
from losses import ProxyLoss, ProxyPLoss
import torch


class ERM(torch.nn.Module):
    """
    Empirical Risk Minimization (ERM)
    """

    def __init__(self, input_shape, num_classes):
        super(ERM, self).__init__()
        self.encoder, self.scale, self.pcl_weights = encoder()
        self._initialize_weights(self.encoder)
        self.fea_proj, self.fc_proj = fea_proj()
        nn.init.kaiming_uniform_(self.fc_proj, mode='fan_out', a=math.sqrt(5))
        self.featurizer = ResNet()

        self.classifier = nn.Parameter(torch.FloatTensor(num_classes,256))
        nn.init.kaiming_uniform_(self.classifier, mode='fan_out', a=math.sqrt(5))

        self.optimizer = torch.optim.Adam([
            {'params': self.featurizer.parameters()},
            {'params': self.encoder.parameters()},
            {'params': self.fea_proj.parameters()},
            {'params': self.fc_proj},
            {'params': self.classifier},
        ], lr=0.002, weight_decay=0.0)

        self.proxycloss = ProxyPLoss(num_classes=num_classes, scale=self.scale)

    def _initialize_weights(self, modules):
        for m in modules:
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                n = m.weight.size(1)
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()

    def update(self, x, y, **kwargs):
        all_x = x
        all_y = y
        rep, pred = self.predict(all_x)
        loss_cls = F.nll_loss(F.log_softmax(pred, dim=1), all_y)

        fc_proj = F.linear(self.classifier, self.fc_proj)
        assert fc_proj.requires_grad == True
        loss_pcl = self.proxycloss(rep, all_y, fc_proj)

        loss = loss_cls + self.pcl_weights * loss_pcl

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return {"loss_cls": loss_cls.item(), "loss_pcl": loss_pcl.item()}

    def predict(self, x):
        x = self.featurizer(x)
        x = self.encoder(x)
        rep = self.fea_proj(x)
        pred = F.linear(x, self.classifier)

        return rep, pred


网络结构

接下来是主要的网络结构模块,这里是用的ResNet50进行图片的特征提取,然后用全连接层进行encoder
————————————————————
model.py


import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models


class Identity(nn.Module):
    """An identity layer"""

    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, x):
        return x
class SqueezeLastTwo(nn.Module):
    """
    A module which squeezes the last two dimensions,
    ordinary squeeze can be a problem for batch size 1
    """

    def __init__(self):
        super(SqueezeLastTwo, self).__init__()

    def forward(self, x):
        return x.view(x.shape[0], x.shape[1])
class ResNet(torch.nn.Module):
    """ResNet with the softmax chopped off and the batchnorm frozen"""

    def __init__(self):
        super(ResNet, self).__init__()
        #如果要用其他的网络进行特征提取,可以在这里改
        #但是要把下面encoder模块的全连接的层的输入和新的网络最后的全连接输出相同
        network = torchvision.models.resnet50(pretrained=False)
        # network = resnet50(pretrained=hparams["pretrained"])
        self.network = network

        # adapt number of channels

        # save memory
        # del self.network.fc
        #把新的网络的输出层替换为空,用来提供encoder的接口
        #tips;最后一层大部分都是model.fc或model.head
        self.network.fc = Identity()
        self.dropout = nn.Dropout(0.1)
        self.freeze_bn()

    def forward(self, x):
        """Encode x into a feature vector of size n_outputs."""
        return self.dropout(self.network(x))

    def train(self, mode=True):
        """
        Override the default train() to freeze the BN parameters
        """
        super().train(mode)
        self.freeze_bn()

    def freeze_bn(self):
        for m in self.network.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()


def encoder():
    scale_weights = 12
    pcl_weights = 1
    dropout = nn.Dropout(0.25)
    hidden_size = 512
    out_dim = 256
    #换了新网络要注意改这里
    n_outputs = 2048
    encoder = nn.Sequential(
        nn.Linear(n_outputs, hidden_size),
        nn.BatchNorm1d(hidden_size),
        nn.ReLU(inplace=True),
        dropout,
        nn.Linear(hidden_size, out_dim),
        )

    return encoder, scale_weights, pcl_weights


def fea_proj():
    dropout = nn.Dropout(0.25)
    hidden_size = 256
    out_dim = 256
    fea_proj = nn.Sequential(
        nn.Linear(out_dim,
                  out_dim),
    )
    fc_proj = nn.Parameter(
        torch.FloatTensor(out_dim,
                          out_dim)
    )

    return fea_proj, fc_proj

损失函数

PCL中提到的损失函数
————————————————————————
losses.py

# coding: utf-8

'''
custom loss function
'''

import math
import numpy as np

import torch
import torch.nn as nn

import torch.nn.functional as F

# # =========================  proxy Contrastive loss ==========================
class ProxyLoss(nn.Module):
	'''
	pass
	'''

	def __init__(self, scale=1, thres=0.1):
		super(ProxyLoss, self).__init__()
		self.scale = scale
		self.thres = thres

	def forward(self, feature, pred, target):
		feature = F.normalize(feature, p=2, dim=1)  # normalize
		feature = torch.matmul(feature, feature.transpose(1, 0))  # (B, B)
		label_matrix = target.unsqueeze(1) == target.unsqueeze(0)
		feature = feature * ~label_matrix  # get negative matrix
		feature = feature.masked_fill(feature < self.thres, -np.inf)
		pred = torch.cat([pred, feature], dim=1)  # (N, C+N)

		loss = F.nll_loss(F.log_softmax(self.scale * pred, dim=1), \
		                  target)

		return loss


class ProxyPLoss(nn.Module):
	'''
	pass
	'''
	
	def __init__(self, num_classes, scale):
		super(ProxyPLoss, self).__init__()
		self.soft_plus = nn.Softplus()
		self.label = torch.LongTensor([i for i in range(num_classes)]).cuda()
		self.scale = scale
	
	def forward(self, feature, target, proxy):
		feature = F.normalize(feature, p=2, dim=1)
		pred = F.linear(feature, F.normalize(proxy, p=2, dim=1))  # (N, C)
		
		label = (self.label.unsqueeze(1) == target.unsqueeze(0))
		pred = torch.masked_select(pred.transpose(1, 0), label)
		pred = pred.unsqueeze(1)
		
		feature = torch.matmul(feature, feature.transpose(1, 0))  # (N, N)
		label_matrix = target.unsqueeze(1) == target.unsqueeze(0)
		
		index_label = torch.LongTensor([i for i in range(feature.shape[0])])  # generate index label
		index_matrix = index_label.unsqueeze(1) == index_label.unsqueeze(0)  # get index matrix
		
		feature = feature * ~label_matrix  # get negative matrix
		feature = feature.masked_fill(feature < 1e-6, -np.inf)
		
		logits = torch.cat([pred, feature], dim=1)  # (N, C+N)
		label = torch.zeros(logits.size(0), dtype=torch.long).cuda()
		loss = F.nll_loss(F.log_softmax(self.scale * logits, dim=1), label)
		
		return loss


class PosAlign(nn.Module):
	'''
	pass
	'''
	
	def __init__(self):
		super(PosAlign, self).__init__()
		self.soft_plus = nn.Softplus()
	
	def forward(self, feature, target):
		feature = F.normalize(feature, p=2, dim=1)
		
		feature = torch.matmul(feature, feature.transpose(1, 0))  # (N, N)
		label_matrix = target.unsqueeze(1) == target.unsqueeze(0)
		
		positive_pair = torch.masked_select(feature, label_matrix)
		
		# print("positive_pair.shape", positive_pair.shape)
		
		loss = 1. * self.soft_plus(torch.logsumexp(positive_pair, 0))
		
		return loss

SWAD调参

如果要使用SWAD调参,可以从去GitHub自取swadswa_utils

以后有时间再逐行分析吧…

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-08-06 10:44:55  更:2022-08-06 10:47:43 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/29 8:38:43-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码
数据统计