import torch
from .graph_classification_dw import GraphClassificationDataWrapper
from cogdl.models.nn.patchy_san import get_single_feature
[docs]class PATCHY_SAN_DataWrapper(GraphClassificationDataWrapper):
[docs] @staticmethod
def add_args(parser):
GraphClassificationDataWrapper.add_args(parser)
parser.add_argument("--num-sample", default=30, type=int, help="Number of chosen vertexes")
parser.add_argument("--num-neighbor", default=10, type=int, help="Number of neighbor in constructing features")
parser.add_argument("--stride", default=1, type=int, help="Stride of chosen vertexes")
def __init__(self, dataset, num_sample, num_neighbor, stride, *args, **kwargs):
super(PATCHY_SAN_DataWrapper, self).__init__(dataset, *args, **kwargs)
self.sample = num_sample
self.dataset = dataset
self.neighbor = num_neighbor
self.stride = stride