This commit is contained in:
Xi Yan 2025-03-16 19:33:57 -07:00
parent d34b70e3ab
commit 035b2dcb60
9 changed files with 2365 additions and 2190 deletions

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -12,11 +12,17 @@ from llama_stack.schema_utils import json_schema_type, webmethod
class CommonBenchmarkFields(BaseModel): class CommonBenchmarkFields(BaseModel):
"""
:param dataset_id: The ID of the dataset to used to run the benchmark.
:param grader_ids: The grader ids to use for this benchmark.
:param metadata: Metadata for this benchmark for additional descriptions.
"""
dataset_id: str dataset_id: str
scoring_functions: List[str] grader_ids: List[str]
metadata: Dict[str, Any] = Field( metadata: Dict[str, Any] = Field(
default_factory=dict, default_factory=dict,
description="Metadata for this evaluation task", description="Metadata for this benchmark",
) )
@ -45,22 +51,39 @@ class ListBenchmarksResponse(BaseModel):
@runtime_checkable @runtime_checkable
class Benchmarks(Protocol): class Benchmarks(Protocol):
@webmethod(route="/eval/benchmarks", method="POST")
async def register_benchmark(
self,
dataset_id: str,
grader_ids: List[str],
benchmark_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Benchmark:
"""
Register a new benchmark.
:param dataset_id: The ID of the dataset to used to run the benchmark.
:param grader_ids: List of grader ids to use for this benchmark.
:param benchmark_id: (Optional) The ID of the benchmark to register. If not provided, an ID will be generated.
:param metadata: (Optional) Metadata for this benchmark for additional descriptions.
"""
...
@webmethod(route="/eval/benchmarks", method="GET") @webmethod(route="/eval/benchmarks", method="GET")
async def list_benchmarks(self) -> ListBenchmarksResponse: ... async def list_benchmarks(self) -> ListBenchmarksResponse:
"""
List all benchmarks.
"""
...
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="GET") @webmethod(route="/eval/benchmarks/{benchmark_id}", method="GET")
async def get_benchmark( async def get_benchmark(
self, self,
benchmark_id: str, benchmark_id: str,
) -> Optional[Benchmark]: ... ) -> Benchmark:
"""
Get a benchmark by ID.
@webmethod(route="/eval/benchmarks", method="POST") :param benchmark_id: The ID of the benchmark to get.
async def register_benchmark( """
self, ...
benchmark_id: str,
dataset_id: str,
scoring_functions: List[str],
provider_benchmark_id: Optional[str] = None,
provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> None: ...

View file

@ -3,21 +3,49 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from datetime import datetime
from enum import Enum from enum import Enum
from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
@json_schema_type class JobType(Enum):
class Job(BaseModel): batch_inference = "batch_inference"
job_id: str evaluation = "evaluation"
finetuning = "finetuning"
@json_schema_type
class JobStatus(Enum): class JobStatus(Enum):
completed = "completed" completed = "completed"
in_progress = "in_progress" in_progress = "in_progress"
failed = "failed" failed = "failed"
scheduled = "scheduled" scheduled = "scheduled"
cancelled = "cancelled"
class JobArtifact(BaseModel):
"""
A job artifact is a file or directory that is produced by a job.
"""
path: str
@json_schema_type
class CommonJobFields(BaseModel):
"""Common fields for all jobs.
:param id: The ID of the job.
:param status: The status of the job.
:param created_at: The time the job was created.
:param ended_at: The time the job ended.
:param error: If status of the job is failed, this will contain the error message.
"""
id: str
status: JobStatus
created_at: datetime
ended_at: Optional[datetime] = None
error: Optional[str] = None

View file

@ -10,7 +10,7 @@ from pydantic import BaseModel, Field
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_stack.apis.agents import AgentConfig from llama_stack.apis.agents import AgentConfig
from llama_stack.apis.common.job_types import Job, JobStatus from llama_stack.apis.common.job_types import JobStatus
from llama_stack.apis.inference import SamplingParams, SystemMessage from llama_stack.apis.inference import SamplingParams, SystemMessage
from llama_stack.apis.scoring import ScoringResult from llama_stack.apis.scoring import ScoringResult
from llama_stack.apis.scoring_functions import ScoringFnParams from llama_stack.apis.scoring_functions import ScoringFnParams
@ -91,7 +91,7 @@ class Eval(Protocol):
self, self,
benchmark_id: str, benchmark_id: str,
benchmark_config: BenchmarkConfig, benchmark_config: BenchmarkConfig,
) -> Job: ) -> None:
"""Run an evaluation on a benchmark. """Run an evaluation on a benchmark.
:param benchmark_id: The ID of the benchmark to run the evaluation on. :param benchmark_id: The ID of the benchmark to run the evaluation on.
@ -135,7 +135,9 @@ class Eval(Protocol):
""" """
... ...
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result", method="GET") @webmethod(
route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result", method="GET"
)
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse: async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
"""Get the result of a job. """Get the result of a job.

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .evaluation import * # noqa: F401 F403

