작성자 : 채윤병
논문의 목적 : 분자 그래프로부터 화학적 성질 예측을 위한 분자의 특성을 학습할 수 있고 그래프 동형(graph isomorphism)문제에 대해 invariant한 효과적인 머신러닝 모델을 제안하는 것
그래프 동형(graph isomorphism) - 노드의 순서의 차이로 인해 다르게 표현되지만 완전히 같은 구조를 가지고 있는 그래프
→ MPNN(Message Passing Neural Networks)
QM9 dataset - 화학 분자 benchmark dataset, 13만개의 분자의 DFT 연산으로 근사한 13가지 특성이 존재
DFT(Density functional theory)? - 기본 재료 특성과 같은 고차 매개변수를 요구하지 않고 양자 역학적 고려 사항을 기반으로 재료 거동을 예측하고 계산
QM9은 화학적 특성의 계산에 사용되는 원자들의 low energy conformation(낮은 에너지 상태의 형태)을 위한 공간적 정보가 들어있다. QM9은 원자 사이의 거리, 결합각과 같은 기하학적 정보가 있는 경우(a)와 원자(atom)와 결합(bond)에 대한 정보만 있고 원자들의 공간적인 위치에 따라 '정의되어 있는' 화학적 특성을 계산해야 하는 경우(b)를 고려한다.
(b)의 경우 모델이 내재적으로 low energy 3D conformation을 결정할 수 있게 학습이 이루어져야 한다!
Error
Key Contribution
Forward pass - 1. Message passing phase 2. Readout phase
Message passing phase - ① Message function, Mt ② Vertex update function, Ut
Mt, Ut, R은 모두 미분가능한 함수
R은 node state의 집합에 대해 적용되며 MPNN이 그래프 동형(graph isomorphism)문제에 대해 invariant하기 위해서 node state의 순서에 따라 invariant해야한다.
Message passing이 연결된 node와 연결된 edge를 각각 더하는 방법으로 정의되기 때문에 edge state와 node state간의 correlation을 나타낼 수 없다.
각 타임 스텝에서 큰 그래프의 일부분만을 이용해 message를 passing 함으로써 연산량을 줄였다.
그래프의 각 node의 target이 있는 경우, 그래프 level 단위의 target이 있는 경우, 각 timp step에서 node level이 영향을 미치는 경우 고려
Message passing 단계에서 edge representation e_vw를 업데이트
Graph laplacian L을 도입해서 graph에 convolution의 개념을 도입(GCN)
→ Moving forward.. 여러 MPNN의 예시에서 실용적인 중요도를 가지는 구체적인 활용에 대해 집중해야 한다. 그래야 모델링의 개선이나 활용의 detail을 정할 수 있다.
그동안 과학자들은 양자 역학의 근사를 위해서 DFT(Density functional theory), GW approximation, Quantum Monte-Carlo 등 다양한 방법을 사용했다.
그 중 DFT는 연산량이 크기 때문에 Neural Network로 이를 근사하려는 노력이 있었다. but 성공적이지 못함.
따라서 최근에 양자 역학의 근사를 Neural Network로 직접하려는 연구가 있었다.
이 두 가지 모두 대칭 함수와 같은 hand-engineered feature를 사용했는데 이는 generalize 문제(원자의 종류가 많아질 경우)와 graph isomorphism에 invariant 하지 않을 수 있다는 문제점이 있다.
H, C, O, N, F 등 9개의 원자로 이루어져 있으며 134k의 약물과 유사한 분자에 대한 정보를 담고 있는 dataset.
각 분자에 대해 합리적인 low energy structure를 찾고 atom position이 가능한지(availability) 판별하기 위해 DFT가 사용된다. 또한 여러 기본적인 화학적 특성이 계산되어 담겨있음.
Bond 관련 특성 4가지 + 분자의 진동에 대한 특성 2가지 + 전자 상태에 대한 특성 3가지 + 전자의 공간적 분포에 대한 특성 3가지
GG-NN을 baseline으로 잡고 다른 message function, output function 등을 시도해 보았다.
MPNN의 input은 그래프 노드의 feature vector x_v, 인접 행렬 A(두 원자의 공간적인 거리와 결합 종류의 정보까지 담고 있는 weighted matrix)
연결되어 있지 않은 node 쌍 간의 "virtual edge type"을 추가하고 모든 node와 특별한 edge type(virtual)으로 연결되어 있는 "master node"를 추가. Master node는 그래프의 global한 정보를 담고 있다.
첫 번째로 원래 GG-NN에서 사용한 readout function 사용, 다른 방법으로는 set2set model(Input으로 projection한 tuple을 사용하여 단순히 더하는 방법보다 표현력이 좋은 방법) 사용
d차원의 node embedding을 k개의 d/k차원의 node embedding으로 만들어 k개의 copies에 대해 각각 propagation step을 진행 → 연산의 효율 증가
Number of Hydrogens와 같은 feature로 atom feature에 포함시키는 것과 달리 수소 분자를 외부 노드로 만들어서도 실험을 진행했다(이 경우 그래프의 노드개수는 최대 29개 → QM9 dataset의 분자는 small molecule).
3가지의 edge representation
실험의 세부 설정들 - Uniform random hyperparameter(50 trials), Adam optimizer, learning rate 1e-5 ~ 5e-4, Target value normalize(mean 0 ,variance 1), Random 하게 13만개 molecule중에서 1만개 validation, 1만개 test 선택, MSE minimize하고 MAE로 evaluate
QM9 dataset에 대해 가장 적절한 input representation 뿐만 아니라 best MPNN을 찾기 위해서 여러가지 실험을 진행했다. 결합 타입과 공간적인 정보를 담고 있는 edge feature를 사용하고 수소 분자를 하나의 노드(explicit node)로 설정했을 때가 가장 성능이 좋았던 실험 설정이었다 + (message function - edge network, output function - set2set).
13개의 target value각각에 대해 따로 학습을 시키는 것이 최대 40%의 성능 향상이 있었다.(상당히 번거롭지 않을까..?, Weakness)
13개의 target value 모두 기존 모델들보다 좋은 결과를 보였으며 13개 중 11개의 target value가 chemical accuracy를 달성했다.
Spaital information을 넣지 않고 실험을 진행했을 때는 MPNN이 long range interaction을 고려하도록 하는 것이 성능향상에 큰 도움이 되었다.
d차원의 node embedding을 k개의 d/k차원의 node embedding으로 만들어 k개의 copies에 대해 각각 propagation step을 진행하는 multiple tower 방식이 정규화의 역할을 했다.
적절한 message, update, output function을 가진 MPNN 모델이 분자의 특성을 예측하기 위한 유용한 inductive bias를 제공한다. Master node나 set2set output과 같이 node간의 long range interaction을 고려하는 것이 중요하다.
미래에 MPNN이 해결해야할 문제 → Generalization - 더 큰 그래프(molecule)에 대해서도 generalization을 개선하는 것
공간적 정보를 사용할 때 큰 분자에서 generalization이 어려운 이유
☆ Attention을 적용한 MPNN 읽어보기
[Graph 스터디] DAGNN(Directed Acyclic Graph Neural Networks) (0) | 2022.11.23 |
---|---|
[Graph 스터디] Semi-supervised classification with Graph Convolution Networks (0) | 2022.11.10 |
[Graph 스터디] Graph Neural Network (0) | 2022.10.09 |
Predict then propagate : Graph Neural Network meet Personalized PageRank (0) | 2022.09.22 |
Graph neural network review(2) (0) | 2022.09.13 |
댓글 영역