diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py index dc562df1f..421997946 100644 --- a/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +++ b/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py @@ -25,7 +25,7 @@ from llama_stack.providers.utils.common.data_schema_validator import ( from .config import LlmAsJudgeScoringConfig from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn -LLM_JUDGE_FNS = [LlmAsJudgeScoringFn] +LLM_JUDGE_FN = LlmAsJudgeScoringFn class LlmAsJudgeScoringImpl( @@ -46,11 +46,10 @@ class LlmAsJudgeScoringImpl( self.scoring_fn_id_impls = {} async def initialize(self) -> None: - for fn in LLM_JUDGE_FNS: - impl = fn(inference_api=self.inference_api) - for fn_defs in impl.get_supported_scoring_fn_defs(): - self.scoring_fn_id_impls[fn_defs.identifier] = impl - self.llm_as_judge_fn = impl + impl = LLM_JUDGE_FN(inference_api=self.inference_api) + for fn_defs in impl.get_supported_scoring_fn_defs(): + self.scoring_fn_id_impls[fn_defs.identifier] = impl + self.llm_as_judge_fn = impl async def shutdown(self) -> None: ... @@ -67,7 +66,8 @@ class LlmAsJudgeScoringImpl( return scoring_fn_defs_list async def register_scoring_function(self, function_def: ScoringFn) -> None: - raise NotImplementedError("Register scoring function not implemented yet") + self.llm_as_judge_fn.register_scoring_fn_def(function_def) + # raise NotImplementedError("Register scoring function not implemented yet") async def score_batch( self,