Source code for cogdl.wrappers.model_wrapper.node_classification.mvgrl_mw

import torch
import torch.nn as nn

from .. import UnsupervisedModelWrapper
from cogdl.wrappers.tools.wrapper_utils import evaluate_node_embeddings_using_logreg


[docs]class MVGRLModelWrapper(UnsupervisedModelWrapper): def __init__(self, model, optimizer_cfg): super(MVGRLModelWrapper, self).__init__() self.model = model self.optimizer_cfg = optimizer_cfg self.loss_f = nn.BCEWithLogitsLoss()
[docs] def train_step(self, subgraph): graph = subgraph logits = self.model(graph) labels = torch.zeros_like(logits) num_outs = logits.shape[1] labels[:, : num_outs // 2] = 1 loss = self.loss_f(logits, labels) return loss
[docs] def test_step(self, graph): with torch.no_grad(): pred = self.model(graph) y = graph.y result = evaluate_node_embeddings_using_logreg(pred, y, graph.train_mask, graph.test_mask) self.note("test_acc", result)
[docs] def setup_optimizer(self): cfg = self.optimizer_cfg return torch.optim.Adam(self.parameters(), lr=cfg["lr"], weight_decay=cfg["weight_decay"])