This commit is contained in:
Xi Yan 2025-03-18 15:01:41 -07:00
parent a8b0467ec3
commit a69759613a
7 changed files with 2486 additions and 389 deletions

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -62,7 +62,7 @@ class Benchmarks(Protocol):
""" """
Register a new benchmark. Register a new benchmark.
:param dataset_id: The ID of the dataset to used to run the benchmark. :param dataset_id: The ID of the dataset to be used to run the benchmark.
:param grader_ids: List of grader ids to use for this benchmark. :param grader_ids: List of grader ids to use for this benchmark.
:param benchmark_id: (Optional) The ID of the benchmark to register. If not provided, an ID will be generated. :param benchmark_id: (Optional) The ID of the benchmark to register. If not provided, an ID will be generated.
:param metadata: (Optional) Metadata for this benchmark for additional descriptions. :param metadata: (Optional) Metadata for this benchmark for additional descriptions.
@ -87,3 +87,10 @@ class Benchmarks(Protocol):
:param benchmark_id: The ID of the benchmark to get. :param benchmark_id: The ID of the benchmark to get.
""" """
... ...
@webmethod(route="/benchmarks/{benchmark_id}", method="DELETE")
async def unregister_benchmark(self, benchmark_id: str) -> None:
"""
Unregister a benchmark by ID.
"""
...

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
@ -38,12 +37,12 @@ class CommonJobFields(BaseModel):
:param id: The ID of the job. :param id: The ID of the job.
:param status: The status of the job. :param status: The status of the job.
:param created_at: The time the job was created. :param created_at: The time the job was created.
:param ended_at: The time the job ended. :param completed_at: The time the job completed.
:param error: If status of the job is failed, this will contain the error message. :param error: If status of the job is failed, this will contain the error message.
""" """
id: str id: str
status: JobStatus status: JobStatus
created_at: datetime created_at: datetime
ended_at: Optional[datetime] = None completed_at: datetime | None = None
error: Optional[str] = None error: str | None = None

View file

