From f55b19e0d0b885875d765b9d902d4ff0e6cc87d6 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 5 Mar 2025 09:53:40 -0800 Subject: [PATCH] update registration --- .../inline/scoring/llm_as_judge/scoring.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py index aa987283c..5b1715d9f 100644 --- a/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +++ b/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py @@ -43,22 +43,17 @@ class LlmAsJudgeScoringImpl( self.datasetio_api = datasetio_api self.datasets_api = datasets_api self.inference_api = inference_api - self.scoring_fn_id_impls = {} async def initialize(self) -> None: impl = LLM_JUDGE_FN(inference_api=self.inference_api) - for fn_defs in impl.get_supported_scoring_fn_defs(): - self.scoring_fn_id_impls[fn_defs.identifier] = impl - self.llm_as_judge_fn = impl + self.llm_as_judge_fn = impl async def shutdown(self) -> None: ... async def list_scoring_functions(self) -> List[ScoringFn]: - scoring_fn_defs_list = [ - fn_def for impl in self.scoring_fn_id_impls.values() for fn_def in impl.get_supported_scoring_fn_defs() - ] + scoring_fn_defs_list = self.llm_as_judge_fn.get_supported_scoring_fn_defs() - for f in scoring_fn_defs_list: + for f in self.llm_as_judge_fn.get_supported_scoring_fn_defs(): assert f.identifier.startswith("llm-as-judge"), ( "All llm-as-judge scoring fn must have identifier prefixed with 'llm-as-judge'! " ) @@ -101,9 +96,7 @@ class LlmAsJudgeScoringImpl( ) -> ScoreResponse: res = {} for scoring_fn_id in scoring_functions.keys(): - if scoring_fn_id not in self.scoring_fn_id_impls: - raise ValueError(f"Scoring function {scoring_fn_id} is not supported.") - scoring_fn = self.scoring_fn_id_impls[scoring_fn_id] + scoring_fn = self.llm_as_judge_fn scoring_fn_params = scoring_functions.get(scoring_fn_id, None) score_results = await scoring_fn.score(input_rows, scoring_fn_id, scoring_fn_params) agg_results = await scoring_fn.aggregate(score_results, scoring_fn_id, scoring_fn_params)