dataset client

This commit is contained in:
Xi Yan 2024-10-23 13:53:58 -07:00
parent c5db025320
commit bb43369521
6 changed files with 156 additions and 10 deletions

View file

@ -0,0 +1,132 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import base64
import json
import mimetypes
import os
from pathlib import Path
from typing import Optional
import fire
import httpx
from termcolor import cprint
from .datasets import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.common.type_system import * # noqa: F403
def data_url_from_file(file_path: str) -> str:
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
with open(file_path, "rb") as file:
file_content = file.read()
base64_content = base64.b64encode(file_content).decode("utf-8")
mime_type, _ = mimetypes.guess_type(file_path)
data_url = f"data:{mime_type};base64,{base64_content}"
return data_url
class DatasetsClient(Datasets):
def __init__(self, base_url: str):
self.base_url = base_url
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def register_dataset(
self,
dataset_def: DatasetDefWithProvider,
) -> None:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/datasets/register",
json={
"dataset_def": json.loads(dataset_def.json()),
},
headers={"Content-Type": "application/json"},
timeout=60,
)
response.raise_for_status()
return
async def get_dataset(
self,
dataset_identifier: str,
) -> Optional[DatasetDefWithProvider]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/datasets/get",
params={
"dataset_identifier": dataset_identifier,
},
headers={"Content-Type": "application/json"},
timeout=60,
)
response.raise_for_status()
if not response.json():
return
return DatasetDefWithProvider(**response.json())
async def list_datasets(self) -> List[DatasetDefWithProvider]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/datasets/list",
headers={"Content-Type": "application/json"},
timeout=60,
)
response.raise_for_status()
if not response.json():
return
return [DatasetDefWithProvider(**x) for x in response.json()]
async def run_main(host: str, port: int):
client = DatasetsClient(f"http://{host}:{port}")
# register dataset
test_file = (
Path(os.path.abspath(__file__)).parent.parent.parent
/ "providers/tests/datasetio/test_dataset.csv"
)
test_url = data_url_from_file(str(test_file))
response = await client.register_dataset(
DatasetDefWithProvider(
identifier="test-dataset",
provider_id="meta0",
url=URL(
uri=test_url,
),
dataset_schema={
"generated_answer": StringType(),
"expected_answer": StringType(),
"input_query": StringType(),
},
)
)
# list datasets
list_dataset = await client.list_datasets()
cprint(list_dataset, "blue")
def main(host: str, port: int):
asyncio.run(run_main(host, port))
if __name__ == "__main__":
fire.Fire(main)

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -130,12 +130,6 @@ async def resolve_impls(run_config: StackRunConfig) -> Dict[Api, Any]:
)
}
if info.router_api.value == "scoring":
print("SCORING API")
# p = all_api_providers[api][provider.provider_type]
# p.deps__ = [a.value for a in p.api_dependencies]
providers_with_specs[info.router_api.value] = {
"__builtin__": ProviderWithSpec(
provider_id="__autorouted__",

View file

@ -3,16 +3,20 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec
from .config import MetaReferenceScoringConfig
async def get_provider_impl(
config: MetaReferenceScoringConfig,
_deps,
deps: Dict[Api, ProviderSpec],
):
print("get_provider_impl", deps)
from .scoring import MetaReferenceScoringImpl
impl = MetaReferenceScoringImpl(config)
impl = MetaReferenceScoringImpl(config, deps[Api.datasetio])
await impl.initialize()
return impl

View file

@ -7,6 +7,9 @@ from typing import List
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.scoring import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from termcolor import cprint
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
@ -14,9 +17,12 @@ from .config import MetaReferenceScoringConfig
class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
def __init__(self, config: MetaReferenceScoringConfig) -> None:
def __init__(
self, config: MetaReferenceScoringConfig, datasetio_api: DatasetIO
) -> None:
self.config = config
self.dataset_infos = {}
self.datasetio_api = datasetio_api
cprint(f"!!! MetaReferenceScoringImpl init {config} {datasetio_api}", "red")
async def initialize(self) -> None: ...

View file

@ -13,7 +13,12 @@ apis:
- inference
- datasets
- datasetio
- scoring
providers:
scoring:
- provider_id: meta0
provider_type: meta-reference
config: {}
datasetio:
- provider_id: meta0
provider_type: meta-reference