From 9e6d99f7b196248417ee6d3c02c73e4b65acca90 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Sat, 15 Mar 2025 16:15:45 -0700 Subject: [PATCH] fix endpoint, only sdk change --- docs/_static/llama-stack-spec.html | 12 +- docs/_static/llama-stack-spec.yaml | 12 +- docs/openapi_generator/pyopenapi/generator.py | 4 +- llama_stack/apis/datasetio/datasetio.py | 9 +- llama_stack/distribution/routers/routers.py | 135 +++++++++++++----- .../datasetio/huggingface/huggingface.py | 40 +++--- tests/integration/datasets/test_datasets.py | 13 +- tests/integration/datasets/test_script.py | 8 +- 8 files changed, 161 insertions(+), 72 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 462034c3d..e3c81ddb9 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -40,7 +40,7 @@ } ], "paths": { - "/v1/datasets/{dataset_id}/append-rows": { + "/v1/datasetio/append-rows/{dataset_id}": { "post": { "responses": { "200": { @@ -60,7 +60,7 @@ } }, "tags": [ - "Datasets" + "DatasetIO" ], "description": "", "parameters": [ @@ -2177,7 +2177,7 @@ } } }, - "/v1/datasets/{dataset_id}/iterrows": { + "/v1/datasetio/iterrows/{dataset_id}": { "get": { "responses": { "200": { @@ -2204,7 +2204,7 @@ } }, "tags": [ - "Datasets" + "DatasetIO" ], "description": "Get a paginated list of rows from a dataset. Uses cursor-based pagination.", "parameters": [ @@ -10274,7 +10274,7 @@ "name": "Benchmarks" }, { - "name": "Datasets" + "name": "DatasetIO" }, { "name": "Datasets" @@ -10342,7 +10342,7 @@ "Agents", "BatchInference (Coming Soon)", "Benchmarks", - "Datasets", + "DatasetIO", "Datasets", "Eval", "Files", diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 79adc221a..a3d4dbcc9 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -10,7 +10,7 @@ info: servers: - url: http://any-hosted-llama-stack.com paths: - /v1/datasets/{dataset_id}/append-rows: + /v1/datasetio/append-rows/{dataset_id}: post: responses: '200': @@ -26,7 +26,7 @@ paths: default: $ref: '#/components/responses/DefaultError' tags: - - Datasets + - DatasetIO description: '' parameters: - name: dataset_id @@ -1457,7 +1457,7 @@ paths: schema: $ref: '#/components/schemas/InvokeToolRequest' required: true - /v1/datasets/{dataset_id}/iterrows: + /v1/datasetio/iterrows/{dataset_id}: get: responses: '200': @@ -1477,7 +1477,7 @@ paths: default: $ref: '#/components/responses/DefaultError' tags: - - Datasets + - DatasetIO description: >- Get a paginated list of rows from a dataset. Uses cursor-based pagination. parameters: @@ -6931,7 +6931,7 @@ tags: Agents API for creating and interacting with agentic systems. - name: BatchInference (Coming Soon) - name: Benchmarks - - name: Datasets + - name: DatasetIO - name: Datasets - name: Eval x-displayName: >- @@ -6971,7 +6971,7 @@ x-tagGroups: - Agents - BatchInference (Coming Soon) - Benchmarks - - Datasets + - DatasetIO - Datasets - Eval - Files diff --git a/docs/openapi_generator/pyopenapi/generator.py b/docs/openapi_generator/pyopenapi/generator.py index a7ee87125..02a4776e4 100644 --- a/docs/openapi_generator/pyopenapi/generator.py +++ b/docs/openapi_generator/pyopenapi/generator.py @@ -552,8 +552,8 @@ class Generator: print(op.defining_class.__name__) # TODO (xiyan): temporary fix for datasetio inner impl + datasets api - if op.defining_class.__name__ in ["DatasetIO"]: - op.defining_class.__name__ = "Datasets" + # if op.defining_class.__name__ in ["DatasetIO"]: + # op.defining_class.__name__ = "Datasets" doc_string = parse_type(op.func_ref) doc_params = dict( diff --git a/llama_stack/apis/datasetio/datasetio.py b/llama_stack/apis/datasetio/datasetio.py index e7073fd29..8545a7189 100644 --- a/llama_stack/apis/datasetio/datasetio.py +++ b/llama_stack/apis/datasetio/datasetio.py @@ -34,7 +34,8 @@ class DatasetIO(Protocol): # keeping for aligning with inference/safety, but this is not used dataset_store: DatasetStore - @webmethod(route="/datasets/{dataset_id}/iterrows", method="GET") + # TODO(xiyan): there's a flakiness here where setting route to "/datasets/" here will not result in proper routing + @webmethod(route="/datasetio/iterrows/{dataset_id:path}", method="GET") async def iterrows( self, dataset_id: str, @@ -49,5 +50,7 @@ class DatasetIO(Protocol): """ ... - @webmethod(route="/datasets/{dataset_id}/append-rows", method="POST") - async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: ... + @webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST") + async def append_rows( + self, dataset_id: str, rows: List[Dict[str, Any]] + ) -> None: ... diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 2cf38f544..879fc924b 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -8,9 +8,9 @@ import time from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union from llama_stack.apis.common.content_types import ( - URL, InterleavedContent, InterleavedContentItem, + URL, ) from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse from llama_stack.apis.datasets import DatasetPurpose, DataSource @@ -94,7 +94,9 @@ class VectorIORouter(VectorIO): provider_id: Optional[str] = None, provider_vector_db_id: Optional[str] = None, ) -> None: - logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}") + logger.debug( + f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}" + ) await self.routing_table.register_vector_db( vector_db_id, embedding_model, @@ -112,7 +114,9 @@ class VectorIORouter(VectorIO): logger.debug( f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}", ) - return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds) + return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks( + vector_db_id, chunks, ttl_seconds + ) async def query_chunks( self, @@ -121,7 +125,9 @@ class VectorIORouter(VectorIO): params: Optional[Dict[str, Any]] = None, ) -> QueryChunksResponse: logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}") - return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params) + return await self.routing_table.get_provider_impl(vector_db_id).query_chunks( + vector_db_id, query, params + ) class InferenceRouter(Inference): @@ -158,7 +164,9 @@ class InferenceRouter(Inference): logger.debug( f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}", ) - await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type) + await self.routing_table.register_model( + model_id, provider_model_id, provider_id, metadata, model_type + ) def _construct_metrics( self, @@ -212,11 +220,16 @@ class InferenceRouter(Inference): total_tokens: int, model: Model, ) -> List[MetricInResponse]: - metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model) + metrics = self._construct_metrics( + prompt_tokens, completion_tokens, total_tokens, model + ) if self.telemetry: for metric in metrics: await self.telemetry.log_event(metric) - return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics] + return [ + MetricInResponse(metric=metric.metric, value=metric.value) + for metric in metrics + ] async def _count_tokens( self, @@ -241,7 +254,9 @@ class InferenceRouter(Inference): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, tool_config: Optional[ToolConfig] = None, - ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: + ) -> Union[ + ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk] + ]: logger.debug( f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}", ) @@ -251,12 +266,19 @@ class InferenceRouter(Inference): if model is None: raise ValueError(f"Model '{model_id}' not found") if model.model_type == ModelType.embedding: - raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions") + raise ValueError( + f"Model '{model_id}' is an embedding model and does not support chat completions" + ) if tool_config: if tool_choice and tool_choice != tool_config.tool_choice: raise ValueError("tool_choice and tool_config.tool_choice must match") - if tool_prompt_format and tool_prompt_format != tool_config.tool_prompt_format: - raise ValueError("tool_prompt_format and tool_config.tool_prompt_format must match") + if ( + tool_prompt_format + and tool_prompt_format != tool_config.tool_prompt_format + ): + raise ValueError( + "tool_prompt_format and tool_config.tool_prompt_format must match" + ) else: params = {} if tool_choice: @@ -274,9 +296,14 @@ class InferenceRouter(Inference): pass else: # verify tool_choice is one of the tools - tool_names = [t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value for t in tools] + tool_names = [ + t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value + for t in tools + ] if tool_config.tool_choice not in tool_names: - raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}") + raise ValueError( + f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}" + ) params = dict( model_id=model_id, @@ -291,17 +318,25 @@ class InferenceRouter(Inference): tool_config=tool_config, ) provider = self.routing_table.get_provider_impl(model_id) - prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format) + prompt_tokens = await self._count_tokens( + messages, tool_config.tool_prompt_format + ) if stream: async def stream_generator(): completion_text = "" async for chunk in await provider.chat_completion(**params): - if chunk.event.event_type == ChatCompletionResponseEventType.progress: + if ( + chunk.event.event_type + == ChatCompletionResponseEventType.progress + ): if chunk.event.delta.type == "text": completion_text += chunk.event.delta.text - if chunk.event.event_type == ChatCompletionResponseEventType.complete: + if ( + chunk.event.event_type + == ChatCompletionResponseEventType.complete + ): completion_tokens = await self._count_tokens( [ CompletionMessage( @@ -318,7 +353,11 @@ class InferenceRouter(Inference): total_tokens, model, ) - chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics + chunk.metrics = ( + metrics + if chunk.metrics is None + else chunk.metrics + metrics + ) yield chunk return stream_generator() @@ -335,7 +374,9 @@ class InferenceRouter(Inference): total_tokens, model, ) - response.metrics = metrics if response.metrics is None else response.metrics + metrics + response.metrics = ( + metrics if response.metrics is None else response.metrics + metrics + ) return response async def completion( @@ -356,7 +397,9 @@ class InferenceRouter(Inference): if model is None: raise ValueError(f"Model '{model_id}' not found") if model.model_type == ModelType.embedding: - raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions") + raise ValueError( + f"Model '{model_id}' is an embedding model and does not support chat completions" + ) provider = self.routing_table.get_provider_impl(model_id) params = dict( model_id=model_id, @@ -376,7 +419,11 @@ class InferenceRouter(Inference): async for chunk in await provider.completion(**params): if hasattr(chunk, "delta"): completion_text += chunk.delta - if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry: + if ( + hasattr(chunk, "stop_reason") + and chunk.stop_reason + and self.telemetry + ): completion_tokens = await self._count_tokens(completion_text) total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) metrics = await self._compute_and_log_token_usage( @@ -385,7 +432,11 @@ class InferenceRouter(Inference): total_tokens, model, ) - chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics + chunk.metrics = ( + metrics + if chunk.metrics is None + else chunk.metrics + metrics + ) yield chunk return stream_generator() @@ -399,7 +450,9 @@ class InferenceRouter(Inference): total_tokens, model, ) - response.metrics = metrics if response.metrics is None else response.metrics + metrics + response.metrics = ( + metrics if response.metrics is None else response.metrics + metrics + ) return response async def embeddings( @@ -415,7 +468,9 @@ class InferenceRouter(Inference): if model is None: raise ValueError(f"Model '{model_id}' not found") if model.model_type == ModelType.llm: - raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings") + raise ValueError( + f"Model '{model_id}' is an LLM model and does not support embeddings" + ) return await self.routing_table.get_provider_impl(model_id).embeddings( model_id=model_id, contents=contents, @@ -449,7 +504,9 @@ class SafetyRouter(Safety): params: Optional[Dict[str, Any]] = None, ) -> Shield: logger.debug(f"SafetyRouter.register_shield: {shield_id}") - return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params) + return await self.routing_table.register_shield( + shield_id, provider_shield_id, provider_id, params + ) async def run_shield( self, @@ -546,7 +603,9 @@ class ScoringRouter(Scoring): logger.debug(f"ScoringRouter.score_batch: {dataset_id}") res = {} for fn_identifier in scoring_functions.keys(): - score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch( + score_response = await self.routing_table.get_provider_impl( + fn_identifier + ).score_batch( dataset_id=dataset_id, scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, ) @@ -564,11 +623,15 @@ class ScoringRouter(Scoring): input_rows: List[Dict[str, Any]], scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, ) -> ScoreResponse: - logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions") + logger.debug( + f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions" + ) res = {} # look up and map each scoring function to its provider impl for fn_identifier in scoring_functions.keys(): - score_response = await self.routing_table.get_provider_impl(fn_identifier).score( + score_response = await self.routing_table.get_provider_impl( + fn_identifier + ).score( input_rows=input_rows, scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, ) @@ -611,7 +674,9 @@ class EvalRouter(Eval): scoring_functions: List[str], benchmark_config: BenchmarkConfig, ) -> EvaluateResponse: - logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows") + logger.debug( + f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows" + ) return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows( benchmark_id=benchmark_id, input_rows=input_rows, @@ -625,7 +690,9 @@ class EvalRouter(Eval): job_id: str, ) -> Optional[JobStatus]: logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}") - return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id) + return await self.routing_table.get_provider_impl(benchmark_id).job_status( + benchmark_id, job_id + ) async def job_cancel( self, @@ -679,9 +746,9 @@ class ToolRuntimeRouter(ToolRuntime): logger.debug( f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}" ) - return await self.routing_table.get_provider_impl("insert_into_memory").insert( - documents, vector_db_id, chunk_size_in_tokens - ) + return await self.routing_table.get_provider_impl( + "insert_into_memory" + ).insert(documents, vector_db_id, chunk_size_in_tokens) def __init__( self, @@ -714,4 +781,6 @@ class ToolRuntimeRouter(ToolRuntime): self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None ) -> List[ToolDef]: logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}") - return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint) + return await self.routing_table.get_provider_impl(tool_group_id).list_tools( + tool_group_id, mcp_endpoint + ) diff --git a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py index d59edda30..82a76f8bc 100644 --- a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py +++ b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py @@ -5,6 +5,8 @@ # the root directory of this source tree. from typing import Any, Dict, List, Optional +from urllib.parse import parse_qs, urlparse + import datasets as hf_datasets from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse @@ -16,24 +18,17 @@ from llama_stack.providers.utils.kvstore import kvstore_impl from .config import HuggingfaceDatasetIOConfig DATASETS_PREFIX = "datasets:" +from rich.pretty import pprint -def load_hf_dataset(dataset_def: Dataset): - if dataset_def.metadata.get("path", None): - dataset = hf_datasets.load_dataset(**dataset_def.metadata) - else: - df = get_dataframe_from_url(dataset_def.url) +def parse_hf_params(dataset_def: Dataset): + uri = dataset_def.source.uri + parsed_uri = urlparse(uri) + params = parse_qs(parsed_uri.query) + params = {k: v[0] for k, v in params.items()} + path = parsed_uri.path.lstrip("/") - if df is None: - raise ValueError(f"Failed to load dataset from {dataset_def.url}") - - dataset = hf_datasets.Dataset.from_pandas(df) - - # drop columns not specified by schema - if dataset_def.dataset_schema: - dataset = dataset.select_columns(list(dataset_def.dataset_schema.keys())) - - return dataset + return path, params class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): @@ -60,6 +55,7 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): self, dataset_def: Dataset, ) -> None: + print("register_dataset") # Store in kvstore key = f"{DATASETS_PREFIX}{dataset_def.identifier}" await self.kvstore.set( @@ -80,7 +76,8 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): limit: Optional[int] = None, ) -> IterrowsResponse: dataset_def = self.dataset_infos[dataset_id] - loaded_dataset = load_hf_dataset(dataset_def) + path, params = parse_hf_params(dataset_def) + loaded_dataset = hf_datasets.load_dataset(path, **params) start_index = start_index or 0 @@ -98,15 +95,20 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: dataset_def = self.dataset_infos[dataset_id] - loaded_dataset = load_hf_dataset(dataset_def) + path, params = parse_hf_params(dataset_def) + loaded_dataset = hf_datasets.load_dataset(path, **params) # Convert rows to HF Dataset format new_dataset = hf_datasets.Dataset.from_list(rows) # Concatenate the new rows with existing dataset - updated_dataset = hf_datasets.concatenate_datasets([loaded_dataset, new_dataset]) + updated_dataset = hf_datasets.concatenate_datasets( + [loaded_dataset, new_dataset] + ) if dataset_def.metadata.get("path", None): updated_dataset.push_to_hub(dataset_def.metadata["path"]) else: - raise NotImplementedError("Uploading to URL-based datasets is not supported yet") + raise NotImplementedError( + "Uploading to URL-based datasets is not supported yet" + ) diff --git a/tests/integration/datasets/test_datasets.py b/tests/integration/datasets/test_datasets.py index 38abb54c9..0715be06e 100644 --- a/tests/integration/datasets/test_datasets.py +++ b/tests/integration/datasets/test_datasets.py @@ -19,6 +19,15 @@ import pytest def test_register_dataset(llama_stack_client): dataset = llama_stack_client.datasets.register( purpose="eval/messages-answer", - source={"type": "uri", "uri": "huggingface://llamastack/simpleqa?split=train"}, + source={ + "type": "uri", + "uri": "huggingface://datasets/llamastack/simpleqa?split=train", + }, ) - print(dataset) + assert dataset.identifier is not None + assert dataset.provider_id == "huggingface" + iterrow_response = llama_stack_client.datasets.iterrows( + dataset.identifier, limit=10 + ) + assert len(iterrow_response.data) == 10 + assert iterrow_response.next_index is not None diff --git a/tests/integration/datasets/test_script.py b/tests/integration/datasets/test_script.py index a3f5f626b..afdde25ca 100644 --- a/tests/integration/datasets/test_script.py +++ b/tests/integration/datasets/test_script.py @@ -6,9 +6,15 @@ def test_register_dataset(): client = LlamaStackClient(base_url="http://localhost:8321") dataset = client.datasets.register( purpose="eval/messages-answer", - source={"type": "uri", "uri": "huggingface://llamastack/simpleqa?split=train"}, + source={ + "type": "uri", + "uri": "huggingface://datasets/llamastack/simpleqa?split=train", + }, ) + dataset_id = dataset.identifier pprint(dataset) + rows = client.datasets.iterrows(dataset_id=dataset_id, limit=10) + pprint(rows) if __name__ == "__main__":