pytorch中的parameters
在模型中,会出现model.parameters()与model.state_dict() 前者用于优化器的初始化,后者多用于模型的保存
self._weight_optimizer = torch.optim.Adam(
self._net.weight_parameters(),
lr=self._weight_lr,
weight_decay=self._weight_decay
)
def _save(self, mode):
save_dir = os.path.join(param_path, self._name)
if not os.path.exists(save_dir):
os.mkdir(save_dir)
states = {
'net': self._net.state_dict(),
'arch_optimizer': self._arch_optimizer.state_dict(),
'arch_optimizer_scheduler': self._arch_optimizer_scheduler.state_dict(),
'weight_optimizer': self._weight_optimizer.state_dict(),
'weight_optimizer_scheduler': self._weight_optimizer_scheduler.state_dict(),
'best_epoch': self._best_epoch,
'valid_records': self._valid_records
}
filename = os.path.join(save_dir, '%s.pth' % mode)
torch.save(obj=states, f=filename)
logging.info('[eval]\tepoch[%d]\tsave parameters to %s', self._best_epoch, filename)
当我们对网络调参或者查看网络的参数是否具有可复现性时,可能会查看网络的参数。 对于parameters() 实例源码:
def parameters(self, recurse=True):
r"""Returns an iterator over module parameters.
This is typically passed to an optimizer.
Args:
recurse (bool): if True, then yields parameters of this module
and all submodules. Otherwise, yields only parameters that
are direct members of this module.
Yields:
Parameter: module parameter
Example::
>>> for param in model.parameters():
>>> print(type(param.data), param.size())
<class 'torch.FloatTensor'> (20L,)
<class 'torch.FloatTensor'> (20L, 1L, 5L, 5L)
"""
for name, param in self.named_parameters(recurse=recurse):
yield param
yield是一个生成器 对于生成器,我们需要用循环或者next()来获取数据,我们这里以一个简单的网络做例子:
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.linear = nn.Linear(2,2)
def forward(self, x):
out = self.linear
return out
net = Net()
for para in net.parameters():
print(para)
Parameter containing:
tensor([[ 0.2593, -0.3468],
[-0.2661, 0.0250]], requires_grad=True)
Parameter containing:
tensor([0.5084, 0.4834], requires_grad=True)
state_dict()方法,看名字中dict就知道这是个字典,我们直接print():
OrderedDict([('linear.weight', tensor([[0.2781, 0.3600],
[0.4755, 0.1770]])), ('linear.bias', tensor([-0.4715, 0.6185]))])
如果设置随机种子,每一次就是一样的
参考链接:https://zhuanlan.zhihu.com/p/270344655
|