| |
|
开发:
C++知识库
Java知识库
JavaScript
Python
PHP知识库
人工智能
区块链
大数据
移动开发
嵌入式
开发工具
数据结构与算法
开发测试
游戏开发
网络协议
系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程 数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁 |
-> 人工智能 -> VIT训练 -> 正文阅读 |
|
[人工智能]VIT训练 |
加载imagenet预训练模型: 首先需要安装timm库: pip install timm timm库使用的话参考https://rwightman.github.io/pytorch-image-models/ 这里加载了输入图像大小为224*224,patch大小为16*16的vit模型 from timm import create_model as creat model = creat('vit_base_patch16_224', pretrained=True, num_classes=1000) 从头开始训练模型的话需要安装vit-pytorch库: pip install vit-pytorch 这里建议添加豆瓣镜像源来加速下载: pip config set global.index-url https://pypi.douban.com/simple/ ?# 统一更改为豆瓣源(用户) pip config set global.index-url https://pypi.python.org/simple/??#恢复默认 这里写了个数据加载模块来实现对于imagenet的验证集无类别标签文件夹的情况的图像标签读取 class ImageTxtDataset(data.Dataset): def __init__(self, txt_path: str, folder_name, transform): self.transform = transform self.data_dir = os.path.dirname(txt_path) self.imgs_path = [] self.labels = [] self.folder_name = folder_name with open(txt_path, 'r') as f: lines = f.readlines() for line in lines: img_path, label = line.split() label = int(label.strip()) img_path = os.path.join(self.data_dir, self.folder_name, img_path) self.labels.append(label) self.imgs_path.append(img_path) def __len__(self): return len(self.imgs_path) def __getitem__(self, i): path, label = self.imgs_path[i], self.labels[i] image = Image.open(path).convert("RGB") if self.transform is not None: image = self.transform(image) return image, label 完整代码如下: import argparse from timm import create_model as creat import torch import torch.nn as nn import torch.utils.data as data import matplotlib.pyplot as plt import os from PIL import Image import random import shutil import time import warnings import torch.nn.parallel import torch.backends.cudnn as cudnn import torch.distributed as dist import torch.optim import torch.multiprocessing as mp import torch.utils.data import torch.utils.data.distributed import torchvision.transforms as transforms plt.ion() # interactive mode from vit_pytorch import ViT class ImageTxtDataset(data.Dataset): def __init__(self, txt_path: str, folder_name, transform): self.transform = transform self.data_dir = os.path.dirname(txt_path) self.imgs_path = [] self.labels = [] self.folder_name = folder_name with open(txt_path, 'r') as f: lines = f.readlines() for line in lines: img_path, label = line.split() label = int(label.strip()) img_path = os.path.join(self.data_dir, self.folder_name, img_path) self.labels.append(label) self.imgs_path.append(img_path) def __len__(self): return len(self.imgs_path) def __getitem__(self, i): path, label = self.imgs_path[i], self.labels[i] image = Image.open(path).convert("RGB") if self.transform is not None: image = self.transform(image) return image, label parser = argparse.ArgumentParser(description='VIT ImageNet Training') parser.add_argument('--train_label', metavar='TRAIN_LABEL_DIR', default='E:/imagenet/train_label.txt', help='path to train_label') parser.add_argument('--val_label', metavar='VAL_LABEL_DIR', default='E:/imagenet/validation_label.txt', help='path to validation_label') parser.add_argument('--train_data', metavar='TRAIN', default='train', help='path to train_dataset') parser.add_argument('--val_data', metavar='VAL', default='val', help='path to val_dataset') parser.add_argument('-j', '--workers', default=0, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=300, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='manual epoch number (useful on restarts)') parser.add_argument('-b', '--batch-size', default=1, type=int, metavar='N', help='mini-batch size (default: 256), this is the total ' 'batch size of all GPUs on the current node when ' 'using Data Parallel or Distributed Data Parallel') parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)', dest='weight_decay') parser.add_argument('-p', '--print-freq', default=10, type=int, metavar='N', help='print frequency (default: 10)') parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', help='evaluate model on validation set') parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='use pre-trained model') parser.add_argument('--world-size', default=-1, type=int, help='number of nodes for distributed training') parser.add_argument('--rank', default=-1, type=int, help='node rank for distributed training') parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, help='url used to set up distributed training') parser.add_argument('--dist-backend', default='nccl', type=str, help='distributed backend') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument('--gpu', default=None, type=int, help='GPU id to use.') parser.add_argument('--multiprocessing-distributed', action='store_true', help='Use multi-processing distributed training to launch ' 'N processes per node, which has N GPUs. This is the ' 'fastest way to use PyTorch for either single node or ' 'multi node data parallel training') best_acc1 = 0 def main(): args = parser.parse_args() if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') if args.gpu is not None: warnings.warn('You have chosen a specific GPU. This will completely ' 'disable data parallelism.') if args.dist_url == "env://" and args.world_size == -1: args.world_size = int(os.environ["WORLD_SIZE"]) args.distributed = args.world_size > 1 or args.multiprocessing_distributed ngpus_per_node = torch.cuda.device_count() # print(ngpus_per_node) if args.multiprocessing_distributed: args.world_size = ngpus_per_node * args.world_size mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) else: main_worker(args.gpu, ngpus_per_node, args) def main_worker(gpu, ngpus_per_node, args): global best_acc1 print(args) args.gpu = gpu if args.gpu is not None: print("Use GPU: {} for training".format(args.gpu)) if args.distributed: if args.dist_url == "env://" and args.rank == -1: args.rank = int(os.environ["RANK"]) if not args.multiprocessing_distributed: args.rank = args.rank * ngpus_per_node + gpu dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) if args.pretrained: model = creat('vit_base_patch16_224', pretrained=True, num_classes=1000) else: model = ViT( image_size = 224, patch_size = 16, num_classes = 1000, dim = 1024, depth = 6, heads = 16, mlp_dim = 2048, dropout = 0.1, emb_dropout = 0.1 ) if not torch.cuda.is_available(): print('using CPU, this will be slow') elif args.distributed: if args.gpu is not None: torch.cuda.set_device(args.gpu) model.cuda(args.gpu) args.batch_size = int(args.batch_size / ngpus_per_node) args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) else: model.cuda() model = torch.nn.parallel.DistributedDataParallel(model) elif args.gpu is not None: torch.cuda.set_device(args.gpu) model = model.cuda(args.gpu) else: model = torch.nn.DataParallel(model).cuda() # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss().cuda(args.gpu) optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) if args.gpu is None: checkpoint = torch.load(args.resume) else: # Map model to be loaded to specified single gpu. loc = 'cuda:{}'.format(args.gpu) checkpoint = torch.load(args.resume, map_location=loc) args.start_epoch = checkpoint['epoch'] best_acc1 = checkpoint['best_acc1'] if args.gpu is not None: # best_acc1 may be from a checkpoint from a different GPU best_acc1 = best_acc1.to(args.gpu) model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) cudnn.benchmark = True normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = ImageTxtDataset(args.train_label, args.train_data, transforms.Compose([transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) else: train_sampler = None # args.workers=0 train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler) val_dataset = ImageTxtDataset(args.val_label, args.val_data, transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize,])) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) if args.evaluate: validate(val_loader, model, criterion, args) return for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) adjust_learning_rate(optimizer, epoch, args) # train for one epoch train(train_loader, model, criterion, optimizer, epoch, args) # evaluate on validation set acc1 = validate(val_loader, model, criterion, args) # remember best acc@1 and save checkpoint is_best = acc1 > best_acc1 best_acc1 = max(acc1, best_acc1) if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): save_checkpoint({ 'epoch': epoch + 1, 'arch': 'vit', 'state_dict': model.state_dict(), 'best_acc1': best_acc1, 'optimizer' : optimizer.state_dict(), }, is_best) def train(train_loader, model, criterion, optimizer, epoch, args): batch_time = AverageMeter('Time', ':6.3f') data_time = AverageMeter('Data', ':6.3f') losses = AverageMeter('Loss', ':.4e') top1 = AverageMeter('Acc@1', ':6.2f') top5 = AverageMeter('Acc@5', ':6.2f') progress = ProgressMeter( len(train_loader), [batch_time, data_time, losses, top1, top5], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() # print('adg') for i, (images, target) in enumerate(train_loader): # measure data loading time data_time.update(time.time() - end) if args.gpu is not None: images = images.cuda(args.gpu, non_blocking=True) if torch.cuda.is_available(): target = target.cuda(args.gpu, non_blocking=True) # compute output output = model(images) loss = criterion(output, target) # measure accuracy and record loss acc1, acc5 = accuracy(output, target, topk=(1, 5)) losses.update(loss.item(), images.size(0)) top1.update(acc1[0], images.size(0)) top5.update(acc5[0], images.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) def validate(val_loader, model, criterion, args): batch_time = AverageMeter('Time', ':6.3f') losses = AverageMeter('Loss', ':.4e') top1 = AverageMeter('Acc@1', ':6.2f') top5 = AverageMeter('Acc@5', ':6.2f') progress = ProgressMeter( len(val_loader), [batch_time, losses, top1, top5], prefix='Test: ') # switch to evaluate mode model.eval() with torch.no_grad(): end = time.time() for i, (images, target) in enumerate(val_loader): if args.gpu is not None: images = images.cuda(args.gpu, non_blocking=True) if torch.cuda.is_available(): target = target.cuda(args.gpu, non_blocking=True) # compute output output = model(images) loss = criterion(output, target) # measure accuracy and record loss acc1, acc5 = accuracy(output, target, topk=(1, 5)) losses.update(loss.item(), images.size(0)) top1.update(acc1[0], images.size(0)) top5.update(acc5[0], images.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' .format(top1=top1, top5=top5)) return top1.avg def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): torch.save(state, filename) if is_best: shutil.copyfile(filename, 'model_best.pth.tar') class AverageMeter(object): def __init__(self, name, fmt=':f'): self.name = name self.fmt = fmt self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def __str__(self): fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' return fmtstr.format(**self.__dict__) class ProgressMeter(object): def __init__(self, num_batches, meters, prefix=""): self.batch_fmtstr = self._get_batch_fmtstr(num_batches) self.meters = meters self.prefix = prefix def display(self, batch): entries = [self.prefix + self.batch_fmtstr.format(batch)] entries += [str(meter) for meter in self.meters] print('\t'.join(entries)) def _get_batch_fmtstr(self, num_batches): num_digits = len(str(num_batches // 1)) fmt = '{:' + str(num_digits) + 'd}' return '[' + fmt + '/' + fmt.format(num_batches) + ']' def adjust_learning_rate(optimizer, epoch, args): lr = args.lr * (0.1 ** (epoch // 30)) for param_group in optimizer.param_groups: param_group['lr'] = lr def accuracy(output, target, topk=(1,)): with torch.no_grad(): maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) res.append(correct_k.mul_(100.0 / batch_size)) return res if __name__ == '__main__': main() 使用方法如下: 文件名为main.py CPU训练; python main.py --lr 0.001 --data (自己的数据集文件路径) 单GPU从头训练: python main.py --gpu 0 --lr 0.001?--data (自己的数据集文件路径) 单GPU使用预训练模型训练: python main.py --gpu 0 --lr 0.001?--data (自己的数据集文件路径)--pretrained 多GPU从头训练(四卡为例): CUDA_VISIBLE_DEVICES=0,1,2,3?python main.py --data --lr 0.001(自己的数据集文件路径) --world-size 1 --rank 0 ?--b 64 多GPU使用预训练模型训练(四卡为例): CUDA_VISIBLE_DEVICES=0,1,2,3?python main.py --data --lr 0.001(自己的数据集文件路径) --world-size 1 --rank 0 ?--b 64 --pretrained |
|
|
上一篇文章 下一篇文章 查看所有文章 |
|
开发:
C++知识库
Java知识库
JavaScript
Python
PHP知识库
人工智能
区块链
大数据
移动开发
嵌入式
开发工具
数据结构与算法
开发测试
游戏开发
网络协议
系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程 数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁 |
360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 | -2025/1/11 16:50:11- |
|
网站联系: qq:121756557 email:121756557@qq.com IT数码 |