Source code for cogdl.models.supervised_model

from abc import ABC, abstractmethod
from typing import Any, Optional, Type
from typing import TYPE_CHECKING

from cogdl.models.base_model import BaseModel

if TYPE_CHECKING:
    # trick for resolve circular import
    from cogdl.trainers.supervised_model_trainer import (
        SupervisedHomogeneousNodeClassificationTrainer,
        SupervisedHeterogeneousNodeClassificationTrainer,
    )


[docs]class SupervisedModel(BaseModel, ABC):
[docs] @abstractmethod def loss(self, data: Any) -> Any: raise NotImplementedError
[docs]class SupervisedHeterogeneousNodeClassificationModel(BaseModel, ABC):
[docs] @abstractmethod def loss(self, data: Any) -> Any: raise NotImplementedError
[docs] def evaluate(self, data: Any, nodes: Any, targets: Any) -> Any: raise NotImplementedError
[docs] @staticmethod def get_trainer(taskType: Any, args: Any) -> "Optional[Type[SupervisedHeterogeneousNodeClassificationTrainer]]": return None
[docs]class SupervisedHomogeneousNodeClassificationModel(BaseModel, ABC):
[docs] @abstractmethod def loss(self, data: Any) -> Any: raise NotImplementedError
[docs] @abstractmethod def predict(self, data: Any) -> Any: raise NotImplementedError
[docs] @staticmethod def get_trainer( taskType: Any, args: Any, ) -> "Optional[Type[SupervisedHomogeneousNodeClassificationTrainer]]": return None