refactor base scoring fn v.s. registerable scoring fn

This commit is contained in:
Xi Yan 2024-12-19 11:49:49 -08:00
parent 03607a68c7
commit 0096c1a6fc
3 changed files with 45 additions and 6 deletions

View file

@ -13,12 +13,51 @@ from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metr
class BaseScoringFn(ABC):
"""
Base interface class for all native scoring_fns.
Each scoring_fn needs to implement the following methods:
Base interface class for Scoring Functions.
Each scoring function needs to implement the following methods:
- score_row(self, row)
- aggregate(self, scoring_fn_results)
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
def __str__(self) -> str:
return self.__class__.__name__
@abstractmethod
async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow:
raise NotImplementedError()
@abstractmethod
async def aggregate(
self,
scoring_results: List[ScoringResultRow],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> Dict[str, Any]:
raise NotImplementedError()
@abstractmethod
async def score(
self,
input_rows: List[Dict[str, Any]],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> List[ScoringResultRow]:
raise NotImplementedError()
class RegisteredBaseScoringFn(BaseScoringFn):
"""
Interface for native scoring functions that are registered in LlamaStack.
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.supported_fn_defs_registry = {}