1.如何查看模型的各个层次
print(model)
2.打印模型的参数
- 缺点:是可以看见一堆的参数,但是你不知道他是谁的参数,很乱。
print("-------------------------------------------------")
print("model的bert的属性")
for parameter in model.parameters():
print(parameter)
print("-------------------------------------------------")
print("-------------------------------------------------")
3.打印模型的各层名字以及参数
print("-------------------------------------------------")
print("model的bert的属性")
for name,parameter in model.named_parameters():
print(name,parameter)
print("-------------------------------------------------")
print("-------------------------------------------------")
7.8309e-02, -2.3176e-02, 1.9839e-02, -1.7092e-02, 9.4321e-02,
-1.4221e-02, 6.2530e-02, -3.1816e-02, -7.9080e-02, 1.4354e-02,
-2.1350e-02, -5.6522e-02, -3.4564e-02], device='cuda:1',
requires_grad=True)
classifier.weight Parameter containing: tensor([[ 0.0106, -0.0667, -0.0263, …, -0.0182, -0.0581, 0.0074], [ 0.0011, -0.0372, -0.0474, …, 0.0231, -0.0420, -0.0075], [ 0.0103, 0.0349, 0.0343, …, -0.0065, -0.0091, 0.0024], …, [-0.1062, -0.1443, -0.0232, …, -0.0280, 0.0067, 0.0993], [ 0.0186, -0.0386, 0.0207, …, -0.0693, -0.0363, 0.0977], [ 0.0094, 0.0551, -0.0461, …, -0.0175, -0.0222, -0.0230]], device=‘cuda:1’, requires_grad=True)
classifier.bias Parameter containing: tensor([-0.3976, 0.1223, 0.3341, 0.3246, 0.2115, -0.0361, 0.2665, 0.1029, 0.4019, 0.1331, -0.1092, 0.1410, -1.5911, -0.1935, 0.2342], device=‘cuda:1’, requires_grad=True) 4.返回模型的各种信息
for module in model.children():
print(module)
print("-------------------------------------------------")
for name, module in model.named_children():
print(name, module)
print("-------------------------------------------------")
for module in model.modules():
print(module)
print("-------------------------------------------------")
for name, module in model.named_modules():
print([name, module])
print("-------------------------------------------------")
print(model.state_dict())
print(model.state_dict().keys())
torch.save(model.state_dict(), "./weight.pth")
print("-------------------------------------------------")
weight = torch.load("./weight.pth")
print(weight)
for k, v in weight.items():
print(k)
[[1, 2], 3], [1, 2], 1, 2, 3
[1, 2], 3
|