2021SC@SDUSC
我们继续分析dataset类,dataset类位于lastDataset.py文件中,是该算法的核心代码之一。dataset类中一共有20个类函数,我将会挑选核心的函数来分析。
首先是对数据集建立词表的build_ent_vocab函数。
def build_ent_vocab(self,path,unkat=0):
ents = ""
with open(path,encoding='utf-8') as f:
for l in f:
ents += " "+l.split("\t")[1]
itos = sorted(list(set(ents.split(" "))))
itos[0] == "<unk>"; itos[1] == "<pad>"
stoi = {x:i for i,x in enumerate(itos)}
return itos,stoi
参数中的path就是数据集所在的路径,调用的时候传入。unkat参数初始值为0,意为为转换。ents是声明的字符串变量,存储遍历读取到的字符串数据集。itos是一个列表变量,每个元素都是ents中根据“ ”切割出的分词。比如ents='A? B',那么itos则为['A','B'],初始化itos第一个值为unk,第二个值为pad,enumerate()函数将itos组合为索引序列,结果组合为stoi变量。返回数据对象itos和索引序列stoi。
接下来是mkGraphs函数。
def mkGraphs(self,r,ent):
……
return (adj,rel)
这个函数的作用是用adj和rel矩阵将三元组转换为entlist。具体操作非关键代码,此处不再赘述。
接下来是mkVocabs函数。
def mkVocabs(self,args):
args.path = args.datadir + args.data
self.INP = data.Field(sequential=True, batch_first=True,init_token="<start>", eos_token="<eos>",include_lengths=True)
self.OUTP = data.Field(sequential=True, batch_first=True,init_token="<start>", eos_token="<eos>",include_lengths=True)
self.TGT = data.Field(sequential=True, batch_first=True,init_token="<start>", eos_token="<eos>")
self.NERD = data.Field(sequential=True, batch_first=True,eos_token="<eos>")
self.ENT = data.RawField()
self.REL = data.RawField()
self.SORDER = data.RawField()
self.SORDER.is_target = False
self.REL.is_target = False
self.ENT.is_target = False
self.fields=[("src",self.INP),("ent",self.ENT),("nerd",self.NERD),("rel",self.REL),("out",self.OUTP),("sorder",self.SORDER)]
该段代码就是对这些参数进行操作,Field类和RawField类在之前已经详细分析过,此处不再单独分析这两个类。它设置了处理后保存的路径,设置INP和OUTP为顺序数据、先生成batch dimension的tensor、以“<start>”为开始标记、以“<eos>”为结束标记、返回带填充的minibatch和的元组。
if args.eval:
train = data.TabularDataset(path=args.datadir+args.traindata, format='tsv',fields=self.fields)
else:
train = data.TabularDataset(path=args.path, format='tsv',fields=self.fields)
print('building vocab')
train变量为把data定义为以TSV格式存储的列的数据集。TabularDataset是一个类,用来定义以CSV、TSV或JSON格式存储的列的数据集。如果使用dict,键应该是JSON键或CSV/TSV列的子集,值应该是(name, field)的元组。这会允许我们从其JSON/CSV/TSV键名重命名列,还允许选择要加载的列的子集。
self.OUTP.build_vocab(train, min_freq=args.outunk)
generics =['<method>','<material>','<otherscientificterm>','<metric>','<task>']
self.OUTP.vocab.itos.extend(generics)
for x in generics:
self.OUTP.vocab.stoi[x] = self.OUTP.vocab.itos.index(x)
self.TGT.vocab = copy(self.OUTP.vocab)
specials = "method material otherscientificterm metric task".split(" ")
for x in specials:
for y in range(40):
s = "<"+x+"_"+str(y)+">"
self.TGT.vocab.stoi[s] = len(self.TGT.vocab.itos)+y
self.NERD.build_vocab(train,min_freq=0)
for x in generics:
self.NERD.vocab.stoi[x] = self.OUTP.vocab.stoi[x]
首先对要输出的变量进行build_vocab操作,该函数为Field的类函数,之前已分析过,此处不再赘述。generics是作者(不是我,是写代码的人)在数据集中找的一个实例。接下来就是对这个数据集进行扩大、切割、存储操作,specials就是把“method material otherscientificterm metric task”这个字符串根据" "进行分割,也就是generics。
接下来看一个批处理函数fixBatch()。
def fixBatch(self,b):
ent,phlens = zip(*b.ent)
ent,elens = self.adjToBatch(ent)
ent = ent.to(self.args.device)
adj,rel = zip(*b.rel)
if self.args.sparse:
b.rel = [adj,self.listTo(rel)]
else:
b.rel = [self.listTo(adj),self.listTo(rel)]
if self.args.plan:
b.sordertgt = self.listTo(self.pad_list(b.sordertgt))
phlens = torch.cat(phlens,0).to(self.args.device)
elens = elens.to(self.args.device)
b.ent = (ent,phlens,elens)
return b
参数b为传入的地址。ent,phlens = zip(*b.ent)和adj,rel = zip(*b.rel)为解压b,解压后仍为元组,对解压后的元组调用adjToBatch函数进行生成邻接矩阵的批处理操作,最后返回的是矩阵。最后b直接变为三元组并返回。
|