diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py index f19a2bffc..156cda385 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/distribution/routers/__init__.py @@ -14,11 +14,11 @@ from llama_stack.providers.datatypes import Api, RoutingTable from .routing_tables import ( DatasetsRoutingTable, EvalTasksRoutingTable, - MemoryBanksRoutingTable, ModelsRoutingTable, ScoringFunctionsRoutingTable, ShieldsRoutingTable, ToolGroupsRoutingTable, + VectorDBsRoutingTable, ) @@ -29,7 +29,7 @@ async def get_routing_table_impl( dist_registry: DistributionRegistry, ) -> Any: api_to_tables = { - "memory_banks": MemoryBanksRoutingTable, + "vector_dbs": VectorDBsRoutingTable, "models": ModelsRoutingTable, "shields": ShieldsRoutingTable, "datasets": DatasetsRoutingTable, @@ -51,14 +51,14 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> DatasetIORouter, EvalRouter, InferenceRouter, - MemoryRouter, SafetyRouter, ScoringRouter, ToolRuntimeRouter, + VectorIORouter, ) api_to_routers = { - "memory": MemoryRouter, + "vector_io": VectorIORouter, "inference": InferenceRouter, "safety": SafetyRouter, "datasetio": DatasetIORouter, diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 8080b9dff..979c68b72 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -27,8 +27,6 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) -from llama_stack.apis.memory import Memory, MemoryBankDocument, QueryDocumentsResponse -from llama_stack.apis.memory_banks.memory_banks import BankParams from llama_stack.apis.models import ModelType from llama_stack.apis.safety import RunShieldResponse, Safety from llama_stack.apis.scoring import ( @@ -39,11 +37,12 @@ from llama_stack.apis.scoring import ( ) from llama_stack.apis.shields import Shield from llama_stack.apis.tools import ToolDef, ToolRuntime +from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.providers.datatypes import RoutingTable -class MemoryRouter(Memory): - """Routes to an provider based on the memory bank identifier""" +class VectorIORouter(VectorIO): + """Routes to an provider based on the vector db identifier""" def __init__( self, @@ -57,38 +56,40 @@ class MemoryRouter(Memory): async def shutdown(self) -> None: pass - async def register_memory_bank( + async def register_vector_db( self, - memory_bank_id: str, - params: BankParams, + vector_db_id: str, + embedding_model: str, + embedding_dimension: Optional[int] = 384, provider_id: Optional[str] = None, - provider_memorybank_id: Optional[str] = None, + provider_vector_db_id: Optional[str] = None, ) -> None: - await self.routing_table.register_memory_bank( - memory_bank_id, - params, + await self.routing_table.register_vector_db( + vector_db_id, + embedding_model, + embedding_dimension, provider_id, - provider_memorybank_id, + provider_vector_db_id, ) - async def insert_documents( + async def insert_chunks( self, - bank_id: str, - documents: List[MemoryBankDocument], + vector_db_id: str, + chunks: List[Chunk], ttl_seconds: Optional[int] = None, ) -> None: - return await self.routing_table.get_provider_impl(bank_id).insert_documents( - bank_id, documents, ttl_seconds + return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks( + vector_db_id, chunks, ttl_seconds ) - async def query_documents( + async def query_chunks( self, - bank_id: str, + vector_db_id: str, query: InterleavedContent, params: Optional[Dict[str, Any]] = None, - ) -> QueryDocumentsResponse: - return await self.routing_table.get_provider_impl(bank_id).query_documents( - bank_id, query, params + ) -> QueryChunksResponse: + return await self.routing_table.get_provider_impl(vector_db_id).query_chunks( + vector_db_id, query, params ) diff --git a/llama_stack/providers/inline/vector_io/faiss/__init__.py b/llama_stack/providers/inline/vector_io/faiss/__init__.py index 2d7ede3b1..32cf262fd 100644 --- a/llama_stack/providers/inline/vector_io/faiss/__init__.py +++ b/llama_stack/providers/inline/vector_io/faiss/__init__.py @@ -11,12 +11,12 @@ from .config import FaissImplConfig async def get_provider_impl(config: FaissImplConfig, deps: Dict[Api, ProviderSpec]): - from .faiss import FaissMemoryImpl + from .faiss import FaissVectorIOImpl assert isinstance( config, FaissImplConfig ), f"Unexpected config type: {type(config)}" - impl = FaissMemoryImpl(config, deps[Api.inference]) + impl = FaissVectorIOImpl(config, deps[Api.inference]) await impl.initialize() return impl diff --git a/llama_stack/providers/inline/vector_io/faiss/faiss.py b/llama_stack/providers/inline/vector_io/faiss/faiss.py index af398801a..db53302bb 100644 --- a/llama_stack/providers/inline/vector_io/faiss/faiss.py +++ b/llama_stack/providers/inline/vector_io/faiss/faiss.py @@ -17,35 +17,28 @@ import numpy as np from numpy.typing import NDArray from llama_stack.apis.inference import InterleavedContent -from llama_stack.apis.memory import ( - Chunk, - Memory, - MemoryBankDocument, - QueryDocumentsResponse, -) -from llama_stack.apis.memory_banks import MemoryBank, MemoryBankType, VectorMemoryBank -from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate +from llama_stack.apis.vector_dbs import VectorDB +from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO +from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.memory.vector_store import ( - BankWithIndex, EmbeddingIndex, + VectorDBWithIndex, ) from .config import FaissImplConfig logger = logging.getLogger(__name__) -MEMORY_BANKS_PREFIX = "memory_banks:v2::" +VECTOR_DBS_PREFIX = "vector_dbs:v2::" FAISS_INDEX_PREFIX = "faiss_index:v2::" class FaissIndex(EmbeddingIndex): - id_by_index: Dict[int, str] chunk_by_index: Dict[int, str] def __init__(self, dimension: int, kvstore=None, bank_id: str = None): self.index = faiss.IndexFlatL2(dimension) - self.id_by_index = {} self.chunk_by_index = {} self.kvstore = kvstore self.bank_id = bank_id @@ -65,7 +58,6 @@ class FaissIndex(EmbeddingIndex): if stored_data: data = json.loads(stored_data) - self.id_by_index = {int(k): v for k, v in data["id_by_index"].items()} self.chunk_by_index = { int(k): Chunk.model_validate_json(v) for k, v in data["chunk_by_index"].items() @@ -82,7 +74,6 @@ class FaissIndex(EmbeddingIndex): buffer = io.BytesIO() np.savetxt(buffer, np_index) data = { - "id_by_index": self.id_by_index, "chunk_by_index": { k: v.model_dump_json() for k, v in self.chunk_by_index.items() }, @@ -108,10 +99,9 @@ class FaissIndex(EmbeddingIndex): f"Embedding dimension mismatch. Expected {self.index.d}, got {embedding_dim}" ) - indexlen = len(self.id_by_index) + indexlen = len(self.chunk_by_index) for i, chunk in enumerate(chunks): self.chunk_by_index[indexlen + i] = chunk - self.id_by_index[indexlen + i] = chunk.document_id self.index.add(np.array(embeddings).astype(np.float32)) @@ -120,7 +110,7 @@ class FaissIndex(EmbeddingIndex): async def query( self, embedding: NDArray, k: int, score_threshold: float - ) -> QueryDocumentsResponse: + ) -> QueryChunksResponse: distances, indices = self.index.search( embedding.reshape(1, -1).astype(np.float32), k ) @@ -133,10 +123,10 @@ class FaissIndex(EmbeddingIndex): chunks.append(self.chunk_by_index[int(i)]) scores.append(1.0 / float(d)) - return QueryDocumentsResponse(chunks=chunks, scores=scores) + return QueryChunksResponse(chunks=chunks, scores=scores) -class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate): +class FaissVectorIOImpl(VectorIO, VectorDBsProtocolPrivate): def __init__(self, config: FaissImplConfig, inference_api: Api.inference) -> None: self.config = config self.inference_api = inference_api @@ -146,77 +136,74 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate): async def initialize(self) -> None: self.kvstore = await kvstore_impl(self.config.kvstore) # Load existing banks from kvstore - start_key = MEMORY_BANKS_PREFIX - end_key = f"{MEMORY_BANKS_PREFIX}\xff" - stored_banks = await self.kvstore.range(start_key, end_key) + start_key = VECTOR_DBS_PREFIX + end_key = f"{VECTOR_DBS_PREFIX}\xff" + stored_vector_dbs = await self.kvstore.range(start_key, end_key) - for bank_data in stored_banks: - bank = VectorMemoryBank.model_validate_json(bank_data) - index = BankWithIndex( - bank, + for vector_db_data in stored_vector_dbs: + vector_db = VectorDB.model_validate_json(vector_db_data) + index = VectorDBWithIndex( + vector_db, await FaissIndex.create( - bank.embedding_dimension, self.kvstore, bank.identifier + vector_db.embedding_dimension, self.kvstore, vector_db.identifier ), self.inference_api, ) - self.cache[bank.identifier] = index + self.cache[vector_db.identifier] = index async def shutdown(self) -> None: # Cleanup if needed pass - async def register_memory_bank( + async def register_vector_db( self, - memory_bank: MemoryBank, + vector_db: VectorDB, ) -> None: - assert ( - memory_bank.memory_bank_type == MemoryBankType.vector.value - ), f"Only vector banks are supported {memory_bank.type}" - - # Store in kvstore - key = f"{MEMORY_BANKS_PREFIX}{memory_bank.identifier}" + key = f"{VECTOR_DBS_PREFIX}{vector_db.identifier}" await self.kvstore.set( key=key, - value=memory_bank.model_dump_json(), + value=vector_db.model_dump_json(), ) # Store in cache - self.cache[memory_bank.identifier] = BankWithIndex( - memory_bank, - await FaissIndex.create( - memory_bank.embedding_dimension, self.kvstore, memory_bank.identifier + self.cache[vector_db.identifier] = VectorDBWithIndex( + vector_db=vector_db, + index=await FaissIndex.create( + vector_db.embedding_dimension, self.kvstore, vector_db.identifier ), - self.inference_api, + inference_api=self.inference_api, ) - async def list_memory_banks(self) -> List[MemoryBank]: - return [i.bank for i in self.cache.values()] + async def list_vector_dbs(self) -> List[VectorDB]: + return [i.vector_db for i in self.cache.values()] - async def unregister_memory_bank(self, memory_bank_id: str) -> None: - await self.cache[memory_bank_id].index.delete() - del self.cache[memory_bank_id] - await self.kvstore.delete(f"{MEMORY_BANKS_PREFIX}{memory_bank_id}") + async def unregister_vector_db(self, vector_db_id: str) -> None: + await self.cache[vector_db_id].index.delete() + del self.cache[vector_db_id] + await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_db_id}") - async def insert_documents( + async def insert_chunks( self, - bank_id: str, - documents: List[MemoryBankDocument], + vector_db_id: str, + chunks: List[Chunk], ttl_seconds: Optional[int] = None, ) -> None: - index = self.cache.get(bank_id) + index = self.cache.get(vector_db_id) if index is None: - raise ValueError(f"Bank {bank_id} not found. found: {self.cache.keys()}") + raise ValueError( + f"Vector DB {vector_db_id} not found. found: {self.cache.keys()}" + ) - await index.insert_documents(documents) + await index.insert_chunks(chunks) - async def query_documents( + async def query_chunks( self, - bank_id: str, + vector_db_id: str, query: InterleavedContent, params: Optional[Dict[str, Any]] = None, - ) -> QueryDocumentsResponse: - index = self.cache.get(bank_id) + ) -> QueryChunksResponse: + index = self.cache.get(vector_db_id) if index is None: - raise ValueError(f"Bank {bank_id} not found") + raise ValueError(f"Vector DB {vector_db_id} not found") - return await index.query_documents(query, params) + return await index.query_chunks(query, params) diff --git a/llama_stack/providers/registry/agents.py b/llama_stack/providers/registry/agents.py index 3e38b1adc..655303f98 100644 --- a/llama_stack/providers/registry/agents.py +++ b/llama_stack/providers/registry/agents.py @@ -33,8 +33,8 @@ def available_providers() -> List[ProviderSpec]: api_dependencies=[ Api.inference, Api.safety, - Api.memory, - Api.memory_banks, + Api.vector_io, + Api.vector_dbs, Api.tool_runtime, Api.tool_groups, ], diff --git a/llama_stack/providers/registry/tool_runtime.py b/llama_stack/providers/registry/tool_runtime.py index 40299edad..b3ea68949 100644 --- a/llama_stack/providers/registry/tool_runtime.py +++ b/llama_stack/providers/registry/tool_runtime.py @@ -23,7 +23,7 @@ def available_providers() -> List[ProviderSpec]: pip_packages=[], module="llama_stack.providers.inline.tool_runtime.memory", config_class="llama_stack.providers.inline.tool_runtime.memory.config.MemoryToolRuntimeConfig", - api_dependencies=[Api.memory, Api.memory_banks, Api.inference], + api_dependencies=[Api.vector_io, Api.vector_dbs, Api.inference], ), InlineProviderSpec( api=Api.tool_runtime, diff --git a/llama_stack/providers/tests/conftest.py b/llama_stack/providers/tests/conftest.py index 4aa53a687..7d0d2ae74 100644 --- a/llama_stack/providers/tests/conftest.py +++ b/llama_stack/providers/tests/conftest.py @@ -302,7 +302,7 @@ def pytest_collection_modifyitems(session, config, items): pytest_plugins = [ "llama_stack.providers.tests.inference.fixtures", "llama_stack.providers.tests.safety.fixtures", - "llama_stack.providers.tests.memory.fixtures", + "llama_stack.providers.tests.vector_io.fixtures", "llama_stack.providers.tests.agents.fixtures", "llama_stack.providers.tests.datasetio.fixtures", "llama_stack.providers.tests.scoring.fixtures", diff --git a/llama_stack/providers/tests/memory/test_memory.py b/llama_stack/providers/tests/memory/test_memory.py deleted file mode 100644 index 801b04dfc..000000000 --- a/llama_stack/providers/tests/memory/test_memory.py +++ /dev/null @@ -1,192 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import uuid - -import pytest - -from llama_stack.apis.memory import MemoryBankDocument, QueryDocumentsResponse - -from llama_stack.apis.memory_banks import ( - MemoryBank, - MemoryBanks, - VectorMemoryBankParams, -) - -# How to run this test: -# -# pytest llama_stack/providers/tests/memory/test_memory.py -# -m "sentence_transformers" --env EMBEDDING_DIMENSION=384 -# -v -s --tb=short --disable-warnings - - -@pytest.fixture -def sample_documents(): - return [ - MemoryBankDocument( - document_id="doc1", - content="Python is a high-level programming language.", - metadata={"category": "programming", "difficulty": "beginner"}, - ), - MemoryBankDocument( - document_id="doc2", - content="Machine learning is a subset of artificial intelligence.", - metadata={"category": "AI", "difficulty": "advanced"}, - ), - MemoryBankDocument( - document_id="doc3", - content="Data structures are fundamental to computer science.", - metadata={"category": "computer science", "difficulty": "intermediate"}, - ), - MemoryBankDocument( - document_id="doc4", - content="Neural networks are inspired by biological neural networks.", - metadata={"category": "AI", "difficulty": "advanced"}, - ), - ] - - -async def register_memory_bank( - banks_impl: MemoryBanks, embedding_model: str -) -> MemoryBank: - bank_id = f"test_bank_{uuid.uuid4().hex}" - return await banks_impl.register_memory_bank( - memory_bank_id=bank_id, - params=VectorMemoryBankParams( - embedding_model=embedding_model, - chunk_size_in_tokens=512, - overlap_size_in_tokens=64, - ), - ) - - -class TestMemory: - @pytest.mark.asyncio - async def test_banks_list(self, memory_stack, embedding_model): - _, banks_impl = memory_stack - - # Register a test bank - registered_bank = await register_memory_bank(banks_impl, embedding_model) - - try: - # Verify our bank shows up in list - response = await banks_impl.list_memory_banks() - assert isinstance(response, list) - assert any( - bank.memory_bank_id == registered_bank.memory_bank_id - for bank in response - ) - finally: - # Clean up - await banks_impl.unregister_memory_bank(registered_bank.memory_bank_id) - - # Verify our bank was removed - response = await banks_impl.list_memory_banks() - assert all( - bank.memory_bank_id != registered_bank.memory_bank_id for bank in response - ) - - @pytest.mark.asyncio - async def test_banks_register(self, memory_stack, embedding_model): - _, banks_impl = memory_stack - - bank_id = f"test_bank_{uuid.uuid4().hex}" - - try: - # Register initial bank - await banks_impl.register_memory_bank( - memory_bank_id=bank_id, - params=VectorMemoryBankParams( - embedding_model=embedding_model, - chunk_size_in_tokens=512, - overlap_size_in_tokens=64, - ), - ) - - # Verify our bank exists - response = await banks_impl.list_memory_banks() - assert isinstance(response, list) - assert any(bank.memory_bank_id == bank_id for bank in response) - - # Try registering same bank again - await banks_impl.register_memory_bank( - memory_bank_id=bank_id, - params=VectorMemoryBankParams( - embedding_model=embedding_model, - chunk_size_in_tokens=512, - overlap_size_in_tokens=64, - ), - ) - - # Verify still only one instance of our bank - response = await banks_impl.list_memory_banks() - assert isinstance(response, list) - assert ( - len([bank for bank in response if bank.memory_bank_id == bank_id]) == 1 - ) - finally: - # Clean up - await banks_impl.unregister_memory_bank(bank_id) - - @pytest.mark.asyncio - async def test_query_documents( - self, memory_stack, embedding_model, sample_documents - ): - memory_impl, banks_impl = memory_stack - - with pytest.raises(ValueError): - await memory_impl.insert_documents("test_bank", sample_documents) - - registered_bank = await register_memory_bank(banks_impl, embedding_model) - await memory_impl.insert_documents( - registered_bank.memory_bank_id, sample_documents - ) - - query1 = "programming language" - response1 = await memory_impl.query_documents( - registered_bank.memory_bank_id, query1 - ) - assert_valid_response(response1) - assert any("Python" in chunk.content for chunk in response1.chunks) - - # Test case 3: Query with semantic similarity - query3 = "AI and brain-inspired computing" - response3 = await memory_impl.query_documents( - registered_bank.memory_bank_id, query3 - ) - assert_valid_response(response3) - assert any( - "neural networks" in chunk.content.lower() for chunk in response3.chunks - ) - - # Test case 4: Query with limit on number of results - query4 = "computer" - params4 = {"max_chunks": 2} - response4 = await memory_impl.query_documents( - registered_bank.memory_bank_id, query4, params4 - ) - assert_valid_response(response4) - assert len(response4.chunks) <= 2 - - # Test case 5: Query with threshold on similarity score - query5 = "quantum computing" # Not directly related to any document - params5 = {"score_threshold": 0.01} - response5 = await memory_impl.query_documents( - registered_bank.memory_bank_id, query5, params5 - ) - assert_valid_response(response5) - print("The scores are:", response5.scores) - assert all(score >= 0.01 for score in response5.scores) - - -def assert_valid_response(response: QueryDocumentsResponse): - assert isinstance(response, QueryDocumentsResponse) - assert len(response.chunks) > 0 - assert len(response.scores) > 0 - assert len(response.chunks) == len(response.scores) - for chunk in response.chunks: - assert isinstance(chunk.content, str) - assert chunk.document_id is not None diff --git a/llama_stack/providers/tests/resolver.py b/llama_stack/providers/tests/resolver.py index 81816d51e..f0c4c530e 100644 --- a/llama_stack/providers/tests/resolver.py +++ b/llama_stack/providers/tests/resolver.py @@ -12,11 +12,11 @@ from pydantic import BaseModel from llama_stack.apis.datasets import DatasetInput from llama_stack.apis.eval_tasks import EvalTaskInput -from llama_stack.apis.memory_banks import MemoryBankInput from llama_stack.apis.models import ModelInput from llama_stack.apis.scoring_functions import ScoringFnInput from llama_stack.apis.shields import ShieldInput from llama_stack.apis.tools import ToolGroupInput +from llama_stack.apis.vector_dbs import VectorDBInput from llama_stack.distribution.build import print_pip_install_help from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.datatypes import Provider, StackRunConfig @@ -39,7 +39,7 @@ async def construct_stack_for_test( provider_data: Optional[Dict[str, Any]] = None, models: Optional[List[ModelInput]] = None, shields: Optional[List[ShieldInput]] = None, - memory_banks: Optional[List[MemoryBankInput]] = None, + vector_dbs: Optional[List[VectorDBInput]] = None, datasets: Optional[List[DatasetInput]] = None, scoring_fns: Optional[List[ScoringFnInput]] = None, eval_tasks: Optional[List[EvalTaskInput]] = None, @@ -53,7 +53,7 @@ async def construct_stack_for_test( metadata_store=SqliteKVStoreConfig(db_path=sqlite_file.name), models=models or [], shields=shields or [], - memory_banks=memory_banks or [], + vector_dbs=vector_dbs or [], datasets=datasets or [], scoring_fns=scoring_fns or [], eval_tasks=eval_tasks or [], diff --git a/llama_stack/providers/tests/memory/__init__.py b/llama_stack/providers/tests/vector_io/__init__.py similarity index 100% rename from llama_stack/providers/tests/memory/__init__.py rename to llama_stack/providers/tests/vector_io/__init__.py diff --git a/llama_stack/providers/tests/memory/conftest.py b/llama_stack/providers/tests/vector_io/conftest.py similarity index 79% rename from llama_stack/providers/tests/memory/conftest.py rename to llama_stack/providers/tests/vector_io/conftest.py index 87dec4beb..df5c8ea6a 100644 --- a/llama_stack/providers/tests/memory/conftest.py +++ b/llama_stack/providers/tests/vector_io/conftest.py @@ -13,14 +13,14 @@ from ..conftest import ( ) from ..inference.fixtures import INFERENCE_FIXTURES -from .fixtures import MEMORY_FIXTURES +from .fixtures import VECTOR_IO_FIXTURES DEFAULT_PROVIDER_COMBINATIONS = [ pytest.param( { "inference": "sentence_transformers", - "memory": "faiss", + "vector_io": "faiss", }, id="sentence_transformers", marks=pytest.mark.sentence_transformers, @@ -28,7 +28,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ pytest.param( { "inference": "ollama", - "memory": "faiss", + "vector_io": "faiss", }, id="ollama", marks=pytest.mark.ollama, @@ -36,7 +36,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ pytest.param( { "inference": "sentence_transformers", - "memory": "chroma", + "vector_io": "chroma", }, id="chroma", marks=pytest.mark.chroma, @@ -44,7 +44,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ pytest.param( { "inference": "bedrock", - "memory": "qdrant", + "vector_io": "qdrant", }, id="qdrant", marks=pytest.mark.qdrant, @@ -52,7 +52,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ pytest.param( { "inference": "fireworks", - "memory": "weaviate", + "vector_io": "weaviate", }, id="weaviate", marks=pytest.mark.weaviate, @@ -61,7 +61,7 @@ DEFAULT_PROVIDER_COMBINATIONS = [ def pytest_configure(config): - for fixture_name in MEMORY_FIXTURES: + for fixture_name in VECTOR_IO_FIXTURES: config.addinivalue_line( "markers", f"{fixture_name}: marks tests as {fixture_name} specific", @@ -69,7 +69,7 @@ def pytest_configure(config): def pytest_generate_tests(metafunc): - test_config = get_test_config_for_api(metafunc.config, "memory") + test_config = get_test_config_for_api(metafunc.config, "vector_io") if "embedding_model" in metafunc.fixturenames: model = getattr(test_config, "embedding_model", None) # Fall back to the default if not specified by the config file @@ -81,16 +81,16 @@ def pytest_generate_tests(metafunc): metafunc.parametrize("embedding_model", params, indirect=True) - if "memory_stack" in metafunc.fixturenames: + if "vector_io_stack" in metafunc.fixturenames: available_fixtures = { "inference": INFERENCE_FIXTURES, - "memory": MEMORY_FIXTURES, + "vector_io": VECTOR_IO_FIXTURES, } combinations = ( get_provider_fixture_overrides_from_test_config( - metafunc.config, "memory", DEFAULT_PROVIDER_COMBINATIONS + metafunc.config, "vector_io", DEFAULT_PROVIDER_COMBINATIONS ) or get_provider_fixture_overrides(metafunc.config, available_fixtures) or DEFAULT_PROVIDER_COMBINATIONS ) - metafunc.parametrize("memory_stack", combinations, indirect=True) + metafunc.parametrize("vector_io_stack", combinations, indirect=True) diff --git a/llama_stack/providers/tests/memory/fixtures.py b/llama_stack/providers/tests/vector_io/fixtures.py similarity index 80% rename from llama_stack/providers/tests/memory/fixtures.py rename to llama_stack/providers/tests/vector_io/fixtures.py index b9dbb84f7..c8d5fa8cf 100644 --- a/llama_stack/providers/tests/memory/fixtures.py +++ b/llama_stack/providers/tests/vector_io/fixtures.py @@ -12,11 +12,12 @@ import pytest_asyncio from llama_stack.apis.models import ModelInput, ModelType from llama_stack.distribution.datatypes import Api, Provider -from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig -from llama_stack.providers.inline.memory.faiss import FaissImplConfig -from llama_stack.providers.remote.memory.chroma import ChromaRemoteImplConfig -from llama_stack.providers.remote.memory.pgvector import PGVectorConfig -from llama_stack.providers.remote.memory.weaviate import WeaviateConfig + +from llama_stack.providers.inline.vector_io.chroma import ChromaInlineImplConfig +from llama_stack.providers.inline.vector_io.faiss import FaissImplConfig +from llama_stack.providers.remote.vector_io.chroma import ChromaRemoteImplConfig +from llama_stack.providers.remote.vector_io.pgvector import PGVectorConfig +from llama_stack.providers.remote.vector_io.weaviate import WeaviateConfig from llama_stack.providers.tests.resolver import construct_stack_for_test from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig @@ -32,12 +33,12 @@ def embedding_model(request): @pytest.fixture(scope="session") -def memory_remote() -> ProviderFixture: +def vector_io_remote() -> ProviderFixture: return remote_stack_fixture() @pytest.fixture(scope="session") -def memory_faiss() -> ProviderFixture: +def vector_io_faiss() -> ProviderFixture: temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db") return ProviderFixture( providers=[ @@ -53,7 +54,7 @@ def memory_faiss() -> ProviderFixture: @pytest.fixture(scope="session") -def memory_pgvector() -> ProviderFixture: +def vector_io_pgvector() -> ProviderFixture: return ProviderFixture( providers=[ Provider( @@ -72,7 +73,7 @@ def memory_pgvector() -> ProviderFixture: @pytest.fixture(scope="session") -def memory_weaviate() -> ProviderFixture: +def vector_io_weaviate() -> ProviderFixture: return ProviderFixture( providers=[ Provider( @@ -89,7 +90,7 @@ def memory_weaviate() -> ProviderFixture: @pytest.fixture(scope="session") -def memory_chroma() -> ProviderFixture: +def vector_io_chroma() -> ProviderFixture: url = os.getenv("CHROMA_URL") if url: config = ChromaRemoteImplConfig(url=url) @@ -110,23 +111,23 @@ def memory_chroma() -> ProviderFixture: ) -MEMORY_FIXTURES = ["faiss", "pgvector", "weaviate", "remote", "chroma"] +VECTOR_IO_FIXTURES = ["faiss", "pgvector", "weaviate", "chroma"] @pytest_asyncio.fixture(scope="session") -async def memory_stack(embedding_model, request): +async def vector_io_stack(embedding_model, request): fixture_dict = request.param providers = {} provider_data = {} - for key in ["inference", "memory"]: + for key in ["inference", "vector_io"]: fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}") providers[key] = fixture.providers if fixture.provider_data: provider_data.update(fixture.provider_data) test_stack = await construct_stack_for_test( - [Api.memory, Api.inference], + [Api.vector_io, Api.inference], providers, provider_data, models=[ @@ -140,4 +141,4 @@ async def memory_stack(embedding_model, request): ], ) - return test_stack.impls[Api.memory], test_stack.impls[Api.memory_banks] + return test_stack.impls[Api.vector_io], test_stack.impls[Api.vector_dbs] diff --git a/llama_stack/providers/tests/memory/fixtures/dummy.pdf b/llama_stack/providers/tests/vector_io/fixtures/dummy.pdf similarity index 100% rename from llama_stack/providers/tests/memory/fixtures/dummy.pdf rename to llama_stack/providers/tests/vector_io/fixtures/dummy.pdf diff --git a/llama_stack/providers/tests/vector_io/test_vector_io.py b/llama_stack/providers/tests/vector_io/test_vector_io.py new file mode 100644 index 000000000..901b8bd11 --- /dev/null +++ b/llama_stack/providers/tests/vector_io/test_vector_io.py @@ -0,0 +1,200 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import uuid + +import pytest + +from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB +from llama_stack.apis.vector_io import QueryChunksResponse + +from llama_stack.providers.utils.memory.vector_store import ( + make_overlapped_chunks, + MemoryBankDocument, +) + +# How to run this test: +# +# pytest llama_stack/providers/tests/memory/test_memory.py +# -m "sentence_transformers" --env EMBEDDING_DIMENSION=384 +# -v -s --tb=short --disable-warnings + + +@pytest.fixture(scope="session") +def sample_chunks(): + docs = [ + MemoryBankDocument( + document_id="doc1", + content="Python is a high-level programming language.", + metadata={"category": "programming", "difficulty": "beginner"}, + ), + MemoryBankDocument( + document_id="doc2", + content="Machine learning is a subset of artificial intelligence.", + metadata={"category": "AI", "difficulty": "advanced"}, + ), + MemoryBankDocument( + document_id="doc3", + content="Data structures are fundamental to computer science.", + metadata={"category": "computer science", "difficulty": "intermediate"}, + ), + MemoryBankDocument( + document_id="doc4", + content="Neural networks are inspired by biological neural networks.", + metadata={"category": "AI", "difficulty": "advanced"}, + ), + ] + chunks = [] + for doc in docs: + chunks.extend( + make_overlapped_chunks( + doc.document_id, doc.content, window_len=512, overlap_len=64 + ) + ) + return chunks + + +async def register_vector_db(vector_dbs_impl: VectorDB, embedding_model: str): + vector_db_id = f"test_vector_db_{uuid.uuid4().hex}" + return await vector_dbs_impl.register_vector_db( + vector_db_id=vector_db_id, + embedding_model=embedding_model, + embedding_dimension=384, + ) + + +class TestVectorIO: + @pytest.mark.asyncio + async def test_banks_list(self, vector_io_stack, embedding_model): + _, vector_dbs_impl = vector_io_stack + + # Register a test bank + registered_vector_db = await register_vector_db( + vector_dbs_impl, embedding_model + ) + + try: + # Verify our bank shows up in list + response = await vector_dbs_impl.list_vector_dbs() + assert isinstance(response, ListVectorDBsResponse) + assert any( + vector_db.vector_db_id == registered_vector_db.vector_db_id + for vector_db in response.data + ) + finally: + # Clean up + await vector_dbs_impl.unregister_vector_db( + registered_vector_db.vector_db_id + ) + + # Verify our bank was removed + response = await vector_dbs_impl.list_vector_dbs() + assert isinstance(response, ListVectorDBsResponse) + assert all( + vector_db.vector_db_id != registered_vector_db.vector_db_id + for vector_db in response.data + ) + + @pytest.mark.asyncio + async def test_banks_register(self, vector_io_stack, embedding_model): + _, vector_dbs_impl = vector_io_stack + + vector_db_id = f"test_vector_db_{uuid.uuid4().hex}" + + try: + # Register initial bank + await vector_dbs_impl.register_vector_db( + vector_db_id=vector_db_id, + embedding_model=embedding_model, + embedding_dimension=384, + ) + + # Verify our bank exists + response = await vector_dbs_impl.list_vector_dbs() + assert isinstance(response, ListVectorDBsResponse) + assert any( + vector_db.vector_db_id == vector_db_id for vector_db in response.data + ) + + # Try registering same bank again + await vector_dbs_impl.register_vector_db( + vector_db_id=vector_db_id, + embedding_model=embedding_model, + embedding_dimension=384, + ) + + # Verify still only one instance of our bank + response = await vector_dbs_impl.list_vector_dbs() + assert isinstance(response, ListVectorDBsResponse) + assert ( + len( + [ + vector_db + for vector_db in response.data + if vector_db.vector_db_id == vector_db_id + ] + ) + == 1 + ) + finally: + # Clean up + await vector_dbs_impl.unregister_vector_db(vector_db_id) + + @pytest.mark.asyncio + async def test_query_documents( + self, vector_io_stack, embedding_model, sample_chunks + ): + vector_io_impl, vector_dbs_impl = vector_io_stack + + with pytest.raises(ValueError): + await vector_io_impl.insert_chunks("test_vector_db", sample_chunks) + + registered_db = await register_vector_db(vector_dbs_impl, embedding_model) + await vector_io_impl.insert_chunks(registered_db.vector_db_id, sample_chunks) + + query1 = "programming language" + response1 = await vector_io_impl.query_chunks( + registered_db.vector_db_id, query1 + ) + assert_valid_response(response1) + assert any("Python" in chunk.content for chunk in response1.chunks) + + # Test case 3: Query with semantic similarity + query3 = "AI and brain-inspired computing" + response3 = await vector_io_impl.query_chunks( + registered_db.vector_db_id, query3 + ) + assert_valid_response(response3) + assert any( + "neural networks" in chunk.content.lower() for chunk in response3.chunks + ) + + # Test case 4: Query with limit on number of results + query4 = "computer" + params4 = {"max_chunks": 2} + response4 = await vector_io_impl.query_chunks( + registered_db.vector_db_id, query4, params4 + ) + assert_valid_response(response4) + assert len(response4.chunks) <= 2 + + # Test case 5: Query with threshold on similarity score + query5 = "quantum computing" # Not directly related to any document + params5 = {"score_threshold": 0.01} + response5 = await vector_io_impl.query_chunks( + registered_db.vector_db_id, query5, params5 + ) + assert_valid_response(response5) + print("The scores are:", response5.scores) + assert all(score >= 0.01 for score in response5.scores) + + +def assert_valid_response(response: QueryChunksResponse): + assert len(response.chunks) > 0 + assert len(response.scores) > 0 + assert len(response.chunks) == len(response.scores) + for chunk in response.chunks: + assert isinstance(chunk.content, str) diff --git a/llama_stack/providers/tests/memory/test_vector_store.py b/llama_stack/providers/tests/vector_io/test_vector_store.py similarity index 94% rename from llama_stack/providers/tests/memory/test_vector_store.py rename to llama_stack/providers/tests/vector_io/test_vector_store.py index 1ad7abf0c..ef6bfca73 100644 --- a/llama_stack/providers/tests/memory/test_vector_store.py +++ b/llama_stack/providers/tests/vector_io/test_vector_store.py @@ -11,8 +11,11 @@ from pathlib import Path import pytest -from llama_stack.apis.memory.memory import MemoryBankDocument, URL -from llama_stack.providers.utils.memory.vector_store import content_from_doc +from llama_stack.providers.utils.memory.vector_store import ( + content_from_doc, + MemoryBankDocument, + URL, +) DUMMY_PDF_PATH = Path(os.path.abspath(__file__)).parent / "fixtures" / "dummy.pdf" diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index c97633558..c2de6c714 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -18,6 +18,8 @@ import numpy as np from llama_models.llama3.api.tokenizer import Tokenizer from numpy.typing import NDArray + +from pydantic import BaseModel, Field from pypdf import PdfReader from llama_stack.apis.common.content_types import ( @@ -25,16 +27,24 @@ from llama_stack.apis.common.content_types import ( TextContentItem, URL, ) -from llama_stack.apis.memory import Chunk, MemoryBankDocument, QueryDocumentsResponse -from llama_stack.apis.memory_banks import VectorMemoryBank +from llama_stack.apis.vector_dbs import VectorDB +from llama_stack.apis.vector_io import Chunk, QueryChunksResponse from llama_stack.providers.datatypes import Api from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, ) + log = logging.getLogger(__name__) +class MemoryBankDocument(BaseModel): + document_id: str + content: InterleavedContent | URL + mime_type: str | None = None + metadata: Dict[str, Any] = Field(default_factory=dict) + + def parse_pdf(data: bytes) -> str: # For PDF and DOC/DOCX files, we can't reliably convert to string pdf_bytes = io.BytesIO(data) @@ -165,7 +175,7 @@ class EmbeddingIndex(ABC): @abstractmethod async def query( self, embedding: NDArray, k: int, score_threshold: float - ) -> QueryDocumentsResponse: + ) -> QueryChunksResponse: raise NotImplementedError() @abstractmethod @@ -174,56 +184,35 @@ class EmbeddingIndex(ABC): @dataclass -class BankWithIndex: - bank: VectorMemoryBank +class VectorDBWithIndex: + vector_db: VectorDB index: EmbeddingIndex inference_api: Api.inference - async def insert_documents( + async def insert_chunks( self, - documents: List[MemoryBankDocument], + chunks: List[Chunk], ) -> None: - for doc in documents: - content = await content_from_doc(doc) - chunks = make_overlapped_chunks( - doc.document_id, - content, - self.bank.chunk_size_in_tokens, - self.bank.overlap_size_in_tokens - or (self.bank.chunk_size_in_tokens // 4), - ) - if not chunks: - continue - embeddings_response = await self.inference_api.embeddings( - self.bank.embedding_model, [x.content for x in chunks] - ) - embeddings = np.array(embeddings_response.embeddings) + embeddings_response = await self.inference_api.embeddings( + self.vector_db.embedding_model, [x.content for x in chunks] + ) + embeddings = np.array(embeddings_response.embeddings) - await self.index.add_chunks(chunks, embeddings) + await self.index.add_chunks(chunks, embeddings) - async def query_documents( + async def query_chunks( self, query: InterleavedContent, params: Optional[Dict[str, Any]] = None, - ) -> QueryDocumentsResponse: + ) -> QueryChunksResponse: if params is None: params = {} k = params.get("max_chunks", 3) score_threshold = params.get("score_threshold", 0.0) - def _process(c) -> str: - if isinstance(c, str): - return c - else: - return "" - - if isinstance(query, list): - query_str = " ".join([_process(c) for c in query]) - else: - query_str = _process(query) - + query_str = interleaved_content_as_str(query) embeddings_response = await self.inference_api.embeddings( - self.bank.embedding_model, [query_str] + self.vector_db.embedding_model, [query_str] ) query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32) return await self.index.query(query_vector, k, score_threshold) diff --git a/tests/client-sdk/conftest.py b/tests/client-sdk/conftest.py index 0b5324c0e..c19546887 100644 --- a/tests/client-sdk/conftest.py +++ b/tests/client-sdk/conftest.py @@ -32,6 +32,7 @@ def pytest_addoption(parser): TEXT_MODEL = "meta-llama/Llama-3.1-8B-Instruct" INFERENCE_MODEL = "meta-llama/Llama-3.2-11B-Vision-Instruct" + @pytest.fixture(scope="session") def provider_data(): # check env for tavily secret, brave secret and inject all into provider data diff --git a/tests/client-sdk/memory/__init__.py b/tests/client-sdk/vector_io/__init__.py similarity index 100% rename from tests/client-sdk/memory/__init__.py rename to tests/client-sdk/vector_io/__init__.py diff --git a/tests/client-sdk/memory/test_memory.py b/tests/client-sdk/vector_io/test_vector_io.py similarity index 100% rename from tests/client-sdk/memory/test_memory.py rename to tests/client-sdk/vector_io/test_vector_io.py