diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py index 0a3fe877f..9991c5502 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py @@ -4,12 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Optional from llama_stack.apis.scoring import ScoringResultRow from llama_stack.apis.scoring_functions import ScoringFnParams -from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn from .fn_defs.equality import equality @@ -43,22 +42,3 @@ class EqualityScoringFn(BaseScoringFn): return { "score": score, } - - async def aggregate( - self, - scoring_results: List[ScoringResultRow], - scoring_fn_identifier: Optional[str] = None, - scoring_params: Optional[ScoringFnParams] = None, - ) -> Dict[str, Any]: - params = self.supported_fn_defs_registry[scoring_fn_identifier].params - if scoring_params is not None: - params = scoring_params - - aggregation_functions = [] - if ( - params - and hasattr(params, "aggregation_functions") - and params.aggregation_functions - ): - aggregation_functions.extend(params.aggregation_functions) - return aggregate_metrics(scoring_results, aggregation_functions) diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py index 326d090f8..552f34d46 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py @@ -5,11 +5,10 @@ # the root directory of this source tree. import re -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Optional from llama_stack.apis.scoring import ScoringResultRow from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType -from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn from .fn_defs.regex_parser_multiple_choice_answer import ( @@ -61,22 +60,3 @@ class RegexParserScoringFn(BaseScoringFn): return { "score": score, } - - async def aggregate( - self, - scoring_results: List[ScoringResultRow], - scoring_fn_identifier: Optional[str] = None, - scoring_params: Optional[ScoringFnParams] = None, - ) -> Dict[str, Any]: - params = self.supported_fn_defs_registry[scoring_fn_identifier].params - if scoring_params is not None: - params = scoring_params - - aggregation_functions = [] - if ( - params - and hasattr(params, "aggregation_functions") - and params.aggregation_functions - ): - aggregation_functions.extend(params.aggregation_functions) - return aggregate_metrics(scoring_results, aggregation_functions) diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/subset_of_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/subset_of_scoring_fn.py index e813cbe6d..29ae12e44 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/subset_of_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/subset_of_scoring_fn.py @@ -4,11 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Optional from llama_stack.apis.scoring import ScoringResultRow from llama_stack.apis.scoring_functions import ScoringFnParams -from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn from .fn_defs.subset_of import subset_of @@ -37,22 +36,3 @@ class SubsetOfScoringFn(BaseScoringFn): return { "score": score, } - - async def aggregate( - self, - scoring_results: List[ScoringResultRow], - scoring_fn_identifier: Optional[str] = None, - scoring_params: Optional[ScoringFnParams] = None, - ) -> Dict[str, Any]: - params = self.supported_fn_defs_registry[scoring_fn_identifier].params - if scoring_params is not None: - params = scoring_params - - aggregation_functions = [] - if ( - params - and hasattr(params, "aggregation_functions") - and params.aggregation_functions - ): - aggregation_functions.extend(params.aggregation_functions) - return aggregate_metrics(scoring_results, aggregation_functions) diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py index 4320cfdc2..f5baf801d 100644 --- a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py @@ -5,15 +5,13 @@ # the root directory of this source tree. import re -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Optional from llama_stack.apis.inference.inference import Inference from llama_stack.apis.scoring import ScoringResultRow from llama_stack.apis.scoring_functions import ScoringFnParams -from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics - from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn from .fn_defs.llm_as_judge_405b_simpleqa import llm_as_judge_405b_simpleqa @@ -47,7 +45,10 @@ class LlmAsJudgeScoringFn(BaseScoringFn): # override params if scoring_params is provided if scoring_params is not None: - fn_def.params = scoring_params + for attr in scoring_params.__dict__: + override_attr = getattr(scoring_params, attr) + if override_attr is not None: + setattr(fn_def.params, attr, override_attr) assert fn_def.params is not None, f"LLMAsJudgeparams not found for {fn_def}." assert ( @@ -90,22 +91,3 @@ class LlmAsJudgeScoringFn(BaseScoringFn): "score": judge_rating, "judge_feedback": content, } - - async def aggregate( - self, - scoring_results: List[ScoringResultRow], - scoring_fn_identifier: Optional[str] = None, - scoring_params: Optional[ScoringFnParams] = None, - ) -> Dict[str, Any]: - params = self.supported_fn_defs_registry[scoring_fn_identifier].params - if scoring_params is not None: - params = scoring_params - - aggregation_functions = [] - if ( - params - and hasattr(params, "aggregation_functions") - and params.aggregation_functions - ): - aggregation_functions.extend(params.aggregation_functions) - return aggregate_metrics(scoring_results, aggregation_functions) diff --git a/llama_stack/providers/utils/scoring/base_scoring_fn.py b/llama_stack/providers/utils/scoring/base_scoring_fn.py index 26ba816c4..2db77fd2b 100644 --- a/llama_stack/providers/utils/scoring/base_scoring_fn.py +++ b/llama_stack/providers/utils/scoring/base_scoring_fn.py @@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional from llama_stack.apis.scoring import ScoringFnParams, ScoringResultRow from llama_stack.apis.scoring_functions import ScoringFn +from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics class BaseScoringFn(ABC): @@ -44,14 +45,27 @@ class BaseScoringFn(ABC): ) -> ScoringResultRow: raise NotImplementedError() - @abstractmethod async def aggregate( self, scoring_results: List[ScoringResultRow], scoring_fn_identifier: Optional[str] = None, scoring_params: Optional[ScoringFnParams] = None, ) -> Dict[str, Any]: - raise NotImplementedError() + params = self.supported_fn_defs_registry[scoring_fn_identifier].params + if scoring_params is not None: + if params is None: + params = scoring_params + else: + params.aggregation_functions = scoring_params.aggregation_functions + + aggregation_functions = [] + if ( + params + and hasattr(params, "aggregation_functions") + and params.aggregation_functions + ): + aggregation_functions.extend(params.aggregation_functions) + return aggregate_metrics(scoring_results, aggregation_functions) async def score( self,