model wrappers

Node Classification

class cogdl.wrappers.model_wrapper.node_classification.DGIModelWrapper(model, optimizer_cfg)[source]

Bases: cogdl.wrappers.model_wrapper.base_model_wrapper.UnsupervisedModelWrapper

static add_args(parser)[source]
static augment(graph)[source]
setup_optimizer()[source]
test_step(graph)[source]
train_step(subgraph)[source]
training: bool
class cogdl.wrappers.model_wrapper.node_classification.GCNMixModelWrapper(model, optimizer_cfg, temperature, rampup_starts, rampup_ends, mixup_consistency, ema_decay, tau, k)[source]

Bases: cogdl.wrappers.model_wrapper.base_model_wrapper.ModelWrapper

GCNMixModelWrapper calls forward_aux in model forward_aux is similar to forward but ignores spmm operation.

static add_args(parser)[source]
setup_optimizer()[source]
test_step(subgraph)[source]
train_step(subgraph)[source]
training: bool
update_aux(data, vector_labels, train_index)[source]
update_soft(graph)[source]
val_step(subgraph)[source]
class cogdl.wrappers.model_wrapper.node_classification.GRACEModelWrapper(model, optimizer_cfg, tau, drop_feature_rates, drop_edge_rates, batch_fwd, proj_hidden_size)[source]

Bases: cogdl.wrappers.model_wrapper.base_model_wrapper.UnsupervisedModelWrapper

static add_args(parser)[source]
batched_loss(z1: torch.Tensor, z2: torch.Tensor, batch_size: int)[source]
contrastive_loss(z1: torch.Tensor, z2: torch.Tensor)[source]
prop(graph: cogdl.data.data.Graph, x: torch.Tensor, drop_feature_rate: float = 0.0, drop_edge_rate: float = 0.0)[source]
setup_optimizer()[source]
test_step(graph)[source]
train_step(subgraph)[source]
training: bool
class cogdl.wrappers.model_wrapper.node_classification.GrandModelWrapper(model, optimizer_cfg, sample=2, temperature=0.5, lmbda=0.5)[source]

Bases: cogdl.wrappers.model_wrapper.node_classification.node_classification_mw.NodeClfModelWrapper

sampleint

Number of augmentations for consistency loss

temperaturefloat

Temperature to sharpen predictions.

lmbdafloat

Proportion of consistency loss of unlabelled data

static add_args(parser)[source]
consistency_loss(logps, train_mask)[source]
train_step(batch)[source]
training: bool
class cogdl.wrappers.model_wrapper.node_classification.MVGRLModelWrapper(model, optimizer_cfg)[source]

Bases: cogdl.wrappers.model_wrapper.base_model_wrapper.UnsupervisedModelWrapper

setup_optimizer()[source]
test_step(graph)[source]
train_step(subgraph)[source]
training: bool
class cogdl.wrappers.model_wrapper.node_classification.SelfAuxiliaryModelWrapper(model, optimizer_cfg, auxiliary_task, dropedge_rate, mask_ratio, sampling)[source]

Bases: cogdl.wrappers.model_wrapper.base_model_wrapper.UnsupervisedModelWrapper

static add_args(parser)[source]
generate_virtual_labels(data)[source]
pre_stage(stage, data_w)[source]
setup_optimizer()[source]
test_step(graph)[source]
train_step(subgraph)[source]
training: bool
class cogdl.wrappers.model_wrapper.node_classification.GraphSAGEModelWrapper(model, optimizer_cfg)[source]

Bases: cogdl.wrappers.model_wrapper.base_model_wrapper.ModelWrapper

setup_optimizer()[source]
test_step(batch)[source]
train_step(batch)[source]
training: bool
val_step(batch)[source]
class cogdl.wrappers.model_wrapper.node_classification.UnsupGraphSAGEModelWrapper(model, optimizer_cfg, walk_length, negative_samples)[source]

Bases: cogdl.wrappers.model_wrapper.base_model_wrapper.UnsupervisedModelWrapper

static add_args(parser)[source]
setup_optimizer()[source]
test_step(batch)[source]
train_step(batch)[source]
training: bool
class cogdl.wrappers.model_wrapper.node_classification.M3SModelWrapper(model, optimizer_cfg, n_cluster, num_new_labels)[source]

Bases: cogdl.wrappers.model_wrapper.node_classification.node_classification_mw.NodeClfModelWrapper

static add_args(parser)[source]
pre_stage(stage, data_w: cogdl.wrappers.data_wrapper.base_data_wrapper.DataWrapper)[source]
training: bool
class cogdl.wrappers.model_wrapper.node_classification.NetworkEmbeddingModelWrapper(model, num_shuffle=1, training_percents=[0.1], enhance=None, max_evals=10, num_workers=1)[source]

Bases: cogdl.wrappers.model_wrapper.base_model_wrapper.EmbeddingModelWrapper

static add_args(parser)[source]
test_step(batch)[source]
train_step(batch)[source]
training: bool
class cogdl.wrappers.model_wrapper.node_classification.NodeClfModelWrapper(model, optimizer_cfg)[source]

Bases: cogdl.wrappers.model_wrapper.base_model_wrapper.ModelWrapper

set_early_stopping()[source]
Returns

  1. str, the monitoring metric

  2. tuple(str, str), that is, (the monitoring metric, small or big). The second parameter means,

    the smaller, the better or the bigger, the better

setup_optimizer()[source]
test_step(batch)[source]
train_step(subgraph)[source]
training: bool
val_step(subgraph)[source]
class cogdl.wrappers.model_wrapper.node_classification.CorrectSmoothModelWrapper(model, optimizer_cfg)[source]

