From 0096c1a6fcadc20207c57ac5e1234f64b214ab8b Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 19 Dec 2024 11:49:49 -0800 Subject: [PATCH] refactor base scoring fn v.s. registerable scoring fn --- .../basic/scoring_fn/equality_scoring_fn.py | 4 +- .../scoring_fn/llm_as_judge_scoring_fn.py | 4 +- .../utils/scoring/base_scoring_fn.py | 43 ++++++++++++++++++- 3 files changed, 45 insertions(+), 6 deletions(-) diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py index 9991c5502..9b0566228 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py @@ -9,12 +9,12 @@ from typing import Any, Dict, Optional from llama_stack.apis.scoring import ScoringResultRow from llama_stack.apis.scoring_functions import ScoringFnParams -from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn +from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn from .fn_defs.equality import equality -class EqualityScoringFn(BaseScoringFn): +class EqualityScoringFn(RegisteredBaseScoringFn): """ A scoring_fn that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise. """ diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py index 00ea53c8f..027709f74 100644 --- a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py @@ -12,14 +12,14 @@ from llama_stack.apis.inference.inference import Inference from llama_stack.apis.scoring import ScoringResultRow from llama_stack.apis.scoring_functions import ScoringFnParams -from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn +from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn from .fn_defs.llm_as_judge_405b_simpleqa import llm_as_judge_405b_simpleqa from .fn_defs.llm_as_judge_base import llm_as_judge_base -class LlmAsJudgeScoringFn(BaseScoringFn): +class LlmAsJudgeScoringFn(RegisteredBaseScoringFn): """ A scoring_fn that assigns """ diff --git a/llama_stack/providers/utils/scoring/base_scoring_fn.py b/llama_stack/providers/utils/scoring/base_scoring_fn.py index 2db77fd2b..e0e557374 100644 --- a/llama_stack/providers/utils/scoring/base_scoring_fn.py +++ b/llama_stack/providers/utils/scoring/base_scoring_fn.py @@ -13,12 +13,51 @@ from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metr class BaseScoringFn(ABC): """ - Base interface class for all native scoring_fns. - Each scoring_fn needs to implement the following methods: + Base interface class for Scoring Functions. + Each scoring function needs to implement the following methods: - score_row(self, row) - aggregate(self, scoring_fn_results) """ + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def __str__(self) -> str: + return self.__class__.__name__ + + @abstractmethod + async def score_row( + self, + input_row: Dict[str, Any], + scoring_fn_identifier: Optional[str] = None, + scoring_params: Optional[ScoringFnParams] = None, + ) -> ScoringResultRow: + raise NotImplementedError() + + @abstractmethod + async def aggregate( + self, + scoring_results: List[ScoringResultRow], + scoring_fn_identifier: Optional[str] = None, + scoring_params: Optional[ScoringFnParams] = None, + ) -> Dict[str, Any]: + raise NotImplementedError() + + @abstractmethod + async def score( + self, + input_rows: List[Dict[str, Any]], + scoring_fn_identifier: Optional[str] = None, + scoring_params: Optional[ScoringFnParams] = None, + ) -> List[ScoringResultRow]: + raise NotImplementedError() + + +class RegisteredBaseScoringFn(BaseScoringFn): + """ + Interface for native scoring functions that are registered in LlamaStack. + """ + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.supported_fn_defs_registry = {}