import torch
import torch.nn.functional as F
from .. import BaseModel, register_model
from cogdl.utils import add_remaining_self_loops, spmm
from .mlp import MLP
[docs]@register_model("ppnp")
class PPNP(BaseModel):
[docs] @staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
# fmt: off
parser.add_argument("--num-features", type=int)
parser.add_argument("--num-classes", type=int)
parser.add_argument("--hidden-size", type=int, default=64)
parser.add_argument("--dropout", type=float, default=0.5)
parser.add_argument("--propagation-type", type=str, default="appnp")
parser.add_argument("--alpha", type=float, default=0.1)
parser.add_argument("--num-layers", type=int, default=2)
parser.add_argument("--num-iterations", type=int, default=10) # only for appnp
# fmt: on
[docs] @classmethod
def build_model_from_args(cls, args):
return cls(
args.num_features,
args.hidden_size,
args.num_classes,
args.num_layers,
args.dropout,
args.propagation_type,
args.alpha,
args.num_iterations,
)
def __init__(self, nfeat, nhid, nclass, num_layers, dropout, propagation, alpha, niter, cache=True):
super(PPNP, self).__init__()
# GCN as a prediction and then apply the personalized page rank on the results
self.nn = MLP(nfeat, nclass, nhid, num_layers, dropout)
if propagation not in ("appnp", "ppnp"):
print("Invalid propagation type, using default appnp")
propagation = "appnp"
self.propagation = propagation
self.alpha = alpha
self.niter = niter
self.dropout = dropout
self.vals = None # speedup for ppnp
self.use_cache = cache
self.cache = dict()
def _calculate_A_hat(self, x, edge_index):
device = x.device
edge_attr = torch.ones(edge_index.shape[1]).to(device)
edge_index, edge_attr = add_remaining_self_loops(edge_index, edge_attr, 1, x.shape[0])
deg = spmm(edge_index, edge_attr, torch.ones(x.shape[0], 1).to(device)).squeeze()
deg_sqrt = deg.pow(-1 / 2)
edge_attr = deg_sqrt[edge_index[1]] * edge_attr * deg_sqrt[edge_index[0]]
return edge_index, edge_attr
[docs] def forward(self, x, adj):
def get_ready_format(input, edge_index, edge_attr=None):
if edge_attr is None:
edge_attr = torch.ones(edge_index.shape[1]).float().to(input.device)
adj = torch.sparse_coo_tensor(
edge_index,
edge_attr,
(input.shape[0], input.shape[0]),
).to(input.device)
return adj
if self.use_cache:
flag = str(adj.shape[1])
if flag not in self.cache:
edge_index, edge_attr = self._calculate_A_hat(x, adj)
self.cache[flag] = (edge_index, edge_attr)
else:
edge_index, edge_attr = self.cache[flag]
else:
edge_index, edge_attr = self._calculate_A_hat(x, adj)
# get prediction
x = F.dropout(x, p=self.dropout, training=self.training)
local_preds = self.nn.forward(x)
# apply personalized pagerank
if self.propagation == "ppnp":
if self.vals is None:
self.vals = self.alpha * torch.inverse(
torch.eye(x.shape[0]).to(x.device) - (1 - self.alpha) * get_ready_format(x, edge_index, edge_attr)
)
final_preds = F.dropout(self.vals) @ local_preds
else: # appnp
preds = local_preds
edge_attr = F.dropout(edge_attr, p=self.dropout, training=self.training)
for _ in range(self.niter):
new_features = spmm(edge_index, edge_attr, preds)
preds = (1 - self.alpha) * new_features + self.alpha * local_preds
final_preds = preds
return final_preds
[docs] def predict(self, data):
return self.forward(data.x, data.edge_index)