feat(dataset api): (1.1/n) dataset api implementation fix pre-commit (#1625)

# What does this PR do?
- fix pre-commit with api updates

[//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])

## Test Plan
```
pre-commit
```

[//]: # (## Documentation)
This commit is contained in:
Xi Yan 2025-03-13 16:41:03 -07:00 committed by GitHub
parent a6095820af
commit 7606e49dbc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 42 additions and 97 deletions

View file

@ -55,6 +55,4 @@ class DatasetIO(Protocol):
... ...
@webmethod(route="/datasets/{dataset_id}/rows", method="POST") @webmethod(route="/datasets/{dataset_id}/rows", method="POST")
async def append_rows( async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: ...
self, dataset_id: str, rows: List[Dict[str, Any]]
) -> None: ...

View file

@ -13,7 +13,7 @@ from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
class DatasetPurpose(Enum): class DatasetPurpose(str, Enum):
""" """
Purpose of the dataset. Each purpose has a required input data schema. Purpose of the dataset. Each purpose has a required input data schema.

View file

@ -6,7 +6,7 @@
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
from llama_stack.apis.common.content_types import URL from llama_stack.apis.datasets import DatasetPurpose, URIDataSource
from llama_stack.apis.models.models import ModelType from llama_stack.apis.models.models import ModelType
from llama_stack.distribution.datatypes import ( from llama_stack.distribution.datatypes import (
BenchmarkInput, BenchmarkInput,
@ -171,60 +171,34 @@ def get_distribution_template() -> DistributionTemplate:
DatasetInput( DatasetInput(
dataset_id="simpleqa", dataset_id="simpleqa",
provider_id="huggingface", provider_id="huggingface",
url=URL(uri="https://huggingface.co/datasets/llamastack/simpleqa"), purpose=DatasetPurpose.eval_messages_answer,
metadata={ source=URIDataSource(
"path": "llamastack/simpleqa", uri="huggingface://llamastack/simpleqa?split=train",
"split": "train", ),
},
dataset_schema={
"input_query": {"type": "string"},
"expected_answer": {"type": "string"},
"chat_completion_input": {"type": "string"},
},
), ),
DatasetInput( DatasetInput(
dataset_id="mmlu_cot", dataset_id="mmlu_cot",
provider_id="huggingface", provider_id="huggingface",
url=URL(uri="https://huggingface.co/datasets/llamastack/mmlu_cot"), purpose=DatasetPurpose.eval_messages_answer,
metadata={ source=URIDataSource(
"path": "llamastack/mmlu_cot", uri="huggingface://llamastack/mmlu_cot?split=test&name=all",
"name": "all", ),
"split": "test",
},
dataset_schema={
"input_query": {"type": "string"},
"expected_answer": {"type": "string"},
"chat_completion_input": {"type": "string"},
},
), ),
DatasetInput( DatasetInput(
dataset_id="gpqa_cot", dataset_id="gpqa_cot",
provider_id="huggingface", provider_id="huggingface",
url=URL(uri="https://huggingface.co/datasets/llamastack/gpqa_0shot_cot"), purpose=DatasetPurpose.eval_messages_answer,
metadata={ source=URIDataSource(
"path": "llamastack/gpqa_0shot_cot", uri="huggingface://llamastack/gpqa_0shot_cot?split=test&name=gpqa_main",
"name": "gpqa_main", ),
"split": "train",
},
dataset_schema={
"input_query": {"type": "string"},
"expected_answer": {"type": "string"},
"chat_completion_input": {"type": "string"},
},
), ),
DatasetInput( DatasetInput(
dataset_id="math_500", dataset_id="math_500",
provider_id="huggingface", provider_id="huggingface",
url=URL(uri="https://huggingface.co/datasets/llamastack/math_500"), purpose=DatasetPurpose.eval_messages_answer,
metadata={ source=URIDataSource(
"path": "llamastack/math_500", uri="huggingface://llamastack/math_500?split=test",
"split": "test", ),
},
dataset_schema={
"input_query": {"type": "string"},
"expected_answer": {"type": "string"},
"chat_completion_input": {"type": "string"},
},
), ),
] ]

View file

@ -158,62 +158,32 @@ shields:
- shield_id: meta-llama/Llama-Guard-3-8B - shield_id: meta-llama/Llama-Guard-3-8B
vector_dbs: [] vector_dbs: []
datasets: datasets:
- dataset_schema: - purpose: eval/messages-answer
input_query: source:
type: string type: uri
expected_answer: uri: huggingface://llamastack/simpleqa?split=train
type: string metadata: {}
chat_completion_input:
type: string
url:
uri: https://huggingface.co/datasets/llamastack/simpleqa
metadata:
path: llamastack/simpleqa
split: train
dataset_id: simpleqa dataset_id: simpleqa
provider_id: huggingface provider_id: huggingface
- dataset_schema: - purpose: eval/messages-answer
input_query: source:
type: string type: uri
expected_answer: uri: huggingface://llamastack/mmlu_cot?split=test&name=all
type: string metadata: {}
chat_completion_input:
type: string
url:
uri: https://huggingface.co/datasets/llamastack/mmlu_cot
metadata:
path: llamastack/mmlu_cot
name: all
split: test
dataset_id: mmlu_cot dataset_id: mmlu_cot
provider_id: huggingface provider_id: huggingface
- dataset_schema: - purpose: eval/messages-answer
input_query: source:
type: string type: uri
expected_answer: uri: huggingface://llamastack/gpqa_0shot_cot?split=test&name=gpqa_main
type: string metadata: {}
chat_completion_input:
type: string
url:
uri: https://huggingface.co/datasets/llamastack/gpqa_0shot_cot
metadata:
path: llamastack/gpqa_0shot_cot
name: gpqa_main
split: train
dataset_id: gpqa_cot dataset_id: gpqa_cot
provider_id: huggingface provider_id: huggingface
- dataset_schema: - purpose: eval/messages-answer
input_query: source:
type: string type: uri
expected_answer: uri: huggingface://llamastack/math_500?split=test
type: string metadata: {}
chat_completion_input:
type: string
url:
uri: https://huggingface.co/datasets/llamastack/math_500
metadata:
path: llamastack/math_500
split: test
dataset_id: math_500 dataset_id: math_500
provider_id: huggingface provider_id: huggingface
scoring_fns: [] scoring_fns: []

View file

@ -11,6 +11,7 @@ import jinja2
import yaml import yaml
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_stack.apis.datasets import DatasetPurpose
from llama_stack.apis.models.models import ModelType from llama_stack.apis.models.models import ModelType
from llama_stack.distribution.datatypes import ( from llama_stack.distribution.datatypes import (
Api, Api,
@ -214,7 +215,9 @@ class DistributionTemplate(BaseModel):
# Register YAML representer for ModelType # Register YAML representer for ModelType
yaml.add_representer(ModelType, enum_representer) yaml.add_representer(ModelType, enum_representer)
yaml.add_representer(DatasetPurpose, enum_representer)
yaml.SafeDumper.add_representer(ModelType, enum_representer) yaml.SafeDumper.add_representer(ModelType, enum_representer)
yaml.SafeDumper.add_representer(DatasetPurpose, enum_representer)
for output_dir in [yaml_output_dir, doc_output_dir]: for output_dir in [yaml_output_dir, doc_output_dir]:
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)