diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py index dc562df1f..5b1715d9f 100644 --- a/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +++ b/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py @@ -25,7 +25,7 @@ from llama_stack.providers.utils.common.data_schema_validator import ( from .config import LlmAsJudgeScoringConfig from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn -LLM_JUDGE_FNS = [LlmAsJudgeScoringFn] +LLM_JUDGE_FN = LlmAsJudgeScoringFn class LlmAsJudgeScoringImpl( @@ -43,23 +43,17 @@ class LlmAsJudgeScoringImpl( self.datasetio_api = datasetio_api self.datasets_api = datasets_api self.inference_api = inference_api - self.scoring_fn_id_impls = {} async def initialize(self) -> None: - for fn in LLM_JUDGE_FNS: - impl = fn(inference_api=self.inference_api) - for fn_defs in impl.get_supported_scoring_fn_defs(): - self.scoring_fn_id_impls[fn_defs.identifier] = impl - self.llm_as_judge_fn = impl + impl = LLM_JUDGE_FN(inference_api=self.inference_api) + self.llm_as_judge_fn = impl async def shutdown(self) -> None: ... async def list_scoring_functions(self) -> List[ScoringFn]: - scoring_fn_defs_list = [ - fn_def for impl in self.scoring_fn_id_impls.values() for fn_def in impl.get_supported_scoring_fn_defs() - ] + scoring_fn_defs_list = self.llm_as_judge_fn.get_supported_scoring_fn_defs() - for f in scoring_fn_defs_list: + for f in self.llm_as_judge_fn.get_supported_scoring_fn_defs(): assert f.identifier.startswith("llm-as-judge"), ( "All llm-as-judge scoring fn must have identifier prefixed with 'llm-as-judge'! " ) @@ -67,7 +61,7 @@ class LlmAsJudgeScoringImpl( return scoring_fn_defs_list async def register_scoring_function(self, function_def: ScoringFn) -> None: - raise NotImplementedError("Register scoring function not implemented yet") + self.llm_as_judge_fn.register_scoring_fn_def(function_def) async def score_batch( self, @@ -102,9 +96,7 @@ class LlmAsJudgeScoringImpl( ) -> ScoreResponse: res = {} for scoring_fn_id in scoring_functions.keys(): - if scoring_fn_id not in self.scoring_fn_id_impls: - raise ValueError(f"Scoring function {scoring_fn_id} is not supported.") - scoring_fn = self.scoring_fn_id_impls[scoring_fn_id] + scoring_fn = self.llm_as_judge_fn scoring_fn_params = scoring_functions.get(scoring_fn_id, None) score_results = await scoring_fn.score(input_rows, scoring_fn_id, scoring_fn_params) agg_results = await scoring_fn.aggregate(score_results, scoring_fn_id, scoring_fn_params)