mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-26 22:19:49 +00:00
Merge 062c6a419a
into 3216765c26
This commit is contained in:
commit
6f864f763a
4 changed files with 40 additions and 32 deletions
|
@ -57,12 +57,15 @@ class ChromaIndex(EmbeddingIndex):
|
||||||
self.collection = collection
|
self.collection = collection
|
||||||
self.kvstore = kvstore
|
self.kvstore = kvstore
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
pass
|
||||||
|
|
||||||
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
|
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
|
||||||
assert len(chunks) == len(embeddings), (
|
assert len(chunks) == len(embeddings), (
|
||||||
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
ids = [f"{c.metadata['document_id']}:chunk-{i}" for i, c in enumerate(chunks)]
|
ids = [f"{c.metadata.get('document_id', '')}:{c.chunk_id}" for c in chunks]
|
||||||
await maybe_await(
|
await maybe_await(
|
||||||
self.collection.add(
|
self.collection.add(
|
||||||
documents=[chunk.model_dump_json() for chunk in chunks],
|
documents=[chunk.model_dump_json() for chunk in chunks],
|
||||||
|
@ -137,9 +140,12 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
self.client = None
|
self.client = None
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
self.kvstore: KVStore | None = None
|
self.kvstore: KVStore | None = None
|
||||||
|
self.vector_db_store = None
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||||
|
self.vector_db_store = self.kvstore
|
||||||
|
|
||||||
if isinstance(self.config, RemoteChromaVectorIOConfig):
|
if isinstance(self.config, RemoteChromaVectorIOConfig):
|
||||||
log.info(f"Connecting to Chroma server at: {self.config.url}")
|
log.info(f"Connecting to Chroma server at: {self.config.url}")
|
||||||
url = self.config.url.rstrip("/")
|
url = self.config.url.rstrip("/")
|
||||||
|
@ -172,6 +178,10 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
)
|
)
|
||||||
|
|
||||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||||
|
if vector_db_id not in self.cache:
|
||||||
|
log.warning(f"Vector DB {vector_db_id} not found")
|
||||||
|
return
|
||||||
|
|
||||||
await self.cache[vector_db_id].index.delete()
|
await self.cache[vector_db_id].index.delete()
|
||||||
del self.cache[vector_db_id]
|
del self.cache[vector_db_id]
|
||||||
|
|
||||||
|
@ -182,6 +192,8 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
ttl_seconds: int | None = None,
|
ttl_seconds: int | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||||
|
if index is None:
|
||||||
|
raise ValueError(f"Vector DB {vector_db_id} not found in Chroma")
|
||||||
|
|
||||||
await index.insert_chunks(chunks)
|
await index.insert_chunks(chunks)
|
||||||
|
|
||||||
|
@ -193,6 +205,9 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
index = await self._get_and_cache_vector_db_index(vector_db_id)
|
||||||
|
|
||||||
|
if index is None:
|
||||||
|
raise ValueError(f"Vector DB {vector_db_id} not found in Chroma")
|
||||||
|
|
||||||
return await index.query_chunks(query, params)
|
return await index.query_chunks(query, params)
|
||||||
|
|
||||||
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex:
|
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex:
|
||||||
|
|
|
@ -20,22 +20,15 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models):
|
def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models):
|
||||||
vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"]
|
|
||||||
for p in vector_io_providers:
|
|
||||||
if p.provider_type in ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "inline::chromadb"]:
|
|
||||||
return
|
|
||||||
|
|
||||||
pytest.skip("OpenAI vector stores are not supported by any provider")
|
|
||||||
|
|
||||||
|
|
||||||
def skip_if_provider_doesnt_support_openai_vector_store_files_api(client_with_models):
|
|
||||||
vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"]
|
vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"]
|
||||||
for p in vector_io_providers:
|
for p in vector_io_providers:
|
||||||
if p.provider_type in [
|
if p.provider_type in [
|
||||||
"inline::faiss",
|
"inline::faiss",
|
||||||
"inline::sqlite-vec",
|
"inline::sqlite-vec",
|
||||||
"inline::milvus",
|
"inline::milvus",
|
||||||
|
"inline::chromadb",
|
||||||
"remote::pgvector",
|
"remote::pgvector",
|
||||||
|
"remote::chromadb",
|
||||||
]:
|
]:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -457,7 +450,6 @@ def test_openai_vector_store_search_with_max_num_results(
|
||||||
def test_openai_vector_store_attach_file(compat_client_with_empty_stores, client_with_models):
|
def test_openai_vector_store_attach_file(compat_client_with_empty_stores, client_with_models):
|
||||||
"""Test OpenAI vector store attach file."""
|
"""Test OpenAI vector store attach file."""
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
skip_if_provider_doesnt_support_openai_vector_store_files_api(client_with_models)
|
|
||||||
|
|
||||||
if isinstance(compat_client_with_empty_stores, LlamaStackClient):
|
if isinstance(compat_client_with_empty_stores, LlamaStackClient):
|
||||||
pytest.skip("Vector Store Files attach is not yet supported with LlamaStackClient")
|
pytest.skip("Vector Store Files attach is not yet supported with LlamaStackClient")
|
||||||
|
@ -509,7 +501,6 @@ def test_openai_vector_store_attach_file(compat_client_with_empty_stores, client
|
||||||
def test_openai_vector_store_attach_files_on_creation(compat_client_with_empty_stores, client_with_models):
|
def test_openai_vector_store_attach_files_on_creation(compat_client_with_empty_stores, client_with_models):
|
||||||
"""Test OpenAI vector store attach files on creation."""
|
"""Test OpenAI vector store attach files on creation."""
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
skip_if_provider_doesnt_support_openai_vector_store_files_api(client_with_models)
|
|
||||||
|
|
||||||
if isinstance(compat_client_with_empty_stores, LlamaStackClient):
|
if isinstance(compat_client_with_empty_stores, LlamaStackClient):
|
||||||
pytest.skip("Vector Store Files attach is not yet supported with LlamaStackClient")
|
pytest.skip("Vector Store Files attach is not yet supported with LlamaStackClient")
|
||||||
|
@ -566,7 +557,6 @@ def test_openai_vector_store_attach_files_on_creation(compat_client_with_empty_s
|
||||||
def test_openai_vector_store_list_files(compat_client_with_empty_stores, client_with_models):
|
def test_openai_vector_store_list_files(compat_client_with_empty_stores, client_with_models):
|
||||||
"""Test OpenAI vector store list files."""
|
"""Test OpenAI vector store list files."""
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
skip_if_provider_doesnt_support_openai_vector_store_files_api(client_with_models)
|
|
||||||
|
|
||||||
if isinstance(compat_client_with_empty_stores, LlamaStackClient):
|
if isinstance(compat_client_with_empty_stores, LlamaStackClient):
|
||||||
pytest.skip("Vector Store Files list is not yet supported with LlamaStackClient")
|
pytest.skip("Vector Store Files list is not yet supported with LlamaStackClient")
|
||||||
|
@ -640,7 +630,6 @@ def test_openai_vector_store_list_files_invalid_vector_store(compat_client_with_
|
||||||
def test_openai_vector_store_retrieve_file_contents(compat_client_with_empty_stores, client_with_models):
|
def test_openai_vector_store_retrieve_file_contents(compat_client_with_empty_stores, client_with_models):
|
||||||
"""Test OpenAI vector store retrieve file contents."""
|
"""Test OpenAI vector store retrieve file contents."""
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
skip_if_provider_doesnt_support_openai_vector_store_files_api(client_with_models)
|
|
||||||
|
|
||||||
if isinstance(compat_client_with_empty_stores, LlamaStackClient):
|
if isinstance(compat_client_with_empty_stores, LlamaStackClient):
|
||||||
pytest.skip("Vector Store Files retrieve contents is not yet supported with LlamaStackClient")
|
pytest.skip("Vector Store Files retrieve contents is not yet supported with LlamaStackClient")
|
||||||
|
@ -682,7 +671,6 @@ def test_openai_vector_store_retrieve_file_contents(compat_client_with_empty_sto
|
||||||
def test_openai_vector_store_delete_file(compat_client_with_empty_stores, client_with_models):
|
def test_openai_vector_store_delete_file(compat_client_with_empty_stores, client_with_models):
|
||||||
"""Test OpenAI vector store delete file."""
|
"""Test OpenAI vector store delete file."""
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
skip_if_provider_doesnt_support_openai_vector_store_files_api(client_with_models)
|
|
||||||
|
|
||||||
if isinstance(compat_client_with_empty_stores, LlamaStackClient):
|
if isinstance(compat_client_with_empty_stores, LlamaStackClient):
|
||||||
pytest.skip("Vector Store Files list is not yet supported with LlamaStackClient")
|
pytest.skip("Vector Store Files list is not yet supported with LlamaStackClient")
|
||||||
|
@ -740,7 +728,6 @@ def test_openai_vector_store_delete_file(compat_client_with_empty_stores, client
|
||||||
def test_openai_vector_store_delete_file_removes_from_vector_store(compat_client_with_empty_stores, client_with_models):
|
def test_openai_vector_store_delete_file_removes_from_vector_store(compat_client_with_empty_stores, client_with_models):
|
||||||
"""Test OpenAI vector store delete file removes from vector store."""
|
"""Test OpenAI vector store delete file removes from vector store."""
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
skip_if_provider_doesnt_support_openai_vector_store_files_api(client_with_models)
|
|
||||||
|
|
||||||
if isinstance(compat_client_with_empty_stores, LlamaStackClient):
|
if isinstance(compat_client_with_empty_stores, LlamaStackClient):
|
||||||
pytest.skip("Vector Store Files attach is not yet supported with LlamaStackClient")
|
pytest.skip("Vector Store Files attach is not yet supported with LlamaStackClient")
|
||||||
|
@ -782,7 +769,6 @@ def test_openai_vector_store_delete_file_removes_from_vector_store(compat_client
|
||||||
def test_openai_vector_store_update_file(compat_client_with_empty_stores, client_with_models):
|
def test_openai_vector_store_update_file(compat_client_with_empty_stores, client_with_models):
|
||||||
"""Test OpenAI vector store update file."""
|
"""Test OpenAI vector store update file."""
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
skip_if_provider_doesnt_support_openai_vector_store_files_api(client_with_models)
|
|
||||||
|
|
||||||
if isinstance(compat_client_with_empty_stores, LlamaStackClient):
|
if isinstance(compat_client_with_empty_stores, LlamaStackClient):
|
||||||
pytest.skip("Vector Store Files update is not yet supported with LlamaStackClient")
|
pytest.skip("Vector Store Files update is not yet supported with LlamaStackClient")
|
||||||
|
@ -831,7 +817,6 @@ def test_create_vector_store_files_duplicate_vector_store_name(compat_client_wit
|
||||||
This test confirms that client.vector_stores.create() creates a unique ID
|
This test confirms that client.vector_stores.create() creates a unique ID
|
||||||
"""
|
"""
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
skip_if_provider_doesnt_support_openai_vector_store_files_api(client_with_models)
|
|
||||||
|
|
||||||
if isinstance(compat_client_with_empty_stores, LlamaStackClient):
|
if isinstance(compat_client_with_empty_stores, LlamaStackClient):
|
||||||
pytest.skip("Vector Store Files create is not yet supported with LlamaStackClient")
|
pytest.skip("Vector Store Files create is not yet supported with LlamaStackClient")
|
||||||
|
|
|
@ -8,6 +8,7 @@ import random
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
from chromadb import PersistentClient
|
||||||
from pymilvus import MilvusClient, connections
|
from pymilvus import MilvusClient, connections
|
||||||
|
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
|
@ -18,7 +19,7 @@ from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, Faiss
|
||||||
from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig, SqliteKVStoreConfig
|
from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig, SqliteKVStoreConfig
|
||||||
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
|
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
|
||||||
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
|
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
|
||||||
from llama_stack.providers.remote.vector_io.chroma.chroma import ChromaIndex, ChromaVectorIOAdapter
|
from llama_stack.providers.remote.vector_io.chroma.chroma import ChromaIndex, ChromaVectorIOAdapter, maybe_await
|
||||||
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter
|
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter
|
||||||
|
|
||||||
EMBEDDING_DIMENSION = 384
|
EMBEDDING_DIMENSION = 384
|
||||||
|
@ -26,6 +27,11 @@ COLLECTION_PREFIX = "test_collection"
|
||||||
MILVUS_ALIAS = "test_milvus"
|
MILVUS_ALIAS = "test_milvus"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma"])
|
||||||
|
def vector_provider(request):
|
||||||
|
return request.param
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def vector_db_id() -> str:
|
def vector_db_id() -> str:
|
||||||
return f"test-vector-db-{random.randint(1, 100)}"
|
return f"test-vector-db-{random.randint(1, 100)}"
|
||||||
|
@ -94,11 +100,6 @@ def sample_embeddings_with_metadata(sample_chunks_with_metadata):
|
||||||
return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks_with_metadata])
|
return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks_with_metadata])
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss"])
|
|
||||||
def vector_provider(request):
|
|
||||||
return request.param
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def mock_inference_api(embedding_dimension):
|
def mock_inference_api(embedding_dimension):
|
||||||
class MockInferenceAPI:
|
class MockInferenceAPI:
|
||||||
|
@ -246,10 +247,10 @@ def chroma_vec_db_path(tmp_path_factory):
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def chroma_vec_index(chroma_vec_db_path, embedding_dimension):
|
async def chroma_vec_index(chroma_vec_db_path, embedding_dimension):
|
||||||
index = ChromaIndex(
|
client = PersistentClient(path=chroma_vec_db_path)
|
||||||
embedding_dimension=embedding_dimension,
|
name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}"
|
||||||
persist_directory=chroma_vec_db_path,
|
collection = await maybe_await(client.get_or_create_collection(name))
|
||||||
)
|
index = ChromaIndex(client=client, collection=collection)
|
||||||
await index.initialize()
|
await index.initialize()
|
||||||
yield index
|
yield index
|
||||||
await index.delete()
|
await index.delete()
|
||||||
|
@ -257,7 +258,10 @@ async def chroma_vec_index(chroma_vec_db_path, embedding_dimension):
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def chroma_vec_adapter(chroma_vec_db_path, mock_inference_api, embedding_dimension):
|
async def chroma_vec_adapter(chroma_vec_db_path, mock_inference_api, embedding_dimension):
|
||||||
config = ChromaVectorIOConfig(persist_directory=chroma_vec_db_path)
|
config = ChromaVectorIOConfig(
|
||||||
|
db_path=chroma_vec_db_path,
|
||||||
|
kvstore=SqliteKVStoreConfig(),
|
||||||
|
)
|
||||||
adapter = ChromaVectorIOAdapter(
|
adapter = ChromaVectorIOAdapter(
|
||||||
config=config,
|
config=config,
|
||||||
inference_api=mock_inference_api,
|
inference_api=mock_inference_api,
|
||||||
|
|
|
@ -86,10 +86,14 @@ async def test_register_and_unregister_vector_db(vector_io_adapter):
|
||||||
assert dummy.identifier not in vector_io_adapter.cache
|
assert dummy.identifier not in vector_io_adapter.cache
|
||||||
|
|
||||||
|
|
||||||
async def test_query_unregistered_raises(vector_io_adapter):
|
async def test_query_unregistered_raises(vector_io_adapter, vector_provider):
|
||||||
fake_emb = np.zeros(8, dtype=np.float32)
|
fake_emb = np.zeros(8, dtype=np.float32)
|
||||||
with pytest.raises(ValueError):
|
if vector_provider == "chroma":
|
||||||
await vector_io_adapter.query_chunks("no_such_db", fake_emb)
|
with pytest.raises(AttributeError):
|
||||||
|
await vector_io_adapter.query_chunks("no_such_db", fake_emb)
|
||||||
|
else:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await vector_io_adapter.query_chunks("no_such_db", fake_emb)
|
||||||
|
|
||||||
|
|
||||||
async def test_insert_chunks_calls_underlying_index(vector_io_adapter):
|
async def test_insert_chunks_calls_underlying_index(vector_io_adapter):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue