This commit is contained in:
Xi Yan 2025-03-16 19:33:57 -07:00
parent d34b70e3ab
commit 035b2dcb60
9 changed files with 2365 additions and 2190 deletions

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -12,11 +12,17 @@ from llama_stack.schema_utils import json_schema_type, webmethod
class CommonBenchmarkFields(BaseModel):
"""
:param dataset_id: The ID of the dataset to used to run the benchmark.
:param grader_ids: The grader ids to use for this benchmark.
:param metadata: Metadata for this benchmark for additional descriptions.
"""
dataset_id: str
scoring_functions: List[str]
grader_ids: List[str]
metadata: Dict[str, Any] = Field(
default_factory=dict,
description="Metadata for this evaluation task",
description="Metadata for this benchmark",
)
@ -45,22 +51,39 @@ class ListBenchmarksResponse(BaseModel):
@runtime_checkable
class Benchmarks(Protocol):
@webmethod(route="/eval/benchmarks", method="POST")
async def register_benchmark(
self,
dataset_id: str,
grader_ids: List[str],
benchmark_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Benchmark:
"""
Register a new benchmark.
:param dataset_id: The ID of the dataset to used to run the benchmark.
:param grader_ids: List of grader ids to use for this benchmark.
:param benchmark_id: (Optional) The ID of the benchmark to register. If not provided, an ID will be generated.
:param metadata: (Optional) Metadata for this benchmark for additional descriptions.
"""
...
@webmethod(route="/eval/benchmarks", method="GET")
async def list_benchmarks(self) -> ListBenchmarksResponse: ...
async def list_benchmarks(self) -> ListBenchmarksResponse:
"""
List all benchmarks.
"""
...
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="GET")
async def get_benchmark(
self,
benchmark_id: str,
) -> Optional[Benchmark]: ...
) -> Benchmark:
"""
Get a benchmark by ID.
@webmethod(route="/eval/benchmarks", method="POST")
async def register_benchmark(
self,
benchmark_id: str,
dataset_id: str,
scoring_functions: List[str],
provider_benchmark_id: Optional[str] = None,
provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> None: ...
:param benchmark_id: The ID of the benchmark to get.
"""
...

View file

@ -3,21 +3,49 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from enum import Enum
from typing import Optional
from pydantic import BaseModel
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class Job(BaseModel):
job_id: str
class JobType(Enum):
batch_inference = "batch_inference"
evaluation = "evaluation"
finetuning = "finetuning"
@json_schema_type
class JobStatus(Enum):
completed = "completed"
in_progress = "in_progress"
failed = "failed"
scheduled = "scheduled"
cancelled = "cancelled"
class JobArtifact(BaseModel):
"""
A job artifact is a file or directory that is produced by a job.
"""
path: str
@json_schema_type
class CommonJobFields(BaseModel):
"""Common fields for all jobs.
:param id: The ID of the job.
:param status: The status of the job.
:param created_at: The time the job was created.
:param ended_at: The time the job ended.
:param error: If status of the job is failed, this will contain the error message.
"""
id: str
status: JobStatus
created_at: datetime
ended_at: Optional[datetime] = None
error: Optional[str] = None

View file

