diff --git a/llama_stack/apis/datasetio/client.py b/llama_stack/apis/datasetio/client.py new file mode 100644 index 000000000..b62db9085 --- /dev/null +++ b/llama_stack/apis/datasetio/client.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +import os +from pathlib import Path +from typing import Optional + +import fire +import httpx +from termcolor import cprint + +from llama_stack.apis.datasets import * # noqa: F403 +from llama_stack.apis.datasetio import * # noqa: F403 +from llama_stack.apis.common.type_system import * # noqa: F403 +from llama_stack.apis.datasets.client import DatasetsClient +from llama_stack.providers.tests.datasetio.test_datasetio import data_url_from_file + + +class DatasetIOClient(DatasetIO): + def __init__(self, base_url: str): + self.base_url = base_url + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def get_rows_paginated( + self, + dataset_id: str, + rows_in_page: int, + page_token: Optional[str] = None, + filter_condition: Optional[str] = None, + ) -> PaginatedRowsResult: + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.base_url}/datasetio/get_rows_paginated", + params={ + "dataset_id": dataset_id, + "rows_in_page": rows_in_page, + "page_token": page_token, + "filter_condition": filter_condition, + }, + headers={"Content-Type": "application/json"}, + timeout=60, + ) + response.raise_for_status() + if not response.json(): + return + + return PaginatedRowsResult(**response.json()) + + +async def run_main(host: str, port: int): + client = DatasetsClient(f"http://{host}:{port}") + + # register dataset + test_file = ( + Path(os.path.abspath(__file__)).parent.parent.parent + / "providers/tests/datasetio/test_dataset.csv" + ) + test_url = data_url_from_file(str(test_file)) + response = await client.register_dataset( + DatasetDefWithProvider( + identifier="test-dataset", + provider_id="meta0", + url=URL( + uri=test_url, + ), + dataset_schema={ + "generated_answer": StringType(), + "expected_answer": StringType(), + "input_query": StringType(), + }, + ) + ) + + # list datasets + list_dataset = await client.list_datasets() + cprint(list_dataset, "blue") + + # datsetio client to get the rows + datasetio_client = DatasetIOClient(f"http://{host}:{port}") + response = await datasetio_client.get_rows_paginated( + dataset_id="test-dataset", + rows_in_page=4, + page_token=None, + filter_condition=None, + ) + cprint(f"Returned {len(response.rows)} rows \n {response}", "green") + + +def main(host: str, port: int): + asyncio.run(run_main(host, port)) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/llama_stack/apis/datasetio/datasetio.py b/llama_stack/apis/datasetio/datasetio.py index e8811d233..b321b260e 100644 --- a/llama_stack/apis/datasetio/datasetio.py +++ b/llama_stack/apis/datasetio/datasetio.py @@ -29,7 +29,7 @@ class DatasetIO(Protocol): # keeping for aligning with inference/safety, but this is not used dataset_store: DatasetStore - @webmethod(route="/dataio/get_rows_paginated") + @webmethod(route="/datasetio/get_rows_paginated", method="GET") async def get_rows_paginated( self, dataset_id: str, diff --git a/llama_stack/providers/impls/meta_reference/datasetio/datasetio.py b/llama_stack/providers/impls/meta_reference/datasetio/datasetio.py index 938cc23e3..f38311a0c 100644 --- a/llama_stack/providers/impls/meta_reference/datasetio/datasetio.py +++ b/llama_stack/providers/impls/meta_reference/datasetio/datasetio.py @@ -142,7 +142,7 @@ class MetaReferenceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): dataset_info = self.dataset_infos.get(dataset_id) dataset_info.dataset_impl.load() - if page_token is None: + if page_token is None or not page_token.isnumeric(): next_page_token = 0 else: next_page_token = int(page_token)