Bases: cogdl.wrappers.model_wrapper.node_classification.node_classification_mw.NodeClfModelWrapper

static add_args(parser)[source]
test_step(batch)[source]
training: bool
val_step(subgraph)[source]
class cogdl.wrappers.model_wrapper.node_classification.PPRGoModelWrapper(model, optimizer_cfg)[source]

Bases: cogdl.wrappers.model_wrapper.base_model_wrapper.ModelWrapper

setup_optimizer()[source]
test_step(batch)[source]
train_step(batch)[source]
training: bool
val_step(batch)[source]
class cogdl.wrappers.model_wrapper.node_classification.SAGNModelWrapper(model, optimizer_cfg)[source]

Bases: cogdl.wrappers.model_wrapper.base_model_wrapper.ModelWrapper

pre_stage(stage, data_w)[source]
setup_optimizer()[source]
test_step(batch)[source]
train_step(batch)[source]
training: bool
val_step(batch)[source]

Graph Classification

class cogdl.wrappers.model_wrapper.graph_classification.GraphClassificationModelWrapper(model, optimizer_cfg)[source]

Bases: cogdl.wrappers.model_wrapper.base_model_wrapper.ModelWrapper

setup_optimizer()[source]
test_step(batch)[source]
train_step(batch)[source]
training: bool
val_step(batch)[source]
class cogdl.wrappers.model_wrapper.graph_classification.GraphEmbeddingModelWrapper(model)[source]

Bases: cogdl.wrappers.model_wrapper.base_model_wrapper.EmbeddingModelWrapper

test_step(batch)[source]
train_step(batch)[source]
training: bool
class cogdl.wrappers.model_wrapper.graph_classification.InfoGraphModelWrapper(model, optimizer_cfg, sup=False)[source]

Bases: cogdl.wrappers.model_wrapper.base_model_wrapper.ModelWrapper

static add_args(parser)[source]
static mi_loss(pos_mask, neg_mask, mi, pos_div, neg_div)[source]
setup_optimizer()[source]
sup_loss(pred, batch)[source]
test_step(dataset)[source]
train_step(batch)[source]
training: bool
unsup_loss(graph_feat, node_feat, batch)[source]

Pretraining

class cogdl.wrappers.model_wrapper.pretraining.GCCModelWrapper(model, optimizer_cfg, nce_k, nce_t, momentum, output_size, finetune=False, num_classes=1, model_path='gcc_pretrain.pt')[source]

Bases: cogdl.wrappers.model_wrapper.base_model_wrapper.ModelWrapper

static add_args(parser)[source]
load_checkpoint(path)[source]
post_stage(stage, data_w)[source]
pre_stage(stage, data_w)[source]
save_checkpoint(path)[source]
setup_optimizer()[source]
train_step(batch)[source]
train_step_finetune(batch)[source]
train_step_pretraining(batch)[source]
training: bool

Heterogeneous

class cogdl.wrappers.model_wrapper.heterogeneous.HeterogeneousEmbeddingModelWrapper(model, hidden_size=200)[source]

Bases: cogdl.wrappers.model_wrapper.base_model_wrapper.EmbeddingModelWrapper

static add_args(parser: argparse.ArgumentParser)[source]

Add task-specific arguments to the parser.

test_step(batch)[source]
train_step(batch)[source]
training: bool
class cogdl.wrappers.model_wrapper.heterogeneous.HeterogeneousGNNModelWrapper(model, optimizer_cfg)[source]

Bases: cogdl.wrappers.model_wrapper.base_model_wrapper.ModelWrapper

setup_optimizer()[source]
test_step(batch)[source]
train_step(batch)[source]
training: bool
val_step(batch)[source]
class cogdl.wrappers.model_wrapper.heterogeneous.MultiplexEmbeddingModelWrapper(model, hidden_size=200, eval_type='all')[source]

Bases: cogdl.wrappers.model_wrapper.base_model_wrapper.EmbeddingModelWrapper

static add_args(parser: argparse.ArgumentParser)[source]

Add task-specific arguments to the parser.

test_step(batch)[source]
train_step(batch)[source]
training: bool

Clustering

class cogdl.wrappers.model_wrapper.clustering.AGCModelWrapper(model, optimizer_cfg, num_clusters, cluster_method='kmeans', evaluation='full', max_iter=5)[source]

Bases: cogdl.wrappers.model_wrapper.base_model_wrapper.EmbeddingModelWrapper

static add_args(parser)[source]
test_step(batch)[source]
train_step(graph)[source]
training: bool
class cogdl.wrappers.model_wrapper.clustering.DAEGCModelWrapper(model, optimizer_cfg, num_clusters, cluster_method='kmeans', evaluation='full', T=5)[source]

Bases: cogdl.wrappers.model_wrapper.base_model_wrapper.ModelWrapper

static add_args(parser)[source]
cluster_loss(P, Q)[source]
getP(Q)[source]
getQ(z, cluster_center)[source]
post_stage(stage, data_w)[source]
pre_stage(stage, data_w)[source]
recon_loss(z, adj)[source]
setup_optimizer()[source]
test_step(subgraph)[source]
train_step(subgraph)[source]
training: bool
class cogdl.wrappers.model_wrapper.clustering.GAEModelWrapper(model, optimizer_cfg, num_clusters, cluster_method='kmeans', evaluation='full')[source]

Bases: cogdl.wrappers.model_wrapper.base_model_wrapper.ModelWrapper

static add_args(parser)[source]
pre_stage(stage, data_w)[source]
setup_optimizer()[source]
test_step(subgraph)[source]
train_step(subgraph)[source]
training: bool