data wrappers

Node Classification

class cogdl.wrappers.data_wrapper.node_classification.ClusterWrapper(dataset, method='metis', batch_size=20, n_cluster=100)[source]

Bases: cogdl.wrappers.data_wrapper.base_data_wrapper.DataWrapper

static add_args(parser)[source]
get_train_dataset()[source]

Return the wrapped dataset for specific usage. For example, return ClusteredDataset in cluster_dw for DDP training.

test_wrapper()[source]
train_wrapper()[source]
Returns

  1. DataLoader

  2. cogdl.Graph

  3. list of DataLoader or Graph

Any other data formats other than DataLoader will not be traversed

val_wrapper()[source]
class cogdl.wrappers.data_wrapper.node_classification.GraphSAGEDataWrapper(dataset, batch_size: int, sample_size: list)[source]

Bases: cogdl.wrappers.data_wrapper.base_data_wrapper.DataWrapper

static add_args(parser)[source]
get_train_dataset()[source]

Return the wrapped dataset for specific usage. For example, return ClusteredDataset in cluster_dw for DDP training.

test_wrapper()[source]
train_transform(batch)[source]
train_wrapper()[source]
Returns

  1. DataLoader

  2. cogdl.Graph

  3. list of DataLoader or Graph

Any other data formats other than DataLoader will not be traversed

val_transform(batch)[source]
val_wrapper()[source]
class cogdl.wrappers.data_wrapper.node_classification.M3SDataWrapper(dataset, label_rate, approximate, alpha)[source]

Bases: cogdl.wrappers.data_wrapper.node_classification.node_classification_dw.FullBatchNodeClfDataWrapper

static add_args(parser)[source]
get_dataset()[source]
post_stage(stage, model_w_out)[source]

Processing after each run

pre_stage(stage, model_w_out)[source]

Processing before each run

pre_transform()[source]

Data Preprocessing before all runs

class cogdl.wrappers.data_wrapper.node_classification.NetworkEmbeddingDataWrapper(dataset)[source]

Bases: cogdl.wrappers.data_wrapper.base_data_wrapper.DataWrapper

test_wrapper()[source]
train_wrapper()[source]
Returns

  1. DataLoader

  2. cogdl.Graph

  3. list of DataLoader or Graph

Any other data formats other than DataLoader will not be traversed

class cogdl.wrappers.data_wrapper.node_classification.FullBatchNodeClfDataWrapper(dataset)[source]

Bases: cogdl.wrappers.data_wrapper.base_data_wrapper.DataWrapper

pre_transform()[source]

Data Preprocessing before all runs

test_wrapper()[source]
train_wrapper() cogdl.data.data.Graph[source]
Returns

  1. DataLoader

  2. cogdl.Graph

  3. list of DataLoader or Graph

Any other data formats other than DataLoader will not be traversed

val_wrapper()[source]
class cogdl.wrappers.data_wrapper.node_classification.PPRGoDataWrapper(dataset, topk, alpha=0.2, norm='sym', batch_size=512, eps=0.0001, test_batch_size=- 1)[source]

Bases: cogdl.wrappers.data_wrapper.base_data_wrapper.DataWrapper

static add_args(parser)[source]
test_wrapper()[source]
train_wrapper()[source]
batch: tuple(x, targets, ppr_scores, y)

x: shape=(b, num_features) targets: shape=(num_edges_of_batch,)

ppr_scores: shape=(num_edges_of_batch,) y: shape=(b, num_classes)

val_wrapper()[source]
class cogdl.wrappers.data_wrapper.node_classification.SAGNDataWrapper(dataset, batch_size, label_nhop, threshold, nhop)[source]

Bases: cogdl.wrappers.data_wrapper.base_data_wrapper.DataWrapper

static add_args(parser)[source]
post_stage_wrapper()[source]
pre_stage(stage, model_w_out)[source]

Processing before each run

pre_stage_transform(batch)[source]
pre_transform()[source]

Data Preprocessing before all runs

