From bf8bc7a781e562b04a115ea412d4b1a7385c0b9c Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Fri, 25 Oct 2024 15:03:03 -0700 Subject: [PATCH] wip scoring refactor --- .../impls/meta_reference/scoring/scoring.py | 4 +- .../scoring/scoring_fn/base_scoring_fn.py | 19 ++++-- .../scoring/scoring_fn/equality_scoring_fn.py | 10 ++- .../scoring_fn/llm_as_judge_scoring_fn.py | 61 +++++++++++++++++++ .../scoring_fn/subset_of_scoring_fn.py | 10 ++- .../scoring/braintrust/__init__.py | 18 ++++++ .../third_party/scoring/braintrust/config.py | 9 +++ llama_stack/providers/registry/scoring.py | 1 + .../scoring/provider_config_example.yaml | 5 ++ .../providers/tests/scoring/test_scoring.py | 9 ++- tests/examples/evals-tgi-run.yaml | 4 ++ 11 files changed, 137 insertions(+), 13 deletions(-) create mode 100644 llama_stack/providers/impls/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py create mode 100644 llama_stack/providers/impls/third_party/scoring/braintrust/__init__.py create mode 100644 llama_stack/providers/impls/third_party/scoring/braintrust/config.py diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring.py b/llama_stack/providers/impls/meta_reference/scoring/scoring.py index b1d561533..a41209520 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring.py @@ -102,8 +102,8 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): if scoring_fn_id not in SCORER_REGISTRY: raise ValueError(f"Scoring function {scoring_fn_id} is not supported.") scoring_fn = SCORER_REGISTRY[scoring_fn_id]() - score_results = scoring_fn.score(input_rows) - agg_results = scoring_fn.aggregate(score_results) + score_results = await scoring_fn.score(input_rows, scoring_fn_id) + agg_results = await scoring_fn.aggregate(score_results) res[scoring_fn_id] = ScoringResult( score_rows=score_results, aggregated_results=agg_results, diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/base_scoring_fn.py b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/base_scoring_fn.py index 952d46bb2..075684976 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/base_scoring_fn.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/base_scoring_fn.py @@ -26,12 +26,23 @@ class BaseScoringFn(ABC): return self.__class__.__name__ @abstractmethod - def score_row(self, input_row: Dict[str, Any]) -> ScoringResultRow: + async def score_row( + self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None + ) -> ScoringResultRow: raise NotImplementedError() @abstractmethod - def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]: + async def aggregate( + self, scoring_results: List[ScoringResultRow] + ) -> Dict[str, Any]: raise NotImplementedError() - def score(self, input_rows: List[Dict[str, Any]]) -> List[ScoringResultRow]: - return [self.score_row(input_row) for input_row in input_rows] + async def score( + self, + input_rows: List[Dict[str, Any]], + scoring_fn_identifier: Optional[str] = None, + ) -> List[ScoringResultRow]: + return [ + await self.score_row(input_row, scoring_fn_identifier) + for input_row in input_rows + ] diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/equality_scoring_fn.py b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/equality_scoring_fn.py index cce0f948a..dae619ee8 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/equality_scoring_fn.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/equality_scoring_fn.py @@ -27,7 +27,11 @@ class EqualityScoringFn(BaseScoringFn): return_type=NumberType(), ) - def score_row(self, input_row: Dict[str, Any]) -> ScoringResultRow: + async def score_row( + self, + input_row: Dict[str, Any], + scoring_fn_identifier: Optional[str] = "equality", + ) -> ScoringResultRow: assert "expected_answer" in input_row, "Expected answer not found in input row." assert ( "generated_answer" in input_row @@ -40,5 +44,7 @@ class EqualityScoringFn(BaseScoringFn): "score": score, } - def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]: + async def aggregate( + self, scoring_results: List[ScoringResultRow] + ) -> Dict[str, Any]: return aggregate_accuracy(scoring_results) diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py new file mode 100644 index 000000000..20546af50 --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py @@ -0,0 +1,61 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.base_scoring_fn import ( + BaseScoringFn, +) +from llama_stack.apis.scoring_functions import * # noqa: F401, F403 +from llama_stack.apis.scoring import * # noqa: F401, F403 +from llama_stack.apis.common.type_system import * # noqa: F403 +from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import ( + aggregate_accuracy, +) + +JUDGE_PROMPT = """ +You will be given a question, a expected_answer, and a system_answer. +Your task is to provide a 'total rating' scoring how well the system_answer answers compared with ground truth in expected_answer in terms of factual correctness to the question. +Give your answer as a integer on a scale of 0 to 5, where 0 means that the system_answer is not correct at all compared with expected_answer, and 5 means that the answer completely and correctly answers the question. +Provide your feedback as follows: +Feedback::: +Total rating: (your rating, as a int between 0 and 5) +Now here are the question, expected_answer, system_answer. +Question: {question} +Expected Answer: {expected_answer} +System Answer: {answer} +Feedback::: +Total rating: +""" + + +class LlmAsJudgeScoringFn(BaseScoringFn): + """ + A scoring_fn that assigns + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.scoring_fn_def_registry = {} + + def register_scoring_def(self, scoring_fn_def: ScoringFnDef) -> None: + self.scoring_function_def_registry[scoring_fn_def.identifier] = scoring_fn_def + + async def score_row(self, input_row: Dict[str, Any]) -> ScoringResultRow: + assert "expected_answer" in input_row, "Expected answer not found in input row." + assert ( + "generated_answer" in input_row + ), "Generated answer not found in input row." + + expected_answer = input_row["expected_answer"] + generated_answer = input_row["generated_answer"] + score = 1.0 if expected_answer == generated_answer else 0.0 + return { + "score": score, + } + + async def aggregate( + self, scoring_results: List[ScoringResultRow] + ) -> Dict[str, Any]: + return aggregate_accuracy(scoring_results) diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/subset_of_scoring_fn.py b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/subset_of_scoring_fn.py index c7ee68e26..68ff8e5a0 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/subset_of_scoring_fn.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/subset_of_scoring_fn.py @@ -27,7 +27,11 @@ class SubsetOfScoringFn(BaseScoringFn): return_type=NumberType(), ) - def score_row(self, input_row: Dict[str, Any]) -> ScoringResultRow: + async def score_row( + self, + input_row: Dict[str, Any], + scoring_fn_identifier: Optional[str] = "subset_of", + ) -> ScoringResultRow: assert "expected_answer" in input_row, "Expected answer not found in input row." assert ( "generated_answer" in input_row @@ -40,5 +44,7 @@ class SubsetOfScoringFn(BaseScoringFn): "score": score, } - def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]: + async def aggregate( + self, scoring_results: List[ScoringResultRow] + ) -> Dict[str, Any]: return aggregate_accuracy(scoring_results) diff --git a/llama_stack/providers/impls/third_party/scoring/braintrust/__init__.py b/llama_stack/providers/impls/third_party/scoring/braintrust/__init__.py new file mode 100644 index 000000000..f31d81060 --- /dev/null +++ b/llama_stack/providers/impls/third_party/scoring/braintrust/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from .config import BraintrustScoringConfig + + +async def get_provider_impl(config: BraintrustScoringConfig, _deps) -> Any: + pass + # from .braintrust import VLLMInferenceImpl + + # impl = VLLMInferenceImpl(config) + # await impl.initialize() + # return impl diff --git a/llama_stack/providers/impls/third_party/scoring/braintrust/config.py b/llama_stack/providers/impls/third_party/scoring/braintrust/config.py new file mode 100644 index 000000000..c720c9f67 --- /dev/null +++ b/llama_stack/providers/impls/third_party/scoring/braintrust/config.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from llama_stack.apis.eval import * # noqa: F401, F403 + + +class BraintrustScoringConfig(BaseModel): ... diff --git a/llama_stack/providers/registry/scoring.py b/llama_stack/providers/registry/scoring.py index 4543449b4..06983cdee 100644 --- a/llama_stack/providers/registry/scoring.py +++ b/llama_stack/providers/registry/scoring.py @@ -20,6 +20,7 @@ def available_providers() -> List[ProviderSpec]: api_dependencies=[ Api.datasetio, Api.datasets, + Api.inference, ], ), ] diff --git a/llama_stack/providers/tests/scoring/provider_config_example.yaml b/llama_stack/providers/tests/scoring/provider_config_example.yaml index 9a8895149..9cf5713c1 100644 --- a/llama_stack/providers/tests/scoring/provider_config_example.yaml +++ b/llama_stack/providers/tests/scoring/provider_config_example.yaml @@ -7,3 +7,8 @@ providers: - provider_id: test-meta provider_type: meta-reference config: {} + inference: + - provider_id: tgi0 + provider_type: remote::tgi + config: + url: http://127.0.0.1:5009 diff --git a/llama_stack/providers/tests/scoring/test_scoring.py b/llama_stack/providers/tests/scoring/test_scoring.py index 2218faa54..7806c4483 100644 --- a/llama_stack/providers/tests/scoring/test_scoring.py +++ b/llama_stack/providers/tests/scoring/test_scoring.py @@ -33,7 +33,9 @@ from llama_stack.providers.tests.resolver import resolve_impls_for_test @pytest_asyncio.fixture(scope="session") async def scoring_settings(): - impls = await resolve_impls_for_test(Api.scoring, deps=[Api.datasetio]) + impls = await resolve_impls_for_test( + Api.scoring, deps=[Api.datasetio, Api.inference] + ) return { "scoring_impl": impls[Api.scoring], "scoring_functions_impl": impls[Api.scoring_functions], @@ -62,8 +64,9 @@ async def test_scoring_score(scoring_settings): response = await scoring_impl.score_batch( dataset_id=response[0].identifier, - scoring_functions=["equality"], + scoring_functions=["equality", "subset_of"], ) - assert len(response.results) == 1 + assert len(response.results) == 2 assert "equality" in response.results + assert "subset_of" in response.results diff --git a/tests/examples/evals-tgi-run.yaml b/tests/examples/evals-tgi-run.yaml index e63523889..e98047654 100644 --- a/tests/examples/evals-tgi-run.yaml +++ b/tests/examples/evals-tgi-run.yaml @@ -33,6 +33,10 @@ providers: provider_type: remote::tgi config: url: http://127.0.0.1:5009 + - provider_id: tgi1 + provider_type: remote::tgi + config: + url: http://127.0.0.1:5010 memory: - provider_id: meta-reference provider_type: meta-reference