diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index e16e047e5..183ee9f34 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -11,12 +11,9 @@ from pydantic import BaseModel, Field from llama_stack.apis.benchmarks import Benchmark, BenchmarkInput from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Dataset, DatasetInput -from llama_stack.apis.eval import Eval from llama_stack.apis.inference import Inference from llama_stack.apis.models import Model, ModelInput from llama_stack.apis.safety import Safety -from llama_stack.apis.scoring import Scoring -from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnInput from llama_stack.apis.shields import Shield, ShieldInput from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput @@ -36,7 +33,6 @@ RoutableObject = Union[ Shield, VectorDB, Dataset, - ScoringFn, Benchmark, Tool, ToolGroup, @@ -49,7 +45,6 @@ RoutableObjectWithProvider = Annotated[ Shield, VectorDB, Dataset, - ScoringFn, Benchmark, Tool, ToolGroup, @@ -62,8 +57,6 @@ RoutedProtocol = Union[ Safety, VectorIO, DatasetIO, - Scoring, - Eval, ToolRuntime, ] @@ -195,7 +188,9 @@ a default SQLite store will be used.""", benchmarks: List[BenchmarkInput] = Field(default_factory=list) tool_groups: List[ToolGroupInput] = Field(default_factory=list) - logging: Optional[LoggingConfig] = Field(default=None, description="Configuration for Llama Stack Logging") + logging: Optional[LoggingConfig] = Field( + default=None, description="Configuration for Llama Stack Logging" + ) server: ServerConfig = Field( default_factory=ServerConfig, @@ -206,7 +201,9 @@ a default SQLite store will be used.""", class BuildConfig(BaseModel): version: str = LLAMA_STACK_BUILD_CONFIG_VERSION - distribution_spec: DistributionSpec = Field(description="The distribution spec to build including API providers. ") + distribution_spec: DistributionSpec = Field( + description="The distribution spec to build including API providers. " + ) image_type: str = Field( default="conda", description="Type of package to build (conda | container | venv)", diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index e9e406699..1a96e0d75 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -11,15 +11,12 @@ from llama_stack.apis.agents import Agents from llama_stack.apis.benchmarks import Benchmarks from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets -from llama_stack.apis.eval import Eval from llama_stack.apis.inference import Inference from llama_stack.apis.inspect import Inspect from llama_stack.apis.models import Models from llama_stack.apis.post_training import PostTraining from llama_stack.apis.providers import Providers as ProvidersAPI from llama_stack.apis.safety import Safety -from llama_stack.apis.scoring import Scoring -from llama_stack.apis.scoring_functions import ScoringFunctions from llama_stack.apis.shields import Shields from llama_stack.apis.telemetry import Telemetry from llama_stack.apis.tools import ToolGroups, ToolRuntime @@ -72,9 +69,6 @@ def api_protocol_map() -> Dict[Api, Any]: Api.telemetry: Telemetry, Api.datasetio: DatasetIO, Api.datasets: Datasets, - Api.scoring: Scoring, - Api.scoring_functions: ScoringFunctions, - Api.eval: Eval, Api.benchmarks: Benchmarks, Api.post_training: PostTraining, Api.tool_groups: ToolGroups, @@ -89,12 +83,6 @@ def additional_protocols_map() -> Dict[Api, Any]: Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs), Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields), Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets), - Api.scoring: ( - ScoringFunctionsProtocolPrivate, - ScoringFunctions, - Api.scoring_functions, - ), - Api.eval: (BenchmarksProtocolPrivate, Benchmarks, Api.benchmarks), } @@ -117,7 +105,9 @@ async def resolve_impls( 2. Sorting them in dependency order. 3. Instantiating them with required dependencies. """ - routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()} + routing_table_apis = { + x.routing_table_api for x in builtin_automatically_routed_apis() + } router_apis = {x.router_api for x in builtin_automatically_routed_apis()} providers_with_specs = validate_and_prepare_providers( @@ -125,7 +115,9 @@ async def resolve_impls( ) apis_to_serve = run_config.apis or set( - list(providers_with_specs.keys()) + [x.value for x in routing_table_apis] + [x.value for x in router_apis] + list(providers_with_specs.keys()) + + [x.value for x in routing_table_apis] + + [x.value for x in router_apis] ) providers_with_specs.update(specs_for_autorouted_apis(apis_to_serve)) @@ -135,7 +127,9 @@ async def resolve_impls( return await instantiate_providers(sorted_providers, router_apis, dist_registry) -def specs_for_autorouted_apis(apis_to_serve: List[str] | Set[str]) -> Dict[str, Dict[str, ProviderWithSpec]]: +def specs_for_autorouted_apis( + apis_to_serve: List[str] | Set[str], +) -> Dict[str, Dict[str, ProviderWithSpec]]: """Generates specifications for automatically routed APIs.""" specs = {} for info in builtin_automatically_routed_apis(): @@ -177,7 +171,10 @@ def specs_for_autorouted_apis(apis_to_serve: List[str] | Set[str]) -> Dict[str, def validate_and_prepare_providers( - run_config: StackRunConfig, provider_registry: ProviderRegistry, routing_table_apis: Set[Api], router_apis: Set[Api] + run_config: StackRunConfig, + provider_registry: ProviderRegistry, + routing_table_apis: Set[Api], + router_apis: Set[Api], ) -> Dict[str, Dict[str, ProviderWithSpec]]: """Validates providers, handles deprecations, and organizes them into a spec dictionary.""" providers_with_specs: Dict[str, Dict[str, ProviderWithSpec]] = {} @@ -185,17 +182,23 @@ def validate_and_prepare_providers( for api_str, providers in run_config.providers.items(): api = Api(api_str) if api in routing_table_apis: - raise ValueError(f"Provider for `{api_str}` is automatically provided and cannot be overridden") + raise ValueError( + f"Provider for `{api_str}` is automatically provided and cannot be overridden" + ) specs = {} for provider in providers: if not provider.provider_id or provider.provider_id == "__disabled__": - logger.warning(f"Provider `{provider.provider_type}` for API `{api}` is disabled") + logger.warning( + f"Provider `{provider.provider_type}` for API `{api}` is disabled" + ) continue validate_provider(provider, api, provider_registry) p = provider_registry[api][provider.provider_type] - p.deps__ = [a.value for a in p.api_dependencies] + [a.value for a in p.optional_api_dependencies] + p.deps__ = [a.value for a in p.api_dependencies] + [ + a.value for a in p.optional_api_dependencies + ] spec = ProviderWithSpec(spec=p, **provider.model_dump()) specs[provider.provider_id] = spec @@ -205,10 +208,14 @@ def validate_and_prepare_providers( return providers_with_specs -def validate_provider(provider: Provider, api: Api, provider_registry: ProviderRegistry): +def validate_provider( + provider: Provider, api: Api, provider_registry: ProviderRegistry +): """Validates if the provider is allowed and handles deprecations.""" if provider.provider_type not in provider_registry[api]: - raise ValueError(f"Provider `{provider.provider_type}` is not available for API `{api}`") + raise ValueError( + f"Provider `{provider.provider_type}` is not available for API `{api}`" + ) p = provider_registry[api][provider.provider_type] if p.deprecation_error: @@ -221,7 +228,8 @@ def validate_provider(provider: Provider, api: Api, provider_registry: ProviderR def sort_providers_by_deps( - providers_with_specs: Dict[str, Dict[str, ProviderWithSpec]], run_config: StackRunConfig + providers_with_specs: Dict[str, Dict[str, ProviderWithSpec]], + run_config: StackRunConfig, ) -> List[Tuple[str, ProviderWithSpec]]: """Sorts providers based on their dependencies.""" sorted_providers: List[Tuple[str, ProviderWithSpec]] = topological_sort( @@ -276,11 +284,15 @@ def sort_providers_by_deps( async def instantiate_providers( - sorted_providers: List[Tuple[str, ProviderWithSpec]], router_apis: Set[Api], dist_registry: DistributionRegistry + sorted_providers: List[Tuple[str, ProviderWithSpec]], + router_apis: Set[Api], + dist_registry: DistributionRegistry, ) -> Dict: """Instantiates providers asynchronously while managing dependencies.""" impls: Dict[Api, Any] = {} - inner_impls_by_provider_id: Dict[str, Dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis} + inner_impls_by_provider_id: Dict[str, Dict[str, Any]] = { + f"inner-{x.value}": {} for x in router_apis + } for api_str, provider in sorted_providers: deps = {a: impls[a] for a in provider.spec.api_dependencies} for a in provider.spec.optional_api_dependencies: @@ -289,7 +301,9 @@ async def instantiate_providers( inner_impls = {} if isinstance(provider.spec, RoutingTableProviderSpec): - inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"] + inner_impls = inner_impls_by_provider_id[ + f"inner-{provider.spec.router_api.value}" + ] impl = await instantiate_provider(provider, deps, inner_impls, dist_registry) @@ -347,7 +361,9 @@ async def instantiate_provider( provider_spec = provider.spec if not hasattr(provider_spec, "module"): - raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute") + raise AttributeError( + f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute" + ) module = importlib.import_module(provider_spec.module) args = [] @@ -384,7 +400,10 @@ async def instantiate_provider( # TODO: check compliance for special tool groups # the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol check_protocol_compliance(impl, protocols[provider_spec.api]) - if not isinstance(provider_spec, AutoRoutedProviderSpec) and provider_spec.api in additional_protocols: + if ( + not isinstance(provider_spec, AutoRoutedProviderSpec) + and provider_spec.api in additional_protocols + ): additional_api, _, _ = additional_protocols[provider_spec.api] check_protocol_compliance(impl, additional_api) @@ -412,12 +431,19 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None: obj_params = set(obj_sig.parameters) obj_params.discard("self") if not (proto_params <= obj_params): - logger.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}") + logger.error( + f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}" + ) missing_methods.append((name, "signature_mismatch")) else: # Check if the method is actually implemented in the class - method_owner = next((cls for cls in mro if name in cls.__dict__), None) - if method_owner is None or method_owner.__name__ == protocol.__name__: + method_owner = next( + (cls for cls in mro if name in cls.__dict__), None + ) + if ( + method_owner is None + or method_owner.__name__ == protocol.__name__ + ): missing_methods.append((name, "not_actually_implemented")) if missing_methods: diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py index d0fca8771..803c94a92 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/distribution/routers/__init__.py @@ -32,7 +32,6 @@ async def get_routing_table_impl( "models": ModelsRoutingTable, "shields": ShieldsRoutingTable, "datasets": DatasetsRoutingTable, - "scoring_functions": ScoringFunctionsRoutingTable, "benchmarks": BenchmarksRoutingTable, "tool_groups": ToolGroupsRoutingTable, } @@ -45,7 +44,9 @@ async def get_routing_table_impl( return impl -async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: Dict[str, Any]) -> Any: +async def get_auto_router_impl( + api: Api, routing_table: RoutingTable, deps: Dict[str, Any] +) -> Any: from .routers import ( DatasetIORouter, EvalRouter, diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 2cf38f544..369789a16 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -8,9 +8,9 @@ import time from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union from llama_stack.apis.common.content_types import ( - URL, InterleavedContent, InterleavedContentItem, + URL, ) from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse from llama_stack.apis.datasets import DatasetPurpose, DataSource @@ -94,7 +94,9 @@ class VectorIORouter(VectorIO): provider_id: Optional[str] = None, provider_vector_db_id: Optional[str] = None, ) -> None: - logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}") + logger.debug( + f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}" + ) await self.routing_table.register_vector_db( vector_db_id, embedding_model, @@ -112,7 +114,9 @@ class VectorIORouter(VectorIO): logger.debug( f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}", ) - return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds) + return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks( + vector_db_id, chunks, ttl_seconds + ) async def query_chunks( self, @@ -121,7 +125,9 @@ class VectorIORouter(VectorIO): params: Optional[Dict[str, Any]] = None, ) -> QueryChunksResponse: logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}") - return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params) + return await self.routing_table.get_provider_impl(vector_db_id).query_chunks( + vector_db_id, query, params + ) class InferenceRouter(Inference): @@ -158,7 +164,9 @@ class InferenceRouter(Inference): logger.debug( f"InferenceRouter.register_model: {model_id=} {provider_model_id=} {provider_id=} {metadata=} {model_type=}", ) - await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type) + await self.routing_table.register_model( + model_id, provider_model_id, provider_id, metadata, model_type + ) def _construct_metrics( self, @@ -212,11 +220,16 @@ class InferenceRouter(Inference): total_tokens: int, model: Model, ) -> List[MetricInResponse]: - metrics = self._construct_metrics(prompt_tokens, completion_tokens, total_tokens, model) + metrics = self._construct_metrics( + prompt_tokens, completion_tokens, total_tokens, model + ) if self.telemetry: for metric in metrics: await self.telemetry.log_event(metric) - return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics] + return [ + MetricInResponse(metric=metric.metric, value=metric.value) + for metric in metrics + ] async def _count_tokens( self, @@ -241,7 +254,9 @@ class InferenceRouter(Inference): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, tool_config: Optional[ToolConfig] = None, - ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: + ) -> Union[ + ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk] + ]: logger.debug( f"InferenceRouter.chat_completion: {model_id=}, {stream=}, {messages=}, {tools=}, {tool_config=}, {response_format=}", ) @@ -251,12 +266,19 @@ class InferenceRouter(Inference): if model is None: raise ValueError(f"Model '{model_id}' not found") if model.model_type == ModelType.embedding: - raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions") + raise ValueError( + f"Model '{model_id}' is an embedding model and does not support chat completions" + ) if tool_config: if tool_choice and tool_choice != tool_config.tool_choice: raise ValueError("tool_choice and tool_config.tool_choice must match") - if tool_prompt_format and tool_prompt_format != tool_config.tool_prompt_format: - raise ValueError("tool_prompt_format and tool_config.tool_prompt_format must match") + if ( + tool_prompt_format + and tool_prompt_format != tool_config.tool_prompt_format + ): + raise ValueError( + "tool_prompt_format and tool_config.tool_prompt_format must match" + ) else: params = {} if tool_choice: @@ -274,9 +296,14 @@ class InferenceRouter(Inference): pass else: # verify tool_choice is one of the tools - tool_names = [t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value for t in tools] + tool_names = [ + t.tool_name if isinstance(t.tool_name, str) else t.tool_name.value + for t in tools + ] if tool_config.tool_choice not in tool_names: - raise ValueError(f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}") + raise ValueError( + f"Tool choice {tool_config.tool_choice} is not one of the tools: {tool_names}" + ) params = dict( model_id=model_id, @@ -291,17 +318,25 @@ class InferenceRouter(Inference): tool_config=tool_config, ) provider = self.routing_table.get_provider_impl(model_id) - prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format) + prompt_tokens = await self._count_tokens( + messages, tool_config.tool_prompt_format + ) if stream: async def stream_generator(): completion_text = "" async for chunk in await provider.chat_completion(**params): - if chunk.event.event_type == ChatCompletionResponseEventType.progress: + if ( + chunk.event.event_type + == ChatCompletionResponseEventType.progress + ): if chunk.event.delta.type == "text": completion_text += chunk.event.delta.text - if chunk.event.event_type == ChatCompletionResponseEventType.complete: + if ( + chunk.event.event_type + == ChatCompletionResponseEventType.complete + ): completion_tokens = await self._count_tokens( [ CompletionMessage( @@ -318,7 +353,11 @@ class InferenceRouter(Inference): total_tokens, model, ) - chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics + chunk.metrics = ( + metrics + if chunk.metrics is None + else chunk.metrics + metrics + ) yield chunk return stream_generator() @@ -335,7 +374,9 @@ class InferenceRouter(Inference): total_tokens, model, ) - response.metrics = metrics if response.metrics is None else response.metrics + metrics + response.metrics = ( + metrics if response.metrics is None else response.metrics + metrics + ) return response async def completion( @@ -356,7 +397,9 @@ class InferenceRouter(Inference): if model is None: raise ValueError(f"Model '{model_id}' not found") if model.model_type == ModelType.embedding: - raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions") + raise ValueError( + f"Model '{model_id}' is an embedding model and does not support chat completions" + ) provider = self.routing_table.get_provider_impl(model_id) params = dict( model_id=model_id, @@ -376,7 +419,11 @@ class InferenceRouter(Inference): async for chunk in await provider.completion(**params): if hasattr(chunk, "delta"): completion_text += chunk.delta - if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry: + if ( + hasattr(chunk, "stop_reason") + and chunk.stop_reason + and self.telemetry + ): completion_tokens = await self._count_tokens(completion_text) total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) metrics = await self._compute_and_log_token_usage( @@ -385,7 +432,11 @@ class InferenceRouter(Inference): total_tokens, model, ) - chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics + chunk.metrics = ( + metrics + if chunk.metrics is None + else chunk.metrics + metrics + ) yield chunk return stream_generator() @@ -399,7 +450,9 @@ class InferenceRouter(Inference): total_tokens, model, ) - response.metrics = metrics if response.metrics is None else response.metrics + metrics + response.metrics = ( + metrics if response.metrics is None else response.metrics + metrics + ) return response async def embeddings( @@ -415,7 +468,9 @@ class InferenceRouter(Inference): if model is None: raise ValueError(f"Model '{model_id}' not found") if model.model_type == ModelType.llm: - raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings") + raise ValueError( + f"Model '{model_id}' is an LLM model and does not support embeddings" + ) return await self.routing_table.get_provider_impl(model_id).embeddings( model_id=model_id, contents=contents, @@ -449,7 +504,9 @@ class SafetyRouter(Safety): params: Optional[Dict[str, Any]] = None, ) -> Shield: logger.debug(f"SafetyRouter.register_shield: {shield_id}") - return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params) + return await self.routing_table.register_shield( + shield_id, provider_shield_id, provider_id, params + ) async def run_shield( self, @@ -521,135 +578,6 @@ class DatasetIORouter(DatasetIO): ) -class ScoringRouter(Scoring): - def __init__( - self, - routing_table: RoutingTable, - ) -> None: - logger.debug("Initializing ScoringRouter") - self.routing_table = routing_table - - async def initialize(self) -> None: - logger.debug("ScoringRouter.initialize") - pass - - async def shutdown(self) -> None: - logger.debug("ScoringRouter.shutdown") - pass - - async def score_batch( - self, - dataset_id: str, - scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, - save_results_dataset: bool = False, - ) -> ScoreBatchResponse: - logger.debug(f"ScoringRouter.score_batch: {dataset_id}") - res = {} - for fn_identifier in scoring_functions.keys(): - score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch( - dataset_id=dataset_id, - scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, - ) - res.update(score_response.results) - - if save_results_dataset: - raise NotImplementedError("Save results dataset not implemented yet") - - return ScoreBatchResponse( - results=res, - ) - - async def score( - self, - input_rows: List[Dict[str, Any]], - scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, - ) -> ScoreResponse: - logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions") - res = {} - # look up and map each scoring function to its provider impl - for fn_identifier in scoring_functions.keys(): - score_response = await self.routing_table.get_provider_impl(fn_identifier).score( - input_rows=input_rows, - scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, - ) - res.update(score_response.results) - - return ScoreResponse(results=res) - - -class EvalRouter(Eval): - def __init__( - self, - routing_table: RoutingTable, - ) -> None: - logger.debug("Initializing EvalRouter") - self.routing_table = routing_table - - async def initialize(self) -> None: - logger.debug("EvalRouter.initialize") - pass - - async def shutdown(self) -> None: - logger.debug("EvalRouter.shutdown") - pass - - async def run_eval( - self, - benchmark_id: str, - benchmark_config: BenchmarkConfig, - ) -> Job: - logger.debug(f"EvalRouter.run_eval: {benchmark_id}") - return await self.routing_table.get_provider_impl(benchmark_id).run_eval( - benchmark_id=benchmark_id, - benchmark_config=benchmark_config, - ) - - async def evaluate_rows( - self, - benchmark_id: str, - input_rows: List[Dict[str, Any]], - scoring_functions: List[str], - benchmark_config: BenchmarkConfig, - ) -> EvaluateResponse: - logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows") - return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows( - benchmark_id=benchmark_id, - input_rows=input_rows, - scoring_functions=scoring_functions, - benchmark_config=benchmark_config, - ) - - async def job_status( - self, - benchmark_id: str, - job_id: str, - ) -> Optional[JobStatus]: - logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}") - return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id) - - async def job_cancel( - self, - benchmark_id: str, - job_id: str, - ) -> None: - logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}") - await self.routing_table.get_provider_impl(benchmark_id).job_cancel( - benchmark_id, - job_id, - ) - - async def job_result( - self, - benchmark_id: str, - job_id: str, - ) -> EvaluateResponse: - logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}") - return await self.routing_table.get_provider_impl(benchmark_id).job_result( - benchmark_id, - job_id, - ) - - class ToolRuntimeRouter(ToolRuntime): class RagToolImpl(RAGToolRuntime): def __init__( @@ -679,9 +607,9 @@ class ToolRuntimeRouter(ToolRuntime): logger.debug( f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}" ) - return await self.routing_table.get_provider_impl("insert_into_memory").insert( - documents, vector_db_id, chunk_size_in_tokens - ) + return await self.routing_table.get_provider_impl( + "insert_into_memory" + ).insert(documents, vector_db_id, chunk_size_in_tokens) def __init__( self, @@ -714,4 +642,6 @@ class ToolRuntimeRouter(ToolRuntime): self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None ) -> List[ToolDef]: logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}") - return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint) + return await self.routing_table.get_provider_impl(tool_group_id).list_tools( + tool_group_id, mcp_endpoint + ) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index e78393fed..9aaf83483 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -418,50 +418,6 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): await self.unregister_object(dataset) -class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): - async def list_scoring_functions(self) -> ListScoringFunctionsResponse: - return ListScoringFunctionsResponse( - data=await self.get_all_with_type(ResourceType.scoring_function.value) - ) - - async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn: - scoring_fn = await self.get_object_by_identifier( - "scoring_function", scoring_fn_id - ) - if scoring_fn is None: - raise ValueError(f"Scoring function '{scoring_fn_id}' not found") - return scoring_fn - - async def register_scoring_function( - self, - scoring_fn_id: str, - description: str, - return_type: ParamType, - provider_scoring_fn_id: Optional[str] = None, - provider_id: Optional[str] = None, - params: Optional[ScoringFnParams] = None, - ) -> None: - if provider_scoring_fn_id is None: - provider_scoring_fn_id = scoring_fn_id - if provider_id is None: - if len(self.impls_by_provider_id) == 1: - provider_id = list(self.impls_by_provider_id.keys())[0] - else: - raise ValueError( - "No provider specified and multiple providers available. Please specify a provider_id." - ) - scoring_fn = ScoringFn( - identifier=scoring_fn_id, - description=description, - return_type=return_type, - provider_resource_id=provider_scoring_fn_id, - provider_id=provider_id, - params=params, - ) - scoring_fn.provider_id = provider_id - await self.register_object(scoring_fn) - - class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks): async def list_benchmarks(self) -> ListBenchmarksResponse: return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark")) diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py index e2b792e52..fcb53da9b 100644 --- a/llama_stack/distribution/stack.py +++ b/llama_stack/distribution/stack.py @@ -78,12 +78,6 @@ RESOURCES = [ ("shields", Api.shields, "register_shield", "list_shields"), ("vector_dbs", Api.vector_dbs, "register_vector_db", "list_vector_dbs"), ("datasets", Api.datasets, "register_dataset", "list_datasets"), - ( - "scoring_fns", - Api.scoring_functions, - "register_scoring_function", - "list_scoring_functions", - ), ("benchmarks", Api.benchmarks, "register_benchmark", "list_benchmarks"), ("tool_groups", Api.tool_groups, "register_tool_group", "list_tool_groups"), ] diff --git a/llama_stack/distribution/ui/modules/api.py b/llama_stack/distribution/ui/modules/api.py index 40caccda0..0e2e1d14f 100644 --- a/llama_stack/distribution/ui/modules/api.py +++ b/llama_stack/distribution/ui/modules/api.py @@ -22,11 +22,16 @@ class LlamaStackApi: }, ) - def run_scoring(self, row, scoring_function_ids: list[str], scoring_params: Optional[dict]): + def run_scoring( + self, row, scoring_function_ids: list[str], scoring_params: Optional[dict] + ): """Run scoring on a single row""" if not scoring_params: scoring_params = {fn_id: None for fn_id in scoring_function_ids} - return self.client.scoring.score(input_rows=[row], scoring_functions=scoring_params) + + # TODO(xiyan): fix this + # return self.client.scoring.score(input_rows=[row], scoring_functions=scoring_params) + raise NotImplementedError("Scoring is not implemented") llama_stack_api = LlamaStackApi() diff --git a/llama_stack/distribution/ui/page/distribution/resources.py b/llama_stack/distribution/ui/page/distribution/resources.py index 5e10e6e80..da42c468c 100644 --- a/llama_stack/distribution/ui/page/distribution/resources.py +++ b/llama_stack/distribution/ui/page/distribution/resources.py @@ -4,14 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from streamlit_option_menu import option_menu - from llama_stack.distribution.ui.page.distribution.datasets import datasets from llama_stack.distribution.ui.page.distribution.eval_tasks import benchmarks from llama_stack.distribution.ui.page.distribution.models import models -from llama_stack.distribution.ui.page.distribution.scoring_functions import scoring_functions from llama_stack.distribution.ui.page.distribution.shields import shields from llama_stack.distribution.ui.page.distribution.vector_dbs import vector_dbs +from streamlit_option_menu import option_menu def resources_page(): @@ -43,8 +41,9 @@ def resources_page(): datasets() elif selected_resource == "Models": models() - elif selected_resource == "Scoring Functions": - scoring_functions() + # TODO(xiyan): fix this + # elif selected_resource == "Scoring Functions": + # scoring_functions() elif selected_resource == "Shields": shields() diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 384582423..76873d188 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -13,7 +13,6 @@ from llama_stack.apis.benchmarks import Benchmark from llama_stack.apis.datasets import Dataset from llama_stack.apis.datatypes import Api from llama_stack.apis.models import Model -from llama_stack.apis.scoring_functions import ScoringFn from llama_stack.apis.shields import Shield from llama_stack.apis.tools import Tool from llama_stack.apis.vector_dbs import VectorDB @@ -42,12 +41,6 @@ class DatasetsProtocolPrivate(Protocol): async def unregister_dataset(self, dataset_id: str) -> None: ... -class ScoringFunctionsProtocolPrivate(Protocol): - async def list_scoring_functions(self) -> List[ScoringFn]: ... - - async def register_scoring_function(self, scoring_fn: ScoringFn) -> None: ... - - class BenchmarksProtocolPrivate(Protocol): async def register_benchmark(self, benchmark: Benchmark) -> None: ... diff --git a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_entity_recall.py b/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_entity_recall.py index d9b129a8b..ac41df000 100644 --- a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_entity_recall.py +++ b/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/context_entity_recall.py @@ -20,5 +20,7 @@ context_entity_recall_fn_def = ScoringFn( provider_id="braintrust", provider_resource_id="context-entity-recall", return_type=NumberType(), - params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.average]), + params=BasicScoringFnParams( + aggregation_functions=[AggregationFunctionType.average] + ), )