@ -48,28 +48,28 @@ EvaluationCandidate = register_schema(
@json_schema_type @json_schema_type
class BenchmarkTask(BaseModel): class BenchmarkEvaluationTask(BaseModel):
type: Literal["benchmark_id"] = "benchmark_id" type: Literal["benchmark"] = "benchmark"
benchmark_id: str benchmark_id: str
@json_schema_type @json_schema_type
class DatasetGraderTask(BaseModel): class DatasetEvaluationTask(BaseModel):
type: Literal["dataset_grader"] = "dataset_grader" type: Literal["dataset"] = "dataset"
dataset_id: str dataset_id: str
grader_ids: List[str] grader_ids: List[str]
@json_schema_type @json_schema_type
class DataSourceGraderTask(BaseModel): class DataEvaluationTask(BaseModel):
type: Literal["data_source_grader"] = "data_source_grader" type: Literal["data"] = "data"
data_source: DataSource data_source: DataSource
grader_ids: List[str] grader_ids: List[str]
EvaluationTask = register_schema( EvaluationTask = register_schema(
Annotated[ Annotated[
Union[BenchmarkTask, DatasetGraderTask, DataSourceGraderTask], Union[BenchmarkEvaluationTask, DatasetEvaluationTask, DataEvaluationTask],
Field(discriminator="type"), Field(discriminator="type"),
], ],
name="EvaluationTask", name="EvaluationTask",

View file

@ -29,6 +29,13 @@ from .graders import * # noqa: F401 F403
class GraderType(Enum): class GraderType(Enum):
""" """
A type of grader. Each type is a criteria for evaluating answers. A type of grader. Each type is a criteria for evaluating answers.
:cvar llm: Use an LLM to score the answer.
:cvar regex_parser: Use a regex parser to score the answer.
:cvar equality: Check if the answer is equal to the reference answer.
:cvar subset_of: Check if the answer is a subset of the reference answer.
:cvar factuality: Check if the answer is factually correct using LLM as judge.
:cvar faithfulness: Check if the answer is faithful to the reference answer using LLM as judge.
""" """
llm = "llm" llm = "llm"
@ -221,9 +228,9 @@ class Graders(Protocol):
... ...
@webmethod(route="/graders/{grader_id:path}", method="DELETE") @webmethod(route="/graders/{grader_id:path}", method="DELETE")
async def delete_grader(self, grader_id: str) -> None: async def unregister_grader(self, grader_id: str) -> None:
""" """
Delete a grader by ID. Unregister a grader by ID.
:param grader_id: The ID of the grader. :param grader_id: The ID of the grader.
""" """
... ...

View file

@ -17,6 +17,7 @@ from llama_stack.apis.batch_inference import BatchInference
from llama_stack.apis.benchmarks import Benchmarks from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval import Eval
from llama_stack.apis.evaluation import Evaluation from llama_stack.apis.evaluation import Evaluation
from llama_stack.apis.files import Files from llama_stack.apis.files import Files
from llama_stack.apis.graders import Graders from llama_stack.apis.graders import Graders
@ -26,6 +27,8 @@ from llama_stack.apis.models import Models
from llama_stack.apis.post_training import PostTraining from llama_stack.apis.post_training import PostTraining
from llama_stack.apis.providers import Providers from llama_stack.apis.providers import Providers
from llama_stack.apis.safety import Safety from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields from llama_stack.apis.shields import Shields
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
from llama_stack.apis.telemetry import Telemetry from llama_stack.apis.telemetry import Telemetry
@ -66,6 +69,9 @@ class LlamaStack(
Files, Files,
Graders, Graders,
Evaluation, Evaluation,
Eval,
ScoringFunctions,
Scoring,
): ):
pass pass
@ -111,7 +117,9 @@ class EnvVarError(Exception):
def __init__(self, var_name: str, path: str = ""): def __init__(self, var_name: str, path: str = ""):
self.var_name = var_name self.var_name = var_name
self.path = path self.path = path
super().__init__(f"Environment variable '{var_name}' not set or empty{f' at {path}' if path else ''}") super().__init__(
f"Environment variable '{var_name}' not set or empty{f' at {path}' if path else ''}"
)
def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]: def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]:
@ -202,7 +210,9 @@ def validate_env_pair(env_pair: str) -> tuple[str, str]:
if not key: if not key:
raise ValueError(f"Empty key in environment variable pair: {env_pair}") raise ValueError(f"Empty key in environment variable pair: {env_pair}")
if not all(c.isalnum() or c == "_" for c in key): if not all(c.isalnum() or c == "_" for c in key):
raise ValueError(f"Key must contain only alphanumeric characters and underscores: {key}") raise ValueError(
f"Key must contain only alphanumeric characters and underscores: {key}"
)
return key, value return key, value
except ValueError as e: except ValueError as e:
raise ValueError( raise ValueError(
@ -215,14 +225,20 @@ def validate_env_pair(env_pair: str) -> tuple[str, str]:
async def construct_stack( async def construct_stack(
run_config: StackRunConfig, provider_registry: Optional[ProviderRegistry] = None run_config: StackRunConfig, provider_registry: Optional[ProviderRegistry] = None
) -> Dict[Api, Any]: ) -> Dict[Api, Any]:
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name) dist_registry, _ = await create_dist_registry(
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(), dist_registry) run_config.metadata_store, run_config.image_name
)
impls = await resolve_impls(
run_config, provider_registry or get_provider_registry(), dist_registry
)
await register_resources(run_config, impls) await register_resources(run_config, impls)
return impls return impls
def get_stack_run_config_from_template(template: str) -> StackRunConfig: def get_stack_run_config_from_template(template: str) -> StackRunConfig:
template_path = importlib.resources.files("llama_stack") / f"templates/{template}/run.yaml" template_path = (
importlib.resources.files("llama_stack") / f"templates/{template}/run.yaml"
)
with importlib.resources.as_file(template_path) as path: with importlib.resources.as_file(template_path) as path:
if not path.exists(): if not path.exists():
@ -265,7 +281,9 @@ def run_config_from_adhoc_config_spec(
# call method "sample_run_config" on the provider spec config class # call method "sample_run_config" on the provider spec config class
provider_config_type = instantiate_class_type(provider_spec.config_class) provider_config_type = instantiate_class_type(provider_spec.config_class)
provider_config = replace_env_vars(provider_config_type.sample_run_config(__distro_dir__=distro_dir)) provider_config = replace_env_vars(
provider_config_type.sample_run_config(__distro_dir__=distro_dir)
)
provider_configs_by_api[api_str] = [ provider_configs_by_api[api_str] = [
Provider( Provider(