diff --git a/docs/source/conf.py b/docs/source/conf.py index 43e8dbdd5..6e59dbdfb 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -22,7 +22,11 @@ from docutils import nodes # Read version from pyproject.toml with Path(__file__).parent.parent.parent.joinpath("pyproject.toml").open("rb") as f: pypi_url = "https://pypi.org/pypi/llama-stack/json" - version_tag = json.loads(requests.get(pypi_url).text)["info"]["version"] + headers = { + 'User-Agent': 'pip/23.0.1 (python 3.11)', # Mimic pip's user agent + 'Accept': 'application/json' + } + version_tag = json.loads(requests.get(pypi_url, headers=headers).text)["info"]["version"] print(f"{version_tag=}") # generate the full link including text and url here diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py index 84560b355..1358d5812 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/distribution/routers/__init__.py @@ -12,16 +12,6 @@ from llama_stack.distribution.store import DistributionRegistry from llama_stack.providers.datatypes import Api, RoutingTable from llama_stack.providers.utils.inference.inference_store import InferenceStore -from .routing_tables import ( - BenchmarksRoutingTable, - DatasetsRoutingTable, - ModelsRoutingTable, - ScoringFunctionsRoutingTable, - ShieldsRoutingTable, - ToolGroupsRoutingTable, - VectorDBsRoutingTable, -) - async def get_routing_table_impl( api: Api, @@ -29,6 +19,14 @@ async def get_routing_table_impl( _deps, dist_registry: DistributionRegistry, ) -> Any: + from ..routing_tables.benchmarks import BenchmarksRoutingTable + from ..routing_tables.datasets import DatasetsRoutingTable + from ..routing_tables.models import ModelsRoutingTable + from ..routing_tables.scoring_functions import ScoringFunctionsRoutingTable + from ..routing_tables.shields import ShieldsRoutingTable + from ..routing_tables.toolgroups import ToolGroupsRoutingTable + from ..routing_tables.vector_dbs import VectorDBsRoutingTable + api_to_tables = { "vector_dbs": VectorDBsRoutingTable, "models": ModelsRoutingTable, @@ -50,15 +48,12 @@ async def get_routing_table_impl( async def get_auto_router_impl( api: Api, routing_table: RoutingTable, deps: dict[str, Any], run_config: StackRunConfig ) -> Any: - from .routers import ( - DatasetIORouter, - EvalRouter, - InferenceRouter, - SafetyRouter, - ScoringRouter, - ToolRuntimeRouter, - VectorIORouter, - ) + from .datasets import DatasetIORouter + from .eval_scoring import EvalRouter, ScoringRouter + from .inference import InferenceRouter + from .safety import SafetyRouter + from .tool_runtime import ToolRuntimeRouter + from .vector_io import VectorIORouter api_to_routers = { "vector_io": VectorIORouter, diff --git a/llama_stack/distribution/routers/datasets.py b/llama_stack/distribution/routers/datasets.py new file mode 100644 index 000000000..6f28756c9 --- /dev/null +++ b/llama_stack/distribution/routers/datasets.py @@ -0,0 +1,71 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from llama_stack.apis.common.responses import PaginatedResponse +from llama_stack.apis.datasetio import DatasetIO +from llama_stack.apis.datasets import DatasetPurpose, DataSource +from llama_stack.log import get_logger +from llama_stack.providers.datatypes import RoutingTable + +logger = get_logger(name=__name__, category="core") + + +class DatasetIORouter(DatasetIO): + def __init__( + self, + routing_table: RoutingTable, + ) -> None: + logger.debug("Initializing DatasetIORouter") + self.routing_table = routing_table + + async def initialize(self) -> None: + logger.debug("DatasetIORouter.initialize") + pass + + async def shutdown(self) -> None: + logger.debug("DatasetIORouter.shutdown") + pass + + async def register_dataset( + self, + purpose: DatasetPurpose, + source: DataSource, + metadata: dict[str, Any] | None = None, + dataset_id: str | None = None, + ) -> None: + logger.debug( + f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}", + ) + await self.routing_table.register_dataset( + purpose=purpose, + source=source, + metadata=metadata, + dataset_id=dataset_id, + ) + + async def iterrows( + self, + dataset_id: str, + start_index: int | None = None, + limit: int | None = None, + ) -> PaginatedResponse: + logger.debug( + f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}", + ) + return await self.routing_table.get_provider_impl(dataset_id).iterrows( + dataset_id=dataset_id, + start_index=start_index, + limit=limit, + ) + + async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None: + logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows") + return await self.routing_table.get_provider_impl(dataset_id).append_rows( + dataset_id=dataset_id, + rows=rows, + ) diff --git a/llama_stack/distribution/routers/eval_scoring.py b/llama_stack/distribution/routers/eval_scoring.py new file mode 100644 index 000000000..fd0bb90a7 --- /dev/null +++ b/llama_stack/distribution/routers/eval_scoring.py @@ -0,0 +1,148 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job +from llama_stack.apis.scoring import ( + ScoreBatchResponse, + ScoreResponse, + Scoring, + ScoringFnParams, +) +from llama_stack.log import get_logger +from llama_stack.providers.datatypes import RoutingTable + +logger = get_logger(name=__name__, category="core") + + +class ScoringRouter(Scoring): + def __init__( + self, + routing_table: RoutingTable, + ) -> None: + logger.debug("Initializing ScoringRouter") + self.routing_table = routing_table + + async def initialize(self) -> None: + logger.debug("ScoringRouter.initialize") + pass + + async def shutdown(self) -> None: + logger.debug("ScoringRouter.shutdown") + pass + + async def score_batch( + self, + dataset_id: str, + scoring_functions: dict[str, ScoringFnParams | None] = None, + save_results_dataset: bool = False, + ) -> ScoreBatchResponse: + logger.debug(f"ScoringRouter.score_batch: {dataset_id}") + res = {} + for fn_identifier in scoring_functions.keys(): + score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch( + dataset_id=dataset_id, + scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, + ) + res.update(score_response.results) + + if save_results_dataset: + raise NotImplementedError("Save results dataset not implemented yet") + + return ScoreBatchResponse( + results=res, + ) + + async def score( + self, + input_rows: list[dict[str, Any]], + scoring_functions: dict[str, ScoringFnParams | None] = None, + ) -> ScoreResponse: + logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions") + res = {} + # look up and map each scoring function to its provider impl + for fn_identifier in scoring_functions.keys(): + score_response = await self.routing_table.get_provider_impl(fn_identifier).score( + input_rows=input_rows, + scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, + ) + res.update(score_response.results) + + return ScoreResponse(results=res) + + +class EvalRouter(Eval): + def __init__( + self, + routing_table: RoutingTable, + ) -> None: + logger.debug("Initializing EvalRouter") + self.routing_table = routing_table + + async def initialize(self) -> None: + logger.debug("EvalRouter.initialize") + pass + + async def shutdown(self) -> None: + logger.debug("EvalRouter.shutdown") + pass + + async def run_eval( + self, + benchmark_id: str, + benchmark_config: BenchmarkConfig, + ) -> Job: + logger.debug(f"EvalRouter.run_eval: {benchmark_id}") + return await self.routing_table.get_provider_impl(benchmark_id).run_eval( + benchmark_id=benchmark_id, + benchmark_config=benchmark_config, + ) + + async def evaluate_rows( + self, + benchmark_id: str, + input_rows: list[dict[str, Any]], + scoring_functions: list[str], + benchmark_config: BenchmarkConfig, + ) -> EvaluateResponse: + logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows") + return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows( + benchmark_id=benchmark_id, + input_rows=input_rows, + scoring_functions=scoring_functions, + benchmark_config=benchmark_config, + ) + + async def job_status( + self, + benchmark_id: str, + job_id: str, + ) -> Job: + logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}") + return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id) + + async def job_cancel( + self, + benchmark_id: str, + job_id: str, + ) -> None: + logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}") + await self.routing_table.get_provider_impl(benchmark_id).job_cancel( + benchmark_id, + job_id, + ) + + async def job_result( + self, + benchmark_id: str, + job_id: str, + ) -> EvaluateResponse: + logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}") + return await self.routing_table.get_provider_impl(benchmark_id).job_result( + benchmark_id, + job_id, + ) diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/inference.py similarity index 66% rename from llama_stack/distribution/routers/routers.py rename to llama_stack/distribution/routers/inference.py index 0515b19f8..f77b19302 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/inference.py @@ -14,14 +14,9 @@ from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToo from pydantic import Field, TypeAdapter from llama_stack.apis.common.content_types import ( - URL, InterleavedContent, InterleavedContentItem, ) -from llama_stack.apis.common.responses import PaginatedResponse -from llama_stack.apis.datasetio import DatasetIO -from llama_stack.apis.datasets import DatasetPurpose, DataSource -from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job from llama_stack.apis.inference import ( BatchChatCompletionResponse, BatchCompletionResponse, @@ -54,24 +49,7 @@ from llama_stack.apis.inference.inference import ( OpenAIResponseFormatParam, ) from llama_stack.apis.models import Model, ModelType -from llama_stack.apis.safety import RunShieldResponse, Safety -from llama_stack.apis.scoring import ( - ScoreBatchResponse, - ScoreResponse, - Scoring, - ScoringFnParams, -) -from llama_stack.apis.shields import Shield from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry -from llama_stack.apis.tools import ( - ListToolDefsResponse, - RAGDocument, - RAGQueryConfig, - RAGQueryResult, - RAGToolRuntime, - ToolRuntime, -) -from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.log import get_logger from llama_stack.models.llama.llama3.chat_format import ChatFormat from llama_stack.models.llama.llama3.tokenizer import Tokenizer @@ -83,62 +61,6 @@ from llama_stack.providers.utils.telemetry.tracing import get_current_span logger = get_logger(name=__name__, category="core") -class VectorIORouter(VectorIO): - """Routes to an provider based on the vector db identifier""" - - def __init__( - self, - routing_table: RoutingTable, - ) -> None: - logger.debug("Initializing VectorIORouter") - self.routing_table = routing_table - - async def initialize(self) -> None: - logger.debug("VectorIORouter.initialize") - pass - - async def shutdown(self) -> None: - logger.debug("VectorIORouter.shutdown") - pass - - async def register_vector_db( - self, - vector_db_id: str, - embedding_model: str, - embedding_dimension: int | None = 384, - provider_id: str | None = None, - provider_vector_db_id: str | None = None, - ) -> None: - logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}") - await self.routing_table.register_vector_db( - vector_db_id, - embedding_model, - embedding_dimension, - provider_id, - provider_vector_db_id, - ) - - async def insert_chunks( - self, - vector_db_id: str, - chunks: list[Chunk], - ttl_seconds: int | None = None, - ) -> None: - logger.debug( - f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}", - ) - return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds) - - async def query_chunks( - self, - vector_db_id: str, - query: InterleavedContent, - params: dict[str, Any] | None = None, - ) -> QueryChunksResponse: - logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}") - return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params) - - class InferenceRouter(Inference): """Routes to an provider based on the model""" @@ -671,295 +593,3 @@ class InferenceRouter(Inference): status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}" ) return health_statuses - - -class SafetyRouter(Safety): - def __init__( - self, - routing_table: RoutingTable, - ) -> None: - logger.debug("Initializing SafetyRouter") - self.routing_table = routing_table - - async def initialize(self) -> None: - logger.debug("SafetyRouter.initialize") - pass - - async def shutdown(self) -> None: - logger.debug("SafetyRouter.shutdown") - pass - - async def register_shield( - self, - shield_id: str, - provider_shield_id: str | None = None, - provider_id: str | None = None, - params: dict[str, Any] | None = None, - ) -> Shield: - logger.debug(f"SafetyRouter.register_shield: {shield_id}") - return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params) - - async def run_shield( - self, - shield_id: str, - messages: list[Message], - params: dict[str, Any] = None, - ) -> RunShieldResponse: - logger.debug(f"SafetyRouter.run_shield: {shield_id}") - return await self.routing_table.get_provider_impl(shield_id).run_shield( - shield_id=shield_id, - messages=messages, - params=params, - ) - - -class DatasetIORouter(DatasetIO): - def __init__( - self, - routing_table: RoutingTable, - ) -> None: - logger.debug("Initializing DatasetIORouter") - self.routing_table = routing_table - - async def initialize(self) -> None: - logger.debug("DatasetIORouter.initialize") - pass - - async def shutdown(self) -> None: - logger.debug("DatasetIORouter.shutdown") - pass - - async def register_dataset( - self, - purpose: DatasetPurpose, - source: DataSource, - metadata: dict[str, Any] | None = None, - dataset_id: str | None = None, - ) -> None: - logger.debug( - f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}", - ) - await self.routing_table.register_dataset( - purpose=purpose, - source=source, - metadata=metadata, - dataset_id=dataset_id, - ) - - async def iterrows( - self, - dataset_id: str, - start_index: int | None = None, - limit: int | None = None, - ) -> PaginatedResponse: - logger.debug( - f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}", - ) - return await self.routing_table.get_provider_impl(dataset_id).iterrows( - dataset_id=dataset_id, - start_index=start_index, - limit=limit, - ) - - async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None: - logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows") - return await self.routing_table.get_provider_impl(dataset_id).append_rows( - dataset_id=dataset_id, - rows=rows, - ) - - -class ScoringRouter(Scoring): - def __init__( - self, - routing_table: RoutingTable, - ) -> None: - logger.debug("Initializing ScoringRouter") - self.routing_table = routing_table - - async def initialize(self) -> None: - logger.debug("ScoringRouter.initialize") - pass - - async def shutdown(self) -> None: - logger.debug("ScoringRouter.shutdown") - pass - - async def score_batch( - self, - dataset_id: str, - scoring_functions: dict[str, ScoringFnParams | None] = None, - save_results_dataset: bool = False, - ) -> ScoreBatchResponse: - logger.debug(f"ScoringRouter.score_batch: {dataset_id}") - res = {} - for fn_identifier in scoring_functions.keys(): - score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch( - dataset_id=dataset_id, - scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, - ) - res.update(score_response.results) - - if save_results_dataset: - raise NotImplementedError("Save results dataset not implemented yet") - - return ScoreBatchResponse( - results=res, - ) - - async def score( - self, - input_rows: list[dict[str, Any]], - scoring_functions: dict[str, ScoringFnParams | None] = None, - ) -> ScoreResponse: - logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions") - res = {} - # look up and map each scoring function to its provider impl - for fn_identifier in scoring_functions.keys(): - score_response = await self.routing_table.get_provider_impl(fn_identifier).score( - input_rows=input_rows, - scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, - ) - res.update(score_response.results) - - return ScoreResponse(results=res) - - -class EvalRouter(Eval): - def __init__( - self, - routing_table: RoutingTable, - ) -> None: - logger.debug("Initializing EvalRouter") - self.routing_table = routing_table - - async def initialize(self) -> None: - logger.debug("EvalRouter.initialize") - pass - - async def shutdown(self) -> None: - logger.debug("EvalRouter.shutdown") - pass - - async def run_eval( - self, - benchmark_id: str, - benchmark_config: BenchmarkConfig, - ) -> Job: - logger.debug(f"EvalRouter.run_eval: {benchmark_id}") - return await self.routing_table.get_provider_impl(benchmark_id).run_eval( - benchmark_id=benchmark_id, - benchmark_config=benchmark_config, - ) - - async def evaluate_rows( - self, - benchmark_id: str, - input_rows: list[dict[str, Any]], - scoring_functions: list[str], - benchmark_config: BenchmarkConfig, - ) -> EvaluateResponse: - logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows") - return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows( - benchmark_id=benchmark_id, - input_rows=input_rows, - scoring_functions=scoring_functions, - benchmark_config=benchmark_config, - ) - - async def job_status( - self, - benchmark_id: str, - job_id: str, - ) -> Job: - logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}") - return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id) - - async def job_cancel( - self, - benchmark_id: str, - job_id: str, - ) -> None: - logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}") - await self.routing_table.get_provider_impl(benchmark_id).job_cancel( - benchmark_id, - job_id, - ) - - async def job_result( - self, - benchmark_id: str, - job_id: str, - ) -> EvaluateResponse: - logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}") - return await self.routing_table.get_provider_impl(benchmark_id).job_result( - benchmark_id, - job_id, - ) - - -class ToolRuntimeRouter(ToolRuntime): - class RagToolImpl(RAGToolRuntime): - def __init__( - self, - routing_table: RoutingTable, - ) -> None: - logger.debug("Initializing ToolRuntimeRouter.RagToolImpl") - self.routing_table = routing_table - - async def query( - self, - content: InterleavedContent, - vector_db_ids: list[str], - query_config: RAGQueryConfig | None = None, - ) -> RAGQueryResult: - logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}") - return await self.routing_table.get_provider_impl("knowledge_search").query( - content, vector_db_ids, query_config - ) - - async def insert( - self, - documents: list[RAGDocument], - vector_db_id: str, - chunk_size_in_tokens: int = 512, - ) -> None: - logger.debug( - f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}" - ) - return await self.routing_table.get_provider_impl("insert_into_memory").insert( - documents, vector_db_id, chunk_size_in_tokens - ) - - def __init__( - self, - routing_table: RoutingTable, - ) -> None: - logger.debug("Initializing ToolRuntimeRouter") - self.routing_table = routing_table - - # HACK ALERT this should be in sync with "get_all_api_endpoints()" - self.rag_tool = self.RagToolImpl(routing_table) - for method in ("query", "insert"): - setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method)) - - async def initialize(self) -> None: - logger.debug("ToolRuntimeRouter.initialize") - pass - - async def shutdown(self) -> None: - logger.debug("ToolRuntimeRouter.shutdown") - pass - - async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any: - logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}") - return await self.routing_table.get_provider_impl(tool_name).invoke_tool( - tool_name=tool_name, - kwargs=kwargs, - ) - - async def list_runtime_tools( - self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None - ) -> ListToolDefsResponse: - logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}") - return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py deleted file mode 100644 index c04562197..000000000 --- a/llama_stack/distribution/routers/routing_tables.py +++ /dev/null @@ -1,634 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import logging -import time -import uuid -from typing import Any - -from pydantic import TypeAdapter - -from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse -from llama_stack.apis.common.content_types import URL -from llama_stack.apis.common.type_system import ParamType -from llama_stack.apis.datasets import ( - Dataset, - DatasetPurpose, - Datasets, - DatasetType, - DataSource, - ListDatasetsResponse, - RowsDataSource, - URIDataSource, -) -from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel -from llama_stack.apis.resource import ResourceType -from llama_stack.apis.scoring_functions import ( - ListScoringFunctionsResponse, - ScoringFn, - ScoringFnParams, - ScoringFunctions, -) -from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields -from llama_stack.apis.tools import ( - ListToolGroupsResponse, - ListToolsResponse, - Tool, - ToolGroup, - ToolGroups, - ToolHost, -) -from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs -from llama_stack.distribution.access_control import check_access -from llama_stack.distribution.datatypes import ( - AccessAttributes, - BenchmarkWithACL, - DatasetWithACL, - ModelWithACL, - RoutableObject, - RoutableObjectWithProvider, - RoutedProtocol, - ScoringFnWithACL, - ShieldWithACL, - ToolGroupWithACL, - ToolWithACL, - VectorDBWithACL, -) -from llama_stack.distribution.request_headers import get_auth_attributes -from llama_stack.distribution.store import DistributionRegistry -from llama_stack.providers.datatypes import Api, RoutingTable - -logger = logging.getLogger(__name__) - - -def get_impl_api(p: Any) -> Api: - return p.__provider_spec__.api - - -# TODO: this should return the registered object for all APIs -async def register_object_with_provider(obj: RoutableObject, p: Any) -> RoutableObject: - api = get_impl_api(p) - - assert obj.provider_id != "remote", "Remote provider should not be registered" - - if api == Api.inference: - return await p.register_model(obj) - elif api == Api.safety: - return await p.register_shield(obj) - elif api == Api.vector_io: - return await p.register_vector_db(obj) - elif api == Api.datasetio: - return await p.register_dataset(obj) - elif api == Api.scoring: - return await p.register_scoring_function(obj) - elif api == Api.eval: - return await p.register_benchmark(obj) - elif api == Api.tool_runtime: - return await p.register_tool(obj) - else: - raise ValueError(f"Unknown API {api} for registering object with provider") - - -async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None: - api = get_impl_api(p) - if api == Api.vector_io: - return await p.unregister_vector_db(obj.identifier) - elif api == Api.inference: - return await p.unregister_model(obj.identifier) - elif api == Api.datasetio: - return await p.unregister_dataset(obj.identifier) - elif api == Api.tool_runtime: - return await p.unregister_tool(obj.identifier) - else: - raise ValueError(f"Unregister not supported for {api}") - - -Registry = dict[str, list[RoutableObjectWithProvider]] - - -class CommonRoutingTableImpl(RoutingTable): - def __init__( - self, - impls_by_provider_id: dict[str, RoutedProtocol], - dist_registry: DistributionRegistry, - ) -> None: - self.impls_by_provider_id = impls_by_provider_id - self.dist_registry = dist_registry - - async def initialize(self) -> None: - async def add_objects(objs: list[RoutableObjectWithProvider], provider_id: str, cls) -> None: - for obj in objs: - if cls is None: - obj.provider_id = provider_id - else: - # Create a copy of the model data and explicitly set provider_id - model_data = obj.model_dump() - model_data["provider_id"] = provider_id - obj = cls(**model_data) - await self.dist_registry.register(obj) - - # Register all objects from providers - for pid, p in self.impls_by_provider_id.items(): - api = get_impl_api(p) - if api == Api.inference: - p.model_store = self - elif api == Api.safety: - p.shield_store = self - elif api == Api.vector_io: - p.vector_db_store = self - elif api == Api.datasetio: - p.dataset_store = self - elif api == Api.scoring: - p.scoring_function_store = self - scoring_functions = await p.list_scoring_functions() - await add_objects(scoring_functions, pid, ScoringFn) - elif api == Api.eval: - p.benchmark_store = self - elif api == Api.tool_runtime: - p.tool_store = self - - async def shutdown(self) -> None: - for p in self.impls_by_provider_id.values(): - await p.shutdown() - - def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any: - def apiname_object(): - if isinstance(self, ModelsRoutingTable): - return ("Inference", "model") - elif isinstance(self, ShieldsRoutingTable): - return ("Safety", "shield") - elif isinstance(self, VectorDBsRoutingTable): - return ("VectorIO", "vector_db") - elif isinstance(self, DatasetsRoutingTable): - return ("DatasetIO", "dataset") - elif isinstance(self, ScoringFunctionsRoutingTable): - return ("Scoring", "scoring_function") - elif isinstance(self, BenchmarksRoutingTable): - return ("Eval", "benchmark") - elif isinstance(self, ToolGroupsRoutingTable): - return ("Tools", "tool") - else: - raise ValueError("Unknown routing table type") - - apiname, objtype = apiname_object() - - # Get objects from disk registry - obj = self.dist_registry.get_cached(objtype, routing_key) - if not obj: - provider_ids = list(self.impls_by_provider_id.keys()) - if len(provider_ids) > 1: - provider_ids_str = f"any of the providers: {', '.join(provider_ids)}" - else: - provider_ids_str = f"provider: `{provider_ids[0]}`" - raise ValueError( - f"{objtype.capitalize()} `{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objtype}." - ) - - if not provider_id or provider_id == obj.provider_id: - return self.impls_by_provider_id[obj.provider_id] - - raise ValueError(f"Provider not found for `{routing_key}`") - - async def get_object_by_identifier(self, type: str, identifier: str) -> RoutableObjectWithProvider | None: - # Get from disk registry - obj = await self.dist_registry.get(type, identifier) - if not obj: - return None - - # Check if user has permission to access this object - if not check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()): - logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch") - return None - - return obj - - async def unregister_object(self, obj: RoutableObjectWithProvider) -> None: - await self.dist_registry.delete(obj.type, obj.identifier) - await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id]) - - async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider: - # if provider_id is not specified, pick an arbitrary one from existing entries - if not obj.provider_id and len(self.impls_by_provider_id) > 0: - obj.provider_id = list(self.impls_by_provider_id.keys())[0] - - if obj.provider_id not in self.impls_by_provider_id: - raise ValueError(f"Provider `{obj.provider_id}` not found") - - p = self.impls_by_provider_id[obj.provider_id] - - # If object supports access control but no attributes set, use creator's attributes - if not obj.access_attributes: - creator_attributes = get_auth_attributes() - if creator_attributes: - obj.access_attributes = AccessAttributes(**creator_attributes) - logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity") - - registered_obj = await register_object_with_provider(obj, p) - # TODO: This needs to be fixed for all APIs once they return the registered object - if obj.type == ResourceType.model.value: - await self.dist_registry.register(registered_obj) - return registered_obj - - else: - await self.dist_registry.register(obj) - return obj - - async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]: - objs = await self.dist_registry.get_all() - filtered_objs = [obj for obj in objs if obj.type == type] - - # Apply attribute-based access control filtering - if filtered_objs: - filtered_objs = [ - obj - for obj in filtered_objs - if check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()) - ] - - return filtered_objs - - -class ModelsRoutingTable(CommonRoutingTableImpl, Models): - async def list_models(self) -> ListModelsResponse: - return ListModelsResponse(data=await self.get_all_with_type("model")) - - async def openai_list_models(self) -> OpenAIListModelsResponse: - models = await self.get_all_with_type("model") - openai_models = [ - OpenAIModel( - id=model.identifier, - object="model", - created=int(time.time()), - owned_by="llama_stack", - ) - for model in models - ] - return OpenAIListModelsResponse(data=openai_models) - - async def get_model(self, model_id: str) -> Model: - model = await self.get_object_by_identifier("model", model_id) - if model is None: - raise ValueError(f"Model '{model_id}' not found") - return model - - async def register_model( - self, - model_id: str, - provider_model_id: str | None = None, - provider_id: str | None = None, - metadata: dict[str, Any] | None = None, - model_type: ModelType | None = None, - ) -> Model: - if provider_model_id is None: - provider_model_id = model_id - if provider_id is None: - # If provider_id not specified, use the only provider if it supports this model - if len(self.impls_by_provider_id) == 1: - provider_id = list(self.impls_by_provider_id.keys())[0] - else: - raise ValueError( - f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}" - ) - if metadata is None: - metadata = {} - if model_type is None: - model_type = ModelType.llm - if "embedding_dimension" not in metadata and model_type == ModelType.embedding: - raise ValueError("Embedding model must have an embedding dimension in its metadata") - model = ModelWithACL( - identifier=model_id, - provider_resource_id=provider_model_id, - provider_id=provider_id, - metadata=metadata, - model_type=model_type, - ) - registered_model = await self.register_object(model) - return registered_model - - async def unregister_model(self, model_id: str) -> None: - existing_model = await self.get_model(model_id) - if existing_model is None: - raise ValueError(f"Model {model_id} not found") - await self.unregister_object(existing_model) - - -class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): - async def list_shields(self) -> ListShieldsResponse: - return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value)) - - async def get_shield(self, identifier: str) -> Shield: - shield = await self.get_object_by_identifier("shield", identifier) - if shield is None: - raise ValueError(f"Shield '{identifier}' not found") - return shield - - async def register_shield( - self, - shield_id: str, - provider_shield_id: str | None = None, - provider_id: str | None = None, - params: dict[str, Any] | None = None, - ) -> Shield: - if provider_shield_id is None: - provider_shield_id = shield_id - if provider_id is None: - # If provider_id not specified, use the only provider if it supports this shield type - if len(self.impls_by_provider_id) == 1: - provider_id = list(self.impls_by_provider_id.keys())[0] - else: - raise ValueError( - "No provider specified and multiple providers available. Please specify a provider_id." - ) - if params is None: - params = {} - shield = ShieldWithACL( - identifier=shield_id, - provider_resource_id=provider_shield_id, - provider_id=provider_id, - params=params, - ) - await self.register_object(shield) - return shield - - -class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): - async def list_vector_dbs(self) -> ListVectorDBsResponse: - return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db")) - - async def get_vector_db(self, vector_db_id: str) -> VectorDB: - vector_db = await self.get_object_by_identifier("vector_db", vector_db_id) - if vector_db is None: - raise ValueError(f"Vector DB '{vector_db_id}' not found") - return vector_db - - async def register_vector_db( - self, - vector_db_id: str, - embedding_model: str, - embedding_dimension: int | None = 384, - provider_id: str | None = None, - provider_vector_db_id: str | None = None, - ) -> VectorDB: - if provider_vector_db_id is None: - provider_vector_db_id = vector_db_id - if provider_id is None: - if len(self.impls_by_provider_id) > 0: - provider_id = list(self.impls_by_provider_id.keys())[0] - if len(self.impls_by_provider_id) > 1: - logger.warning( - f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}." - ) - else: - raise ValueError("No provider available. Please configure a vector_io provider.") - model = await self.get_object_by_identifier("model", embedding_model) - if model is None: - raise ValueError(f"Model {embedding_model} not found") - if model.model_type != ModelType.embedding: - raise ValueError(f"Model {embedding_model} is not an embedding model") - if "embedding_dimension" not in model.metadata: - raise ValueError(f"Model {embedding_model} does not have an embedding dimension") - vector_db_data = { - "identifier": vector_db_id, - "type": ResourceType.vector_db.value, - "provider_id": provider_id, - "provider_resource_id": provider_vector_db_id, - "embedding_model": embedding_model, - "embedding_dimension": model.metadata["embedding_dimension"], - } - vector_db = TypeAdapter(VectorDBWithACL).validate_python(vector_db_data) - await self.register_object(vector_db) - return vector_db - - async def unregister_vector_db(self, vector_db_id: str) -> None: - existing_vector_db = await self.get_vector_db(vector_db_id) - if existing_vector_db is None: - raise ValueError(f"Vector DB {vector_db_id} not found") - await self.unregister_object(existing_vector_db) - - -class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): - async def list_datasets(self) -> ListDatasetsResponse: - return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value)) - - async def get_dataset(self, dataset_id: str) -> Dataset: - dataset = await self.get_object_by_identifier("dataset", dataset_id) - if dataset is None: - raise ValueError(f"Dataset '{dataset_id}' not found") - return dataset - - async def register_dataset( - self, - purpose: DatasetPurpose, - source: DataSource, - metadata: dict[str, Any] | None = None, - dataset_id: str | None = None, - ) -> Dataset: - if isinstance(source, dict): - if source["type"] == "uri": - source = URIDataSource.parse_obj(source) - elif source["type"] == "rows": - source = RowsDataSource.parse_obj(source) - - if not dataset_id: - dataset_id = f"dataset-{str(uuid.uuid4())}" - - provider_dataset_id = dataset_id - - # infer provider from source - if metadata: - if metadata.get("provider_id"): - provider_id = metadata.get("provider_id") # pass through from nvidia datasetio - elif source.type == DatasetType.rows.value: - provider_id = "localfs" - elif source.type == DatasetType.uri.value: - # infer provider from uri - if source.uri.startswith("huggingface"): - provider_id = "huggingface" - else: - provider_id = "localfs" - else: - raise ValueError(f"Unknown data source type: {source.type}") - - if metadata is None: - metadata = {} - - dataset = DatasetWithACL( - identifier=dataset_id, - provider_resource_id=provider_dataset_id, - provider_id=provider_id, - purpose=purpose, - source=source, - metadata=metadata, - ) - - await self.register_object(dataset) - return dataset - - async def unregister_dataset(self, dataset_id: str) -> None: - dataset = await self.get_dataset(dataset_id) - if dataset is None: - raise ValueError(f"Dataset {dataset_id} not found") - await self.unregister_object(dataset) - - -class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): - async def list_scoring_functions(self) -> ListScoringFunctionsResponse: - return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value)) - - async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn: - scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id) - if scoring_fn is None: - raise ValueError(f"Scoring function '{scoring_fn_id}' not found") - return scoring_fn - - async def register_scoring_function( - self, - scoring_fn_id: str, - description: str, - return_type: ParamType, - provider_scoring_fn_id: str | None = None, - provider_id: str | None = None, - params: ScoringFnParams | None = None, - ) -> None: - if provider_scoring_fn_id is None: - provider_scoring_fn_id = scoring_fn_id - if provider_id is None: - if len(self.impls_by_provider_id) == 1: - provider_id = list(self.impls_by_provider_id.keys())[0] - else: - raise ValueError( - "No provider specified and multiple providers available. Please specify a provider_id." - ) - scoring_fn = ScoringFnWithACL( - identifier=scoring_fn_id, - description=description, - return_type=return_type, - provider_resource_id=provider_scoring_fn_id, - provider_id=provider_id, - params=params, - ) - scoring_fn.provider_id = provider_id - await self.register_object(scoring_fn) - - -class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks): - async def list_benchmarks(self) -> ListBenchmarksResponse: - return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark")) - - async def get_benchmark(self, benchmark_id: str) -> Benchmark: - benchmark = await self.get_object_by_identifier("benchmark", benchmark_id) - if benchmark is None: - raise ValueError(f"Benchmark '{benchmark_id}' not found") - return benchmark - - async def register_benchmark( - self, - benchmark_id: str, - dataset_id: str, - scoring_functions: list[str], - metadata: dict[str, Any] | None = None, - provider_benchmark_id: str | None = None, - provider_id: str | None = None, - ) -> None: - if metadata is None: - metadata = {} - if provider_id is None: - if len(self.impls_by_provider_id) == 1: - provider_id = list(self.impls_by_provider_id.keys())[0] - else: - raise ValueError( - "No provider specified and multiple providers available. Please specify a provider_id." - ) - if provider_benchmark_id is None: - provider_benchmark_id = benchmark_id - benchmark = BenchmarkWithACL( - identifier=benchmark_id, - dataset_id=dataset_id, - scoring_functions=scoring_functions, - metadata=metadata, - provider_id=provider_id, - provider_resource_id=provider_benchmark_id, - ) - await self.register_object(benchmark) - - -class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): - async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse: - tools = await self.get_all_with_type("tool") - if toolgroup_id: - tools = [tool for tool in tools if tool.toolgroup_id == toolgroup_id] - return ListToolsResponse(data=tools) - - async def list_tool_groups(self) -> ListToolGroupsResponse: - return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group")) - - async def get_tool_group(self, toolgroup_id: str) -> ToolGroup: - tool_group = await self.get_object_by_identifier("tool_group", toolgroup_id) - if tool_group is None: - raise ValueError(f"Tool group '{toolgroup_id}' not found") - return tool_group - - async def get_tool(self, tool_name: str) -> Tool: - return await self.get_object_by_identifier("tool", tool_name) - - async def register_tool_group( - self, - toolgroup_id: str, - provider_id: str, - mcp_endpoint: URL | None = None, - args: dict[str, Any] | None = None, - ) -> None: - tools = [] - tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint) - tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution - - for tool_def in tool_defs.data: - tools.append( - ToolWithACL( - identifier=tool_def.name, - toolgroup_id=toolgroup_id, - description=tool_def.description or "", - parameters=tool_def.parameters or [], - provider_id=provider_id, - provider_resource_id=tool_def.name, - metadata=tool_def.metadata, - tool_host=tool_host, - ) - ) - for tool in tools: - existing_tool = await self.get_tool(tool.identifier) - # Compare existing and new object if one exists - if existing_tool: - existing_dict = existing_tool.model_dump() - new_dict = tool.model_dump() - - if existing_dict != new_dict: - raise ValueError( - f"Object {tool.identifier} already exists in registry. Please use a different identifier." - ) - await self.register_object(tool) - - await self.dist_registry.register( - ToolGroupWithACL( - identifier=toolgroup_id, - provider_id=provider_id, - provider_resource_id=toolgroup_id, - mcp_endpoint=mcp_endpoint, - args=args, - ) - ) - - async def unregister_toolgroup(self, toolgroup_id: str) -> None: - tool_group = await self.get_tool_group(toolgroup_id) - if tool_group is None: - raise ValueError(f"Tool group {toolgroup_id} not found") - tools = await self.list_tools(toolgroup_id) - for tool in getattr(tools, "data", []): - await self.unregister_object(tool) - await self.unregister_object(tool_group) - - async def shutdown(self) -> None: - pass diff --git a/llama_stack/distribution/routers/safety.py b/llama_stack/distribution/routers/safety.py new file mode 100644 index 000000000..9761d2db0 --- /dev/null +++ b/llama_stack/distribution/routers/safety.py @@ -0,0 +1,57 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from llama_stack.apis.inference import ( + Message, +) +from llama_stack.apis.safety import RunShieldResponse, Safety +from llama_stack.apis.shields import Shield +from llama_stack.log import get_logger +from llama_stack.providers.datatypes import RoutingTable + +logger = get_logger(name=__name__, category="core") + + +class SafetyRouter(Safety): + def __init__( + self, + routing_table: RoutingTable, + ) -> None: + logger.debug("Initializing SafetyRouter") + self.routing_table = routing_table + + async def initialize(self) -> None: + logger.debug("SafetyRouter.initialize") + pass + + async def shutdown(self) -> None: + logger.debug("SafetyRouter.shutdown") + pass + + async def register_shield( + self, + shield_id: str, + provider_shield_id: str | None = None, + provider_id: str | None = None, + params: dict[str, Any] | None = None, + ) -> Shield: + logger.debug(f"SafetyRouter.register_shield: {shield_id}") + return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params) + + async def run_shield( + self, + shield_id: str, + messages: list[Message], + params: dict[str, Any] = None, + ) -> RunShieldResponse: + logger.debug(f"SafetyRouter.run_shield: {shield_id}") + return await self.routing_table.get_provider_impl(shield_id).run_shield( + shield_id=shield_id, + messages=messages, + params=params, + ) diff --git a/llama_stack/distribution/routers/tool_runtime.py b/llama_stack/distribution/routers/tool_runtime.py new file mode 100644 index 000000000..2d4734a2e --- /dev/null +++ b/llama_stack/distribution/routers/tool_runtime.py @@ -0,0 +1,91 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from llama_stack.apis.common.content_types import ( + URL, + InterleavedContent, +) +from llama_stack.apis.tools import ( + ListToolDefsResponse, + RAGDocument, + RAGQueryConfig, + RAGQueryResult, + RAGToolRuntime, + ToolRuntime, +) +from llama_stack.log import get_logger +from llama_stack.providers.datatypes import RoutingTable + +logger = get_logger(name=__name__, category="core") + + +class ToolRuntimeRouter(ToolRuntime): + class RagToolImpl(RAGToolRuntime): + def __init__( + self, + routing_table: RoutingTable, + ) -> None: + logger.debug("Initializing ToolRuntimeRouter.RagToolImpl") + self.routing_table = routing_table + + async def query( + self, + content: InterleavedContent, + vector_db_ids: list[str], + query_config: RAGQueryConfig | None = None, + ) -> RAGQueryResult: + logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}") + return await self.routing_table.get_provider_impl("knowledge_search").query( + content, vector_db_ids, query_config + ) + + async def insert( + self, + documents: list[RAGDocument], + vector_db_id: str, + chunk_size_in_tokens: int = 512, + ) -> None: + logger.debug( + f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}" + ) + return await self.routing_table.get_provider_impl("insert_into_memory").insert( + documents, vector_db_id, chunk_size_in_tokens + ) + + def __init__( + self, + routing_table: RoutingTable, + ) -> None: + logger.debug("Initializing ToolRuntimeRouter") + self.routing_table = routing_table + + # HACK ALERT this should be in sync with "get_all_api_endpoints()" + self.rag_tool = self.RagToolImpl(routing_table) + for method in ("query", "insert"): + setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method)) + + async def initialize(self) -> None: + logger.debug("ToolRuntimeRouter.initialize") + pass + + async def shutdown(self) -> None: + logger.debug("ToolRuntimeRouter.shutdown") + pass + + async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any: + logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}") + return await self.routing_table.get_provider_impl(tool_name).invoke_tool( + tool_name=tool_name, + kwargs=kwargs, + ) + + async def list_runtime_tools( + self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None + ) -> ListToolDefsResponse: + logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}") + return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint) diff --git a/llama_stack/distribution/routers/vector_io.py b/llama_stack/distribution/routers/vector_io.py new file mode 100644 index 000000000..8c17aa890 --- /dev/null +++ b/llama_stack/distribution/routers/vector_io.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from llama_stack.apis.common.content_types import ( + InterleavedContent, +) +from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO +from llama_stack.log import get_logger +from llama_stack.providers.datatypes import RoutingTable + +logger = get_logger(name=__name__, category="core") + + +class VectorIORouter(VectorIO): + """Routes to an provider based on the vector db identifier""" + + def __init__( + self, + routing_table: RoutingTable, + ) -> None: + logger.debug("Initializing VectorIORouter") + self.routing_table = routing_table + + async def initialize(self) -> None: + logger.debug("VectorIORouter.initialize") + pass + + async def shutdown(self) -> None: + logger.debug("VectorIORouter.shutdown") + pass + + async def register_vector_db( + self, + vector_db_id: str, + embedding_model: str, + embedding_dimension: int | None = 384, + provider_id: str | None = None, + provider_vector_db_id: str | None = None, + ) -> None: + logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}") + await self.routing_table.register_vector_db( + vector_db_id, + embedding_model, + embedding_dimension, + provider_id, + provider_vector_db_id, + ) + + async def insert_chunks( + self, + vector_db_id: str, + chunks: list[Chunk], + ttl_seconds: int | None = None, + ) -> None: + logger.debug( + f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}", + ) + return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds) + + async def query_chunks( + self, + vector_db_id: str, + query: InterleavedContent, + params: dict[str, Any] | None = None, + ) -> QueryChunksResponse: + logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}") + return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params) diff --git a/llama_stack/distribution/routing_tables/__init__.py b/llama_stack/distribution/routing_tables/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/distribution/routing_tables/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/distribution/routing_tables/benchmarks.py b/llama_stack/distribution/routing_tables/benchmarks.py new file mode 100644 index 000000000..589a00c02 --- /dev/null +++ b/llama_stack/distribution/routing_tables/benchmarks.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse +from llama_stack.distribution.datatypes import ( + BenchmarkWithACL, +) +from llama_stack.log import get_logger + +from .common import CommonRoutingTableImpl + +logger = get_logger(name=__name__, category="core") + + +class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks): + async def list_benchmarks(self) -> ListBenchmarksResponse: + return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark")) + + async def get_benchmark(self, benchmark_id: str) -> Benchmark: + benchmark = await self.get_object_by_identifier("benchmark", benchmark_id) + if benchmark is None: + raise ValueError(f"Benchmark '{benchmark_id}' not found") + return benchmark + + async def register_benchmark( + self, + benchmark_id: str, + dataset_id: str, + scoring_functions: list[str], + metadata: dict[str, Any] | None = None, + provider_benchmark_id: str | None = None, + provider_id: str | None = None, + ) -> None: + if metadata is None: + metadata = {} + if provider_id is None: + if len(self.impls_by_provider_id) == 1: + provider_id = list(self.impls_by_provider_id.keys())[0] + else: + raise ValueError( + "No provider specified and multiple providers available. Please specify a provider_id." + ) + if provider_benchmark_id is None: + provider_benchmark_id = benchmark_id + benchmark = BenchmarkWithACL( + identifier=benchmark_id, + dataset_id=dataset_id, + scoring_functions=scoring_functions, + metadata=metadata, + provider_id=provider_id, + provider_resource_id=provider_benchmark_id, + ) + await self.register_object(benchmark) diff --git a/llama_stack/distribution/routing_tables/common.py b/llama_stack/distribution/routing_tables/common.py new file mode 100644 index 000000000..95a92a5ba --- /dev/null +++ b/llama_stack/distribution/routing_tables/common.py @@ -0,0 +1,218 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from llama_stack.apis.resource import ResourceType +from llama_stack.apis.scoring_functions import ScoringFn +from llama_stack.distribution.access_control import check_access +from llama_stack.distribution.datatypes import ( + AccessAttributes, + RoutableObject, + RoutableObjectWithProvider, + RoutedProtocol, +) +from llama_stack.distribution.request_headers import get_auth_attributes +from llama_stack.distribution.store import DistributionRegistry +from llama_stack.log import get_logger +from llama_stack.providers.datatypes import Api, RoutingTable + +logger = get_logger(name=__name__, category="core") + + +def get_impl_api(p: Any) -> Api: + return p.__provider_spec__.api + + +# TODO: this should return the registered object for all APIs +async def register_object_with_provider(obj: RoutableObject, p: Any) -> RoutableObject: + api = get_impl_api(p) + + assert obj.provider_id != "remote", "Remote provider should not be registered" + + if api == Api.inference: + return await p.register_model(obj) + elif api == Api.safety: + return await p.register_shield(obj) + elif api == Api.vector_io: + return await p.register_vector_db(obj) + elif api == Api.datasetio: + return await p.register_dataset(obj) + elif api == Api.scoring: + return await p.register_scoring_function(obj) + elif api == Api.eval: + return await p.register_benchmark(obj) + elif api == Api.tool_runtime: + return await p.register_tool(obj) + else: + raise ValueError(f"Unknown API {api} for registering object with provider") + + +async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None: + api = get_impl_api(p) + if api == Api.vector_io: + return await p.unregister_vector_db(obj.identifier) + elif api == Api.inference: + return await p.unregister_model(obj.identifier) + elif api == Api.datasetio: + return await p.unregister_dataset(obj.identifier) + elif api == Api.tool_runtime: + return await p.unregister_tool(obj.identifier) + else: + raise ValueError(f"Unregister not supported for {api}") + + +Registry = dict[str, list[RoutableObjectWithProvider]] + + +class CommonRoutingTableImpl(RoutingTable): + def __init__( + self, + impls_by_provider_id: dict[str, RoutedProtocol], + dist_registry: DistributionRegistry, + ) -> None: + self.impls_by_provider_id = impls_by_provider_id + self.dist_registry = dist_registry + + async def initialize(self) -> None: + async def add_objects(objs: list[RoutableObjectWithProvider], provider_id: str, cls) -> None: + for obj in objs: + if cls is None: + obj.provider_id = provider_id + else: + # Create a copy of the model data and explicitly set provider_id + model_data = obj.model_dump() + model_data["provider_id"] = provider_id + obj = cls(**model_data) + await self.dist_registry.register(obj) + + # Register all objects from providers + for pid, p in self.impls_by_provider_id.items(): + api = get_impl_api(p) + if api == Api.inference: + p.model_store = self + elif api == Api.safety: + p.shield_store = self + elif api == Api.vector_io: + p.vector_db_store = self + elif api == Api.datasetio: + p.dataset_store = self + elif api == Api.scoring: + p.scoring_function_store = self + scoring_functions = await p.list_scoring_functions() + await add_objects(scoring_functions, pid, ScoringFn) + elif api == Api.eval: + p.benchmark_store = self + elif api == Api.tool_runtime: + p.tool_store = self + + async def shutdown(self) -> None: + for p in self.impls_by_provider_id.values(): + await p.shutdown() + + def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any: + from .benchmarks import BenchmarksRoutingTable + from .datasets import DatasetsRoutingTable + from .models import ModelsRoutingTable + from .scoring_functions import ScoringFunctionsRoutingTable + from .shields import ShieldsRoutingTable + from .toolgroups import ToolGroupsRoutingTable + from .vector_dbs import VectorDBsRoutingTable + + def apiname_object(): + if isinstance(self, ModelsRoutingTable): + return ("Inference", "model") + elif isinstance(self, ShieldsRoutingTable): + return ("Safety", "shield") + elif isinstance(self, VectorDBsRoutingTable): + return ("VectorIO", "vector_db") + elif isinstance(self, DatasetsRoutingTable): + return ("DatasetIO", "dataset") + elif isinstance(self, ScoringFunctionsRoutingTable): + return ("Scoring", "scoring_function") + elif isinstance(self, BenchmarksRoutingTable): + return ("Eval", "benchmark") + elif isinstance(self, ToolGroupsRoutingTable): + return ("Tools", "tool") + else: + raise ValueError("Unknown routing table type") + + apiname, objtype = apiname_object() + + # Get objects from disk registry + obj = self.dist_registry.get_cached(objtype, routing_key) + if not obj: + provider_ids = list(self.impls_by_provider_id.keys()) + if len(provider_ids) > 1: + provider_ids_str = f"any of the providers: {', '.join(provider_ids)}" + else: + provider_ids_str = f"provider: `{provider_ids[0]}`" + raise ValueError( + f"{objtype.capitalize()} `{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objtype}." + ) + + if not provider_id or provider_id == obj.provider_id: + return self.impls_by_provider_id[obj.provider_id] + + raise ValueError(f"Provider not found for `{routing_key}`") + + async def get_object_by_identifier(self, type: str, identifier: str) -> RoutableObjectWithProvider | None: + # Get from disk registry + obj = await self.dist_registry.get(type, identifier) + if not obj: + return None + + # Check if user has permission to access this object + if not check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()): + logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch") + return None + + return obj + + async def unregister_object(self, obj: RoutableObjectWithProvider) -> None: + await self.dist_registry.delete(obj.type, obj.identifier) + await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id]) + + async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider: + # if provider_id is not specified, pick an arbitrary one from existing entries + if not obj.provider_id and len(self.impls_by_provider_id) > 0: + obj.provider_id = list(self.impls_by_provider_id.keys())[0] + + if obj.provider_id not in self.impls_by_provider_id: + raise ValueError(f"Provider `{obj.provider_id}` not found") + + p = self.impls_by_provider_id[obj.provider_id] + + # If object supports access control but no attributes set, use creator's attributes + if not obj.access_attributes: + creator_attributes = get_auth_attributes() + if creator_attributes: + obj.access_attributes = AccessAttributes(**creator_attributes) + logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity") + + registered_obj = await register_object_with_provider(obj, p) + # TODO: This needs to be fixed for all APIs once they return the registered object + if obj.type == ResourceType.model.value: + await self.dist_registry.register(registered_obj) + return registered_obj + + else: + await self.dist_registry.register(obj) + return obj + + async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]: + objs = await self.dist_registry.get_all() + filtered_objs = [obj for obj in objs if obj.type == type] + + # Apply attribute-based access control filtering + if filtered_objs: + filtered_objs = [ + obj + for obj in filtered_objs + if check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()) + ] + + return filtered_objs diff --git a/llama_stack/distribution/routing_tables/datasets.py b/llama_stack/distribution/routing_tables/datasets.py new file mode 100644 index 000000000..4401ad47e --- /dev/null +++ b/llama_stack/distribution/routing_tables/datasets.py @@ -0,0 +1,93 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import uuid +from typing import Any + +from llama_stack.apis.datasets import ( + Dataset, + DatasetPurpose, + Datasets, + DatasetType, + DataSource, + ListDatasetsResponse, + RowsDataSource, + URIDataSource, +) +from llama_stack.apis.resource import ResourceType +from llama_stack.distribution.datatypes import ( + DatasetWithACL, +) +from llama_stack.log import get_logger + +from .common import CommonRoutingTableImpl + +logger = get_logger(name=__name__, category="core") + + +class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): + async def list_datasets(self) -> ListDatasetsResponse: + return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value)) + + async def get_dataset(self, dataset_id: str) -> Dataset: + dataset = await self.get_object_by_identifier("dataset", dataset_id) + if dataset is None: + raise ValueError(f"Dataset '{dataset_id}' not found") + return dataset + + async def register_dataset( + self, + purpose: DatasetPurpose, + source: DataSource, + metadata: dict[str, Any] | None = None, + dataset_id: str | None = None, + ) -> Dataset: + if isinstance(source, dict): + if source["type"] == "uri": + source = URIDataSource.parse_obj(source) + elif source["type"] == "rows": + source = RowsDataSource.parse_obj(source) + + if not dataset_id: + dataset_id = f"dataset-{str(uuid.uuid4())}" + + provider_dataset_id = dataset_id + + # infer provider from source + if metadata: + if metadata.get("provider_id"): + provider_id = metadata.get("provider_id") # pass through from nvidia datasetio + elif source.type == DatasetType.rows.value: + provider_id = "localfs" + elif source.type == DatasetType.uri.value: + # infer provider from uri + if source.uri.startswith("huggingface"): + provider_id = "huggingface" + else: + provider_id = "localfs" + else: + raise ValueError(f"Unknown data source type: {source.type}") + + if metadata is None: + metadata = {} + + dataset = DatasetWithACL( + identifier=dataset_id, + provider_resource_id=provider_dataset_id, + provider_id=provider_id, + purpose=purpose, + source=source, + metadata=metadata, + ) + + await self.register_object(dataset) + return dataset + + async def unregister_dataset(self, dataset_id: str) -> None: + dataset = await self.get_dataset(dataset_id) + if dataset is None: + raise ValueError(f"Dataset {dataset_id} not found") + await self.unregister_object(dataset) diff --git a/llama_stack/distribution/routing_tables/models.py b/llama_stack/distribution/routing_tables/models.py new file mode 100644 index 000000000..7216d9935 --- /dev/null +++ b/llama_stack/distribution/routing_tables/models.py @@ -0,0 +1,82 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import time +from typing import Any + +from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel +from llama_stack.distribution.datatypes import ( + ModelWithACL, +) +from llama_stack.log import get_logger + +from .common import CommonRoutingTableImpl + +logger = get_logger(name=__name__, category="core") + + +class ModelsRoutingTable(CommonRoutingTableImpl, Models): + async def list_models(self) -> ListModelsResponse: + return ListModelsResponse(data=await self.get_all_with_type("model")) + + async def openai_list_models(self) -> OpenAIListModelsResponse: + models = await self.get_all_with_type("model") + openai_models = [ + OpenAIModel( + id=model.identifier, + object="model", + created=int(time.time()), + owned_by="llama_stack", + ) + for model in models + ] + return OpenAIListModelsResponse(data=openai_models) + + async def get_model(self, model_id: str) -> Model: + model = await self.get_object_by_identifier("model", model_id) + if model is None: + raise ValueError(f"Model '{model_id}' not found") + return model + + async def register_model( + self, + model_id: str, + provider_model_id: str | None = None, + provider_id: str | None = None, + metadata: dict[str, Any] | None = None, + model_type: ModelType | None = None, + ) -> Model: + if provider_model_id is None: + provider_model_id = model_id + if provider_id is None: + # If provider_id not specified, use the only provider if it supports this model + if len(self.impls_by_provider_id) == 1: + provider_id = list(self.impls_by_provider_id.keys())[0] + else: + raise ValueError( + f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}" + ) + if metadata is None: + metadata = {} + if model_type is None: + model_type = ModelType.llm + if "embedding_dimension" not in metadata and model_type == ModelType.embedding: + raise ValueError("Embedding model must have an embedding dimension in its metadata") + model = ModelWithACL( + identifier=model_id, + provider_resource_id=provider_model_id, + provider_id=provider_id, + metadata=metadata, + model_type=model_type, + ) + registered_model = await self.register_object(model) + return registered_model + + async def unregister_model(self, model_id: str) -> None: + existing_model = await self.get_model(model_id) + if existing_model is None: + raise ValueError(f"Model {model_id} not found") + await self.unregister_object(existing_model) diff --git a/llama_stack/distribution/routing_tables/scoring_functions.py b/llama_stack/distribution/routing_tables/scoring_functions.py new file mode 100644 index 000000000..d85f64b57 --- /dev/null +++ b/llama_stack/distribution/routing_tables/scoring_functions.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.common.type_system import ParamType +from llama_stack.apis.resource import ResourceType +from llama_stack.apis.scoring_functions import ( + ListScoringFunctionsResponse, + ScoringFn, + ScoringFnParams, + ScoringFunctions, +) +from llama_stack.distribution.datatypes import ( + ScoringFnWithACL, +) +from llama_stack.log import get_logger + +from .common import CommonRoutingTableImpl + +logger = get_logger(name=__name__, category="core") + + +class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): + async def list_scoring_functions(self) -> ListScoringFunctionsResponse: + return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value)) + + async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn: + scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id) + if scoring_fn is None: + raise ValueError(f"Scoring function '{scoring_fn_id}' not found") + return scoring_fn + + async def register_scoring_function( + self, + scoring_fn_id: str, + description: str, + return_type: ParamType, + provider_scoring_fn_id: str | None = None, + provider_id: str | None = None, + params: ScoringFnParams | None = None, + ) -> None: + if provider_scoring_fn_id is None: + provider_scoring_fn_id = scoring_fn_id + if provider_id is None: + if len(self.impls_by_provider_id) == 1: + provider_id = list(self.impls_by_provider_id.keys())[0] + else: + raise ValueError( + "No provider specified and multiple providers available. Please specify a provider_id." + ) + scoring_fn = ScoringFnWithACL( + identifier=scoring_fn_id, + description=description, + return_type=return_type, + provider_resource_id=provider_scoring_fn_id, + provider_id=provider_id, + params=params, + ) + scoring_fn.provider_id = provider_id + await self.register_object(scoring_fn) diff --git a/llama_stack/distribution/routing_tables/shields.py b/llama_stack/distribution/routing_tables/shields.py new file mode 100644 index 000000000..7f62596c9 --- /dev/null +++ b/llama_stack/distribution/routing_tables/shields.py @@ -0,0 +1,57 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from llama_stack.apis.resource import ResourceType +from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields +from llama_stack.distribution.datatypes import ( + ShieldWithACL, +) +from llama_stack.log import get_logger + +from .common import CommonRoutingTableImpl + +logger = get_logger(name=__name__, category="core") + + +class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): + async def list_shields(self) -> ListShieldsResponse: + return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value)) + + async def get_shield(self, identifier: str) -> Shield: + shield = await self.get_object_by_identifier("shield", identifier) + if shield is None: + raise ValueError(f"Shield '{identifier}' not found") + return shield + + async def register_shield( + self, + shield_id: str, + provider_shield_id: str | None = None, + provider_id: str | None = None, + params: dict[str, Any] | None = None, + ) -> Shield: + if provider_shield_id is None: + provider_shield_id = shield_id + if provider_id is None: + # If provider_id not specified, use the only provider if it supports this shield type + if len(self.impls_by_provider_id) == 1: + provider_id = list(self.impls_by_provider_id.keys())[0] + else: + raise ValueError( + "No provider specified and multiple providers available. Please specify a provider_id." + ) + if params is None: + params = {} + shield = ShieldWithACL( + identifier=shield_id, + provider_resource_id=provider_shield_id, + provider_id=provider_id, + params=params, + ) + await self.register_object(shield) + return shield diff --git a/llama_stack/distribution/routing_tables/toolgroups.py b/llama_stack/distribution/routing_tables/toolgroups.py new file mode 100644 index 000000000..cb73dc7c2 --- /dev/null +++ b/llama_stack/distribution/routing_tables/toolgroups.py @@ -0,0 +1,98 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from llama_stack.apis.common.content_types import URL +from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups, ToolHost +from llama_stack.distribution.datatypes import ( + ToolGroupWithACL, + ToolWithACL, +) +from llama_stack.log import get_logger + +from .common import CommonRoutingTableImpl + +logger = get_logger(name=__name__, category="core") + + +class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): + async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse: + tools = await self.get_all_with_type("tool") + if toolgroup_id: + tools = [tool for tool in tools if tool.toolgroup_id == toolgroup_id] + return ListToolsResponse(data=tools) + + async def list_tool_groups(self) -> ListToolGroupsResponse: + return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group")) + + async def get_tool_group(self, toolgroup_id: str) -> ToolGroup: + tool_group = await self.get_object_by_identifier("tool_group", toolgroup_id) + if tool_group is None: + raise ValueError(f"Tool group '{toolgroup_id}' not found") + return tool_group + + async def get_tool(self, tool_name: str) -> Tool: + return await self.get_object_by_identifier("tool", tool_name) + + async def register_tool_group( + self, + toolgroup_id: str, + provider_id: str, + mcp_endpoint: URL | None = None, + args: dict[str, Any] | None = None, + ) -> None: + tools = [] + tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint) + tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution + + for tool_def in tool_defs.data: + tools.append( + ToolWithACL( + identifier=tool_def.name, + toolgroup_id=toolgroup_id, + description=tool_def.description or "", + parameters=tool_def.parameters or [], + provider_id=provider_id, + provider_resource_id=tool_def.name, + metadata=tool_def.metadata, + tool_host=tool_host, + ) + ) + for tool in tools: + existing_tool = await self.get_tool(tool.identifier) + # Compare existing and new object if one exists + if existing_tool: + existing_dict = existing_tool.model_dump() + new_dict = tool.model_dump() + + if existing_dict != new_dict: + raise ValueError( + f"Object {tool.identifier} already exists in registry. Please use a different identifier." + ) + await self.register_object(tool) + + await self.dist_registry.register( + ToolGroupWithACL( + identifier=toolgroup_id, + provider_id=provider_id, + provider_resource_id=toolgroup_id, + mcp_endpoint=mcp_endpoint, + args=args, + ) + ) + + async def unregister_toolgroup(self, toolgroup_id: str) -> None: + tool_group = await self.get_tool_group(toolgroup_id) + if tool_group is None: + raise ValueError(f"Tool group {toolgroup_id} not found") + tools = await self.list_tools(toolgroup_id) + for tool in getattr(tools, "data", []): + await self.unregister_object(tool) + await self.unregister_object(tool_group) + + async def shutdown(self) -> None: + pass diff --git a/llama_stack/distribution/routing_tables/vector_dbs.py b/llama_stack/distribution/routing_tables/vector_dbs.py new file mode 100644 index 000000000..dc6c0d0ef --- /dev/null +++ b/llama_stack/distribution/routing_tables/vector_dbs.py @@ -0,0 +1,74 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import TypeAdapter + +from llama_stack.apis.models import ModelType +from llama_stack.apis.resource import ResourceType +from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs +from llama_stack.distribution.datatypes import ( + VectorDBWithACL, +) +from llama_stack.log import get_logger + +from .common import CommonRoutingTableImpl + +logger = get_logger(name=__name__, category="core") + + +class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): + async def list_vector_dbs(self) -> ListVectorDBsResponse: + return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db")) + + async def get_vector_db(self, vector_db_id: str) -> VectorDB: + vector_db = await self.get_object_by_identifier("vector_db", vector_db_id) + if vector_db is None: + raise ValueError(f"Vector DB '{vector_db_id}' not found") + return vector_db + + async def register_vector_db( + self, + vector_db_id: str, + embedding_model: str, + embedding_dimension: int | None = 384, + provider_id: str | None = None, + provider_vector_db_id: str | None = None, + ) -> VectorDB: + if provider_vector_db_id is None: + provider_vector_db_id = vector_db_id + if provider_id is None: + if len(self.impls_by_provider_id) > 0: + provider_id = list(self.impls_by_provider_id.keys())[0] + if len(self.impls_by_provider_id) > 1: + logger.warning( + f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}." + ) + else: + raise ValueError("No provider available. Please configure a vector_io provider.") + model = await self.get_object_by_identifier("model", embedding_model) + if model is None: + raise ValueError(f"Model {embedding_model} not found") + if model.model_type != ModelType.embedding: + raise ValueError(f"Model {embedding_model} is not an embedding model") + if "embedding_dimension" not in model.metadata: + raise ValueError(f"Model {embedding_model} does not have an embedding dimension") + vector_db_data = { + "identifier": vector_db_id, + "type": ResourceType.vector_db.value, + "provider_id": provider_id, + "provider_resource_id": provider_vector_db_id, + "embedding_model": embedding_model, + "embedding_dimension": model.metadata["embedding_dimension"], + } + vector_db = TypeAdapter(VectorDBWithACL).validate_python(vector_db_data) + await self.register_object(vector_db) + return vector_db + + async def unregister_vector_db(self, vector_db_id: str) -> None: + existing_vector_db = await self.get_vector_db(vector_db_id) + if existing_vector_db is None: + raise ValueError(f"Vector DB {vector_db_id} not found") + await self.unregister_object(existing_vector_db) diff --git a/pyproject.toml b/pyproject.toml index 8d8137233..0b3b30a7f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -219,6 +219,7 @@ exclude = [ "^llama_stack/distribution/client\\.py$", "^llama_stack/distribution/request_headers\\.py$", "^llama_stack/distribution/routers/", + "^llama_stack/distribution/routing_tables/", "^llama_stack/distribution/server/endpoints\\.py$", "^llama_stack/distribution/server/server\\.py$", "^llama_stack/distribution/stack\\.py$", diff --git a/tests/unit/distribution/routers/test_routing_tables.py b/tests/unit/distribution/routers/test_routing_tables.py index 4e6585ad6..b5db6854a 100644 --- a/tests/unit/distribution/routers/test_routing_tables.py +++ b/tests/unit/distribution/routers/test_routing_tables.py @@ -17,15 +17,13 @@ from llama_stack.apis.models.models import Model, ModelType from llama_stack.apis.shields.shields import Shield from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolParameter from llama_stack.apis.vector_dbs.vector_dbs import VectorDB -from llama_stack.distribution.routers.routing_tables import ( - BenchmarksRoutingTable, - DatasetsRoutingTable, - ModelsRoutingTable, - ScoringFunctionsRoutingTable, - ShieldsRoutingTable, - ToolGroupsRoutingTable, - VectorDBsRoutingTable, -) +from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable +from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable +from llama_stack.distribution.routing_tables.models import ModelsRoutingTable +from llama_stack.distribution.routing_tables.scoring_functions import ScoringFunctionsRoutingTable +from llama_stack.distribution.routing_tables.shields import ShieldsRoutingTable +from llama_stack.distribution.routing_tables.toolgroups import ToolGroupsRoutingTable +from llama_stack.distribution.routing_tables.vector_dbs import VectorDBsRoutingTable class Impl: diff --git a/tests/unit/server/test_access_control.py b/tests/unit/server/test_access_control.py index b5e9c2698..e352ba54d 100644 --- a/tests/unit/server/test_access_control.py +++ b/tests/unit/server/test_access_control.py @@ -11,7 +11,7 @@ import pytest from llama_stack.apis.datatypes import Api from llama_stack.apis.models import ModelType from llama_stack.distribution.datatypes import AccessAttributes, ModelWithACL -from llama_stack.distribution.routers.routing_tables import ModelsRoutingTable +from llama_stack.distribution.routing_tables.models import ModelsRoutingTable class AsyncMock(MagicMock): @@ -37,7 +37,7 @@ async def test_setup(cached_disk_dist_registry): @pytest.mark.asyncio -@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes") +@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes") async def test_access_control_with_cache(mock_get_auth_attributes, test_setup): registry, routing_table = test_setup model_public = ModelWithACL( @@ -102,7 +102,7 @@ async def test_access_control_with_cache(mock_get_auth_attributes, test_setup): @pytest.mark.asyncio -@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes") +@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes") async def test_access_control_and_updates(mock_get_auth_attributes, test_setup): registry, routing_table = test_setup model_public = ModelWithACL( @@ -132,7 +132,7 @@ async def test_access_control_and_updates(mock_get_auth_attributes, test_setup): @pytest.mark.asyncio -@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes") +@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes") async def test_access_control_empty_attributes(mock_get_auth_attributes, test_setup): registry, routing_table = test_setup model = ModelWithACL( @@ -154,7 +154,7 @@ async def test_access_control_empty_attributes(mock_get_auth_attributes, test_se @pytest.mark.asyncio -@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes") +@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes") async def test_no_user_attributes(mock_get_auth_attributes, test_setup): registry, routing_table = test_setup model_public = ModelWithACL( @@ -185,7 +185,7 @@ async def test_no_user_attributes(mock_get_auth_attributes, test_setup): @pytest.mark.asyncio -@patch("llama_stack.distribution.routers.routing_tables.get_auth_attributes") +@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes") async def test_automatic_access_attributes(mock_get_auth_attributes, test_setup): """Test that newly created resources inherit access attributes from their creator.""" registry, routing_table = test_setup diff --git a/tests/unit/server/test_resolver.py b/tests/unit/server/test_resolver.py index 3af9535a0..bb4c15dbc 100644 --- a/tests/unit/server/test_resolver.py +++ b/tests/unit/server/test_resolver.py @@ -19,8 +19,8 @@ from llama_stack.distribution.datatypes import ( StackRunConfig, ) from llama_stack.distribution.resolver import resolve_impls -from llama_stack.distribution.routers.routers import InferenceRouter -from llama_stack.distribution.routers.routing_tables import ModelsRoutingTable +from llama_stack.distribution.routers.inference import InferenceRouter +from llama_stack.distribution.routing_tables.models import ModelsRoutingTable from llama_stack.providers.datatypes import InlineProviderSpec, ProviderSpec