diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring.py b/llama_stack/providers/impls/meta_reference/scoring/scoring.py index 31ccae5b8..3168a5282 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring.py @@ -57,6 +57,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): await impl.initialize() for fn_defs in impl.get_supported_scoring_fn_defs(): self.scoring_fn_id_impls[fn_defs.identifier] = impl + self.llm_as_judge_fn = impl async def shutdown(self) -> None: ... @@ -68,9 +69,8 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): ] async def register_scoring_function(self, function_def: ScoringFnDef) -> None: - raise NotImplementedError( - "Dynamically registering scoring functions is not supported" - ) + self.llm_as_judge_fn.register_scoring_fn_def(function_def) + self.scoring_fn_id_impls[function_def.identifier] = self.llm_as_judge_fn async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None: dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id) diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py index bf3b8de17..559ef41c9 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py @@ -43,7 +43,7 @@ class LlmAsJudgeScoringFn(BaseScoringFn): ), "LLM Judge prompt_template not found." assert ( fn_def.context.judge_score_regex is not None - ), "LLM Judge prompt_template not found." + ), "LLM Judge judge_score_regex not found." input_query = input_row["input_query"] expected_answer = input_row["expected_answer"] diff --git a/llama_stack/providers/impls/third_party/scoring/braintrust/__init__.py b/llama_stack/providers/impls/third_party/scoring/braintrust/__init__.py deleted file mode 100644 index f31d81060..000000000 --- a/llama_stack/providers/impls/third_party/scoring/braintrust/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Any - -from .config import BraintrustScoringConfig - - -async def get_provider_impl(config: BraintrustScoringConfig, _deps) -> Any: - pass - # from .braintrust import VLLMInferenceImpl - - # impl = VLLMInferenceImpl(config) - # await impl.initialize() - # return impl diff --git a/llama_stack/providers/impls/third_party/scoring/braintrust/config.py b/llama_stack/providers/impls/third_party/scoring/braintrust/config.py deleted file mode 100644 index c720c9f67..000000000 --- a/llama_stack/providers/impls/third_party/scoring/braintrust/config.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. -from llama_stack.apis.eval import * # noqa: F401, F403 - - -class BraintrustScoringConfig(BaseModel): ... diff --git a/llama_stack/providers/tests/eval/test_eval.py b/llama_stack/providers/tests/eval/test_eval.py index dc6faaffc..667be1bd5 100644 --- a/llama_stack/providers/tests/eval/test_eval.py +++ b/llama_stack/providers/tests/eval/test_eval.py @@ -77,7 +77,6 @@ async def test_eval(eval_settings): eval_response = await eval_impl.job_result(response.job_id) - print(eval_response) assert eval_response is not None assert len(eval_response.generations) == 5 assert "meta-reference::subset_of" in eval_response.scores diff --git a/llama_stack/providers/tests/scoring/test_scoring.py b/llama_stack/providers/tests/scoring/test_scoring.py index 52904ac1e..86deecc71 100644 --- a/llama_stack/providers/tests/scoring/test_scoring.py +++ b/llama_stack/providers/tests/scoring/test_scoring.py @@ -55,6 +55,47 @@ async def test_scoring_functions_list(scoring_settings): assert "meta-reference::llm_as_judge_8b_correctness" in function_ids +@pytest.mark.asyncio +async def test_scoring_functions_register(scoring_settings): + scoring_impl = scoring_settings["scoring_impl"] + scoring_functions_impl = scoring_settings["scoring_functions_impl"] + datasets_impl = scoring_settings["datasets_impl"] + test_prompt = """Output a number between 0 to 10. Your answer must match the format \n Number: """ + # register the scoring function + await scoring_functions_impl.register_scoring_function( + ScoringFnDefWithProvider( + identifier="meta-reference::llm_as_judge_8b_random", + description="Llm As Judge Scoring Function", + parameters=[], + return_type=NumberType(), + context=LLMAsJudgeContext( + prompt_template=test_prompt, + judge_model="Llama3.1-8B-Instruct", + judge_score_regex=[r"Number: (\d+)"], + ), + provider_id="test-meta", + ) + ) + + scoring_functions = await scoring_functions_impl.list_scoring_functions() + assert isinstance(scoring_functions, list) + assert len(scoring_functions) > 0 + function_ids = [f.identifier for f in scoring_functions] + assert "meta-reference::llm_as_judge_8b_random" in function_ids + + # test score using newly registered scoring function + await register_dataset(datasets_impl) + response = await datasets_impl.list_datasets() + assert len(response) == 1 + response = await scoring_impl.score_batch( + dataset_id=response[0].identifier, + scoring_functions=[ + "meta-reference::llm_as_judge_8b_random", + ], + ) + assert "meta-reference::llm_as_judge_8b_random" in response.results + + @pytest.mark.asyncio async def test_scoring_score(scoring_settings): scoring_impl = scoring_settings["scoring_impl"] @@ -73,7 +114,6 @@ async def test_scoring_score(scoring_settings): ], ) - print(response) assert len(response.results) == 3 assert "meta-reference::equality" in response.results assert "meta-reference::subset_of" in response.results