feat(api): (1/n) datasets api clean up (#1573)

## PR Stack
- https://github.com/meta-llama/llama-stack/pull/1573
- https://github.com/meta-llama/llama-stack/pull/1625
- https://github.com/meta-llama/llama-stack/pull/1656
- https://github.com/meta-llama/llama-stack/pull/1657
- https://github.com/meta-llama/llama-stack/pull/1658
- https://github.com/meta-llama/llama-stack/pull/1659
- https://github.com/meta-llama/llama-stack/pull/1660

**Client SDK**
- https://github.com/meta-llama/llama-stack-client-python/pull/203

**CI**
- 1391130488
<img width="1042" alt="image"
src="https://github.com/user-attachments/assets/69636067-376d-436b-9204-896e2dd490ca"
/>
-- the test_rag_agent_with_attachments is flaky and not related to this
PR

## Doc
<img width="789" alt="image"
src="https://github.com/user-attachments/assets/b88390f3-73d6-4483-b09a-a192064e32d9"
/>


## Client Usage
```python
client.datasets.register(
    source={
        "type": "uri",
        "uri": "lsfs://mydata.jsonl",
    },
    schema="jsonl_messages",
    # optional 
    dataset_id="my_first_train_data"
)

# quick prototype debugging
client.datasets.register(
    data_reference={
        "type": "rows",
        "rows": [
                "messages": [...],
        ],
    },
    schema="jsonl_messages",
)
```

## Test Plan
- CI:
1387805545

```
LLAMA_STACK_CONFIG=fireworks pytest -v tests/integration/datasets/test_datasets.py
```

```
LLAMA_STACK_CONFIG=fireworks pytest -v tests/integration/scoring/test_scoring.py
```

```
pytest -v -s --nbval-lax ./docs/notebooks/Llama_Stack_Benchmark_Evals.ipynb
```
This commit is contained in:
Xi Yan 2025-03-17 16:55:45 -07:00 committed by GitHub
parent 3b35a39b8b
commit 5287b437ae
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
29 changed files with 2593 additions and 2296 deletions

View file

@ -14,16 +14,11 @@ from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.inference import Inference, SystemMessage, UserMessage
from llama_stack.apis.scoring import Scoring
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
MEMORY_QUERY_TOOL,
)
from llama_stack.providers.utils.common.data_schema_validator import (
ColumnName,
get_valid_schemas,
validate_dataset_schema,
)
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
from llama_stack.providers.utils.kvstore import kvstore_impl
from .....apis.common.job_types import Job
@ -88,15 +83,17 @@ class MetaReferenceEvalImpl(
task_def = self.benchmarks[benchmark_id]
dataset_id = task_def.dataset_id
scoring_functions = task_def.scoring_functions
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.eval.value))
all_rows = await self.datasetio_api.get_rows_paginated(
# TODO (xiyan): validate dataset schema
# dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id,
rows_in_page=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples),
limit=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples),
)
res = await self.evaluate_rows(
benchmark_id=benchmark_id,
input_rows=all_rows.rows,
input_rows=all_rows.data,
scoring_functions=scoring_functions,
benchmark_config=benchmark_config,
)