From 3b83032555d83695c01c8d40cd0750401160b03f Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Tue, 22 Jul 2025 15:22:48 -0700 Subject: [PATCH] feat(registry): more flexible model lookup (#2859) This PR updates model registration and lookup behavior to be slightly more general / flexible. See https://github.com/meta-llama/llama-stack/issues/2843 for more details. Note that this change is backwards compatible given the design of the `lookup_model()` method. ## Test Plan Added unit tests --- .github/workflows/integration-tests.yml | 2 +- .../workflows/integration-vector-io-tests.yml | 2 +- llama_stack/distribution/routers/datasets.py | 6 +- .../distribution/routers/eval_scoring.py | 21 ++- llama_stack/distribution/routers/inference.py | 17 +- llama_stack/distribution/routers/safety.py | 3 +- .../distribution/routers/tool_runtime.py | 13 +- llama_stack/distribution/routers/vector_io.py | 14 +- .../distribution/routing_tables/common.py | 28 +++- .../distribution/routing_tables/models.py | 36 +++-- .../distribution/routing_tables/toolgroups.py | 6 +- .../distribution/routing_tables/vector_dbs.py | 37 +++-- llama_stack/providers/datatypes.py | 2 +- .../utils/inference/embedding_mixin.py | 2 +- .../routers/test_routing_tables.py | 151 ++++++++++++++++-- 15 files changed, 265 insertions(+), 75 deletions(-) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index f8f01756d..082f1e204 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -99,7 +99,7 @@ jobs: uv run pytest -s -v tests/integration/${{ matrix.test-type }} --stack-config=${stack_config} \ -k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \ --text-model="ollama/llama3.2:3b-instruct-fp16" \ - --embedding-model=all-MiniLM-L6-v2 \ + --embedding-model=sentence-transformers/all-MiniLM-L6-v2 \ --safety-shield=$SAFETY_MODEL \ --color=yes \ --capture=tee-sys | tee pytest-${{ matrix.test-type }}.log diff --git a/.github/workflows/integration-vector-io-tests.yml b/.github/workflows/integration-vector-io-tests.yml index ec236b33b..525c17d46 100644 --- a/.github/workflows/integration-vector-io-tests.yml +++ b/.github/workflows/integration-vector-io-tests.yml @@ -114,7 +114,7 @@ jobs: run: | uv run pytest -sv --stack-config="inference=inline::sentence-transformers,vector_io=${{ matrix.vector-io-provider }}" \ tests/integration/vector_io \ - --embedding-model all-MiniLM-L6-v2 + --embedding-model sentence-transformers/all-MiniLM-L6-v2 - name: Check Storage and Memory Available After Tests if: ${{ always() }} diff --git a/llama_stack/distribution/routers/datasets.py b/llama_stack/distribution/routers/datasets.py index 6f28756c9..d7984f729 100644 --- a/llama_stack/distribution/routers/datasets.py +++ b/llama_stack/distribution/routers/datasets.py @@ -57,7 +57,8 @@ class DatasetIORouter(DatasetIO): logger.debug( f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}", ) - return await self.routing_table.get_provider_impl(dataset_id).iterrows( + provider = await self.routing_table.get_provider_impl(dataset_id) + return await provider.iterrows( dataset_id=dataset_id, start_index=start_index, limit=limit, @@ -65,7 +66,8 @@ class DatasetIORouter(DatasetIO): async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None: logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows") - return await self.routing_table.get_provider_impl(dataset_id).append_rows( + provider = await self.routing_table.get_provider_impl(dataset_id) + return await provider.append_rows( dataset_id=dataset_id, rows=rows, ) diff --git a/llama_stack/distribution/routers/eval_scoring.py b/llama_stack/distribution/routers/eval_scoring.py index fd0bb90a7..f7a17eecf 100644 --- a/llama_stack/distribution/routers/eval_scoring.py +++ b/llama_stack/distribution/routers/eval_scoring.py @@ -44,7 +44,8 @@ class ScoringRouter(Scoring): logger.debug(f"ScoringRouter.score_batch: {dataset_id}") res = {} for fn_identifier in scoring_functions.keys(): - score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch( + provider = await self.routing_table.get_provider_impl(fn_identifier) + score_response = await provider.score_batch( dataset_id=dataset_id, scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, ) @@ -66,7 +67,8 @@ class ScoringRouter(Scoring): res = {} # look up and map each scoring function to its provider impl for fn_identifier in scoring_functions.keys(): - score_response = await self.routing_table.get_provider_impl(fn_identifier).score( + provider = await self.routing_table.get_provider_impl(fn_identifier) + score_response = await provider.score( input_rows=input_rows, scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, ) @@ -97,7 +99,8 @@ class EvalRouter(Eval): benchmark_config: BenchmarkConfig, ) -> Job: logger.debug(f"EvalRouter.run_eval: {benchmark_id}") - return await self.routing_table.get_provider_impl(benchmark_id).run_eval( + provider = await self.routing_table.get_provider_impl(benchmark_id) + return await provider.run_eval( benchmark_id=benchmark_id, benchmark_config=benchmark_config, ) @@ -110,7 +113,8 @@ class EvalRouter(Eval): benchmark_config: BenchmarkConfig, ) -> EvaluateResponse: logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows") - return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows( + provider = await self.routing_table.get_provider_impl(benchmark_id) + return await provider.evaluate_rows( benchmark_id=benchmark_id, input_rows=input_rows, scoring_functions=scoring_functions, @@ -123,7 +127,8 @@ class EvalRouter(Eval): job_id: str, ) -> Job: logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}") - return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id) + provider = await self.routing_table.get_provider_impl(benchmark_id) + return await provider.job_status(benchmark_id, job_id) async def job_cancel( self, @@ -131,7 +136,8 @@ class EvalRouter(Eval): job_id: str, ) -> None: logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}") - await self.routing_table.get_provider_impl(benchmark_id).job_cancel( + provider = await self.routing_table.get_provider_impl(benchmark_id) + await provider.job_cancel( benchmark_id, job_id, ) @@ -142,7 +148,8 @@ class EvalRouter(Eval): job_id: str, ) -> EvaluateResponse: logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}") - return await self.routing_table.get_provider_impl(benchmark_id).job_result( + provider = await self.routing_table.get_provider_impl(benchmark_id) + return await provider.job_result( benchmark_id, job_id, ) diff --git a/llama_stack/distribution/routers/inference.py b/llama_stack/distribution/routers/inference.py index b39da7810..a5cc8c4b5 100644 --- a/llama_stack/distribution/routers/inference.py +++ b/llama_stack/distribution/routers/inference.py @@ -231,7 +231,7 @@ class InferenceRouter(Inference): logprobs=logprobs, tool_config=tool_config, ) - provider = self.routing_table.get_provider_impl(model_id) + provider = await self.routing_table.get_provider_impl(model_id) prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format) if stream: @@ -292,7 +292,7 @@ class InferenceRouter(Inference): logger.debug( f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}", ) - provider = self.routing_table.get_provider_impl(model_id) + provider = await self.routing_table.get_provider_impl(model_id) return await provider.batch_chat_completion( model_id=model_id, messages_batch=messages_batch, @@ -322,7 +322,7 @@ class InferenceRouter(Inference): raise ValueError(f"Model '{model_id}' not found") if model.model_type == ModelType.embedding: raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions") - provider = self.routing_table.get_provider_impl(model_id) + provider = await self.routing_table.get_provider_impl(model_id) params = dict( model_id=model_id, content=content, @@ -378,7 +378,7 @@ class InferenceRouter(Inference): logger.debug( f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}", ) - provider = self.routing_table.get_provider_impl(model_id) + provider = await self.routing_table.get_provider_impl(model_id) return await provider.batch_completion(model_id, content_batch, sampling_params, response_format, logprobs) async def embeddings( @@ -395,7 +395,8 @@ class InferenceRouter(Inference): raise ValueError(f"Model '{model_id}' not found") if model.model_type == ModelType.llm: raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings") - return await self.routing_table.get_provider_impl(model_id).embeddings( + provider = await self.routing_table.get_provider_impl(model_id) + return await provider.embeddings( model_id=model_id, contents=contents, text_truncation=text_truncation, @@ -458,7 +459,7 @@ class InferenceRouter(Inference): suffix=suffix, ) - provider = self.routing_table.get_provider_impl(model_obj.identifier) + provider = await self.routing_table.get_provider_impl(model_obj.identifier) return await provider.openai_completion(**params) async def openai_chat_completion( @@ -538,7 +539,7 @@ class InferenceRouter(Inference): user=user, ) - provider = self.routing_table.get_provider_impl(model_obj.identifier) + provider = await self.routing_table.get_provider_impl(model_obj.identifier) if stream: response_stream = await provider.openai_chat_completion(**params) if self.store: @@ -575,7 +576,7 @@ class InferenceRouter(Inference): user=user, ) - provider = self.routing_table.get_provider_impl(model_obj.identifier) + provider = await self.routing_table.get_provider_impl(model_obj.identifier) return await provider.openai_embeddings(**params) async def list_chat_completions( diff --git a/llama_stack/distribution/routers/safety.py b/llama_stack/distribution/routers/safety.py index 9761d2db0..26ee8e722 100644 --- a/llama_stack/distribution/routers/safety.py +++ b/llama_stack/distribution/routers/safety.py @@ -50,7 +50,8 @@ class SafetyRouter(Safety): params: dict[str, Any] = None, ) -> RunShieldResponse: logger.debug(f"SafetyRouter.run_shield: {shield_id}") - return await self.routing_table.get_provider_impl(shield_id).run_shield( + provider = await self.routing_table.get_provider_impl(shield_id) + return await provider.run_shield( shield_id=shield_id, messages=messages, params=params, diff --git a/llama_stack/distribution/routers/tool_runtime.py b/llama_stack/distribution/routers/tool_runtime.py index 285843dbc..5a40bc0c5 100644 --- a/llama_stack/distribution/routers/tool_runtime.py +++ b/llama_stack/distribution/routers/tool_runtime.py @@ -41,9 +41,8 @@ class ToolRuntimeRouter(ToolRuntime): query_config: RAGQueryConfig | None = None, ) -> RAGQueryResult: logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}") - return await self.routing_table.get_provider_impl("knowledge_search").query( - content, vector_db_ids, query_config - ) + provider = await self.routing_table.get_provider_impl("knowledge_search") + return await provider.query(content, vector_db_ids, query_config) async def insert( self, @@ -54,9 +53,8 @@ class ToolRuntimeRouter(ToolRuntime): logger.debug( f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}" ) - return await self.routing_table.get_provider_impl("insert_into_memory").insert( - documents, vector_db_id, chunk_size_in_tokens - ) + provider = await self.routing_table.get_provider_impl("insert_into_memory") + return await provider.insert(documents, vector_db_id, chunk_size_in_tokens) def __init__( self, @@ -80,7 +78,8 @@ class ToolRuntimeRouter(ToolRuntime): async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any: logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}") - return await self.routing_table.get_provider_impl(tool_name).invoke_tool( + provider = await self.routing_table.get_provider_impl(tool_name) + return await provider.invoke_tool( tool_name=tool_name, kwargs=kwargs, ) diff --git a/llama_stack/distribution/routers/vector_io.py b/llama_stack/distribution/routers/vector_io.py index a1dd66060..3d0996c49 100644 --- a/llama_stack/distribution/routers/vector_io.py +++ b/llama_stack/distribution/routers/vector_io.py @@ -104,7 +104,8 @@ class VectorIORouter(VectorIO): logger.debug( f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}", ) - return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds) + provider = await self.routing_table.get_provider_impl(vector_db_id) + return await provider.insert_chunks(vector_db_id, chunks, ttl_seconds) async def query_chunks( self, @@ -113,7 +114,8 @@ class VectorIORouter(VectorIO): params: dict[str, Any] | None = None, ) -> QueryChunksResponse: logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}") - return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params) + provider = await self.routing_table.get_provider_impl(vector_db_id) + return await provider.query_chunks(vector_db_id, query, params) # OpenAI Vector Stores API endpoints async def openai_create_vector_store( @@ -146,7 +148,8 @@ class VectorIORouter(VectorIO): provider_vector_db_id=vector_db_id, vector_db_name=name, ) - return await self.routing_table.get_provider_impl(registered_vector_db.identifier).openai_create_vector_store( + provider = await self.routing_table.get_provider_impl(registered_vector_db.identifier) + return await provider.openai_create_vector_store( name=name, file_ids=file_ids, expires_after=expires_after, @@ -172,9 +175,8 @@ class VectorIORouter(VectorIO): all_stores = [] for vector_db in vector_dbs: try: - vector_store = await self.routing_table.get_provider_impl( - vector_db.identifier - ).openai_retrieve_vector_store(vector_db.identifier) + provider = await self.routing_table.get_provider_impl(vector_db.identifier) + vector_store = await provider.openai_retrieve_vector_store(vector_db.identifier) all_stores.append(vector_store) except Exception as e: logger.error(f"Error retrieving vector store {vector_db.identifier}: {e}") diff --git a/llama_stack/distribution/routing_tables/common.py b/llama_stack/distribution/routing_tables/common.py index bbe0113e9..2f6ac90bb 100644 --- a/llama_stack/distribution/routing_tables/common.py +++ b/llama_stack/distribution/routing_tables/common.py @@ -6,6 +6,7 @@ from typing import Any +from llama_stack.apis.models import Model from llama_stack.apis.resource import ResourceType from llama_stack.apis.scoring_functions import ScoringFn from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed @@ -116,7 +117,7 @@ class CommonRoutingTableImpl(RoutingTable): for p in self.impls_by_provider_id.values(): await p.shutdown() - def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any: + async def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any: from .benchmarks import BenchmarksRoutingTable from .datasets import DatasetsRoutingTable from .models import ModelsRoutingTable @@ -235,3 +236,28 @@ class CommonRoutingTableImpl(RoutingTable): ] return filtered_objs + + +async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) -> Model: + # first try to get the model by identifier + # this works if model_id is an alias or is of the form provider_id/provider_model_id + model = await routing_table.get_object_by_identifier("model", model_id) + if model is not None: + return model + + logger.warning( + f"WARNING: model identifier '{model_id}' not found in routing table. Falling back to " + "searching in all providers. This is only for backwards compatibility and will stop working " + "soon. Migrate your calls to use fully scoped `provider_id/model_id` names." + ) + # if not found, this means model_id is an unscoped provider_model_id, we need + # to iterate (given a lack of an efficient index on the KVStore) + models = await routing_table.get_all_with_type("model") + matching_models = [m for m in models if m.provider_resource_id == model_id] + if len(matching_models) == 0: + raise ValueError(f"Model '{model_id}' not found") + + if len(matching_models) > 1: + raise ValueError(f"Multiple providers found for '{model_id}': {[m.provider_id for m in matching_models]}") + + return matching_models[0] diff --git a/llama_stack/distribution/routing_tables/models.py b/llama_stack/distribution/routing_tables/models.py index 9a9db7257..f2787b308 100644 --- a/llama_stack/distribution/routing_tables/models.py +++ b/llama_stack/distribution/routing_tables/models.py @@ -13,7 +13,7 @@ from llama_stack.distribution.datatypes import ( ) from llama_stack.log import get_logger -from .common import CommonRoutingTableImpl +from .common import CommonRoutingTableImpl, lookup_model logger = get_logger(name=__name__, category="core") @@ -36,10 +36,11 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): return OpenAIListModelsResponse(data=openai_models) async def get_model(self, model_id: str) -> Model: - model = await self.get_object_by_identifier("model", model_id) - if model is None: - raise ValueError(f"Model '{model_id}' not found") - return model + return await lookup_model(self, model_id) + + async def get_provider_impl(self, model_id: str) -> Any: + model = await lookup_model(self, model_id) + return self.impls_by_provider_id[model.provider_id] async def register_model( self, @@ -49,24 +50,33 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): metadata: dict[str, Any] | None = None, model_type: ModelType | None = None, ) -> Model: - if provider_model_id is None: - provider_model_id = model_id if provider_id is None: # If provider_id not specified, use the only provider if it supports this model if len(self.impls_by_provider_id) == 1: provider_id = list(self.impls_by_provider_id.keys())[0] else: raise ValueError( - f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}" + f"Please specify a provider_id for model {model_id} since multiple providers are available: {self.impls_by_provider_id.keys()}.\n\n" + "Use the provider_id as a prefix to disambiguate, e.g. 'provider_id/model_id'." ) - if metadata is None: - metadata = {} - if model_type is None: - model_type = ModelType.llm + + provider_model_id = provider_model_id or model_id + metadata = metadata or {} + model_type = model_type or ModelType.llm if "embedding_dimension" not in metadata and model_type == ModelType.embedding: raise ValueError("Embedding model must have an embedding dimension in its metadata") + + # an identifier different than provider_model_id implies it is an alias, so that + # becomes the globally unique identifier. otherwise provider_model_ids can conflict, + # so as a general rule we must use the provider_id to disambiguate. + + if model_id != provider_model_id: + identifier = model_id + else: + identifier = f"{provider_id}/{provider_model_id}" + model = ModelWithOwner( - identifier=model_id, + identifier=identifier, provider_resource_id=provider_model_id, provider_id=provider_id, metadata=metadata, diff --git a/llama_stack/distribution/routing_tables/toolgroups.py b/llama_stack/distribution/routing_tables/toolgroups.py index b86f057bd..22c4e109a 100644 --- a/llama_stack/distribution/routing_tables/toolgroups.py +++ b/llama_stack/distribution/routing_tables/toolgroups.py @@ -30,7 +30,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): tool_to_toolgroup: dict[str, str] = {} # overridden - def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any: + async def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any: # we don't index tools in the registry anymore, but only keep a cache of them by toolgroup_id # TODO: we may want to invalidate the cache (for a given toolgroup_id) every once in a while? @@ -40,7 +40,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): if routing_key in self.tool_to_toolgroup: routing_key = self.tool_to_toolgroup[routing_key] - return super().get_provider_impl(routing_key, provider_id) + return await super().get_provider_impl(routing_key, provider_id) async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse: if toolgroup_id: @@ -59,7 +59,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): return ListToolsResponse(data=all_tools) async def _index_tools(self, toolgroup: ToolGroup): - provider_impl = super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id) + provider_impl = await super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id) tooldefs_response = await provider_impl.list_runtime_tools(toolgroup.identifier, toolgroup.mcp_endpoint) # TODO: kill this Tool vs ToolDef distinction diff --git a/llama_stack/distribution/routing_tables/vector_dbs.py b/llama_stack/distribution/routing_tables/vector_dbs.py index b4e60c625..58ecf24da 100644 --- a/llama_stack/distribution/routing_tables/vector_dbs.py +++ b/llama_stack/distribution/routing_tables/vector_dbs.py @@ -27,7 +27,7 @@ from llama_stack.distribution.datatypes import ( ) from llama_stack.log import get_logger -from .common import CommonRoutingTableImpl +from .common import CommonRoutingTableImpl, lookup_model logger = get_logger(name=__name__, category="core") @@ -51,8 +51,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): provider_vector_db_id: str | None = None, vector_db_name: str | None = None, ) -> VectorDB: - if provider_vector_db_id is None: - provider_vector_db_id = vector_db_id + provider_vector_db_id = provider_vector_db_id or vector_db_id if provider_id is None: if len(self.impls_by_provider_id) > 0: provider_id = list(self.impls_by_provider_id.keys())[0] @@ -62,7 +61,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): ) else: raise ValueError("No provider available. Please configure a vector_io provider.") - model = await self.get_object_by_identifier("model", embedding_model) + model = await lookup_model(self, embedding_model) if model is None: raise ValueError(f"Model {embedding_model} not found") if model.model_type != ModelType.embedding: @@ -93,7 +92,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): vector_store_id: str, ) -> VectorStoreObject: await self.assert_action_allowed("read", "vector_db", vector_store_id) - return await self.get_provider_impl(vector_store_id).openai_retrieve_vector_store(vector_store_id) + provider = await self.get_provider_impl(vector_store_id) + return await provider.openai_retrieve_vector_store(vector_store_id) async def openai_update_vector_store( self, @@ -103,7 +103,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): metadata: dict[str, Any] | None = None, ) -> VectorStoreObject: await self.assert_action_allowed("update", "vector_db", vector_store_id) - return await self.get_provider_impl(vector_store_id).openai_update_vector_store( + provider = await self.get_provider_impl(vector_store_id) + return await provider.openai_update_vector_store( vector_store_id=vector_store_id, name=name, expires_after=expires_after, @@ -115,7 +116,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): vector_store_id: str, ) -> VectorStoreDeleteResponse: await self.assert_action_allowed("delete", "vector_db", vector_store_id) - result = await self.get_provider_impl(vector_store_id).openai_delete_vector_store(vector_store_id) + provider = await self.get_provider_impl(vector_store_id) + result = await provider.openai_delete_vector_store(vector_store_id) await self.unregister_vector_db(vector_store_id) return result @@ -130,7 +132,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): search_mode: str | None = "vector", ) -> VectorStoreSearchResponsePage: await self.assert_action_allowed("read", "vector_db", vector_store_id) - return await self.get_provider_impl(vector_store_id).openai_search_vector_store( + provider = await self.get_provider_impl(vector_store_id) + return await provider.openai_search_vector_store( vector_store_id=vector_store_id, query=query, filters=filters, @@ -148,7 +151,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): chunking_strategy: VectorStoreChunkingStrategy | None = None, ) -> VectorStoreFileObject: await self.assert_action_allowed("update", "vector_db", vector_store_id) - return await self.get_provider_impl(vector_store_id).openai_attach_file_to_vector_store( + provider = await self.get_provider_impl(vector_store_id) + return await provider.openai_attach_file_to_vector_store( vector_store_id=vector_store_id, file_id=file_id, attributes=attributes, @@ -165,7 +169,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): filter: VectorStoreFileStatus | None = None, ) -> list[VectorStoreFileObject]: await self.assert_action_allowed("read", "vector_db", vector_store_id) - return await self.get_provider_impl(vector_store_id).openai_list_files_in_vector_store( + provider = await self.get_provider_impl(vector_store_id) + return await provider.openai_list_files_in_vector_store( vector_store_id=vector_store_id, limit=limit, order=order, @@ -180,7 +185,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): file_id: str, ) -> VectorStoreFileObject: await self.assert_action_allowed("read", "vector_db", vector_store_id) - return await self.get_provider_impl(vector_store_id).openai_retrieve_vector_store_file( + provider = await self.get_provider_impl(vector_store_id) + return await provider.openai_retrieve_vector_store_file( vector_store_id=vector_store_id, file_id=file_id, ) @@ -191,7 +197,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): file_id: str, ) -> VectorStoreFileContentsResponse: await self.assert_action_allowed("read", "vector_db", vector_store_id) - return await self.get_provider_impl(vector_store_id).openai_retrieve_vector_store_file_contents( + provider = await self.get_provider_impl(vector_store_id) + return await provider.openai_retrieve_vector_store_file_contents( vector_store_id=vector_store_id, file_id=file_id, ) @@ -203,7 +210,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): attributes: dict[str, Any], ) -> VectorStoreFileObject: await self.assert_action_allowed("update", "vector_db", vector_store_id) - return await self.get_provider_impl(vector_store_id).openai_update_vector_store_file( + provider = await self.get_provider_impl(vector_store_id) + return await provider.openai_update_vector_store_file( vector_store_id=vector_store_id, file_id=file_id, attributes=attributes, @@ -215,7 +223,8 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): file_id: str, ) -> VectorStoreFileDeleteResponse: await self.assert_action_allowed("delete", "vector_db", vector_store_id) - return await self.get_provider_impl(vector_store_id).openai_delete_vector_store_file( + provider = await self.get_provider_impl(vector_store_id) + return await provider.openai_delete_vector_store_file( vector_store_id=vector_store_id, file_id=file_id, ) diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index efe8a98fe..424380324 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -113,7 +113,7 @@ class ProviderSpec(BaseModel): class RoutingTable(Protocol): - def get_provider_impl(self, routing_key: str) -> Any: ... + async def get_provider_impl(self, routing_key: str) -> Any: ... # TODO: this can now be inlined into RemoteProviderSpec diff --git a/llama_stack/providers/utils/inference/embedding_mixin.py b/llama_stack/providers/utils/inference/embedding_mixin.py index 97cf87360..32e89f987 100644 --- a/llama_stack/providers/utils/inference/embedding_mixin.py +++ b/llama_stack/providers/utils/inference/embedding_mixin.py @@ -88,7 +88,7 @@ class SentenceTransformerEmbeddingMixin: usage = OpenAIEmbeddingUsage(prompt_tokens=-1, total_tokens=-1) return OpenAIEmbeddingsResponse( data=data, - model=model_obj.provider_resource_id, + model=model, usage=usage, ) diff --git a/tests/unit/distribution/routers/test_routing_tables.py b/tests/unit/distribution/routers/test_routing_tables.py index 30f795d33..12b05ebff 100644 --- a/tests/unit/distribution/routers/test_routing_tables.py +++ b/tests/unit/distribution/routers/test_routing_tables.py @@ -11,15 +11,17 @@ from unittest.mock import AsyncMock from llama_stack.apis.common.type_system import NumberType from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource from llama_stack.apis.datatypes import Api -from llama_stack.apis.models import Model +from llama_stack.apis.models import Model, ModelType from llama_stack.apis.shields.shields import Shield from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter +from llama_stack.apis.vector_dbs import VectorDB from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable from llama_stack.distribution.routing_tables.models import ModelsRoutingTable from llama_stack.distribution.routing_tables.scoring_functions import ScoringFunctionsRoutingTable from llama_stack.distribution.routing_tables.shields import ShieldsRoutingTable from llama_stack.distribution.routing_tables.toolgroups import ToolGroupsRoutingTable +from llama_stack.distribution.routing_tables.vector_dbs import VectorDBsRoutingTable class Impl: @@ -104,6 +106,17 @@ class ToolGroupsImpl(Impl): ) +class VectorDBImpl(Impl): + def __init__(self): + super().__init__(Api.vector_io) + + async def register_vector_db(self, vector_db: VectorDB): + return vector_db + + async def unregister_vector_db(self, vector_db_id: str): + return vector_db_id + + async def test_models_routing_table(cached_disk_dist_registry): table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) await table.initialize() @@ -115,27 +128,27 @@ async def test_models_routing_table(cached_disk_dist_registry): models = await table.list_models() assert len(models.data) == 2 model_ids = {m.identifier for m in models.data} - assert "test-model" in model_ids - assert "test-model-2" in model_ids + assert "test_provider/test-model" in model_ids + assert "test_provider/test-model-2" in model_ids # Test openai list models openai_models = await table.openai_list_models() assert len(openai_models.data) == 2 openai_model_ids = {m.id for m in openai_models.data} - assert "test-model" in openai_model_ids - assert "test-model-2" in openai_model_ids + assert "test_provider/test-model" in openai_model_ids + assert "test_provider/test-model-2" in openai_model_ids # Test get_object_by_identifier - model = await table.get_object_by_identifier("model", "test-model") + model = await table.get_object_by_identifier("model", "test_provider/test-model") assert model is not None - assert model.identifier == "test-model" + assert model.identifier == "test_provider/test-model" # Test get_object_by_identifier on non-existent object non_existent = await table.get_object_by_identifier("model", "non-existent-model") assert non_existent is None - await table.unregister_model(model_id="test-model") - await table.unregister_model(model_id="test-model-2") + await table.unregister_model(model_id="test_provider/test-model") + await table.unregister_model(model_id="test_provider/test-model-2") models = await table.list_models() assert len(models.data) == 0 @@ -160,6 +173,36 @@ async def test_shields_routing_table(cached_disk_dist_registry): assert "test-shield-2" in shield_ids +async def test_vectordbs_routing_table(cached_disk_dist_registry): + table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {}) + await table.initialize() + + m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) + await m_table.initialize() + await m_table.register_model( + model_id="test-model", + provider_id="test_provider", + metadata={"embedding_dimension": 128}, + model_type=ModelType.embedding, + ) + + # Register multiple vector databases and verify listing + await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test_provider/test-model") + await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test_provider/test-model") + vector_dbs = await table.list_vector_dbs() + + assert len(vector_dbs.data) == 2 + vector_db_ids = {v.identifier for v in vector_dbs.data} + assert "test-vectordb" in vector_db_ids + assert "test-vectordb-2" in vector_db_ids + + await table.unregister_vector_db(vector_db_id="test-vectordb") + await table.unregister_vector_db(vector_db_id="test-vectordb-2") + + vector_dbs = await table.list_vector_dbs() + assert len(vector_dbs.data) == 0 + + async def test_datasets_routing_table(cached_disk_dist_registry): table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry, {}) await table.initialize() @@ -245,3 +288,93 @@ async def test_tool_groups_routing_table(cached_disk_dist_registry): await table.unregister_toolgroup(toolgroup_id="test-toolgroup") tool_groups = await table.list_tool_groups() assert len(tool_groups.data) == 0 + + +async def test_models_alias_registration_and_lookup(cached_disk_dist_registry): + """Test alias registration (model_id != provider_model_id) and lookup behavior.""" + table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) + await table.initialize() + + # Register model with alias (model_id different from provider_model_id) + await table.register_model( + model_id="my-alias", provider_model_id="actual-provider-model", provider_id="test_provider" + ) + + # Verify the model was registered with alias as identifier (not namespaced) + models = await table.list_models() + assert len(models.data) == 1 + model = models.data[0] + assert model.identifier == "my-alias" # Uses alias as identifier + assert model.provider_resource_id == "actual-provider-model" + + # Test lookup by alias works + retrieved_model = await table.get_model("my-alias") + assert retrieved_model.identifier == "my-alias" + assert retrieved_model.provider_resource_id == "actual-provider-model" + + +async def test_models_multi_provider_disambiguation(cached_disk_dist_registry): + """Test registration and lookup with multiple providers having same provider_model_id.""" + table = ModelsRoutingTable( + {"provider1": InferenceImpl(), "provider2": InferenceImpl()}, cached_disk_dist_registry, {} + ) + await table.initialize() + + # Register same provider_model_id on both providers (no aliases) + await table.register_model(model_id="common-model", provider_id="provider1") + await table.register_model(model_id="common-model", provider_id="provider2") + + # Verify both models get namespaced identifiers + models = await table.list_models() + assert len(models.data) == 2 + identifiers = {m.identifier for m in models.data} + assert identifiers == {"provider1/common-model", "provider2/common-model"} + + # Test lookup by full namespaced identifier works + model1 = await table.get_model("provider1/common-model") + assert model1.provider_id == "provider1" + assert model1.provider_resource_id == "common-model" + + model2 = await table.get_model("provider2/common-model") + assert model2.provider_id == "provider2" + assert model2.provider_resource_id == "common-model" + + # Test lookup by unscoped provider_model_id fails with multiple providers error + try: + await table.get_model("common-model") + raise AssertionError("Should have raised ValueError for multiple providers") + except ValueError as e: + assert "Multiple providers found" in str(e) + assert "provider1" in str(e) and "provider2" in str(e) + + +async def test_models_fallback_lookup_behavior(cached_disk_dist_registry): + """Test two-stage lookup: direct identifier hit vs fallback to provider_resource_id.""" + table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) + await table.initialize() + + # Register model without alias (gets namespaced identifier) + await table.register_model(model_id="test-model", provider_id="test_provider") + + # Verify namespaced identifier was created + models = await table.list_models() + assert len(models.data) == 1 + model = models.data[0] + assert model.identifier == "test_provider/test-model" + assert model.provider_resource_id == "test-model" + + # Test lookup by full namespaced identifier (direct hit via get_object_by_identifier) + retrieved_model = await table.get_model("test_provider/test-model") + assert retrieved_model.identifier == "test_provider/test-model" + + # Test lookup by unscoped provider_model_id (fallback via iteration) + retrieved_model = await table.get_model("test-model") + assert retrieved_model.identifier == "test_provider/test-model" + assert retrieved_model.provider_resource_id == "test-model" + + # Test lookup of non-existent model fails + try: + await table.get_model("non-existent") + raise AssertionError("Should have raised ValueError for non-existent model") + except ValueError as e: + assert "not found" in str(e)