Source code for cogdl.models.base_model

from typing import Optional, Type, Any
import torch.nn as nn

from cogdl.trainers.base_trainer import BaseTrainer


[docs]class BaseModel(nn.Module):
[docs] @staticmethod def add_args(parser): """Add model-specific arguments to the parser.""" pass
[docs] @classmethod def build_model_from_args(cls, args): """Build a new model instance.""" raise NotImplementedError("Models must implement the build_model_from_args method")
def __init__(self): super(BaseModel, self).__init__() self.device = "" self.loss_fn = None self.evaluator = None def _forward_unimplemented(self, *input: Any) -> None: # abc warning pass
[docs] def forward(self, *args): raise NotImplementedError
[docs] def predict(self, data): return self.forward(data)
[docs] def node_classification_loss(self, data, mask=None): if mask is None: mask = data.train_mask pred = self.forward(data) return self.loss_fn(pred[mask], data.y[mask])
[docs] def graph_classification_loss(self, batch): pred = self.forward(batch) return self.loss_fn(pred, batch.y)
[docs] @staticmethod def get_trainer(args=None) -> Optional[Type[BaseTrainer]]: return None
[docs] def set_device(self, device): self.device = device
[docs] def set_loss_fn(self, loss_fn): self.loss_fn = loss_fn