import torch.nn as nn
class A (nn.Module):
def __init__(self):
super(A, self).__init__()
self.c1 = nn.Conv2d(256,256,1)
self.c2 = nn.Conv2d(256,256,1)
def forward(self,x):
x1=x
for i in range(3):
print('i',i)
output = self.c1(x)
output = self.c2(output)
if i ==0:
x = x1+output
else:
x = output
# x = output
# print('output',output)
return output
# if __name__ == '__main__':
# # # from time import time
# import torch
# from tensorboardX import SummaryWriter
#
# pose = A() # .cuda()
# for param in pose.parameters():
# if param.requires_grad:
# print('param autograd')
# break
#
# # t0 = time()
# input = torch.randn(1, 256, 256, 256) # .cuda()
# # print(pose)
# output = pose(input) # type: torch.Tensor
# output[0][0].sum().backward()
#
# with SummaryWriter(comment='testnetwork') as w:
# w.add_graph(pose, (input,))
if __name__ == '__main__':
import torch.onnx
import netron
#####hourglass是自己的网络代码,自己定义的网络结构类名
pose = A() # .cuda()
dummy_input = torch.randn(1, 256, 256, 256)
# 输出的文件名称,一般是在当前定义网络路径下
onnx_path = "pose.onnx"
torch.onnx.export(pose, dummy_input, "pose.onnx") # netron --host=localhost
# 自动跳转到netron的网址下
netron.start(onnx_path)
|