安装
官方网站:https://torchmetrics.readthedocs.io/en/stable/。
安装方法:
conda install -c conda-forge torchmetrics
使用
大家如果用过sklearn.metrics ,那么就会很容易使用这个,其实我一开始都是一直使用前者的,但是前者只支持numpy以及list,所以每次使用sklearn的使用,我们需要将tensor从gpu搬到cpu,然后再由tensor转化为numpy,现在的话,TorchMetrics则不会需要啦。
import torchmetrics
import torch
preds = torch.randn(2, 2).softmax(dim=-1)
target = torch.randint(2, (2,))
print(preds)
print(target)
从上面我们可以看到,两个样本都是输出第0个类别的概率高,标签也是第0个类别,那么准确率为100%,下面验证一下:
acc = torchmetrics.functional.accuracy(preds, target)
print(acc)
tensor(1.)
我们发现,TorchMetrics直接支持tensor计算,其实其也支持gpu上直接计算,如下:
device=torch.device("cuda:1")
acc = torchmetrics.functional.accuracy(preds.to(device), target.to(device))
print(acc)
tensor(1., device=‘cuda:1’)
|