From cb8403456748a17c93beb04ce9f2870a10ec9800 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 24 Oct 2024 14:52:30 -0700 Subject: [PATCH] [Evals API][3/n] scoring_functions / scoring meta-reference implementations (#296) * wip * dataset validation * test_scoring * cleanup * clean up test * comments * error checking * dataset client * test client: * datasetio client * clean up * basic scoring function works * scorer wip * equality scorer * score batch impl * score batch * update scoring test * refactor * validate scorer input * address comments * add all rows scores to ScoringResult * bugfix * scoring function def rename --- llama_stack/apis/datasetio/client.py | 103 ++++++++++++++ llama_stack/apis/datasetio/datasetio.py | 2 +- llama_stack/apis/datasets/client.py | 116 +++++++++++++++ llama_stack/apis/datasets/datasets.py | 2 +- llama_stack/apis/scoring/client.py | 132 ++++++++++++++++++ llama_stack/apis/scoring/scoring.py | 20 ++- .../scoring_functions/scoring_functions.py | 46 ++---- llama_stack/distribution/datatypes.py | 5 + llama_stack/distribution/distribution.py | 4 + llama_stack/distribution/resolver.py | 4 + llama_stack/distribution/routers/__init__.py | 12 +- llama_stack/distribution/routers/routers.py | 54 +++++++ .../distribution/routers/routing_tables.py | 36 ++++- llama_stack/providers/datatypes.py | 13 +- .../meta_reference/datasetio/datasetio.py | 21 ++- .../impls/meta_reference/scoring/__init__.py | 21 +++ .../impls/meta_reference/scoring/config.py | 9 ++ .../meta_reference/scoring/scorer/__init__.py | 5 + .../scoring/scorer/base_scorer.py | 37 +++++ .../scoring/scorer/equality_scorer.py | 49 +++++++ .../impls/meta_reference/scoring/scoring.py | 109 +++++++++++++++ llama_stack/providers/registry/scoring.py | 25 ++++ .../tests/datasetio/test_dataset.csv | 6 + .../tests/datasetio/test_datasetio.py | 36 ++++- .../providers/tests/scoring/__init__.py | 5 + .../scoring/provider_config_example.yaml | 9 ++ .../providers/tests/scoring/test_scoring.py | 69 +++++++++ tests/examples/evals-tgi-run.yaml | 5 + 28 files changed, 904 insertions(+), 51 deletions(-) create mode 100644 llama_stack/apis/datasetio/client.py create mode 100644 llama_stack/apis/datasets/client.py create mode 100644 llama_stack/apis/scoring/client.py create mode 100644 llama_stack/providers/impls/meta_reference/scoring/__init__.py create mode 100644 llama_stack/providers/impls/meta_reference/scoring/config.py create mode 100644 llama_stack/providers/impls/meta_reference/scoring/scorer/__init__.py create mode 100644 llama_stack/providers/impls/meta_reference/scoring/scorer/base_scorer.py create mode 100644 llama_stack/providers/impls/meta_reference/scoring/scorer/equality_scorer.py create mode 100644 llama_stack/providers/impls/meta_reference/scoring/scoring.py create mode 100644 llama_stack/providers/registry/scoring.py create mode 100644 llama_stack/providers/tests/datasetio/test_dataset.csv create mode 100644 llama_stack/providers/tests/scoring/__init__.py create mode 100644 llama_stack/providers/tests/scoring/provider_config_example.yaml create mode 100644 llama_stack/providers/tests/scoring/test_scoring.py diff --git a/llama_stack/apis/datasetio/client.py b/llama_stack/apis/datasetio/client.py new file mode 100644 index 000000000..b62db9085 --- /dev/null +++ b/llama_stack/apis/datasetio/client.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +import os +from pathlib import Path +from typing import Optional + +import fire +import httpx +from termcolor import cprint + +from llama_stack.apis.datasets import * # noqa: F403 +from llama_stack.apis.datasetio import * # noqa: F403 +from llama_stack.apis.common.type_system import * # noqa: F403 +from llama_stack.apis.datasets.client import DatasetsClient +from llama_stack.providers.tests.datasetio.test_datasetio import data_url_from_file + + +class DatasetIOClient(DatasetIO): + def __init__(self, base_url: str): + self.base_url = base_url + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def get_rows_paginated( + self, + dataset_id: str, + rows_in_page: int, + page_token: Optional[str] = None, + filter_condition: Optional[str] = None, + ) -> PaginatedRowsResult: + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.base_url}/datasetio/get_rows_paginated", + params={ + "dataset_id": dataset_id, + "rows_in_page": rows_in_page, + "page_token": page_token, + "filter_condition": filter_condition, + }, + headers={"Content-Type": "application/json"}, + timeout=60, + ) + response.raise_for_status() + if not response.json(): + return + + return PaginatedRowsResult(**response.json()) + + +async def run_main(host: str, port: int): + client = DatasetsClient(f"http://{host}:{port}") + + # register dataset + test_file = ( + Path(os.path.abspath(__file__)).parent.parent.parent + / "providers/tests/datasetio/test_dataset.csv" + ) + test_url = data_url_from_file(str(test_file)) + response = await client.register_dataset( + DatasetDefWithProvider( + identifier="test-dataset", + provider_id="meta0", + url=URL( + uri=test_url, + ), + dataset_schema={ + "generated_answer": StringType(), + "expected_answer": StringType(), + "input_query": StringType(), + }, + ) + ) + + # list datasets + list_dataset = await client.list_datasets() + cprint(list_dataset, "blue") + + # datsetio client to get the rows + datasetio_client = DatasetIOClient(f"http://{host}:{port}") + response = await datasetio_client.get_rows_paginated( + dataset_id="test-dataset", + rows_in_page=4, + page_token=None, + filter_condition=None, + ) + cprint(f"Returned {len(response.rows)} rows \n {response}", "green") + + +def main(host: str, port: int): + asyncio.run(run_main(host, port)) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/llama_stack/apis/datasetio/datasetio.py b/llama_stack/apis/datasetio/datasetio.py index e8811d233..b321b260e 100644 --- a/llama_stack/apis/datasetio/datasetio.py +++ b/llama_stack/apis/datasetio/datasetio.py @@ -29,7 +29,7 @@ class DatasetIO(Protocol): # keeping for aligning with inference/safety, but this is not used dataset_store: DatasetStore - @webmethod(route="/dataio/get_rows_paginated") + @webmethod(route="/datasetio/get_rows_paginated", method="GET") async def get_rows_paginated( self, dataset_id: str, diff --git a/llama_stack/apis/datasets/client.py b/llama_stack/apis/datasets/client.py new file mode 100644 index 000000000..9e5891e74 --- /dev/null +++ b/llama_stack/apis/datasets/client.py @@ -0,0 +1,116 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +import json +import os +from pathlib import Path +from typing import Optional + +import fire +import httpx +from termcolor import cprint + +from .datasets import * # noqa: F403 +from llama_stack.apis.datasets import * # noqa: F403 +from llama_stack.apis.common.type_system import * # noqa: F403 +from llama_stack.providers.tests.datasetio.test_datasetio import data_url_from_file + + +class DatasetsClient(Datasets): + def __init__(self, base_url: str): + self.base_url = base_url + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def register_dataset( + self, + dataset_def: DatasetDefWithProvider, + ) -> None: + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self.base_url}/datasets/register", + json={ + "dataset_def": json.loads(dataset_def.json()), + }, + headers={"Content-Type": "application/json"}, + timeout=60, + ) + response.raise_for_status() + return + + async def get_dataset( + self, + dataset_identifier: str, + ) -> Optional[DatasetDefWithProvider]: + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.base_url}/datasets/get", + params={ + "dataset_identifier": dataset_identifier, + }, + headers={"Content-Type": "application/json"}, + timeout=60, + ) + response.raise_for_status() + if not response.json(): + return + + return DatasetDefWithProvider(**response.json()) + + async def list_datasets(self) -> List[DatasetDefWithProvider]: + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.base_url}/datasets/list", + headers={"Content-Type": "application/json"}, + timeout=60, + ) + response.raise_for_status() + if not response.json(): + return + + return [DatasetDefWithProvider(**x) for x in response.json()] + + +async def run_main(host: str, port: int): + client = DatasetsClient(f"http://{host}:{port}") + + # register dataset + test_file = ( + Path(os.path.abspath(__file__)).parent.parent.parent + / "providers/tests/datasetio/test_dataset.csv" + ) + test_url = data_url_from_file(str(test_file)) + response = await client.register_dataset( + DatasetDefWithProvider( + identifier="test-dataset", + provider_id="meta0", + url=URL( + uri=test_url, + ), + dataset_schema={ + "generated_answer": StringType(), + "expected_answer": StringType(), + "input_query": StringType(), + }, + ) + ) + + # list datasets + list_dataset = await client.list_datasets() + cprint(list_dataset, "blue") + + +def main(host: str, port: int): + asyncio.run(run_main(host, port)) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/llama_stack/apis/datasets/datasets.py b/llama_stack/apis/datasets/datasets.py index e2b764d7f..7a56049bf 100644 --- a/llama_stack/apis/datasets/datasets.py +++ b/llama_stack/apis/datasets/datasets.py @@ -20,7 +20,7 @@ class DatasetDef(BaseModel): identifier: str = Field( description="A unique name for the dataset", ) - columns_schema: Dict[str, ParamType] = Field( + dataset_schema: Dict[str, ParamType] = Field( description="The schema definition for this dataset", ) url: URL diff --git a/llama_stack/apis/scoring/client.py b/llama_stack/apis/scoring/client.py new file mode 100644 index 000000000..f08fa4bc0 --- /dev/null +++ b/llama_stack/apis/scoring/client.py @@ -0,0 +1,132 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +import os +from pathlib import Path + +import fire +import httpx +from termcolor import cprint + +from llama_stack.apis.datasets import * # noqa: F403 +from llama_stack.apis.scoring import * # noqa: F403 +from llama_stack.apis.common.type_system import * # noqa: F403 +from llama_stack.apis.datasetio.client import DatasetIOClient +from llama_stack.apis.datasets.client import DatasetsClient +from llama_stack.providers.tests.datasetio.test_datasetio import data_url_from_file + + +class ScoringClient(Scoring): + def __init__(self, base_url: str): + self.base_url = base_url + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def score_batch( + self, dataset_id: str, scoring_functions: List[str] + ) -> ScoreBatchResponse: + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self.base_url}/scoring/score_batch", + json={ + "dataset_id": dataset_id, + "scoring_functions": scoring_functions, + }, + headers={"Content-Type": "application/json"}, + timeout=60, + ) + response.raise_for_status() + if not response.json(): + return + + return ScoreBatchResponse(**response.json()) + + async def score( + self, input_rows: List[Dict[str, Any]], scoring_functions: List[str] + ) -> ScoreResponse: + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self.base_url}/scoring/score", + json={ + "input_rows": input_rows, + "scoring_functions": scoring_functions, + }, + headers={"Content-Type": "application/json"}, + timeout=60, + ) + response.raise_for_status() + if not response.json(): + return + + return ScoreResponse(**response.json()) + + +async def run_main(host: str, port: int): + client = DatasetsClient(f"http://{host}:{port}") + + # register dataset + test_file = ( + Path(os.path.abspath(__file__)).parent.parent.parent + / "providers/tests/datasetio/test_dataset.csv" + ) + test_url = data_url_from_file(str(test_file)) + response = await client.register_dataset( + DatasetDefWithProvider( + identifier="test-dataset", + provider_id="meta0", + url=URL( + uri=test_url, + ), + dataset_schema={ + "generated_answer": StringType(), + "expected_answer": StringType(), + "input_query": StringType(), + }, + ) + ) + + # list datasets + list_dataset = await client.list_datasets() + cprint(list_dataset, "blue") + + # datsetio client to get the rows + datasetio_client = DatasetIOClient(f"http://{host}:{port}") + response = await datasetio_client.get_rows_paginated( + dataset_id="test-dataset", + rows_in_page=4, + page_token=None, + filter_condition=None, + ) + cprint(f"Returned {len(response.rows)} rows \n {response}", "green") + + # scoring client to score the rows + scoring_client = ScoringClient(f"http://{host}:{port}") + response = await scoring_client.score( + input_rows=response.rows, + scoring_functions=["equality"], + ) + cprint(f"score response={response}", "blue") + + # test scoring batch using datasetio api + scoring_client = ScoringClient(f"http://{host}:{port}") + response = await scoring_client.score_batch( + dataset_id="test-dataset", + scoring_functions=["equality"], + ) + cprint(f"score_batch response={response}", "cyan") + + +def main(host: str, port: int): + asyncio.run(run_main(host, port)) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/llama_stack/apis/scoring/scoring.py b/llama_stack/apis/scoring/scoring.py index ec50ecab1..adac34d55 100644 --- a/llama_stack/apis/scoring/scoring.py +++ b/llama_stack/apis/scoring/scoring.py @@ -13,18 +13,27 @@ from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.scoring_functions import * # noqa: F403 -ScoringResult = Dict[str, Any] +# mapping of metric to value +ScoringResultRow = Dict[str, Any] + + +@json_schema_type +class ScoringResult(BaseModel): + score_rows: List[ScoringResultRow] + # aggregated metrics to value + aggregated_results: Dict[str, Any] @json_schema_type class ScoreBatchResponse(BaseModel): - dataset_id: str + dataset_id: Optional[str] = None + results: Dict[str, ScoringResult] @json_schema_type class ScoreResponse(BaseModel): # each key in the dict is a scoring function name - results: List[Dict[str, ScoringResult]] + results: Dict[str, ScoringResult] class ScoringFunctionStore(Protocol): @@ -37,7 +46,10 @@ class Scoring(Protocol): @webmethod(route="/scoring/score_batch") async def score_batch( - self, dataset_id: str, scoring_functions: List[str] + self, + dataset_id: str, + scoring_functions: List[str], + save_results_dataset: bool = False, ) -> ScoreBatchResponse: ... @webmethod(route="/scoring/score") diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py index 1d71c51f3..a242215c6 100644 --- a/llama_stack/apis/scoring_functions/scoring_functions.py +++ b/llama_stack/apis/scoring_functions/scoring_functions.py @@ -4,20 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import ( - Any, - Dict, - List, - Literal, - Optional, - Protocol, - runtime_checkable, - Union, -) +from typing import Any, Dict, List, Optional, Protocol, runtime_checkable from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field -from typing_extensions import Annotated from llama_stack.apis.common.type_system import ParamType @@ -33,45 +23,37 @@ class Parameter(BaseModel): # with standard metrics so they can be rolled up? +class LLMAsJudgeContext(BaseModel): + judge_model: str + prompt_template: Optional[str] = None + + @json_schema_type -class CommonDef(BaseModel): - name: str +class ScoringFunctionDef(BaseModel): + identifier: str description: Optional[str] = None metadata: Dict[str, Any] = Field( default_factory=dict, description="Any additional metadata for this definition", ) - # Hack: same with memory_banks for union defs - provider_id: str = "" - - -@json_schema_type -class DeterministicFunctionDef(CommonDef): - type: Literal["deterministic"] = "deterministic" parameters: List[Parameter] = Field( description="List of parameters for the deterministic function", + default_factory=list, ) return_type: ParamType = Field( description="The return type of the deterministic function", ) + context: Optional[LLMAsJudgeContext] = None # We can optionally add information here to support packaging of code, etc. @json_schema_type -class LLMJudgeFunctionDef(CommonDef): - type: Literal["judge"] = "judge" - model: str = Field( - description="The LLM model to use for the judge function", +class ScoringFunctionDefWithProvider(ScoringFunctionDef): + provider_id: str = Field( + description="ID of the provider which serves this dataset", ) -ScoringFunctionDef = Annotated[ - Union[DeterministicFunctionDef, LLMJudgeFunctionDef], Field(discriminator="type") -] - -ScoringFunctionDefWithProvider = ScoringFunctionDef - - @runtime_checkable class ScoringFunctions(Protocol): @webmethod(route="/scoring_functions/list", method="GET") @@ -84,5 +66,5 @@ class ScoringFunctions(Protocol): @webmethod(route="/scoring_functions/register", method="POST") async def register_scoring_function( - self, function: ScoringFunctionDefWithProvider + self, function_def: ScoringFunctionDefWithProvider ) -> None: ... diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index 10f78b78f..318809baf 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -15,10 +15,12 @@ from llama_stack.apis.models import * # noqa: F403 from llama_stack.apis.shields import * # noqa: F403 from llama_stack.apis.memory_banks import * # noqa: F403 from llama_stack.apis.datasets import * # noqa: F403 +from llama_stack.apis.scoring_functions import * # noqa: F403 from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.inference import Inference from llama_stack.apis.memory import Memory from llama_stack.apis.safety import Safety +from llama_stack.apis.scoring import Scoring LLAMA_STACK_BUILD_CONFIG_VERSION = "2" LLAMA_STACK_RUN_CONFIG_VERSION = "2" @@ -32,6 +34,7 @@ RoutableObject = Union[ ShieldDef, MemoryBankDef, DatasetDef, + ScoringFunctionDef, ] RoutableObjectWithProvider = Union[ @@ -39,6 +42,7 @@ RoutableObjectWithProvider = Union[ ShieldDefWithProvider, MemoryBankDefWithProvider, DatasetDefWithProvider, + ScoringFunctionDefWithProvider, ] RoutedProtocol = Union[ @@ -46,6 +50,7 @@ RoutedProtocol = Union[ Safety, Memory, DatasetIO, + Scoring, ] diff --git a/llama_stack/distribution/distribution.py b/llama_stack/distribution/distribution.py index 53d544471..2149162a6 100644 --- a/llama_stack/distribution/distribution.py +++ b/llama_stack/distribution/distribution.py @@ -39,6 +39,10 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]: routing_table_api=Api.datasets, router_api=Api.datasetio, ), + AutoRoutedApiInfo( + routing_table_api=Api.scoring_functions, + router_api=Api.scoring, + ), ] diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 2e6b64a53..b9b9fb229 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -20,6 +20,8 @@ from llama_stack.apis.memory import Memory from llama_stack.apis.memory_banks import MemoryBanks from llama_stack.apis.models import Models from llama_stack.apis.safety import Safety +from llama_stack.apis.scoring import Scoring +from llama_stack.apis.scoring_functions import ScoringFunctions from llama_stack.apis.shields import Shields from llama_stack.apis.telemetry import Telemetry from llama_stack.distribution.distribution import ( @@ -42,6 +44,8 @@ def api_protocol_map() -> Dict[Api, Any]: Api.telemetry: Telemetry, Api.datasets: Datasets, Api.datasetio: DatasetIO, + Api.scoring_functions: ScoringFunctions, + Api.scoring: Scoring, } diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py index 4970e93e1..2cc89848e 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/distribution/routers/__init__.py @@ -11,6 +11,7 @@ from .routing_tables import ( DatasetsRoutingTable, MemoryBanksRoutingTable, ModelsRoutingTable, + ScoringFunctionsRoutingTable, ShieldsRoutingTable, ) @@ -25,7 +26,9 @@ async def get_routing_table_impl( "models": ModelsRoutingTable, "shields": ShieldsRoutingTable, "datasets": DatasetsRoutingTable, + "scoring_functions": ScoringFunctionsRoutingTable, } + if api.value not in api_to_tables: raise ValueError(f"API {api.value} not found in router map") @@ -35,13 +38,20 @@ async def get_routing_table_impl( async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any: - from .routers import DatasetIORouter, InferenceRouter, MemoryRouter, SafetyRouter + from .routers import ( + DatasetIORouter, + InferenceRouter, + MemoryRouter, + SafetyRouter, + ScoringRouter, + ) api_to_routers = { "memory": MemoryRouter, "inference": InferenceRouter, "safety": SafetyRouter, "datasetio": DatasetIORouter, + "scoring": ScoringRouter, } if api.value not in api_to_routers: raise ValueError(f"API {api.value} not found in router map") diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 31b8efa48..348d8449d 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -13,6 +13,7 @@ from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.datasetio import * # noqa: F403 +from llama_stack.apis.scoring import * # noqa: F403 class MemoryRouter(Memory): @@ -192,3 +193,56 @@ class DatasetIORouter(DatasetIO): page_token=page_token, filter_condition=filter_condition, ) + + +class ScoringRouter(Scoring): + def __init__( + self, + routing_table: RoutingTable, + ) -> None: + self.routing_table = routing_table + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def score_batch( + self, + dataset_id: str, + scoring_functions: List[str], + save_results_dataset: bool = False, + ) -> ScoreBatchResponse: + res = {} + for fn_identifier in scoring_functions: + score_response = await self.routing_table.get_provider_impl( + fn_identifier + ).score_batch( + dataset_id=dataset_id, + scoring_functions=[fn_identifier], + ) + res.update(score_response.results) + + if save_results_dataset: + raise NotImplementedError("Save results dataset not implemented yet") + + return ScoreBatchResponse( + results=res, + ) + + async def score( + self, input_rows: List[Dict[str, Any]], scoring_functions: List[str] + ) -> ScoreResponse: + res = {} + # look up and map each scoring function to its provider impl + for fn_identifier in scoring_functions: + score_response = await self.routing_table.get_provider_impl( + fn_identifier + ).score( + input_rows=input_rows, + scoring_functions=[fn_identifier], + ) + res.update(score_response.results) + + return ScoreResponse(results=res) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index db0946d81..dcd588a9e 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -30,6 +30,8 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> None: await p.register_memory_bank(obj) elif api == Api.datasetio: await p.register_dataset(obj) + elif api == Api.scoring: + await p.register_scoring_function(obj) else: raise ValueError(f"Unknown API {api} for registering object with provider") @@ -93,7 +95,15 @@ class CommonRoutingTableImpl(RoutingTable): for d in datasets: d.provider_id = pid - add_objects(datasets) + elif api == Api.scoring: + p.scoring_function_store = self + scoring_functions = await p.list_scoring_functions() + add_objects( + [ + ScoringFunctionDefWithProvider(**s.dict(), provider_id=pid) + for s in scoring_functions + ] + ) async def shutdown(self) -> None: for p in self.impls_by_provider_id.values(): @@ -109,6 +119,10 @@ class CommonRoutingTableImpl(RoutingTable): return ("Safety", "shield") elif isinstance(self, MemoryBanksRoutingTable): return ("Memory", "memory_bank") + elif isinstance(self, DatasetsRoutingTable): + return ("DatasetIO", "dataset") + elif isinstance(self, ScoringFunctionsRoutingTable): + return ("Scoring", "scoring_function") else: raise ValueError("Unknown routing table type") @@ -218,7 +232,25 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): async def get_dataset( self, dataset_identifier: str ) -> Optional[DatasetDefWithProvider]: - return self.get_object_by_identifier(identifier) + return self.get_object_by_identifier(dataset_identifier) async def register_dataset(self, dataset_def: DatasetDefWithProvider) -> None: await self.register_object(dataset_def) + + +class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, Scoring): + async def list_scoring_functions(self) -> List[ScoringFunctionDefWithProvider]: + objects = [] + for objs in self.registry.values(): + objects.extend(objs) + return objects + + async def get_scoring_function( + self, name: str + ) -> Optional[ScoringFunctionDefWithProvider]: + return self.get_object_by_identifier(name) + + async def register_scoring_function( + self, function_def: ScoringFunctionDefWithProvider + ) -> None: + await self.register_object(function_def) diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index d7e2d4d0c..903ff5438 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -11,10 +11,9 @@ from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field from llama_stack.apis.datasets import DatasetDef - from llama_stack.apis.memory_banks import MemoryBankDef - from llama_stack.apis.models import ModelDef +from llama_stack.apis.scoring_functions import ScoringFunctionDef from llama_stack.apis.shields import ShieldDef @@ -25,6 +24,7 @@ class Api(Enum): agents = "agents" memory = "memory" datasetio = "datasetio" + scoring = "scoring" telemetry = "telemetry" @@ -32,6 +32,7 @@ class Api(Enum): shields = "shields" memory_banks = "memory_banks" datasets = "datasets" + scoring_functions = "scoring_functions" # built-in API inspect = "inspect" @@ -61,6 +62,14 @@ class DatasetsProtocolPrivate(Protocol): async def register_datasets(self, dataset_def: DatasetDef) -> None: ... +class ScoringFunctionsProtocolPrivate(Protocol): + async def list_scoring_functions(self) -> List[ScoringFunctionDef]: ... + + async def register_scoring_function( + self, function_def: ScoringFunctionDef + ) -> None: ... + + @json_schema_type class ProviderSpec(BaseModel): api: Api diff --git a/llama_stack/providers/impls/meta_reference/datasetio/datasetio.py b/llama_stack/providers/impls/meta_reference/datasetio/datasetio.py index a8e648e46..43664f394 100644 --- a/llama_stack/providers/impls/meta_reference/datasetio/datasetio.py +++ b/llama_stack/providers/impls/meta_reference/datasetio/datasetio.py @@ -3,17 +3,20 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import io from typing import List, Optional import pandas - from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.datasetio import * # noqa: F403 +import base64 from abc import ABC, abstractmethod from dataclasses import dataclass +from urllib.parse import unquote from llama_stack.providers.datatypes import DatasetsProtocolPrivate +from llama_stack.providers.utils.memory.vector_store import parse_data_url from .config import MetaReferenceDatasetIOConfig @@ -52,11 +55,20 @@ class PandasDataframeDataset(BaseDataset): return len(self.df) def __getitem__(self, idx): + assert self.df is not None, "Dataset not loaded. Please call .load() first" if isinstance(idx, slice): return self.df.iloc[idx].to_dict(orient="records") else: return self.df.iloc[idx].to_dict() + def _validate_dataset_schema(self, df) -> pandas.DataFrame: + # note that we will drop any columns in dataset that are not in the schema + df = df[self.dataset_def.dataset_schema.keys()] + # check all columns in dataset schema are present + assert len(df.columns) == len(self.dataset_def.dataset_schema) + # TODO: type checking against column types in dataset schema + return df + def load(self) -> None: if self.df is not None: return @@ -87,7 +99,7 @@ class PandasDataframeDataset(BaseDataset): else: raise ValueError(f"Unsupported file type: {self.dataset_def.url}") - self.df = df + self.df = self._validate_dataset_schema(df) class MetaReferenceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): @@ -123,7 +135,10 @@ class MetaReferenceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): dataset_info = self.dataset_infos.get(dataset_id) dataset_info.dataset_impl.load() - if page_token is None: + if page_token and not page_token.isnumeric(): + raise ValueError("Invalid page_token") + + if page_token is None or len(page_token) == 0: next_page_token = 0 else: next_page_token = int(page_token) diff --git a/llama_stack/providers/impls/meta_reference/scoring/__init__.py b/llama_stack/providers/impls/meta_reference/scoring/__init__.py new file mode 100644 index 000000000..69d9b543a --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/scoring/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from typing import Dict + +from llama_stack.distribution.datatypes import Api, ProviderSpec + +from .config import MetaReferenceScoringConfig + + +async def get_provider_impl( + config: MetaReferenceScoringConfig, + deps: Dict[Api, ProviderSpec], +): + from .scoring import MetaReferenceScoringImpl + + impl = MetaReferenceScoringImpl(config, deps[Api.datasetio], deps[Api.datasets]) + await impl.initialize() + return impl diff --git a/llama_stack/providers/impls/meta_reference/scoring/config.py b/llama_stack/providers/impls/meta_reference/scoring/config.py new file mode 100644 index 000000000..bd4dcb9f0 --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/scoring/config.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from llama_stack.apis.scoring import * # noqa: F401, F403 + + +class MetaReferenceScoringConfig(BaseModel): ... diff --git a/llama_stack/providers/impls/meta_reference/scoring/scorer/__init__.py b/llama_stack/providers/impls/meta_reference/scoring/scorer/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/scoring/scorer/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/impls/meta_reference/scoring/scorer/base_scorer.py b/llama_stack/providers/impls/meta_reference/scoring/scorer/base_scorer.py new file mode 100644 index 000000000..ea8a3f063 --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/scoring/scorer/base_scorer.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from abc import ABC, abstractmethod +from typing import Any, Dict, List +from llama_stack.apis.scoring_functions import * # noqa: F401, F403 +from llama_stack.apis.scoring import * # noqa: F401, F403 + + +class BaseScorer(ABC): + """ + Base interface class for all meta-reference scorers. + Each scorer needs to implement the following methods: + - score_row(self, row) + - aggregate(self, scorer_results) + """ + + scoring_function_def: ScoringFunctionDef + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def __str__(self) -> str: + return self.__class__.__name__ + + @abstractmethod + def score_row(self, input_row: Dict[str, Any]) -> ScoringResultRow: + raise NotImplementedError() + + @abstractmethod + def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]: + raise NotImplementedError() + + def score(self, input_rows: List[Dict[str, Any]]) -> List[ScoringResultRow]: + return [self.score_row(input_row) for input_row in input_rows] diff --git a/llama_stack/providers/impls/meta_reference/scoring/scorer/equality_scorer.py b/llama_stack/providers/impls/meta_reference/scoring/scorer/equality_scorer.py new file mode 100644 index 000000000..ce765bfb5 --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/scoring/scorer/equality_scorer.py @@ -0,0 +1,49 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.providers.impls.meta_reference.scoring.scorer.base_scorer import ( + BaseScorer, +) +from llama_stack.apis.scoring_functions import * # noqa: F401, F403 +from llama_stack.apis.scoring import * # noqa: F401, F403 +from llama_stack.apis.common.type_system import * # noqa: F403 + + +class EqualityScorer(BaseScorer): + """ + A scorer that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise. + """ + + scoring_function_def = ScoringFunctionDef( + identifier="equality", + description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.", + parameters=[], + return_type=NumberType(), + ) + + def score_row(self, input_row: Dict[str, Any]) -> ScoringResultRow: + assert "expected_answer" in input_row, "Expected answer not found in input row." + assert ( + "generated_answer" in input_row + ), "Generated answer not found in input row." + + expected_answer = input_row["expected_answer"] + generated_answer = input_row["generated_answer"] + score = 1.0 if expected_answer == generated_answer else 0.0 + return { + "score": score, + } + + def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]: + assert len(scoring_results) > 0, "Empty scoring results provided." + num_correct = sum(result["score"] for result in scoring_results) + avg_score = num_correct / len(scoring_results) + + return { + "accuracy": avg_score, + "num_correct": num_correct, + "num_total": len(scoring_results), + } diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring.py b/llama_stack/providers/impls/meta_reference/scoring/scoring.py new file mode 100644 index 000000000..0d32c8195 --- /dev/null +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring.py @@ -0,0 +1,109 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from typing import List + +from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_stack.apis.scoring import * # noqa: F403 +from llama_stack.apis.scoring_functions import * # noqa: F403 +from llama_stack.apis.common.type_system import * # noqa: F403 +from llama_stack.apis.datasetio import * # noqa: F403 +from llama_stack.apis.datasets import * # noqa: F403 + +from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate +from llama_stack.providers.impls.meta_reference.scoring.scorer.equality_scorer import ( + EqualityScorer, +) + +from .config import MetaReferenceScoringConfig + +SUPPORTED_SCORERS = [ + EqualityScorer, +] + +SCORER_REGISTRY = {x.scoring_function_def.identifier: x for x in SUPPORTED_SCORERS} + + +class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): + def __init__( + self, + config: MetaReferenceScoringConfig, + datasetio_api: DatasetIO, + datasets_api: Datasets, + ) -> None: + self.config = config + self.datasetio_api = datasetio_api + self.datasets_api = datasets_api + + async def initialize(self) -> None: ... + + async def shutdown(self) -> None: ... + + async def list_scoring_functions(self) -> List[ScoringFunctionDef]: + return [x.scoring_function_def for x in SUPPORTED_SCORERS] + + async def register_scoring_function(self, function_def: ScoringFunctionDef) -> None: + raise NotImplementedError( + "Dynamically registering scoring functions is not supported" + ) + + async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None: + dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id) + if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0: + raise ValueError( + f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset." + ) + + for required_column in ["generated_answer", "expected_answer", "input_query"]: + if required_column not in dataset_def.dataset_schema: + raise ValueError( + f"Dataset {dataset_id} does not have a '{required_column}' column." + ) + if dataset_def.dataset_schema[required_column].type != "string": + raise ValueError( + f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'." + ) + + async def score_batch( + self, + dataset_id: str, + scoring_functions: List[str], + save_results_dataset: bool = False, + ) -> ScoreBatchResponse: + await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id) + all_rows = await self.datasetio_api.get_rows_paginated( + dataset_id=dataset_id, + rows_in_page=-1, + ) + res = await self.score( + input_rows=all_rows.rows, scoring_functions=scoring_functions + ) + if save_results_dataset: + # TODO: persist and register dataset on to server for reading + # self.datasets_api.register_dataset() + raise NotImplementedError("Save results dataset not implemented yet") + + return ScoreBatchResponse( + results=res.results, + ) + + async def score( + self, input_rows: List[Dict[str, Any]], scoring_functions: List[str] + ) -> ScoreResponse: + res = {} + for scoring_fn_id in scoring_functions: + if scoring_fn_id not in SCORER_REGISTRY: + raise ValueError(f"Scoring function {scoring_fn_id} is not supported.") + scorer = SCORER_REGISTRY[scoring_fn_id]() + score_results = scorer.score(input_rows) + agg_results = scorer.aggregate(score_results) + res[scoring_fn_id] = ScoringResult( + score_rows=score_results, + aggregated_results=agg_results, + ) + + return ScoreResponse( + results=res, + ) diff --git a/llama_stack/providers/registry/scoring.py b/llama_stack/providers/registry/scoring.py new file mode 100644 index 000000000..4543449b4 --- /dev/null +++ b/llama_stack/providers/registry/scoring.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import List + +from llama_stack.distribution.datatypes import * # noqa: F403 + + +def available_providers() -> List[ProviderSpec]: + return [ + InlineProviderSpec( + api=Api.scoring, + provider_type="meta-reference", + pip_packages=[], + module="llama_stack.providers.impls.meta_reference.scoring", + config_class="llama_stack.providers.impls.meta_reference.scoring.MetaReferenceScoringConfig", + api_dependencies=[ + Api.datasetio, + Api.datasets, + ], + ), + ] diff --git a/llama_stack/providers/tests/datasetio/test_dataset.csv b/llama_stack/providers/tests/datasetio/test_dataset.csv new file mode 100644 index 000000000..a1a250753 --- /dev/null +++ b/llama_stack/providers/tests/datasetio/test_dataset.csv @@ -0,0 +1,6 @@ +input_query,generated_answer,expected_answer +What is the capital of France?,London,Paris +Who is the CEO of Meta?,Mark Zuckerberg,Mark Zuckerberg +What is the largest planet in our solar system?,Jupiter,Jupiter +What is the smallest country in the world?,China,Vatican City +What is the currency of Japan?,Yen,Yen diff --git a/llama_stack/providers/tests/datasetio/test_datasetio.py b/llama_stack/providers/tests/datasetio/test_datasetio.py index 85235a64b..9a351ba30 100644 --- a/llama_stack/providers/tests/datasetio/test_datasetio.py +++ b/llama_stack/providers/tests/datasetio/test_datasetio.py @@ -8,8 +8,13 @@ import os import pytest import pytest_asyncio +from llama_stack.apis.common.type_system import * # noqa: F403 from llama_stack.apis.datasetio import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403 +import base64 +import mimetypes +from pathlib import Path + from llama_stack.providers.tests.resolver import resolve_impls_for_test # How to run this test: @@ -41,14 +46,35 @@ async def datasetio_settings(): } +def data_url_from_file(file_path: str) -> str: + if not os.path.exists(file_path): + raise FileNotFoundError(f"File not found: {file_path}") + + with open(file_path, "rb") as file: + file_content = file.read() + + base64_content = base64.b64encode(file_content).decode("utf-8") + mime_type, _ = mimetypes.guess_type(file_path) + + data_url = f"data:{mime_type};base64,{base64_content}" + + return data_url + + async def register_dataset(datasets_impl: Datasets): + test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.csv" + test_url = data_url_from_file(str(test_file)) dataset = DatasetDefWithProvider( identifier="test_dataset", provider_id=os.environ["PROVIDER_ID"], url=URL( - uri="https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv", + uri=test_url, ), - columns_schema={}, + dataset_schema={ + "generated_answer": StringType(), + "expected_answer": StringType(), + "input_query": StringType(), + }, ) await datasets_impl.register_dataset(dataset) @@ -100,10 +126,10 @@ async def test_get_rows_paginated(datasetio_settings): # iterate over all rows response = await datasetio_impl.get_rows_paginated( dataset_id="test_dataset", - rows_in_page=10, + rows_in_page=2, page_token=response.next_page_token, ) assert isinstance(response.rows, list) - assert len(response.rows) == 10 - assert response.next_page_token == "13" + assert len(response.rows) == 2 + assert response.next_page_token == "5" diff --git a/llama_stack/providers/tests/scoring/__init__.py b/llama_stack/providers/tests/scoring/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/tests/scoring/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/tests/scoring/provider_config_example.yaml b/llama_stack/providers/tests/scoring/provider_config_example.yaml new file mode 100644 index 000000000..9a8895149 --- /dev/null +++ b/llama_stack/providers/tests/scoring/provider_config_example.yaml @@ -0,0 +1,9 @@ +providers: + datasetio: + - provider_id: test-meta + provider_type: meta-reference + config: {} + scoring: + - provider_id: test-meta + provider_type: meta-reference + config: {} diff --git a/llama_stack/providers/tests/scoring/test_scoring.py b/llama_stack/providers/tests/scoring/test_scoring.py new file mode 100644 index 000000000..2218faa54 --- /dev/null +++ b/llama_stack/providers/tests/scoring/test_scoring.py @@ -0,0 +1,69 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +import pytest +import pytest_asyncio + +from llama_stack.apis.common.type_system import * # noqa: F403 +from llama_stack.apis.datasetio import * # noqa: F403 +from llama_stack.distribution.datatypes import * # noqa: F403 + +from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset +from llama_stack.providers.tests.resolver import resolve_impls_for_test + +# How to run this test: +# +# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky +# since it depends on the provider you are testing. On top of that you need +# `pytest` and `pytest-asyncio` installed. +# +# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing. +# +# 3. Run: +# +# ```bash +# PROVIDER_ID= \ +# PROVIDER_CONFIG=provider_config.yaml \ +# pytest -s llama_stack/providers/tests/scoring/test_scoring.py \ +# --tb=short --disable-warnings +# ``` + + +@pytest_asyncio.fixture(scope="session") +async def scoring_settings(): + impls = await resolve_impls_for_test(Api.scoring, deps=[Api.datasetio]) + return { + "scoring_impl": impls[Api.scoring], + "scoring_functions_impl": impls[Api.scoring_functions], + "datasets_impl": impls[Api.datasets], + } + + +@pytest.mark.asyncio +async def test_scoring_functions_list(scoring_settings): + scoring_functions_impl = scoring_settings["scoring_functions_impl"] + scoring_functions = await scoring_functions_impl.list_scoring_functions() + assert isinstance(scoring_functions, list) + assert len(scoring_functions) > 0 + function_ids = [f.identifier for f in scoring_functions] + assert "equality" in function_ids + + +@pytest.mark.asyncio +async def test_scoring_score(scoring_settings): + scoring_impl = scoring_settings["scoring_impl"] + datasets_impl = scoring_settings["datasets_impl"] + await register_dataset(datasets_impl) + + response = await datasets_impl.list_datasets() + assert len(response) == 1 + + response = await scoring_impl.score_batch( + dataset_id=response[0].identifier, + scoring_functions=["equality"], + ) + + assert len(response.results) == 1 + assert "equality" in response.results diff --git a/tests/examples/evals-tgi-run.yaml b/tests/examples/evals-tgi-run.yaml index 8edb050cc..e56c43420 100644 --- a/tests/examples/evals-tgi-run.yaml +++ b/tests/examples/evals-tgi-run.yaml @@ -13,7 +13,12 @@ apis: - inference - datasets - datasetio +- scoring providers: + scoring: + - provider_id: meta0 + provider_type: meta-reference + config: {} datasetio: - provider_id: meta0 provider_type: meta-reference