mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
dataset client
This commit is contained in:
parent
c5db025320
commit
bb43369521
6 changed files with 156 additions and 10 deletions
132
llama_stack/apis/datasets/client.py
Normal file
132
llama_stack/apis/datasets/client.py
Normal file
|
@ -0,0 +1,132 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import fire
|
||||
import httpx
|
||||
from termcolor import cprint
|
||||
|
||||
from .datasets import * # noqa: F403
|
||||
from llama_stack.apis.datasets import * # noqa: F403
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
|
||||
|
||||
def data_url_from_file(file_path: str) -> str:
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
with open(file_path, "rb") as file:
|
||||
file_content = file.read()
|
||||
|
||||
base64_content = base64.b64encode(file_content).decode("utf-8")
|
||||
mime_type, _ = mimetypes.guess_type(file_path)
|
||||
|
||||
data_url = f"data:{mime_type};base64,{base64_content}"
|
||||
|
||||
return data_url
|
||||
|
||||
|
||||
class DatasetsClient(Datasets):
|
||||
def __init__(self, base_url: str):
|
||||
self.base_url = base_url
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_dataset(
|
||||
self,
|
||||
dataset_def: DatasetDefWithProvider,
|
||||
) -> None:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/datasets/register",
|
||||
json={
|
||||
"dataset_def": json.loads(dataset_def.json()),
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=60,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return
|
||||
|
||||
async def get_dataset(
|
||||
self,
|
||||
dataset_identifier: str,
|
||||
) -> Optional[DatasetDefWithProvider]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/datasets/get",
|
||||
params={
|
||||
"dataset_identifier": dataset_identifier,
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=60,
|
||||
)
|
||||
response.raise_for_status()
|
||||
if not response.json():
|
||||
return
|
||||
|
||||
return DatasetDefWithProvider(**response.json())
|
||||
|
||||
async def list_datasets(self) -> List[DatasetDefWithProvider]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/datasets/list",
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=60,
|
||||
)
|
||||
response.raise_for_status()
|
||||
if not response.json():
|
||||
return
|
||||
|
||||
return [DatasetDefWithProvider(**x) for x in response.json()]
|
||||
|
||||
|
||||
async def run_main(host: str, port: int):
|
||||
client = DatasetsClient(f"http://{host}:{port}")
|
||||
|
||||
# register dataset
|
||||
test_file = (
|
||||
Path(os.path.abspath(__file__)).parent.parent.parent
|
||||
/ "providers/tests/datasetio/test_dataset.csv"
|
||||
)
|
||||
test_url = data_url_from_file(str(test_file))
|
||||
response = await client.register_dataset(
|
||||
DatasetDefWithProvider(
|
||||
identifier="test-dataset",
|
||||
provider_id="meta0",
|
||||
url=URL(
|
||||
uri=test_url,
|
||||
),
|
||||
dataset_schema={
|
||||
"generated_answer": StringType(),
|
||||
"expected_answer": StringType(),
|
||||
"input_query": StringType(),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# list datasets
|
||||
list_dataset = await client.list_datasets()
|
||||
cprint(list_dataset, "blue")
|
||||
|
||||
|
||||
def main(host: str, port: int):
|
||||
asyncio.run(run_main(host, port))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(main)
|
5
llama_stack/apis/scoring/client.py
Normal file
5
llama_stack/apis/scoring/client.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -130,12 +130,6 @@ async def resolve_impls(run_config: StackRunConfig) -> Dict[Api, Any]:
|
|||
)
|
||||
}
|
||||
|
||||
if info.router_api.value == "scoring":
|
||||
print("SCORING API")
|
||||
|
||||
# p = all_api_providers[api][provider.provider_type]
|
||||
# p.deps__ = [a.value for a in p.api_dependencies]
|
||||
|
||||
providers_with_specs[info.router_api.value] = {
|
||||
"__builtin__": ProviderWithSpec(
|
||||
provider_id="__autorouted__",
|
||||
|
|
|
@ -3,16 +3,20 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||
|
||||
from .config import MetaReferenceScoringConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: MetaReferenceScoringConfig,
|
||||
_deps,
|
||||
deps: Dict[Api, ProviderSpec],
|
||||
):
|
||||
print("get_provider_impl", deps)
|
||||
from .scoring import MetaReferenceScoringImpl
|
||||
|
||||
impl = MetaReferenceScoringImpl(config)
|
||||
impl = MetaReferenceScoringImpl(config, deps[Api.datasetio])
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -7,6 +7,9 @@ from typing import List
|
|||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.scoring import * # noqa: F403
|
||||
from llama_stack.apis.datasetio import * # noqa: F403
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
||||
|
||||
|
@ -14,9 +17,12 @@ from .config import MetaReferenceScoringConfig
|
|||
|
||||
|
||||
class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||
def __init__(self, config: MetaReferenceScoringConfig) -> None:
|
||||
def __init__(
|
||||
self, config: MetaReferenceScoringConfig, datasetio_api: DatasetIO
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.dataset_infos = {}
|
||||
self.datasetio_api = datasetio_api
|
||||
cprint(f"!!! MetaReferenceScoringImpl init {config} {datasetio_api}", "red")
|
||||
|
||||
async def initialize(self) -> None: ...
|
||||
|
||||
|
|
|
@ -13,7 +13,12 @@ apis:
|
|||
- inference
|
||||
- datasets
|
||||
- datasetio
|
||||
- scoring
|
||||
providers:
|
||||
scoring:
|
||||
- provider_id: meta0
|
||||
provider_type: meta-reference
|
||||
config: {}
|
||||
datasetio:
|
||||
- provider_id: meta0
|
||||
provider_type: meta-reference
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue