背景还是我在高DQN算法的时候遇到的,下面代码的第七行。完整代码参考这个博客。
def optimize_model():
if len(memory) < BATCH_SIZE:
return
transitions = memory.sample(BATCH_SIZE)
batch = Transition(*zip(*transitions))
non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)),device=device, dtype=torch.bool)
non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
state_batch = torch.cat(batch.state)
action_batch = torch.cat(batch.action)
reward_batch = torch.cat(batch.reward)
state_action_values = policy_net(state_batch).gather(1, action_batch)
next_state_values = torch.zeros(BATCH_SIZE, device=device)
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()
expected_state_action_values = (next_state_values * GAMMA) + reward_batch
criterion = nn.MSELoss()
loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))
optimizer.zero_grad()
loss.backward()
for param in policy_net.parameters():
param.grad.data.clamp_(-1, 1)
optimizer.step()
根据对DQN的理解,这个transiton是一个剧名数组,构造方式如下:
Transition = namedtuple('Transition',('state', 'action', 'next_state', 'reward'))
再训练DQN的时候,我们需要从replay buffer中提取Transition ,然后将transition的四个变量提取出来使用,这个时候就需要使用解包操作。
1、使用说明
(1)解包的意义就是将传递给函数的一个列表,元组,字典,拆分成独立的多个元素然后赋值给函数中的形参变量。
(2)解包字典有两种解法,一种用
?
*
?解的只有key,一种用
?
?
**
??解的有key、value。但是这个方法**只能在函数定义中使用。
2. 解包方法
解包的方法分类两种,
?
*
?和
?
?
**
??
其中
?
?
**
??是针对字典的。 我们先举
?
*
?的例子,也用数组表示吧,其实列表list也一样。 来一个数组
a = (1,2,3)
常规的解包操作是这样的
a = (1,2,3)
a1,a2,a3 = (1,2,3)
print(a1)
print(a2)
print(a3)
输出是:
1
2
3
如果使用*方法解包,那么就省事很多了。
但是,一般来讲,我们会把zip(*)在一起用。 下面举个例子
Transition = namedtuple('Transition',('state', 'action', 'next_state', 'reward'))
a = [(1,2,3,4),(11,12,13,14),(21,22,23,24),(31,32,33,34)]
b = zip(*a)
b = list(b)
c = Transition(*zip(*a))
c = list(c)
print(c)
|