from typing import Optional, Type, Any
import torch.nn as nn
[docs]class BaseModel(nn.Module):
[docs] @staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
pass
[docs] @classmethod
def build_model_from_args(cls, args):
"""Build a new model instance."""
raise NotImplementedError("Models must implement the build_model_from_args method")
def __init__(self):
super(BaseModel, self).__init__()
self.model_name = self.__class__.__name__
self.loss_fn = None
self.evaluator = None
def _forward_unimplemented(self, *input: Any) -> None: # abc warning
pass
[docs] def forward(self, *args):
raise NotImplementedError
[docs] def predict(self, data):
return self.forward(data)
@property
def device(self):
return next(self.parameters()).device
[docs] def set_loss_fn(self, loss_fn):
self.loss_fn = loss_fn