In this section, we will introduce how to implement a specific Trainer for a model.

In previous section, we introduce the implementation of different tasks. But the training paradigm varies and is incompatible with the defined training process in some cases. Therefore, CogDL provides Trainer to customize the training and inference mode. Take NeighborSamplingTrainer as the example, this section will show how to define a trainer.


1. A self-defined trainer should inherits BaseTrainer and must implement function fit to define the training and evaluating process. Necessary parameters for training need to be added to the add_args in models and can be obtained here in __init___.

class NeighborSamplingTrainer(BaseTrainer):
    def __init__(self, args):
        # ... get necessary parameters from args

    def fit(self, model, dataset):
        # ... implement the training and evaluation

    def build_trainer_from_args(cls, args):
        return cls(args)

2. All training and evaluating process, including data preprocessing and defining optimizer, should be implemented in fit. In other words, given the model and dataset, the rest is up to you. fit accepts two parameters: model and dataset, which usually are in cpu. You need to move them to cuda if you want to train on GPU.

def fit(self, model, dataset): = dataset[0]

    # preprocess data
    self.train_loader = NeighborSampler(,,
    self.test_loader = NeighborSampler(, mask=None, sizes=[-1], batch_size=self.batch_size, shuffle=False
    # move model to GPU
    self.model =

    # define optimizer
    self.optimizer = torch.optim.Adam(self.model.parameters(),, weight_decay=self.weight_decay)
    # training
    best_model = self.train()
    self.model = best_model
    # evaluation
    acc, loss = self._test_step()
    return dict(Acc=acc["test"], ValAcc=acc["val"])

3. To make the training of a model use the trainer, we should assign the trainer to the model. In Cogdl, a model must implement get_trainer as static method if it has a customized training process. GraphSAGE depends on NeighborSamplingTrainer, so the following codes should exsits in the implementation.

def get_trainer(taskType, args):
    return NeighborSamplingTrainer

The details of training and evaluating are similar to the implementation in Tasks. The overall implementation of trainers is at