from collections import namedtuple
import dgl
from dgl import DGLGraph
import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, TensorDataset
from torch.utils.data import DataLoader
import torch.optim as optim
from rdkit import Chem
from rdkit.Chem import RDConfig
import os
import copy
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
定义元素列表
ELEM_LIST = ['C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', 'Na', 'Ca', 'Fe', 'Al', 'I', 'B', 'K', 'Se', 'Zn', 'H', 'Cu', 'Mn', 'unknown']
ATOM_FDIM = len(ELEM_LIST) + 6 + 5 + 1
MAX_ATOMNUM =60
BOND_FDIM = 5
MAX_NB = 10
代码来自dgl的 junction tree,生成分子结构图
PAPER = os.getenv('PAPER', False)
def onek_encoding_unk(x, allowable_set):
if x not in allowable_set:
x = allowable_set[-1]
return [x == s for s in allowable_set]
# Note that during graph decoding they don't predict stereochemistry-related
# characteristics (i.e. Chiral Atoms, E-Z, Cis-Trans). Instead, they decode
# the 2-D graph first, then enumerate all possible 3-D forms and find the
# one with highest score.
'''
def atom_features(atom):
return (torch.Tensor(onek_encoding_unk(atom.GetSymbol(), ELEM_LIST)
+ onek_encoding_unk(atom.GetDegree(), [0,1,2,3,4,5])
+ onek_encoding_unk(atom.GetFormalCharge(), [-1,-2,1,2,0])
+ [atom.GetIsAromatic()]))
'''
def atom_features(atom):
return (onek_encoding_unk(atom.GetSymbol(), ELEM_LIST)
+ onek_encoding_unk(atom.GetDegree(), [0,1,2,3,4,5])
+ onek_encoding_unk(atom.GetFormalCharge(), [-1,-2,1,2,0])
+ [atom.GetIsAromatic()])
def bond_features(bond):
bt = bond.GetBondType()
return (torch.Tensor([bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE, bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC, bond.IsInRing()]))
def mol2dgl_single(mols):
cand_graphs = []
n_nodes = 0
n_edges = 0
bond_x = []
for mol in mols:
n_atoms = mol.GetNumAtoms()
n_bonds = mol.GetNumBonds()
g = DGLGraph()
nodeF = []
for i, atom in enumerate(mol.GetAtoms()):
assert i == atom.GetIdx()
nodeF.append(atom_features(atom))
g.add_nodes(n_atoms)
bond_src = []
bond_dst = []
for i, bond in enumerate(mol.GetBonds()):
a1 = bond.GetBeginAtom()
a2 = bond.GetEndAtom()
begin_idx = a1.GetIdx()
end_idx = a2.GetIdx()
features = bond_features(bond)
bond_src.append(begin_idx)
bond_dst.append(end_idx)
bond_x.append(features)
bond_src.append(end_idx)
bond_dst.append(begin_idx)
bond_x.append(features)
g.add_edges(bond_src, bond_dst)
g.ndata['h'] = torch.Tensor(nodeF)
cand_graphs.append(g)
return cand_graphs
msg = fn.copy_src(src="h", out="m")
定义GCN
class NodeApplyModule(nn.Module):
def __init__(self, in_feats, out_feats, activation):
super(NodeApplyModule, self).__init__()
self.linear = nn.Linear(in_feats, out_feats)
self.activation = activation
def forward(self, node):
h = self.linear(node.data['h'])
h = self.activation(h)
return {'h': h}
class GCN(nn.Module):
def __init__(self, in_feats, out_feats, activation):
super(GCN, self).__init__()
self.apply_mod = NodeApplyModule(in_feats, out_feats, activation)
def forward(self, g, feature):
g.ndata['h'] = feature
g.update_all(msg, reduce)
g.apply_nodes(func=self.apply_mod)
h = g.ndata.pop('h')
#print(h.shape)
return h
class Classifier(nn.Module):
def __init__(self, in_dim, hidden_dim, n_classes):
super(Classifier, self).__init__()
self.layers = nn.ModuleList([GCN(in_dim, hidden_dim, F.relu),
GCN(hidden_dim, hidden_dim, F.relu)])
self.classify = nn.Linear(hidden_dim, n_classes)
def forward(self, g):
h = g.ndata['h']
for conv in self.layers:
h = conv(g, h)
g.ndata['h'] = h
hg = dgl.mean_nodes(g, 'h')
return self.classify(hg)
模型训练
epoch_losses = []
for epoch in range(200):
epoch_loss = 0
for i, (bg, label) in enumerate(data_loader):
bg.set_e_initializer(dgl.init.zero_initializer)
bg.set_n_initializer(dgl.init.zero_initializer)
pred = model(bg)
loss = loss_func(pred, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.detach().item()
epoch_loss /= (i + 1)
if (epoch+1) % 20 == 0:
print('Epoch {}, loss {:.4f}'.format(epoch+1, epoch_loss))
epoch_losses.append(epoch_loss)
plt.plot(epoch_losses, c='b')
模型评价
from sklearn.metrics import accuracy_score
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
accuracy_score(test_y, pred_y)
0.7665369649805448
print(classification_report(test_y, pred_y))
导入包
from rdkit.Chem import AllChem
from rdkit.Chem.Descriptors import rdMolDescriptors
from sklearn.preprocessing import normalize
from sklearn.ensemble import RandomForestClassifier
定义3D描述符计算函数,并计算3D描述符
def calc_dragon_type_desc(mol):
return rdMolDescriptors.CalcAUTOCORR3D(mol) + rdMolDescriptors.CalcMORSE(mol) + \
rdMolDescriptors.CalcRDF(mol) + rdMolDescriptors.CalcWHIM(mol)
train_X = normalize([calc_dragon_type_desc(m) for m in train_mols2])
test_X = normalize([calc_dragon_type_desc(m) for m in test_mols2])
基于随机森林的分类
rfc = RandomForestClassifier(n_estimators=100)
rfc.fit(train_X, train_y)
模型评价