You are on page 1of 39

Junction Tree Variational Autoencoder for

Molecular Graph Generation

王振

2021年11月15日
目录
• 简介
• Tree decomposition
• Tree encoder
• Graph encoder
• Resample
• Tree decoder
• Graph decoder
• Conclusion
目录
• 简介
• Tree decomposition
• Tree encoder
• Graph encoder
• Sample
• Tree decoder
• Graph decoder
• Conclusion
Junction Tree Variational Autoencoder for Molecular Graph Generation

本篇论文是基于Graph的,而在生成分子之前的很多工作中,很多都是基于SMILES的。

基于SMILES生成模型的两个关键缺点:
1.分子的 SMILES 表示不是为捕获分子相似度而设计的,会导致生成模型难以学习到平
滑的分子embedding。

2.比起SMILES表示,在图上更容易表达分子的一些重要的化学特性,比如分子的有效性。
这里作者假设,直接在图上操作可以改进有效化学结构的生成性建模。
‣ 基于原子(atom)的分子生成== 基于字母 ‣ 基于官能团(group)的分子生成 == 基
生成句子 于单词生成句子

‣ 原子和键 ‣ 环和键
‣ 基于原子的生成 ‣ 基于官能团的生成
‣ 中间步骤可能不具有化学意义 ‣ 每一步都具有化学意义

Jin et al., Junction Tree Variational Autoencoder for Molecular Graph Generation. arXiv:1802.04364
整体架构

4
目录
• 简介
• Tree decomposition
• Tree encoder
• Graph decoder
• Resample
• Tree decoder
• Graph decoder
• Conclusion
分子→连接树

分子 连接树

N N N N N O O Cl S …
官能团词汇
N N N S C …

• group by group生成分子

• 词汇量:处理250K分子得到少于800个
分子→连接树

1.对每个原子编号,提取非环键和单环,划分成两类节点
2.共享原子数大于2的节点合并
3.若有3个及以上的节点共享一个原子,则将该
原子独立成新的节点
4.提取官能团

总词汇数<800
5.找到最小生成树(最短路径)
Tree Decomposition

Vocab
Tree Decomposition

SMILES
MolTree类
self.smiles
self.nodes node1 MolTreeNode类
node2 self.smiles
node3 self.mol
node4 self.clique
…… self.neighbor
self.is_leaf
self.nid
self.label
self.idx
self.wid
目录
• 简介
• Tree decomposition
• Tree encoder
• Graph encoder
• Resample
• Tree decoder
• Graph decoder
• Conclusion
Tree encoder
Old memory

embedding

𝑚𝑚𝑘𝑘𝑘𝑘
𝑚𝑚𝑖𝑖𝑖𝑖
𝑘𝑘 ∈ 𝑁𝑁 𝑖𝑖 \j
Final
Old memory memory

𝑥𝑥𝑖𝑖
embedding
fnode:存储该batch_size fmess:存储该batch_size中 mess_graph:存储所有边的所有 node_graph:存储所有节点的所
中所有node的word_id 所有边的初始节点的idx 前向边的idx 有前向边的idx(作为尾节点)
torch.size([num_nodes, ]) torch.size([num_edges, ]) torch.size([num_edges, Max_NB]) torch.size([num_nodes, Max_NB])
Max_NB Max_NB
0 73 0 23 0 20 5 0
0 20 5 0
1 5 1 54 1 11 16 45
1 11 16 45
2 65 2 11 2 2 0 0
2 2 0 0
3 12 3 12 3 3 6 0
3 3 6 0
…… …… … …
… …

scope=List[Tuple(int: start_idx, int: len)]


