fix endpoint, only sdk change

This commit is contained in:
Xi Yan 2025-03-15 16:15:45 -07:00
parent 13c7c5b6a1
commit 9e6d99f7b1
8 changed files with 161 additions and 72 deletions

View file

@ -40,7 +40,7 @@
} }
], ],
"paths": { "paths": {
"/v1/datasets/{dataset_id}/append-rows": { "/v1/datasetio/append-rows/{dataset_id}": {
"post": { "post": {
"responses": { "responses": {
"200": { "200": {
@ -60,7 +60,7 @@
} }
}, },
"tags": [ "tags": [
"Datasets" "DatasetIO"
], ],
"description": "", "description": "",
"parameters": [ "parameters": [
@ -2177,7 +2177,7 @@
} }
} }
}, },
"/v1/datasets/{dataset_id}/iterrows": { "/v1/datasetio/iterrows/{dataset_id}": {
"get": { "get": {
"responses": { "responses": {
"200": { "200": {
@ -2204,7 +2204,7 @@
} }
}, },
"tags": [ "tags": [
"Datasets" "DatasetIO"
], ],
"description": "Get a paginated list of rows from a dataset. Uses cursor-based pagination.", "description": "Get a paginated list of rows from a dataset. Uses cursor-based pagination.",
"parameters": [ "parameters": [
@ -10274,7 +10274,7 @@
"name": "Benchmarks" "name": "Benchmarks"
}, },
{ {
"name": "Datasets" "name": "DatasetIO"
}, },
{ {
"name": "Datasets" "name": "Datasets"
@ -10342,7 +10342,7 @@
"Agents", "Agents",
"BatchInference (Coming Soon)", "BatchInference (Coming Soon)",
"Benchmarks", "Benchmarks",
"Datasets", "DatasetIO",
"Datasets", "Datasets",
"Eval", "Eval",
"Files", "Files",

View file

@ -10,7 +10,7 @@ info:
servers: servers:
- url: http://any-hosted-llama-stack.com - url: http://any-hosted-llama-stack.com
paths: paths:
/v1/datasets/{dataset_id}/append-rows: /v1/datasetio/append-rows/{dataset_id}:
post: post:
responses: responses:
'200': '200':
@ -26,7 +26,7 @@ paths:
default: default:
$ref: '#/components/responses/DefaultError' $ref: '#/components/responses/DefaultError'
tags: tags:
- Datasets - DatasetIO
description: '' description: ''
parameters: parameters:
- name: dataset_id - name: dataset_id
@ -1457,7 +1457,7 @@ paths:
schema: schema:
$ref: '#/components/schemas/InvokeToolRequest' $ref: '#/components/schemas/InvokeToolRequest'
required: true required: true
/v1/datasets/{dataset_id}/iterrows: /v1/datasetio/iterrows/{dataset_id}:
get: get:
responses: responses:
'200': '200':
@ -1477,7 +1477,7 @@ paths:
default: default:
$ref: '#/components/responses/DefaultError' $ref: '#/components/responses/DefaultError'
tags: tags:
- Datasets - DatasetIO
description: >- description: >-
Get a paginated list of rows from a dataset. Uses cursor-based pagination. Get a paginated list of rows from a dataset. Uses cursor-based pagination.
parameters: parameters:
@ -6931,7 +6931,7 @@ tags:
Agents API for creating and interacting with agentic systems. Agents API for creating and interacting with agentic systems.
- name: BatchInference (Coming Soon) - name: BatchInference (Coming Soon)
- name: Benchmarks - name: Benchmarks
- name: Datasets - name: DatasetIO
- name: Datasets - name: Datasets
- name: Eval - name: Eval
x-displayName: >- x-displayName: >-
@ -6971,7 +6971,7 @@ x-tagGroups:
- Agents - Agents
- BatchInference (Coming Soon) - BatchInference (Coming Soon)
- Benchmarks - Benchmarks
- Datasets - DatasetIO
- Datasets - Datasets
- Eval - Eval
- Files - Files

View file

@ -552,8 +552,8 @@ class Generator:
print(op.defining_class.__name__) print(op.defining_class.__name__)
# TODO (xiyan): temporary fix for datasetio inner impl + datasets api # TODO (xiyan): temporary fix for datasetio inner impl + datasets api
if op.defining_class.__name__ in ["DatasetIO"]: # if op.defining_class.__name__ in ["DatasetIO"]:
op.defining_class.__name__ = "Datasets" # op.defining_class.__name__ = "Datasets"
doc_string = parse_type(op.func_ref) doc_string = parse_type(op.func_ref)
doc_params = dict( doc_params = dict(

View file

@ -34,7 +34,8 @@ class DatasetIO(Protocol):
# keeping for aligning with inference/safety, but this is not used # keeping for aligning with inference/safety, but this is not used
dataset_store: DatasetStore dataset_store: DatasetStore
@webmethod(route="/datasets/{dataset_id}/iterrows", method="GET") # TODO(xiyan): there's a flakiness here where setting route to "/datasets/" here will not result in proper routing
@webmethod(route="/datasetio/iterrows/{dataset_id:path}", method="GET")
async def iterrows( async def iterrows(
self, self,
dataset_id: str, dataset_id: str,
@ -49,5 +50,7 @@ class DatasetIO(Protocol):
""" """
... ...
@webmethod(route="/datasets/{dataset_id}/append-rows", method="POST") @webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST")
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: ... async def append_rows(
self, dataset_id: str, rows: List[Dict[str, Any]]
) -> None: ...

View file

@ -8,9 +8,9 @@ import time
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
URL,
InterleavedContent, InterleavedContent,
InterleavedContentItem, InterleavedContentItem,
URL,
) )
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
from llama_stack.apis.datasets import DatasetPurpose, DataSource from llama_stack.apis.datasets import DatasetPurpose, DataSource
@ -94,7 +94,9 @@ class VectorIORouter(VectorIO):
provider_id: Optional[str] = None, provider_id: Optional[str] = None,
provider_vector_db_id: Optional[str] = None, provider_vector_db_id: Optional[str] = None,
) -> None: ) -> None:
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}") logger.debug(
f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}"
)
await self.routing_table.register_vector_db( await self.routing_table.register_vector_db(
vector_db_id, vector_db_id,
embedding_model, embedding_model,
@ -112,7 +114,9 @@ class VectorIORouter(VectorIO):
logger.debug( logger.debug(
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}", f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
) )
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds) return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(
vector_db_id, chunks, ttl_seconds
)
async def query_chunks( async def query_chunks(
self, self,
@ -121,7 +125,9 @@ class VectorIORouter(VectorIO):
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
) -> QueryChunksResponse: ) -> QueryChunksResponse:
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}") logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params) return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(
vector_db_id, query, params
)
class InferenceRouter(Inference): class InferenceRouter(Inference):
@ -158,7 +164,9 @@ class InferenceRouter(Inference):
logger.debug( logger.debug(
f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}", f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}",
) )
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type) await self.routing_table.register_model(
model_id, provider_model_id, provider_id, metadata, model_type
)
def _construct_metrics( def _construct_metrics(
self, self,
@ -212,11 +220,16 @@ class InferenceRouter(Inference):
total_tokens: int, total_tokens: int,
model: Model, model: Model,
) -> List[MetricInResponse]: ) -> List[MetricInResponse]:
metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model) metrics = self._construct_metrics(
prompt_tokens, completion_tokens, total_tokens, model
)
if self.telemetry: if self.telemetry:
for metric in metrics: for metric in metrics:
await self.telemetry.log_event(metric) await self.telemetry.log_event(metric)
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics] return [
MetricInResponse(metric=metric.metric, value=metric.value)
for metric in metrics
]
async def _count_tokens( async def _count_tokens(
self, self,
@ -241,7 +254,9 @@ class InferenceRouter(Inference):
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None, tool_config: Optional[ToolConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: ) -> Union[
ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]
]:
logger.debug( logger.debug(
f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}", f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}",
) )
@ -251,12 +266,19 @@ class InferenceRouter(Inference):
if model is None: if model is None:
raise ValueError(f"Model '{model_id}' not found") raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.embedding: if model.model_type == ModelType.embedding:
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions") raise ValueError(
f"Model '{model_id}' is an embedding model and does not support chat completions"
)
if tool_config: if tool_config:
if tool_choice and tool_choice != tool_config.tool_choice: if tool_choice and tool_choice != tool_config.tool_choice:
raise ValueError("tool_choice and tool_config.tool_choice must match") raise ValueError("tool_choice and tool_config.tool_choice must match")
if tool_prompt_format and tool_prompt_format != tool_config.tool_prompt_format: if (
raise ValueError("tool_prompt_format and tool_config.tool_prompt_format must match") tool_prompt_format
and tool_prompt_format != tool_config.tool_prompt_format
):
raise ValueError(
"tool_prompt_format and tool_config.tool_prompt_format must match"
)
else: else:
params = {} params = {}
if tool_choice: if tool_choice:
@ -274,9 +296,14 @@ class InferenceRouter(Inference):
pass pass
else: else:
# verify tool_choice is one of the tools # verify tool_choice is one of the tools
tool_names = [t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value for t in tools] tool_names = [
t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value
for t in tools
]
if tool_config.tool_choice not in tool_names: if tool_config.tool_choice not in tool_names:
raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}") raise ValueError(
f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}"
)
params = dict( params = dict(
model_id=model_id, model_id=model_id,
@ -291,17 +318,25 @@ class InferenceRouter(Inference):
tool_config=tool_config, tool_config=tool_config,
) )
provider = self.routing_table.get_provider_impl(model_id) provider = self.routing_table.get_provider_impl(model_id)
prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format) prompt_tokens = await self._count_tokens(
messages, tool_config.tool_prompt_format
)
if stream: if stream:
async def stream_generator(): async def stream_generator():
completion_text = "" completion_text = ""
async for chunk in await provider.chat_completion(**params): async for chunk in await provider.chat_completion(**params):
if chunk.event.event_type == ChatCompletionResponseEventType.progress: if (
chunk.event.event_type
== ChatCompletionResponseEventType.progress
):
if chunk.event.delta.type == "text": if chunk.event.delta.type == "text":
completion_text += chunk.event.delta.text completion_text += chunk.event.delta.text
if chunk.event.event_type == ChatCompletionResponseEventType.complete: if (
chunk.event.event_type
== ChatCompletionResponseEventType.complete
):
completion_tokens = await self._count_tokens( completion_tokens = await self._count_tokens(
[ [
CompletionMessage( CompletionMessage(
@ -318,7 +353,11 @@ class InferenceRouter(Inference):
total_tokens, total_tokens,
model, model,
) )
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics chunk.metrics = (
metrics
if chunk.metrics is None
else chunk.metrics + metrics
)
yield chunk yield chunk
return stream_generator() return stream_generator()
@ -335,7 +374,9 @@ class InferenceRouter(Inference):
total_tokens, total_tokens,
model, model,
) )
response.metrics = metrics if response.metrics is None else response.metrics + metrics response.metrics = (
metrics if response.metrics is None else response.metrics + metrics
)
return response return response
async def completion( async def completion(
@ -356,7 +397,9 @@ class InferenceRouter(Inference):
if model is None: if model is None:
raise ValueError(f"Model '{model_id}' not found") raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.embedding: if model.model_type == ModelType.embedding:
raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions") raise ValueError(
f"Model '{model_id}' is an embedding model and does not support chat completions"
)
provider = self.routing_table.get_provider_impl(model_id) provider = self.routing_table.get_provider_impl(model_id)
params = dict( params = dict(
model_id=model_id, model_id=model_id,
@ -376,7 +419,11 @@ class InferenceRouter(Inference):
async for chunk in await provider.completion(**params): async for chunk in await provider.completion(**params):
if hasattr(chunk, "delta"): if hasattr(chunk, "delta"):
completion_text += chunk.delta completion_text += chunk.delta
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry: if (
hasattr(chunk, "stop_reason")
and chunk.stop_reason
and self.telemetry
):
completion_tokens = await self._count_tokens(completion_text) completion_tokens = await self._count_tokens(completion_text)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
metrics = await self._compute_and_log_token_usage( metrics = await self._compute_and_log_token_usage(
@ -385,7 +432,11 @@ class InferenceRouter(Inference):
total_tokens, total_tokens,
model, model,
) )
chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics chunk.metrics = (
metrics
if chunk.metrics is None
else chunk.metrics + metrics
)
yield chunk yield chunk
return stream_generator() return stream_generator()
@ -399,7 +450,9 @@ class InferenceRouter(Inference):
total_tokens, total_tokens,
model, model,
) )
response.metrics = metrics if response.metrics is None else response.metrics + metrics response.metrics = (
metrics if response.metrics is None else response.metrics + metrics
)
return response return response
async def embeddings( async def embeddings(
@ -415,7 +468,9 @@ class InferenceRouter(Inference):
if model is None: if model is None:
raise ValueError(f"Model '{model_id}' not found") raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.llm: if model.model_type == ModelType.llm:
raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings") raise ValueError(
f"Model '{model_id}' is an LLM model and does not support embeddings"
)
return await self.routing_table.get_provider_impl(model_id).embeddings( return await self.routing_table.get_provider_impl(model_id).embeddings(
model_id=model_id, model_id=model_id,
contents=contents, contents=contents,
@ -449,7 +504,9 @@ class SafetyRouter(Safety):
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
) -> Shield: ) -> Shield:
logger.debug(f"SafetyRouter.register_shield: {shield_id}") logger.debug(f"SafetyRouter.register_shield: {shield_id}")
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params) return await self.routing_table.register_shield(
shield_id, provider_shield_id, provider_id, params
)
async def run_shield( async def run_shield(
self, self,
@ -546,7 +603,9 @@ class ScoringRouter(Scoring):
logger.debug(f"ScoringRouter.score_batch: {dataset_id}") logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
res = {} res = {}
for fn_identifier in scoring_functions.keys(): for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch( score_response = await self.routing_table.get_provider_impl(
fn_identifier
).score_batch(
dataset_id=dataset_id, dataset_id=dataset_id,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
) )
@ -564,11 +623,15 @@ class ScoringRouter(Scoring):
input_rows: List[Dict[str, Any]], input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
) -> ScoreResponse: ) -> ScoreResponse:
logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions") logger.debug(
f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions"
)
res = {} res = {}
# look up and map each scoring function to its provider impl # look up and map each scoring function to its provider impl
for fn_identifier in scoring_functions.keys(): for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(fn_identifier).score( score_response = await self.routing_table.get_provider_impl(
fn_identifier
).score(
input_rows=input_rows, input_rows=input_rows,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
) )
@ -611,7 +674,9 @@ class EvalRouter(Eval):
scoring_functions: List[str], scoring_functions: List[str],
benchmark_config: BenchmarkConfig, benchmark_config: BenchmarkConfig,
) -> EvaluateResponse: ) -> EvaluateResponse:
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows") logger.debug(
f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows"
)
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows( return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
benchmark_id=benchmark_id, benchmark_id=benchmark_id,
input_rows=input_rows, input_rows=input_rows,
@ -625,7 +690,9 @@ class EvalRouter(Eval):
job_id: str, job_id: str,
) -> Optional[JobStatus]: ) -> Optional[JobStatus]:
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}") logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id) return await self.routing_table.get_provider_impl(benchmark_id).job_status(
benchmark_id, job_id
)
async def job_cancel( async def job_cancel(
self, self,
@ -679,9 +746,9 @@ class ToolRuntimeRouter(ToolRuntime):
logger.debug( logger.debug(
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}" f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
) )
return await self.routing_table.get_provider_impl("insert_into_memory").insert( return await self.routing_table.get_provider_impl(
documents, vector_db_id, chunk_size_in_tokens "insert_into_memory"
) ).insert(documents, vector_db_id, chunk_size_in_tokens)
def __init__( def __init__(
self, self,
@ -714,4 +781,6 @@ class ToolRuntimeRouter(ToolRuntime):
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]: ) -> List[ToolDef]:
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}") logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint) return await self.routing_table.get_provider_impl(tool_group_id).list_tools(
tool_group_id, mcp_endpoint
)

