IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> PointRend使用记录 -> 正文阅读

[人工智能]PointRend使用记录

下面是PointRend的源码位置,接下来先跑下看看

GitHub - zsef123/PointRend-PyTorch: A PyTorch implementation of PointRend: Image Segmentation as Renderinghttps://github.com/zsef123/PointRend-PyTorch

(1)数据准备?

数据就用公共数据集CamVid,该数据集加背景0共12个类,标签值为0-11,下面是一级目标,目录结构及文件名务必保持一致,因为我后面在数据读取的时候添加了读自己数据集的数据导入函数,文件夹名字是固定了的,当然你也可以改代码。

?二级目录,train/val/test,目录结构需要一致,另外如果test只有图像也可以不要labels文件夹

?(2)添加自己的数据加载模块

在?__init__.py文件中添加了get_own函数,加完以后在get_loader函数添加自己数据的引导,另外需要强调下,我自己添加的数据加载没有专门加数据扩充策略,你们自己加下,加了效果应该会好点。

__init__.py代码:?

import os
import cv2
from PIL import Image
import numpy as np
import torch
from torch.utils.data import DataLoader, DistributedSampler
from torchvision.datasets.voc import VOCSegmentation
from torchvision.datasets.cityscapes import Cityscapes

from .transforms import Compose, Resize, ToTensor, Normalize, RandomCrop, RandomFlip, ConvertMaskID