fnode:存储该batch_size中 fmess:存储该batch_size中所 mess_graph:存储所有边的所有
所有node的word_id 有边的初始节点的idx 前向边的idx
torch.size([num_nodes, ]) torch.size( [num_edges, ]) torch.size([num_edges, Max_NB])
Max_NB
初始化 node_graph:存储所有节点的所
0 73 0 0(pad)
所有边的特征 有前向边的idx
0 P P P
1 5 1 54 torch.size([num_nodes, Max_NB])
hidden_size
2 65 1 11 16 45
2 11 Max_NB
3 12 2 2 0 0 0 0 0 0 0
3 12
… … 3 3 6 0 1 0 0 0 0 0 20 5 0
… …
… … 2 0 0 0 0 1 11 16 45
3 0 0 0 0 2 2 0 0
nn.Embedding
(vocab_size, hidden_size) hidden_size=450 … … 3 3 6 0
0 p p p p GRU messages … …
1 torch.size([num_
hidden_size=450 edges, 450 ])
2
0
3 hidden_size=450
1
… … 0
2 1
edge_begin_node_features
3 torch.size([num_edges, 450 ]) 2
… … 3
… …
node_features
torch.size([num_nodes, 450 ]) messages
torch.size([num_edges, 450 ])
node_graph:存储所有节点的所 node_features
messages 有前向边的idx torch.size([num_nodes,
torch.size([num_edges, 450 ]) torch.size([num_nodes, Max_NB]) 450 ])
hidden_size=450
hidden_size=450 0 20 5 0 hidden_size + hidden_size hidden_size
0
0 1 11 16 45 0
1 Cat(dim=1)0
1 2 2 0 0
2
1 Linear 1
2
3 3 6 0 2 2
3 3
… … 3 3
… … … …
… … … …
hidden_size Node_features
hidden_size=450 torch.size([num_nodes,
Max_NB 450 ])
0
mess_nei.sum(dim=1) 1
num_nodes hidden_size
2
scope=List[Tuple( start_idx, len)] 0
3
mess_nei 1
… … 之前设置了每个tree的第1个节
torch.size([num_nodes, Max_NB, 450 ])
点为根节点 2
mess_nei
torch.size([num_nodes, … …
450 ])
Tree_features
torch.size(
[batch_size, 450 ])
目录
• 简介
• Tree decomposition
• Tree encoder
• Graph encoder
• Resample
• Tree decoder
• Graph decoder
• Conclusion
Graph encoder
fatoms:存储该batch_size fbonds:存储该batch_size agraph:存储所有atom的所有前向 bgraph:存储所有bond的所有
中所有atom的特征 中所有bond的特征 边的idx 前向边的idx
torch.size([num_atoms, 39]) torch.size([num_bonds, 50]) torch_size([num_atoms, Max_NB]) torch_size([num_bonds,
39 Max_NB=6 Max_NB])
39+11=50 Max_NB=6
0 1 0 . 1 0 1 0 1 . 0 0 0 20 5 0 0 20 5 0
1 0 1 . 0 1 1 1 0 . 0 1 1 11 16 45 1 11 16 45
2 0 1 . 1 2 0 1 0 . 1 1 2 2 0 0 2 2 0 0
3 1 0 . 0 3 0 0 1 . 1 0 3 3 6 0 3 3 6 0
…… . …… . … … … …
one-hot encoding one-hot encoding
Symbol:23 Bond Type:5
Degree(NB):6 Stereo :6
Formal Charge:5 scope=List[Tuple(int: start_idx, int: len)]
Chiral:4 Bond_Feature_Dim
Aromatic :1 =5+6=11
Atom_Feature_Dim
=23+6+5+4+1=39
fbonds:存储该batch_size中所有 binput messages bgraph:存储所有bond的
bond的特征 torch.size([num_bonds, torch.size([num_bonds, 所有前向边的idx
torch.size([num_bonds, 50]) hidden_size]) hidden_size]) torch_size([num_bonds,
Max_NB])
hidden_size=450 hidden_size=450 Max_NB=6
39+11=50
0 0 0 20 5 0
0 1 0 1 . 0 0
nn.Linear(50, 1 1 1 11 16 45
1 1 1 0 . 0 1 hidden_sze) ReLU
2 2 2 2 0 0
2 0 1 0 . 1 1
3 0 0 1 . 1 0 3 3 3 3 6 0
… … . … … … … … …
one-hot encoding

hidden_size=450 hidden_size