View file

@ -5,6 +5,8 @@
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from urllib.parse import parse_qs, urlparse
import datasets as hf_datasets import datasets as hf_datasets
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
@ -16,24 +18,17 @@ from llama_stack.providers.utils.kvstore import kvstore_impl
from .config import HuggingfaceDatasetIOConfig from .config import HuggingfaceDatasetIOConfig
DATASETS_PREFIX = "datasets:" DATASETS_PREFIX = "datasets:"
from rich.pretty import pprint
def load_hf_dataset(dataset_def: Dataset): def parse_hf_params(dataset_def: Dataset):
if dataset_def.metadata.get("path", None): uri = dataset_def.source.uri
dataset = hf_datasets.load_dataset(**dataset_def.metadata) parsed_uri = urlparse(uri)
else: params = parse_qs(parsed_uri.query)
df = get_dataframe_from_url(dataset_def.url) params = {k: v[0] for k, v in params.items()}
path = parsed_uri.path.lstrip("/")
if df is None: return path, params
raise ValueError(f"Failed to load dataset from {dataset_def.url}")
dataset = hf_datasets.Dataset.from_pandas(df)
# drop columns not specified by schema
if dataset_def.dataset_schema:
dataset = dataset.select_columns(list(dataset_def.dataset_schema.keys()))
return dataset
class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
@ -60,6 +55,7 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
self, self,
dataset_def: Dataset, dataset_def: Dataset,
) -> None: ) -> None:
print("register_dataset")
# Store in kvstore # Store in kvstore
key = f"{DATASETS_PREFIX}{dataset_def.identifier}" key = f"{DATASETS_PREFIX}{dataset_def.identifier}"
await self.kvstore.set( await self.kvstore.set(
@ -80,7 +76,8 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
limit: Optional[int] = None, limit: Optional[int] = None,
) -> IterrowsResponse: ) -> IterrowsResponse:
dataset_def = self.dataset_infos[dataset_id] dataset_def = self.dataset_infos[dataset_id]
loaded_dataset = load_hf_dataset(dataset_def) path, params = parse_hf_params(dataset_def)
loaded_dataset = hf_datasets.load_dataset(path, **params)
start_index = start_index or 0 start_index = start_index or 0
@ -98,15 +95,20 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
dataset_def = self.dataset_infos[dataset_id] dataset_def = self.dataset_infos[dataset_id]
loaded_dataset = load_hf_dataset(dataset_def) path, params = parse_hf_params(dataset_def)
loaded_dataset = hf_datasets.load_dataset(path, **params)
# Convert rows to HF Dataset format # Convert rows to HF Dataset format
new_dataset = hf_datasets.Dataset.from_list(rows) new_dataset = hf_datasets.Dataset.from_list(rows)
# Concatenate the new rows with existing dataset # Concatenate the new rows with existing dataset
updated_dataset = hf_datasets.concatenate_datasets([loaded_dataset, new_dataset]) updated_dataset = hf_datasets.concatenate_datasets(
[loaded_dataset, new_dataset]
)
if dataset_def.metadata.get("path", None): if dataset_def.metadata.get("path", None):
updated_dataset.push_to_hub(dataset_def.metadata["path"]) updated_dataset.push_to_hub(dataset_def.metadata["path"])
else: else:
raise NotImplementedError("Uploading to URL-based datasets is not supported yet") raise NotImplementedError(
"Uploading to URL-based datasets is not supported yet"
)