test_transform(batch)[source]
test_wrapper()[source]
train_transform(batch)[source]
train_wrapper()[source]
Returns

  1. DataLoader

  2. cogdl.Graph

  3. list of DataLoader or Graph

Any other data formats other than DataLoader will not be traversed

val_transform(batch)[source]
val_wrapper()[source]

Graph Classification

class cogdl.wrappers.data_wrapper.graph_classification.GraphClassificationDataWrapper(dataset, degree_node_features=False, batch_size=32, train_ratio=0.5, test_ratio=0.3)[source]

Bases: cogdl.wrappers.data_wrapper.base_data_wrapper.DataWrapper

static add_args(parser)[source]
setup_node_features()[source]
test_wrapper()[source]
train_wrapper()[source]
Returns

  1. DataLoader

  2. cogdl.Graph

  3. list of DataLoader or Graph

Any other data formats other than DataLoader will not be traversed

val_wrapper()[source]
class cogdl.wrappers.data_wrapper.graph_classification.GraphEmbeddingDataWrapper(dataset, degree_node_features=False)[source]

Bases: cogdl.wrappers.data_wrapper.base_data_wrapper.DataWrapper

static add_args(parser)[source]
pre_transform()[source]

Data Preprocessing before all runs

test_wrapper()[source]
train_wrapper()[source]
Returns

  1. DataLoader

  2. cogdl.Graph

  3. list of DataLoader or Graph

Any other data formats other than DataLoader will not be traversed

class cogdl.wrappers.data_wrapper.graph_classification.InfoGraphDataWrapper(dataset, degree_node_features=False, batch_size=32, train_ratio=0.5, test_ratio=0.3)[source]

Bases: cogdl.wrappers.data_wrapper.graph_classification.graph_classification_dw.GraphClassificationDataWrapper

test_wrapper()[source]
class cogdl.wrappers.data_wrapper.graph_classification.PATCHY_SAN_DataWrapper(dataset, num_sample, num_neighbor, stride, *args, **kwargs)[source]

Bases: cogdl.wrappers.data_wrapper.graph_classification.graph_classification_dw.GraphClassificationDataWrapper

static add_args(parser)[source]
pre_transform()[source]

Data Preprocessing before all runs

Pretraining

class cogdl.wrappers.data_wrapper.pretraining.GCCDataWrapper(dataset, batch_size, finetune=False, num_workers=4, rw_hops=64, subgraph_size=128, restart_prob=0.8, positional_embedding_size=128, task='node_classification')[source]

Bases: cogdl.wrappers.data_wrapper.base_data_wrapper.DataWrapper

static add_args(parser)[source]
train_wrapper()[source]
Returns

  1. DataLoader

  2. cogdl.Graph

  3. list of DataLoader or Graph

Any other data formats other than DataLoader will not be traversed

Heterogeneous

class cogdl.wrappers.data_wrapper.heterogeneous.HeterogeneousEmbeddingDataWrapper(dataset)[source]

Bases: cogdl.wrappers.data_wrapper.base_data_wrapper.DataWrapper

test_wrapper()[source]
train_wrapper()[source]
Returns

  1. DataLoader

  2. cogdl.Graph

  3. list of DataLoader or Graph

Any other data formats other than DataLoader will not be traversed

class cogdl.wrappers.data_wrapper.heterogeneous.HeterogeneousGNNDataWrapper(dataset)[source]

Bases: cogdl.wrappers.data_wrapper.base_data_wrapper.DataWrapper

test_wrapper()[source]
train_wrapper()[source]
Returns

  1. DataLoader

  2. cogdl.Graph

  3. list of DataLoader or Graph

Any other data formats other than DataLoader will not be traversed

val_wrapper()[source]
class cogdl.wrappers.data_wrapper.heterogeneous.MultiplexEmbeddingDataWrapper(dataset)[source]

Bases: cogdl.wrappers.data_wrapper.base_data_wrapper.DataWrapper

test_wrapper()[source]
train_wrapper()[source]
Returns

  1. DataLoader

  2. cogdl.Graph

  3. list of DataLoader or Graph

Any other data formats other than DataLoader will not be traversed