From fd424e7900512e30c0ecb22cb36a52c7089bfbff Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Mon, 11 Nov 2024 23:02:56 -0500 Subject: [PATCH] util --- .../basic/scoring_fn/equality_scoring_fn.py | 2 +- .../scoring_fn/regex_parser_scoring_fn.py | 2 +- .../basic/scoring_fn/subset_of_scoring_fn.py | 2 +- .../scoring_fn/base_scoring_fn.py | 61 ------------------- .../scoring_fn/llm_as_judge_scoring_fn.py | 2 +- .../scoring}/base_scoring_fn.py | 7 ++- 6 files changed, 8 insertions(+), 68 deletions(-) delete mode 100644 llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/base_scoring_fn.py rename llama_stack/providers/{inline/scoring/basic/scoring_fn => utils/scoring}/base_scoring_fn.py (91%) diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py index 877b64e4e..7eba4a21b 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .base_scoring_fn import BaseScoringFn +from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn from llama_stack.apis.scoring_functions import * # noqa: F401, F403 from llama_stack.apis.scoring import * # noqa: F401, F403 from llama_stack.apis.common.type_system import * # noqa: F403 diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py index 33773b7bb..fd036ced1 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import re -from .base_scoring_fn import BaseScoringFn +from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn from llama_stack.apis.scoring_functions import * # noqa: F401, F403 from llama_stack.apis.scoring import * # noqa: F401, F403 from llama_stack.apis.common.type_system import * # noqa: F403 diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/subset_of_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/subset_of_scoring_fn.py index fe5988160..1ff3c9b1c 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/subset_of_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/subset_of_scoring_fn.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .base_scoring_fn import BaseScoringFn +from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn from llama_stack.apis.scoring_functions import * # noqa: F401, F403 from llama_stack.apis.scoring import * # noqa: F401, F403 from llama_stack.apis.common.type_system import * # noqa: F403 diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/base_scoring_fn.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/base_scoring_fn.py deleted file mode 100644 index e356bc289..000000000 --- a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/base_scoring_fn.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. -from abc import ABC, abstractmethod -from typing import Any, Dict, List -from llama_stack.apis.scoring_functions import * # noqa: F401, F403 -from llama_stack.apis.scoring import * # noqa: F401, F403 - - -class BaseScoringFn(ABC): - """ - Base interface class for all meta-reference scoring_fns. - Each scoring_fn needs to implement the following methods: - - score_row(self, row) - - aggregate(self, scoring_fn_results) - """ - - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.supported_fn_defs_registry = {} - - def __str__(self) -> str: - return self.__class__.__name__ - - def get_supported_scoring_fn_defs(self) -> List[ScoringFn]: - return [x for x in self.supported_fn_defs_registry.values()] - - def register_scoring_fn_def(self, scoring_fn: ScoringFn) -> None: - if scoring_fn.identifier in self.supported_fn_defs_registry: - raise ValueError( - f"Scoring function def with identifier {scoring_fn.identifier} already exists." - ) - self.supported_fn_defs_registry[scoring_fn.identifier] = scoring_fn - - @abstractmethod - async def score_row( - self, - input_row: Dict[str, Any], - scoring_fn_identifier: Optional[str] = None, - scoring_params: Optional[ScoringFnParams] = None, - ) -> ScoringResultRow: - raise NotImplementedError() - - @abstractmethod - async def aggregate( - self, scoring_results: List[ScoringResultRow] - ) -> Dict[str, Any]: - raise NotImplementedError() - - async def score( - self, - input_rows: List[Dict[str, Any]], - scoring_fn_identifier: Optional[str] = None, - scoring_params: Optional[ScoringFnParams] = None, - ) -> List[ScoringResultRow]: - return [ - await self.score_row(input_row, scoring_fn_identifier, scoring_params) - for input_row in input_rows - ] diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py index e1f19e640..a950f35f9 100644 --- a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from llama_stack.apis.inference.inference import Inference -from .base_scoring_fn import BaseScoringFn +from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn from llama_stack.apis.scoring_functions import * # noqa: F401, F403 from llama_stack.apis.scoring import * # noqa: F401, F403 from llama_stack.apis.common.type_system import * # noqa: F403 diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/base_scoring_fn.py b/llama_stack/providers/utils/scoring/base_scoring_fn.py similarity index 91% rename from llama_stack/providers/inline/scoring/basic/scoring_fn/base_scoring_fn.py rename to llama_stack/providers/utils/scoring/base_scoring_fn.py index e356bc289..8cd101c50 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/base_scoring_fn.py +++ b/llama_stack/providers/utils/scoring/base_scoring_fn.py @@ -4,9 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. from abc import ABC, abstractmethod -from typing import Any, Dict, List -from llama_stack.apis.scoring_functions import * # noqa: F401, F403 -from llama_stack.apis.scoring import * # noqa: F401, F403 +from typing import Any, Dict, List, Optional + +from llama_stack.apis.scoring import ScoringFnParams, ScoringResultRow +from llama_stack.apis.scoring_functions import ScoringFn class BaseScoringFn(ABC):