import os
import torch
from ogb.nodeproppred import NodePropPredDataset
from ogb.nodeproppred import Evaluator as NodeEvaluator
from ogb.graphproppred import GraphPropPredDataset
from ogb.linkproppred import LinkPropPredDataset
from cogdl.data import Dataset, Graph, DataLoader
from cogdl.utils import CrossEntropyLoss, Accuracy, remove_self_loops, coalesce, BCEWithLogitsLoss
[docs]class OGBNDataset(Dataset):
def __init__(self, root, name, transform=None):
name = name.replace("-", "_")
self.name = name
root = os.path.join(root, name)
super(OGBNDataset, self).__init__(root)
self.transform = None
self.data = torch.load(self.processed_paths[0])
[docs] def get(self, idx):
assert idx == 0
return self.data
[docs] def get_loss_fn(self):
return CrossEntropyLoss()
[docs] def get_evaluator(self):
return Accuracy()
def _download(self):
pass
@property
def processed_file_names(self):
return "data_cogdl.pt"
[docs] def process(self):
name = self.name.replace("_", "-")
dataset = NodePropPredDataset(name, self.root)
graph, y = dataset[0]
x = torch.tensor(graph["node_feat"]).contiguous() if graph["node_feat"] is not None else None
y = torch.tensor(y.squeeze())
row, col = graph["edge_index"][0], graph["edge_index"][1]
row = torch.from_numpy(row)
col = torch.from_numpy(col)
edge_index = torch.stack([row, col], dim=0)
edge_attr = torch.as_tensor(graph["edge_feat"]) if graph["edge_feat"] is not None else graph["edge_feat"]
edge_index, edge_attr = remove_self_loops(edge_index, edge_attr)
row = torch.cat([edge_index[0], edge_index[1]])
col = torch.cat([edge_index[1], edge_index[0]])
row, col, _ = coalesce(row, col)
edge_index = torch.stack([row, col], dim=0)
data = Graph(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)
data.num_nodes = graph["num_nodes"]
# split
split_index = dataset.get_idx_split()
data.train_mask = torch.full((data.num_nodes,), False, dtype=torch.bool)
data.val_mask = torch.full((data.num_nodes,), False, dtype=torch.bool)
data.test_mask = torch.full((data.num_nodes,), False, dtype=torch.bool)
data.train_mask[split_index["train"]] = True
data.test_mask[split_index["test"]] = True
data.val_mask[split_index["valid"]] = True
torch.save(data, self.processed_paths[0])
return data
[docs]class OGBArxivDataset(OGBNDataset):
def __init__(self, data_path="data"):
dataset = "ogbn-arxiv"
super(OGBArxivDataset, self).__init__(data_path, dataset)
[docs]class OGBProductsDataset(OGBNDataset):
def __init__(self, data_path="data"):
dataset = "ogbn-products"
super(OGBProductsDataset, self).__init__(data_path, dataset)
[docs]class OGBProteinsDataset(OGBNDataset):
def __init__(self, data_path="data"):
dataset = "ogbn-proteins"
super(OGBProteinsDataset, self).__init__(data_path, dataset)
@property
def edge_attr_size(self):
return [
self.data.edge_attr.shape[1],
]
[docs] def get_loss_fn(self):
return BCEWithLogitsLoss()
[docs] def get_evaluator(self):
evaluator = NodeEvaluator(name="ogbn-proteins")
def wrap(y_pred, y_true):
input_dict = {"y_true": y_true, "y_pred": y_pred}
return evaluator.eval(input_dict)["rocauc"]
return wrap
[docs] def process(self):
name = self.name.replace("_", "-")
dataset = NodePropPredDataset(name, self.root)
graph, y = dataset[0]
y = torch.tensor(y.squeeze())
row, col = graph["edge_index"][0], graph["edge_index"][1]
row = torch.from_numpy(row)
col = torch.from_numpy(col)
edge_attr = torch.as_tensor(graph["edge_feat"]) if "edge_feat" in graph else None
data = Graph(x=None, edge_index=(row, col), edge_attr=edge_attr, y=y)
data.num_nodes = graph["num_nodes"]
# split
split_index = dataset.get_idx_split()
data.train_mask = torch.full((data.num_nodes,), False, dtype=torch.bool)
data.val_mask = torch.full((data.num_nodes,), False, dtype=torch.bool)
data.test_mask = torch.full((data.num_nodes,), False, dtype=torch.bool)
data.train_mask[split_index["train"]] = True
data.test_mask[split_index["test"]] = True
data.val_mask[split_index["valid"]] = True
edge_attr = data.edge_attr
deg = data.degrees()
dst, _ = data.edge_index
dst = dst.view(-1, 1).expand(dst.shape[0], edge_attr.shape[1])
x = torch.zeros((data.num_nodes, edge_attr.shape[1]), dtype=torch.float32)
x = x.scatter_add_(dim=0, index=dst, src=edge_attr)
deg = torch.clamp(deg, min=1)
x = x / deg.view(-1, 1)
data.x = x
node_species = torch.as_tensor(graph["node_species"])
n_species, new_index = torch.unique(node_species, return_inverse=True)
one_hot_x = torch.nn.functional.one_hot(new_index, num_classes=torch.max(new_index).int().item())
data.species = node_species
data.x = torch.cat([data.x, one_hot_x], dim=1)
torch.save(data, self.processed_paths[0])
return data
[docs]class OGBPapers100MDataset(OGBNDataset):
def __init__(self, data_path="data"):
dataset = "ogbn-papers100M"
super(OGBPapers100MDataset, self).__init__(data_path, dataset)
[docs]class OGBGDataset(Dataset):
def __init__(self, root, name):
super(OGBGDataset, self).__init__(root)
self.name = name
self.dataset = GraphPropPredDataset(self.name, root)
self.data = []
self.all_nodes = 0
self.all_edges = 0
for i in range(len(self.dataset.graphs)):
graph, label = self.dataset[i]
data = Graph(
x=torch.tensor(graph["node_feat"], dtype=torch.float),
edge_index=torch.tensor(graph["edge_index"]),
edge_attr=None if "edge_feat" not in graph else torch.tensor(graph["edge_feat"], dtype=torch.float),
y=torch.tensor(label),
)
data.num_nodes = graph["num_nodes"]
self.data.append(data)
self.all_nodes += graph["num_nodes"]
self.all_edges += graph["edge_index"].shape[1]
self.transform = None
[docs] def get_loader(self, args):
split_index = self.dataset.get_idx_split()
train_loader = DataLoader(self.get_subset(split_index["train"]), batch_size=args.batch_size, shuffle=True)
valid_loader = DataLoader(self.get_subset(split_index["valid"]), batch_size=args.batch_size, shuffle=False)
test_loader = DataLoader(self.get_subset(split_index["test"]), batch_size=args.batch_size, shuffle=False)
return train_loader, valid_loader, test_loader
[docs] def get_subset(self, subset):
datalist = []
for idx in subset:
datalist.append(self.data[idx])
return datalist
[docs] def get(self, idx):
return self.data[idx]
def _download(self):
pass
def _process(self):
pass
@property
def num_classes(self):
return int(self.dataset.num_classes)
[docs]class OGBMolbaceDataset(OGBGDataset):
def __init__(self, data_path="data"):
dataset = "ogbg-molbace"
super(OGBMolbaceDataset, self).__init__(data_path, dataset)
[docs]class OGBMolhivDataset(OGBGDataset):
def __init__(self, data_path="data"):
dataset = "ogbg-molhiv"
super(OGBMolhivDataset, self).__init__(data_path, dataset)
[docs]class OGBMolpcbaDataset(OGBGDataset):
def __init__(self, data_path="data"):
dataset = "ogbg-molpcba"
super(OGBMolpcbaDataset, self).__init__(data_path, dataset)
[docs]class OGBPpaDataset(OGBGDataset):
def __init__(self):
dataset = "ogbg-ppa"
path = "data"
super(OGBPpaDataset, self).__init__(path, dataset)
[docs]class OGBCodeDataset(OGBGDataset):
def __init__(self, data_path="data"):
dataset = "ogbg-code"
super(OGBCodeDataset, self).__init__(data_path, dataset)
#This part is for ogbl datasets
[docs]class OGBLDataset(Dataset):
def __init__(self, root, name):
"""
- name (str): name of the dataset
- root (str): root directory to store the dataset folder
"""
self.name = name
dataset = LinkPropPredDataset(name, root)
graph= dataset[0]
x = torch.tensor(graph["node_feat"]).contiguous() if graph["node_feat"] is not None else None
row, col = graph["edge_index"][0], graph["edge_index"][1]
row = torch.from_numpy(row)
col = torch.from_numpy(col)
edge_index = torch.stack([row, col], dim=0)
edge_attr = torch.as_tensor(graph["edge_feat"]) if graph["edge_feat"] is not None else graph["edge_feat"]
edge_index, edge_attr = remove_self_loops(edge_index, edge_attr)
row = torch.cat([edge_index[0], edge_index[1]])
col = torch.cat([edge_index[1], edge_index[0]])
row, col, _ = coalesce(row, col)
edge_index = torch.stack([row, col], dim=0)
self.data = Graph(x=x, edge_index=edge_index, edge_attr=edge_attr, y=None)
self.data.num_nodes = graph["num_nodes"]
[docs] def get(self, idx):
assert idx == 0
return self.data
[docs] def get_loss_fn(self):
return CrossEntropyLoss()
[docs] def get_evaluator(self):
return Accuracy()
def _download(self):
pass
@property
def processed_file_names(self):
return "data_cogdl.pt"
def _process(self):
pass
[docs] def get_edge_split(self):
idx = self.dataset.get_edge_split()
train_edge = torch.from_numpy(idx['train']['edge'].T)
val_edge = torch.from_numpy(idx['valid']['edge'].T)
test_edge = torch.from_numpy(idx['test']['edge'].T)
return train_edge, val_edge, test_edge
[docs]class OGBLPpaDataset(OGBLDataset):
def __init__(self, data_path="data"):
dataset = "ogbl-ppa"
super(OGBLPpaDataset, self).__init__(data_path, dataset)
[docs]class OGBLCollabDataset(OGBLDataset):
def __init__(self, data_path="data"):
dataset = "ogbl-collab"
super(OGBLCollabDataset, self).__init__(data_path, dataset)
[docs]class OGBLDdiDataset(OGBLDataset):
def __init__(self, data_path="data"):
dataset = "ogbl-ddi"
super(OGBLDdiDataset, self).__init__(data_path, dataset)
[docs]class OGBLCitation2Dataset(OGBLDataset):
def __init__(self, data_path="data"):
dataset = "ogbl-citation2"
super(OGBLCitation2Dataset, self).__init__(data_path, dataset)