hidden_size=450 0
6
1 nei_message.
0 nn.Linear(
2 sum(dim=1) num_
1 2*hidden_size, Cat
hidden_sze) (dim=1) 3 bonds
2 Linear
… …
3 nei_message
nei_message torch.size([num_bonds, 6, 450 ])
… …
torch.size([num_bonds, 450 ])
messages
torch.size([num_bonds,
450 ])
fatoms:存储该batch_size中所有
messages agraph:存储所有atom的 atom的特征
torch.size([num_bonds, 所有前向边的idx torch.size([num_atoms, 39]) hidden_size=450
hidden_size]) torch_size([num_atoms, 39
Max_NB]) 0
hidden_size=450 Max_NB=6
0 1 0 . 1 1
0 0 20 5 0 1 0 1 . 0
Cat Linear(489,450) 2
1 1 11 16 45 2 0 1 . 1
3
2 2 2 0 0 3 1 0 . 0
… …
3 3 6 0 … … .
3 atom_features
one-hot encoding
… … … … torch.size([num_atoms,
450 ])
scope=List[Tuple( start_idx,
hidden_size=450
len)]
hidden_size
0 得到每个分子的每个原子
hidden_size
6 1 的特征,取均值
2 0
atom_nei_message.
num_ sum(dim=1) 3 1
atoms
… … 2
… …
atom_nei_message atom_nei_message graph_features
torch.size([num_atoms, torch.size([num_atoms, torch.size(
6, 450 ]) 450 ]) [batch_size, 450 ])
目录
• 简介
• Tree decomposition
• Tree encoder
• Graph encoder
• Resample
• Tree decoder
• Graph decoder
• Conclusion
latent_size=28
resample
0
hidden_size hidden_size
nn.Linear(hidden_size, 1 𝜇𝜇
0 0 latent_size) 2
torch.size(
1 … … [batch_size, latent_size ])
1
2 2 latent_size=28
… … … …
0
𝜎𝜎 2
Tree_features
torch.size(
graph_features
torch.size(
nn.Linear(hidden_size, 1 𝑙𝑙𝑙𝑙𝑙𝑙
[batch_size, [batch_size, latent_size) 2
torch.size(
hidden_size ]) hidden_size ]) … … [batch_size, latent_size ])

若从𝑁𝑁 𝜇𝜇, 𝜎𝜎 中直接采样一个样本进行解码,则在反向传播的时候会造成梯度断裂。


从𝑁𝑁 0, 1 中采样𝜀𝜀,则𝑁𝑁 𝜇𝜇, 𝜎𝜎 中的样本可表示为𝜀𝜀𝜀𝜀 + 𝜇𝜇

𝜇𝜇
𝑧𝑧𝐺𝐺
𝜎𝜎 2 exp
𝑙𝑙𝑙𝑙𝑙𝑙
[batch_size, latent_size ]
1
𝜀𝜀
[batch_size, latent_size ]
2
目录
• 简介
• Tree decomposition
• Tree encoder
• Graph decoder
• Resample
• Tree decoder
• Graph decoder
• Conclusion
Tree Decoder
Tree Decoder

node_ Node_ directi


x y on
扩展到整个batch_size中的所有tree,每行表示1个tree
1 2 1
2 3 1
(1,2,1) (2,3,1) (3,4,1) (4,5,1) (5,4,0) (4,3,0) (3,6,1) (6,3,0) (3,2,0) (2,1,0) (1,7,1) (7,1,0)
dfs 3 4 1
4 5 1
batch_
5 4 0 size
4 3 0
3 6 1
6 3 0
3 2 0
prop_list
2 1 0
1 7 1
7 1 0
Tree Decoder 1
root
root root 1 2
1 1
2
Add node 2 as 3
Get embedding
with wid of 1 2 neighbor of 1
3 4 Add node 4 as
neighbor of 5

Get embedding 4
with wid of 2 5 Get embedding


initial with wid of 4
GRU
message 5 Add node 5 as
GRU neighbor of 4
node_x Node_y directio
n
1 2 1
1 1 Get embedding
2 3 1
with wid of 5 3 4 1
new_message
4 5 1
2 new_message
GRU 5 4 0
4 3 0
4 3 6 1
Stop pre new_message 6 3 0
Label pre(if direction =1) 3 2 0
TP 5 2 1 0
LP TP 1 7 1
LP 7 1 0
Tree Decoder
目录
• 简介
• Tree decomposition
• Tree encoder
• Graph encoder
• Resample
• Tree decoder
• Graph decoder
• Conclusion
Graph Decoder

Decoding graph
Set of possible candidates
graphs of tree 𝑇𝑇�
Graph Decoder
i

Atoms of subgraph Bond Features


u-v bond(edge)
Atom Features Old message from
neighbors
j

Tree message

Subgraph of ground truth Subgraphs from model

Ground truth Set of possible candidate subgraphs


目录
• 简介
• Tree decomposition
• Tree encoder
• Graph encoder
• Resample
• Tree decoder
• Graph decoder
• Conclusion
Conclusion

生成分子的合法性
谢 谢!

You might also like