feat(1/n): scoring function registration for llm-as-judge (#1405)

# What does this PR do?

- add ability to register a llm-as-judge scoring function with custom
judge prompts / params.
- Closes https://github.com/meta-llama/llama-stack/issues/1395

[//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])

## Test Plan
**Via CLI**
```
llama-stack-client scoring_functions register \ 
--scoring-fn-id "llm-as-judge::my-prompt" \
--description "my custom judge" \
--return-type '{"type": "string"}' \
--provider-id "llm-as-judge" \
--provider-scoring-fn-id "my-prompt" \
--params '{"type": "llm_as_judge", "judge_model": "meta-llama/Llama-3.2-3B-Instruct", "prompt_template": "always output 1.0"}'
```

<img width="1373" alt="image"
src="https://github.com/user-attachments/assets/7c6fc0ae-64fe-4581-8927-a9d8d746bd72"
/>

- Unit test will be addressed with
https://github.com/meta-llama/llama-stack/issues/1396


[//]: # (## Documentation)
This commit is contained in:
Xi Yan 2025-03-05 10:00:34 -08:00 committed by GitHub
parent 3d9331840e
commit d3508c4c76
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -25,7 +25,7 @@ from llama_stack.providers.utils.common.data_schema_validator import (
from .config import LlmAsJudgeScoringConfig from .config import LlmAsJudgeScoringConfig
from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn
LLM_JUDGE_FNS = [LlmAsJudgeScoringFn] LLM_JUDGE_FN = LlmAsJudgeScoringFn
class LlmAsJudgeScoringImpl( class LlmAsJudgeScoringImpl(
@ -43,23 +43,17 @@ class LlmAsJudgeScoringImpl(
self.datasetio_api = datasetio_api self.datasetio_api = datasetio_api
self.datasets_api = datasets_api self.datasets_api = datasets_api
self.inference_api = inference_api self.inference_api = inference_api
self.scoring_fn_id_impls = {}
async def initialize(self) -> None: async def initialize(self) -> None:
for fn in LLM_JUDGE_FNS: impl = LLM_JUDGE_FN(inference_api=self.inference_api)
impl = fn(inference_api=self.inference_api)
for fn_defs in impl.get_supported_scoring_fn_defs():
self.scoring_fn_id_impls[fn_defs.identifier] = impl
self.llm_as_judge_fn = impl self.llm_as_judge_fn = impl
async def shutdown(self) -> None: ... async def shutdown(self) -> None: ...
async def list_scoring_functions(self) -> List[ScoringFn]: async def list_scoring_functions(self) -> List[ScoringFn]:
scoring_fn_defs_list = [ scoring_fn_defs_list = self.llm_as_judge_fn.get_supported_scoring_fn_defs()
fn_def for impl in self.scoring_fn_id_impls.values() for fn_def in impl.get_supported_scoring_fn_defs()
]
for f in scoring_fn_defs_list: for f in self.llm_as_judge_fn.get_supported_scoring_fn_defs():
assert f.identifier.startswith("llm-as-judge"), ( assert f.identifier.startswith("llm-as-judge"), (
"All llm-as-judge scoring fn must have identifier prefixed with 'llm-as-judge'! " "All llm-as-judge scoring fn must have identifier prefixed with 'llm-as-judge'! "
) )
@ -67,7 +61,7 @@ class LlmAsJudgeScoringImpl(
return scoring_fn_defs_list return scoring_fn_defs_list
async def register_scoring_function(self, function_def: ScoringFn) -> None: async def register_scoring_function(self, function_def: ScoringFn) -> None:
raise NotImplementedError("Register scoring function not implemented yet") self.llm_as_judge_fn.register_scoring_fn_def(function_def)
async def score_batch( async def score_batch(
self, self,
@ -102,9 +96,7 @@ class LlmAsJudgeScoringImpl(
) -> ScoreResponse: ) -> ScoreResponse:
res = {} res = {}
for scoring_fn_id in scoring_functions.keys(): for scoring_fn_id in scoring_functions.keys():
if scoring_fn_id not in self.scoring_fn_id_impls: scoring_fn = self.llm_as_judge_fn
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
scoring_fn_params = scoring_functions.get(scoring_fn_id, None) scoring_fn_params = scoring_functions.get(scoring_fn_id, None)
score_results = await scoring_fn.score(input_rows, scoring_fn_id, scoring_fn_params) score_results = await scoring_fn.score(input_rows, scoring_fn_id, scoring_fn_params)
agg_results = await scoring_fn.aggregate(score_results, scoring_fn_id, scoring_fn_params) agg_results = await scoring_fn.aggregate(score_results, scoring_fn_id, scoring_fn_params)