From 7b8748c53ed5eb2e89011ccfe9371e1231e5203b Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Mon, 28 Oct 2024 14:08:42 -0700 Subject: [PATCH] [Evals API][6/n] meta-reference llm as judge, registration for ScoringFnDefs (#330) * wip scoring refactor * llm as judge, move folders * test full generation + eval * extract score regex to llm context * remove prints, cleanup braintrust in this branch * change json -> class * remove initialize * address nits * check identifier prefix * udpate MANIFEST --- .../scoring_functions/scoring_functions.py | 4 + .../impls/meta_reference/eval/eval.py | 3 + .../impls/meta_reference/scoring/__init__.py | 4 +- .../impls/meta_reference/scoring/scoring.py | 53 +++++++---- .../scoring/scoring_fn/base_scoring_fn.py | 32 +++++-- .../scoring/scoring_fn/common.py | 12 +++ .../scoring/scoring_fn/equality_scoring_fn.py | 26 ++++-- .../scoring/scoring_fn/fn_defs/__init__.py | 5 ++ .../scoring/scoring_fn/fn_defs/equality.py | 16 ++++ .../fn_defs/llm_as_judge_8b_correctness.py | 36 ++++++++ .../scoring/scoring_fn/fn_defs/subset_of.py | 16 ++++ .../scoring_fn/llm_as_judge_scoring_fn.py | 89 +++++++++++++++++++ .../scoring_fn/subset_of_scoring_fn.py | 30 ++++--- llama_stack/providers/registry/scoring.py | 1 + .../tests/datasetio/test_datasetio.py | 1 + .../tests/eval/provider_config_example.yaml | 4 + llama_stack/providers/tests/eval/test_eval.py | 8 +- .../scoring/provider_config_example.yaml | 5 ++ .../providers/tests/scoring/test_scoring.py | 61 +++++++++++-- tests/examples/evals-tgi-run.yaml | 4 + 20 files changed, 360 insertions(+), 50 deletions(-) create mode 100644 llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/__init__.py create mode 100644 llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/equality.py create mode 100644 llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py create mode 100644 llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/subset_of.py create mode 100644 llama_stack/providers/impls/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py index e140430ac..6b90aea6d 100644 --- a/llama_stack/apis/scoring_functions/scoring_functions.py +++ b/llama_stack/apis/scoring_functions/scoring_functions.py @@ -26,6 +26,10 @@ class Parameter(BaseModel): class LLMAsJudgeContext(BaseModel): judge_model: str prompt_template: Optional[str] = None + judge_score_regex: Optional[List[str]] = Field( + description="Regex to extract the score from the judge response", + default=None, + ) @json_schema_type diff --git a/llama_stack/providers/impls/meta_reference/eval/eval.py b/llama_stack/providers/impls/meta_reference/eval/eval.py index d675e40eb..3aec6170f 100644 --- a/llama_stack/providers/impls/meta_reference/eval/eval.py +++ b/llama_stack/providers/impls/meta_reference/eval/eval.py @@ -18,6 +18,7 @@ from .config import MetaReferenceEvalConfig class ColumnName(Enum): + input_query = "input_query" expected_answer = "expected_answer" chat_completion_input = "chat_completion_input" completion_input = "completion_input" @@ -53,10 +54,12 @@ class MetaReferenceEvalImpl(Eval): expected_schemas = [ { + ColumnName.input_query.value: StringType(), ColumnName.expected_answer.value: StringType(), ColumnName.chat_completion_input.value: ChatCompletionInputType(), }, { + ColumnName.input_query.value: StringType(), ColumnName.expected_answer.value: StringType(), ColumnName.completion_input.value: CompletionInputType(), }, diff --git a/llama_stack/providers/impls/meta_reference/scoring/__init__.py b/llama_stack/providers/impls/meta_reference/scoring/__init__.py index 69d9b543a..002f74e86 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/__init__.py +++ b/llama_stack/providers/impls/meta_reference/scoring/__init__.py @@ -16,6 +16,8 @@ async def get_provider_impl( ): from .scoring import MetaReferenceScoringImpl - impl = MetaReferenceScoringImpl(config, deps[Api.datasetio], deps[Api.datasets]) + impl = MetaReferenceScoringImpl( + config, deps[Api.datasetio], deps[Api.datasets], deps[Api.inference] + ) await impl.initialize() return impl diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring.py b/llama_stack/providers/impls/meta_reference/scoring/scoring.py index b1d561533..41b24a512 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring.py @@ -11,24 +11,25 @@ from llama_stack.apis.scoring_functions import * # noqa: F403 from llama_stack.apis.common.type_system import * # noqa: F403 from llama_stack.apis.datasetio import * # noqa: F403 from llama_stack.apis.datasets import * # noqa: F403 - +from llama_stack.apis.inference.inference import Inference from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.equality_scoring_fn import ( EqualityScoringFn, ) +from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.llm_as_judge_scoring_fn import ( + LlmAsJudgeScoringFn, +) + from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.subset_of_scoring_fn import ( SubsetOfScoringFn, ) from .config import MetaReferenceScoringConfig -SUPPORTED_SCORING_FNS = [ - EqualityScoringFn, - SubsetOfScoringFn, -] +FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn] -SCORER_REGISTRY = {x.scoring_function_def.identifier: x for x in SUPPORTED_SCORING_FNS} +LLM_JUDGE_FNS = [LlmAsJudgeScoringFn] class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): @@ -37,22 +38,44 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): config: MetaReferenceScoringConfig, datasetio_api: DatasetIO, datasets_api: Datasets, + inference_api: Inference, ) -> None: self.config = config self.datasetio_api = datasetio_api self.datasets_api = datasets_api + self.inference_api = inference_api + self.scoring_fn_id_impls = {} - async def initialize(self) -> None: ... + async def initialize(self) -> None: + for x in FIXED_FNS: + impl = x() + for fn_defs in impl.get_supported_scoring_fn_defs(): + self.scoring_fn_id_impls[fn_defs.identifier] = impl + for x in LLM_JUDGE_FNS: + impl = x(inference_api=self.inference_api) + for fn_defs in impl.get_supported_scoring_fn_defs(): + self.scoring_fn_id_impls[fn_defs.identifier] = impl + self.llm_as_judge_fn = impl async def shutdown(self) -> None: ... async def list_scoring_functions(self) -> List[ScoringFnDef]: - return [x.scoring_function_def for x in SUPPORTED_SCORING_FNS] + scoring_fn_defs_list = [ + fn_def + for impl in self.scoring_fn_id_impls.values() + for fn_def in impl.get_supported_scoring_fn_defs() + ] + + for f in scoring_fn_defs_list: + assert f.identifier.startswith( + "meta-reference" + ), "All meta-reference scoring fn must have identifier prefixed with 'meta-reference'! " + + return scoring_fn_defs_list async def register_scoring_function(self, function_def: ScoringFnDef) -> None: - raise NotImplementedError( - "Dynamically registering scoring functions is not supported" - ) + self.llm_as_judge_fn.register_scoring_fn_def(function_def) + self.scoring_fn_id_impls[function_def.identifier] = self.llm_as_judge_fn async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None: dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id) @@ -99,11 +122,11 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): ) -> ScoreResponse: res = {} for scoring_fn_id in scoring_functions: - if scoring_fn_id not in SCORER_REGISTRY: + if scoring_fn_id not in self.scoring_fn_id_impls: raise ValueError(f"Scoring function {scoring_fn_id} is not supported.") - scoring_fn = SCORER_REGISTRY[scoring_fn_id]() - score_results = scoring_fn.score(input_rows) - agg_results = scoring_fn.aggregate(score_results) + scoring_fn = self.scoring_fn_id_impls[scoring_fn_id] + score_results = await scoring_fn.score(input_rows, scoring_fn_id) + agg_results = await scoring_fn.aggregate(score_results) res[scoring_fn_id] = ScoringResult( score_rows=score_results, aggregated_results=agg_results, diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/base_scoring_fn.py b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/base_scoring_fn.py index 952d46bb2..cbd875be6 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/base_scoring_fn.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/base_scoring_fn.py @@ -17,21 +17,41 @@ class BaseScoringFn(ABC): - aggregate(self, scoring_fn_results) """ - scoring_function_def: ScoringFnDef - def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) + self.supported_fn_defs_registry = {} def __str__(self) -> str: return self.__class__.__name__ + def get_supported_scoring_fn_defs(self) -> List[ScoringFnDef]: + return [x for x in self.supported_fn_defs_registry.values()] + + def register_scoring_fn_def(self, scoring_fn_def: ScoringFnDef) -> None: + if scoring_fn_def.identifier in self.supported_fn_defs_registry: + raise ValueError( + f"Scoring function def with identifier {scoring_fn_def.identifier} already exists." + ) + self.supported_fn_defs_registry[scoring_fn_def.identifier] = scoring_fn_def + @abstractmethod - def score_row(self, input_row: Dict[str, Any]) -> ScoringResultRow: + async def score_row( + self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None + ) -> ScoringResultRow: raise NotImplementedError() @abstractmethod - def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]: + async def aggregate( + self, scoring_results: List[ScoringResultRow] + ) -> Dict[str, Any]: raise NotImplementedError() - def score(self, input_rows: List[Dict[str, Any]]) -> List[ScoringResultRow]: - return [self.score_row(input_row) for input_row in input_rows] + async def score( + self, + input_rows: List[Dict[str, Any]], + scoring_fn_identifier: Optional[str] = None, + ) -> List[ScoringResultRow]: + return [ + await self.score_row(input_row, scoring_fn_identifier) + for input_row in input_rows + ] diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/common.py b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/common.py index 52eabea2e..25bac5edc 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/common.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/common.py @@ -3,10 +3,13 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from pathlib import Path from typing import Any, Dict, List from llama_stack.apis.scoring import ScoringResultRow +FN_DEFS_PATH = Path(__file__).parent / "fn_defs" + def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]: num_correct = sum(result["score"] for result in scoring_results) @@ -17,3 +20,12 @@ def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any "num_correct": num_correct, "num_total": len(scoring_results), } + + +def aggregate_average(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]: + return { + "average": sum( + result["score"] for result in scoring_results if result["score"] is not None + ) + / len([_ for _ in scoring_results if _["score"] is not None]), + } diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/equality_scoring_fn.py b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/equality_scoring_fn.py index cce0f948a..556436286 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/equality_scoring_fn.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/equality_scoring_fn.py @@ -10,24 +10,32 @@ from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.base_scoring_ from llama_stack.apis.scoring_functions import * # noqa: F401, F403 from llama_stack.apis.scoring import * # noqa: F401, F403 from llama_stack.apis.common.type_system import * # noqa: F403 + from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import ( aggregate_accuracy, ) +from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.fn_defs.equality import ( + equality, +) + class EqualityScoringFn(BaseScoringFn): """ A scoring_fn that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise. """ - scoring_function_def = ScoringFnDef( - identifier="equality", - description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.", - parameters=[], - return_type=NumberType(), - ) + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.supported_fn_defs_registry = { + equality.identifier: equality, + } - def score_row(self, input_row: Dict[str, Any]) -> ScoringResultRow: + async def score_row( + self, + input_row: Dict[str, Any], + scoring_fn_identifier: Optional[str] = "equality", + ) -> ScoringResultRow: assert "expected_answer" in input_row, "Expected answer not found in input row." assert ( "generated_answer" in input_row @@ -40,5 +48,7 @@ class EqualityScoringFn(BaseScoringFn): "score": score, } - def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]: + async def aggregate( + self, scoring_results: List[ScoringResultRow] + ) -> Dict[str, Any]: return aggregate_accuracy(scoring_results) diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/__init__.py b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/equality.py b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/equality.py new file mode 100644 index 000000000..99fa6cc3a --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/equality.py @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.common.type_system import NumberType +from llama_stack.apis.scoring_functions import ScoringFnDef + + +equality = ScoringFnDef( + identifier="meta-reference::equality", + description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.", + parameters=[], + return_type=NumberType(), +) diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py new file mode 100644 index 000000000..20a67edc7 --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py @@ -0,0 +1,36 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.scoring_functions import * # noqa: F401, F403 +from llama_stack.apis.scoring import * # noqa: F401, F403 +from llama_stack.apis.common.type_system import NumberType + +JUDGE_PROMPT = """ +You will be given a question, a expected_answer, and a system_answer. +Your task is to provide a 'total rating' scoring how well the system_answer answers compared with ground truth in expected_answer in terms of factual correctness to the question. +Give your answer as a integer on a scale of 0 to 5, where 0 means that the system_answer is not correct at all compared with expected_answer, and 5 means that the answer completely and correctly answers the question. +Provide your feedback as follows: +Feedback::: +Total rating: (your rating, as a int between 0 and 5) +Now here are the question, expected_answer, system_answer. +Question: {input_query} +Expected Answer: {expected_answer} +System Answer: {generated_answer} +Feedback::: +Total rating: +""" + +llm_as_judge_8b_correctness = ScoringFnDef( + identifier="meta-reference::llm_as_judge_8b_correctness", + description="Llm As Judge Scoring Function", + parameters=[], + return_type=NumberType(), + context=LLMAsJudgeContext( + prompt_template=JUDGE_PROMPT, + judge_model="Llama3.1-8B-Instruct", + judge_score_regex=[r"Total rating: (\d+)", r"rating: (\d+)", r"Rating: (\d+)"], + ), +) diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/subset_of.py b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/subset_of.py new file mode 100644 index 000000000..5a3e2e8fb --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/subset_of.py @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.common.type_system import NumberType +from llama_stack.apis.scoring_functions import ScoringFnDef + + +subset_of = ScoringFnDef( + identifier="meta-reference::subset_of", + description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.", + parameters=[], + return_type=NumberType(), +) diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py new file mode 100644 index 000000000..5a5ce2550 --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from llama_stack.apis.inference.inference import Inference +from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.base_scoring_fn import ( + BaseScoringFn, +) +from llama_stack.apis.scoring_functions import * # noqa: F401, F403 +from llama_stack.apis.scoring import * # noqa: F401, F403 +from llama_stack.apis.common.type_system import * # noqa: F403 +import re + +from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import ( + aggregate_average, +) +from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.fn_defs.llm_as_judge_8b_correctness import ( + llm_as_judge_8b_correctness, +) + + +class LlmAsJudgeScoringFn(BaseScoringFn): + """ + A scoring_fn that assigns + """ + + def __init__(self, inference_api: Inference, *arg, **kwargs) -> None: + super().__init__(*arg, **kwargs) + self.inference_api = inference_api + self.supported_fn_defs_registry = { + llm_as_judge_8b_correctness.identifier: llm_as_judge_8b_correctness, + } + + async def score_row( + self, + input_row: Dict[str, Any], + scoring_fn_identifier: Optional[str] = None, + ) -> ScoringResultRow: + assert ( + scoring_fn_identifier is not None + ), "Scoring function identifier not found." + fn_def = self.supported_fn_defs_registry[scoring_fn_identifier] + assert fn_def.context is not None, f"LLMAsJudgeContext not found for {fn_def}." + assert ( + fn_def.context.prompt_template is not None + ), "LLM Judge prompt_template not found." + assert ( + fn_def.context.judge_score_regex is not None + ), "LLM Judge judge_score_regex not found." + + input_query = input_row["input_query"] + expected_answer = input_row["expected_answer"] + generated_answer = input_row["generated_answer"] + + judge_input_msg = fn_def.context.prompt_template.format( + input_query=input_query, + expected_answer=expected_answer, + generated_answer=generated_answer, + ) + + judge_response = await self.inference_api.chat_completion( + model=fn_def.context.judge_model, + messages=[ + { + "role": "user", + "content": judge_input_msg, + } + ], + ) + content = judge_response.completion_message.content + rating_regexs = fn_def.context.judge_score_regex + + judge_rating = None + for regex in rating_regexs: + match = re.search(regex, content) + if match: + judge_rating = int(match.group(1)) + break + + return { + "score": judge_rating, + "judge_feedback": content, + } + + async def aggregate( + self, scoring_results: List[ScoringResultRow] + ) -> Dict[str, Any]: + return aggregate_average(scoring_results) diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/subset_of_scoring_fn.py b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/subset_of_scoring_fn.py index c7ee68e26..fcef2ead7 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/subset_of_scoring_fn.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/subset_of_scoring_fn.py @@ -14,25 +14,27 @@ from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import aggregate_accuracy, ) +from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.fn_defs.subset_of import ( + subset_of, +) + class SubsetOfScoringFn(BaseScoringFn): """ A scoring_fn that assigns a score of 1.0 if the expected string is included in the generated string, and 0.0 otherwise. """ - scoring_function_def = ScoringFnDef( - identifier="subset_of", - description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.", - parameters=[], - return_type=NumberType(), - ) - - def score_row(self, input_row: Dict[str, Any]) -> ScoringResultRow: - assert "expected_answer" in input_row, "Expected answer not found in input row." - assert ( - "generated_answer" in input_row - ), "Generated answer not found in input row." + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.supported_fn_defs_registry = { + subset_of.identifier: subset_of, + } + async def score_row( + self, + input_row: Dict[str, Any], + scoring_fn_identifier: Optional[str] = "subset_of", + ) -> ScoringResultRow: expected_answer = input_row["expected_answer"] generated_answer = input_row["generated_answer"] score = 1.0 if expected_answer in generated_answer else 0.0 @@ -40,5 +42,7 @@ class SubsetOfScoringFn(BaseScoringFn): "score": score, } - def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]: + async def aggregate( + self, scoring_results: List[ScoringResultRow] + ) -> Dict[str, Any]: return aggregate_accuracy(scoring_results) diff --git a/llama_stack/providers/registry/scoring.py b/llama_stack/providers/registry/scoring.py index 4543449b4..06983cdee 100644 --- a/llama_stack/providers/registry/scoring.py +++ b/llama_stack/providers/registry/scoring.py @@ -20,6 +20,7 @@ def available_providers() -> List[ProviderSpec]: api_dependencies=[ Api.datasetio, Api.datasets, + Api.inference, ], ), ] diff --git a/llama_stack/providers/tests/datasetio/test_datasetio.py b/llama_stack/providers/tests/datasetio/test_datasetio.py index 9bd80f94d..743e191d4 100644 --- a/llama_stack/providers/tests/datasetio/test_datasetio.py +++ b/llama_stack/providers/tests/datasetio/test_datasetio.py @@ -70,6 +70,7 @@ async def register_dataset( if for_generation: dataset_schema = { "expected_answer": StringType(), + "input_query": StringType(), "chat_completion_input": ChatCompletionInputType(), } else: diff --git a/llama_stack/providers/tests/eval/provider_config_example.yaml b/llama_stack/providers/tests/eval/provider_config_example.yaml index 1576d2ef0..38f7512f1 100644 --- a/llama_stack/providers/tests/eval/provider_config_example.yaml +++ b/llama_stack/providers/tests/eval/provider_config_example.yaml @@ -16,3 +16,7 @@ providers: provider_type: remote::tgi config: url: http://127.0.0.1:5009 + - provider_id: test-tgi-2 + provider_type: remote::tgi + config: + url: http://127.0.0.1:5010 diff --git a/llama_stack/providers/tests/eval/test_eval.py b/llama_stack/providers/tests/eval/test_eval.py index 6b0d99a22..667be1bd5 100644 --- a/llama_stack/providers/tests/eval/test_eval.py +++ b/llama_stack/providers/tests/eval/test_eval.py @@ -65,7 +65,10 @@ async def test_eval(eval_settings): model="Llama3.2-1B-Instruct", sampling_params=SamplingParams(), ), - scoring_functions=["subset_of"], + scoring_functions=[ + "meta-reference::subset_of", + "meta-reference::llm_as_judge_8b_correctness", + ], ) assert response.job_id == "0" job_status = await eval_impl.job_status(response.job_id) @@ -76,4 +79,5 @@ async def test_eval(eval_settings): assert eval_response is not None assert len(eval_response.generations) == 5 - assert "subset_of" in eval_response.scores + assert "meta-reference::subset_of" in eval_response.scores + assert "meta-reference::llm_as_judge_8b_correctness" in eval_response.scores diff --git a/llama_stack/providers/tests/scoring/provider_config_example.yaml b/llama_stack/providers/tests/scoring/provider_config_example.yaml index 9a8895149..9cf5713c1 100644 --- a/llama_stack/providers/tests/scoring/provider_config_example.yaml +++ b/llama_stack/providers/tests/scoring/provider_config_example.yaml @@ -7,3 +7,8 @@ providers: - provider_id: test-meta provider_type: meta-reference config: {} + inference: + - provider_id: tgi0 + provider_type: remote::tgi + config: + url: http://127.0.0.1:5009 diff --git a/llama_stack/providers/tests/scoring/test_scoring.py b/llama_stack/providers/tests/scoring/test_scoring.py index 2218faa54..86deecc71 100644 --- a/llama_stack/providers/tests/scoring/test_scoring.py +++ b/llama_stack/providers/tests/scoring/test_scoring.py @@ -33,7 +33,9 @@ from llama_stack.providers.tests.resolver import resolve_impls_for_test @pytest_asyncio.fixture(scope="session") async def scoring_settings(): - impls = await resolve_impls_for_test(Api.scoring, deps=[Api.datasetio]) + impls = await resolve_impls_for_test( + Api.scoring, deps=[Api.datasetio, Api.inference] + ) return { "scoring_impl": impls[Api.scoring], "scoring_functions_impl": impls[Api.scoring_functions], @@ -48,7 +50,50 @@ async def test_scoring_functions_list(scoring_settings): assert isinstance(scoring_functions, list) assert len(scoring_functions) > 0 function_ids = [f.identifier for f in scoring_functions] - assert "equality" in function_ids + assert "meta-reference::equality" in function_ids + assert "meta-reference::subset_of" in function_ids + assert "meta-reference::llm_as_judge_8b_correctness" in function_ids + + +@pytest.mark.asyncio +async def test_scoring_functions_register(scoring_settings): + scoring_impl = scoring_settings["scoring_impl"] + scoring_functions_impl = scoring_settings["scoring_functions_impl"] + datasets_impl = scoring_settings["datasets_impl"] + test_prompt = """Output a number between 0 to 10. Your answer must match the format \n Number: """ + # register the scoring function + await scoring_functions_impl.register_scoring_function( + ScoringFnDefWithProvider( + identifier="meta-reference::llm_as_judge_8b_random", + description="Llm As Judge Scoring Function", + parameters=[], + return_type=NumberType(), + context=LLMAsJudgeContext( + prompt_template=test_prompt, + judge_model="Llama3.1-8B-Instruct", + judge_score_regex=[r"Number: (\d+)"], + ), + provider_id="test-meta", + ) + ) + + scoring_functions = await scoring_functions_impl.list_scoring_functions() + assert isinstance(scoring_functions, list) + assert len(scoring_functions) > 0 + function_ids = [f.identifier for f in scoring_functions] + assert "meta-reference::llm_as_judge_8b_random" in function_ids + + # test score using newly registered scoring function + await register_dataset(datasets_impl) + response = await datasets_impl.list_datasets() + assert len(response) == 1 + response = await scoring_impl.score_batch( + dataset_id=response[0].identifier, + scoring_functions=[ + "meta-reference::llm_as_judge_8b_random", + ], + ) + assert "meta-reference::llm_as_judge_8b_random" in response.results @pytest.mark.asyncio @@ -62,8 +107,14 @@ async def test_scoring_score(scoring_settings): response = await scoring_impl.score_batch( dataset_id=response[0].identifier, - scoring_functions=["equality"], + scoring_functions=[ + "meta-reference::equality", + "meta-reference::subset_of", + "meta-reference::llm_as_judge_8b_correctness", + ], ) - assert len(response.results) == 1 - assert "equality" in response.results + assert len(response.results) == 3 + assert "meta-reference::equality" in response.results + assert "meta-reference::subset_of" in response.results + assert "meta-reference::llm_as_judge_8b_correctness" in response.results diff --git a/tests/examples/evals-tgi-run.yaml b/tests/examples/evals-tgi-run.yaml index e63523889..e98047654 100644 --- a/tests/examples/evals-tgi-run.yaml +++ b/tests/examples/evals-tgi-run.yaml @@ -33,6 +33,10 @@ providers: provider_type: remote::tgi config: url: http://127.0.0.1:5009 + - provider_id: tgi1 + provider_type: remote::tgi + config: + url: http://127.0.0.1:5010 memory: - provider_id: meta-reference provider_type: meta-reference