import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from cogdl.utils import spmm
from .. import BaseModel
def mix_hidden_state(feat, target, train_index, alpha):
if alpha > 0:
lamb = np.random.beta(alpha, alpha)
else:
lamb = 1
permuted_index = train_index[torch.randperm(train_index.size(0))]
feat[train_index] = lamb * feat[train_index] + (1 - lamb) * feat[permuted_index]
return feat, target[train_index], target[permuted_index], lamb
class GCNConv(nn.Module):
def __init__(self, in_feats, out_feats):
super(GCNConv, self).__init__()
self.weight = nn.Linear(in_features=in_feats, out_features=out_feats)
self.edge_index = None
self.edge_attr = None
def forward(self, graph, x):
h = self.weight(x)
h = spmm(graph, h)
return h
def forward_aux(self, x):
return self.weight(x)
[docs]class GCNMix(BaseModel):
[docs] @staticmethod
def add_args(parser):
parser.add_argument("--dropout", type=float, default=0.5)
parser.add_argument("--hidden-size", type=int, default=64)
parser.add_argument("--alpha", type=float, default=1.0)
parser.add_argument("--k", type=int, default=10)
parser.add_argument("--temperature", type=float, default=0.1)
# parser.add_argument("--rampup-starts", type=int, default=500)
# parser.add_argument("--rampup_ends", type=int, default=1000)
# parser.add_argument("--mixup-consistency", type=float, default=10.0)
# parser.add_argument("--ema-decay", type=float, default=0.999)
# parser.add_argument("--tau", type=float, default=1.0)
[docs] @classmethod
def build_model_from_args(cls, args):
return cls(
in_feat=args.num_features,
hidden_size=args.hidden_size,
num_classes=args.num_classes,
k=args.k,
temperature=args.temperature,
alpha=args.alpha,
dropout=args.dropout,
)
def __init__(self, in_feat, hidden_size, num_classes, k, temperature, alpha, dropout):
super(GCNMix, self).__init__()
self.dropout = dropout
self.alpha = alpha
self.k = k
self.temperature = temperature
self.input_gnn = GCNConv(in_feat, hidden_size)
self.hidden_gnn = GCNConv(hidden_size, num_classes)
self.loss_f = nn.BCELoss()
[docs] def forward(self, graph):
graph.sym_norm()
x = graph.x
h = F.dropout(x, p=self.dropout, training=self.training)
h = self.input_gnn(graph, h)
h = F.relu(h)
h = F.dropout(h, p=self.dropout, training=self.training)
h = self.hidden_gnn(graph, h)
return h
[docs] def forward_aux(self, x, label, train_index, mix_hidden=True, layer_mix=1):
h = F.dropout(x, p=self.dropout, training=self.training)
assert layer_mix in (0, 1)
if layer_mix == 0:
h, target, target_mix, lamb = mix_hidden_state(h, label, train_index, self.alpha)
h = self.input_gnn.forward_aux(h)
h = F.relu(h)
if layer_mix == 1:
h, target, target_mix, lamb = mix_hidden_state(h, label, train_index, self.alpha)
h = F.dropout(h, p=self.dropout, training=self.training)
h = self.hidden_gnn.forward_aux(h)
target_label = lamb * target + (1 - lamb) * target_mix
return h, target_label
[docs] def predict_noise(self, data, tau=1):
out = self.forward(data) / tau
return out