View file

@ -0,0 +1,175 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.apis.agents import AgentConfig
from llama_stack.apis.common.job_types import CommonJobFields, JobType
from llama_stack.apis.datasets import DataSource
from llama_stack.apis.inference import SamplingParams, SystemMessage
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@json_schema_type
class ModelCandidate(BaseModel):
"""A model candidate for evaluation.
:param model: The model ID to evaluate.
:param sampling_params: The sampling parameters for the model.
:param system_message: (Optional) The system message providing instructions or context to the model.
"""
type: Literal["model"] = "model"
model_id: str
sampling_params: SamplingParams
system_message: Optional[SystemMessage] = None
@json_schema_type
class AgentCandidate(BaseModel):
"""An agent candidate for evaluation.
:param config: The configuration for the agent candidate.
"""
type: Literal["agent"] = "agent"
config: AgentConfig
EvaluationCandidate = register_schema(
Annotated[Union[ModelCandidate, AgentCandidate], Field(discriminator="type")],
name="EvaluationCandidate",
)
@json_schema_type
class BenchmarkTask(BaseModel):
type: Literal["benchmark_id"] = "benchmark_id"
benchmark_id: str
@json_schema_type
class DatasetGraderTask(BaseModel):
type: Literal["dataset_grader"] = "dataset_grader"
dataset_id: str
grader_ids: List[str]
@json_schema_type
class DataSourceGraderTask(BaseModel):
type: Literal["data_source_grader"] = "data_source_grader"
data_source: DataSource
grader_ids: List[str]
EvaluationTask = register_schema(
Annotated[
Union[BenchmarkTask, DatasetGraderTask, DataSourceGraderTask],
Field(discriminator="type"),
],
name="EvaluationTask",
)
@json_schema_type
class EvaluationJob(CommonJobFields):
type: Literal[JobType.evaluation.value] = JobType.evaluation.value
# input params for the submitted evaluation job
task: EvaluationTask
candidate: EvaluationCandidate
@json_schema_type
class ScoringResult(BaseModel):
"""
A scoring result for a single row.
:param scores: The scoring result for each row. Each row is a map of grader column name to value.
:param metrics: Map of metric name to aggregated value.
"""
scores: List[Dict[str, Any]]
metrics: Dict[str, Any]
@json_schema_type
class EvaluationResponse(BaseModel):
"""
A response to an inline evaluation.
:param generations: The generations in rows for the evaluation.
:param scores: The scores for the evaluation. Map of grader id to ScoringResult.
"""
generations: List[Dict[str, Any]]
scores: Dict[str, ScoringResult]
class Evaluation(Protocol):
@webmethod(route="/evaluation/run", method="POST")
async def run(
self,
task: EvaluationTask,
candidate: EvaluationCandidate,
) -> EvaluationJob:
"""
Run an evaluation job.
:param task: The task to evaluate. One of:
- BenchmarkTask: Run evaluation task against a benchmark_id
- DatasetGraderTask: Run evaluation task against a dataset_id and a list of grader_ids
- DataSourceGraderTask: Run evaluation task against a data source (e.g. rows, uri, etc.) and a list of grader_ids
:param candidate: The candidate to evaluate.
"""
...
@webmethod(route="/evaluation/run_inline", method="POST")
async def run_inline(
self,
task: EvaluationTask,
candidate: EvaluationCandidate,
) -> EvaluationResponse:
"""
Run an evaluation job inline.
:param task: The task to evaluate. One of:
- BenchmarkTask: Run evaluation task against a benchmark_id
- DatasetGraderTask: Run evaluation task against a dataset_id and a list of grader_ids
- DataSourceGraderTask: Run evaluation task against a data source (e.g. rows, uri, etc.) and a list of grader_ids
:param candidate: The candidate to evaluate.
"""
...
@webmethod(route="/evaluation/grade", method="POST")
async def grade(self, task: EvaluationTask) -> EvaluationJob:
"""
Run an grading job with generated results. Use this when you have generated results from inference in a dataset.
:param task: The task to evaluate. One of:
- BenchmarkTask: Run evaluation task against a benchmark_id
- DatasetGraderTask: Run evaluation task against a dataset_id and a list of grader_ids
- DataSourceGraderTask: Run evaluation task against a data source (e.g. rows, uri, etc.) and a list of grader_ids
:return: The evaluation job containing grader scores.
"""
...
@webmethod(route="/evaluation/grade_inline", method="POST")
async def grade_inline(self, task: EvaluationTask) -> EvaluationResponse:
"""
Run an grading job with generated results inline.
:param task: The task to evaluate. One of:
- BenchmarkTask: Run evaluation task against a benchmark_id
- DatasetGraderTask: Run evaluation task against a dataset_id and a list of grader_ids
- DataSourceGraderTask: Run evaluation task against a data source (e.g. rows, uri, etc.) and a list of grader_ids
:return: The evaluation job containing grader scores. "generations" is not populated in the response.
"""
...

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .graders import * # noqa: F401 F403

