From 16620a81856c2893236132a1ea5262060fdc8f00 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Fri, 25 Oct 2024 16:41:36 -0700 Subject: [PATCH] llm as judge, move folders --- MANIFEST.in | 1 + llama_stack/distribution/routers/routers.py | 1 + .../impls/meta_reference/scoring/__init__.py | 4 +- .../impls/meta_reference/scoring/scoring.py | 38 ++++++--- .../scoring/scoring_fn/base_scoring_fn.py | 21 ++++- .../scoring/scoring_fn/common.py | 12 +++ .../scoring/scoring_fn/equality_scoring_fn.py | 11 ++- .../scoring/scoring_fn/fn_defs/equality.json | 10 +++ .../fn_defs/llm_as_judge_8b_correctness.json | 13 +++ .../scoring/scoring_fn/fn_defs/subset_of.json | 10 +++ .../scoring_fn/llm_as_judge_scoring_fn.py | 85 ++++++++++++------- .../scoring_fn/subset_of_scoring_fn.py | 15 +--- .../providers/tests/scoring/test_scoring.py | 18 ++-- 13 files changed, 173 insertions(+), 66 deletions(-) create mode 100644 llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/equality.json create mode 100644 llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.json create mode 100644 llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/subset_of.json diff --git a/MANIFEST.in b/MANIFEST.in index 0517b86a8..09c34cec5 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -2,3 +2,4 @@ include requirements.txt include llama_stack/distribution/*.sh include llama_stack/cli/scripts/*.sh include llama_stack/templates/*/build.yaml +include llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/*.json diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 348d8449d..4bf3d0187 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -97,6 +97,7 @@ class InferenceRouter(Inference): logprobs=logprobs, ) provider = self.routing_table.get_provider_impl(model) + if stream: return (chunk async for chunk in await provider.chat_completion(**params)) else: diff --git a/llama_stack/providers/impls/meta_reference/scoring/__init__.py b/llama_stack/providers/impls/meta_reference/scoring/__init__.py index 69d9b543a..002f74e86 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/__init__.py +++ b/llama_stack/providers/impls/meta_reference/scoring/__init__.py @@ -16,6 +16,8 @@ async def get_provider_impl( ): from .scoring import MetaReferenceScoringImpl - impl = MetaReferenceScoringImpl(config, deps[Api.datasetio], deps[Api.datasets]) + impl = MetaReferenceScoringImpl( + config, deps[Api.datasetio], deps[Api.datasets], deps[Api.inference] + ) await impl.initialize() return impl diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring.py b/llama_stack/providers/impls/meta_reference/scoring/scoring.py index a41209520..31ccae5b8 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring.py @@ -11,24 +11,25 @@ from llama_stack.apis.scoring_functions import * # noqa: F403 from llama_stack.apis.common.type_system import * # noqa: F403 from llama_stack.apis.datasetio import * # noqa: F403 from llama_stack.apis.datasets import * # noqa: F403 - +from llama_stack.apis.inference.inference import Inference from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.equality_scoring_fn import ( EqualityScoringFn, ) +from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.llm_as_judge_scoring_fn import ( + LlmAsJudgeScoringFn, +) + from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.subset_of_scoring_fn import ( SubsetOfScoringFn, ) from .config import MetaReferenceScoringConfig -SUPPORTED_SCORING_FNS = [ - EqualityScoringFn, - SubsetOfScoringFn, -] +FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn] -SCORER_REGISTRY = {x.scoring_function_def.identifier: x for x in SUPPORTED_SCORING_FNS} +LLM_JUDGE_FNS = [LlmAsJudgeScoringFn] class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): @@ -37,17 +38,34 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): config: MetaReferenceScoringConfig, datasetio_api: DatasetIO, datasets_api: Datasets, + inference_api: Inference, ) -> None: self.config = config self.datasetio_api = datasetio_api self.datasets_api = datasets_api + self.inference_api = inference_api + self.scoring_fn_id_impls = {} - async def initialize(self) -> None: ... + async def initialize(self) -> None: + for x in FIXED_FNS: + impl = x() + await impl.initialize() + for fn_defs in impl.get_supported_scoring_fn_defs(): + self.scoring_fn_id_impls[fn_defs.identifier] = impl + for x in LLM_JUDGE_FNS: + impl = x(inference_api=self.inference_api) + await impl.initialize() + for fn_defs in impl.get_supported_scoring_fn_defs(): + self.scoring_fn_id_impls[fn_defs.identifier] = impl async def shutdown(self) -> None: ... async def list_scoring_functions(self) -> List[ScoringFnDef]: - return [x.scoring_function_def for x in SUPPORTED_SCORING_FNS] + return [ + fn_defs + for impl in self.scoring_fn_id_impls.values() + for fn_defs in impl.get_supported_scoring_fn_defs() + ] async def register_scoring_function(self, function_def: ScoringFnDef) -> None: raise NotImplementedError( @@ -99,9 +117,9 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): ) -> ScoreResponse: res = {} for scoring_fn_id in scoring_functions: - if scoring_fn_id not in SCORER_REGISTRY: + if scoring_fn_id not in self.scoring_fn_id_impls: raise ValueError(f"Scoring function {scoring_fn_id} is not supported.") - scoring_fn = SCORER_REGISTRY[scoring_fn_id]() + scoring_fn = self.scoring_fn_id_impls[scoring_fn_id] score_results = await scoring_fn.score(input_rows, scoring_fn_id) agg_results = await scoring_fn.aggregate(score_results) res[scoring_fn_id] = ScoringResult( diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/base_scoring_fn.py b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/base_scoring_fn.py index 075684976..a64bbf07b 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/base_scoring_fn.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/base_scoring_fn.py @@ -7,6 +7,7 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List from llama_stack.apis.scoring_functions import * # noqa: F401, F403 from llama_stack.apis.scoring import * # noqa: F401, F403 +import json class BaseScoringFn(ABC): @@ -17,14 +18,30 @@ class BaseScoringFn(ABC): - aggregate(self, scoring_fn_results) """ - scoring_function_def: ScoringFnDef - def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) + self.supported_fn_defs_registry = {} + self.defs_paths = [] def __str__(self) -> str: return self.__class__.__name__ + async def initialize(self) -> None: + for f in self.defs_paths: + with open(f, "r") as f: + scoring_fn_def = ScoringFnDef(**json.load(f)) + self.register_scoring_fn_def(scoring_fn_def) + + def get_supported_scoring_fn_defs(self) -> List[ScoringFnDef]: + return [x for x in self.supported_fn_defs_registry.values()] + + def register_scoring_fn_def(self, scoring_fn_def: ScoringFnDef) -> None: + if scoring_fn_def.identifier in self.supported_fn_defs_registry: + raise ValueError( + f"Scoring function def with identifier {scoring_fn_def.identifier} already exists." + ) + self.supported_fn_defs_registry[scoring_fn_def.identifier] = scoring_fn_def + @abstractmethod async def score_row( self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/common.py b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/common.py index 52eabea2e..25bac5edc 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/common.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/common.py @@ -3,10 +3,13 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from pathlib import Path from typing import Any, Dict, List from llama_stack.apis.scoring import ScoringResultRow +FN_DEFS_PATH = Path(__file__).parent / "fn_defs" + def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]: num_correct = sum(result["score"] for result in scoring_results) @@ -17,3 +20,12 @@ def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any "num_correct": num_correct, "num_total": len(scoring_results), } + + +def aggregate_average(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]: + return { + "average": sum( + result["score"] for result in scoring_results if result["score"] is not None + ) + / len([_ for _ in scoring_results if _["score"] is not None]), + } diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/equality_scoring_fn.py b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/equality_scoring_fn.py index dae619ee8..d9a0aa651 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/equality_scoring_fn.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/equality_scoring_fn.py @@ -10,8 +10,10 @@ from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.base_scoring_ from llama_stack.apis.scoring_functions import * # noqa: F401, F403 from llama_stack.apis.scoring import * # noqa: F401, F403 from llama_stack.apis.common.type_system import * # noqa: F403 + from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import ( aggregate_accuracy, + FN_DEFS_PATH, ) @@ -20,12 +22,9 @@ class EqualityScoringFn(BaseScoringFn): A scoring_fn that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise. """ - scoring_function_def = ScoringFnDef( - identifier="equality", - description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.", - parameters=[], - return_type=NumberType(), - ) + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.defs_paths = [FN_DEFS_PATH / "equality.json"] async def score_row( self, diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/equality.json b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/equality.json new file mode 100644 index 000000000..e5397ffc9 --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/equality.json @@ -0,0 +1,10 @@ +{ + "identifier": "meta-reference::equality", + "description": "Returns 1.0 if the input is equal to the target, 0.0 otherwise.", + "metadata": {}, + "parameters": [], + "return_type": { + "type": "number" + }, + "context": null +} diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.json b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.json new file mode 100644 index 000000000..64d86a7ea --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.json @@ -0,0 +1,13 @@ +{ + "identifier": "meta-reference::llm_as_judge_8b_correctness", + "description": "Llm As Judge Scoring Function", + "metadata": {}, + "parameters": [], + "return_type": { + "type": "number" + }, + "context": { + "judge_model": "Llama3.1-8B-Instruct", + "prompt_template": "\nYou will be given a question, a expected_answer, and a system_answer.\nYour task is to provide a 'total rating' scoring how well the system_answer answers compared with ground truth in expected_answer in terms of factual correctness to the question.\nGive your answer as a integer on a scale of 0 to 5, where 0 means that the system_answer is not correct at all compared with expected_answer, and 5 means that the answer completely and correctly answers the question.\nProvide your feedback as follows:\nFeedback:::\nTotal rating: (your rating, as a int between 0 and 5)\nNow here are the question, expected_answer, system_answer.\nQuestion: {input_query}\nExpected Answer: {expected_answer}\nSystem Answer: {generated_answer}\nFeedback:::\nTotal rating:\n" + } +} diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/subset_of.json b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/subset_of.json new file mode 100644 index 000000000..1beb65a3d --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/subset_of.json @@ -0,0 +1,10 @@ +{ + "identifier": "meta-reference::subset_of", + "description": "Returns 1.0 if the expected is included in generated, 0.0 otherwise.", + "metadata": {}, + "parameters": [], + "return_type": { + "type": "number" + }, + "context": null +} diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py index 20546af50..16672434f 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py @@ -3,31 +3,19 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. - +from llama_stack.apis.inference.inference import Inference from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.base_scoring_fn import ( BaseScoringFn, ) from llama_stack.apis.scoring_functions import * # noqa: F401, F403 from llama_stack.apis.scoring import * # noqa: F401, F403 from llama_stack.apis.common.type_system import * # noqa: F403 -from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import ( - aggregate_accuracy, -) +import re -JUDGE_PROMPT = """ -You will be given a question, a expected_answer, and a system_answer. -Your task is to provide a 'total rating' scoring how well the system_answer answers compared with ground truth in expected_answer in terms of factual correctness to the question. -Give your answer as a integer on a scale of 0 to 5, where 0 means that the system_answer is not correct at all compared with expected_answer, and 5 means that the answer completely and correctly answers the question. -Provide your feedback as follows: -Feedback::: -Total rating: (your rating, as a int between 0 and 5) -Now here are the question, expected_answer, system_answer. -Question: {question} -Expected Answer: {expected_answer} -System Answer: {answer} -Feedback::: -Total rating: -""" +from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import ( + aggregate_average, + FN_DEFS_PATH, +) class LlmAsJudgeScoringFn(BaseScoringFn): @@ -35,27 +23,62 @@ class LlmAsJudgeScoringFn(BaseScoringFn): A scoring_fn that assigns """ - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.scoring_fn_def_registry = {} + def __init__(self, inference_api: Inference, *arg, **kwargs) -> None: + super().__init__(*arg, **kwargs) + self.inference_api = inference_api + self.defs_paths = [FN_DEFS_PATH / "llm_as_judge_8b_correctness.json"] - def register_scoring_def(self, scoring_fn_def: ScoringFnDef) -> None: - self.scoring_function_def_registry[scoring_fn_def.identifier] = scoring_fn_def - - async def score_row(self, input_row: Dict[str, Any]) -> ScoringResultRow: - assert "expected_answer" in input_row, "Expected answer not found in input row." + async def score_row( + self, + input_row: Dict[str, Any], + scoring_fn_identifier: Optional[str] = None, + ) -> ScoringResultRow: assert ( - "generated_answer" in input_row - ), "Generated answer not found in input row." + scoring_fn_identifier is not None + ), "Scoring function identifier not found." + fn_def = self.supported_fn_defs_registry[scoring_fn_identifier] + assert ( + fn_def.context is not None and fn_def.context.prompt_template is not None + ), "LLM Judge prompt_template not found." + input_query = input_row["input_query"] expected_answer = input_row["expected_answer"] generated_answer = input_row["generated_answer"] - score = 1.0 if expected_answer == generated_answer else 0.0 + + judge_input_msg = fn_def.context.prompt_template.format( + input_query=input_query, + expected_answer=expected_answer, + generated_answer=generated_answer, + ) + + judge_response = await self.inference_api.chat_completion( + model=fn_def.context.judge_model, + messages=[ + { + "role": "user", + "content": judge_input_msg, + } + ], + ) + content = judge_response.completion_message.content + rating_regexs = [ + r"Total rating: (\d+)", + r"rating: (\d+)", + r"Rating: (\d+)", + ] + judge_rating = None + for regex in rating_regexs: + match = re.search(regex, content) + if match: + judge_rating = int(match.group(1)) + break + return { - "score": score, + "score": judge_rating, + "judge_feedback": content, } async def aggregate( self, scoring_results: List[ScoringResultRow] ) -> Dict[str, Any]: - return aggregate_accuracy(scoring_results) + return aggregate_average(scoring_results) diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/subset_of_scoring_fn.py b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/subset_of_scoring_fn.py index 68ff8e5a0..a358c337b 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/subset_of_scoring_fn.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/subset_of_scoring_fn.py @@ -12,6 +12,7 @@ from llama_stack.apis.scoring import * # noqa: F401, F403 from llama_stack.apis.common.type_system import * # noqa: F403 from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import ( aggregate_accuracy, + FN_DEFS_PATH, ) @@ -20,23 +21,15 @@ class SubsetOfScoringFn(BaseScoringFn): A scoring_fn that assigns a score of 1.0 if the expected string is included in the generated string, and 0.0 otherwise. """ - scoring_function_def = ScoringFnDef( - identifier="subset_of", - description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.", - parameters=[], - return_type=NumberType(), - ) + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.defs_paths = [FN_DEFS_PATH / "subset_of.json"] async def score_row( self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = "subset_of", ) -> ScoringResultRow: - assert "expected_answer" in input_row, "Expected answer not found in input row." - assert ( - "generated_answer" in input_row - ), "Generated answer not found in input row." - expected_answer = input_row["expected_answer"] generated_answer = input_row["generated_answer"] score = 1.0 if expected_answer in generated_answer else 0.0 diff --git a/llama_stack/providers/tests/scoring/test_scoring.py b/llama_stack/providers/tests/scoring/test_scoring.py index 7806c4483..52904ac1e 100644 --- a/llama_stack/providers/tests/scoring/test_scoring.py +++ b/llama_stack/providers/tests/scoring/test_scoring.py @@ -50,7 +50,9 @@ async def test_scoring_functions_list(scoring_settings): assert isinstance(scoring_functions, list) assert len(scoring_functions) > 0 function_ids = [f.identifier for f in scoring_functions] - assert "equality" in function_ids + assert "meta-reference::equality" in function_ids + assert "meta-reference::subset_of" in function_ids + assert "meta-reference::llm_as_judge_8b_correctness" in function_ids @pytest.mark.asyncio @@ -64,9 +66,15 @@ async def test_scoring_score(scoring_settings): response = await scoring_impl.score_batch( dataset_id=response[0].identifier, - scoring_functions=["equality", "subset_of"], + scoring_functions=[ + "meta-reference::equality", + "meta-reference::subset_of", + "meta-reference::llm_as_judge_8b_correctness", + ], ) - assert len(response.results) == 2 - assert "equality" in response.results - assert "subset_of" in response.results + print(response) + assert len(response.results) == 3 + assert "meta-reference::equality" in response.results + assert "meta-reference::subset_of" in response.results + assert "meta-reference::llm_as_judge_8b_correctness" in response.results