View file

@ -19,6 +19,15 @@ import pytest
def test_register_dataset(llama_stack_client): def test_register_dataset(llama_stack_client):
dataset = llama_stack_client.datasets.register( dataset = llama_stack_client.datasets.register(
purpose="eval/messages-answer", purpose="eval/messages-answer",
source={"type": "uri", "uri": "huggingface://llamastack/simpleqa?split=train"}, source={
"type": "uri",
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
},
) )
print(dataset) assert dataset.identifier is not None
assert dataset.provider_id == "huggingface"
iterrow_response = llama_stack_client.datasets.iterrows(
dataset.identifier, limit=10
)
assert len(iterrow_response.data) == 10
assert iterrow_response.next_index is not None

View file

@ -6,9 +6,15 @@ def test_register_dataset():
client = LlamaStackClient(base_url="http://localhost:8321") client = LlamaStackClient(base_url="http://localhost:8321")
dataset = client.datasets.register( dataset = client.datasets.register(
purpose="eval/messages-answer", purpose="eval/messages-answer",
source={"type": "uri", "uri": "huggingface://llamastack/simpleqa?split=train"}, source={
"type": "uri",
"uri": "huggingface://datasets/llamastack/simpleqa?split=train",
},
) )
dataset_id = dataset.identifier
pprint(dataset) pprint(dataset)
rows = client.datasets.iterrows(dataset_id=dataset_id, limit=10)
pprint(rows)
if __name__ == "__main__": if __name__ == "__main__":