自 2017 年 1 月 PyTorch 推出以来,其热度持续上升,一度有赶超?TensorFlow?的趋势。PyTorch 能在短时间内被众多研究人员和工程师接受并推崇是因为其有着诸多优点,如采用?Python?语言、动态图机制、网络构建灵活以及拥有强大的社群等。
?
1.Tensor(张量)
和TensorFlow 类似,PyTorch 的核心对象也是Tensor。
import torch
x = torch.Tensor(5, 3) # 创建一个5行3列的二维张量
print(x)
print(x.size()) # 得到张量的大小
输出:
tensor([[1.1708e-19, 7.2128e+22, 9.2216e+29],
[7.5546e+31, 1.6932e+22, 3.0728e+32],
[2.9514e+29, 2.8940e+12, 7.5338e+28],
[1.8037e+28, 3.4740e-12, 1.7743e+28],
[2.0535e-19, 1.4609e-19, 7.5630e+28]])
torch.Size([5, 3])
2.Operation(运算)
和TensorFlow 一样,有了Tensor 之后就可以用Operation 进行计算了。但是和TensorFlow 不同,TensorFlow 只是定义计算图但不会立即“执行”,而Pytorch 的Operation 是马上“执行”的。所以PyTorch 使用起来更加简单,当然PyTorch 也有计算图的执行引擎,但是它不对用户可见,它是“动态”编译的
import torch
x = torch.Tensor(5, 3) # 创建一个5行3列的二维张量
print(x)
print(x.size()) # 得到张量的大小
y = torch.rand(5, 3) # 创建一个5行3列的二维张量
print(y)
print(y.size())
print(x+y) # 这种方法会产生一个新的tensor
# 我们还可以创建一个张量存放运算的结果
result = torch.Tensor(5, 3)
torch.add(x, y, out=result)
print(result)
输出:
tensor([[0.4537, 0.4637, 0.6539],
[0.9365, 0.0710, 0.7438],
[0.7921, 0.5835, 0.1821],
[0.6770, 0.4274, 0.7033],
[0.4109, 0.3758, 0.6795]])
torch.Size([5, 3])
tensor([[0.4744, 0.6735, 0.0132],
[0.7808, 0.1955, 0.6542],
[0.3074, 0.6732, 0.4994],
[0.9840, 0.5362, 0.2304],
[0.8565, 0.5074, 0.2969]])
torch.Size([5, 3])
tensor([[0.9281, 1.1372, 0.6672],
[1.7173, 0.2665, 1.3980],
[1.0996, 1.2567, 0.6815],
[1.6610, 0.9636, 0.9337],
[1.2675, 0.8832, 0.9765]])
我们也可以用view 来修改Tensor 的shape,注意view 要求新的Tensor 的元素个数和原来是一样的
x = torch.randn(4, 4)
y = x.view(16)
z = x.view(-1, 8) # the size -1 is inferred from other dimensions
print(x.size(), y.size(), z.size())
输出:torch.Size([4, 4]) torch.Size([16]) torch.Size([2, 8])
3.numpy ndarray 的转换?
Tensor 转numpy
a = torch.ones(5)
print(a)
b = a.numpy()
print(b)
a.add_(2) # 把2加到a中
结果:
tensor([1., 1., 1., 1., 1.])
[1. 1. 1. 1. 1.]
tensor([3., 3., 3., 3., 3.])
numpy 转Tensor
import numpy as np
a = np.ones(5)
print(a)
b = torch.from_numpy(a)# 转换
print(b)
np.add(a, 2, out=a) # 把1加到a中去在输出a
结果:
[1. 1. 1. 1. 1.]
tensor([1., 1., 1., 1., 1.], dtype=torch.float64)
array([3., 3., 3., 3., 3.])
|