tasks

Base Task

class cogdl.tasks.base_task.BaseTask(args)[source]

Bases: abc.ABC

static add_args(parser: argparse.ArgumentParser)[source]

Add task-specific arguments to the parser.

get_trainer(model, args)[source]
load_from_pretrained()[source]
save_checkpoint()[source]
set_evaluator(dataset)[source]
set_loss_fn(dataset)[source]
train()[source]
class cogdl.tasks.base_task.LoadFrom[source]

Bases: abc.ABCMeta

Node Classification

class cogdl.tasks.node_classification.NodeClassification(args, dataset=None, model: Optional[cogdl.models.supervised_model.SupervisedHomogeneousNodeClassificationModel] = None)[source]

Bases: cogdl.tasks.base_task.BaseTask

Node classification task.

static add_args(parser: argparse.ArgumentParser)[source]

Add task-specific arguments to the parser.

train()[source]

Unsupervised Node Classification

class cogdl.tasks.unsupervised_node_classification.TopKRanker(estimator, *, n_jobs=None)[source]

Bases: sklearn.multiclass.OneVsRestClassifier

predict(X, top_k_list)[source]

Predict multi-class targets using underlying estimators.

X : (sparse) array-like of shape (n_samples, n_features)
Data.
y : (sparse) array-like of shape (n_samples,) or (n_samples, n_classes)
Predicted multi-class targets.
class cogdl.tasks.unsupervised_node_classification.UnsupervisedNodeClassification(args, dataset=None, model=None)[source]

Bases: cogdl.tasks.base_task.BaseTask

Node classification task.

static add_args(parser: argparse.ArgumentParser)[source]

Add task-specific arguments to the parser.

enhance_emb(G, embs)[source]
save_emb(embs)[source]
train()[source]

Heterogeneous Node Classification

class cogdl.tasks.heterogeneous_node_classification.HeterogeneousNodeClassification(args, dataset=None, model=None)[source]

Bases: cogdl.tasks.base_task.BaseTask

Heterogeneous Node classification task.

static add_args(_: argparse.ArgumentParser)[source]

Add task-specific arguments to the parser.

train()[source]

Multiplex Node Classification

class cogdl.tasks.multiplex_node_classification.MultiplexNodeClassification(args, dataset=None, model=None)[source]

Bases: cogdl.tasks.base_task.BaseTask

Node classification task.

static add_args(parser: argparse.ArgumentParser)[source]

Add task-specific arguments to the parser.

train()[source]

Graph Classification

class cogdl.tasks.graph_classification.GraphClassification(args, dataset=None, model=None)[source]

Bases: cogdl.tasks.base_task.BaseTask

Superiviced graph classification task.

static add_args(parser: argparse.ArgumentParser)[source]

Add task-specific arguments to the parser.

generate_data(dataset, args)[source]
train()[source]
cogdl.tasks.graph_classification.node_degree_as_feature(data)[source]

Set each node feature as one-hot encoding of degree :param data: a list of class Data :return: a list of class Data

cogdl.tasks.graph_classification.uniform_node_feature(data)[source]

Set each node feature to the same

Unsupervised Graph Classification

class cogdl.tasks.unsupervised_graph_classification.UnsupervisedGraphClassification(args, dataset=None, model=None)[source]

Bases: cogdl.tasks.base_task.BaseTask

Unsupervised graph classification

static add_args(parser: argparse.ArgumentParser)[source]

Add task-specific arguments to the parser.

save_emb(embs)[source]
train()[source]

Attributed Graph Clustering

class cogdl.tasks.attributed_graph_clustering.AttributedGraphClustering(args, dataset=None, _=None)[source]

Bases: cogdl.tasks.base_task.BaseTask

Attributed graph clustring task.

static add_args(parser: argparse.ArgumentParser)[source]

Add task-specific arguments to the parser.

train() → Dict[str, float][source]

Pretrain

class cogdl.tasks.pretrain.PretrainTask(args)[source]

Bases: cogdl.tasks.base_task.BaseTask

static add_args(_: argparse.ArgumentParser)[source]

Add task-specific arguments to the parser.

train()[source]

Task Module

cogdl.tasks.build_task(args, dataset=None, model=None)[source]
cogdl.tasks.register_task(name)[source]

New task types can be added to cogdl with the register_task() function decorator.

For example:

@register_task('node_classification')
class NodeClassification(BaseTask):
    (...)
Args:
name (str): the name of the task