mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 23:29:43 +00:00
dataset client
This commit is contained in:
parent
c5db025320
commit
bb43369521
6 changed files with 156 additions and 10 deletions
132
llama_stack/apis/datasets/client.py
Normal file
132
llama_stack/apis/datasets/client.py
Normal file
|
@ -0,0 +1,132 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import mimetypes
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import fire
|
||||||
|
import httpx
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
|
from .datasets import * # noqa: F403
|
||||||
|
from llama_stack.apis.datasets import * # noqa: F403
|
||||||
|
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
def data_url_from_file(file_path: str) -> str:
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
raise FileNotFoundError(f"File not found: {file_path}")
|
||||||
|
|
||||||
|
with open(file_path, "rb") as file:
|
||||||
|
file_content = file.read()
|
||||||
|
|
||||||
|
base64_content = base64.b64encode(file_content).decode("utf-8")
|
||||||
|
mime_type, _ = mimetypes.guess_type(file_path)
|
||||||
|
|
||||||
|
data_url = f"data:{mime_type};base64,{base64_content}"
|
||||||
|
|
||||||
|
return data_url
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetsClient(Datasets):
|
||||||
|
def __init__(self, base_url: str):
|
||||||
|
self.base_url = base_url
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def register_dataset(
|
||||||
|
self,
|
||||||
|
dataset_def: DatasetDefWithProvider,
|
||||||
|
) -> None:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.post(
|
||||||
|
f"{self.base_url}/datasets/register",
|
||||||
|
json={
|
||||||
|
"dataset_def": json.loads(dataset_def.json()),
|
||||||
|
},
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
timeout=60,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return
|
||||||
|
|
||||||
|
async def get_dataset(
|
||||||
|
self,
|
||||||
|
dataset_identifier: str,
|
||||||
|
) -> Optional[DatasetDefWithProvider]:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(
|
||||||
|
f"{self.base_url}/datasets/get",
|
||||||
|
params={
|
||||||
|
"dataset_identifier": dataset_identifier,
|
||||||
|
},
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
timeout=60,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
if not response.json():
|
||||||
|
return
|
||||||
|
|
||||||
|
return DatasetDefWithProvider(**response.json())
|
||||||
|
|
||||||
|
async def list_datasets(self) -> List[DatasetDefWithProvider]:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(
|
||||||
|
f"{self.base_url}/datasets/list",
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
timeout=60,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
if not response.json():
|
||||||
|
return
|
||||||
|
|
||||||
|
return [DatasetDefWithProvider(**x) for x in response.json()]
|
||||||
|
|
||||||
|
|
||||||
|
async def run_main(host: str, port: int):
|
||||||
|
client = DatasetsClient(f"http://{host}:{port}")
|
||||||
|
|
||||||
|
# register dataset
|
||||||
|
test_file = (
|
||||||
|
Path(os.path.abspath(__file__)).parent.parent.parent
|
||||||
|
/ "providers/tests/datasetio/test_dataset.csv"
|
||||||
|
)
|
||||||
|
test_url = data_url_from_file(str(test_file))
|
||||||
|
response = await client.register_dataset(
|
||||||
|
DatasetDefWithProvider(
|
||||||
|
identifier="test-dataset",
|
||||||
|
provider_id="meta0",
|
||||||
|
url=URL(
|
||||||
|
uri=test_url,
|
||||||
|
),
|
||||||
|
dataset_schema={
|
||||||
|
"generated_answer": StringType(),
|
||||||
|
"expected_answer": StringType(),
|
||||||
|
"input_query": StringType(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# list datasets
|
||||||
|
list_dataset = await client.list_datasets()
|
||||||
|
cprint(list_dataset, "blue")
|
||||||
|
|
||||||
|
|
||||||
|
def main(host: str, port: int):
|
||||||
|
asyncio.run(run_main(host, port))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
fire.Fire(main)
|
5
llama_stack/apis/scoring/client.py
Normal file
5
llama_stack/apis/scoring/client.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
|
@ -130,12 +130,6 @@ async def resolve_impls(run_config: StackRunConfig) -> Dict[Api, Any]:
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
if info.router_api.value == "scoring":
|
|
||||||
print("SCORING API")
|
|
||||||
|
|
||||||
# p = all_api_providers[api][provider.provider_type]
|
|
||||||
# p.deps__ = [a.value for a in p.api_dependencies]
|
|
||||||
|
|
||||||
providers_with_specs[info.router_api.value] = {
|
providers_with_specs[info.router_api.value] = {
|
||||||
"__builtin__": ProviderWithSpec(
|
"__builtin__": ProviderWithSpec(
|
||||||
provider_id="__autorouted__",
|
provider_id="__autorouted__",
|
||||||
|
|
|
@ -3,16 +3,20 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||||
|
|
||||||
from .config import MetaReferenceScoringConfig
|
from .config import MetaReferenceScoringConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(
|
async def get_provider_impl(
|
||||||
config: MetaReferenceScoringConfig,
|
config: MetaReferenceScoringConfig,
|
||||||
_deps,
|
deps: Dict[Api, ProviderSpec],
|
||||||
):
|
):
|
||||||
|
print("get_provider_impl", deps)
|
||||||
from .scoring import MetaReferenceScoringImpl
|
from .scoring import MetaReferenceScoringImpl
|
||||||
|
|
||||||
impl = MetaReferenceScoringImpl(config)
|
impl = MetaReferenceScoringImpl(config, deps[Api.datasetio])
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -7,6 +7,9 @@ from typing import List
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_stack.apis.scoring import * # noqa: F403
|
from llama_stack.apis.scoring import * # noqa: F403
|
||||||
|
from llama_stack.apis.datasetio import * # noqa: F403
|
||||||
|
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
||||||
|
|
||||||
|
@ -14,9 +17,12 @@ from .config import MetaReferenceScoringConfig
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||||
def __init__(self, config: MetaReferenceScoringConfig) -> None:
|
def __init__(
|
||||||
|
self, config: MetaReferenceScoringConfig, datasetio_api: DatasetIO
|
||||||
|
) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.dataset_infos = {}
|
self.datasetio_api = datasetio_api
|
||||||
|
cprint(f"!!! MetaReferenceScoringImpl init {config} {datasetio_api}", "red")
|
||||||
|
|
||||||
async def initialize(self) -> None: ...
|
async def initialize(self) -> None: ...
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,12 @@ apis:
|
||||||
- inference
|
- inference
|
||||||
- datasets
|
- datasets
|
||||||
- datasetio
|
- datasetio
|
||||||
|
- scoring
|
||||||
providers:
|
providers:
|
||||||
|
scoring:
|
||||||
|
- provider_id: meta0
|
||||||
|
provider_type: meta-reference
|
||||||
|
config: {}
|
||||||
datasetio:
|
datasetio:
|
||||||
- provider_id: meta0
|
- provider_id: meta0
|
||||||
provider_type: meta-reference
|
provider_type: meta-reference
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue