DGL团队发布了以生命科学为重点的软件包DGL-LifeSci。
尝试使用新的DGL--LifeSci并建立Attentive FP模型并可视化其预测结果。 基于深度图学习框架DGL 环境准备 PyTorch:深度学习框架 DGL:基于PyTorch的库,支持深度学习以处理图形 RDKit:用于构建分子图并从字符串表示形式绘制结构式 DGL-LifeSci:面向化学和生物领域的 GNN 算法库
DGL安装 conda install -c dglteam dgl #DGLv0.4.3
DGL-LifeSci安装
基于Attentive FP可视化训练模型导入库 import matplotlib.pyplot as plt import os from rdkit import Chem from rdkit.Chem import rdmolops, rdmolfiles from rdkit import RDPaths import dgl import numpy as np import random import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader from torch.utils.data import Dataset from dgl import model_zoo from dgllife.model import AttentiveFPPredictor from dgllife.utils import mol_to_complete_graph, mol_to_bigraph from dgllife.utils import atom_type_one_hot from dgllife.utils import atom_degree_one_hot from dgllife.utils import atom_formal_charge from dgllife.utils import atom_num_radical_electrons from dgllife.utils import atom_hybridization_one_hot from dgllife.utils import atom_total_num_H_one_hot from dgllife.utils import one_hot_encoding from dgllife.utils import CanonicalAtomFeaturizer from dgllife.utils import CanonicalBondFeaturizer from dgllife.utils import ConcatFeaturizer from dgllife.utils import BaseAtomFeaturizer from dgllife.utils import BaseBondFeaturizer from dgllife.utils import one_hot_encoding from dgl.data.utils import split_dataset from functools import partial from sklearn.metrics import roc_auc_score
定义辅助函数 代码来源于dgl/example。 def chirality(atom): try: return one_hot_encoding(atom.GetProp('_CIPCode'), ['R', 'S']) + \ [atom.HasProp('_ChiralityPossible')] except: return [False, False] + [atom.HasProp('_ChiralityPossible')] def collate_molgraphs(data): """Batching a list of datapoints for dataloader. Parameters ---------- data : list of 3-tuples or 4-tuples. Each tuple is for a single datapoint, consisting of a SMILES, a DGLGraph, all-task labels and optionally a binary mask indicating the existence of labels. Returns ------- smiles : list List of smiles bg : BatchedDGLGraph Batched DGLGraphs labels : Tensor of dtype float32 and shape (B, T) Batched datapoint labels. B is len(data) and T is the number of total tasks. masks : Tensor of dtype float32 and shape (B, T) Batched datapoint binary mask, indicating the existence of labels. If binary masks are not provided, return a tensor with ones. """ assert len(data[0]) in [3, 4], \ 'Expect the tuple to be of length 3 or 4, got {:d}'.format(len(data[0])) if len(data[0]) == 3: smiles, graphs, labels = map(list, zip(*data)) masks = None else: smiles, graphs, labels, masks = map(list, zip(*data)) bg = dgl.batch(graphs) bg.set_n_initializer(dgl.init.zero_initializer) bg.set_e_initializer(dgl.init.zero_initializer) labels = torch.stack(labels, dim=0) if masks is None: masks = torch.ones(labels.shape) else: masks = torch.stack(masks, dim=0) return smiles, bg, labels, masks
原子和键特征化器 atom_featurizer = BaseAtomFeaturizer( {'hv': ConcatFeaturizer([ partial(atom_type_one_hot, allowable_set=[ 'B', 'C', 'N', 'O', 'F', 'Si', 'P', 'S', 'Cl', 'As', 'Se', 'Br', 'Te', 'I', 'At'], encode_unknown=True), partial(atom_degree_one_hot, allowable_set=list(range(6))), atom_formal_charge, atom_num_radical_electrons, partial(atom_hybridization_one_hot, encode_unknown=True), lambda atom: [0], # A placeholder for aromatic information, atom_total_num_H_one_hot, chirality ], )}) bond_featurizer = BaseBondFeaturizer({ 'he': lambda bond: [0 for _ in range(10)] })
加载数据集,rdkit mol对象转换为图对象 带有featurizer的mol_to_bigraph方法将rdkit mol对象转换为图对象。此外,smiles_to_bigraph方法可以将smiles转换为图。 train_mols = Chem.SDMolSupplier('solubility.train.sdf') train_smi =[Chem.MolToSmiles(m) for m in train_mols] train_sol = torch.tensor([float(mol.GetProp('SOL')) for mol in train_mols]).reshape(-1,1) test_mols = Chem.SDMolSupplier('solubility.test.sdf') test_smi = [Chem.MolToSmiles(m) for m in test_mols] test_sol = torch.tensor([float(mol.GetProp('SOL')) for mol in test_mols]).reshape(-1,1) train_graph =[mol_to_bigraph(mol, node_featurizer=atom_featurizer, edge_featurizer=bond_featurizer) for mol in train_mols] test_graph =[mol_to_bigraph(mol, node_featurizer=atom_featurizer, edge_featurizer=bond_featurizer) for mol in test_mols]
AttentivFp模型 并定义用于训练和测试的数据加载器。 model = AttentiveFPPredictor(node_feat_size=39, edge_feat_size=10, num_layers=2, num_timesteps=2, graph_feat_size=200, n_tasks=1, dropout=0.2) #model = model.to('cuda:0') train_loader = DataLoader(dataset=list(zip(train_smi, train_graph, train_sol)), batch_size=128, collate_fn=collate_molgraphs) test_loader = DataLoader(dataset=list(zip(test_smi, test_graph, test_sol)), batch_size=128, collate_fn=collate_molgraphs)
定义可视化函数 def drawmol(idx, dataset, timestep): smiles, graph, _ = dataset[idx] print(smiles) bg = dgl.batch([graph]) atom_feats, bond_feats = bg.ndata['hv'], bg.edata['he'] if torch.cuda.is_available(): print('use cuda') bg.to(torch.device('cuda:0')) atom_feats = atom_feats.to('cuda:0') bond_feats = bond_feats.to('cuda:0') _, atom_weights = model(bg, atom_feats, bond_feats, get_node_weight=True) assert timestep < len(atom_weights), 'Unexpected id for the readout round' atom_weights = atom_weights[timestep] min_value = torch.min(atom_weights) max_value = torch.max(atom_weights) atom_weights = (atom_weights - min_value) / (max_value - min_value) norm = matplotlib.colors.Normalize(vmin=0, vmax=1.28) cmap = cm.get_cmap('bwr') plt_colors = cm.ScalarMappable(norm=norm, cmap=cmap) atom_colors = {i: plt_colors.to_rgba(atom_weights[i].data.item()) for i in range(bg.number_of_nodes())} mol = Chem.MolFromSmiles(smiles) rdDepictor.Compute2DCoords(mol) drawer = rdMolDraw2D.MolDraw2DSVG(280, 280) drawer.SetFontSize(1) op = drawer.drawOptions() mol = rdMolDraw2D.PrepareMolForDrawing(mol) drawer.DrawMolecule(mol, highlightAtoms=range(bg.number_of_nodes()), highlightBonds=[], highlightAtomColors=atom_colors) drawer.FinishDrawing() svg = drawer.GetDrawingText() svg = svg.replace('svg:', '') if torch.cuda.is_available(): atom_weights = atom_weights.to('cpu') a = np.array([[0,1]]) plt.figure(figsize=(9, 1.5)) img = plt.imshow(a, cmap="bwr") plt.gca().set_visible(False) cax = plt.axes([0.1, 0.2, 0.8, 0.2]) plt.colorbar(orientation='horizontal', cax=cax) plt.show() return (Chem.MolFromSmiles(smiles), atom_weights.data.numpy(), svg)
绘制测试数据集分子 该模型预测溶解度,颜色表示红色是溶解度的积极影响,蓝色是负面影响。 target = test_loader.dataset for i in range(len(target))[:5]: mol, aw, svg = drawmol(i, target, 0) print(aw.min(), aw.max()) display(SVG(svg))
参考资料 https://github.com/dmlc/dgl/tree/master/apps/life_sci https://github.com/dmlc/dgl/blob/master/python/dgl/model_zoo/chem/attentive_fp.py https://pubs./doi/full/10.1021/acs.jcim.9b00387 https://github.com/awslabs/dgl-lifesci
|