记:关于Pytorch中Linear结构与参数权重查看
代码段:
import torch
x = torch.ones(1,3)
y = torch.nn.Linear(3,3,bias=True)
print(x)
print(y)
out = y.forward(x)
print(out)
print(y.state_dict().keys()) # 查看有哪些参量
print(y.weight) # 输出weight参量
print(y.bias) # 输出bias参量
运行结果:
tensor([[1., 1., 1.]])
Linear(in_features=3, out_features=3, bias=True)
tensor([[ 0.0577, -0.6237, -0.2998]], grad_fn=<AddmmBackward>)
odict_keys(['weight', 'bias'])
Parameter containing:
tensor([[ 0.0177, 0.4276, -0.1507],
[-0.0755, -0.2118, 0.1403],
[ 0.1560, -0.3528, -0.4536]], requires_grad=True)
Parameter containing:
tensor([-0.2369, -0.4766, 0.3506], requires_grad=True)
|