import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl.nn.pytorch import NNConv, Set2Set
from dgl.nn.pytorch.conv import GINConv
from dgl.nn.pytorch.glob import AvgPooling, MaxPooling, SumPooling
[docs]class GATLayer(nn.Module):
def __init__(self, g, in_dim, out_dim):
super(GATLayer, self).__init__()
self.g = g
self.fc = nn.Linear(in_dim, out_dim, bias=False)
self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False)
[docs] def edge_attention(self, edges):
z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
a = self.attn_fc(z2)
return {'e': F.leaky_relu(a)}
[docs] def message_func(self, edges):
return {'z': edges.src['z'], 'e': edges.data['e']}
[docs] def reduce_func(self, nodes):
alpha = F.softmax(nodes.mailbox['e'], dim=1)
h = torch.sum(alpha * nodes.mailbox['z'], dim=1)
return {'h': h}
[docs] def forward(self, h):
z = self.fc(h)
self.g.ndata['z'] = z
self.g.apply_edges(self.edge_attention)
self.g.update_all(self.message_func, self.reduce_func)
return self.g.ndata.pop('h')
[docs]class SELayer(nn.Module):
"""Squeeze-and-excitation networks"""
def __init__(self, in_channels, se_channels):
super(SELayer, self).__init__()
self.in_channels = in_channels
self.se_channels = se_channels
self.encoder_decoder = nn.Sequential(
nn.Linear(in_channels, se_channels),
nn.ELU(),
nn.Linear(se_channels, in_channels),
nn.Sigmoid(),
)
[docs] def forward(self, x):
""""""
# Aggregate input representation
x_global = torch.mean(x, dim=0)
# Compute reweighting vector s
s = self.encoder_decoder(x_global)
return x * s
[docs]class ApplyNodeFunc(nn.Module):
"""Update the node feature hv with MLP, BN and ReLU."""
def __init__(self, mlp, use_selayer):
super(ApplyNodeFunc, self).__init__()
self.mlp = mlp
self.bn = (
SELayer(self.mlp.output_dim, int(np.sqrt(self.mlp.output_dim)))
if use_selayer
else nn.BatchNorm1d(self.mlp.output_dim)
)
[docs] def forward(self, h):
h = self.mlp(h)
h = self.bn(h)
h = F.relu(h)
return h
[docs]class MLP(nn.Module):
"""MLP with linear output"""
def __init__(self, num_layers, input_dim, hidden_dim, output_dim, use_selayer):
"""MLP layers construction
Paramters
---------
num_layers: int
The number of linear layers
input_dim: int
The dimensionality of input features
hidden_dim: int
The dimensionality of hidden units at ALL layers
output_dim: int
The number of classes for prediction
"""
super(MLP, self).__init__()
self.linear_or_not = True # default is linear model
self.num_layers = num_layers
self.output_dim = output_dim
if num_layers < 1:
raise ValueError("number of layers should be positive!")
elif num_layers == 1:
# Linear model
self.linear = nn.Linear(input_dim, output_dim)
else:
# Multi-layer model
self.linear_or_not = False
self.linears = torch.nn.ModuleList()
self.batch_norms = torch.nn.ModuleList()
self.linears.append(nn.Linear(input_dim, hidden_dim))
for layer in range(num_layers - 2):
self.linears.append(nn.Linear(hidden_dim, hidden_dim))
self.linears.append(nn.Linear(hidden_dim, output_dim))
for layer in range(num_layers - 1):
self.batch_norms.append(
SELayer(hidden_dim, int(np.sqrt(hidden_dim)))
if use_selayer
else nn.BatchNorm1d(hidden_dim)
)
[docs] def forward(self, x):
if self.linear_or_not:
# If linear model
return self.linear(x)
else:
# If MLP
h = x
for i in range(self.num_layers - 1):
h = F.relu(self.batch_norms[i](self.linears[i](h)))
return self.linears[-1](h)
[docs]class UnsupervisedGAT(nn.Module):
def __init__(
self, node_input_dim, node_hidden_dim, edge_input_dim, num_layers, num_heads
):
super(UnsupervisedGAT, self).__init__()
assert node_hidden_dim % num_heads == 0
self.layers = nn.ModuleList(
[
GATLayer(
in_feats=node_input_dim if i == 0 else node_hidden_dim,
out_feats=node_hidden_dim // num_heads,
num_heads=num_heads,
feat_drop=0.0,
attn_drop=0.0,
alpha=0.2,
residual=False,
agg_mode="flatten",
activation=F.leaky_relu if i + 1 < num_layers else None,
)
for i in range(num_layers)
]
)
[docs] def forward(self, g, n_feat, e_feat):
for i, layer in enumerate(self.layers):
n_feat = layer(g, n_feat)
return n_feat
[docs]class UnsupervisedMPNN(nn.Module):
"""
MPNN from
`Neural Message Passing for Quantum Chemistry <https://arxiv.org/abs/1704.01212>`__
Parameters
----------
node_input_dim : int
Dimension of input node feature, default to be 15.
edge_input_dim : int
Dimension of input edge feature, default to be 15.
output_dim : int
Dimension of prediction, default to be 12.
node_hidden_dim : int
Dimension of node feature in hidden layers, default to be 64.
edge_hidden_dim : int
Dimension of edge feature in hidden layers, default to be 128.
num_step_message_passing : int
Number of message passing steps, default to be 6.
num_step_set2set : int
Number of set2set steps
num_layer_set2set : int
Number of set2set layers
"""
def __init__(
self,
output_dim=32,
node_input_dim=32,
node_hidden_dim=32,
edge_input_dim=32,
edge_hidden_dim=32,
num_step_message_passing=6,
lstm_as_gate=False,
):
super(UnsupervisedMPNN, self).__init__()
self.num_step_message_passing = num_step_message_passing
self.lin0 = nn.Linear(node_input_dim, node_hidden_dim)
edge_network = nn.Sequential(
nn.Linear(edge_input_dim, edge_hidden_dim),
nn.ReLU(),
nn.Linear(edge_hidden_dim, node_hidden_dim * node_hidden_dim),
)
self.conv = NNConv(
in_feats=node_hidden_dim,
out_feats=node_hidden_dim,
edge_func=edge_network,
aggregator_type="sum",
)
self.lstm_as_gate = lstm_as_gate
if lstm_as_gate:
self.lstm = nn.LSTM(node_hidden_dim, node_hidden_dim)
else:
self.gru = nn.GRU(node_hidden_dim, node_hidden_dim)
[docs] def forward(self, g, n_feat, e_feat):
"""Predict molecule labels
Parameters
----------
g : DGLGraph
Input DGLGraph for molecule(s)
n_feat : tensor of dtype float32 and shape (B1, D1)
Node features. B1 for number of nodes and D1 for
the node feature size.
e_feat : tensor of dtype float32 and shape (B2, D2)
Edge features. B2 for number of edges and D2 for
the edge feature size.
Returns
-------
res : Predicted labels
"""
out = F.relu(self.lin0(n_feat)) # (B1, H1)
h = out.unsqueeze(0) # (1, B1, H1)
c = torch.zeros_like(h)
for i in range(self.num_step_message_passing):
m = F.relu(self.conv(g, out, e_feat)) # (B1, H1)
if self.lstm_as_gate:
out, (h, c) = self.lstm(m.unsqueeze(0), (h, c))
else:
out, h = self.gru(m.unsqueeze(0), h)
out = out.squeeze(0)
return out
[docs]class UnsupervisedGIN(nn.Module):
"""GIN model"""
def __init__(
self,
num_layers,
num_mlp_layers,
input_dim,
hidden_dim,
output_dim,
final_dropout,
learn_eps,
graph_pooling_type,
neighbor_pooling_type,
use_selayer,
):
"""model parameters setting
Paramters
---------
num_layers: int
The number of linear layers in the neural network
num_mlp_layers: int
The number of linear layers in mlps
input_dim: int
The dimensionality of input features
hidden_dim: int
The dimensionality of hidden units at ALL layers
output_dim: int
The number of classes for prediction
final_dropout: float
dropout ratio on the final linear layer
learn_eps: boolean
If True, learn epsilon to distinguish center nodes from neighbors
If False, aggregate neighbors and center nodes altogether.
neighbor_pooling_type: str
how to aggregate neighbors (sum, mean, or max)
graph_pooling_type: str
how to aggregate entire nodes in a graph (sum, mean or max)
"""
super(UnsupervisedGIN, self).__init__()
self.num_layers = num_layers
self.learn_eps = learn_eps
# List of MLPs
self.ginlayers = torch.nn.ModuleList()
self.batch_norms = torch.nn.ModuleList()
for layer in range(self.num_layers - 1):
if layer == 0:
mlp = MLP(
num_mlp_layers, input_dim, hidden_dim, hidden_dim, use_selayer
)
else:
mlp = MLP(
num_mlp_layers, hidden_dim, hidden_dim, hidden_dim, use_selayer
)
self.ginlayers.append(
GINConv(
ApplyNodeFunc(mlp, use_selayer),
neighbor_pooling_type,
0,
self.learn_eps,
)
)
self.batch_norms.append(
SELayer(hidden_dim, int(np.sqrt(hidden_dim)))
if use_selayer
else nn.BatchNorm1d(hidden_dim)
)
# Linear function for graph poolings of output of each layer
# which maps the output of different layers into a prediction score
self.linears_prediction = torch.nn.ModuleList()
for layer in range(num_layers):
if layer == 0:
self.linears_prediction.append(nn.Linear(input_dim, output_dim))
else:
self.linears_prediction.append(nn.Linear(hidden_dim, output_dim))
self.drop = nn.Dropout(final_dropout)
if graph_pooling_type == "sum":
self.pool = SumPooling()
elif graph_pooling_type == "mean":
self.pool = AvgPooling()
elif graph_pooling_type == "max":
self.pool = MaxPooling()
else:
raise NotImplementedError
[docs] def forward(self, g, h, efeat):
# list of hidden representation at each layer (including input)
hidden_rep = [h]
for i in range(self.num_layers - 1):
h = self.ginlayers[i](g, h)
h = self.batch_norms[i](h)
h = F.relu(h)
hidden_rep.append(h)
score_over_layer = 0
# perform pooling over all nodes in each graph in every layer
all_outputs = []
for i, h in list(enumerate(hidden_rep)):
pooled_h = self.pool(g, h)
all_outputs.append(pooled_h)
score_over_layer += self.drop(self.linears_prediction[i](pooled_h))
return score_over_layer, all_outputs[1:]
[docs]class GraphEncoder(nn.Module):
"""
MPNN from
`Neural Message Passing for Quantum Chemistry <https://arxiv.org/abs/1704.01212>`__
Parameters
----------
node_input_dim : int
Dimension of input node feature, default to be 15.
edge_input_dim : int
Dimension of input edge feature, default to be 15.
output_dim : int
Dimension of prediction, default to be 12.
node_hidden_dim : int
Dimension of node feature in hidden layers, default to be 64.
edge_hidden_dim : int
Dimension of edge feature in hidden layers, default to be 128.
num_step_message_passing : int
Number of message passing steps, default to be 6.
num_step_set2set : int
Number of set2set steps
num_layer_set2set : int
Number of set2set layers
"""
def __init__(
self,
positional_embedding_size=32,
max_node_freq=8,
max_edge_freq=8,
max_degree=128,
freq_embedding_size=32,
degree_embedding_size=32,
output_dim=32,
node_hidden_dim=32,
edge_hidden_dim=32,
num_layers=6,
num_heads=4,
num_step_set2set=6,
num_layer_set2set=3,
norm=False,
gnn_model="mpnn",
degree_input=False,
lstm_as_gate=False,
):
super(GraphEncoder, self).__init__()
if degree_input:
node_input_dim = positional_embedding_size + degree_embedding_size + 1
else:
node_input_dim = positional_embedding_size + 1
edge_input_dim = freq_embedding_size + 1
if gnn_model == "mpnn":
self.gnn = UnsupervisedMPNN(
output_dim=output_dim,
node_input_dim=node_input_dim,
node_hidden_dim=node_hidden_dim,
edge_input_dim=edge_input_dim,
edge_hidden_dim=edge_hidden_dim,
num_step_message_passing=num_layers,
lstm_as_gate=lstm_as_gate,
)
elif gnn_model == "gat":
self.gnn = UnsupervisedGAT(
node_input_dim=node_input_dim,
node_hidden_dim=node_hidden_dim,
edge_input_dim=edge_input_dim,
num_layers=num_layers,
num_heads=num_heads,
)
elif gnn_model == "gin":
self.gnn = UnsupervisedGIN(
num_layers=num_layers,
num_mlp_layers=2,
input_dim=node_input_dim,
hidden_dim=node_hidden_dim,
output_dim=output_dim,
final_dropout=0.5,
learn_eps=False,
graph_pooling_type="sum",
neighbor_pooling_type="sum",
use_selayer=False,
)
self.gnn_model = gnn_model
self.max_node_freq = max_node_freq
self.max_edge_freq = max_edge_freq
self.max_degree = max_degree
self.degree_input = degree_input
if degree_input:
self.degree_embedding = nn.Embedding(
num_embeddings=max_degree + 1, embedding_dim=degree_embedding_size
)
self.set2set = Set2Set(node_hidden_dim, num_step_set2set, num_layer_set2set)
self.lin_readout = nn.Sequential(
nn.Linear(2 * node_hidden_dim, node_hidden_dim),
nn.ReLU(),
nn.Linear(node_hidden_dim, output_dim),
)
self.norm = norm
[docs] def forward(self, g, return_all_outputs=False):
"""Predict molecule labels
Parameters
----------
g : DGLGraph
Input DGLGraph for molecule(s)
n_feat : tensor of dtype float32 and shape (B1, D1)
Node features. B1 for number of nodes and D1 for
the node feature size.
e_feat : tensor of dtype float32 and shape (B2, D2)
Edge features. B2 for number of edges and D2 for
the edge feature size.
Returns
-------
res : Predicted labels
"""
if self.degree_input:
device = g.ndata["seed"].device
degrees = g.in_degrees()
if device != torch.device("cpu"):
degrees = degrees.cuda(device)
n_feat = torch.cat(
(
g.ndata["pos_undirected"],
self.degree_embedding(degrees.clamp(0, self.max_degree)),
g.ndata["seed"].unsqueeze(1).float(),
),
dim=-1,
)
else:
n_feat = torch.cat(
(g.ndata["pos_undirected"], g.ndata["seed"].unsqueeze(1).float()),
dim=-1,
)
e_feat = None
if self.gnn_model == "gin":
x, all_outputs = self.gnn(g, n_feat, e_feat)
else:
x, all_outputs = self.gnn(g, n_feat, e_feat), None
x = self.set2set(g, x)
x = self.lin_readout(x)
if self.norm:
x = F.normalize(x, p=2, dim=-1, eps=1e-5)
if return_all_outputs:
return x, all_outputs
else:
return x