From d361154102ac382916a8121128fe0ebc0d7824ec Mon Sep 17 00:00:00 2001 From: Francisco Javier Arceo Date: Fri, 11 Jul 2025 16:10:10 -0400 Subject: [PATCH] reverting to original call order for a simpler change Signed-off-by: Francisco Javier Arceo --- docs/_static/llama-stack-spec.html | 11 +++++------ docs/_static/llama-stack-spec.yaml | 13 +++++++------ llama_stack/apis/vector_dbs/vector_dbs.py | 3 ++- llama_stack/distribution/routers/vector_io.py | 7 ++++--- .../distribution/routing_tables/vector_dbs.py | 4 +++- tests/integration/vector_io/test_vector_io.py | 4 ++++ 6 files changed, 25 insertions(+), 17 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 2ec7c8369..6ca572a76 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -15610,10 +15610,6 @@ "type": "string", "description": "The identifier of the vector database to register." }, - "provider_vector_db_id": { - "type": "string", - "description": "The identifier of the vector database in the provider." - }, "embedding_model": { "type": "string", "description": "The embedding model to use." @@ -15628,13 +15624,16 @@ }, "vector_db_name": { "type": "string", - "description": "The name of the vector database." + "description": "The name of the vector database. :param provider_vector_db_id: The identifier of the vector database in the provider." + }, + "provider_vector_db_id": { + "type": "string", + "description": "The identifier of the vector database in the provider." } }, "additionalProperties": false, "required": [ "vector_db_id", - "provider_vector_db_id", "embedding_model" ], "title": "RegisterVectorDbRequest" diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 76340f061..fd3945d85 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -10922,10 +10922,6 @@ components: type: string description: >- The identifier of the vector database to register. - provider_vector_db_id: - type: string - description: >- - The identifier of the vector database in the provider. embedding_model: type: string description: The embedding model to use. @@ -10937,11 +10933,16 @@ components: description: The identifier of the provider. vector_db_name: type: string - description: The name of the vector database. + description: >- + The name of the vector database. :param provider_vector_db_id: The identifier + of the vector database in the provider. + provider_vector_db_id: + type: string + description: >- + The identifier of the vector database in the provider. additionalProperties: false required: - vector_db_id - - provider_vector_db_id - embedding_model title: RegisterVectorDbRequest ResumeAgentTurnRequest: diff --git a/llama_stack/apis/vector_dbs/vector_dbs.py b/llama_stack/apis/vector_dbs/vector_dbs.py index e4f41ad1a..d6d638f97 100644 --- a/llama_stack/apis/vector_dbs/vector_dbs.py +++ b/llama_stack/apis/vector_dbs/vector_dbs.py @@ -68,11 +68,11 @@ class VectorDBs(Protocol): async def register_vector_db( self, vector_db_id: str, - provider_vector_db_id: str, embedding_model: str, embedding_dimension: int | None = 384, provider_id: str | None = None, vector_db_name: str | None = None, + provider_vector_db_id: str | None = None, ) -> VectorDB: """Register a vector database. @@ -82,6 +82,7 @@ class VectorDBs(Protocol): :param embedding_dimension: The dimension of the embedding model. :param provider_id: The identifier of the provider. :param vector_db_name: The name of the vector database. + :param provider_vector_db_id: The identifier of the vector database in the provider. :returns: A VectorDB. """ ... diff --git a/llama_stack/distribution/routers/vector_io.py b/llama_stack/distribution/routers/vector_io.py index 1f8cf56d3..d35d5fa05 100644 --- a/llama_stack/distribution/routers/vector_io.py +++ b/llama_stack/distribution/routers/vector_io.py @@ -79,20 +79,20 @@ class VectorIORouter(VectorIO): async def register_vector_db( self, vector_db_id: str, - provider_vector_db_id: str, embedding_model: str, embedding_dimension: int | None = 384, provider_id: str | None = None, vector_db_name: str | None = None, + provider_vector_db_id: str | None = None, ) -> None: logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}") await self.routing_table.register_vector_db( vector_db_id, - provider_vector_db_id, embedding_model, embedding_dimension, provider_id, vector_db_name, + provider_vector_db_id, ) async def insert_chunks( @@ -126,6 +126,7 @@ class VectorIORouter(VectorIO): embedding_model: str | None = None, embedding_dimension: int | None = None, provider_id: str | None = None, + provider_vector_db_id: str | None = None, ) -> VectorStoreObject: logger.debug(f"VectorIORouter.openai_create_vector_store: name={name}, provider_id={provider_id}") @@ -139,11 +140,11 @@ class VectorIORouter(VectorIO): vector_db_id = f"vs_{uuid.uuid4()}" registered_vector_db = await self.routing_table.register_vector_db( - vector_db_id, vector_db_id, embedding_model, embedding_dimension, provider_id, + provider_vector_db_id, name, ) return await self.routing_table.get_provider_impl(registered_vector_db.identifier).openai_create_vector_store( diff --git a/llama_stack/distribution/routing_tables/vector_dbs.py b/llama_stack/distribution/routing_tables/vector_dbs.py index 87e67bfeb..f861102c8 100644 --- a/llama_stack/distribution/routing_tables/vector_dbs.py +++ b/llama_stack/distribution/routing_tables/vector_dbs.py @@ -32,12 +32,14 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): async def register_vector_db( self, vector_db_id: str, - provider_vector_db_id: str, embedding_model: str, embedding_dimension: int | None = 384, provider_id: str | None = None, + provider_vector_db_id: str | None = None, vector_db_name: str | None = None, ) -> VectorDB: + if provider_vector_db_id is None: + provider_vector_db_id = vector_db_id if provider_id is None: if len(self.impls_by_provider_id) > 0: provider_id = list(self.impls_by_provider_id.keys())[0] diff --git a/tests/integration/vector_io/test_vector_io.py b/tests/integration/vector_io/test_vector_io.py index 9cd4fc18c..f6953a4f1 100644 --- a/tests/integration/vector_io/test_vector_io.py +++ b/tests/integration/vector_io/test_vector_io.py @@ -53,6 +53,7 @@ def test_vector_db_retrieve(client_with_empty_registry, embedding_model_id, embe vector_db_id=vector_db_id, embedding_model=embedding_model_id, embedding_dimension=embedding_dimension, + provider_vector_db_id=vector_db_id, ) # Retrieve the memory bank and validate its properties @@ -69,6 +70,7 @@ def test_vector_db_register(client_with_empty_registry, embedding_model_id, embe vector_db_id=vector_db_id, embedding_model=embedding_model_id, embedding_dimension=embedding_dimension, + provider_vector_db_id=vector_db_id, ) vector_dbs_after_register = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()] @@ -96,6 +98,7 @@ def test_insert_chunks(client_with_empty_registry, embedding_model_id, embedding vector_db_id=vector_db_id, embedding_model=embedding_model_id, embedding_dimension=embedding_dimension, + provider_vector_db_id=vector_db_id, ) client_with_empty_registry.vector_io.insert( @@ -131,6 +134,7 @@ def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, e vector_db_id=vector_db_id, embedding_model=embedding_model_id, embedding_dimension=embedding_dimension, + provider_vector_db_id=vector_db_id, ) chunks_with_embeddings = [