From 70c08e694d0f5bcff73eac5d462c1879a19d0826 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 23 Oct 2024 14:42:28 -0700 Subject: [PATCH] basic scoring function works --- llama_stack/apis/scoring/client.py | 116 ++++++++++++++++++ .../scoring_functions/scoring_functions.py | 9 +- llama_stack/distribution/routers/routers.py | 12 +- .../distribution/routers/routing_tables.py | 16 +++ .../impls/meta_reference/scoring/scoring.py | 17 ++- 5 files changed, 164 insertions(+), 6 deletions(-) diff --git a/llama_stack/apis/scoring/client.py b/llama_stack/apis/scoring/client.py index 756f351d8..0e3998275 100644 --- a/llama_stack/apis/scoring/client.py +++ b/llama_stack/apis/scoring/client.py @@ -3,3 +3,119 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. + +import asyncio +import os +from pathlib import Path + +import fire +import httpx +from termcolor import cprint + +from llama_stack.apis.datasets import * # noqa: F403 +from llama_stack.apis.scoring import * # noqa: F403 +from llama_stack.apis.common.type_system import * # noqa: F403 +from llama_stack.apis.datasetio.client import DatasetIOClient +from llama_stack.apis.datasets.client import DatasetsClient +from llama_stack.providers.tests.datasetio.test_datasetio import data_url_from_file + + +class ScoringClient(Scoring): + def __init__(self, base_url: str): + self.base_url = base_url + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def score_batch( + self, dataset_id: str, scoring_functions: List[str] + ) -> ScoreBatchResponse: + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self.base_url}/scoring/score_batch", + params={}, + headers={"Content-Type": "application/json"}, + timeout=60, + ) + response.raise_for_status() + if not response.json(): + return + + return ScoreResponse(**response.json()) + + async def score( + self, input_rows: List[Dict[str, Any]], scoring_functions: List[str] + ) -> ScoreResponse: + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self.base_url}/scoring/score", + json={ + "input_rows": input_rows, + "scoring_functions": scoring_functions, + }, + headers={"Content-Type": "application/json"}, + timeout=60, + ) + response.raise_for_status() + if not response.json(): + return + + return ScoreResponse(**response.json()) + + +async def run_main(host: str, port: int): + client = DatasetsClient(f"http://{host}:{port}") + + # register dataset + test_file = ( + Path(os.path.abspath(__file__)).parent.parent.parent + / "providers/tests/datasetio/test_dataset.csv" + ) + test_url = data_url_from_file(str(test_file)) + response = await client.register_dataset( + DatasetDefWithProvider( + identifier="test-dataset", + provider_id="meta0", + url=URL( + uri=test_url, + ), + dataset_schema={ + "generated_answer": StringType(), + "expected_answer": StringType(), + "input_query": StringType(), + }, + ) + ) + + # list datasets + list_dataset = await client.list_datasets() + cprint(list_dataset, "blue") + + # datsetio client to get the rows + datasetio_client = DatasetIOClient(f"http://{host}:{port}") + response = await datasetio_client.get_rows_paginated( + dataset_id="test-dataset", + rows_in_page=4, + page_token=None, + filter_condition=None, + ) + cprint(f"Returned {len(response.rows)} rows \n {response}", "green") + + # scoring client to score the rows + scoring_client = ScoringClient(f"http://{host}:{port}") + response = await scoring_client.score( + input_rows=response.rows, + scoring_functions=["equality"], + ) + cprint(f"scoring response={response}", "blue") + + +def main(host: str, port: int): + asyncio.run(run_main(host, port)) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py index 025c62c94..5888f08f5 100644 --- a/llama_stack/apis/scoring_functions/scoring_functions.py +++ b/llama_stack/apis/scoring_functions/scoring_functions.py @@ -34,8 +34,8 @@ class Parameter(BaseModel): @json_schema_type -class CommonDef(BaseModel): - name: str +class CommonFunctionDef(BaseModel): + identifier: str description: Optional[str] = None metadata: Dict[str, Any] = Field( default_factory=dict, @@ -46,10 +46,11 @@ class CommonDef(BaseModel): @json_schema_type -class DeterministicFunctionDef(CommonDef): +class DeterministicFunctionDef(CommonFunctionDef): type: Literal["deterministic"] = "deterministic" parameters: List[Parameter] = Field( description="List of parameters for the deterministic function", + default_factory=list, ) return_type: ParamType = Field( description="The return type of the deterministic function", @@ -58,7 +59,7 @@ class DeterministicFunctionDef(CommonDef): @json_schema_type -class LLMJudgeFunctionDef(CommonDef): +class LLMJudgeFunctionDef(CommonFunctionDef): type: Literal["judge"] = "judge" model: str = Field( description="The LLM model to use for the judge function", diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index ab058ca8a..168cf9235 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -217,4 +217,14 @@ class ScoringRouter(Scoring): async def score( self, input_rows: List[Dict[str, Any]], scoring_functions: List[str] ) -> ScoreResponse: - pass + # look up and map each scoring function to its provider impl + for fn_identifier in scoring_functions: + score_response = await self.routing_table.get_provider_impl( + fn_identifier + ).score( + input_rows=input_rows, + scoring_functions=[fn_identifier], + ) + print( + f"fn_identifier={fn_identifier}, score_response={score_response}", + ) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index f13a046c0..10b39e522 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -30,6 +30,8 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> None: await p.register_memory_bank(obj) elif api == Api.datasetio: await p.register_dataset(obj) + elif api == Api.scoring: + await p.register_scoring_function(obj) else: raise ValueError(f"Unknown API {api} for registering object with provider") @@ -95,6 +97,16 @@ class CommonRoutingTableImpl(RoutingTable): add_objects(datasets) + elif api == Api.scoring: + p.scoring_function_store = self + scoring_functions = await p.list_scoring_functions() + + # do in-memory updates due to pesky Annotated unions + for s in scoring_functions: + s.provider_id = pid + + add_objects(scoring_functions) + async def shutdown(self) -> None: for p in self.impls_by_provider_id.values(): await p.shutdown() @@ -109,6 +121,10 @@ class CommonRoutingTableImpl(RoutingTable): return ("Safety", "shield") elif isinstance(self, MemoryBanksRoutingTable): return ("Memory", "memory_bank") + elif isinstance(self, DatasetsRoutingTable): + return ("DatasetIO", "dataset") + elif isinstance(self, ScoringFunctionsRoutingTable): + return ("Scoring", "scoring_function") else: raise ValueError("Unknown routing table type") diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring.py b/llama_stack/providers/impls/meta_reference/scoring/scoring.py index 1ec843983..71283d97c 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring.py @@ -7,6 +7,8 @@ from typing import List from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.scoring import * # noqa: F403 +from llama_stack.apis.scoring_functions import * # noqa: F403 +from llama_stack.apis.common.type_system import * # noqa: F403 from llama_stack.apis.datasetio import * # noqa: F403 from termcolor import cprint @@ -28,6 +30,19 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): async def shutdown(self) -> None: ... + async def list_scoring_functions(self) -> List[ScoringFunctionDef]: + return [ + DeterministicFunctionDef( + identifier="equality", + description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.", + parameters=[], + return_type=NumberType(), + ) + ] + + async def register_scoring_function(self, function_def: ScoringFunctionDef) -> None: + pass + async def score_batch( self, dataset_id: str, scoring_functions: List[str] ) -> ScoreBatchResponse: @@ -36,4 +51,4 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): async def score( self, input_rows: List[Dict[str, Any]], scoring_functions: List[str] ) -> ScoreResponse: - print("score") + print("!!!!score")