diff --git a/llama_stack/providers/inline/scoring/basic/__init__.py b/llama_stack/providers/inline/scoring/basic/__init__.py index 77bb9e6e6..c72434e9e 100644 --- a/llama_stack/providers/inline/scoring/basic/__init__.py +++ b/llama_stack/providers/inline/scoring/basic/__init__.py @@ -17,7 +17,9 @@ async def get_provider_impl( from .scoring import BasicScoringImpl impl = BasicScoringImpl( - config, deps[Api.datasetio], deps[Api.datasets], deps[Api.inference] + config, + deps[Api.datasetio], + deps[Api.datasets], ) await impl.initialize() return impl diff --git a/llama_stack/providers/inline/scoring/basic/scoring.py b/llama_stack/providers/inline/scoring/basic/scoring.py index 897e8ef36..6632d5b81 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring.py +++ b/llama_stack/providers/inline/scoring/basic/scoring.py @@ -11,19 +11,15 @@ from llama_stack.apis.scoring_functions import * # noqa: F403 from llama_stack.apis.common.type_system import * # noqa: F403 from llama_stack.apis.datasetio import * # noqa: F403 from llama_stack.apis.datasets import * # noqa: F403 -from llama_stack.apis.inference.inference import Inference from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate from .config import BasicScoringConfig from .scoring_fn.equality_scoring_fn import EqualityScoringFn -from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn] -LLM_JUDGE_FNS = [LlmAsJudgeScoringFn] - class BasicScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): def __init__( @@ -31,12 +27,10 @@ class BasicScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): config: BasicScoringConfig, datasetio_api: DatasetIO, datasets_api: Datasets, - inference_api: Inference, ) -> None: self.config = config self.datasetio_api = datasetio_api self.datasets_api = datasets_api - self.inference_api = inference_api self.scoring_fn_id_impls = {} async def initialize(self) -> None: @@ -44,11 +38,6 @@ class BasicScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): impl = x() for fn_defs in impl.get_supported_scoring_fn_defs(): self.scoring_fn_id_impls[fn_defs.identifier] = impl - for x in LLM_JUDGE_FNS: - impl = x(inference_api=self.inference_api) - for fn_defs in impl.get_supported_scoring_fn_defs(): - self.scoring_fn_id_impls[fn_defs.identifier] = impl - self.llm_as_judge_fn = impl async def shutdown(self) -> None: ... @@ -61,8 +50,8 @@ class BasicScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): for f in scoring_fn_defs_list: assert f.identifier.startswith( - "meta-reference" - ), "All meta-reference scoring fn must have identifier prefixed with 'meta-reference'! " + "basic" + ), "All basic scoring fn must have identifier prefixed with 'basic'! " return scoring_fn_defs_list diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/equality.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/equality.py index b3fbb5d2f..8403119f6 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/equality.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/equality.py @@ -9,10 +9,10 @@ from llama_stack.apis.scoring_functions import ScoringFn equality = ScoringFn( - identifier="meta-reference::equality", + identifier="basic::equality", description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.", params=None, - provider_id="meta-reference", + provider_id="basic", provider_resource_id="equality", return_type=NumberType(), ) diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/regex_parser_multiple_choice_answer.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/regex_parser_multiple_choice_answer.py index 20b59c273..9d028a468 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/regex_parser_multiple_choice_answer.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/regex_parser_multiple_choice_answer.py @@ -57,10 +57,10 @@ MULTILINGUAL_ANSWER_PATTERN_TEMPLATE = ( ) regex_parser_multiple_choice_answer = ScoringFn( - identifier="meta-reference::regex_parser_multiple_choice_answer", + identifier="basic::regex_parser_multiple_choice_answer", description="Extract answer from response matching Answer: [the_answer_letter], and compare with expected result", return_type=NumberType(), - provider_id="meta-reference", + provider_id="basic", provider_resource_id="regex-parser-multiple-choice-answer", params=RegexParserScoringFnParams( parsing_regexes=[ diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/subset_of.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/subset_of.py index b2759f3ee..ab2a9c60b 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/subset_of.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/subset_of.py @@ -9,9 +9,9 @@ from llama_stack.apis.scoring_functions import ScoringFn subset_of = ScoringFn( - identifier="meta-reference::subset_of", + identifier="basic::subset_of", description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.", return_type=NumberType(), - provider_id="meta-reference", + provider_id="basic", provider_resource_id="subset-of", ) diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/__init__.py b/llama_stack/providers/inline/scoring/llm_as_judge/__init__.py new file mode 100644 index 000000000..806aef272 --- /dev/null +++ b/llama_stack/providers/inline/scoring/llm_as_judge/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from typing import Dict + +from llama_stack.distribution.datatypes import Api, ProviderSpec + +from .config import LlmAsJudgeScoringConfig + + +async def get_provider_impl( + config: LlmAsJudgeScoringConfig, + deps: Dict[Api, ProviderSpec], +): + from .scoring import LlmAsJudgeScoringImpl + + impl = LlmAsJudgeScoringImpl( + config, deps[Api.datasetio], deps[Api.datasets], deps[Api.inference] + ) + await impl.initialize() + return impl diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/config.py b/llama_stack/providers/inline/scoring/llm_as_judge/config.py new file mode 100644 index 000000000..1b538420c --- /dev/null +++ b/llama_stack/providers/inline/scoring/llm_as_judge/config.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from pydantic import BaseModel + + +class LlmAsJudgeScoringConfig(BaseModel): ... diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py new file mode 100644 index 000000000..b20a189b1 --- /dev/null +++ b/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py @@ -0,0 +1,131 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from typing import Any, Dict, List, Optional + +from llama_stack.apis.datasetio import DatasetIO +from llama_stack.apis.datasets import Datasets +from llama_stack.apis.inference.inference import Inference + +from llama_stack.apis.scoring import ( + ScoreBatchResponse, + ScoreResponse, + Scoring, + ScoringResult, +) +from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams +from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate + +from .config import LlmAsJudgeScoringConfig +from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn + + +LLM_JUDGE_FNS = [LlmAsJudgeScoringFn] + + +class LlmAsJudgeScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): + def __init__( + self, + config: LlmAsJudgeScoringConfig, + datasetio_api: DatasetIO, + datasets_api: Datasets, + inference_api: Inference, + ) -> None: + self.config = config + self.datasetio_api = datasetio_api + self.datasets_api = datasets_api + self.inference_api = inference_api + self.scoring_fn_id_impls = {} + + async def initialize(self) -> None: + for x in LLM_JUDGE_FNS: + impl = x(inference_api=self.inference_api) + for fn_defs in impl.get_supported_scoring_fn_defs(): + self.scoring_fn_id_impls[fn_defs.identifier] = impl + self.llm_as_judge_fn = impl + + async def shutdown(self) -> None: ... + + async def list_scoring_functions(self) -> List[ScoringFn]: + scoring_fn_defs_list = [ + fn_def + for impl in self.scoring_fn_id_impls.values() + for fn_def in impl.get_supported_scoring_fn_defs() + ] + + for f in scoring_fn_defs_list: + assert f.identifier.startswith( + "llm-as-judge" + ), "All meta-reference scoring fn must have identifier prefixed with 'meta-reference'! " + + return scoring_fn_defs_list + + async def register_scoring_function(self, function_def: ScoringFn) -> None: + raise NotImplementedError("Register scoring function not implemented yet") + + async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None: + dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) + if not dataset_def.schema or len(dataset_def.schema) == 0: + raise ValueError( + f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset." + ) + + for required_column in ["generated_answer", "expected_answer", "input_query"]: + if required_column not in dataset_def.schema: + raise ValueError( + f"Dataset {dataset_id} does not have a '{required_column}' column." + ) + if dataset_def.schema[required_column].type != "string": + raise ValueError( + f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'." + ) + + async def score_batch( + self, + dataset_id: str, + scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, + save_results_dataset: bool = False, + ) -> ScoreBatchResponse: + await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id) + all_rows = await self.datasetio_api.get_rows_paginated( + dataset_id=dataset_id, + rows_in_page=-1, + ) + res = await self.score( + input_rows=all_rows.rows, + scoring_functions=scoring_functions, + ) + if save_results_dataset: + # TODO: persist and register dataset on to server for reading + # self.datasets_api.register_dataset() + raise NotImplementedError("Save results dataset not implemented yet") + + return ScoreBatchResponse( + results=res.results, + ) + + async def score( + self, + input_rows: List[Dict[str, Any]], + scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, + ) -> ScoreResponse: + res = {} + for scoring_fn_id in scoring_functions.keys(): + if scoring_fn_id not in self.scoring_fn_id_impls: + raise ValueError(f"Scoring function {scoring_fn_id} is not supported.") + scoring_fn = self.scoring_fn_id_impls[scoring_fn_id] + scoring_fn_params = scoring_functions.get(scoring_fn_id, None) + score_results = await scoring_fn.score( + input_rows, scoring_fn_id, scoring_fn_params + ) + agg_results = await scoring_fn.aggregate(score_results) + res[scoring_fn_id] = ScoringResult( + score_rows=score_results, + aggregated_results=agg_results, + ) + + return ScoreResponse( + results=res, + ) diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/__init__.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/base_scoring_fn.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/base_scoring_fn.py new file mode 100644 index 000000000..e356bc289 --- /dev/null +++ b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/base_scoring_fn.py @@ -0,0 +1,61 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from abc import ABC, abstractmethod +from typing import Any, Dict, List +from llama_stack.apis.scoring_functions import * # noqa: F401, F403 +from llama_stack.apis.scoring import * # noqa: F401, F403 + + +class BaseScoringFn(ABC): + """ + Base interface class for all meta-reference scoring_fns. + Each scoring_fn needs to implement the following methods: + - score_row(self, row) + - aggregate(self, scoring_fn_results) + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.supported_fn_defs_registry = {} + + def __str__(self) -> str: + return self.__class__.__name__ + + def get_supported_scoring_fn_defs(self) -> List[ScoringFn]: + return [x for x in self.supported_fn_defs_registry.values()] + + def register_scoring_fn_def(self, scoring_fn: ScoringFn) -> None: + if scoring_fn.identifier in self.supported_fn_defs_registry: + raise ValueError( + f"Scoring function def with identifier {scoring_fn.identifier} already exists." + ) + self.supported_fn_defs_registry[scoring_fn.identifier] = scoring_fn + + @abstractmethod + async def score_row( + self, + input_row: Dict[str, Any], + scoring_fn_identifier: Optional[str] = None, + scoring_params: Optional[ScoringFnParams] = None, + ) -> ScoringResultRow: + raise NotImplementedError() + + @abstractmethod + async def aggregate( + self, scoring_results: List[ScoringResultRow] + ) -> Dict[str, Any]: + raise NotImplementedError() + + async def score( + self, + input_rows: List[Dict[str, Any]], + scoring_fn_identifier: Optional[str] = None, + scoring_params: Optional[ScoringFnParams] = None, + ) -> List[ScoringResultRow]: + return [ + await self.score_row(input_row, scoring_fn_identifier, scoring_params) + for input_row in input_rows + ] diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/__init__.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/llm_as_judge_base.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/llm_as_judge_base.py similarity index 84% rename from llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/llm_as_judge_base.py rename to llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/llm_as_judge_base.py index ad07ea1b8..51517a0b0 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/llm_as_judge_base.py +++ b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/llm_as_judge_base.py @@ -9,9 +9,9 @@ from llama_stack.apis.scoring_functions import ScoringFn llm_as_judge_base = ScoringFn( - identifier="meta-reference::llm_as_judge_base", + identifier="llm-as-judge::llm_as_judge_base", description="Llm As Judge Scoring Function", return_type=NumberType(), - provider_id="meta-reference", + provider_id="llm-as-judge", provider_resource_id="llm-as-judge-base", ) diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/llm_as_judge_scoring_fn.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py similarity index 100% rename from llama_stack/providers/inline/scoring/basic/scoring_fn/llm_as_judge_scoring_fn.py rename to llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py diff --git a/llama_stack/providers/registry/scoring.py b/llama_stack/providers/registry/scoring.py index ccf303d59..9666c66e3 100644 --- a/llama_stack/providers/registry/scoring.py +++ b/llama_stack/providers/registry/scoring.py @@ -13,10 +13,21 @@ def available_providers() -> List[ProviderSpec]: return [ InlineProviderSpec( api=Api.scoring, - provider_type="meta-reference", + provider_type="basic", pip_packages=[], module="llama_stack.providers.inline.scoring.basic", config_class="llama_stack.providers.inline.scoring.basic.BasicScoringConfig", + api_dependencies=[ + Api.datasetio, + Api.datasets, + ], + ), + InlineProviderSpec( + api=Api.scoring, + provider_type="llm-as-judge", + pip_packages=[], + module="llama_stack.providers.inline.scoring.llm_as_judge", + config_class="llama_stack.providers.inline.scoring.llm_as_judge.LlmAsJudgeScoringConfig", api_dependencies=[ Api.datasetio, Api.datasets, diff --git a/llama_stack/providers/tests/scoring/conftest.py b/llama_stack/providers/tests/scoring/conftest.py index ed56df230..e8ecfaa68 100644 --- a/llama_stack/providers/tests/scoring/conftest.py +++ b/llama_stack/providers/tests/scoring/conftest.py @@ -15,21 +15,12 @@ from .fixtures import SCORING_FIXTURES DEFAULT_PROVIDER_COMBINATIONS = [ pytest.param( { - "scoring": "meta_reference", - "datasetio": "localfs", - "inference": "fireworks", - }, - id="meta_reference_scoring_fireworks_inference", - marks=pytest.mark.meta_reference_scoring_fireworks_inference, - ), - pytest.param( - { - "scoring": "meta_reference", + "scoring": "basic", "datasetio": "localfs", "inference": "together", }, - id="meta_reference_scoring_together_inference", - marks=pytest.mark.meta_reference_scoring_together_inference, + id="basic_scoring_together_inference", + marks=pytest.mark.basic_scoring_together_inference, ), pytest.param( { @@ -40,13 +31,21 @@ DEFAULT_PROVIDER_COMBINATIONS = [ id="braintrust_scoring_together_inference", marks=pytest.mark.braintrust_scoring_together_inference, ), + pytest.param( + { + "scoring": "llm_as_judge", + "datasetio": "localfs", + "inference": "together", + }, + id="llm_as_judge_scoring_together_inference", + marks=pytest.mark.llm_as_judge_scoring_together_inference, + ), ] def pytest_configure(config): for fixture_name in [ - "meta_reference_scoring_fireworks_inference", - "meta_reference_scoring_together_inference", + "basic_scoring_together_inference", "braintrust_scoring_together_inference", ]: config.addinivalue_line( diff --git a/llama_stack/providers/tests/scoring/fixtures.py b/llama_stack/providers/tests/scoring/fixtures.py index 20631f5cf..5b28dcfa8 100644 --- a/llama_stack/providers/tests/scoring/fixtures.py +++ b/llama_stack/providers/tests/scoring/fixtures.py @@ -19,12 +19,12 @@ def scoring_remote() -> ProviderFixture: @pytest.fixture(scope="session") -def scoring_meta_reference() -> ProviderFixture: +def scoring_basic() -> ProviderFixture: return ProviderFixture( providers=[ Provider( - provider_id="meta-reference", - provider_type="meta-reference", + provider_id="basic", + provider_type="basic", config={}, ) ], @@ -44,7 +44,20 @@ def scoring_braintrust() -> ProviderFixture: ) -SCORING_FIXTURES = ["meta_reference", "remote", "braintrust"] +@pytest.fixture(scope="session") +def scoring_llm_as_judge() -> ProviderFixture: + return ProviderFixture( + providers=[ + Provider( + provider_id="llm-as-judge", + provider_type="llm-as-judge", + config={}, + ) + ], + ) + + +SCORING_FIXTURES = ["basic", "remote", "braintrust", "llm_as_judge"] @pytest_asyncio.fixture(scope="session") diff --git a/llama_stack/providers/tests/scoring/test_scoring.py b/llama_stack/providers/tests/scoring/test_scoring.py index f3c925048..08a05681f 100644 --- a/llama_stack/providers/tests/scoring/test_scoring.py +++ b/llama_stack/providers/tests/scoring/test_scoring.py @@ -43,6 +43,13 @@ class TestScoring: scoring_stack[Api.datasets], scoring_stack[Api.models], ) + scoring_fns_list = await scoring_functions_impl.list_scoring_functions() + provider_id = scoring_fns_list[0].provider_id + if provider_id == "llm-as-judge": + pytest.skip( + f"{provider_id} provider does not support scoring without params" + ) + await register_dataset(datasets_impl) response = await datasets_impl.list_datasets() assert len(response) == 1 @@ -111,8 +118,8 @@ class TestScoring: scoring_fns_list = await scoring_functions_impl.list_scoring_functions() provider_id = scoring_fns_list[0].provider_id - if provider_id == "braintrust": - pytest.skip("Braintrust provider does not support scoring with params") + if provider_id == "braintrust" or provider_id == "basic": + pytest.skip(f"{provider_id} provider does not support scoring with params") # scoring individual rows rows = await datasetio_impl.get_rows_paginated( @@ -122,7 +129,7 @@ class TestScoring: assert len(rows.rows) == 3 scoring_functions = { - "meta-reference::llm_as_judge_base": LLMAsJudgeScoringFnParams( + "llm-as-judge::llm_as_judge_base": LLMAsJudgeScoringFnParams( judge_model="Llama3.1-405B-Instruct", prompt_template="Output a number response in the following format: Score: , where is the number between 0 and 9.", judge_score_regexes=[r"Score: (\d+)"],