Source code for cogdl.data.dataset

import collections
import os.path as osp
from itertools import repeat
import numpy as np

import torch.utils.data

from cogdl.data import Adjacency, Graph
from cogdl.utils import makedirs
from cogdl.utils import Accuracy, CrossEntropyLoss


def to_list(x):
    if not isinstance(x, collections.Iterable) or isinstance(x, str):
        x = [x]
    return x


def files_exist(files):
    return all([osp.exists(f) for f in files])


[docs]class Dataset(torch.utils.data.Dataset): r"""Dataset base class for creating graph datasets. See `here <https://rusty1s.github.io/pycogdl/build/html/notes/ create_dataset.html>`__ for the accompanying tutorial. Args: root (string): Root directory where the dataset should be saved. transform (callable, optional): A function/transform that takes in an :obj:`cogdl.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`cogdl.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`cogdl.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) """
[docs] @staticmethod def add_args(parser): """Add dataset-specific arguments to the parser.""" pass
@property def raw_file_names(self): r"""The name of the files to find in the :obj:`self.raw_dir` folder in order to skip the download.""" raise NotImplementedError @property def processed_file_names(self): r"""The name of the files to find in the :obj:`self.processed_dir` folder in order to skip the processing.""" raise NotImplementedError
[docs] def download(self): r"""Downloads the dataset to the :obj:`self.raw_dir` folder.""" raise NotImplementedError
[docs] def process(self): r"""Processes the dataset to the :obj:`self.processed_dir` folder.""" raise NotImplementedError
def __len__(self): r"""The number of examples in the dataset.""" return 1
[docs] def get(self, idx): r"""Gets the data object at index :obj:`idx`.""" raise NotImplementedError
def __init__(self, root, transform=None, pre_transform=None, pre_filter=None): super(Dataset, self).__init__() self.root = osp.expanduser(osp.normpath(root)) self.raw_dir = osp.join(self.root, "raw") self.processed_dir = osp.join(self.root, "processed") self.transform = transform self.pre_transform = pre_transform self.pre_filter = pre_filter self._download() self._process() @property def num_features(self): r"""Returns the number of features per node in the graph.""" if hasattr(self, "data") and isinstance(self.data, Graph): return self.data.num_features elif hasattr(self, "data") and isinstance(self.data, list): return self.data[0].num_features else: return 0 @property def raw_paths(self): r"""The filepaths to find in order to skip the download.""" files = to_list(self.raw_file_names) return [osp.join(self.raw_dir, f) for f in files] @property def processed_paths(self): r"""The filepaths to find in the :obj:`self.processed_dir` folder in order to skip the processing.""" files = to_list(self.processed_file_names) return [osp.join(self.processed_dir, f) for f in files] def _download(self): if files_exist(self.raw_paths): # pragma: no cover return makedirs(self.raw_dir) self.download() def _process(self): if files_exist(self.processed_paths): # pragma: no cover return print("Processing...") makedirs(self.processed_dir) self.process() print("Done!")
[docs] def get_evaluator(self): return Accuracy()
[docs] def get_loss_fn(self): return CrossEntropyLoss()
def __getitem__(self, idx): # pragma: no cover r"""Gets the data object at index :obj:`idx` and transforms it (in case a :obj:`self.transform` is given).""" assert idx == 0 data = self.data data = data if self.transform is None else self.transform(data) return data @property def num_classes(self): r"""The number of classes in the dataset.""" if hasattr(self, "y") and self.y is not None: y = self.y elif hasattr(self, "data") and hasattr(self.data, "y") and self.data.y is not None: y = self.data.y else: return 0 return y.max().item() + 1 if y.dim() == 1 else y.size(1) @property def edge_attr_size(self): return None @property def max_degree(self): return self.data.degrees().max().item() + 1 @property def max_graph_size(self): return self.data.num_nodes @property def num_graphs(self): return 1 def __repr__(self): # pragma: no cover return "{}".format(self.__class__.__name__)
[docs]class MultiGraphDataset(Dataset): def __init__(self, root=None, transform=None, pre_transform=None, pre_filter=None): super(MultiGraphDataset, self).__init__(root, transform, pre_transform, pre_filter) self.data, self.slices = None, None @property def num_classes(self): if hasattr(self, "y"): y = self.y elif hasattr(self, "data") and hasattr(self.data[0], "y"): y = torch.cat([x.y for x in self.data], dim=0) else: return 0 return y.max().item() + 1 if y.dim() == 1 else y.size(1) @property def num_features(self): if isinstance(self[0], Graph): return self[0].num_features else: return 0 @property def max_degree(self): max_degree = [x.degrees().max().item() for x in self.data] max_degree = np.max(max_degree) + 1 return max_degree @property def num_graphs(self): return len(self.data) @property def max_graph_size(self): return np.max([g.num_nodes for g in self.data])
[docs] def len(self): if isinstance(self.data, list): return len(self.data) else: for item in self.slices.values(): return len(item) - 1 return 0
def _get(self, idx): data = self.data.__class__() if hasattr(self.data, "__num_nodes__"): data.num_nodes = self.data.__num_nodes__[idx] for key in self.data.__old_keys__(): item, slices = self.data[key], self.slices[key] # start, end = slices[idx].item(), slices[idx + 1].item() start, end = int(slices[idx]), int(slices[idx + 1]) if key == "edge_index": data[key] = (item[0][start:end], item[1][start:end]) else: if torch.is_tensor(item): s = list(repeat(slice(None), item.dim())) s[self.data.__cat_dim__(key, item)] = slice(start, end) elif start + 1 == end: s = slices[start] else: s = slice(start, end) data[key] = item[s] return data
[docs] def get(self, idx): try: idx = int(idx) except Exception: idx = idx if torch.is_tensor(idx): idx = idx.numpy().tolist() if isinstance(idx, int): if self.slices is not None: return self._get(idx) return self.data[idx] if isinstance(idx, slice): start = idx.start end = idx.stop step = idx.step if idx.step else 1 idx = list(range(start, end, step)) if len(idx) > 1: # unsupport `slice` if self.slices is not None: return [self._get(int(i)) for i in idx] return [self.data[i] for i in idx]
def __getitem__(self, item): return self.get(item) def __len__(self): return len(self.data) def __repr__(self): # pragma: no cover return "{}({})".format(self.__class__.__name__, len(self))