feat(api): (1/n) datasets api clean up (#1573)

## PR Stack
- https://github.com/meta-llama/llama-stack/pull/1573
- https://github.com/meta-llama/llama-stack/pull/1625
- https://github.com/meta-llama/llama-stack/pull/1656
- https://github.com/meta-llama/llama-stack/pull/1657
- https://github.com/meta-llama/llama-stack/pull/1658
- https://github.com/meta-llama/llama-stack/pull/1659
- https://github.com/meta-llama/llama-stack/pull/1660

**Client SDK**
- https://github.com/meta-llama/llama-stack-client-python/pull/203

**CI**
- 1391130488
<img width="1042" alt="image"
src="https://github.com/user-attachments/assets/69636067-376d-436b-9204-896e2dd490ca"
/>
-- the test_rag_agent_with_attachments is flaky and not related to this
PR

## Doc
<img width="789" alt="image"
src="https://github.com/user-attachments/assets/b88390f3-73d6-4483-b09a-a192064e32d9"
/>


## Client Usage
```python
client.datasets.register(
    source={
        "type": "uri",
        "uri": "lsfs://mydata.jsonl",
    },
    schema="jsonl_messages",
    # optional 
    dataset_id="my_first_train_data"
)

# quick prototype debugging
client.datasets.register(
    data_reference={
        "type": "rows",
        "rows": [
                "messages": [...],
        ],
    },
    schema="jsonl_messages",
)
```

## Test Plan
- CI:
1387805545

```
LLAMA_STACK_CONFIG=fireworks pytest -v tests/integration/datasets/test_datasets.py
```

```
LLAMA_STACK_CONFIG=fireworks pytest -v tests/integration/scoring/test_scoring.py
```

```
pytest -v -s --nbval-lax ./docs/notebooks/Llama_Stack_Benchmark_Evals.ipynb
```
This commit is contained in:
Xi Yan 2025-03-17 16:55:45 -07:00 committed by GitHub
parent 3b35a39b8b
commit 5287b437ae
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
29 changed files with 2593 additions and 2296 deletions

View file

@ -12,7 +12,8 @@ from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
from llama_stack.apis.datasets import DatasetPurpose, DataSource
from llama_stack.apis.eval import (
BenchmarkConfig,
Eval,
@ -160,7 +161,11 @@ class InferenceRouter(Inference):
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
def _construct_metrics(
self, prompt_tokens: int, completion_tokens: int, total_tokens: int, model: Model
self,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
model: Model,
) -> List[MetricEvent]:
"""Constructs a list of MetricEvent objects containing token usage metrics.
@ -298,7 +303,12 @@ class InferenceRouter(Inference):
completion_text += chunk.event.delta.text
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
completion_tokens = await self._count_tokens(
[CompletionMessage(content=completion_text, stop_reason=StopReason.end_of_turn)],
[
CompletionMessage(
content=completion_text,
stop_reason=StopReason.end_of_turn,
)
],
tool_config.tool_prompt_format,
)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
@ -471,21 +481,36 @@ class DatasetIORouter(DatasetIO):
logger.debug("DatasetIORouter.shutdown")
pass
async def get_rows_paginated(
async def register_dataset(
self,
purpose: DatasetPurpose,
source: DataSource,
metadata: Optional[Dict[str, Any]] = None,
dataset_id: Optional[str] = None,
) -> None:
logger.debug(
f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}",
)
await self.routing_table.register_dataset(
purpose=purpose,
source=source,
metadata=metadata,
dataset_id=dataset_id,
)
async def iterrows(
self,
dataset_id: str,
rows_in_page: int,
page_token: Optional[str] = None,
filter_condition: Optional[str] = None,
) -> PaginatedRowsResult:
start_index: Optional[int] = None,
limit: Optional[int] = None,
) -> IterrowsResponse:
logger.debug(
f"DatasetIORouter.get_rows_paginated: {dataset_id}, rows_in_page={rows_in_page}",
f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
)
return await self.routing_table.get_provider_impl(dataset_id).get_rows_paginated(
return await self.routing_table.get_provider_impl(dataset_id).iterrows(
dataset_id=dataset_id,
rows_in_page=rows_in_page,
page_token=page_token,
filter_condition=filter_condition,
start_index=start_index,
limit=limit,
)
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: