bn层,卷积层测试:
import torch
from torch import nn
def init_weights(m):
if type(m) == torch.nn.Linear :
m.weight.data=torch.ones_like(m.weight)
m.bias.data = torch.ones_like(m.bias)
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(6)
def forward(self,x):
a1=self.conv1(x)
a2=self.bn1(a1)
return a2
if __name__ == '__main__':
net=Net()
net.apply(init_weights) #为了固定住网络的初始参数
print(net.conv1.weight.grad,net.bn1.weight.grad,sep="\n")
print("-------------")
net.conv1.requires_grad_(False)
net=net.train()
x=torch.rand([2,3,8,8],dtype=torch.float32)
y=net(x)
y.sum().backward()
print(net
|