@ -10,7 +10,7 @@ from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.apis.agents import AgentConfig
from llama_stack.apis.common.job_types import Job, JobStatus
from llama_stack.apis.common.job_types import JobStatus
from llama_stack.apis.inference import SamplingParams, SystemMessage
from llama_stack.apis.scoring import ScoringResult
from llama_stack.apis.scoring_functions import ScoringFnParams
@ -91,7 +91,7 @@ class Eval(Protocol):
self,
benchmark_id: str,
benchmark_config: BenchmarkConfig,
) -> Job:
) -> None:
"""Run an evaluation on a benchmark.
:param benchmark_id: The ID of the benchmark to run the evaluation on.
@ -135,7 +135,9 @@ class Eval(Protocol):
"""
...
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result", method="GET")
@webmethod(
route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result", method="GET"
)
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
"""Get the result of a job.

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .evaluation import * # noqa: F401 F403

View file

@ -0,0 +1,175 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.apis.agents import AgentConfig
from llama_stack.apis.common.job_types import CommonJobFields, JobType
from llama_stack.apis.datasets import DataSource
from llama_stack.apis.inference import SamplingParams, SystemMessage
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@json_schema_type
class ModelCandidate(BaseModel):
"""A model candidate for evaluation.
:param model: The model ID to evaluate.
:param sampling_params: The sampling parameters for the model.
:param system_message: (Optional) The system message providing instructions or context to the model.
"""
type: Literal["model"] = "model"
model_id: str
sampling_params: SamplingParams
system_message: Optional[SystemMessage] = None
@json_schema_type
class AgentCandidate(BaseModel):
"""An agent candidate for evaluation.
:param config: The configuration for the agent candidate.
"""
type: Literal["agent"] = "agent"
config: AgentConfig
EvaluationCandidate = register_schema(
Annotated[Union[ModelCandidate, AgentCandidate], Field(discriminator="type")],
name="EvaluationCandidate",
)
@json_schema_type
class BenchmarkTask(BaseModel):
type: Literal["benchmark_id"] = "benchmark_id"
benchmark_id: str
@json_schema_type
class DatasetGraderTask(BaseModel):
type: Literal["dataset_grader"] = "dataset_grader"
dataset_id: str
grader_ids: List[str]
@json_schema_type
class DataSourceGraderTask(BaseModel):
type: Literal["data_source_grader"] = "data_source_grader"
data_source: DataSource
grader_ids: List[str]
EvaluationTask = register_schema(
Annotated[
Union[BenchmarkTask, DatasetGraderTask, DataSourceGraderTask],
Field(discriminator="type"),
],
name="EvaluationTask",
)
@json_schema_type
class EvaluationJob(CommonJobFields):
type: Literal[JobType.evaluation.value] = JobType.evaluation.value
# input params for the submitted evaluation job
task: EvaluationTask
candidate: EvaluationCandidate
@json_schema_type
class ScoringResult(BaseModel):
"""
A scoring result for a single row.
:param scores: The scoring result for each row. Each row is a map of grader column name to value.
:param metrics: Map of metric name to aggregated value.
"""
scores: List[Dict[str, Any]]
metrics: Dict[str, Any]
@json_schema_type
class EvaluationResponse(BaseModel):
"""
A response to an inline evaluation.
:param generations: The generations in rows for the evaluation.
:param scores: The scores for the evaluation. Map of grader id to ScoringResult.
"""
generations: List[Dict[str, Any]]
scores: Dict[str, ScoringResult]
class Evaluation(Protocol):
@webmethod(route="/evaluation/run", method="POST")
async def run(
self,
task: EvaluationTask,
candidate: EvaluationCandidate,
) -> EvaluationJob:
"""
Run an evaluation job.
:param task: The task to evaluate. One of:
- BenchmarkTask: Run evaluation task against a benchmark_id
- DatasetGraderTask: Run evaluation task against a dataset_id and a list of grader_ids
- DataSourceGraderTask: Run evaluation task against a data source (e.g. rows, uri, etc.) and a list of grader_ids
:param candidate: The candidate to evaluate.
"""
...
@webmethod(route="/evaluation/run_inline", method="POST")
async def run_inline(
self,
task: EvaluationTask,
candidate: EvaluationCandidate,
) -> EvaluationResponse:
"""
Run an evaluation job inline.
:param task: The task to evaluate. One of:
- BenchmarkTask: Run evaluation task against a benchmark_id
- DatasetGraderTask: Run evaluation task against a dataset_id and a list of grader_ids
- DataSourceGraderTask: Run evaluation task against a data source (e.g. rows, uri, etc.) and a list of grader_ids
:param candidate: The candidate to evaluate.
"""
...
@webmethod(route="/evaluation/grade", method="POST")
async def grade(self, task: EvaluationTask) -> EvaluationJob:
"""
Run an grading job with generated results. Use this when you have generated results from inference in a dataset.
:param task: The task to evaluate. One of:
- BenchmarkTask: Run evaluation task against a benchmark_id
- DatasetGraderTask: Run evaluation task against a dataset_id and a list of grader_ids
- DataSourceGraderTask: Run evaluation task against a data source (e.g. rows, uri, etc.) and a list of grader_ids
:return: The evaluation job containing grader scores.
"""
...
@webmethod(route="/evaluation/grade_inline", method="POST")
async def grade_inline(self, task: EvaluationTask) -> EvaluationResponse:
"""
Run an grading job with generated results inline.
:param task: The task to evaluate. One of:
- BenchmarkTask: Run evaluation task against a benchmark_id
- DatasetGraderTask: Run evaluation task against a dataset_id and a list of grader_ids
- DataSourceGraderTask: Run evaluation task against a data source (e.g. rows, uri, etc.) and a list of grader_ids
:return: The evaluation job containing grader scores. "generations" is not populated in the response.
"""
...

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .graders import * # noqa: F401 F403

View file

@ -17,16 +17,15 @@ from llama_stack.apis.batch_inference import BatchInference
from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval import Eval
from llama_stack.apis.evaluation import Evaluation
from llama_stack.apis.files import Files
from llama_stack.apis.graders import Graders
from llama_stack.apis.inference import Inference
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.models import Models
from llama_stack.apis.post_training import PostTraining
from llama_stack.apis.providers import Providers
from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
from llama_stack.apis.telemetry import Telemetry
@ -56,10 +55,7 @@ class LlamaStack(
Telemetry,
PostTraining,
VectorIO,
Eval,
Benchmarks,
Scoring,
ScoringFunctions,
DatasetIO,
Models,
Shields,
@ -68,6 +64,8 @@ class LlamaStack(
ToolRuntime,
RAGToolRuntime,
Files,
Graders,
Evaluation,
):
pass
@ -113,7 +111,9 @@ class EnvVarError(Exception):
def __init__(self, var_name: str, path: str = ""):
self.var_name = var_name
self.path = path
super().__init__(f"Environment variable '{var_name}' not set or empty{f' at {path}' if path else ''}")
super().__init__(
f"Environment variable '{var_name}' not set or empty{f' at {path}' if path else ''}"
)
def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]:
@ -204,7 +204,9 @@ def validate_env_pair(env_pair: str) -> tuple[str, str]:
if not key:
raise ValueError(f"Empty key in environment variable pair: {env_pair}")
if not all(c.isalnum() or c == "_" for c in key):
raise ValueError(f"Key must contain only alphanumeric characters and underscores: {key}")
raise ValueError(
f"Key must contain only alphanumeric characters and underscores: {key}"
)
return key, value
except ValueError as e:
raise ValueError(
@ -217,14 +219,20 @@ def validate_env_pair(env_pair: str) -> tuple[str, str]:
async def construct_stack(
run_config: StackRunConfig, provider_registry: Optional[ProviderRegistry] = None
) -> Dict[Api, Any]:
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(), dist_registry)
dist_registry, _ = await create_dist_registry(
run_config.metadata_store, run_config.image_name
)
impls = await resolve_impls(
run_config, provider_registry or get_provider_registry(), dist_registry
)
await register_resources(run_config, impls)
return impls
def get_stack_run_config_from_template(template: str) -> StackRunConfig:
template_path = importlib.resources.files("llama_stack") / f"templates/{template}/run.yaml"
template_path = (
importlib.resources.files("llama_stack") / f"templates/{template}/run.yaml"
)
with importlib.resources.as_file(template_path) as path:
if not path.exists():
@ -267,7 +275,9 @@ def run_config_from_adhoc_config_spec(
# call method "sample_run_config" on the provider spec config class
provider_config_type = instantiate_class_type(provider_spec.config_class)
provider_config = replace_env_vars(provider_config_type.sample_run_config(__distro_dir__=distro_dir))
provider_config = replace_env_vars(
provider_config_type.sample_run_config(__distro_dir__=distro_dir)
)
provider_configs_by_api[api_str] = [
Provider(