Source code for cogdl.models.base_model

from typing import Optional, Type, Any
import torch.nn as nn


[docs]class BaseModel(nn.Module):
[docs] @staticmethod def add_args(parser): """Add model-specific arguments to the parser.""" pass
[docs] @classmethod def build_model_from_args(cls, args): """Build a new model instance.""" raise NotImplementedError("Models must implement the build_model_from_args method")
def __init__(self): super(BaseModel, self).__init__() self.model_name = self.__class__.__name__ self.loss_fn = None self.evaluator = None def _forward_unimplemented(self, *input: Any) -> None: # abc warning pass
[docs] def forward(self, *args): raise NotImplementedError
[docs] def predict(self, data): return self.forward(data)
@property def device(self): return next(self.parameters()).device
[docs] def set_loss_fn(self, loss_fn): self.loss_fn = loss_fn