def get_voc(C, split="train"):
    if split == "train":
        transforms = Compose([
            ToTensor(),
            RandomCrop((256, 256)),
            Resize((256, 256)),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
    else:
        transforms = Compose([
            ToTensor(),
            Resize((256, 256)),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    return VOCSegmentation(C['root'], download=True, image_set=split, transforms=transforms)


def get_cityscapes(C, split="train"):
    if split == "train":
        # Appendix B. Semantic Segmentation Details
        transforms = Compose([
            ToTensor(),
            RandomCrop(768),
            ConvertMaskID(Cityscapes.classes),
            RandomFlip(),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    else:
        transforms = Compose([
            ToTensor(),
            Resize(768),
            ConvertMaskID(Cityscapes.classes),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    return Cityscapes(**C, split=split, transforms=transforms)

class get_own(torch.utils.data.Dataset):

    def __init__(self, C, split="train"):
        images_path = os.path.join(C['root'], split, 'images')
        labels_path = os.path.join(C['root'], split, 'labels')

        images_path_list = []
        labels_path_list = []

        imgs = os.listdir(images_path)
        for name in imgs:
            img_full_path = os.path.join(images_path, name)
            lab_full_path = os.path.join(labels_path, name)
            images_path_list.append(img_full_path)
            labels_path_list.append(lab_full_path)

        self.images_path_list = images_path_list
        self.labels_path_list = labels_path_list

        if split == "train":
            # Appendix B. Semantic Segmentation Details
            Transform = Compose([
                ToTensor(),
                # RandomCrop(256),
                # ConvertMaskID(Cityscapes.classes),
                # RandomFlip()
                # Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
        else:
            Transform = Compose([
                ToTensor(),
                # Resize(256),
                # ConvertMaskID(Cityscapes.classes),
                # Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
        
        self.transform = Transform
       
    def __getitem__(self,index):  
        image_path = self.images_path_list[index]
        label_path = self.labels_path_list[index]

        # image = Image.open(image_path).convert('RGB')
        # label = Image.open(label_path)
        image = cv2.imread(image_path)
        label = cv2.imread(label_path, 0)
        image = np.array(image, np.float32) / 255.0
        label = np.array(label, np.float32)
        # image = self.transform('images':image)
        # label = self.transform('masks':label)
        image, label = self.transform(image, label)
        # image = image.type(torch.FloatTensor)
        # label = label.type(torch.FloatTensor)
             
        return image, label
        
    def __len__(self):
        return len(self.images_path_list)


def get_loader(C, split, distributed):
    """
    Args:
        C (Config): C.data
        split (str): args of dataset,
                    The image split to use, ``train``, ``test`` or ``val`` if split="gtFine"
                    otherwise ``train``, ``train_extra`` or ``val`
    """
    print(C.name, C.dataset, split)
    if C.name == "cityscapes":
        dset = get_cityscapes(C.dataset, split)
    elif C.name == "pascalvoc":
        dset = get_voc(C.dataset, split)
    elif C.name == "own":
        dset = get_own(C.dataset, split)

    if split == "train":
        shuffle = True
        drop_last = False
    else:
        shuffle = False
        drop_last = False

    sampler = None
    if distributed:
        sampler = DistributedSampler(dset, shuffle=shuffle)
        shuffle = None

    return DataLoader(dset, **C.loader, sampler=sampler,
                      shuffle=shuffle, drop_last=drop_last,
                      pin_memory=True)

?(3)训练

?这个GitHub项目结构比较好,训练模块在train.py中,不需要改,主要改main.py文件中的部分东西,由于这个项目用了apex来加速训练,而我这里安装不方便,还报错了,我main.py的主要改动就是注释掉apex相关的部分。

main.py代码

import os
import sys
import argparse
import logging
from tokenize import Double
from configs.parser import Parser

import torch

# from apex import amp
# from apex.parallel import DistributedDataParallel as ApexDDP

from model import deeplabv3, PointHead, PointRend
from datas import get_loader
from train import train
from utils.gpus import synchronize, is_main_process


def parse_args():
    parser = argparse.ArgumentParser(description="PyTorch Object Detection Training")
    parser.add_argument("--config", type=str, default="./configs/default.yaml", help="It must be config/*.yaml")  #yaml文件是必要的配置文件,后面会简要说明
    parser.add_argument("--save", type=str, default="build", help="Save path in out directory")
    parser.add_argument("--local_rank", type=int, default=0, help="Using for Apex DDP")
    return parser.parse_args()


def amp_init(args):
    # Apex Initialize
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1

    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend="nccl", init_method="env://")
        synchronize()

    torch.backends.cudnn.benchmark = True


def set_loggging(save_dir):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout, level=logging.INFO,
                        format=log_format, datefmt='[%y/%m/%d %H:%M:%S]')

    fh = logging.FileHandler(f"{save_dir}/log.txt")
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)


if __name__ == "__main__":
    args = parse_args()
    amp_init(args)

    parser = Parser(args.config)
    C = parser.C
    save_dir = f"{os.getcwd()}/outs/{args.save}"

    if is_main_process():
        if not os.path.exists(save_dir):
            os.makedirs(save_dir, mode=0o775)

        parser.dump(f"{save_dir}/config.yaml")

        set_loggging(save_dir)

    device = torch.device("cuda")
    train_loader = get_loader(C.data, "train", distributed=args.distributed)
    valid_loader = get_loader(C.data, "val", distributed=args.distributed)

    net = PointRend(
        deeplabv3(**C.net.deeplab),
        PointHead(**C.net.pointhead)
    ).to(device)

    params = [{"params": net.backbone.backbone.parameters(),   "lr": float(C.train.lr)},
              {"params": net.head.parameters(),                "lr": float(C.train.lr)},
              {"params": net.backbone.classifier.parameters(), "lr": float(C.train.lr) * 10}]

    # optim = torch.optim.SGD(params, momentum=C.train.momentum, weight_decay=C.train.weight_decay)
    #这里尝试了用adamw优化器训练
    optim = torch.optim.AdamW(params, lr=float(C.train.lr), weight_decay=float(C.train.weight_decay))
    
    #这里注释了需要apex加速的模块
    # net, optim = amp.initialize(net, optim, opt_level=C.apex.opt)
    # if args.distributed:
    #     net = ApexDDP(net, delay_allreduce=True)

    train(C.run, save_dir, train_loader, valid_loader, net, optim, device)



#Apex混合精度加速 介绍:为了帮助提高Pytorch的训练效率,英伟达提供了混合精度训练工具Apex。
# 号称能够在不降低性能的情况下,将模型训练的速度提升2-4倍,训练显存消耗减少为之前的一半。
# 该项目开源于:https://github.com/NVIDIA/apex ,文档地址是:https://nvidia.github.io/apex/index.html该工具提供了三个功能,amp、parallel和normalization。

训练用的default.yaml文件

data:
  name: "own"
  dataset:
    root: "./datasets/CamVid/"
    mode: "fine"
    target_type: "semantic"
  loader:
    batch_size: 5
    num_workers: 0

net:
  deeplab:
    pretrained: False
    resnet: "res101"
    head_in_ch: 2048
    num_classes: 12
  pointhead:
    in_c: 524 # 512 + num_classes
    num_classes: 12
    k: 3
    beta: 0.75

run:
  epochs: 101

train:
  lr: 1e-3       
  momentum: 0.9
  weight_decay: 1e-3

apex:
  opt: "O0"

?

(4)预测

原始的预测用的是infer.py文件,这个预测要加载标签,而且会给出精度评价,我考虑到会有直接预测而不加标签预测的情况,改了一个预测代码

predict.py代码:

import os
import time
import logging
import cv2
from PIL import Image
import numpy as np
import torch
import argparse
from configs.parser import Parser
from model import deeplabv3, PointHead, PointRend
from utils.metrics import ConfusionMatrix
from utils.gpus import synchronize, is_main_process

@torch.no_grad()
def infer(loader, net, device):
    net.eval()
    num_classes = 2 # Hard coding for Cityscapes
    metric = ConfusionMatrix(num_classes)
    for i, (x, gt) in enumerate(loader):
        x = x.to(device, non_blocking=True)
        gt = gt.squeeze(1).to(device, dtype=torch.long, non_blocking=True)

        pred = net(x)["fine"].argmax(1)

        metric.update(pred, gt)

    mIoU = metric.mIoU()
    logging.info(f"[Infer] mIOU : {mIoU}")
    return mIoU

def parse_args():
    parser = argparse.ArgumentParser(description="PyTorch Object Detection Training")
    parser.add_argument("--config", type=str, default="./configs/default.yaml", help="It must be config/*.yaml")
    parser.add_argument("--save", type=str, default="build", help="Save path in out directory")
    parser.add_argument("--local_rank", type=int, default=0, help="Using for Apex DDP")
    return parser.parse_args()

def amp_init(args):
    # Apex Initialize
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1

    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend="nccl", init_method="env://")
        synchronize()

    torch.backends.cudnn.benchmark = True

def predict(data_path, model_path, net, save_path):
    net.load_state_dict(torch.load(model_path))
    net.eval()
    img_names = os.listdir(data_path)
    for ele in img_names:
        full_path = os.path.join(data_path, ele)
        # image = Image.open(full_path).convert('RGB')
        image = cv2.imread(full_path)
        image = np.array(image, np.float32) / 255.0
        # image = np.array(image)
        image = image.transpose(2,0,1)
        image = np.expand_dims(image, axis=0)
        # image = torch.from_numpy(image)
        image = torch.FloatTensor(image)
        x = image.to(device, non_blocking=True)
        pred = net(x)["fine"].argmax(1)
        # pred = net(x)["fine"]
        save_full_path = os.path.join(save_path, ele)
        pred = pred.cpu().numpy()
        cv2.imwrite(save_full_path, pred[0])


if __name__ == "__main__":
    path = './datasets/CamVid/test/'
    save_path = './datasets/pred/'
    model_path = './outs/CamVid/epoch_0100_loss_0.54185.pth'
    args = parse_args()
    amp_init(args)

    parser = Parser(args.config)
    C = parser.C

    device = torch.device("cuda")
    net = PointRend(
        deeplabv3(**C.net.deeplab),
        PointHead(**C.net.pointhead)
    ).to(device)

    predict(path, model_path, net, save_path)

效果:

??

? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? 图像? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? 标签

?预测结果

从结果看,很明显效果不理想,不过不要太过悲观,因为我去掉了加速模块,这个训练有点慢,我训练了100个epoch就停掉了,并且数据也没有做增强,效果肯定还可以提高的。

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-04-18 17:43:23  更:2022-04-18 17:46:55 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/8 3:29:50-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码