mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
feat!: Migrate Vector DB IDs to Vector Store IDs (breaking change) (#3253)
Some checks failed
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 1s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 1s
Integration Tests (Replay) / Integration Tests (, , , client=, vision=) (push) Failing after 3s
Vector IO Integration Tests / test-matrix (push) Failing after 4s
Test Llama Stack Build / generate-matrix (push) Successful in 3s
Python Package Build Test / build (3.13) (push) Failing after 2s
Test Llama Stack Build / build-single-provider (push) Failing after 3s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 3s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 3s
Python Package Build Test / build (3.12) (push) Failing after 2s
Test External API and Providers / test-external (venv) (push) Failing after 3s
Unit Tests / unit-tests (3.13) (push) Failing after 3s
Update ReadTheDocs / update-readthedocs (push) Failing after 3s
Test Llama Stack Build / build (push) Failing after 3s
Unit Tests / unit-tests (3.12) (push) Failing after 4s
UI Tests / ui-tests (22) (push) Successful in 35s
Pre-commit / pre-commit (push) Successful in 1m15s
Some checks failed
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 1s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 1s
Integration Tests (Replay) / Integration Tests (, , , client=, vision=) (push) Failing after 3s
Vector IO Integration Tests / test-matrix (push) Failing after 4s
Test Llama Stack Build / generate-matrix (push) Successful in 3s
Python Package Build Test / build (3.13) (push) Failing after 2s
Test Llama Stack Build / build-single-provider (push) Failing after 3s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 3s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 3s
Python Package Build Test / build (3.12) (push) Failing after 2s
Test External API and Providers / test-external (venv) (push) Failing after 3s
Unit Tests / unit-tests (3.13) (push) Failing after 3s
Update ReadTheDocs / update-readthedocs (push) Failing after 3s
Test Llama Stack Build / build (push) Failing after 3s
Unit Tests / unit-tests (3.12) (push) Failing after 4s
UI Tests / ui-tests (22) (push) Successful in 35s
Pre-commit / pre-commit (push) Successful in 1m15s
# What does this PR do? This change migrates the VectorDB id generation to Vector Stores. This is a breaking change for **_some users_** that may have application code using the `vector_db_id` parameter in the request of the VectorDB protocol instead of the `VectorDB.identifier` in the response. By default we will now create a Vector Store every time we register a VectorDB. The caveat with this approach is that this maps the `vector_db_id` → `vector_store.name`. This is a reasonable tradeoff to transition users towards OpenAI Vector Stores. As an added benefit, registering VectorDBs will result in them appearing in the VectorStores admin UI. ### Why? This PR makes the `POST` API call to `/v1/vector-dbs` swap the `vector_db_id` parameter in the **request body** into the VectorStore's name field and sets the `vector_db_id` to the generated vector store id (e.g., `vs_038247dd-4bbb-4dbb-a6be-d5ecfd46cfdb`). That means that users would have to do something like follows in their application code: ```python res = client.vector_dbs.register( vector_db_id='my-vector-db-id', embedding_model='ollama/all-minilm:l6-v2', embedding_dimension=384, ) vector_db_id = res.identifier ``` And then the rest of their code would behave, including `VectorIO`'s insert protocol using `vector_db_id` in the request. An alternative implementation would be to just delete the `vector_db_id` parameter in `VectorDB` but the end result would still require users having to write `vector_db_id = res.identifier` since `VectorStores.create()` generates the ID for you. So this approach felt the easiest way to migrate users towards VectorStores (subsequent PRs will be added to trigger `files.create()` and `vector_stores.files.create()`). ## Test Plan Unit tests and integration tests have been added. Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
64b2977162
commit
e2fe39aee1
4 changed files with 209 additions and 49 deletions
|
@ -52,7 +52,6 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
||||||
provider_vector_db_id: str | None = None,
|
provider_vector_db_id: str | None = None,
|
||||||
vector_db_name: str | None = None,
|
vector_db_name: str | None = None,
|
||||||
) -> VectorDB:
|
) -> VectorDB:
|
||||||
provider_vector_db_id = provider_vector_db_id or vector_db_id
|
|
||||||
if provider_id is None:
|
if provider_id is None:
|
||||||
if len(self.impls_by_provider_id) > 0:
|
if len(self.impls_by_provider_id) > 0:
|
||||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||||
|
@ -69,14 +68,33 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
||||||
raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding)
|
raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding)
|
||||||
if "embedding_dimension" not in model.metadata:
|
if "embedding_dimension" not in model.metadata:
|
||||||
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
|
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
|
||||||
|
|
||||||
|
provider = self.impls_by_provider_id[provider_id]
|
||||||
|
logger.warning(
|
||||||
|
"VectorDB is being deprecated in future releases in favor of VectorStore. Please migrate your usage accordingly."
|
||||||
|
)
|
||||||
|
vector_store = await provider.openai_create_vector_store(
|
||||||
|
name=vector_db_name or vector_db_id,
|
||||||
|
embedding_model=embedding_model,
|
||||||
|
embedding_dimension=model.metadata["embedding_dimension"],
|
||||||
|
provider_id=provider_id,
|
||||||
|
provider_vector_db_id=provider_vector_db_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
vector_store_id = vector_store.id
|
||||||
|
actual_provider_vector_db_id = provider_vector_db_id or vector_store_id
|
||||||
|
logger.warning(
|
||||||
|
f"Ignoring vector_db_id {vector_db_id} and using vector_store_id {vector_store_id} instead. Setting VectorDB {vector_db_id} to VectorDB.vector_db_name"
|
||||||
|
)
|
||||||
|
|
||||||
vector_db_data = {
|
vector_db_data = {
|
||||||
"identifier": vector_db_id,
|
"identifier": vector_store_id,
|
||||||
"type": ResourceType.vector_db.value,
|
"type": ResourceType.vector_db.value,
|
||||||
"provider_id": provider_id,
|
"provider_id": provider_id,
|
||||||
"provider_resource_id": provider_vector_db_id,
|
"provider_resource_id": actual_provider_vector_db_id,
|
||||||
"embedding_model": embedding_model,
|
"embedding_model": embedding_model,
|
||||||
"embedding_dimension": model.metadata["embedding_dimension"],
|
"embedding_dimension": model.metadata["embedding_dimension"],
|
||||||
"vector_db_name": vector_db_name,
|
"vector_db_name": vector_store.name,
|
||||||
}
|
}
|
||||||
vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data)
|
vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data)
|
||||||
await self.register_object(vector_db)
|
await self.register_object(vector_db)
|
||||||
|
|
|
@ -47,34 +47,45 @@ def client_with_empty_registry(client_with_models):
|
||||||
|
|
||||||
|
|
||||||
def test_vector_db_retrieve(client_with_empty_registry, embedding_model_id, embedding_dimension):
|
def test_vector_db_retrieve(client_with_empty_registry, embedding_model_id, embedding_dimension):
|
||||||
# Register a memory bank first
|
vector_db_name = "test_vector_db"
|
||||||
vector_db_id = "test_vector_db"
|
register_response = client_with_empty_registry.vector_dbs.register(
|
||||||
client_with_empty_registry.vector_dbs.register(
|
vector_db_id=vector_db_name,
|
||||||
vector_db_id=vector_db_id,
|
|
||||||
embedding_model=embedding_model_id,
|
embedding_model=embedding_model_id,
|
||||||
embedding_dimension=embedding_dimension,
|
embedding_dimension=embedding_dimension,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
actual_vector_db_id = register_response.identifier
|
||||||
|
|
||||||
# Retrieve the memory bank and validate its properties
|
# Retrieve the memory bank and validate its properties
|
||||||
response = client_with_empty_registry.vector_dbs.retrieve(vector_db_id=vector_db_id)
|
response = client_with_empty_registry.vector_dbs.retrieve(vector_db_id=actual_vector_db_id)
|
||||||
assert response is not None
|
assert response is not None
|
||||||
assert response.identifier == vector_db_id
|
assert response.identifier == actual_vector_db_id
|
||||||
assert response.embedding_model == embedding_model_id
|
assert response.embedding_model == embedding_model_id
|
||||||
assert response.provider_resource_id == vector_db_id
|
assert response.identifier.startswith("vs_")
|
||||||
|
|
||||||
|
|
||||||
def test_vector_db_register(client_with_empty_registry, embedding_model_id, embedding_dimension):
|
def test_vector_db_register(client_with_empty_registry, embedding_model_id, embedding_dimension):
|
||||||
vector_db_id = "test_vector_db"
|
vector_db_name = "test_vector_db"
|
||||||
client_with_empty_registry.vector_dbs.register(
|
response = client_with_empty_registry.vector_dbs.register(
|
||||||
vector_db_id=vector_db_id,
|
vector_db_id=vector_db_name,
|
||||||
embedding_model=embedding_model_id,
|
embedding_model=embedding_model_id,
|
||||||
embedding_dimension=embedding_dimension,
|
embedding_dimension=embedding_dimension,
|
||||||
)
|
)
|
||||||
|
|
||||||
vector_dbs_after_register = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
|
actual_vector_db_id = response.identifier
|
||||||
assert vector_dbs_after_register == [vector_db_id]
|
assert actual_vector_db_id.startswith("vs_")
|
||||||
|
assert actual_vector_db_id != vector_db_name
|
||||||
|
|
||||||
client_with_empty_registry.vector_dbs.unregister(vector_db_id=vector_db_id)
|
vector_dbs_after_register = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
|
||||||
|
assert vector_dbs_after_register == [actual_vector_db_id]
|
||||||
|
|
||||||
|
vector_stores = client_with_empty_registry.vector_stores.list()
|
||||||
|
assert len(vector_stores.data) == 1
|
||||||
|
vector_store = vector_stores.data[0]
|
||||||
|
assert vector_store.id == actual_vector_db_id
|
||||||
|
assert vector_store.name == vector_db_name
|
||||||
|
|
||||||
|
client_with_empty_registry.vector_dbs.unregister(vector_db_id=actual_vector_db_id)
|
||||||
|
|
||||||
vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
|
vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
|
||||||
assert len(vector_dbs) == 0
|
assert len(vector_dbs) == 0
|
||||||
|
@ -91,20 +102,22 @@ def test_vector_db_register(client_with_empty_registry, embedding_model_id, embe
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_insert_chunks(client_with_empty_registry, embedding_model_id, embedding_dimension, sample_chunks, test_case):
|
def test_insert_chunks(client_with_empty_registry, embedding_model_id, embedding_dimension, sample_chunks, test_case):
|
||||||
vector_db_id = "test_vector_db"
|
vector_db_name = "test_vector_db"
|
||||||
client_with_empty_registry.vector_dbs.register(
|
register_response = client_with_empty_registry.vector_dbs.register(
|
||||||
vector_db_id=vector_db_id,
|
vector_db_id=vector_db_name,
|
||||||
embedding_model=embedding_model_id,
|
embedding_model=embedding_model_id,
|
||||||
embedding_dimension=embedding_dimension,
|
embedding_dimension=embedding_dimension,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
actual_vector_db_id = register_response.identifier
|
||||||
|
|
||||||
client_with_empty_registry.vector_io.insert(
|
client_with_empty_registry.vector_io.insert(
|
||||||
vector_db_id=vector_db_id,
|
vector_db_id=actual_vector_db_id,
|
||||||
chunks=sample_chunks,
|
chunks=sample_chunks,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = client_with_empty_registry.vector_io.query(
|
response = client_with_empty_registry.vector_io.query(
|
||||||
vector_db_id=vector_db_id,
|
vector_db_id=actual_vector_db_id,
|
||||||
query="What is the capital of France?",
|
query="What is the capital of France?",
|
||||||
)
|
)
|
||||||
assert response is not None
|
assert response is not None
|
||||||
|
@ -113,7 +126,7 @@ def test_insert_chunks(client_with_empty_registry, embedding_model_id, embedding
|
||||||
|
|
||||||
query, expected_doc_id = test_case
|
query, expected_doc_id = test_case
|
||||||
response = client_with_empty_registry.vector_io.query(
|
response = client_with_empty_registry.vector_io.query(
|
||||||
vector_db_id=vector_db_id,
|
vector_db_id=actual_vector_db_id,
|
||||||
query=query,
|
query=query,
|
||||||
)
|
)
|
||||||
assert response is not None
|
assert response is not None
|
||||||
|
@ -128,13 +141,15 @@ def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, e
|
||||||
"remote::qdrant": {"score_threshold": -1.0},
|
"remote::qdrant": {"score_threshold": -1.0},
|
||||||
"inline::qdrant": {"score_threshold": -1.0},
|
"inline::qdrant": {"score_threshold": -1.0},
|
||||||
}
|
}
|
||||||
vector_db_id = "test_precomputed_embeddings_db"
|
vector_db_name = "test_precomputed_embeddings_db"
|
||||||
client_with_empty_registry.vector_dbs.register(
|
register_response = client_with_empty_registry.vector_dbs.register(
|
||||||
vector_db_id=vector_db_id,
|
vector_db_id=vector_db_name,
|
||||||
embedding_model=embedding_model_id,
|
embedding_model=embedding_model_id,
|
||||||
embedding_dimension=embedding_dimension,
|
embedding_dimension=embedding_dimension,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
actual_vector_db_id = register_response.identifier
|
||||||
|
|
||||||
chunks_with_embeddings = [
|
chunks_with_embeddings = [
|
||||||
Chunk(
|
Chunk(
|
||||||
content="This is a test chunk with precomputed embedding.",
|
content="This is a test chunk with precomputed embedding.",
|
||||||
|
@ -144,13 +159,13 @@ def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, e
|
||||||
]
|
]
|
||||||
|
|
||||||
client_with_empty_registry.vector_io.insert(
|
client_with_empty_registry.vector_io.insert(
|
||||||
vector_db_id=vector_db_id,
|
vector_db_id=actual_vector_db_id,
|
||||||
chunks=chunks_with_embeddings,
|
chunks=chunks_with_embeddings,
|
||||||
)
|
)
|
||||||
|
|
||||||
provider = [p.provider_id for p in client_with_empty_registry.providers.list() if p.api == "vector_io"][0]
|
provider = [p.provider_id for p in client_with_empty_registry.providers.list() if p.api == "vector_io"][0]
|
||||||
response = client_with_empty_registry.vector_io.query(
|
response = client_with_empty_registry.vector_io.query(
|
||||||
vector_db_id=vector_db_id,
|
vector_db_id=actual_vector_db_id,
|
||||||
query="precomputed embedding test",
|
query="precomputed embedding test",
|
||||||
params=vector_io_provider_params_dict.get(provider, None),
|
params=vector_io_provider_params_dict.get(provider, None),
|
||||||
)
|
)
|
||||||
|
@ -173,13 +188,15 @@ def test_query_returns_valid_object_when_identical_to_embedding_in_vdb(
|
||||||
"remote::qdrant": {"score_threshold": 0.0},
|
"remote::qdrant": {"score_threshold": 0.0},
|
||||||
"inline::qdrant": {"score_threshold": 0.0},
|
"inline::qdrant": {"score_threshold": 0.0},
|
||||||
}
|
}
|
||||||
vector_db_id = "test_precomputed_embeddings_db"
|
vector_db_name = "test_precomputed_embeddings_db"
|
||||||
client_with_empty_registry.vector_dbs.register(
|
register_response = client_with_empty_registry.vector_dbs.register(
|
||||||
vector_db_id=vector_db_id,
|
vector_db_id=vector_db_name,
|
||||||
embedding_model=embedding_model_id,
|
embedding_model=embedding_model_id,
|
||||||
embedding_dimension=embedding_dimension,
|
embedding_dimension=embedding_dimension,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
actual_vector_db_id = register_response.identifier
|
||||||
|
|
||||||
chunks_with_embeddings = [
|
chunks_with_embeddings = [
|
||||||
Chunk(
|
Chunk(
|
||||||
content="duplicate",
|
content="duplicate",
|
||||||
|
@ -189,13 +206,13 @@ def test_query_returns_valid_object_when_identical_to_embedding_in_vdb(
|
||||||
]
|
]
|
||||||
|
|
||||||
client_with_empty_registry.vector_io.insert(
|
client_with_empty_registry.vector_io.insert(
|
||||||
vector_db_id=vector_db_id,
|
vector_db_id=actual_vector_db_id,
|
||||||
chunks=chunks_with_embeddings,
|
chunks=chunks_with_embeddings,
|
||||||
)
|
)
|
||||||
|
|
||||||
provider = [p.provider_id for p in client_with_empty_registry.providers.list() if p.api == "vector_io"][0]
|
provider = [p.provider_id for p in client_with_empty_registry.providers.list() if p.api == "vector_io"][0]
|
||||||
response = client_with_empty_registry.vector_io.query(
|
response = client_with_empty_registry.vector_io.query(
|
||||||
vector_db_id=vector_db_id,
|
vector_db_id=actual_vector_db_id,
|
||||||
query="duplicate",
|
query="duplicate",
|
||||||
params=vector_io_provider_params_dict.get(provider, None),
|
params=vector_io_provider_params_dict.get(provider, None),
|
||||||
)
|
)
|
||||||
|
|
|
@ -146,6 +146,20 @@ class VectorDBImpl(Impl):
|
||||||
async def unregister_vector_db(self, vector_db_id: str):
|
async def unregister_vector_db(self, vector_db_id: str):
|
||||||
return vector_db_id
|
return vector_db_id
|
||||||
|
|
||||||
|
async def openai_create_vector_store(self, **kwargs):
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from llama_stack.apis.vector_io.vector_io import VectorStoreFileCounts, VectorStoreObject
|
||||||
|
|
||||||
|
vector_store_id = kwargs.get("provider_vector_db_id") or f"vs_{uuid.uuid4()}"
|
||||||
|
return VectorStoreObject(
|
||||||
|
id=vector_store_id,
|
||||||
|
name=kwargs.get("name", vector_store_id),
|
||||||
|
created_at=int(time.time()),
|
||||||
|
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def test_models_routing_table(cached_disk_dist_registry):
|
async def test_models_routing_table(cached_disk_dist_registry):
|
||||||
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||||
|
@ -247,17 +261,21 @@ async def test_vectordbs_routing_table(cached_disk_dist_registry):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Register multiple vector databases and verify listing
|
# Register multiple vector databases and verify listing
|
||||||
await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test_provider/test-model")
|
vdb1 = await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test_provider/test-model")
|
||||||
await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test_provider/test-model")
|
vdb2 = await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test_provider/test-model")
|
||||||
vector_dbs = await table.list_vector_dbs()
|
vector_dbs = await table.list_vector_dbs()
|
||||||
|
|
||||||
assert len(vector_dbs.data) == 2
|
assert len(vector_dbs.data) == 2
|
||||||
vector_db_ids = {v.identifier for v in vector_dbs.data}
|
vector_db_ids = {v.identifier for v in vector_dbs.data}
|
||||||
assert "test-vectordb" in vector_db_ids
|
assert vdb1.identifier in vector_db_ids
|
||||||
assert "test-vectordb-2" in vector_db_ids
|
assert vdb2.identifier in vector_db_ids
|
||||||
|
|
||||||
await table.unregister_vector_db(vector_db_id="test-vectordb")
|
# Verify they have UUID-based identifiers
|
||||||
await table.unregister_vector_db(vector_db_id="test-vectordb-2")
|
assert vdb1.identifier.startswith("vs_")
|
||||||
|
assert vdb2.identifier.startswith("vs_")
|
||||||
|
|
||||||
|
await table.unregister_vector_db(vector_db_id=vdb1.identifier)
|
||||||
|
await table.unregister_vector_db(vector_db_id=vdb2.identifier)
|
||||||
|
|
||||||
vector_dbs = await table.list_vector_dbs()
|
vector_dbs = await table.list_vector_dbs()
|
||||||
assert len(vector_dbs.data) == 0
|
assert len(vector_dbs.data) == 0
|
||||||
|
|
|
@ -7,6 +7,7 @@
|
||||||
# Unit tests for the routing tables vector_dbs
|
# Unit tests for the routing tables vector_dbs
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
import uuid
|
||||||
from unittest.mock import AsyncMock
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -34,6 +35,7 @@ from tests.unit.distribution.routers.test_routing_tables import Impl, InferenceI
|
||||||
class VectorDBImpl(Impl):
|
class VectorDBImpl(Impl):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(Api.vector_io)
|
super().__init__(Api.vector_io)
|
||||||
|
self.vector_stores = {}
|
||||||
|
|
||||||
async def register_vector_db(self, vector_db: VectorDB):
|
async def register_vector_db(self, vector_db: VectorDB):
|
||||||
return vector_db
|
return vector_db
|
||||||
|
@ -114,8 +116,35 @@ class VectorDBImpl(Impl):
|
||||||
async def openai_delete_vector_store_file(self, vector_store_id, file_id):
|
async def openai_delete_vector_store_file(self, vector_store_id, file_id):
|
||||||
return VectorStoreFileDeleteResponse(id=file_id, deleted=True)
|
return VectorStoreFileDeleteResponse(id=file_id, deleted=True)
|
||||||
|
|
||||||
|
async def openai_create_vector_store(
|
||||||
|
self,
|
||||||
|
name=None,
|
||||||
|
embedding_model=None,
|
||||||
|
embedding_dimension=None,
|
||||||
|
provider_id=None,
|
||||||
|
provider_vector_db_id=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
vector_store_id = provider_vector_db_id or f"vs_{uuid.uuid4()}"
|
||||||
|
vector_store = VectorStoreObject(
|
||||||
|
id=vector_store_id,
|
||||||
|
name=name or vector_store_id,
|
||||||
|
created_at=int(time.time()),
|
||||||
|
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
|
||||||
|
)
|
||||||
|
self.vector_stores[vector_store_id] = vector_store
|
||||||
|
return vector_store
|
||||||
|
|
||||||
|
async def openai_list_vector_stores(self, **kwargs):
|
||||||
|
from llama_stack.apis.vector_io.vector_io import VectorStoreListResponse
|
||||||
|
|
||||||
|
return VectorStoreListResponse(
|
||||||
|
data=list(self.vector_stores.values()), has_more=False, first_id=None, last_id=None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def test_vectordbs_routing_table(cached_disk_dist_registry):
|
async def test_vectordbs_routing_table(cached_disk_dist_registry):
|
||||||
|
n = 10
|
||||||
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
|
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
|
||||||
await table.initialize()
|
await table.initialize()
|
||||||
|
|
||||||
|
@ -129,22 +158,98 @@ async def test_vectordbs_routing_table(cached_disk_dist_registry):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Register multiple vector databases and verify listing
|
# Register multiple vector databases and verify listing
|
||||||
await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test-model")
|
vdb_dict = {}
|
||||||
await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test-model")
|
for i in range(n):
|
||||||
|
vdb_dict[i] = await table.register_vector_db(vector_db_id=f"test-vectordb-{i}", embedding_model="test-model")
|
||||||
|
|
||||||
vector_dbs = await table.list_vector_dbs()
|
vector_dbs = await table.list_vector_dbs()
|
||||||
|
|
||||||
assert len(vector_dbs.data) == 2
|
assert len(vector_dbs.data) == len(vdb_dict)
|
||||||
vector_db_ids = {v.identifier for v in vector_dbs.data}
|
vector_db_ids = {v.identifier for v in vector_dbs.data}
|
||||||
assert "test-vectordb" in vector_db_ids
|
for k in vdb_dict:
|
||||||
assert "test-vectordb-2" in vector_db_ids
|
assert vdb_dict[k].identifier in vector_db_ids
|
||||||
|
for k in vdb_dict:
|
||||||
await table.unregister_vector_db(vector_db_id="test-vectordb")
|
await table.unregister_vector_db(vector_db_id=vdb_dict[k].identifier)
|
||||||
await table.unregister_vector_db(vector_db_id="test-vectordb-2")
|
|
||||||
|
|
||||||
vector_dbs = await table.list_vector_dbs()
|
vector_dbs = await table.list_vector_dbs()
|
||||||
assert len(vector_dbs.data) == 0
|
assert len(vector_dbs.data) == 0
|
||||||
|
|
||||||
|
|
||||||
|
async def test_vector_db_and_vector_store_id_mapping(cached_disk_dist_registry):
|
||||||
|
n = 10
|
||||||
|
impl = VectorDBImpl()
|
||||||
|
table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, {})
|
||||||
|
await table.initialize()
|
||||||
|
|
||||||
|
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||||
|
await m_table.initialize()
|
||||||
|
await m_table.register_model(
|
||||||
|
model_id="test-model",
|
||||||
|
provider_id="test_provider",
|
||||||
|
metadata={"embedding_dimension": 128},
|
||||||
|
model_type=ModelType.embedding,
|
||||||
|
)
|
||||||
|
|
||||||
|
vdb_dict = {}
|
||||||
|
for i in range(n):
|
||||||
|
vdb_dict[i] = await table.register_vector_db(vector_db_id=f"test-vectordb-{i}", embedding_model="test-model")
|
||||||
|
|
||||||
|
vector_dbs = await table.list_vector_dbs()
|
||||||
|
vector_db_ids = {v.identifier for v in vector_dbs.data}
|
||||||
|
|
||||||
|
vector_stores = await impl.openai_list_vector_stores()
|
||||||
|
vector_store_ids = {v.id for v in vector_stores.data}
|
||||||
|
|
||||||
|
assert vector_db_ids == vector_store_ids, (
|
||||||
|
f"Vector DB IDs {vector_db_ids} don't match vector store IDs {vector_store_ids}"
|
||||||
|
)
|
||||||
|
|
||||||
|
for vector_store in vector_stores.data:
|
||||||
|
vector_db = await table.get_vector_db(vector_store.id)
|
||||||
|
assert vector_store.name == vector_db.vector_db_name, (
|
||||||
|
f"Vector store name {vector_store.name} doesn't match vector store ID {vector_store.id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
for vector_db_id in vector_db_ids:
|
||||||
|
await table.unregister_vector_db(vector_db_id)
|
||||||
|
|
||||||
|
assert len((await table.list_vector_dbs()).data) == 0
|
||||||
|
|
||||||
|
|
||||||
|
async def test_vector_db_id_becomes_vector_store_name(cached_disk_dist_registry):
|
||||||
|
impl = VectorDBImpl()
|
||||||
|
table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, {})
|
||||||
|
await table.initialize()
|
||||||
|
|
||||||
|
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||||
|
await m_table.initialize()
|
||||||
|
await m_table.register_model(
|
||||||
|
model_id="test-model",
|
||||||
|
provider_id="test_provider",
|
||||||
|
metadata={"embedding_dimension": 128},
|
||||||
|
model_type=ModelType.embedding,
|
||||||
|
)
|
||||||
|
|
||||||
|
user_provided_id = "my-custom-vector-db"
|
||||||
|
await table.register_vector_db(vector_db_id=user_provided_id, embedding_model="test-model")
|
||||||
|
|
||||||
|
vector_stores = await impl.openai_list_vector_stores()
|
||||||
|
assert len(vector_stores.data) == 1
|
||||||
|
|
||||||
|
vector_store = vector_stores.data[0]
|
||||||
|
|
||||||
|
assert vector_store.name == user_provided_id
|
||||||
|
|
||||||
|
assert vector_store.id.startswith("vs_")
|
||||||
|
assert vector_store.id != user_provided_id
|
||||||
|
|
||||||
|
vector_dbs = await table.list_vector_dbs()
|
||||||
|
assert len(vector_dbs.data) == 1
|
||||||
|
assert vector_dbs.data[0].identifier == vector_store.id
|
||||||
|
|
||||||
|
await table.unregister_vector_db(vector_store.id)
|
||||||
|
|
||||||
|
|
||||||
async def test_openai_vector_stores_routing_table_roles(cached_disk_dist_registry):
|
async def test_openai_vector_stores_routing_table_roles(cached_disk_dist_registry):
|
||||||
impl = VectorDBImpl()
|
impl = VectorDBImpl()
|
||||||
impl.openai_retrieve_vector_store = AsyncMock(return_value="OK")
|
impl.openai_retrieve_vector_store = AsyncMock(return_value="OK")
|
||||||
|
@ -164,7 +269,8 @@ async def test_openai_vector_stores_routing_table_roles(cached_disk_dist_registr
|
||||||
|
|
||||||
authorized_user = User(principal="alice", attributes={"roles": [authorized_team]})
|
authorized_user = User(principal="alice", attributes={"roles": [authorized_team]})
|
||||||
with request_provider_data_context({}, authorized_user):
|
with request_provider_data_context({}, authorized_user):
|
||||||
_ = await table.register_vector_db(vector_db_id="vs1", embedding_model="test-model")
|
registered_vdb = await table.register_vector_db(vector_db_id="vs1", embedding_model="test-model")
|
||||||
|
authorized_table = registered_vdb.identifier # Use the actual generated ID
|
||||||
|
|
||||||
# Authorized reader
|
# Authorized reader
|
||||||
with request_provider_data_context({}, authorized_user):
|
with request_provider_data_context({}, authorized_user):
|
||||||
|
@ -227,7 +333,8 @@ async def test_openai_vector_stores_routing_table_actions(cached_disk_dist_regis
|
||||||
)
|
)
|
||||||
|
|
||||||
with request_provider_data_context({}, admin_user):
|
with request_provider_data_context({}, admin_user):
|
||||||
await table.register_vector_db(vector_db_id=vector_db_id, embedding_model="test-model")
|
registered_vdb = await table.register_vector_db(vector_db_id=vector_db_id, embedding_model="test-model")
|
||||||
|
vector_db_id = registered_vdb.identifier # Use the actual generated ID
|
||||||
|
|
||||||
read_methods = [
|
read_methods = [
|
||||||
(table.openai_retrieve_vector_store, (vector_db_id,), {}),
|
(table.openai_retrieve_vector_store, (vector_db_id,), {}),
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue