Supervised Graph Classification

In this section, we will introduce the implementation “Graph classification task”.

Task Design

  1. Set up “SupervisedGraphClassification” class, which has two specific parameters.

    • degree-feature: Use one-hot node degree as node feature, for datasets such as lmdb-binary and lmdb-multi, which don’t have node features.

    • gamma: Multiplicative factor of learning rate decay.

    • lr: Learning rate.

  2. Build dataset convert it to a list of Data defined in Cogdl. Specially, we reformat the data according to the input format of specific models. generate_data is implemented to convert dataset.

dataset = build_dataset(args) = self.generate_data(dataset, args)

def generate_data(self, dataset, args):
     if "ModelNet" in str(type(dataset).__name__):
         train_set, test_set = dataset.get_all()
         args.num_features = 3
         return {"train": train_set, "test": test_set}
        datalist = []
        if isinstance(dataset[0], Data):
            return dataset
        for idata in dataset:
            data = Data()
            for key in idata.keys:
                data[key] = idata[key]

        if args.degree_feature:
            datalist = node_degree_as_feature(datalist)
            args.num_features = datalist[0].num_features
        return datalist
  1. Then we build model and can run train to train the model.

def train(self):
    for epoch in epoch_iter:
         val_acc, val_loss = self._test_step(split="valid")
         # ...
         return dict(Acc=test_acc)

def _train_step(self):
    loss_n = 0
    for batch in self.train_loader:
        batch =
        output, loss = self.model(batch)
        loss_n += loss.item()

def _test_step(self, split):
    """split in ['train', 'test', 'valid']"""
    # ...
    return acc, loss

The overall implementation of GraphClassification is at (

Create a model

To create a model for task graph classification, the following functions have to be implemented.

  1. add_args(parser): add necessary hyper-parameters used in model.

def add_args(parser):
     parser.add_argument("--hidden-size", type=int, default=128)
     parser.add_argument("--num-layers", type=int, default=2)
     parser.add_argument("--lr", type=float, default=0.001)
     # ...
  1. build_model_from_args(cls, args): this function is called in ‘task’ to build model.

  2. split_dataset(cls, dataset, args): split train/validation/test data and return correspondent dataloader according to requirement of model.

def split_dataset(cls, dataset, args):
    train_size = int(len(dataset) * args.train_ratio)
    test_size = int(len(dataset) * args.test_ratio)
    bs = args.batch_size
    train_loader = DataLoader(dataset[:train_size], batch_size=bs)
    test_loader = DataLoader(dataset[-test_size:], batch_size=bs)
    if args.train_ratio + args.test_ratio < 1:
         valid_loader = DataLoader(dataset[train_size:-test_size], batch_size=bs)
         valid_loader = test_loader
    return train_loader, valid_loader, test_loader
  1. forward: forward propagation, and the return should be (predication, loss) or (prediction, None), respectively for training and test. Input parameters of forward is class Batch, which

def forward(self, batch):
 h = batch.x
 layer_rep = [h]
 for i in range(self.num_layers-1):
     h = self.gin_layers[i](h, batch.edge_index)
     h = self.batch_norm[i](h)
     h = F.relu(h)

 final_score = 0
 for i in range(self.num_layers):
 pooled = scatter_add(layer_rep[i], batch.batch, dim=0)
 final_score += self.dropout(self.linear_prediction[i](pooled))
 final_score = F.softmax(final_score, dim=-1)
 if batch.y is not None:
     loss = self.loss(final_score, batch.y)
     return final_score, loss
 return final_score, None


To run GraphClassification, we can use the following command:

python scripts/ --task graph_classification --dataset proteins --model gin diffpool sortpool dgcnn --seed 0 1

Then We get experimental results like this:



(‘proteins’, ‘gin’)


(‘proteins’, ‘diffpool’)


(‘proteins’, ‘sortpool’)


(‘proteins’, ‘dgcnn’)


(‘proteins’, ‘patchy_san’)