from abc import ABC, abstractmethod
from typing import Any, Optional, Type
from typing import TYPE_CHECKING
from cogdl.models.base_model import BaseModel
if TYPE_CHECKING:
# trick for resolve circular import
from cogdl.trainers.supervised_model_trainer import (
SupervisedHomogeneousNodeClassificationTrainer,
SupervisedHeterogeneousNodeClassificationTrainer,
)
[docs]class SupervisedModel(BaseModel, ABC):
[docs] @abstractmethod
def loss(self, data: Any) -> Any:
raise NotImplementedError
[docs]class SupervisedHeterogeneousNodeClassificationModel(BaseModel, ABC):
[docs] @abstractmethod
def loss(self, data: Any) -> Any:
raise NotImplementedError
[docs] def evaluate(self, data: Any, nodes: Any, targets: Any) -> Any:
raise NotImplementedError
[docs] @staticmethod
def get_trainer(taskType: Any, args: Any) -> "Optional[Type[SupervisedHeterogeneousNodeClassificationTrainer]]":
return None
[docs]class SupervisedHomogeneousNodeClassificationModel(BaseModel, ABC):
[docs] @abstractmethod
def loss(self, data: Any) -> Any:
raise NotImplementedError
[docs] @abstractmethod
def predict(self, data: Any) -> Any:
raise NotImplementedError
[docs] @staticmethod
def get_trainer(
taskType: Any,
args: Any,
) -> "Optional[Type[SupervisedHomogeneousNodeClassificationTrainer]]":
return None