基于分子图的 BERT 模型,原文:MG-BERT: leveraging unsupervised atomic representation learning for molecular property prediction,原文解析:MG-BERT | 利用 无监督 原子表示学习 预测分子性质 | 在分子图上应用BERT | GNN | 无监督学习(掩蔽原子预训练) | attention,代码:Molecular-graph-BERT。本文在前两篇分析的基础上看 attention__visualize 部分
1.run
medium = {'name':'Medium','num_layers': 6, 'num_heads': 8, 'd_model': 256,'path':'medium_weights','addH':True}
rch = medium
trained_epoch = 8
num_layers = arch['num_layers']
num_heads = arch['num_heads']
d_model = arch['d_model']
addH = arch['addH']
dff = d_model * 2
vocab_size = 17
dropout_rate = 0.1
seed = 7
np.random.seed(seed=seed)
tf.random.set_seed(seed=seed)
task = 'logD'
df = pd.read_csv('data/reg/logD.txt',sep='\t')
sml_list = df['SMILES'].tolist()
inference_dataset = Inference_Dataset(['C=C(CC)C(=O)c1ccc(OCC(=O)[O-])c(Cl)c1Cl',
'CN(Cc1c(C(=O)NC2CCCC(O)C2)noc1-c1ccc(C(F)(F)F)cc1)C1CCOC1',
'CC(=O)Nc1ccc(O)cc1',
'CC(=O)CC(c1ccc([N+](=O)[O-])cc1)c1c(O)c2ccccc2oc1=O',
'Cc1cccc(NC(=S)Oc2ccc3ccccc3c2)c1',
'CC(=O)Nc1nnc(S(N)(=O)=O)s1',
'CC(=O)c1ccc(S(=O)(=O)NC(=O)NC2CCCCC2)cc1',
'N#Cc1c(Cl)c2ccccc2n2c1nc1ccccc12',
'CN(C(=O)C(Cc1ccccc1Cl)C[NH+]1CCC2(CC1)OCCc1cc(F)sc12)C1CC1'],addH=addH).get_data()
- 预测 logD 的回归任务,构建 Inference_Dataset
2.Inference_Dataset
class Inference_Dataset(object):
def __init__(self,sml_list,max_len=100,addH=True):
self.vocab = str2num
self.devocab = num2str
self.sml_list = [i for i in sml_list if len(i)<max_len]
self.addH = addH
def get_data(self):
self.dataset = tf.data.Dataset.from_tensor_slices((self.sml_list,))
self.dataset = self.dataset.map(self.tf_numerical_smiles).padded_batch(64, padded_shapes=(
tf.TensorShape([None]), tf.TensorShape([None,None]),tf.TensorShape([1]),tf.TensorShape([None]))).cache().prefetch(20)
return self.dataset
def numerical_smiles(self, smiles):
smiles_origin = smiles
smiles = smiles.numpy().decode()
atoms_list, adjoin_matrix = smiles2adjoin(smiles,explicit_hydrogens=self.addH)
atoms_list = ['<global>'] + atoms_list
nums_list = [str2num.get(i,str2num['<unk>']) for i in atoms_list]
temp = np.ones((len(nums_list),len(nums_list)))
temp[1:,1:] = adjoin_matrix
adjoin_matrix = (1-temp)*(-1e9)
x = np.array(nums_list).astype('int64')
return x, adjoin_matrix,[smiles], atoms_list
def tf_numerical_smiles(self, smiles):
x,adjoin_matrix,smiles,atom_list = tf.py_function(self.numerical_smiles, [smiles], [tf.int64, tf.float32,tf.string, tf.string])
x.set_shape([None])
adjoin_matrix.set_shape([None,None])
smiles.set_shape([1])
atom_list.set_shape([None])
return x, adjoin_matrix,smiles,atom_list
x, adjoin_matrix, smiles ,atom_list = next(iter(inference_dataset.take(1)))
seq = tf.cast(tf.math.equal(x, 0), tf.float32)
mask = seq[:, tf.newaxis, tf.newaxis, :]
print(x, adjoin_matrix, smiles ,atom_list,mask)
tf.Tensor(
[[16 2 2 2 2 2 4 2 2 2 2 4 2 2 4 4 2 7 2 7 1 1 1 1
1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[16 2 3 2 2 2 2 4 3 2 2 2 2 2 4 2 3 4 2 2 2 2 2 2
5 5 5 2 2 2 2 2 4 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0]
[16 2 2 4 3 2 2 2 2 4 2 2 1 1 1 1 1 1 1 1 1 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[16 2 2 4 2 2 2 2 2 2 3 4 4 2 2 2 2 4 2 2 2 2 2 2
4 2 4 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[16 2 2 2 2 2 2 3 2 6 4 2 2 2 2 2 2 2 2 2 2 2 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[16 2 2 4 3 2 3 3 2 6 3 4 4 6 1 1 1 1 1 1 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[16 2 2 4 2 2 2 2 6 4 4 3 2 4 3 2 2 2 2 2 2 2 2 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[16 3 2 2 2 7 2 2 2 2 2 2 3 2 3 2 2 2 2 2 2 1 1 1
1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[16 2 3 2 4 2 2 2 2 2 2 2 2 7 2 3 2 2 2 2 2 4 2 2
2 2 2 5 6 2 2 2 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]], shape=(9, 64), dtype=int64)
tf.Tensor(
[[[-0.e+00 -0.e+00 -0.e+00 ... 0.e+00 0.e+00 0.e+00]
[-0.e+00 -0.e+00 -0.e+00 ... 0.e+00 0.e+00 0.e+00]
[-0.e+00 -0.e+00 -0.e+00 ... 0.e+00 0.e+00 0.e+00]
...
[ 0.e+00 0.e+00 0.e+00 ... 0.e+00 0.e+00 0.e+00]
[ 0.e+00 0.e+00 0.e+00 ... 0.e+00 0.e+00 0.e+00]
[ 0.e+00 0.e+00 0.e+00 ... 0.e+00 0.e+00 0.e+00]]
[[-0.e+00 -0.e+00 -0.e+00 ... -0.e+00 0.e+00 0.e+00]
[-0.e+00 -0.e+00 -0.e+00 ... -1.e+09 0.e+00 0.e+00]
[-0.e+00 -0.e+00 -0.e+00 ... -1.e+09 0.e+00 0.e+00]
...
[-0.e+00 -1.e+09 -1.e+09 ... -0.e+00 0.e+00 0.e+00]
[ 0.e+00 0.e+00 0.e+00 ... 0.e+00 0.e+00 0.e+00]
[ 0.e+00 0.e+00 0.e+00 ... 0.e+00 0.e+00 0.e+00]]
[[-0.e+00 -0.e+00 -0.e+00 ... 0.e+00 0.e+00 0.e+00]
[-0.e+00 -0.e+00 -0.e+00 ... 0.e+00 0.e+00 0.e+00]
[-0.e+00 -0.e+00 -0.e+00 ... 0.e+00 0.e+00 0.e+00]
...
[ 0.e+00 0.e+00 0.e+00 ... 0.e+00 0.e+00 0.e+00]
[ 0.e+00 0.e+00 0.e+00 ... 0.e+00 0.e+00 0.e+00]
[ 0.e+00 0.e+00 0.e+00 ... 0.e+00 0.e+00 0.e+00]]
...
[[-0.e+00 -0.e+00 -0.e+00 ... 0.e+00 0.e+00 0.e+00]
[-0.e+00 -0.e+00 -0.e+00 ... 0.e+00 0.e+00 0.e+00]
[-0.e+00 -0.e+00 -0.e+00 ... 0.e+00 0.e+00 0.e+00]
...
[ 0.e+00 0.e+00 0.e+00 ... 0.e+00 0.e+00 0.e+00]
[ 0.e+00 0.e+00 0.e+00 ... 0.e+00 0.e+00 0.e+00]
[ 0.e+00 0.e+00 0.e+00 ... 0.e+00 0.e+00 0.e+00]]
[[-0.e+00 -0.e+00 -0.e+00 ... 0.e+00 0.e+00 0.e+00]
[-0.e+00 -0.e+00 -0.e+00 ... 0.e+00 0.e+00 0.e+00]
[-0.e+00 -0.e+00 -0.e+00 ... 0.e+00 0.e+00 0.e+00]
...
[ 0.e+00 0.e+00 0.e+00 ... 0.e+00 0.e+00 0.e+00]
[ 0.e+00 0.e+00 0.e+00 ... 0.e+00 0.e+00 0.e+00]
[ 0.e+00 0.e+00 0.e+00 ... 0.e+00 0.e+00 0.e+00]]
[[-0.e+00 -0.e+00 -0.e+00 ... -0.e+00 -0.e+00 -0.e+00]
[-0.e+00 -0.e+00 -0.e+00 ... -1.e+09 -1.e+09 -1.e+09]
[-0.e+00 -0.e+00 -0.e+00 ... -1.e+09 -1.e+09 -1.e+09]
...
[-0.e+00 -1.e+09 -1.e+09 ... -0.e+00 -1.e+09 -1.e+09]
[-0.e+00 -1.e+09 -1.e+09 ... -1.e+09 -0.e+00 -1.e+09]
[-0.e+00 -1.e+09 -1.e+09 ... -1.e+09 -1.e+09 -0.e+00]]], shape=(9, 64, 64), dtype=float32)
tf.Tensor(
[[b'C=C(CC)C(=O)c1ccc(OCC(=O)[O-])c(Cl)c1Cl']
[b'CN(Cc1c(C(=O)NC2CCCC(O)C2)noc1-c1ccc(C(F)(F)F)cc1)C1CCOC1']
[b'CC(=O)Nc1ccc(O)cc1']
[b'CC(=O)CC(c1ccc([N+](=O)[O-])cc1)c1c(O)c2ccccc2oc1=O']
[b'Cc1cccc(NC(=S)Oc2ccc3ccccc3c2)c1']
[b'CC(=O)Nc1nnc(S(N)(=O)=O)s1']
[b'CC(=O)c1ccc(S(=O)(=O)NC(=O)NC2CCCCC2)cc1']
[b'N#Cc1c(Cl)c2ccccc2n2c1nc1ccccc12']
[b'CN(C(=O)C(Cc1ccccc1Cl)C[NH+]1CCC2(CC1)OCCc1cc(F)sc12)C1CC1']], shape=(9, 1), dtype=string)
tf.Tensor(
[[b'<global>' b'C' b'C' b'C' b'C' b'C' b'O' b'C' b'C' b'C' b'C' b'O' b'C'
b'C' b'O' b'O' b'C' b'Cl' b'C' b'Cl' b'H' b'H' b'H' b'H' b'H' b'H' b'H'
b'H' b'H' b'H' b'H' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b''
b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b''
b'' b'']
[b'<global>' b'C' b'N' b'C' b'C' b'C' b'C' b'O' b'N' b'C' b'C' b'C' b'C'
b'C' b'O' b'C' b'N' b'O' b'C' b'C' b'C' b'C' b'C' b'C' b'F' b'F' b'F'
b'C' b'C' b'C' b'C' b'C' b'O' b'C' b'H' b'H' b'H' b'H' b'H' b'H' b'H'
b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H'
b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'' b'']
[b'<global>' b'C' b'C' b'O' b'N' b'C' b'C' b'C' b'C' b'O' b'C' b'C' b'H'
b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'' b'' b'' b'' b'' b'' b'' b''
b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b''
b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'']
[b'<global>' b'C' b'C' b'O' b'C' b'C' b'C' b'C' b'C' b'C' b'N' b'O' b'O'
b'C' b'C' b'C' b'C' b'O' b'C' b'C' b'C' b'C' b'C' b'C' b'O' b'C' b'O'
b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H'
b'H' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b''
b'' b'' b'' b'' b'' b'']
[b'<global>' b'C' b'C' b'C' b'C' b'C' b'C' b'N' b'C' b'S' b'O' b'C' b'C'
b'C' b'C' b'C' b'C' b'C' b'C' b'C' b'C' b'C' b'H' b'H' b'H' b'H' b'H'
b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'' b'' b'' b'' b''
b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b''
b'' b'' b'' b'']
[b'<global>' b'C' b'C' b'O' b'N' b'C' b'N' b'N' b'C' b'S' b'N' b'O' b'O'
b'S' b'H' b'H' b'H' b'H' b'H' b'H' b'' b'' b'' b'' b'' b'' b'' b'' b''
b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b''
b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'']
[b'<global>' b'C' b'C' b'O' b'C' b'C' b'C' b'C' b'S' b'O' b'O' b'N' b'C'
b'O' b'N' b'C' b'C' b'C' b'C' b'C' b'C' b'C' b'C' b'H' b'H' b'H' b'H'
b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H'
b'H' b'H' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b''
b'' b'' b'' b'' b'' b'']
[b'<global>' b'N' b'C' b'C' b'C' b'Cl' b'C' b'C' b'C' b'C' b'C' b'C'
b'N' b'C' b'N' b'C' b'C' b'C' b'C' b'C' b'C' b'H' b'H' b'H' b'H' b'H'
b'H' b'H' b'H' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b''
b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b'' b''
b'' b'' b'']
[b'<global>' b'C' b'N' b'C' b'O' b'C' b'C' b'C' b'C' b'C' b'C' b'C' b'C'
b'Cl' b'C' b'N' b'C' b'C' b'C' b'C' b'C' b'O' b'C' b'C' b'C' b'C' b'C'
b'F' b'S' b'C' b'C' b'C' b'C' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H'
b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H'
b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H' b'H']], shape=(9, 64), dtype=string)
tf.Tensor(
[[[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]]
[[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1.]]]
[[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]]
[[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]]
[[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]]
[[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]]
[[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]]
[[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]]
[[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]]], shape=(9, 1, 1, 64), dtype=float32)
- 将 Inference_Dataset 初始化中的9个 SMILES 转化成下列数据:
- x是 SMILES 的向量表示,以 dataset 定义的 str2num映射
- adjoin_matrix 是分子图的邻接矩阵表示,有键相连为-0.e+00,没有键相连为-1.e+09
- smiles 是输入的 SMILES
- atom_list 是 SMILES 的原子列表,包括添加的超节点
- mask 是 pad 标志,pad 的位置是1,没有pad 是0,因为最后一个 SMILES 的分子最长,其他都需要 pad 到它的长度,只有 mask 最后一个是最长的,没有 pad标志都是0
model = PredictModel_test(num_layers=num_layers, d_model=d_model, dff=dff, num_heads=num_heads, vocab_size=vocab_size,dense_dropout=0.15)
pred = model(x,mask=mask,training=True,adjoin_matrix=adjoin_matrix)
model.load_weights('regression_weights/logD.h5')
- 构造推理模型,加载 regression 训练好的参数
3.PredictModel_test
class PredictModel_test(tf.keras.Model):
def __init__(self,num_layers = 6,d_model = 256,dff = 512,num_heads = 8,vocab_size =17,dropout_rate = 0.1,dense_dropout=0.5):
super(PredictModel_test, self).__init__()
self.encoder = Encoder_test(num_layers=num_layers,d_model=d_model,
num_heads=num_heads,dff=dff,input_vocab_size=vocab_size,maximum_position_encoding=200,rate=dropout_rate)
self.fc1 = tf.keras.layers.Dense(256, activation=tf.keras.layers.LeakyReLU(0.1))
self.dropout = tf.keras.layers.Dropout(dense_dropout)
self.fc2 = tf.keras.layers.Dense(1)
def call(self,x,adjoin_matrix,mask,training=False):
x,att,xs = self.encoder(x,training=training,mask=mask,adjoin_matrix=adjoin_matrix)
x = x[:, 0, :]
x = self.fc1(x)
x = self.dropout(x, training=training)
x = self.fc2(x)
return x,att,xs
- encoder 部分是 Encoder_test,在 call 阶段输出不同
3.1.Encoder_test
class Encoder_test(tf.keras.Model):
def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size,
maximum_position_encoding, rate=0.1):
super(Encoder_test, self).__init__()
self.d_model = d_model
self.num_layers = num_layers
self.embedding = tf.keras.layers.Embedding(input_vocab_size, d_model)
self.enc_layers = [EncoderLayer(d_model, num_heads, dff, rate)
for _ in range(num_layers)]
self.dropout = tf.keras.layers.Dropout(rate)
def call(self, x, training, mask,adjoin_matrix):
seq_len = tf.shape(x)[1]
adjoin_matrix = adjoin_matrix[:,tf.newaxis,:,:]
x = self.embedding(x)
x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
x = self.dropout(x, training=training)
attention_weights_list = []
xs = []
for i in range(self.num_layers):
x,attention_weights = self.enc_layers[i](x, training, mask,adjoin_matrix)
attention_weights_list.append(attention_weights)
xs.append(x)
return x,attention_weights_list,xs
- x 先经过 Embedding,再进入6 个 EncoderLayer 层,每经过一个 EncoderLayer 层记录一次输出的 x 和注意力权重
- 三者的 shape 如下
x, adjoin_matrix, smiles ,atom_list = next(iter(inference_dataset.take(1)))
seq = tf.cast(tf.math.equal(x, 0), tf.float32)
mask = seq[:, tf.newaxis, tf.newaxis, :]
x,atts,xs= model(x,mask=mask,training=True,adjoin_matrix=adjoin_matrix)
np.asarray(x).shape,np.asarray(atts).shape,np.asarray(xs).shape
((9, 1), (6, 9, 8, 64, 64), (6, 9, 64, 256))
- x 是经过 model 完整流程,最后预测得到的每个分子的 logD 矩阵,atts 和 xs 是经过 6 层 EncoderLayer 后每层收集到的注意力权重和 x,注意力权重的 shape 是 (9, 8, 64, 64),x 的 shape 是 (9, 64, 256)
- 输入 model 的参数 shape 分别是 (9,64),(9,64,64),(9,1,1,64),x 表示的是 9 个长度为 64 的向量,向量的每个元素表示的是一个原子类型索引,经过 Embedding 后,变成了(9,64,256),主要变换是将每个原子变成了长度是 256 的向量
- 输入 EncoderLayer 的参数 x,mask,adjoin_matrix 的 shape 分别是 (9,64,256),(9,1,1,64),(9,1,64,64)
3.2.atts & xs
class EncoderLayer(tf.keras.layers.Layer):
def __init__(self, d_model, num_heads, dff, rate=0.1):
super(EncoderLayer, self).__init__()
self.mha = MultiHeadAttention(d_model, num_heads)
self.ffn = point_wise_feed_forward_network(d_model, dff)
self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.dropout1 = tf.keras.layers.Dropout(rate)
self.dropout2 = tf.keras.layers.Dropout(rate)
def call(self, x, training, mask,adjoin_matrix):
attn_output, attention_weights = self.mha(x, x, x, mask,adjoin_matrix)
attn_output = self.dropout1(attn_output, training=training)
out1 = self.layernorm1(x + attn_output)
ffn_output = self.ffn(out1)
ffn_output = self.dropout2(ffn_output, training=training)
out2 = self.layernorm2(out1 + ffn_output)
return out2,attention_weights
- 这里的 out2 和 attention_weights 就是列表中的单个元素
- 进入多头注意力层得到输出
3.3.MultiHeadAttention
class MultiHeadAttention(tf.keras.layers.Layer):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
self.num_heads = num_heads
self.d_model = d_model
assert d_model % self.num_heads == 0
self.depth = d_model // self.num_heads
self.wq = tf.keras.layers.Dense(d_model)
self.wk = tf.keras.layers.Dense(d_model)
self.wv = tf.keras.layers.Dense(d_model)
self.dense = tf.keras.layers.Dense(d_model)
def split_heads(self, x, batch_size):
"""Split the last dimension into (num_heads, depth).
Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
"""
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, v, k, q, mask,adjoin_matrix):
batch_size = tf.shape(q)[0]
q = self.wq(q)
k = self.wk(k)
v = self.wv(v)
q = self.split_heads(q, batch_size)
k = self.split_heads(k, batch_size)
v = self.split_heads(v, batch_size)
scaled_attention, attention_weights = scaled_dot_product_attention(
q, k, v, mask,adjoin_matrix)
scaled_attention = tf.transpose(scaled_attention,
perm=[0, 2, 1, 3])
concat_attention = tf.reshape(scaled_attention,
(batch_size, -1, self.d_model))
output = self.dense(concat_attention)
return output, attention_weights
- (9,64,256) 的 x 分别作为 q,k,v 经过全连接层后 shape 仍然是 (9,64,256),然后得到 (9,8,64,64) 的注意力张量 attention_weights 和 (9,64,256) 的 output,之后经过全连接层 ffn 等层之后作为下一层 EncoderLayer 的 x
4.plot_weights
i = 0
print(x)
smiles_plot = smiles[i].numpy().tolist()[0].decode()
mol = Chem.MolFromSmiles(smiles_plot)
num_atoms = mol.GetNumAtoms()
attentions_plot = tf.concat([att[i:(i+1),:,:num_atoms+1,:num_atoms+1] for att in atts],axis=0)
plot_weights(smiles_plot,attentions_plot)
-
x 是预测的 logD,smiles_plot 是第 i 个 smiles 字符串 -
atts 的 shape 是 (6, 9, 8, 64, 64),att 是每一层 EncoderLayer 得到的注意力权重张量,att[i:(i+1),:,:num_atoms+1,:num_atoms+1] 索引的是第 i 个分子的注意力权重,而且用 num_atoms+1 将 pad 部分的注意力排除,这里表示的是第0个分子对应的8个注意力头得到的注意力张量。最后将6个注意力张量concat起来,最终得到的 shape 是 (6,8,20,20),这里20包括超节点的注意力 -
输出结果如下:
tf.Tensor(
[[ 0.00611259]
[ 0.10352743]
[-0.09410169]
[-0.01148836]
[ 0.29321525]
[-0.2280695 ]
[ 0.03325798]
[ 0.24569517]
[ 0.2692453 ]], shape=(9, 1), dtype=float32)
[[0.02703399 0.03738527 0.0341253 0.04261673 0.03024014 0.046207
0.04024169 0.03412233]
[0.00220853 0.01803984 0.01086219 0.10811627 0.0314402 0.20126604
0.03202052 0.01011109]
[0.01053316 0.02646094 0.02683227 0.05374964 0.00403768 0.01397272
0.03217852 0.00420805]
[0.01578051 0.00886862 0.00608971 0.17516439 0.0039229 0.01038222
0.0165774 0.04300734]
[0.02164022 0.05754587 0.01020102 0.01972281 0.04598003 0.0355869
0.0741061 0.03371757]
[0.02877988 0.02249552 0.03075971 0.04478414 0.04234883 0.03901222
0.05866967 0.02472056]]
[['O13', '0.07'], ['C12', '0.05'], ['O14', '0.05'], ['C11', '0.05'], ['O5', '0.05'], ['O10', '0.04'], ['C3', '0.03'], ['C2', '0.03'], ['C9', '0.03'], ['Cl16', '0.03'], ['C0', '0.02'], ['C8', '0.02'], ['C4', '0.02'], ['C1', '0.02'], ['C7', '0.02'], ['Cl18', '0.02'], ['C15', '0.02'], ['C6', '0.02'], ['C17', '0.02']]
def plot_weights(smiles,attention_plot,max=5):
mol = Chem.MolFromSmiles(smiles_plot)
mol = Chem.RemoveHs(mol)
num_atoms = mol.GetNumAtoms()
atoms = []
for i in range(num_atoms):
atom = mol.GetAtomWithIdx(i)
atoms.append(atom.GetSymbol()+str(i))
att = tf.reduce_mean(tf.reduce_mean(attentions_plot[3:,:,0,:],axis=0),axis=0)[1:].numpy()
print(attentions_plot[:,:,0,0].numpy())
indices = (-att).argsort()
highlight = indices.tolist()
print([[atoms[indices[i]],('%.2f'%att[indices[i]])] for i in range(len(indices))])
drawer = rdMolDraw2D.MolDraw2DSVG(800,600)
opts = drawer.drawOptions()
drawer.drawOptions().updateAtomPalette({k: (0, 0, 0) for k in DrawingOptions.elemDict.keys()})
colors = {}
for i,h in enumerate(highlight):
colors[h] = (1,
1-1*(att[h]-att[highlight[-1]])/(att[highlight[0]]-att[highlight[-1]]),
1-1*(att[h]-att[highlight[-1]])/(att[highlight[0]]-att[highlight[-1]]))
drawer.DrawMolecule(mol,highlightAtoms = highlight,highlightAtomColors=colors,highlightBonds=[])
drawer.FinishDrawing()
svg = drawer.GetDrawingText().replace('svg:','')
display(SVG(svg))
- tf.reduce_mean(attentions_plot[3:,:,0,:],axis=0) 求和表示将后3层的8个头的注意力值求和,得到 (8,20) 的矩阵,再求和得到20个原子的注意力值,取[1:]除去添加的超原子,最终得到 (19,) 的向量表示每个原子的注意力,输出如下
tf.reduce_mean(tf.reduce_mean(attentions_plot[3:,:,0,:],axis=0),axis=0)[1:]
<tf.Tensor: shape=(19,), dtype=float32, numpy=
array([0.02398169, 0.02298377, 0.03124303, 0.03180713, 0.0232711 ,
0.04853497, 0.0192481 , 0.02276288, 0.02380962, 0.02621273,
0.03749018, 0.04908869, 0.05493892, 0.06948597, 0.04930567,
0.02019732, 0.0252701 , 0.01770476, 0.02110688], dtype=float32)>
- 将注意力值排序后每个原子在分子属性中的贡献(注意力)排序,输出如下:
indices = (-att).argsort()
highlight = indices.tolist()
highlight
[13, 12, 14, 11, 5, 10, 3, 2, 9, 16, 0, 8, 4, 1, 7, 18, 15, 6, 17]
- 这里的输出意味着对预测本分子 logD 贡献最大原子的索引是13,贡献第二大的原子索引是12,以此类推
print([[atoms[indices[i]],('%.2f'%att[indices[i]])] for i in range(len(indices))])
- 当i为0表示取第1个原子,indices[i] 得到贡献序号,atoms[indices[i]] 得到具体的原子,(‘%.2f’%att[indices[i]]) 取这个原子的贡献值,这里即 [‘O13’, ‘0.07’],意味着第13个原子O对预测 logD 的贡献是 0.07,将分子可视化且将贡献值也就是注意力高亮就得到了图片展示的结果
- 剩下的示例运行与上面的第一个分子基本一致
|