model wrappers
Node Classification
- class cogdl.wrappers.model_wrapper.node_classification.DGIModelWrapper(model, optimizer_cfg)[source]
Bases:
cogdl.wrappers.model_wrapper.base_model_wrapper.UnsupervisedModelWrapper
- class cogdl.wrappers.model_wrapper.node_classification.GCNMixModelWrapper(model, optimizer_cfg, temperature, rampup_starts, rampup_ends, mixup_consistency, ema_decay, tau, k)[source]
Bases:
cogdl.wrappers.model_wrapper.base_model_wrapper.ModelWrapper
GCNMixModelWrapper calls forward_aux in model forward_aux is similar to forward but ignores spmm operation.
- class cogdl.wrappers.model_wrapper.node_classification.GRACEModelWrapper(model, optimizer_cfg, tau, drop_feature_rates, drop_edge_rates, batch_fwd, proj_hidden_size)[source]
Bases:
cogdl.wrappers.model_wrapper.base_model_wrapper.UnsupervisedModelWrapper
- prop(graph: cogdl.data.data.Graph, x: torch.Tensor, drop_feature_rate: float = 0.0, drop_edge_rate: float = 0.0)[source]
- class cogdl.wrappers.model_wrapper.node_classification.GrandModelWrapper(model, optimizer_cfg, sample=2, temperature=0.5, lmbda=0.5)[source]
Bases:
cogdl.wrappers.model_wrapper.node_classification.node_classification_mw.NodeClfModelWrapper
- sampleint
Number of augmentations for consistency loss
- temperaturefloat
Temperature to sharpen predictions.
- lmbdafloat
Proportion of consistency loss of unlabelled data
- class cogdl.wrappers.model_wrapper.node_classification.MVGRLModelWrapper(model, optimizer_cfg)[source]
Bases:
cogdl.wrappers.model_wrapper.base_model_wrapper.UnsupervisedModelWrapper
- class cogdl.wrappers.model_wrapper.node_classification.SelfAuxiliaryModelWrapper(model, optimizer_cfg, auxiliary_task, dropedge_rate, mask_ratio, sampling)[source]
Bases:
cogdl.wrappers.model_wrapper.base_model_wrapper.UnsupervisedModelWrapper
- class cogdl.wrappers.model_wrapper.node_classification.GraphSAGEModelWrapper(model, optimizer_cfg)[source]
Bases:
cogdl.wrappers.model_wrapper.base_model_wrapper.ModelWrapper
- class cogdl.wrappers.model_wrapper.node_classification.UnsupGraphSAGEModelWrapper(model, optimizer_cfg, walk_length, negative_samples)[source]
Bases:
cogdl.wrappers.model_wrapper.base_model_wrapper.UnsupervisedModelWrapper
- class cogdl.wrappers.model_wrapper.node_classification.M3SModelWrapper(model, optimizer_cfg, n_cluster, num_new_labels)[source]
Bases:
cogdl.wrappers.model_wrapper.node_classification.node_classification_mw.NodeClfModelWrapper
- class cogdl.wrappers.model_wrapper.node_classification.NetworkEmbeddingModelWrapper(model, num_shuffle=1, training_percents=[0.1], enhance=None, max_evals=10, num_workers=1)[source]
Bases:
cogdl.wrappers.model_wrapper.base_model_wrapper.EmbeddingModelWrapper
- class cogdl.wrappers.model_wrapper.node_classification.NodeClfModelWrapper(model, optimizer_cfg)[source]
Bases:
cogdl.wrappers.model_wrapper.base_model_wrapper.ModelWrapper
- class cogdl.wrappers.model_wrapper.node_classification.CorrectSmoothModelWrapper(model, optimizer_cfg)[source]
Bases:
cogdl.wrappers.model_wrapper.node_classification.node_classification_mw.NodeClfModelWrapper
- class cogdl.wrappers.model_wrapper.node_classification.PPRGoModelWrapper(model, optimizer_cfg)[source]
Bases:
cogdl.wrappers.model_wrapper.base_model_wrapper.ModelWrapper
Graph Classification
- class cogdl.wrappers.model_wrapper.graph_classification.GraphClassificationModelWrapper(model, optimizer_cfg)[source]
Bases:
cogdl.wrappers.model_wrapper.base_model_wrapper.ModelWrapper
- class cogdl.wrappers.model_wrapper.graph_classification.GraphEmbeddingModelWrapper(model)[source]
Bases:
cogdl.wrappers.model_wrapper.base_model_wrapper.EmbeddingModelWrapper
Pretraining
- class cogdl.wrappers.model_wrapper.pretraining.GCCModelWrapper(model, optimizer_cfg, nce_k, nce_t, momentum, output_size, finetune=False, num_classes=1, num_shuffle=10, save_model_path='saved', load_model_path='', freeze=False, pretrain=False)[source]
Bases:
cogdl.wrappers.model_wrapper.base_model_wrapper.ModelWrapper
Link Prediction
- class cogdl.wrappers.model_wrapper.link_prediction.EmbeddingLinkPredictionModelWrapper(model)[source]
Bases:
cogdl.wrappers.model_wrapper.base_model_wrapper.EmbeddingModelWrapper
- class cogdl.wrappers.model_wrapper.link_prediction.GNNKGLinkPredictionModelWrapper(model, optimizer_cfg, score_func)[source]
Bases:
cogdl.wrappers.model_wrapper.base_model_wrapper.ModelWrapper
- class cogdl.wrappers.model_wrapper.link_prediction.GNNLinkPredictionModelWrapper(model, optimizer_cfg)[source]
Bases:
cogdl.wrappers.model_wrapper.base_model_wrapper.ModelWrapper
Heterogeneous
- class cogdl.wrappers.model_wrapper.heterogeneous.HeterogeneousEmbeddingModelWrapper(model, hidden_size=200)[source]
Bases:
cogdl.wrappers.model_wrapper.base_model_wrapper.EmbeddingModelWrapper
- static add_args(parser: argparse.ArgumentParser)[source]
Add task-specific arguments to the parser.
- class cogdl.wrappers.model_wrapper.heterogeneous.HeterogeneousGNNModelWrapper(model, optimizer_cfg)[source]
Bases:
cogdl.wrappers.model_wrapper.base_model_wrapper.ModelWrapper
- class cogdl.wrappers.model_wrapper.heterogeneous.MultiplexEmbeddingModelWrapper(model, hidden_size=200, eval_type='all')[source]
Bases:
cogdl.wrappers.model_wrapper.base_model_wrapper.EmbeddingModelWrapper
- static add_args(parser: argparse.ArgumentParser)[source]
Add task-specific arguments to the parser.
Clustering
- class cogdl.wrappers.model_wrapper.clustering.AGCModelWrapper(model, optimizer_cfg, num_clusters, cluster_method='kmeans', evaluation='full', max_iter=5)[source]
Bases:
cogdl.wrappers.model_wrapper.base_model_wrapper.EmbeddingModelWrapper
- class cogdl.wrappers.model_wrapper.clustering.DAEGCModelWrapper(model, optimizer_cfg, num_clusters, cluster_method='kmeans', evaluation='full', T=5)[source]
Bases:
cogdl.wrappers.model_wrapper.base_model_wrapper.ModelWrapper