cogdl.trainers.gpt_gnn_trainer

Module Contents

Functions

node_classification_sample(args, target_type, seed, nodes, time_range)

sub-graph sampling and label preparation for node classification:

prepare_data(args, graph, target_type, train_target_nodes, valid_target_nodes, pool)

Sampled and prepare training and validation data using multi-process parallization.

cogdl.trainers.gpt_gnn_trainer.graph_pool[source]
cogdl.trainers.gpt_gnn_trainer.node_classification_sample(args, target_type, seed, nodes, time_range)[source]

sub-graph sampling and label preparation for node classification: (1) Sample batch_size number of output nodes (papers) and their time.

cogdl.trainers.gpt_gnn_trainer.prepare_data(args, graph, target_type, train_target_nodes, valid_target_nodes, pool)[source]

Sampled and prepare training and validation data using multi-process parallization.

class cogdl.trainers.gpt_gnn_trainer.GPT_GNNHomogeneousTrainer(args)[source]

Bases: cogdl.trainers.supervised_trainer.SupervisedHomogeneousNodeClassificationTrainer

fit(self, model: cogdl.models.supervised_model.SupervisedHeterogeneousNodeClassificationModel, dataset: cogdl.data.Dataset)None[source]
classmethod build_trainer_from_args(cls, args)[source]
class cogdl.trainers.gpt_gnn_trainer.GPT_GNNHeterogeneousTrainer(model, dataset)[source]

Bases: cogdl.trainers.supervised_trainer.SupervisedHeterogeneousNodeClassificationTrainer

fit(self)None[source]
evaluate(self, data: Any, nodes: Any, targets: Any)Any[source]