From bb433695211f1e2bea4f99e6b6ec9e9c6266211d Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 23 Oct 2024 13:53:58 -0700 Subject: [PATCH] dataset client --- llama_stack/apis/datasets/client.py | 132 ++++++++++++++++++ llama_stack/apis/scoring/client.py | 5 + llama_stack/distribution/resolver.py | 6 - .../impls/meta_reference/scoring/__init__.py | 8 +- .../impls/meta_reference/scoring/scoring.py | 10 +- tests/examples/evals-tgi-run.yaml | 5 + 6 files changed, 156 insertions(+), 10 deletions(-) create mode 100644 llama_stack/apis/datasets/client.py create mode 100644 llama_stack/apis/scoring/client.py diff --git a/llama_stack/apis/datasets/client.py b/llama_stack/apis/datasets/client.py new file mode 100644 index 000000000..e387eca6d --- /dev/null +++ b/llama_stack/apis/datasets/client.py @@ -0,0 +1,132 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +import base64 +import json +import mimetypes +import os +from pathlib import Path +from typing import Optional + +import fire +import httpx +from termcolor import cprint + +from .datasets import * # noqa: F403 +from llama_stack.apis.datasets import * # noqa: F403 +from llama_stack.apis.common.type_system import * # noqa: F403 + + +def data_url_from_file(file_path: str) -> str: + if not os.path.exists(file_path): + raise FileNotFoundError(f"File not found: {file_path}") + + with open(file_path, "rb") as file: + file_content = file.read() + + base64_content = base64.b64encode(file_content).decode("utf-8") + mime_type, _ = mimetypes.guess_type(file_path) + + data_url = f"data:{mime_type};base64,{base64_content}" + + return data_url + + +class DatasetsClient(Datasets): + def __init__(self, base_url: str): + self.base_url = base_url + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def register_dataset( + self, + dataset_def: DatasetDefWithProvider, + ) -> None: + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self.base_url}/datasets/register", + json={ + "dataset_def": json.loads(dataset_def.json()), + }, + headers={"Content-Type": "application/json"}, + timeout=60, + ) + response.raise_for_status() + return + + async def get_dataset( + self, + dataset_identifier: str, + ) -> Optional[DatasetDefWithProvider]: + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.base_url}/datasets/get", + params={ + "dataset_identifier": dataset_identifier, + }, + headers={"Content-Type": "application/json"}, + timeout=60, + ) + response.raise_for_status() + if not response.json(): + return + + return DatasetDefWithProvider(**response.json()) + + async def list_datasets(self) -> List[DatasetDefWithProvider]: + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.base_url}/datasets/list", + headers={"Content-Type": "application/json"}, + timeout=60, + ) + response.raise_for_status() + if not response.json(): + return + + return [DatasetDefWithProvider(**x) for x in response.json()] + + +async def run_main(host: str, port: int): + client = DatasetsClient(f"http://{host}:{port}") + + # register dataset + test_file = ( + Path(os.path.abspath(__file__)).parent.parent.parent + / "providers/tests/datasetio/test_dataset.csv" + ) + test_url = data_url_from_file(str(test_file)) + response = await client.register_dataset( + DatasetDefWithProvider( + identifier="test-dataset", + provider_id="meta0", + url=URL( + uri=test_url, + ), + dataset_schema={ + "generated_answer": StringType(), + "expected_answer": StringType(), + "input_query": StringType(), + }, + ) + ) + + # list datasets + list_dataset = await client.list_datasets() + cprint(list_dataset, "blue") + + +def main(host: str, port: int): + asyncio.run(run_main(host, port)) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/llama_stack/apis/scoring/client.py b/llama_stack/apis/scoring/client.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/apis/scoring/client.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 53da099ce..b9b9fb229 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -130,12 +130,6 @@ async def resolve_impls(run_config: StackRunConfig) -> Dict[Api, Any]: ) } - if info.router_api.value == "scoring": - print("SCORING API") - - # p = all_api_providers[api][provider.provider_type] - # p.deps__ = [a.value for a in p.api_dependencies] - providers_with_specs[info.router_api.value] = { "__builtin__": ProviderWithSpec( provider_id="__autorouted__", diff --git a/llama_stack/providers/impls/meta_reference/scoring/__init__.py b/llama_stack/providers/impls/meta_reference/scoring/__init__.py index 31c93faef..48e177324 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/__init__.py +++ b/llama_stack/providers/impls/meta_reference/scoring/__init__.py @@ -3,16 +3,20 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from typing import Dict + +from llama_stack.distribution.datatypes import Api, ProviderSpec from .config import MetaReferenceScoringConfig async def get_provider_impl( config: MetaReferenceScoringConfig, - _deps, + deps: Dict[Api, ProviderSpec], ): + print("get_provider_impl", deps) from .scoring import MetaReferenceScoringImpl - impl = MetaReferenceScoringImpl(config) + impl = MetaReferenceScoringImpl(config, deps[Api.datasetio]) await impl.initialize() return impl diff --git a/llama_stack/providers/impls/meta_reference/scoring/scoring.py b/llama_stack/providers/impls/meta_reference/scoring/scoring.py index 39ae40c13..1ec843983 100644 --- a/llama_stack/providers/impls/meta_reference/scoring/scoring.py +++ b/llama_stack/providers/impls/meta_reference/scoring/scoring.py @@ -7,6 +7,9 @@ from typing import List from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.scoring import * # noqa: F403 +from llama_stack.apis.datasetio import * # noqa: F403 + +from termcolor import cprint from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate @@ -14,9 +17,12 @@ from .config import MetaReferenceScoringConfig class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): - def __init__(self, config: MetaReferenceScoringConfig) -> None: + def __init__( + self, config: MetaReferenceScoringConfig, datasetio_api: DatasetIO + ) -> None: self.config = config - self.dataset_infos = {} + self.datasetio_api = datasetio_api + cprint(f"!!! MetaReferenceScoringImpl init {config} {datasetio_api}", "red") async def initialize(self) -> None: ... diff --git a/tests/examples/evals-tgi-run.yaml b/tests/examples/evals-tgi-run.yaml index 8edb050cc..e56c43420 100644 --- a/tests/examples/evals-tgi-run.yaml +++ b/tests/examples/evals-tgi-run.yaml @@ -13,7 +13,12 @@ apis: - inference - datasets - datasetio +- scoring providers: + scoring: + - provider_id: meta0 + provider_type: meta-reference + config: {} datasetio: - provider_id: meta0 provider_type: meta-reference