diff --git a/llama_stack/apis/datatypes.py b/llama_stack/apis/datatypes.py index f644e5137..597572758 100644 --- a/llama_stack/apis/datatypes.py +++ b/llama_stack/apis/datatypes.py @@ -20,8 +20,6 @@ class Api(Enum): agents = "agents" vector_io = "vector_io" datasetio = "datasetio" - scoring = "scoring" - eval = "eval" post_training = "post_training" tool_runtime = "tool_runtime" @@ -31,7 +29,6 @@ class Api(Enum): shields = "shields" vector_dbs = "vector_dbs" datasets = "datasets" - scoring_functions = "scoring_functions" benchmarks = "benchmarks" tool_groups = "tool_groups" diff --git a/llama_stack/apis/eval/__init__.py b/llama_stack/apis/eval/__init__.py deleted file mode 100644 index 5f91ad70d..000000000 --- a/llama_stack/apis/eval/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from .eval import * # noqa: F401 F403 diff --git a/llama_stack/apis/eval/eval.py b/llama_stack/apis/eval/eval.py deleted file mode 100644 index 51c38b16a..000000000 --- a/llama_stack/apis/eval/eval.py +++ /dev/null @@ -1,145 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Any, Dict, List, Literal, Optional, Protocol, Union - -from pydantic import BaseModel, Field -from typing_extensions import Annotated - -from llama_stack.apis.agents import AgentConfig -from llama_stack.apis.common.job_types import Job, JobStatus -from llama_stack.apis.inference import SamplingParams, SystemMessage -from llama_stack.apis.scoring import ScoringResult -from llama_stack.apis.scoring_functions import ScoringFnParams -from llama_stack.schema_utils import json_schema_type, register_schema, webmethod - - -@json_schema_type -class ModelCandidate(BaseModel): - """A model candidate for evaluation. - - :param model: The model ID to evaluate. - :param sampling_params: The sampling parameters for the model. - :param system_message: (Optional) The system message providing instructions or context to the model. - """ - - type: Literal["model"] = "model" - model: str - sampling_params: SamplingParams - system_message: Optional[SystemMessage] = None - - -@json_schema_type -class AgentCandidate(BaseModel): - """An agent candidate for evaluation. - - :param config: The configuration for the agent candidate. - """ - - type: Literal["agent"] = "agent" - config: AgentConfig - - -EvalCandidate = register_schema( - Annotated[Union[ModelCandidate, AgentCandidate], Field(discriminator="type")], - name="EvalCandidate", -) - - -@json_schema_type -class BenchmarkConfig(BaseModel): - """A benchmark configuration for evaluation. - - :param eval_candidate: The candidate to evaluate. - :param scoring_params: Map between scoring function id and parameters for each scoring function you want to run - :param num_examples: (Optional) The number of examples to evaluate. If not provided, all examples in the dataset will be evaluated - """ - - eval_candidate: EvalCandidate - scoring_params: Dict[str, ScoringFnParams] = Field( - description="Map between scoring function id and parameters for each scoring function you want to run", - default_factory=dict, - ) - num_examples: Optional[int] = Field( - description="Number of examples to evaluate (useful for testing), if not provided, all examples in the dataset will be evaluated", - default=None, - ) - # we could optinally add any specific dataset config here - - -@json_schema_type -class EvaluateResponse(BaseModel): - """The response from an evaluation. - - :param generations: The generations from the evaluation. - :param scores: The scores from the evaluation. - """ - - generations: List[Dict[str, Any]] - # each key in the dict is a scoring function name - scores: Dict[str, ScoringResult] - - -class Eval(Protocol): - """Llama Stack Evaluation API for running evaluations on model and agent candidates.""" - - @webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST") - async def run_eval( - self, - benchmark_id: str, - benchmark_config: BenchmarkConfig, - ) -> Job: - """Run an evaluation on a benchmark. - - :param benchmark_id: The ID of the benchmark to run the evaluation on. - :param benchmark_config: The configuration for the benchmark. - :return: The job that was created to run the evaluation. - """ - - @webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST") - async def evaluate_rows( - self, - benchmark_id: str, - input_rows: List[Dict[str, Any]], - scoring_functions: List[str], - benchmark_config: BenchmarkConfig, - ) -> EvaluateResponse: - """Evaluate a list of rows on a benchmark. - - :param benchmark_id: The ID of the benchmark to run the evaluation on. - :param input_rows: The rows to evaluate. - :param scoring_functions: The scoring functions to use for the evaluation. - :param benchmark_config: The configuration for the benchmark. - :return: EvaluateResponse object containing generations and scores - """ - - @webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET") - async def job_status(self, benchmark_id: str, job_id: str) -> JobStatus: - """Get the status of a job. - - :param benchmark_id: The ID of the benchmark to run the evaluation on. - :param job_id: The ID of the job to get the status of. - :return: The status of the evaluationjob. - """ - ... - - @webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="DELETE") - async def job_cancel(self, benchmark_id: str, job_id: str) -> None: - """Cancel a job. - - :param benchmark_id: The ID of the benchmark to run the evaluation on. - :param job_id: The ID of the job to cancel. - """ - ... - - @webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result", method="GET") - async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse: - """Get the result of a job. - - :param benchmark_id: The ID of the benchmark to run the evaluation on. - :param job_id: The ID of the job to get the result of. - :return: The result of the job. - """ diff --git a/llama_stack/apis/scoring/__init__.py b/llama_stack/apis/scoring/__init__.py deleted file mode 100644 index 0739dfc80..000000000 --- a/llama_stack/apis/scoring/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from .scoring import * # noqa: F401 F403 diff --git a/llama_stack/apis/scoring/scoring.py b/llama_stack/apis/scoring/scoring.py deleted file mode 100644 index 54a9ac2aa..000000000 --- a/llama_stack/apis/scoring/scoring.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Any, Dict, List, Optional, Protocol, runtime_checkable - -from pydantic import BaseModel - -from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams -from llama_stack.schema_utils import json_schema_type, webmethod - -# mapping of metric to value -ScoringResultRow = Dict[str, Any] - - -@json_schema_type -class ScoringResult(BaseModel): - """ - A scoring result for a single row. - - :param score_rows: The scoring result for each row. Each row is a map of column name to value. - :param aggregated_results: Map of metric name to aggregated value - """ - - score_rows: List[ScoringResultRow] - # aggregated metrics to value - aggregated_results: Dict[str, Any] - - -@json_schema_type -class ScoreBatchResponse(BaseModel): - dataset_id: Optional[str] = None - results: Dict[str, ScoringResult] - - -@json_schema_type -class ScoreResponse(BaseModel): - """ - The response from scoring. - - :param results: A map of scoring function name to ScoringResult. - """ - - # each key in the dict is a scoring function name - results: Dict[str, ScoringResult] - - -class ScoringFunctionStore(Protocol): - def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn: ... - - -@runtime_checkable -class Scoring(Protocol): - scoring_function_store: ScoringFunctionStore - - @webmethod(route="/scoring/score-batch", method="POST") - async def score_batch( - self, - dataset_id: str, - scoring_functions: Dict[str, Optional[ScoringFnParams]], - save_results_dataset: bool = False, - ) -> ScoreBatchResponse: ... - - @webmethod(route="/scoring/score", method="POST") - async def score( - self, - input_rows: List[Dict[str, Any]], - scoring_functions: Dict[str, Optional[ScoringFnParams]], - ) -> ScoreResponse: - """Score a list of rows. - - :param input_rows: The rows to score. - :param scoring_functions: The scoring functions to use for the scoring. - :return: ScoreResponse object containing rows and aggregated results - """ - ... diff --git a/llama_stack/apis/scoring_functions/__init__.py b/llama_stack/apis/scoring_functions/__init__.py deleted file mode 100644 index b96acb45f..000000000 --- a/llama_stack/apis/scoring_functions/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from .scoring_functions import * # noqa: F401 F403 diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py deleted file mode 100644 index b02a7a0c4..000000000 --- a/llama_stack/apis/scoring_functions/scoring_functions.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from enum import Enum -from typing import ( - Any, - Dict, - List, - Literal, - Optional, - Protocol, - Union, - runtime_checkable, -) - -from pydantic import BaseModel, Field -from typing_extensions import Annotated - -from llama_stack.apis.common.type_system import ParamType -from llama_stack.apis.resource import Resource, ResourceType -from llama_stack.schema_utils import json_schema_type, register_schema, webmethod - - -# Perhaps more structure can be imposed on these functions. Maybe they could be associated -# with standard metrics so they can be rolled up? -@json_schema_type -class ScoringFnParamsType(Enum): - llm_as_judge = "llm_as_judge" - regex_parser = "regex_parser" - basic = "basic" - - -@json_schema_type -class AggregationFunctionType(Enum): - average = "average" - median = "median" - categorical_count = "categorical_count" - accuracy = "accuracy" - - -@json_schema_type -class LLMAsJudgeScoringFnParams(BaseModel): - type: Literal[ScoringFnParamsType.llm_as_judge.value] = ScoringFnParamsType.llm_as_judge.value - judge_model: str - prompt_template: Optional[str] = None - judge_score_regexes: Optional[List[str]] = Field( - description="Regexes to extract the answer from generated response", - default_factory=list, - ) - aggregation_functions: Optional[List[AggregationFunctionType]] = Field( - description="Aggregation functions to apply to the scores of each row", - default_factory=list, - ) - - -@json_schema_type -class RegexParserScoringFnParams(BaseModel): - type: Literal[ScoringFnParamsType.regex_parser.value] = ScoringFnParamsType.regex_parser.value - parsing_regexes: Optional[List[str]] = Field( - description="Regex to extract the answer from generated response", - default_factory=list, - ) - aggregation_functions: Optional[List[AggregationFunctionType]] = Field( - description="Aggregation functions to apply to the scores of each row", - default_factory=list, - ) - - -@json_schema_type -class BasicScoringFnParams(BaseModel): - type: Literal[ScoringFnParamsType.basic.value] = ScoringFnParamsType.basic.value - aggregation_functions: Optional[List[AggregationFunctionType]] = Field( - description="Aggregation functions to apply to the scores of each row", - default_factory=list, - ) - - -ScoringFnParams = register_schema( - Annotated[ - Union[ - LLMAsJudgeScoringFnParams, - RegexParserScoringFnParams, - BasicScoringFnParams, - ], - Field(discriminator="type"), - ], - name="ScoringFnParams", -) - - -class CommonScoringFnFields(BaseModel): - description: Optional[str] = None - metadata: Dict[str, Any] = Field( - default_factory=dict, - description="Any additional metadata for this definition", - ) - return_type: ParamType = Field( - description="The return type of the deterministic function", - ) - params: Optional[ScoringFnParams] = Field( - description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval", - default=None, - ) - - -@json_schema_type -class ScoringFn(CommonScoringFnFields, Resource): - type: Literal[ResourceType.scoring_function.value] = ResourceType.scoring_function.value - - @property - def scoring_fn_id(self) -> str: - return self.identifier - - @property - def provider_scoring_fn_id(self) -> str: - return self.provider_resource_id - - -class ScoringFnInput(CommonScoringFnFields, BaseModel): - scoring_fn_id: str - provider_id: Optional[str] = None - provider_scoring_fn_id: Optional[str] = None - - -class ListScoringFunctionsResponse(BaseModel): - data: List[ScoringFn] - - -@runtime_checkable -class ScoringFunctions(Protocol): - @webmethod(route="/scoring-functions", method="GET") - async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ... - - @webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET") - async def get_scoring_function(self, scoring_fn_id: str, /) -> ScoringFn: ... - - @webmethod(route="/scoring-functions", method="POST") - async def register_scoring_function( - self, - scoring_fn_id: str, - description: str, - return_type: ParamType, - provider_scoring_fn_id: Optional[str] = None, - provider_id: Optional[str] = None, - params: Optional[ScoringFnParams] = None, - ) -> None: ... diff --git a/llama_stack/distribution/distribution.py b/llama_stack/distribution/distribution.py index ddb727663..233712c60 100644 --- a/llama_stack/distribution/distribution.py +++ b/llama_stack/distribution/distribution.py @@ -39,14 +39,6 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]: routing_table_api=Api.datasets, router_api=Api.datasetio, ), - AutoRoutedApiInfo( - routing_table_api=Api.scoring_functions, - router_api=Api.scoring, - ), - AutoRoutedApiInfo( - routing_table_api=Api.benchmarks, - router_api=Api.eval, - ), AutoRoutedApiInfo( routing_table_api=Api.tool_groups, router_api=Api.tool_runtime, @@ -55,8 +47,14 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]: def providable_apis() -> List[Api]: - routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()} - return [api for api in Api if api not in routing_table_apis and api != Api.inspect and api != Api.providers] + routing_table_apis = { + x.routing_table_api for x in builtin_automatically_routed_apis() + } + return [ + api + for api in Api + if api not in routing_table_apis and api != Api.inspect and api != Api.providers + ] def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]: diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 3e44d2926..e78393fed 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -23,12 +23,6 @@ from llama_stack.apis.datasets import ( ) from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType from llama_stack.apis.resource import ResourceType -from llama_stack.apis.scoring_functions import ( - ListScoringFunctionsResponse, - ScoringFn, - ScoringFnParams, - ScoringFunctions, -) from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields from llama_stack.apis.tools import ( ListToolGroupsResponse, @@ -68,10 +62,6 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable return await p.register_vector_db(obj) elif api == Api.datasetio: return await p.register_dataset(obj) - elif api == Api.scoring: - return await p.register_scoring_function(obj) - elif api == Api.eval: - return await p.register_benchmark(obj) elif api == Api.tool_runtime: return await p.register_tool(obj) else: @@ -105,7 +95,9 @@ class CommonRoutingTableImpl(RoutingTable): self.dist_registry = dist_registry async def initialize(self) -> None: - async def add_objects(objs: List[RoutableObjectWithProvider], provider_id: str, cls) -> None: + async def add_objects( + objs: List[RoutableObjectWithProvider], provider_id: str, cls + ) -> None: for obj in objs: if cls is None: obj.provider_id = provider_id @@ -127,12 +119,6 @@ class CommonRoutingTableImpl(RoutingTable): p.vector_db_store = self elif api == Api.datasetio: p.dataset_store = self - elif api == Api.scoring: - p.scoring_function_store = self - scoring_functions = await p.list_scoring_functions() - await add_objects(scoring_functions, pid, ScoringFn) - elif api == Api.eval: - p.benchmark_store = self elif api == Api.tool_runtime: p.tool_store = self @@ -140,7 +126,9 @@ class CommonRoutingTableImpl(RoutingTable): for p in self.impls_by_provider_id.values(): await p.shutdown() - def get_provider_impl(self, routing_key: str, provider_id: Optional[str] = None) -> Any: + def get_provider_impl( + self, routing_key: str, provider_id: Optional[str] = None + ) -> Any: def apiname_object(): if isinstance(self, ModelsRoutingTable): return ("Inference", "model") @@ -150,8 +138,6 @@ class CommonRoutingTableImpl(RoutingTable): return ("VectorIO", "vector_db") elif isinstance(self, DatasetsRoutingTable): return ("DatasetIO", "dataset") - elif isinstance(self, ScoringFunctionsRoutingTable): - return ("Scoring", "scoring_function") elif isinstance(self, BenchmarksRoutingTable): return ("Eval", "benchmark") elif isinstance(self, ToolGroupsRoutingTable): @@ -178,7 +164,9 @@ class CommonRoutingTableImpl(RoutingTable): raise ValueError(f"Provider not found for `{routing_key}`") - async def get_object_by_identifier(self, type: str, identifier: str) -> Optional[RoutableObjectWithProvider]: + async def get_object_by_identifier( + self, type: str, identifier: str + ) -> Optional[RoutableObjectWithProvider]: # Get from disk registry obj = await self.dist_registry.get(type, identifier) if not obj: @@ -188,9 +176,13 @@ class CommonRoutingTableImpl(RoutingTable): async def unregister_object(self, obj: RoutableObjectWithProvider) -> None: await self.dist_registry.delete(obj.type, obj.identifier) - await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id]) + await unregister_object_from_provider( + obj, self.impls_by_provider_id[obj.provider_id] + ) - async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider: + async def register_object( + self, obj: RoutableObjectWithProvider + ) -> RoutableObjectWithProvider: # if provider_id is not specified, pick an arbitrary one from existing entries if not obj.provider_id and len(self.impls_by_provider_id) > 0: obj.provider_id = list(self.impls_by_provider_id.keys())[0] @@ -248,7 +240,9 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): if model_type is None: model_type = ModelType.llm if "embedding_dimension" not in metadata and model_type == ModelType.embedding: - raise ValueError("Embedding model must have an embedding dimension in its metadata") + raise ValueError( + "Embedding model must have an embedding dimension in its metadata" + ) model = Model( identifier=model_id, provider_resource_id=provider_model_id, @@ -268,7 +262,9 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): async def list_shields(self) -> ListShieldsResponse: - return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value)) + return ListShieldsResponse( + data=await self.get_all_with_type(ResourceType.shield.value) + ) async def get_shield(self, identifier: str) -> Shield: shield = await self.get_object_by_identifier("shield", identifier) @@ -333,14 +329,18 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}." ) else: - raise ValueError("No provider available. Please configure a vector_io provider.") + raise ValueError( + "No provider available. Please configure a vector_io provider." + ) model = await self.get_object_by_identifier("model", embedding_model) if model is None: raise ValueError(f"Model {embedding_model} not found") if model.model_type != ModelType.embedding: raise ValueError(f"Model {embedding_model} is not an embedding model") if "embedding_dimension" not in model.metadata: - raise ValueError(f"Model {embedding_model} does not have an embedding dimension") + raise ValueError( + f"Model {embedding_model} does not have an embedding dimension" + ) vector_db_data = { "identifier": vector_db_id, "type": ResourceType.vector_db.value, @@ -362,7 +362,9 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): async def list_datasets(self) -> ListDatasetsResponse: - return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value)) + return ListDatasetsResponse( + data=await self.get_all_with_type(ResourceType.dataset.value) + ) async def get_dataset(self, dataset_id: str) -> Dataset: dataset = await self.get_object_by_identifier("dataset", dataset_id) @@ -418,10 +420,14 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): async def list_scoring_functions(self) -> ListScoringFunctionsResponse: - return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value)) + return ListScoringFunctionsResponse( + data=await self.get_all_with_type(ResourceType.scoring_function.value) + ) async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn: - scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id) + scoring_fn = await self.get_object_by_identifier( + "scoring_function", scoring_fn_id + ) if scoring_fn is None: raise ValueError(f"Scoring function '{scoring_fn_id}' not found") return scoring_fn @@ -485,7 +491,9 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks): # TODO (xiyan): we will need a way to infer provider_id for evaluation # keep it as meta-reference for now if len(self.impls_by_provider_id) == 0: - raise ValueError("No evaluation providers available. Please configure an evaluation provider.") + raise ValueError( + "No evaluation providers available. Please configure an evaluation provider." + ) provider_id = list(self.impls_by_provider_id.keys())[0] benchmark = Benchmark( @@ -527,8 +535,12 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): args: Optional[Dict[str, Any]] = None, ) -> None: tools = [] - tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint) - tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution + tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools( + toolgroup_id, mcp_endpoint + ) + tool_host = ( + ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution + ) for tool_def in tool_defs: tools.append( diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py index ccd75f6f6..e2b792e52 100644 --- a/llama_stack/distribution/stack.py +++ b/llama_stack/distribution/stack.py @@ -17,7 +17,6 @@ from llama_stack.apis.batch_inference import BatchInference from llama_stack.apis.benchmarks import Benchmarks from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets -from llama_stack.apis.eval import Eval from llama_stack.apis.evaluation import Evaluation from llama_stack.apis.files import Files from llama_stack.apis.graders import Graders @@ -27,8 +26,6 @@ from llama_stack.apis.models import Models from llama_stack.apis.post_training import PostTraining from llama_stack.apis.providers import Providers from llama_stack.apis.safety import Safety -from llama_stack.apis.scoring import Scoring -from llama_stack.apis.scoring_functions import ScoringFunctions from llama_stack.apis.shields import Shields from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration from llama_stack.apis.telemetry import Telemetry @@ -117,7 +114,9 @@ class EnvVarError(Exception): def __init__(self, var_name: str, path: str = ""): self.var_name = var_name self.path = path - super().__init__(f"Environment variable '{var_name}' not set or empty{f' at {path}' if path else ''}") + super().__init__( + f"Environment variable '{var_name}' not set or empty{f' at {path}' if path else ''}" + ) def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]: @@ -208,7 +207,9 @@ def validate_env_pair(env_pair: str) -> tuple[str, str]: if not key: raise ValueError(f"Empty key in environment variable pair: {env_pair}") if not all(c.isalnum() or c == "_" for c in key): - raise ValueError(f"Key must contain only alphanumeric characters and underscores: {key}") + raise ValueError( + f"Key must contain only alphanumeric characters and underscores: {key}" + ) return key, value except ValueError as e: raise ValueError( @@ -221,14 +222,20 @@ def validate_env_pair(env_pair: str) -> tuple[str, str]: async def construct_stack( run_config: StackRunConfig, provider_registry: Optional[ProviderRegistry] = None ) -> Dict[Api, Any]: - dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name) - impls = await resolve_impls(run_config, provider_registry or get_provider_registry(), dist_registry) + dist_registry, _ = await create_dist_registry( + run_config.metadata_store, run_config.image_name + ) + impls = await resolve_impls( + run_config, provider_registry or get_provider_registry(), dist_registry + ) await register_resources(run_config, impls) return impls def get_stack_run_config_from_template(template: str) -> StackRunConfig: - template_path = importlib.resources.files("llama_stack") / f"templates/{template}/run.yaml" + template_path = ( + importlib.resources.files("llama_stack") / f"templates/{template}/run.yaml" + ) with importlib.resources.as_file(template_path) as path: if not path.exists(): @@ -271,7 +278,9 @@ def run_config_from_adhoc_config_spec( # call method "sample_run_config" on the provider spec config class provider_config_type = instantiate_class_type(provider_spec.config_class) - provider_config = replace_env_vars(provider_config_type.sample_run_config(__distro_dir__=distro_dir)) + provider_config = replace_env_vars( + provider_config_type.sample_run_config(__distro_dir__=distro_dir) + ) provider_configs_by_api[api_str] = [ Provider( diff --git a/pyproject.toml b/pyproject.toml index cf4e81ab8..da29e0b9b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -168,7 +168,6 @@ exclude = [ "^llama_stack/apis/common/training_types\\.py$", "^llama_stack/apis/datasetio/datasetio\\.py$", "^llama_stack/apis/datasets/datasets\\.py$", - "^llama_stack/apis/eval/eval\\.py$", "^llama_stack/apis/files/files\\.py$", "^llama_stack/apis/inference/inference\\.py$", "^llama_stack/apis/inspect/inspect\\.py$", @@ -177,8 +176,6 @@ exclude = [ "^llama_stack/apis/providers/providers\\.py$", "^llama_stack/apis/resource\\.py$", "^llama_stack/apis/safety/safety\\.py$", - "^llama_stack/apis/scoring/scoring\\.py$", - "^llama_stack/apis/scoring_functions/scoring_functions\\.py$", "^llama_stack/apis/shields/shields\\.py$", "^llama_stack/apis/synthetic_data_generation/synthetic_data_generation\\.py$", "^llama_stack/apis/telemetry/telemetry\\.py$",