From 84c6fbbd933f58a87f5c7eb312c13c032753f8d5 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Tue, 12 Nov 2024 10:35:44 -0500 Subject: [PATCH] fix tests after registration migration & rename meta-reference -> basic / llm_as_judge provider (#424) * rename meta-reference -> basic * config rename * impl rename * rename llm_as_judge, fix test * util * rebase * naming fix --- .../inline/scoring/basic/__init__.py | 25 ++++ .../{meta_reference => basic}/config.py | 4 +- .../{meta_reference => basic}/scoring.py | 33 ++--- .../scoring_fn/__init__.py | 0 .../scoring_fn/equality_scoring_fn.py | 2 +- .../scoring_fn/fn_defs/__init__.py | 0 .../scoring_fn/fn_defs/equality.py | 4 +- .../regex_parser_multiple_choice_answer.py | 4 +- .../scoring_fn/fn_defs/subset_of.py | 4 +- .../scoring_fn/regex_parser_scoring_fn.py | 2 +- .../scoring_fn/subset_of_scoring_fn.py | 2 +- .../inline/scoring/braintrust/braintrust.py | 8 +- .../__init__.py | 8 +- .../inline/scoring/llm_as_judge/config.py | 9 ++ .../inline/scoring/llm_as_judge/scoring.py | 131 ++++++++++++++++++ .../llm_as_judge/scoring_fn/__init__.py | 5 + .../scoring_fn/fn_defs/__init__.py | 5 + .../scoring_fn/fn_defs/llm_as_judge_base.py | 4 +- .../scoring_fn/llm_as_judge_scoring_fn.py | 2 +- llama_stack/providers/registry/scoring.py | 19 ++- .../providers/tests/scoring/conftest.py | 27 ++-- .../providers/tests/scoring/fixtures.py | 23 ++- .../providers/tests/scoring/test_scoring.py | 13 +- .../scoring}/base_scoring_fn.py | 7 +- 24 files changed, 268 insertions(+), 73 deletions(-) create mode 100644 llama_stack/providers/inline/scoring/basic/__init__.py rename llama_stack/providers/inline/scoring/{meta_reference => basic}/config.py (65%) rename llama_stack/providers/inline/scoring/{meta_reference => basic}/scoring.py (80%) rename llama_stack/providers/inline/scoring/{meta_reference => basic}/scoring_fn/__init__.py (100%) rename llama_stack/providers/inline/scoring/{meta_reference => basic}/scoring_fn/equality_scoring_fn.py (95%) rename llama_stack/providers/inline/scoring/{meta_reference => basic}/scoring_fn/fn_defs/__init__.py (100%) rename llama_stack/providers/inline/scoring/{meta_reference => basic}/scoring_fn/fn_defs/equality.py (86%) rename llama_stack/providers/inline/scoring/{meta_reference => basic}/scoring_fn/fn_defs/regex_parser_multiple_choice_answer.py (95%) rename llama_stack/providers/inline/scoring/{meta_reference => basic}/scoring_fn/fn_defs/subset_of.py (86%) rename llama_stack/providers/inline/scoring/{meta_reference => basic}/scoring_fn/regex_parser_scoring_fn.py (96%) rename llama_stack/providers/inline/scoring/{meta_reference => basic}/scoring_fn/subset_of_scoring_fn.py (95%) rename llama_stack/providers/inline/scoring/{meta_reference => llm_as_judge}/__init__.py (73%) create mode 100644 llama_stack/providers/inline/scoring/llm_as_judge/config.py create mode 100644 llama_stack/providers/inline/scoring/llm_as_judge/scoring.py create mode 100644 llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/__init__.py create mode 100644 llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/__init__.py rename llama_stack/providers/inline/scoring/{meta_reference => llm_as_judge}/scoring_fn/fn_defs/llm_as_judge_base.py (84%) rename llama_stack/providers/inline/scoring/{meta_reference => llm_as_judge}/scoring_fn/llm_as_judge_scoring_fn.py (97%) rename llama_stack/providers/{inline/scoring/meta_reference/scoring_fn => utils/scoring}/base_scoring_fn.py (91%) diff --git a/llama_stack/providers/inline/scoring/basic/__init__.py b/llama_stack/providers/inline/scoring/basic/__init__.py new file mode 100644 index 000000000..c72434e9e --- /dev/null +++ b/llama_stack/providers/inline/scoring/basic/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from typing import Dict + +from llama_stack.distribution.datatypes import Api, ProviderSpec + +from .config import BasicScoringConfig + + +async def get_provider_impl( + config: BasicScoringConfig, + deps: Dict[Api, ProviderSpec], +): + from .scoring import BasicScoringImpl + + impl = BasicScoringImpl( + config, + deps[Api.datasetio], + deps[Api.datasets], + ) + await impl.initialize() + return impl diff --git a/llama_stack/providers/inline/scoring/meta_reference/config.py b/llama_stack/providers/inline/scoring/basic/config.py similarity index 65% rename from llama_stack/providers/inline/scoring/meta_reference/config.py rename to llama_stack/providers/inline/scoring/basic/config.py index bd4dcb9f0..d9dbe71bc 100644 --- a/llama_stack/providers/inline/scoring/meta_reference/config.py +++ b/llama_stack/providers/inline/scoring/basic/config.py @@ -3,7 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.apis.scoring import * # noqa: F401, F403 +from pydantic import BaseModel -class MetaReferenceScoringConfig(BaseModel): ... +class BasicScoringConfig(BaseModel): ... diff --git a/llama_stack/providers/inline/scoring/meta_reference/scoring.py b/llama_stack/providers/inline/scoring/basic/scoring.py similarity index 80% rename from llama_stack/providers/inline/scoring/meta_reference/scoring.py rename to llama_stack/providers/inline/scoring/basic/scoring.py index b78379062..98803ae4a 100644 --- a/llama_stack/providers/inline/scoring/meta_reference/scoring.py +++ b/llama_stack/providers/inline/scoring/basic/scoring.py @@ -11,44 +11,33 @@ from llama_stack.apis.scoring_functions import * # noqa: F403 from llama_stack.apis.common.type_system import * # noqa: F403 from llama_stack.apis.datasetio import * # noqa: F403 from llama_stack.apis.datasets import * # noqa: F403 -from llama_stack.apis.inference.inference import Inference from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate -from .config import MetaReferenceScoringConfig +from .config import BasicScoringConfig from .scoring_fn.equality_scoring_fn import EqualityScoringFn -from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn] -LLM_JUDGE_FNS = [LlmAsJudgeScoringFn] - -class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): +class BasicScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): def __init__( self, - config: MetaReferenceScoringConfig, + config: BasicScoringConfig, datasetio_api: DatasetIO, datasets_api: Datasets, - inference_api: Inference, ) -> None: self.config = config self.datasetio_api = datasetio_api self.datasets_api = datasets_api - self.inference_api = inference_api self.scoring_fn_id_impls = {} async def initialize(self) -> None: - for x in FIXED_FNS: - impl = x() + for fn in FIXED_FNS: + impl = fn() for fn_defs in impl.get_supported_scoring_fn_defs(): self.scoring_fn_id_impls[fn_defs.identifier] = impl - for x in LLM_JUDGE_FNS: - impl = x(inference_api=self.inference_api) - for fn_defs in impl.get_supported_scoring_fn_defs(): - self.scoring_fn_id_impls[fn_defs.identifier] = impl - self.llm_as_judge_fn = impl async def shutdown(self) -> None: ... @@ -61,8 +50,8 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): for f in scoring_fn_defs_list: assert f.identifier.startswith( - "meta-reference" - ), "All meta-reference scoring fn must have identifier prefixed with 'meta-reference'! " + "basic" + ), "All basic scoring fn must have identifier prefixed with 'basic'! " return scoring_fn_defs_list @@ -70,18 +59,18 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): raise NotImplementedError("Register scoring function not implemented yet") async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None: - dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id) - if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0: + dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) + if not dataset_def.schema or len(dataset_def.schema) == 0: raise ValueError( f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset." ) for required_column in ["generated_answer", "expected_answer", "input_query"]: - if required_column not in dataset_def.dataset_schema: + if required_column not in dataset_def.schema: raise ValueError( f"Dataset {dataset_id} does not have a '{required_column}' column." ) - if dataset_def.dataset_schema[required_column].type != "string": + if dataset_def.schema[required_column].type != "string": raise ValueError( f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'." ) diff --git a/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/__init__.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/__init__.py similarity index 100% rename from llama_stack/providers/inline/scoring/meta_reference/scoring_fn/__init__.py rename to llama_stack/providers/inline/scoring/basic/scoring_fn/__init__.py diff --git a/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/equality_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py similarity index 95% rename from llama_stack/providers/inline/scoring/meta_reference/scoring_fn/equality_scoring_fn.py rename to llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py index 877b64e4e..7eba4a21b 100644 --- a/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/equality_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .base_scoring_fn import BaseScoringFn +from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn from llama_stack.apis.scoring_functions import * # noqa: F401, F403 from llama_stack.apis.scoring import * # noqa: F401, F403 from llama_stack.apis.common.type_system import * # noqa: F403 diff --git a/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/fn_defs/__init__.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/__init__.py similarity index 100% rename from llama_stack/providers/inline/scoring/meta_reference/scoring_fn/fn_defs/__init__.py rename to llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/__init__.py diff --git a/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/fn_defs/equality.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/equality.py similarity index 86% rename from llama_stack/providers/inline/scoring/meta_reference/scoring_fn/fn_defs/equality.py rename to llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/equality.py index b3fbb5d2f..8403119f6 100644 --- a/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/fn_defs/equality.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/equality.py @@ -9,10 +9,10 @@ from llama_stack.apis.scoring_functions import ScoringFn equality = ScoringFn( - identifier="meta-reference::equality", + identifier="basic::equality", description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.", params=None, - provider_id="meta-reference", + provider_id="basic", provider_resource_id="equality", return_type=NumberType(), ) diff --git a/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/fn_defs/regex_parser_multiple_choice_answer.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/regex_parser_multiple_choice_answer.py similarity index 95% rename from llama_stack/providers/inline/scoring/meta_reference/scoring_fn/fn_defs/regex_parser_multiple_choice_answer.py rename to llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/regex_parser_multiple_choice_answer.py index 20b59c273..9d028a468 100644 --- a/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/fn_defs/regex_parser_multiple_choice_answer.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/regex_parser_multiple_choice_answer.py @@ -57,10 +57,10 @@ MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = ( ) regex_parser_multiple_choice_answer = ScoringFn( - identifier="meta-reference::regex_parser_multiple_choice_answer", + identifier="basic::regex_parser_multiple_choice_answer", description="Extract answer from response matching Answer: [the_answer_letter], and compare with expected result", return_type=NumberType(), - provider_id="meta-reference", + provider_id="basic", provider_resource_id="regex-parser-multiple-choice-answer", params=RegexParserScoringFnParams( parsing_regexes=[ diff --git a/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/fn_defs/subset_of.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/subset_of.py similarity index 86% rename from llama_stack/providers/inline/scoring/meta_reference/scoring_fn/fn_defs/subset_of.py rename to llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/subset_of.py index b2759f3ee..ab2a9c60b 100644 --- a/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/fn_defs/subset_of.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/subset_of.py @@ -9,9 +9,9 @@ from llama_stack.apis.scoring_functions import ScoringFn subset_of = ScoringFn( - identifier="meta-reference::subset_of", + identifier="basic::subset_of", description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.", return_type=NumberType(), - provider_id="meta-reference", + provider_id="basic", provider_resource_id="subset-of", ) diff --git a/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/regex_parser_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py similarity index 96% rename from llama_stack/providers/inline/scoring/meta_reference/scoring_fn/regex_parser_scoring_fn.py rename to llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py index 33773b7bb..fd036ced1 100644 --- a/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/regex_parser_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import re -from .base_scoring_fn import BaseScoringFn +from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn from llama_stack.apis.scoring_functions import * # noqa: F401, F403 from llama_stack.apis.scoring import * # noqa: F401, F403 from llama_stack.apis.common.type_system import * # noqa: F403 diff --git a/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/subset_of_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/subset_of_scoring_fn.py similarity index 95% rename from llama_stack/providers/inline/scoring/meta_reference/scoring_fn/subset_of_scoring_fn.py rename to llama_stack/providers/inline/scoring/basic/scoring_fn/subset_of_scoring_fn.py index fe5988160..1ff3c9b1c 100644 --- a/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/subset_of_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/subset_of_scoring_fn.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .base_scoring_fn import BaseScoringFn +from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn from llama_stack.apis.scoring_functions import * # noqa: F401, F403 from llama_stack.apis.scoring import * # noqa: F401, F403 from llama_stack.apis.common.type_system import * # noqa: F403 diff --git a/llama_stack/providers/inline/scoring/braintrust/braintrust.py b/llama_stack/providers/inline/scoring/braintrust/braintrust.py index 9105a4978..973232f4e 100644 --- a/llama_stack/providers/inline/scoring/braintrust/braintrust.py +++ b/llama_stack/providers/inline/scoring/braintrust/braintrust.py @@ -63,18 +63,18 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): ) async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None: - dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id) - if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0: + dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) + if not dataset_def.schema or len(dataset_def.schema) == 0: raise ValueError( f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset." ) for required_column in ["generated_answer", "expected_answer", "input_query"]: - if required_column not in dataset_def.dataset_schema: + if required_column not in dataset_def.schema: raise ValueError( f"Dataset {dataset_id} does not have a '{required_column}' column." ) - if dataset_def.dataset_schema[required_column].type != "string": + if dataset_def.schema[required_column].type != "string": raise ValueError( f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'." ) diff --git a/llama_stack/providers/inline/scoring/meta_reference/__init__.py b/llama_stack/providers/inline/scoring/llm_as_judge/__init__.py similarity index 73% rename from llama_stack/providers/inline/scoring/meta_reference/__init__.py rename to llama_stack/providers/inline/scoring/llm_as_judge/__init__.py index 002f74e86..806aef272 100644 --- a/llama_stack/providers/inline/scoring/meta_reference/__init__.py +++ b/llama_stack/providers/inline/scoring/llm_as_judge/__init__.py @@ -7,16 +7,16 @@ from typing import Dict from llama_stack.distribution.datatypes import Api, ProviderSpec -from .config import MetaReferenceScoringConfig +from .config import LlmAsJudgeScoringConfig async def get_provider_impl( - config: MetaReferenceScoringConfig, + config: LlmAsJudgeScoringConfig, deps: Dict[Api, ProviderSpec], ): - from .scoring import MetaReferenceScoringImpl + from .scoring import LlmAsJudgeScoringImpl - impl = MetaReferenceScoringImpl( + impl = LlmAsJudgeScoringImpl( config, deps[Api.datasetio], deps[Api.datasets], deps[Api.inference] ) await impl.initialize() diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/config.py b/llama_stack/providers/inline/scoring/llm_as_judge/config.py new file mode 100644 index 000000000..1b538420c --- /dev/null +++ b/llama_stack/providers/inline/scoring/llm_as_judge/config.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from pydantic import BaseModel + + +class LlmAsJudgeScoringConfig(BaseModel): ... diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py new file mode 100644 index 000000000..0cb81e114 --- /dev/null +++ b/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py @@ -0,0 +1,131 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from typing import Any, Dict, List, Optional + +from llama_stack.apis.datasetio import DatasetIO +from llama_stack.apis.datasets import Datasets +from llama_stack.apis.inference.inference import Inference + +from llama_stack.apis.scoring import ( + ScoreBatchResponse, + ScoreResponse, + Scoring, + ScoringResult, +) +from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams +from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate + +from .config import LlmAsJudgeScoringConfig +from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn + + +LLM_JUDGE_FNS = [LlmAsJudgeScoringFn] + + +class LlmAsJudgeScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): + def __init__( + self, + config: LlmAsJudgeScoringConfig, + datasetio_api: DatasetIO, + datasets_api: Datasets, + inference_api: Inference, + ) -> None: + self.config = config + self.datasetio_api = datasetio_api + self.datasets_api = datasets_api + self.inference_api = inference_api + self.scoring_fn_id_impls = {} + + async def initialize(self) -> None: + for fn in LLM_JUDGE_FNS: + impl = fn(inference_api=self.inference_api) + for fn_defs in impl.get_supported_scoring_fn_defs(): + self.scoring_fn_id_impls[fn_defs.identifier] = impl + self.llm_as_judge_fn = impl + + async def shutdown(self) -> None: ... + + async def list_scoring_functions(self) -> List[ScoringFn]: + scoring_fn_defs_list = [ + fn_def + for impl in self.scoring_fn_id_impls.values() + for fn_def in impl.get_supported_scoring_fn_defs() + ] + + for f in scoring_fn_defs_list: + assert f.identifier.startswith( + "llm-as-judge" + ), "All llm-as-judge scoring fn must have identifier prefixed with 'llm-as-judge'! " + + return scoring_fn_defs_list + + async def register_scoring_function(self, function_def: ScoringFn) -> None: + raise NotImplementedError("Register scoring function not implemented yet") + + async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None: + dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) + if not dataset_def.schema or len(dataset_def.schema) == 0: + raise ValueError( + f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset." + ) + + for required_column in ["generated_answer", "expected_answer", "input_query"]: + if required_column not in dataset_def.schema: + raise ValueError( + f"Dataset {dataset_id} does not have a '{required_column}' column." + ) + if dataset_def.schema[required_column].type != "string": + raise ValueError( + f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'." + ) + + async def score_batch( + self, + dataset_id: str, + scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, + save_results_dataset: bool = False, + ) -> ScoreBatchResponse: + await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id) + all_rows = await self.datasetio_api.get_rows_paginated( + dataset_id=dataset_id, + rows_in_page=-1, + ) + res = await self.score( + input_rows=all_rows.rows, + scoring_functions=scoring_functions, + ) + if save_results_dataset: + # TODO: persist and register dataset on to server for reading + # self.datasets_api.register_dataset() + raise NotImplementedError("Save results dataset not implemented yet") + + return ScoreBatchResponse( + results=res.results, + ) + + async def score( + self, + input_rows: List[Dict[str, Any]], + scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, + ) -> ScoreResponse: + res = {} + for scoring_fn_id in scoring_functions.keys(): + if scoring_fn_id not in self.scoring_fn_id_impls: + raise ValueError(f"Scoring function {scoring_fn_id} is not supported.") + scoring_fn = self.scoring_fn_id_impls[scoring_fn_id] + scoring_fn_params = scoring_functions.get(scoring_fn_id, None) + score_results = await scoring_fn.score( + input_rows, scoring_fn_id, scoring_fn_params + ) + agg_results = await scoring_fn.aggregate(score_results) + res[scoring_fn_id] = ScoringResult( + score_rows=score_results, + aggregated_results=agg_results, + ) + + return ScoreResponse( + results=res, + ) diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/__init__.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/__init__.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/fn_defs/llm_as_judge_base.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/llm_as_judge_base.py similarity index 84% rename from llama_stack/providers/inline/scoring/meta_reference/scoring_fn/fn_defs/llm_as_judge_base.py rename to llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/llm_as_judge_base.py index ad07ea1b8..51517a0b0 100644 --- a/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/fn_defs/llm_as_judge_base.py +++ b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/llm_as_judge_base.py @@ -9,9 +9,9 @@ from llama_stack.apis.scoring_functions import ScoringFn llm_as_judge_base = ScoringFn( - identifier="meta-reference::llm_as_judge_base", + identifier="llm-as-judge::llm_as_judge_base", description="Llm As Judge Scoring Function", return_type=NumberType(), - provider_id="meta-reference", + provider_id="llm-as-judge", provider_resource_id="llm-as-judge-base", ) diff --git a/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/llm_as_judge_scoring_fn.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py similarity index 97% rename from llama_stack/providers/inline/scoring/meta_reference/scoring_fn/llm_as_judge_scoring_fn.py rename to llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py index e1f19e640..a950f35f9 100644 --- a/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/llm_as_judge_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from llama_stack.apis.inference.inference import Inference -from .base_scoring_fn import BaseScoringFn +from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn from llama_stack.apis.scoring_functions import * # noqa: F401, F403 from llama_stack.apis.scoring import * # noqa: F401, F403 from llama_stack.apis.common.type_system import * # noqa: F403 diff --git a/llama_stack/providers/registry/scoring.py b/llama_stack/providers/registry/scoring.py index a63b21c65..2da9797bc 100644 --- a/llama_stack/providers/registry/scoring.py +++ b/llama_stack/providers/registry/scoring.py @@ -13,10 +13,21 @@ def available_providers() -> List[ProviderSpec]: return [ InlineProviderSpec( api=Api.scoring, - provider_type="inline::meta-reference", + provider_type="inline::basic", pip_packages=[], - module="llama_stack.providers.inline.scoring.meta_reference", - config_class="llama_stack.providers.inline.scoring.meta_reference.MetaReferenceScoringConfig", + module="llama_stack.providers.inline.scoring.basic", + config_class="llama_stack.providers.inline.scoring.basic.BasicScoringConfig", + api_dependencies=[ + Api.datasetio, + Api.datasets, + ], + ), + InlineProviderSpec( + api=Api.scoring, + provider_type="inline::llm-as-judge", + pip_packages=[], + module="llama_stack.providers.inline.scoring.llm_as_judge", + config_class="llama_stack.providers.inline.scoring.llm_as_judge.LlmAsJudgeScoringConfig", api_dependencies=[ Api.datasetio, Api.datasets, @@ -25,7 +36,7 @@ def available_providers() -> List[ProviderSpec]: ), InlineProviderSpec( api=Api.scoring, - provider_type="braintrust", + provider_type="inline::braintrust", pip_packages=["autoevals", "openai"], module="llama_stack.providers.inline.scoring.braintrust", config_class="llama_stack.providers.inline.scoring.braintrust.BraintrustScoringConfig", diff --git a/llama_stack/providers/tests/scoring/conftest.py b/llama_stack/providers/tests/scoring/conftest.py index ed56df230..e8ecfaa68 100644 --- a/llama_stack/providers/tests/scoring/conftest.py +++ b/llama_stack/providers/tests/scoring/conftest.py @@ -15,21 +15,12 @@ from .fixtures import SCORING_FIXTURES DEFAULT_PROVIDER_COMBINATIONS = [ pytest.param( { - "scoring": "meta_reference", - "datasetio": "localfs", - "inference": "fireworks", - }, - id="meta_reference_scoring_fireworks_inference", - marks=pytest.mark.meta_reference_scoring_fireworks_inference, - ), - pytest.param( - { - "scoring": "meta_reference", + "scoring": "basic", "datasetio": "localfs", "inference": "together", }, - id="meta_reference_scoring_together_inference", - marks=pytest.mark.meta_reference_scoring_together_inference, + id="basic_scoring_together_inference", + marks=pytest.mark.basic_scoring_together_inference, ), pytest.param( { @@ -40,13 +31,21 @@ DEFAULT_PROVIDER_COMBINATIONS = [ id="braintrust_scoring_together_inference", marks=pytest.mark.braintrust_scoring_together_inference, ), + pytest.param( + { + "scoring": "llm_as_judge", + "datasetio": "localfs", + "inference": "together", + }, + id="llm_as_judge_scoring_together_inference", + marks=pytest.mark.llm_as_judge_scoring_together_inference, + ), ] def pytest_configure(config): for fixture_name in [ - "meta_reference_scoring_fireworks_inference", - "meta_reference_scoring_together_inference", + "basic_scoring_together_inference", "braintrust_scoring_together_inference", ]: config.addinivalue_line( diff --git a/llama_stack/providers/tests/scoring/fixtures.py b/llama_stack/providers/tests/scoring/fixtures.py index 20631f5cf..14095b526 100644 --- a/llama_stack/providers/tests/scoring/fixtures.py +++ b/llama_stack/providers/tests/scoring/fixtures.py @@ -19,12 +19,12 @@ def scoring_remote() -> ProviderFixture: @pytest.fixture(scope="session") -def scoring_meta_reference() -> ProviderFixture: +def scoring_basic() -> ProviderFixture: return ProviderFixture( providers=[ Provider( - provider_id="meta-reference", - provider_type="meta-reference", + provider_id="basic", + provider_type="inline::basic", config={}, ) ], @@ -37,14 +37,27 @@ def scoring_braintrust() -> ProviderFixture: providers=[ Provider( provider_id="braintrust", - provider_type="braintrust", + provider_type="inline::braintrust", config={}, ) ], ) -SCORING_FIXTURES = ["meta_reference", "remote", "braintrust"] +@pytest.fixture(scope="session") +def scoring_llm_as_judge() -> ProviderFixture: + return ProviderFixture( + providers=[ + Provider( + provider_id="llm-as-judge", + provider_type="inline::llm-as-judge", + config={}, + ) + ], + ) + + +SCORING_FIXTURES = ["basic", "remote", "braintrust", "llm_as_judge"] @pytest_asyncio.fixture(scope="session") diff --git a/llama_stack/providers/tests/scoring/test_scoring.py b/llama_stack/providers/tests/scoring/test_scoring.py index f3c925048..08a05681f 100644 --- a/llama_stack/providers/tests/scoring/test_scoring.py +++ b/llama_stack/providers/tests/scoring/test_scoring.py @@ -43,6 +43,13 @@ class TestScoring: scoring_stack[Api.datasets], scoring_stack[Api.models], ) + scoring_fns_list = await scoring_functions_impl.list_scoring_functions() + provider_id = scoring_fns_list[0].provider_id + if provider_id == "llm-as-judge": + pytest.skip( + f"{provider_id} provider does not support scoring without params" + ) + await register_dataset(datasets_impl) response = await datasets_impl.list_datasets() assert len(response) == 1 @@ -111,8 +118,8 @@ class TestScoring: scoring_fns_list = await scoring_functions_impl.list_scoring_functions() provider_id = scoring_fns_list[0].provider_id - if provider_id == "braintrust": - pytest.skip("Braintrust provider does not support scoring with params") + if provider_id == "braintrust" or provider_id == "basic": + pytest.skip(f"{provider_id} provider does not support scoring with params") # scoring individual rows rows = await datasetio_impl.get_rows_paginated( @@ -122,7 +129,7 @@ class TestScoring: assert len(rows.rows) == 3 scoring_functions = { - "meta-reference::llm_as_judge_base": LLMAsJudgeScoringFnParams( + "llm-as-judge::llm_as_judge_base": LLMAsJudgeScoringFnParams( judge_model="Llama3.1-405B-Instruct", prompt_template="Output a number response in the following format: Score: , where is the number between 0 and 9.", judge_score_regexes=[r"Score: (\d+)"], diff --git a/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/base_scoring_fn.py b/llama_stack/providers/utils/scoring/base_scoring_fn.py similarity index 91% rename from llama_stack/providers/inline/scoring/meta_reference/scoring_fn/base_scoring_fn.py rename to llama_stack/providers/utils/scoring/base_scoring_fn.py index e356bc289..8cd101c50 100644 --- a/llama_stack/providers/inline/scoring/meta_reference/scoring_fn/base_scoring_fn.py +++ b/llama_stack/providers/utils/scoring/base_scoring_fn.py @@ -4,9 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. from abc import ABC, abstractmethod -from typing import Any, Dict, List -from llama_stack.apis.scoring_functions import * # noqa: F401, F403 -from llama_stack.apis.scoring import * # noqa: F401, F403 +from typing import Any, Dict, List, Optional + +from llama_stack.apis.scoring import ScoringFnParams, ScoringResultRow +from llama_stack.apis.scoring_functions import ScoringFn class BaseScoringFn(ABC):