import re
import numba
import torch
import numpy as np
import scipy.sparse as sparse
import time
@numba.njit
def reindex(node_idx, col):
node_dict = dict()
cnt = 0
for i in node_idx:
node_dict[i] = cnt
cnt += 1
new_col = [node_dict[i] for i in col]
return np.array(new_col)
[docs]class Data(object):
r"""A plain old python object modeling a single graph with various
(optional) attributes:
Args:
x (Tensor, optional): Node feature matrix with shape :obj:`[num_nodes,
num_node_features]`. (default: :obj:`None`)
edge_index (LongTensor, optional): Graph connectivity in COO format
with shape :obj:`[2, num_edges]`. (default: :obj:`None`)
edge_attr (Tensor, optional): Edge feature matrix with shape
:obj:`[num_edges, num_edge_features]`. (default: :obj:`None`)
y (Tensor, optional): Graph or node targets with arbitrary shape.
(default: :obj:`None`)
pos (Tensor, optional): Node position matrix with shape
:obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`)
The data object is not restricted to these attributes and can be extented
by any other additional data.
"""
def __init__(self, x=None, edge_index=None, edge_attr=None, y=None, pos=None, **kwargs):
self.x = x
self.edge_index = edge_index
self.edge_attr = edge_attr
self.y = y
self.pos = pos
for key, item in kwargs.items():
if key == "num_nodes":
self.__num_nodes__ = item
else:
self[key] = item
self.__adj = None
[docs] @staticmethod
def from_dict(dictionary):
r"""Creates a data object from a python dictionary."""
data = Data()
for key, item in dictionary.items():
data[key] = item
return data
def __getitem__(self, key):
r"""Gets the data of the attribute :obj:`key`."""
return getattr(self, key)
def __setitem__(self, key, value):
"""Sets the attribute :obj:`key` to :obj:`value`."""
setattr(self, key, value)
@property
def keys(self):
r"""Returns all names of graph attributes."""
keys = [key for key in self.__dict__.keys() if self[key] is not None]
keys = [key for key in keys if key[:2] != "__" and key[-2:] != "__"]
return keys
def __len__(self):
r"""Returns the number of all present attributes."""
return len(self.keys)
def __contains__(self, key):
r"""Returns :obj:`True`, if the attribute :obj:`key` is present in the
data."""
return key in self.keys
def __iter__(self):
r"""Iterates over all present attributes in the data, yielding their
attribute names and content."""
for key in sorted(self.keys):
yield key, self[key]
def __call__(self, *keys):
r"""Iterates over all attributes :obj:`*keys` in the data, yielding
their attribute names and content.
If :obj:`*keys` is not given this method will iterative over all
present attributes."""
for key in sorted(self.keys) if not keys else keys:
if self[key] is not None:
yield key, self[key]
[docs] def cat_dim(self, key, value):
r"""Returns the dimension in which the attribute :obj:`key` with
content :obj:`value` gets concatenated when creating batches.
.. note::
This method is for internal use only, and should only be overridden
if the batch concatenation process is corrupted for a specific data
attribute.
"""
# `*index*` and `*face*` should be concatenated in the last dimension,
# everything else in the first dimension.
return -1 if bool(re.search("(index|face)", key)) else 0
def __inc__(self, key, value):
r""" "Returns the incremental count to cumulatively increase the value
of the next attribute of :obj:`key` when creating batches.
.. note::
This method is for internal use only, and should only be overridden
if the batch concatenation process is corrupted for a specific data
attribute.
"""
# Only `*index*` and `*face*` should be cumulatively summed up when
# creating batches.
return self.num_nodes if bool(re.search("(index|face)", key)) else 0
def __cat_dim__(self, key, value):
return self.cat_dim(key, value)
@property
def num_edges(self):
r"""Returns the number of edges in the graph."""
for key, item in self("edge_index", "edge_attr"):
return item.size(self.cat_dim(key, item))
return None
@property
def num_features(self):
r"""Returns the number of features per node in the graph."""
if self.x is None:
return 0
return 1 if self.x.dim() == 1 else self.x.size(1)
@property
def num_nodes(self):
if self.x is not None:
return self.x.shape[0]
return torch.max(self.edge_index) + 1
@property
def num_classes(self):
if self.y is not None:
return int(torch.max(self.y) + 1) if self.y.dim() == 1 else self.y.shape[-1]
@num_nodes.setter
def num_nodes(self, num_nodes):
self.__num_nodes__ = num_nodes
[docs] def is_coalesced(self):
r"""Returns :obj:`True`, if edge indices are ordered and do not contain
duplicate entries."""
row, col = self.edge_index
index = self.num_nodes * row + col
return row.size(0) == torch.unique(index).size(0)
[docs] def apply(self, func, *keys):
r"""Applies the function :obj:`func` to all attributes :obj:`*keys`.
If :obj:`*keys` is not given, :obj:`func` is applied to all present
attributes.
"""
for key, item in self(*keys):
if not isinstance(item, torch.Tensor):
continue
self[key] = func(item)
return self
[docs] def contiguous(self, *keys):
r"""Ensures a contiguous memory layout for all attributes :obj:`*keys`.
If :obj:`*keys` is not given, all present attributes are ensured to
have a contiguous memory layout."""
return self.apply(lambda x: x.contiguous(), *keys)
[docs] def to(self, device, *keys):
r"""Performs tensor dtype and/or device conversion to all attributes
:obj:`*keys`.
If :obj:`*keys` is not given, the conversion is applied to all present
attributes."""
return self.apply(lambda x: x.to(device), *keys)
[docs] def cuda(self, *keys):
return self.apply(lambda x: x.cuda(), *keys)
[docs] def clone(self):
return Data.from_dict({k: v.clone() for k, v in self})
def _build_adj_(self):
if self.__adj is not None:
return
num_edges = self.edge_index.shape[1]
if str(self.edge_index.device) == "cpu":
edge_index = self.edge_index
else:
edge_index = self.edge_index.cpu()
self.edge_row = edge_index[0]
self.edge_col = edge_index[1].numpy()
self.node_idx = torch.unique(self.edge_index).cpu().numpy()
edge_index_np = self.edge_index.cpu().numpy()
num_nodes = self.x.shape[0]
edge_attr_np = np.ones(num_edges)
self.__adj = sparse.csr_matrix(
(edge_attr_np, (edge_index_np[0], edge_index_np[1])), shape=(num_nodes, num_nodes)
)
return self.__adj
def _eliminate_adj_(self):
self.__adj = None
self.edge_row = None
self.edge_col = None
self.node_idx = None
[docs] def subgraph(self, node_idx):
"""Return the induced node subgraph."""
if self.__adj is None:
self._build_adj_()
if isinstance(node_idx, torch.Tensor):
node_idx = node_idx.cpu().numpy()
node_idx = np.unique(node_idx)
adj = self.__adj[node_idx, :][:, node_idx]
adj_coo = sparse.coo_matrix(adj)
row, col = adj_coo.row, adj_coo.col
edge_attr = torch.from_numpy(adj_coo.data).to(self.x.device)
edge_index = torch.from_numpy(np.stack([row, col], axis=0)).to(self.x.device).long()
keys = self.keys
attrs = {key: self[key][node_idx] for key in keys if "edge" not in key and "node_idx" not in key}
attrs["edge_index"] = edge_index
if edge_attr is not None:
attrs["edge_attr"] = edge_attr
return Data(**attrs)
[docs] def edge_subgraph(self, edge_idx, require_idx=False):
"""Return the induced edge subgraph."""
if isinstance(edge_idx, torch.Tensor):
edge_idx = edge_idx.cpu().numpy()
edge_index = self.edge_index.T[edge_idx].cpu().numpy()
node_idx = np.unique(edge_index)
idx_dict = {val: key for key, val in enumerate(node_idx)}
def func(x):
return [idx_dict[x[0]], idx_dict[x[1]]]
edge_index = np.array([func(x) for x in edge_index]).transpose()
edge_index = torch.from_numpy(edge_index).to(self.x.device)
edge_attr = self.edge_attr[edge_idx] if self.edge_attr else None
keys = self.keys
attrs = {key: self[key][node_idx] for key in keys if "edge" not in key}
attrs["edge_index"] = edge_index
if edge_attr is not None:
attrs["edge_attr"] = edge_attr
if require_idx:
return Data(**attrs), node_idx, edge_idx
return Data(**attrs)
[docs] def sample_adj(self, batch, size=-1, replace=True):
assert size != 0
if self.__adj is None:
self._build_adj_()
if isinstance(batch, torch.Tensor):
batch = batch.cpu().numpy()
adj = self.__adj[batch]
batch_size = len(batch)
if size == -1:
row, col = self.edge_row, self.edge_col
_node_idx = self.node_idx
else:
indices = torch.from_numpy(adj.indices)
indptr = torch.from_numpy(adj.indptr)
node_idx, (row, col) = self._sample_adj(batch_size, indices, indptr, size)
col = col.numpy()
_node_idx = node_idx.numpy()
# Reindexing: target nodes are always put at the front
_node_idx = np.concatenate((batch, np.setdiff1d(_node_idx, batch)))
new_col = torch.as_tensor(reindex(_node_idx, col), dtype=torch.long)
edge_index = torch.stack([row.long(), new_col])
node_idx = torch.as_tensor(_node_idx, dtype=torch.long).to(self.x.device)
edge_index = edge_index.long().to(self.x.device)
return node_idx, edge_index
def _sample_adj(self, batch_size, indices, indptr, size):
row_counts = torch.as_tensor([indptr[i] - indptr[i - 1] for i in range(1, len(indptr))]).long()
rand = torch.rand(batch_size, size)
rand = rand * row_counts.view(-1, 1)
rand = rand.long()
rand = rand + indptr[:-1].view(-1, 1)
edge_cols = indices[rand].view(-1)
row = torch.arange(0, batch_size).view(-1, 1).repeat(1, size).view(-1)
node_idx = torch.unique(edge_cols)
return node_idx, (row, edge_cols)
[docs] @staticmethod
def from_pyg_data(data):
val = {k: v for k, v in data}
return Data(**val)
def __repr__(self):
info = ["{}={}".format(key, list(item.size())) for key, item in self if not key.startswith("_")]
return "{}({})".format(self.__class__.__name__, ", ".join(info))