diff --git a/llama_stack/apis/eval/eval.py b/llama_stack/apis/eval/eval.py index 5b4433041..8bba8c4d4 100644 --- a/llama_stack/apis/eval/eval.py +++ b/llama_stack/apis/eval/eval.py @@ -10,7 +10,7 @@ from pydantic import BaseModel, Field from typing_extensions import Annotated from llama_stack.apis.agents import AgentConfig -from llama_stack.apis.common.job_types import JobStatus +from llama_stack.apis.common.job_types import CommonJobFields, JobStatus from llama_stack.apis.inference import SamplingParams, SystemMessage from llama_stack.apis.scoring import ScoringResult from llama_stack.apis.scoring_functions import ScoringFnParams @@ -91,7 +91,7 @@ class Eval(Protocol): self, benchmark_id: str, benchmark_config: BenchmarkConfig, - ) -> None: + ) -> CommonJobFields: """Run an evaluation on a benchmark. :param benchmark_id: The ID of the benchmark to run the evaluation on. diff --git a/llama_stack/apis/graders/graders.py b/llama_stack/apis/graders/graders.py index 077497414..08ccb9715 100644 --- a/llama_stack/apis/graders/graders.py +++ b/llama_stack/apis/graders/graders.py @@ -4,9 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .graders import * # noqa: F401 F403 from enum import Enum - from typing import ( Annotated, Any, @@ -15,17 +13,18 @@ from typing import ( Literal, Optional, Protocol, - runtime_checkable, Union, + runtime_checkable, ) -from llama_stack.apis.datasets import DatasetPurpose - -from llama_stack.apis.resource import Resource, ResourceType - -from llama_stack.schema_utils import json_schema_type, register_schema, webmethod from pydantic import BaseModel, Field +from llama_stack.apis.datasets import DatasetPurpose +from llama_stack.apis.resource import Resource, ResourceType +from llama_stack.schema_utils import json_schema_type, register_schema, webmethod + +from .graders import * # noqa: F401 F403 + class GraderType(Enum): """ diff --git a/llama_stack/apis/resource.py b/llama_stack/apis/resource.py index fc590b118..5f4f9876c 100644 --- a/llama_stack/apis/resource.py +++ b/llama_stack/apis/resource.py @@ -25,9 +25,7 @@ class ResourceType(Enum): class Resource(BaseModel): """Base class for all Llama Stack resources""" - identifier: str = Field( - description="Unique identifier for this resource in llama stack" - ) + identifier: str = Field(description="Unique identifier for this resource in llama stack") provider_resource_id: str = Field( description="Unique identifier for this resource in the provider", @@ -36,6 +34,4 @@ class Resource(BaseModel): provider_id: str = Field(description="ID of the provider that owns this resource") - type: ResourceType = Field( - description="Type of resource (e.g. 'model', 'shield', 'vector_db', etc.)" - ) + type: ResourceType = Field(description="Type of resource (e.g. 'model', 'shield', 'vector_db', etc.)") diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py index cd1c58348..b4862537a 100644 --- a/llama_stack/distribution/stack.py +++ b/llama_stack/distribution/stack.py @@ -111,9 +111,7 @@ class EnvVarError(Exception): def __init__(self, var_name: str, path: str = ""): self.var_name = var_name self.path = path - super().__init__( - f"Environment variable '{var_name}' not set or empty{f' at {path}' if path else ''}" - ) + super().__init__(f"Environment variable '{var_name}' not set or empty{f' at {path}' if path else ''}") def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]: @@ -204,9 +202,7 @@ def validate_env_pair(env_pair: str) -> tuple[str, str]: if not key: raise ValueError(f"Empty key in environment variable pair: {env_pair}") if not all(c.isalnum() or c == "_" for c in key): - raise ValueError( - f"Key must contain only alphanumeric characters and underscores: {key}" - ) + raise ValueError(f"Key must contain only alphanumeric characters and underscores: {key}") return key, value except ValueError as e: raise ValueError( @@ -219,20 +215,14 @@ def validate_env_pair(env_pair: str) -> tuple[str, str]: async def construct_stack( run_config: StackRunConfig, provider_registry: Optional[ProviderRegistry] = None ) -> Dict[Api, Any]: - dist_registry, _ = await create_dist_registry( - run_config.metadata_store, run_config.image_name - ) - impls = await resolve_impls( - run_config, provider_registry or get_provider_registry(), dist_registry - ) + dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name) + impls = await resolve_impls(run_config, provider_registry or get_provider_registry(), dist_registry) await register_resources(run_config, impls) return impls def get_stack_run_config_from_template(template: str) -> StackRunConfig: - template_path = ( - importlib.resources.files("llama_stack") / f"templates/{template}/run.yaml" - ) + template_path = importlib.resources.files("llama_stack") / f"templates/{template}/run.yaml" with importlib.resources.as_file(template_path) as path: if not path.exists(): @@ -275,9 +265,7 @@ def run_config_from_adhoc_config_spec( # call method "sample_run_config" on the provider spec config class provider_config_type = instantiate_class_type(provider_spec.config_class) - provider_config = replace_env_vars( - provider_config_type.sample_run_config(__distro_dir__=distro_dir) - ) + provider_config = replace_env_vars(provider_config_type.sample_run_config(__distro_dir__=distro_dir)) provider_configs_by_api[api_str] = [ Provider(