import torch
import torch.autograd as autograd
def batchify_with_label(input_batch_list,gpu,num_layer,volatile_flag = False):
"""
input_list:batch_size个7*list
0.wordid:(max_len)
1.biword_id:(max_len])
2.----
3.gaz_id:max个[[id1...],[len1...]]
4.tag_id:[max_len]
5.layer_gazs:(max_len,num_layer)
6.gaz_mask:(max_len,num_layer)
"""
batch_size = len(input_batch_list)
words = [sent[0] for sent in input_batch_list]
biwords = [sent[1] for sent in input_batch_list]
gazs = [sent[3] for sent in input_batch_list]
labels = [sent[4] for sent in input_batch_list]
layer_gazs = [sent[5] for sent in input_batch_list]
gaz_mask = [sent[6] for sent in input_batch_list]
word_seq_lengths = torch.LongTensor(list(map(len,words)))
max_seq_len = word_seq_lengths.max()
word_seq_tensor = autograd.Variable(torch.zeros((batch_size,max_seq_len))).long()
biword_seq_tensor = autograd.Variable(torch.zeros((batch_size,max_seq_len))).long()
label_seq_tensor = autograd.Variable(torch.zeros((batch_size,max_seq_len))).long()
layer_gaz_tensor = torch.zeros(batch_size,max_seq_len,num_layer).long()
mask = autograd.Variable(torch.zeros((batch_size,max_seq_len))).byte()
gaz_mask_tensor = torch.zeros((batch_size,max_seq_len,num_layer)).byte()
for idx,(seq,biseq,label,seq_len,layergaz,gazmask) in enumerate(
zip(words,biwords,labels,word_seq_lengths,layer_gazs,gaz_mask)):
word_seq_tensor[idx,:seq_len] = torch.LongTensor(seq)
biword_seq_tensor[idx,:seq_len] = torch.LongTensor(biseq)
label_seq_tensor[idx,:seq_len] = torch.LongTensor(label)
layer_gaz_tensor[idx,:seq_len] = torch.LongTensor(layergaz)
mask[idx,:seq_len] = torch.Tensor([1] * int(seq_len))
gaz_mask_tensor[idx,:seq_len] = torch.LongTensor(gazmask)
if gpu:
word_seq_tensor = word_seq_tensor.cuda()
biword_seq_tensor = biword_seq_tensor.cuda()
word_seq_lengths = word_seq_lengths.cuda()
label_seq_tensor = label_seq_tensor.cuda()
layer_gaz_tensor = layer_gaz_tensor.cuda()
gaz_mask_tensor = gaz_mask_tensor.cuda()
mask = mask.cuda()
return gazs,word_seq_tensor,biword_seq_tensor,word_seq_lengths,label_seq_tensor,layer_gaz_tensor,gaz_mask_tensor,mask
input_batch_list = [[
[1,3,2,1,2,3],
[2,3,4,2,2,6],
[],
[[[2,1],[3,3]],[],[[[2],[3]]],[],[],[]],
[2,2,2,1,4,2],
[[3,0,0,0],[0,6,0,0],[0,0,0,0],[9,0,0,0],[0,0,3,0],[0,0,0,0]],
[[0,1,1,1],[1,0,1,1],[1,1,1,1],[0,1,1,1],[1,1,0,1],[1,1,1,1]]
]]
gpu,num_layer = 0,4
gazs,word_seq_tensor,biword_seq_tensor,word_seq_lengths,label_seq_tensor,layer_gaz_tensor,gaz_mask_tensor,mask = batchify_with_label(input_batch_list,gpu,num_layer)
print('gazs:',gazs)
print('word_seq_tensor:',word_seq_tensor)
print('biword_seq_tensor:',biword_seq_tensor)
print('word_seq_lengths:',word_seq_lengths)
print('label_seq_tensor:',label_seq_tensor)
print('layer_gaz_tensor:',layer_gaz_tensor)
print('gaz_mask_tensor:',gaz_mask_tensor)
print('mask:',mask)
输出:
gazs: [[[[2, 1], [3, 3]], [], [[[2], [3]]], [], [], []]] word_seq_tensor: tensor([[1, 3, 2, 1, 2, 3]]) biword_seq_tensor: tensor([[2, 3, 4, 2, 2, 6]]) word_seq_lengths: tensor([6]) label_seq_tensor: tensor([[2, 2, 2, 1, 4, 2]]) layer_gaz_tensor: tensor([[[3, 0, 0, 0], [0, 6, 0, 0], [0, 0, 0, 0], [9, 0, 0, 0], [0, 0, 3, 0], [0, 0, 0, 0]]]) gaz_mask_tensor: tensor([[[0, 1, 1, 1], [1, 0, 1, 1], [1, 1, 1, 1], [0, 1, 1, 1], [1, 1, 0, 1], [1, 1, 1, 1]]], dtype=torch.uint8) mask: tensor([[1, 1, 1, 1, 1, 1]], dtype=torch.uint8)
|