View file

@ -17,16 +17,15 @@ from llama_stack.apis.batch_inference import BatchInference
from llama_stack.apis.benchmarks import Benchmarks from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval import Eval from llama_stack.apis.evaluation import Evaluation
from llama_stack.apis.files import Files from llama_stack.apis.files import Files
from llama_stack.apis.graders import Graders
from llama_stack.apis.inference import Inference from llama_stack.apis.inference import Inference
from llama_stack.apis.inspect import Inspect from llama_stack.apis.inspect import Inspect
from llama_stack.apis.models import Models from llama_stack.apis.models import Models
from llama_stack.apis.post_training import PostTraining from llama_stack.apis.post_training import PostTraining
from llama_stack.apis.providers import Providers from llama_stack.apis.providers import Providers
from llama_stack.apis.safety import Safety from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields from llama_stack.apis.shields import Shields
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
from llama_stack.apis.telemetry import Telemetry from llama_stack.apis.telemetry import Telemetry
@ -56,10 +55,7 @@ class LlamaStack(
Telemetry, Telemetry,
PostTraining, PostTraining,
VectorIO, VectorIO,
Eval,
Benchmarks, Benchmarks,
Scoring,
ScoringFunctions,
DatasetIO, DatasetIO,
Models, Models,
Shields, Shields,
@ -68,6 +64,8 @@ class LlamaStack(
ToolRuntime, ToolRuntime,
RAGToolRuntime, RAGToolRuntime,
Files, Files,
Graders,
Evaluation,
): ):
pass pass
@ -113,7 +111,9 @@ class EnvVarError(Exception):
def __init__(self, var_name: str, path: str = ""): def __init__(self, var_name: str, path: str = ""):
self.var_name = var_name self.var_name = var_name
self.path = path self.path = path
super().__init__(f"Environment variable '{var_name}' not set or empty{f' at {path}' if path else ''}") super().__init__(
f"Environment variable '{var_name}' not set or empty{f' at {path}' if path else ''}"
)
def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]: def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]:
@ -204,7 +204,9 @@ def validate_env_pair(env_pair: str) -> tuple[str, str]:
if not key: if not key:
raise ValueError(f"Empty key in environment variable pair: {env_pair}") raise ValueError(f"Empty key in environment variable pair: {env_pair}")
if not all(c.isalnum() or c == "_" for c in key): if not all(c.isalnum() or c == "_" for c in key):
raise ValueError(f"Key must contain only alphanumeric characters and underscores: {key}") raise ValueError(
f"Key must contain only alphanumeric characters and underscores: {key}"
)
return key, value return key, value
except ValueError as e: except ValueError as e:
raise ValueError( raise ValueError(
@ -217,14 +219,20 @@ def validate_env_pair(env_pair: str) -> tuple[str, str]:
async def construct_stack( async def construct_stack(
run_config: StackRunConfig, provider_registry: Optional[ProviderRegistry] = None run_config: StackRunConfig, provider_registry: Optional[ProviderRegistry] = None
) -> Dict[Api, Any]: ) -> Dict[Api, Any]:
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name) dist_registry, _ = await create_dist_registry(
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(), dist_registry) run_config.metadata_store, run_config.image_name
)
impls = await resolve_impls(
run_config, provider_registry or get_provider_registry(), dist_registry
)
await register_resources(run_config, impls) await register_resources(run_config, impls)
return impls return impls
def get_stack_run_config_from_template(template: str) -> StackRunConfig: def get_stack_run_config_from_template(template: str) -> StackRunConfig:
template_path = importlib.resources.files("llama_stack") / f"templates/{template}/run.yaml" template_path = (
importlib.resources.files("llama_stack") / f"templates/{template}/run.yaml"
)
with importlib.resources.as_file(template_path) as path: with importlib.resources.as_file(template_path) as path:
if not path.exists(): if not path.exists():
@ -267,7 +275,9 @@ def run_config_from_adhoc_config_spec(
# call method "sample_run_config" on the provider spec config class # call method "sample_run_config" on the provider spec config class
provider_config_type = instantiate_class_type(provider_spec.config_class) provider_config_type = instantiate_class_type(provider_spec.config_class)
provider_config = replace_env_vars(provider_config_type.sample_run_config(__distro_dir__=distro_dir)) provider_config = replace_env_vars(
provider_config_type.sample_run_config(__distro_dir__=distro_dir)
)
provider_configs_by_api[api_str] = [ provider_configs_by_api[api_str] = [
Provider( Provider(