Source code for cogdl.models.nn.sgc

from cogdl.layers import SGCLayer

from .. import BaseModel


[docs]class sgc(BaseModel):
[docs] @staticmethod def add_args(parser): parser.add_argument("--num-features", type=int) parser.add_argument("--num-classes", type=int)
[docs] @classmethod def build_model_from_args(cls, args): return cls(in_feats=args.num_features, out_feats=args.num_classes)
def __init__(self, in_feats, out_feats): super(sgc, self).__init__() self.nn = SGCLayer(in_feats, out_feats) self.cache = dict()
[docs] def forward(self, graph): graph.sym_norm() x = self.nn(graph, graph.x) return x
[docs] def predict(self, data): return self.forward(data)