From 8aee752c19a18246dcefdcbb736cf17218573e9f Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Mon, 28 Oct 2024 11:24:32 -0700 Subject: [PATCH] change json -> class --- .../scoring/scoring_fn/base_scoring_fn.py | 8 +---- .../scoring/scoring_fn/equality_scoring_fn.py | 9 +++-- .../scoring/scoring_fn/fn_defs/__init__.py | 5 +++ .../scoring/scoring_fn/fn_defs/equality.json | 10 ------ .../scoring/scoring_fn/fn_defs/equality.py | 16 +++++++++ .../fn_defs/llm_as_judge_8b_correctness.json | 14 -------- .../fn_defs/llm_as_judge_8b_correctness.py | 35 +++++++++++++++++++ .../scoring/scoring_fn/fn_defs/subset_of.json | 10 ------ .../scoring/scoring_fn/fn_defs/subset_of.py | 16 +++++++++ .../scoring_fn/llm_as_judge_scoring_fn.py | 8 +++-- .../scoring_fn/subset_of_scoring_fn.py | 9 +++-- 11 files changed, 93 insertions(+), 47 deletions(-) create mode 100644 llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/__init__.py delete mode 100644 llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/equality.json create mode 100644 llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/equality.py delete mode 100644 llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.json create mode 100644 llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py delete mode 100644 llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/subset_of.json create mode 100644 llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/subset_of.py diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/base_scoring_fn.py b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/base_scoring_fn.py index a64bbf07b..52b48e5bb 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/base_scoring_fn.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/base_scoring_fn.py @@ -7,7 +7,6 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List from llama_stack.apis.scoring_functions import * # noqa: F401, F403 from llama_stack.apis.scoring import * # noqa: F401, F403 -import json class BaseScoringFn(ABC): @@ -21,16 +20,11 @@ class BaseScoringFn(ABC): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.supported_fn_defs_registry = {} - self.defs_paths = [] def __str__(self) -> str: return self.__class__.__name__ - async def initialize(self) -> None: - for f in self.defs_paths: - with open(f, "r") as f: - scoring_fn_def = ScoringFnDef(**json.load(f)) - self.register_scoring_fn_def(scoring_fn_def) + async def initialize(self) -> None: ... def get_supported_scoring_fn_defs(self) -> List[ScoringFnDef]: return [x for x in self.supported_fn_defs_registry.values()] diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/equality_scoring_fn.py b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/equality_scoring_fn.py index d9a0aa651..1b8d531aa 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/equality_scoring_fn.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/equality_scoring_fn.py @@ -13,7 +13,10 @@ from llama_stack.apis.common.type_system import * # noqa: F403 from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import ( aggregate_accuracy, - FN_DEFS_PATH, +) + +from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.fn_defs.equality import ( + equality_fn_def, ) @@ -24,7 +27,9 @@ class EqualityScoringFn(BaseScoringFn): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - self.defs_paths = [FN_DEFS_PATH / "equality.json"] + self.supported_fn_defs_registry = { + equality_fn_def.identifier: equality_fn_def, + } async def score_row( self, diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/__init__.py b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/equality.json b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/equality.json deleted file mode 100644 index e5397ffc9..000000000 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/equality.json +++ /dev/null @@ -1,10 +0,0 @@ -{ - "identifier": "meta-reference::equality", - "description": "Returns 1.0 if the input is equal to the target, 0.0 otherwise.", - "metadata": {}, - "parameters": [], - "return_type": { - "type": "number" - }, - "context": null -} diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/equality.py b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/equality.py new file mode 100644 index 000000000..cdc4fdc81 --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/equality.py @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.common.type_system import NumberType +from llama_stack.apis.scoring_functions import ScoringFnDef + + +equality_fn_def = ScoringFnDef( + identifier="meta-reference::equality", + description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.", + parameters=[], + return_type=NumberType(), +) diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.json b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.json deleted file mode 100644 index e33bc09ee..000000000 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.json +++ /dev/null @@ -1,14 +0,0 @@ -{ - "identifier": "meta-reference::llm_as_judge_8b_correctness", - "description": "Llm As Judge Scoring Function", - "metadata": {}, - "parameters": [], - "return_type": { - "type": "number" - }, - "context": { - "judge_model": "Llama3.1-8B-Instruct", - "prompt_template": "\nYou will be given a question, a expected_answer, and a system_answer.\nYour task is to provide a 'total rating' scoring how well the system_answer answers compared with ground truth in expected_answer in terms of factual correctness to the question.\nGive your answer as a integer on a scale of 0 to 5, where 0 means that the system_answer is not correct at all compared with expected_answer, and 5 means that the answer completely and correctly answers the question.\nProvide your feedback as follows:\nFeedback:::\nTotal rating: (your rating, as a int between 0 and 5)\nNow here are the question, expected_answer, system_answer.\nQuestion: {input_query}\nExpected Answer: {expected_answer}\nSystem Answer: {generated_answer}\nFeedback:::\nTotal rating:\n", - "judge_score_regex": ["Total rating: (\\d+)", "rating: (\\d+)", "Rating: (\\d+)"] - } -} diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py new file mode 100644 index 000000000..215f4649e --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.scoring_functions import * # noqa: F401, F403 +from llama_stack.apis.scoring import * # noqa: F401, F403 +from llama_stack.apis.common.type_system import NumberType + +JUDGE_PROMPT = """ +You will be given a question, a expected_answer, and a system_answer. +Your task is to provide a 'total rating' scoring how well the system_answer answers compared with ground truth in expected_answer in terms of factual correctness to the question. +Give your answer as a integer on a scale of 0 to 5, where 0 means that the system_answer is not correct at all compared with expected_answer, and 5 means that the answer completely and correctly answers the question. +Provide your feedback as follows: +Feedback::: +Total rating: (your rating, as a int between 0 and 5) +Now here are the question, expected_answer, system_answer. +Question: {input_query} +Expected Answer: {expected_answer} +System Answer: {generated_answer} +Feedback::: +Total rating: +""" +llm_as_judge_8b_correctness_fn_def = ScoringFnDef( + identifier="meta-reference::llm_as_judge_8b_correctness", + description="Llm As Judge Scoring Function", + parameters=[], + return_type=NumberType(), + context=LLMAsJudgeContext( + prompt_template=JUDGE_PROMPT, + judge_model="Llama3.1-8B-Instruct", + judge_score_regex=[r"Total rating: (\d+)", r"rating: (\d+)", r"Rating: (\d+)"], + ), +) diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/subset_of.json b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/subset_of.json deleted file mode 100644 index 1beb65a3d..000000000 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/subset_of.json +++ /dev/null @@ -1,10 +0,0 @@ -{ - "identifier": "meta-reference::subset_of", - "description": "Returns 1.0 if the expected is included in generated, 0.0 otherwise.", - "metadata": {}, - "parameters": [], - "return_type": { - "type": "number" - }, - "context": null -} diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/subset_of.py b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/subset_of.py new file mode 100644 index 000000000..c3cf8e960 --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/fn_defs/subset_of.py @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.common.type_system import NumberType +from llama_stack.apis.scoring_functions import ScoringFnDef + + +subset_of_fn_def = ScoringFnDef( + identifier="meta-reference::subset_of", + description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.", + parameters=[], + return_type=NumberType(), +) diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py index 559ef41c9..cc8e04048 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py @@ -14,7 +14,9 @@ import re from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import ( aggregate_average, - FN_DEFS_PATH, +) +from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.fn_defs.llm_as_judge_8b_correctness import ( + llm_as_judge_8b_correctness_fn_def, ) @@ -26,7 +28,9 @@ class LlmAsJudgeScoringFn(BaseScoringFn): def __init__(self, inference_api: Inference, *arg, **kwargs) -> None: super().__init__(*arg, **kwargs) self.inference_api = inference_api - self.defs_paths = [FN_DEFS_PATH / "llm_as_judge_8b_correctness.json"] + self.supported_fn_defs_registry = { + llm_as_judge_8b_correctness_fn_def.identifier: llm_as_judge_8b_correctness_fn_def, + } async def score_row( self, diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/subset_of_scoring_fn.py b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/subset_of_scoring_fn.py index a358c337b..394aa8177 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/subset_of_scoring_fn.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring_fn/subset_of_scoring_fn.py @@ -12,7 +12,10 @@ from llama_stack.apis.scoring import * # noqa: F401, F403 from llama_stack.apis.common.type_system import * # noqa: F403 from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.common import ( aggregate_accuracy, - FN_DEFS_PATH, +) + +from llama_stack.providers.impls.meta_reference.scoring.scoring_fn.fn_defs.subset_of import ( + subset_of_fn_def, ) @@ -23,7 +26,9 @@ class SubsetOfScoringFn(BaseScoringFn): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - self.defs_paths = [FN_DEFS_PATH / "subset_of.json"] + self.supported_fn_defs_registry = { + subset_of_fn_def.identifier: subset_of_fn_def, + } async def score_row( self,