基于人工智能的多肽药物分析问题(十四)
2021SC@SDUSC
1. 前言
在上篇博客中,我们已经将该项目的代码基本上全部分析完毕了,但在Refine_module方法中,还有一个比较有意思的模块,Regen_Network,现在奉上对其的分析
2. 代码分析
2.1 Regen_Network模型
class Regen_Network(nn.Module):
def __init__(self,
node_dim_in=64,
node_dim_hidden=64,
edge_dim_in=128,
edge_dim_hidden=64,
state_dim=8,
nheads=4,
nblocks=3,
dropout=0.0):
super(Regen_Network, self).__init__()
self.norm_node = LayerNorm(node_dim_in)
self.norm_edge = LayerNorm(edge_dim_in)
self.embed_x = nn.Sequential(nn.Linear(node_dim_in+21, node_dim_hidden),
LayerNorm(node_dim_hidden))
self.embed_e = nn.Sequential(nn.Linear(edge_dim_in+2, edge_dim_hidden),
LayerNorm(edge_dim_hidden))
blocks = [UniMPBlock(node_dim_hidden,edge_dim_hidden,nheads,dropout) for _ in range(nblocks)]
self.transformer = nn.Sequential(*blocks)
self.get_xyz = nn.Linear(node_dim_hidden,9)
self.norm_state = LayerNorm(node_dim_hidden)
self.get_state = nn.Linear(node_dim_hidden, state_dim)
?
该模型的作用是用于通过最终的残基对特征和经过前面的模型预测出的参数重新生成初始坐标,然后通过多个SE3 transformer block对其进行细化 。
在上述代码的第14~15行,我们可以看到,定义了两个LayerNorm层,然后在17、19行定义了embedding层,是为节点和边缘特征做嵌入的嵌入层。嵌入层可以把我们的稀疏矩阵,通过一些线性变换(在CNN中用全连接层进行转换,也称为查表操作),变成了一个密集矩阵,在这个密集矩阵中,表象上代表着密集矩阵跟单个字的一一对应关系,实际上还蕴含了大量的字与字之间,词与词之间甚至句子与句子之间的内在关系。他们之间的关系,用的是嵌入层学习来的参数进行表征。从稀疏矩阵到密集矩阵的过程,叫做embedding,很多人也把它叫做查表,因为他们之间也是一个一一映射的关系。
更重要的是,这种关系在反向传播的过程中,是一直在更新的,因此能在多次epoch后,使得这个关系变成相对成熟,即:正确的表达整个语义以及各个语句之间的关系。这个成熟的关系,就是embedding层的所有权重参数。
Embedding是NPL领域最重要的发明之一,他把独立的向量一下子就关联起来了。这就相当于什么呢,相当于你是你爸的儿子,你爸是A的同事,B是A的儿子,似乎跟你是八竿子才打得着的关系。结果你一看B,是你的同桌。Embedding层就是用来发现这个秘密的武器。
再之后,在第24行定义了一个graph transformer。即前面提及的通过该层次进行坐标的细化。
最后是用全连接层等层次结构对结果进行输出。
def forward(self, seq1hot, idx, node, edge):
B, L = node.shape[:2]
node = self.norm_node(node)
edge = self.norm_edge(edge)
node = torch.cat((node, seq1hot), dim=-1)
node = self.embed_x(node)
seqsep = get_seqsep(idx)
neighbor = get_bonded_neigh(idx)
edge = torch.cat((edge, seqsep, neighbor), dim=-1)
edge = self.embed_e(edge)
G = make_graph(node, idx, edge)
Gout = self.transformer(G)
xyz = self.get_xyz(Gout.x)
state = self.get_state(self.norm_state(Gout.x))
return xyz.reshape(B, L, 3, 3) , state.reshape(B, L, -1)
在Regen_Network模型的forward方法中,说明了如何调用上述模型
2~7行是对在构造函数中定义好的模型的一些赋值工作,之后调用了get_seqsep进行数据处理
def get_seqsep(idx):
seqsep = idx[:,None,:] - idx[:,:,None]
sign = torch.sign(seqsep)
seqsep = torch.log(torch.abs(seqsep) + 1.0)
seqsep = torch.clamp(seqsep, 0.0, 5.5)
seqsep = sign * seqsep
return seqsep.unsqueeze(-1)
该方法的输入为给定蛋白质序列的残基参数,经过一些数学上的处理,以(B, L, L, 1)的四元组形式输出序列的分离特性,通过测试发现该处理会对模型的准确度有一点点的帮助。
之后调用get_bonded_neigh方法
def get_bonded_neigh(idx):
neighbor = idx[:,None,:] - idx[:,:,None]
neighbor = neighbor.float()
sign = torch.sign(neighbor)
neighbor = torch.abs(neighbor)
neighbor[neighbor > 1] = 0.0
neighbor = sign * neighbor
return neighbor.unsqueeze(-1)
该方法的输入同样是给定蛋白质序列的残基参数,最终的输出是相邻节点的相关信息。
下面细看以下用于细化坐标的graph transformer模型
2.2 graph transformer模型
class UniMPBlock(nn.Module):
'''https://arxiv.org/pdf/2009.03509.pdf'''
def __init__(self,
node_dim=64,
edge_dim=64,
heads=4,
dropout=0.15):
super(UniMPBlock, self).__init__()
self.TConv = TransformerConv(node_dim, node_dim, heads, dropout=dropout, edge_dim=edge_dim)
self.LNorm = LayerNorm(node_dim*heads)
self.Linear = nn.Linear(node_dim*heads, node_dim)
self.Activ = nn.ELU(inplace=True)
@torch.cuda.amp.autocast(enabled=True)
def forward(self, G):
xin, e_idx, e_attr = G.x, G.edge_index, G.edge_attr
x = self.TConv(xin, e_idx, e_attr)
x = self.LNorm(x)
x = self.Linear(x)
out = self.Activ(x+xin)
return Data(x=out, edge_index=e_idx, edge_attr=e_attr)
该模型中最关键的一层是在第10行调用的TransformerConv,是基于图神经网络的节点表征学习
GCN全称Graph convolutional network,图卷积网络。GCN于2017年提出,它的到来标志着图神经网络时代的出现。
GCN与我们常见的CNN(卷积神经网络)听起来名字很相似,其实理解起来也比较类似,都可以理解为是一种特征提取器。不同的是,CNN提取的是张量数据特征,而GCN提取的是图结构数据特征。图的结构一般来说是十分不规则的,可以认为是无限维的一种数据,所以它没有平移不变性。每一个节点的周围结构可能都是独一无二的,这种结构的数据,就让传统的CNN、RNN瞬间失效。所以很多学者从上个世纪就开始研究怎么处理这类数据了。这里涌现出了很多方法,例如GNN、DeepWalk、node2vec等等,GCN只是其中一种。图卷积神经网络,实际上跟CNN的作用一样,就是一个特征提取器,只不过它的对象是图数据。GCN精妙地设计了一种从图数据中提取特征的方法,从而让我们可以使用这些特征去对图数据进行节点分类(node classification)、图分类(graph classification)、边预测(link prediction) ,还可以顺便得到 图的嵌入表示(graph embedding)GCN也是一个神经网络层。
图卷积的核心思想是利用边的信息对节点信息进行聚合从而生成新的节点表示。GCN的本质目的就是用来提取拓扑图的空间特征。而pytorch中提供的一个图卷积模块。
该模型的作用就是通过多层的图卷积网络层对坐标进行细化。
下面来看一下e2e模式中用到的另一个模型
2.3 Refine_Network模型
class Refine_Network(nn.Module):
def __init__(self, d_node=64, d_pair=128, d_state=16,
SE3_param={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32}, p_drop=0.0):
super(Refine_Network, self).__init__()
self.norm_msa = LayerNorm(d_node)
self.norm_pair = LayerNorm(d_pair)
self.norm_state = LayerNorm(d_state)
self.embed_x = nn.Linear(d_node+21+d_state, SE3_param['l0_in_features'])
self.embed_e1 = nn.Linear(d_pair, SE3_param['num_edge_features'])
self.embed_e2 = nn.Linear(SE3_param['num_edge_features']+36+1, SE3_param['num_edge_features'])
self.norm_node = LayerNorm(SE3_param['l0_in_features'])
self.norm_edge1 = LayerNorm(SE3_param['num_edge_features'])
self.norm_edge2 = LayerNorm(SE3_param['num_edge_features'])
self.se3 = SE3Transformer(**SE3_param)
在Refine_Network模型中,先是定义了一些LayerNorm和全连接层,将不属于经典深度学习网络模型的参数,变成了模型中可迭代训练的参数通常并将相关参数全部保存,最后定义的是一个SE3Transformer模型。
class SE3Transformer(nn.Module):
def __init__(self, num_layers=2, num_channels=32, num_degrees=3, n_heads=4, div=4,
si_m='1x1', si_e='att',
l0_in_features=32, l0_out_features=32,
l1_in_features=3, l1_out_features=3,
num_edge_features=32, x_ij=None):
super().__init__()
self.num_layers = num_layers
self.num_channels = num_channels
self.num_degrees = num_degrees
self.edge_dim = num_edge_features
self.div = div
self.n_heads = n_heads
self.si_m, self.si_e = si_m, si_e
self.x_ij = x_ij
if l1_out_features > 0:
fibers = {'in': Fiber(dictionary={0: l0_in_features, 1: l1_in_features}),
'mid': Fiber(self.num_degrees, self.num_channels),
'out': Fiber(dictionary={0: l0_out_features, 1: l1_out_features})}
else:
fibers = {'in': Fiber(dictionary={0: l0_in_features, 1: l1_in_features}),
'mid': Fiber(self.num_degrees, self.num_channels),
'out': Fiber(dictionary={0: l0_out_features})}
blocks = self._build_gcn(fibers)
self.Gblock = blocks
该模型的主要作用是从相对位置计算等变权重,从查阅的资料上来看,应该是将注意力机制和图神经网络GCN相结合而定义的一个自定义模型,在代码的第20行,用到了一个自定义的Fiber的数据结构,主要是用于表征与该模型相关的信息。
该类的代码如下:
class Fiber(object):
"""A Handy Data Structure for Fibers"""
def __init__(self, num_degrees: int=None, num_channels: int=None,
structure: List[Tuple[int,int]]=None, dictionary=None):
if structure:
self.structure = structure
elif dictionary:
self.structure = [(dictionary[o], o) for o in sorted(dictionary.keys())]
else:
self.structure = [(num_channels, i) for i in range(num_degrees)]
self.multiplicities, self.degrees = zip(*self.structure)
self.max_degree = max(self.degrees)
self.min_degree = min(self.degrees)
self.structure_dict = {k: v for v, k in self.structure}
self.dict = self.structure_dict
self.n_features = np.sum([i[0] * (2*i[1]+1) for i in self.structure])
self.feature_indices = {}
idx = 0
for (num_channels, d) in self.structure:
length = num_channels * (2*d + 1)
self.feature_indices[d] = (idx, idx + length)
idx += length
2.4 构建图卷积神经网络
等变化层
def _build_gcn(self, fibers):
Gblock = []
fin = fibers['in']
for i in range(self.num_layers):
Gblock.append(GSE3Res(fin, fibers['mid'], edge_dim=self.edge_dim,
div=self.div, n_heads=self.n_heads,
learnable_skip=True, skip='cat',
selfint=self.si_m, x_ij=self.x_ij))
Gblock.append(GNormBias(fibers['mid']))
fin = fibers['mid']
Gblock.append(
GSE3Res(fibers['mid'], fibers['out'], edge_dim=self.edge_dim,
div=1, n_heads=min(1, 2), learnable_skip=True,
skip='cat', selfint=self.si_e, x_ij=self.x_ij))
return nn.ModuleList(Gblock)
先是定义了一个列表,之后将GSE3Res、GNormBias等方法的返回值全部添加到该列表中,构建出多层的图神经网络。
下面具体分析图卷积神经网络的各个层:
GSE3Res
class GSE3Res(nn.Module):
"""Graph attention block with SE(3)-equivariance and skip connection"""
def __init__(self, f_in: Fiber, f_out: Fiber, edge_dim: int=0, div: float=4,
n_heads: int=1, learnable_skip=True, skip='cat', selfint='1x1', x_ij=None):
super().__init__()
self.f_in = f_in
self.f_out = f_out
self.div = div
self.n_heads = n_heads
self.skip = skip
f_mid_out = {k: int(v // div) for k, v in self.f_out.structure_dict.items()}
self.f_mid_out = Fiber(dictionary=f_mid_out)
f_mid_in = {d: m for d, m in f_mid_out.items() if d in self.f_in.degrees}
self.f_mid_in = Fiber(dictionary=f_mid_in)
self.edge_dim = edge_dim
self.GMAB = nn.ModuleDict()
self.GMAB['v'] = GConvSE3Partial(f_in, self.f_mid_out, edge_dim=edge_dim, x_ij=x_ij)
self.GMAB['k'] = GConvSE3Partial(f_in, self.f_mid_in, edge_dim=edge_dim, x_ij=x_ij)
self.GMAB['q'] = G1x1SE3(f_in, self.f_mid_in)
self.GMAB['attn'] = GMABSE3(self.f_mid_out, self.f_mid_in, n_heads=n_heads)
if self.skip == 'cat':
self.cat = GCat(self.f_mid_out, f_in)
if selfint == 'att':
self.project = GAttentiveSelfInt(self.cat.f_out, f_out)
elif selfint == '1x1':
self.project = G1x1SE3(self.cat.f_out, f_out, learnable=learnable_skip)
elif self.skip == 'sum':
self.project = G1x1SE3(self.f_mid_out, f_out, learnable=learnable_skip)
self.add = GSum(f_out, f_in)
assert self.add.f_out.structure_dict == f_out.structure_dict, \
'skip connection would change output structure'
在该模型中,先是对传入的参数进行赋值,之后对输出进行了一些处理,f_mid_out和’f_out’有相同的结构,但是channels被’div’分开 这将用于构建value,f_mid_in的结构与f_mid_out相同,但只包含f_in中的度数 这将用于key和queue,又因为该层中定义的查询仅仅是投影,所以对应的输入必须是对应格式的,最后对skip connection进行一个检测,避免输入的维度和输出的维度不匹配。
GNormBias:
class GNormBias(nn.Module):
"""Norm-based SE(3)-equivariant nonlinearity with only learned biases."""
def __init__(self, fiber, nonlin=nn.ReLU(inplace=True),
num_layers: int = 0):
super().__init__()
self.fiber = fiber
self.nonlin = nonlin
self.num_layers = num_layers
self.eps = 1e-12
self.bias = nn.ParameterDict()
for m, d in self.fiber.structure:
self.bias[str(d)] = nn.Parameter(torch.randn(m).view(1, m))
该层以Fiber作为输入,用于计算标准化的特征,并进行了正则化防止梯度爆炸。
这样,该模型就构建好了
3. 总结
代码分析工作基本上已经全部结束了。
|