diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 5b5175cee..2d0ad137b 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -413,8 +413,8 @@ class ChatAgent(ShieldRunnerMixin): session_info = await self.storage.get_session_info(session_id) # if the session has a memory bank id, let the memory tool use it - if session_info.memory_bank_id: - vector_db_ids.append(session_info.memory_bank_id) + if session_info.vector_db_id: + vector_db_ids.append(session_info.vector_db_id) yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( @@ -829,7 +829,7 @@ class ChatAgent(ShieldRunnerMixin): msg = await attachment_message(self.tempdir, url_items) input_messages.append(msg) # Since memory is present, add all the data to the memory bank - await self.add_to_session_memory_bank(session_id, documents) + await self.add_to_session_vector_db(session_id, documents) elif code_interpreter_tool: # if only code_interpreter is available, we download the URLs to a tempdir # and attach the path to them as a message to inference with the @@ -838,7 +838,7 @@ class ChatAgent(ShieldRunnerMixin): input_messages.append(msg) elif memory_tool: # if only memory is available, we load the data from the URLs and content items to the memory bank - await self.add_to_session_memory_bank(session_id, documents) + await self.add_to_session_vector_db(session_id, documents) else: # if no memory or code_interpreter tool is available, # we try to load the data from the URLs and content items as a message to inference @@ -848,31 +848,31 @@ class ChatAgent(ShieldRunnerMixin): + await load_data_from_urls(url_items) ) - async def _ensure_memory_bank(self, session_id: str) -> str: + async def _ensure_vector_db(self, session_id: str) -> str: session_info = await self.storage.get_session_info(session_id) if session_info is None: raise ValueError(f"Session {session_id} not found") - if session_info.memory_bank_id is None: - bank_id = f"memory_bank_{session_id}" + if session_info.vector_db_id is None: + vector_db_id = f"vector_db_{session_id}" # TODO: the semantic for registration is definitely not "creation" # so we need to fix it if we expect the agent to create a new vector db # for each session await self.vector_io_api.register_vector_db( - vector_db_id=bank_id, + vector_db_id=vector_db_id, embedding_model="all-MiniLM-L6-v2", ) - await self.storage.add_memory_bank_to_session(session_id, bank_id) + await self.storage.add_vector_db_to_session(session_id, vector_db_id) else: - bank_id = session_info.memory_bank_id + vector_db_id = session_info.vector_db_id - return bank_id + return vector_db_id - async def add_to_session_memory_bank( + async def add_to_session_vector_db( self, session_id: str, data: List[Document] ) -> None: - vector_db_id = await self._ensure_memory_bank(session_id) + vector_db_id = await self._ensure_vector_db(session_id) documents = [ RAGDocument( document_id=str(uuid.uuid4()), diff --git a/llama_stack/providers/inline/agents/meta_reference/persistence.py b/llama_stack/providers/inline/agents/meta_reference/persistence.py index 58b69858b..4b8ad6d4a 100644 --- a/llama_stack/providers/inline/agents/meta_reference/persistence.py +++ b/llama_stack/providers/inline/agents/meta_reference/persistence.py @@ -21,7 +21,7 @@ log = logging.getLogger(__name__) class AgentSessionInfo(BaseModel): session_id: str session_name: str - memory_bank_id: Optional[str] = None + vector_db_id: Optional[str] = None started_at: datetime @@ -52,12 +52,12 @@ class AgentPersistence: return AgentSessionInfo(**json.loads(value)) - async def add_memory_bank_to_session(self, session_id: str, bank_id: str): + async def add_vector_db_to_session(self, session_id: str, vector_db_id: str): session_info = await self.get_session_info(session_id) if session_info is None: raise ValueError(f"Session {session_id} not found") - session_info.memory_bank_id = bank_id + session_info.vector_db_id = vector_db_id await self.kvstore.set( key=f"session:{self.agent_id}:{session_id}", value=session_info.model_dump_json(), diff --git a/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py b/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py index a7e6efc8c..205868279 100644 --- a/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py +++ b/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py @@ -29,10 +29,9 @@ from llama_stack.apis.inference import ( SamplingParams, ToolChoice, ToolDefinition, + ToolPromptFormat, UserMessage, ) -from llama_stack.apis.memory import MemoryBank -from llama_stack.apis.memory_banks import BankParams, VectorMemoryBank from llama_stack.apis.safety import RunShieldResponse from llama_stack.apis.tools import ( Tool, @@ -40,8 +39,9 @@ from llama_stack.apis.tools import ( ToolGroup, ToolHost, ToolInvocationResult, - ToolPromptFormat, ) +from llama_stack.apis.vector_io import QueryChunksResponse + from llama_stack.providers.inline.agents.meta_reference.agent_instance import ( MEMORY_QUERY_TOOL, ) @@ -110,68 +110,22 @@ class MockSafetyAPI: return RunShieldResponse(violation=None) -class MockMemoryAPI: +class MockVectorIOAPI: def __init__(self): - self.memory_banks = {} - self.documents = {} + self.chunks = {} - async def create_memory_bank(self, name, config, url=None): - bank_id = f"bank_{len(self.memory_banks)}" - bank = MemoryBank(bank_id, name, config, url) - self.memory_banks[bank_id] = bank - self.documents[bank_id] = {} - return bank + async def insert_chunks(self, vector_db_id, chunks, ttl_seconds=None): + for chunk in chunks: + metadata = chunk.metadata + self.chunks[vector_db_id][metadata["document_id"]] = chunk - async def list_memory_banks(self): - return list(self.memory_banks.values()) + async def query_chunks(self, vector_db_id, query, params=None): + if vector_db_id not in self.chunks: + raise ValueError(f"Bank {vector_db_id} not found") - async def get_memory_bank(self, bank_id): - return self.memory_banks.get(bank_id) - - async def drop_memory_bank(self, bank_id): - if bank_id in self.memory_banks: - del self.memory_banks[bank_id] - del self.documents[bank_id] - return bank_id - - async def insert_documents(self, bank_id, documents, ttl_seconds=None): - if bank_id not in self.documents: - raise ValueError(f"Bank {bank_id} not found") - for doc in documents: - self.documents[bank_id][doc.document_id] = doc - - async def update_documents(self, bank_id, documents): - if bank_id not in self.documents: - raise ValueError(f"Bank {bank_id} not found") - for doc in documents: - if doc.document_id in self.documents[bank_id]: - self.documents[bank_id][doc.document_id] = doc - - async def query_documents(self, bank_id, query, params=None): - if bank_id not in self.documents: - raise ValueError(f"Bank {bank_id} not found") - # Simple mock implementation: return all documents - chunks = [ - {"content": doc.content, "token_count": 10, "document_id": doc.document_id} - for doc in self.documents[bank_id].values() - ] + chunks = list(self.chunks[vector_db_id].values()) scores = [1.0] * len(chunks) - return {"chunks": chunks, "scores": scores} - - async def get_documents(self, bank_id, document_ids): - if bank_id not in self.documents: - raise ValueError(f"Bank {bank_id} not found") - return [ - self.documents[bank_id][doc_id] - for doc_id in document_ids - if doc_id in self.documents[bank_id] - ] - - async def delete_documents(self, bank_id, document_ids): - if bank_id not in self.documents: - raise ValueError(f"Bank {bank_id} not found") - for doc_id in document_ids: - self.documents[bank_id].pop(doc_id, None) + return QueryChunksResponse(chunks=chunks, scores=scores) class MockToolGroupsAPI: @@ -241,31 +195,6 @@ class MockToolRuntimeAPI: return ToolInvocationResult(content={"result": "Mock tool result"}) -class MockMemoryBanksAPI: - async def list_memory_banks(self) -> List[MemoryBank]: - return [] - - async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]: - return None - - async def register_memory_bank( - self, - memory_bank_id: str, - params: BankParams, - provider_id: Optional[str] = None, - provider_memory_bank_id: Optional[str] = None, - ) -> MemoryBank: - return VectorMemoryBank( - identifier=memory_bank_id, - provider_resource_id=provider_memory_bank_id or memory_bank_id, - embedding_model="mock_model", - chunk_size_in_tokens=512, - ) - - async def unregister_memory_bank(self, memory_bank_id: str) -> None: - pass - - @pytest.fixture def mock_inference_api(): return MockInferenceAPI() @@ -277,8 +206,8 @@ def mock_safety_api(): @pytest.fixture -def mock_memory_api(): - return MockMemoryAPI() +def mock_vector_io_api(): + return MockVectorIOAPI() @pytest.fixture @@ -291,17 +220,11 @@ def mock_tool_runtime_api(): return MockToolRuntimeAPI() -@pytest.fixture -def mock_memory_banks_api(): - return MockMemoryBanksAPI() - - @pytest.fixture async def get_agents_impl( mock_inference_api, mock_safety_api, - mock_memory_api, - mock_memory_banks_api, + mock_vector_io_api, mock_tool_runtime_api, mock_tool_groups_api, ): @@ -314,8 +237,7 @@ async def get_agents_impl( ), inference_api=mock_inference_api, safety_api=mock_safety_api, - memory_api=mock_memory_api, - memory_banks_api=mock_memory_banks_api, + vector_io_api=mock_vector_io_api, tool_runtime_api=mock_tool_runtime_api, tool_groups_api=mock_tool_groups_api, ) @@ -484,7 +406,7 @@ async def test_chat_agent_tools( toolgroups_for_turn=[ AgentToolGroupWithArgs( name=MEMORY_TOOLGROUP, - args={"memory_banks": ["test_memory_bank"]}, + args={"vector_dbs": ["test_vector_db"]}, ) ] ) diff --git a/llama_stack/providers/remote/vector_io/chroma/chroma.py b/llama_stack/providers/remote/vector_io/chroma/chroma.py index c04d775ca..e33980be1 100644 --- a/llama_stack/providers/remote/vector_io/chroma/chroma.py +++ b/llama_stack/providers/remote/vector_io/chroma/chroma.py @@ -6,25 +6,20 @@ import asyncio import json import logging -from typing import List, Optional, Union +from typing import Any, Dict, List, Optional, Union from urllib.parse import urlparse import chromadb from numpy.typing import NDArray from llama_stack.apis.inference import InterleavedContent -from llama_stack.apis.memory import ( - Chunk, - Memory, - MemoryBankDocument, - QueryDocumentsResponse, -) -from llama_stack.apis.memory_banks import MemoryBank, MemoryBankType -from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate -from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig +from llama_stack.apis.vector_dbs import VectorDB +from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO +from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate +from llama_stack.providers.inline.vector_io.chroma import ChromaInlineImplConfig from llama_stack.providers.utils.memory.vector_store import ( - BankWithIndex, EmbeddingIndex, + VectorDBWithIndex, ) from .config import ChromaRemoteImplConfig @@ -61,7 +56,7 @@ class ChromaIndex(EmbeddingIndex): async def query( self, embedding: NDArray, k: int, score_threshold: float - ) -> QueryDocumentsResponse: + ) -> QueryChunksResponse: results = await maybe_await( self.collection.query( query_embeddings=[embedding.tolist()], @@ -85,13 +80,13 @@ class ChromaIndex(EmbeddingIndex): chunks.append(chunk) scores.append(1.0 / float(dist)) - return QueryDocumentsResponse(chunks=chunks, scores=scores) + return QueryChunksResponse(chunks=chunks, scores=scores) async def delete(self): await maybe_await(self.client.delete_collection(self.collection.name)) -class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): +class ChromaMemoryAdapter(VectorIO, VectorDBsProtocolPrivate): def __init__( self, config: Union[ChromaRemoteImplConfig, ChromaInlineImplConfig], @@ -123,60 +118,58 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): async def shutdown(self) -> None: pass - async def register_memory_bank( + async def register_vector_db( self, - memory_bank: MemoryBank, + vector_db: VectorDB, ) -> None: - assert ( - memory_bank.memory_bank_type == MemoryBankType.vector.value - ), f"Only vector banks are supported {memory_bank.memory_bank_type}" - collection = await maybe_await( self.client.get_or_create_collection( - name=memory_bank.identifier, - metadata={"bank": memory_bank.model_dump_json()}, + name=vector_db.identifier, + metadata={"vector_db": vector_db.model_dump_json()}, ) ) - self.cache[memory_bank.identifier] = BankWithIndex( - memory_bank, ChromaIndex(self.client, collection), self.inference_api + self.cache[vector_db.identifier] = VectorDBWithIndex( + vector_db, ChromaIndex(self.client, collection), self.inference_api ) - async def unregister_memory_bank(self, memory_bank_id: str) -> None: - await self.cache[memory_bank_id].index.delete() - del self.cache[memory_bank_id] + async def unregister_vector_db(self, vector_db_id: str) -> None: + await self.cache[vector_db_id].index.delete() + del self.cache[vector_db_id] - async def insert_documents( + async def insert_chunks( self, - bank_id: str, - documents: List[MemoryBankDocument], - ttl_seconds: Optional[int] = None, + vector_db_id: str, + chunks: List[Chunk], + embeddings: NDArray, ) -> None: - index = await self._get_and_cache_bank_index(bank_id) + index = await self._get_and_cache_vector_db_index(vector_db_id) - await index.insert_documents(documents) + await index.insert_chunks(chunks, embeddings) - async def query_documents( + async def query_chunks( self, - bank_id: str, + vector_db_id: str, query: InterleavedContent, params: Optional[Dict[str, Any]] = None, - ) -> QueryDocumentsResponse: - index = await self._get_and_cache_bank_index(bank_id) + ) -> QueryChunksResponse: + index = await self._get_and_cache_vector_db_index(vector_db_id) - return await index.query_documents(query, params) + return await index.query_chunks(query, params) - async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex: - if bank_id in self.cache: - return self.cache[bank_id] + async def _get_and_cache_vector_db_index( + self, vector_db_id: str + ) -> VectorDBWithIndex: + if vector_db_id in self.cache: + return self.cache[vector_db_id] - bank = await self.memory_bank_store.get_memory_bank(bank_id) - if not bank: - raise ValueError(f"Bank {bank_id} not found in Llama Stack") - collection = await maybe_await(self.client.get_collection(bank_id)) + vector_db = await self.vector_db_store.get_vector_db(vector_db_id) + if not vector_db: + raise ValueError(f"Vector DB {vector_db_id} not found in Llama Stack") + collection = await maybe_await(self.client.get_collection(vector_db_id)) if not collection: - raise ValueError(f"Bank {bank_id} not found in Chroma") - index = BankWithIndex( - bank, ChromaIndex(self.client, collection), self.inference_api + raise ValueError(f"Vector DB {vector_db_id} not found in Chroma") + index = VectorDBWithIndex( + vector_db, ChromaIndex(self.client, collection), self.inference_api ) - self.cache[bank_id] = index + self.cache[vector_db_id] = index return index diff --git a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py index b2c720b2c..3605f038c 100644 --- a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py +++ b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py @@ -12,21 +12,16 @@ from numpy.typing import NDArray from psycopg2 import sql from psycopg2.extras import execute_values, Json -from pydantic import BaseModel, parse_obj_as +from pydantic import BaseModel, TypeAdapter from llama_stack.apis.inference import InterleavedContent -from llama_stack.apis.memory import ( - Chunk, - Memory, - MemoryBankDocument, - QueryDocumentsResponse, -) -from llama_stack.apis.memory_banks import MemoryBank, MemoryBankType, VectorMemoryBank -from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate +from llama_stack.apis.vector_dbs import VectorDB +from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO +from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.utils.memory.vector_store import ( - BankWithIndex, EmbeddingIndex, + VectorDBWithIndex, ) from .config import PGVectorConfig @@ -50,20 +45,20 @@ def upsert_models(cur, keys_models: List[Tuple[str, BaseModel]]): """ ) - values = [(key, Json(model.dict())) for key, model in keys_models] + values = [(key, Json(model.model_dump())) for key, model in keys_models] execute_values(cur, query, values, template="(%s, %s)") def load_models(cur, cls): cur.execute("SELECT key, data FROM metadata_store") rows = cur.fetchall() - return [parse_obj_as(cls, row["data"]) for row in rows] + return [TypeAdapter(cls).validate_python(row["data"]) for row in rows] class PGVectorIndex(EmbeddingIndex): - def __init__(self, bank: VectorMemoryBank, dimension: int, cursor): + def __init__(self, vector_db: VectorDB, dimension: int, cursor): self.cursor = cursor - self.table_name = f"vector_store_{bank.identifier}" + self.table_name = f"vector_store_{vector_db.identifier}" self.cursor.execute( f""" @@ -85,7 +80,7 @@ class PGVectorIndex(EmbeddingIndex): values.append( ( f"{chunk.document_id}:chunk-{i}", - Json(chunk.dict()), + Json(chunk.model_dump()), embeddings[i].tolist(), ) ) @@ -101,7 +96,7 @@ class PGVectorIndex(EmbeddingIndex): async def query( self, embedding: NDArray, k: int, score_threshold: float - ) -> QueryDocumentsResponse: + ) -> QueryChunksResponse: self.cursor.execute( f""" SELECT document, embedding <-> %s::vector AS distance @@ -119,13 +114,13 @@ class PGVectorIndex(EmbeddingIndex): chunks.append(Chunk(**doc)) scores.append(1.0 / float(dist)) - return QueryDocumentsResponse(chunks=chunks, scores=scores) + return QueryChunksResponse(chunks=chunks, scores=scores) async def delete(self): self.cursor.execute(f"DROP TABLE IF EXISTS {self.table_name}") -class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): +class PGVectorVectorDBAdapter(VectorIO, VectorDBsProtocolPrivate): def __init__(self, config: PGVectorConfig, inference_api: Api.inference) -> None: self.config = config self.inference_api = inference_api @@ -167,46 +162,45 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): async def shutdown(self) -> None: pass - async def register_memory_bank(self, memory_bank: MemoryBank) -> None: - assert ( - memory_bank.memory_bank_type == MemoryBankType.vector.value - ), f"Only vector banks are supported {memory_bank.memory_bank_type}" + async def register_vector_db(self, vector_db: VectorDB) -> None: + upsert_models(self.cursor, [(vector_db.identifier, vector_db)]) - upsert_models(self.cursor, [(memory_bank.identifier, memory_bank)]) - index = PGVectorIndex(memory_bank, memory_bank.embedding_dimension, self.cursor) - self.cache[memory_bank.identifier] = BankWithIndex( - memory_bank, index, self.inference_api + index = PGVectorIndex(vector_db, vector_db.embedding_dimension, self.cursor) + self.cache[vector_db.identifier] = VectorDBWithIndex( + vector_db, index, self.inference_api ) - async def unregister_memory_bank(self, memory_bank_id: str) -> None: - await self.cache[memory_bank_id].index.delete() - del self.cache[memory_bank_id] + async def unregister_vector_db(self, vector_db_id: str) -> None: + await self.cache[vector_db_id].index.delete() + del self.cache[vector_db_id] - async def insert_documents( + async def insert_chunks( self, - bank_id: str, - documents: List[MemoryBankDocument], + vector_db_id: str, + chunks: List[Chunk], ttl_seconds: Optional[int] = None, ) -> None: - index = await self._get_and_cache_bank_index(bank_id) - await index.insert_documents(documents) + index = await self._get_and_cache_vector_db_index(vector_db_id) + await index.insert_chunks(chunks) - async def query_documents( + async def query_chunks( self, - bank_id: str, + vector_db_id: str, query: InterleavedContent, params: Optional[Dict[str, Any]] = None, - ) -> QueryDocumentsResponse: - index = await self._get_and_cache_bank_index(bank_id) - return await index.query_documents(query, params) + ) -> QueryChunksResponse: + index = await self._get_and_cache_vector_db_index(vector_db_id) + return await index.query_chunks(query, params) - self.inference_api = inference_api + async def _get_and_cache_vector_db_index( + self, vector_db_id: str + ) -> VectorDBWithIndex: + if vector_db_id in self.cache: + return self.cache[vector_db_id] - async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex: - if bank_id in self.cache: - return self.cache[bank_id] - - bank = await self.memory_bank_store.get_memory_bank(bank_id) - index = PGVectorIndex(bank, bank.embedding_dimension, self.cursor) - self.cache[bank_id] = BankWithIndex(bank, index, self.inference_api) - return self.cache[bank_id] + vector_db = await self.vector_db_store.get_vector_db(vector_db_id) + index = PGVectorIndex(vector_db, vector_db.embedding_dimension, self.cursor) + self.cache[vector_db_id] = VectorDBWithIndex( + vector_db, index, self.inference_api + ) + return self.cache[vector_db_id] diff --git a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py index b1d5bd7fa..d3257b4c9 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py +++ b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py @@ -13,19 +13,14 @@ from qdrant_client import AsyncQdrantClient, models from qdrant_client.models import PointStruct from llama_stack.apis.inference import InterleavedContent -from llama_stack.apis.memory import ( - Chunk, - Memory, - MemoryBankDocument, - QueryDocumentsResponse, -) -from llama_stack.apis.memory_banks import MemoryBank, MemoryBankType -from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate -from llama_stack.providers.remote.memory.qdrant.config import QdrantConfig +from llama_stack.apis.vector_dbs import VectorDB +from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO +from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.utils.memory.vector_store import ( - BankWithIndex, EmbeddingIndex, + VectorDBWithIndex, ) +from .config import QdrantConfig log = logging.getLogger(__name__) CHUNK_ID_KEY = "_chunk_id" @@ -76,7 +71,7 @@ class QdrantIndex(EmbeddingIndex): async def query( self, embedding: NDArray, k: int, score_threshold: float - ) -> QueryDocumentsResponse: + ) -> QueryChunksResponse: results = ( await self.client.query_points( collection_name=self.collection_name, @@ -101,10 +96,10 @@ class QdrantIndex(EmbeddingIndex): chunks.append(chunk) scores.append(point.score) - return QueryDocumentsResponse(chunks=chunks, scores=scores) + return QueryChunksResponse(chunks=chunks, scores=scores) -class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): +class QdrantVectorDBAdapter(VectorIO, VectorDBsProtocolPrivate): def __init__(self, config: QdrantConfig, inference_api: Api.inference) -> None: self.config = config self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True)) @@ -117,58 +112,56 @@ class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): async def shutdown(self) -> None: self.client.close() - async def register_memory_bank( + async def register_vector_db( self, - memory_bank: MemoryBank, + vector_db: VectorDB, ) -> None: - assert ( - memory_bank.memory_bank_type == MemoryBankType.vector - ), f"Only vector banks are supported {memory_bank.memory_bank_type}" - - index = BankWithIndex( - bank=memory_bank, - index=QdrantIndex(self.client, memory_bank.identifier), + index = VectorDBWithIndex( + vector_db=vector_db, + index=QdrantIndex(self.client, vector_db.identifier), inference_api=self.inference_api, ) - self.cache[memory_bank.identifier] = index + self.cache[vector_db.identifier] = index - async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]: - if bank_id in self.cache: - return self.cache[bank_id] + async def _get_and_cache_vector_db_index( + self, vector_db_id: str + ) -> Optional[VectorDBWithIndex]: + if vector_db_id in self.cache: + return self.cache[vector_db_id] - bank = await self.memory_bank_store.get_memory_bank(bank_id) - if not bank: - raise ValueError(f"Bank {bank_id} not found") + vector_db = await self.vector_db_store.get_vector_db(vector_db_id) + if not vector_db: + raise ValueError(f"Vector DB {vector_db_id} not found") - index = BankWithIndex( - bank=bank, - index=QdrantIndex(client=self.client, collection_name=bank_id), + index = VectorDBWithIndex( + vector_db=vector_db, + index=QdrantIndex(client=self.client, collection_name=vector_db.identifier), inference_api=self.inference_api, ) - self.cache[bank_id] = index + self.cache[vector_db_id] = index return index - async def insert_documents( + async def insert_chunks( self, - bank_id: str, - documents: List[MemoryBankDocument], + vector_db_id: str, + chunks: List[Chunk], ttl_seconds: Optional[int] = None, ) -> None: - index = await self._get_and_cache_bank_index(bank_id) + index = await self._get_and_cache_vector_db_index(vector_db_id) if not index: - raise ValueError(f"Bank {bank_id} not found") + raise ValueError(f"Vector DB {vector_db_id} not found") - await index.insert_documents(documents) + await index.insert_chunks(chunks) - async def query_documents( + async def query_chunks( self, - bank_id: str, + vector_db_id: str, query: InterleavedContent, params: Optional[Dict[str, Any]] = None, - ) -> QueryDocumentsResponse: - index = await self._get_and_cache_bank_index(bank_id) + ) -> QueryChunksResponse: + index = await self._get_and_cache_vector_db_index(vector_db_id) if not index: - raise ValueError(f"Bank {bank_id} not found") + raise ValueError(f"Vector DB {vector_db_id} not found") - return await index.query_documents(query, params) + return await index.query_chunks(query, params) diff --git a/llama_stack/providers/remote/vector_io/sample/sample.py b/llama_stack/providers/remote/vector_io/sample/sample.py index b051eb544..e311be39d 100644 --- a/llama_stack/providers/remote/vector_io/sample/sample.py +++ b/llama_stack/providers/remote/vector_io/sample/sample.py @@ -4,19 +4,22 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.apis.memory import Memory -from llama_stack.apis.memory_banks import MemoryBank +from llama_stack.apis.vector_dbs import VectorDB +from llama_stack.apis.vector_io import VectorIO from .config import SampleConfig -class SampleMemoryImpl(Memory): +class SampleMemoryImpl(VectorIO): def __init__(self, config: SampleConfig): self.config = config - async def register_memory_bank(self, memory_bank: MemoryBank) -> None: - # these are the memory banks the Llama Stack will use to route requests to this provider + async def register_vector_db(self, vector_db: VectorDB) -> None: + # these are the vector dbs the Llama Stack will use to route requests to this provider # perform validation here if necessary pass async def initialize(self): pass + + async def shutdown(self): + pass diff --git a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py index f1433090d..ea9ce5185 100644 --- a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py +++ b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py @@ -15,18 +15,13 @@ from weaviate.classes.init import Auth from weaviate.classes.query import Filter from llama_stack.apis.common.content_types import InterleavedContent -from llama_stack.apis.memory import ( - Chunk, - Memory, - MemoryBankDocument, - QueryDocumentsResponse, -) -from llama_stack.apis.memory_banks import MemoryBank, MemoryBankType +from llama_stack.apis.vector_dbs import VectorDB +from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.distribution.request_headers import NeedsRequestProviderData -from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate +from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.utils.memory.vector_store import ( - BankWithIndex, EmbeddingIndex, + VectorDBWithIndex, ) from .config import WeaviateConfig, WeaviateRequestProviderData @@ -49,7 +44,7 @@ class WeaviateIndex(EmbeddingIndex): data_objects.append( wvc.data.DataObject( properties={ - "chunk_content": chunk.json(), + "chunk_content": chunk.model_dump_json(), }, vector=embeddings[i].tolist(), ) @@ -63,7 +58,7 @@ class WeaviateIndex(EmbeddingIndex): async def query( self, embedding: NDArray, k: int, score_threshold: float - ) -> QueryDocumentsResponse: + ) -> QueryChunksResponse: collection = self.client.collections.get(self.collection_name) results = collection.query.near_vector( @@ -86,7 +81,7 @@ class WeaviateIndex(EmbeddingIndex): chunks.append(chunk) scores.append(1.0 / doc.metadata.distance) - return QueryDocumentsResponse(chunks=chunks, scores=scores) + return QueryChunksResponse(chunks=chunks, scores=scores) async def delete(self, chunk_ids: List[str]) -> None: collection = self.client.collections.get(self.collection_name) @@ -96,9 +91,9 @@ class WeaviateIndex(EmbeddingIndex): class WeaviateMemoryAdapter( - Memory, + VectorIO, NeedsRequestProviderData, - MemoryBanksProtocolPrivate, + VectorDBsProtocolPrivate, ): def __init__(self, config: WeaviateConfig, inference_api: Api.inference) -> None: self.config = config @@ -129,20 +124,16 @@ class WeaviateMemoryAdapter( for client in self.client_cache.values(): client.close() - async def register_memory_bank( + async def register_vector_db( self, - memory_bank: MemoryBank, + vector_db: VectorDB, ) -> None: - assert ( - memory_bank.memory_bank_type == MemoryBankType.vector.value - ), f"Only vector banks are supported {memory_bank.memory_bank_type}" - client = self._get_client() # Create collection if it doesn't exist - if not client.collections.exists(memory_bank.identifier): + if not client.collections.exists(vector_db.identifier): client.collections.create( - name=memory_bank.identifier, + name=vector_db.identifier, vectorizer_config=wvc.config.Configure.Vectorizer.none(), properties=[ wvc.config.Property( @@ -152,52 +143,54 @@ class WeaviateMemoryAdapter( ], ) - self.cache[memory_bank.identifier] = BankWithIndex( - memory_bank, - WeaviateIndex(client=client, collection_name=memory_bank.identifier), + self.cache[vector_db.identifier] = VectorDBWithIndex( + vector_db, + WeaviateIndex(client=client, collection_name=vector_db.identifier), self.inference_api, ) - async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]: - if bank_id in self.cache: - return self.cache[bank_id] + async def _get_and_cache_vector_db_index( + self, vector_db_id: str + ) -> Optional[VectorDBWithIndex]: + if vector_db_id in self.cache: + return self.cache[vector_db_id] - bank = await self.memory_bank_store.get_memory_bank(bank_id) - if not bank: - raise ValueError(f"Bank {bank_id} not found") + vector_db = await self.vector_db_store.get_vector_db(vector_db_id) + if not vector_db: + raise ValueError(f"Vector DB {vector_db_id} not found") client = self._get_client() - if not client.collections.exists(bank.identifier): - raise ValueError(f"Collection with name `{bank.identifier}` not found") + if not client.collections.exists(vector_db.identifier): + raise ValueError(f"Collection with name `{vector_db.identifier}` not found") - index = BankWithIndex( - bank=bank, - index=WeaviateIndex(client=client, collection_name=bank_id), + index = VectorDBWithIndex( + vector_db=vector_db, + index=WeaviateIndex(client=client, collection_name=vector_db.identifier), inference_api=self.inference_api, ) - self.cache[bank_id] = index + self.cache[vector_db_id] = index return index - async def insert_documents( + async def insert_chunks( self, - bank_id: str, - documents: List[MemoryBankDocument], + vector_db_id: str, + chunks: List[Chunk], ttl_seconds: Optional[int] = None, ) -> None: - index = await self._get_and_cache_bank_index(bank_id) + index = await self._get_and_cache_vector_db_index(vector_db_id) if not index: - raise ValueError(f"Bank {bank_id} not found") + raise ValueError(f"Vector DB {vector_db_id} not found") - await index.insert_documents(documents) + await index.insert_chunks(chunks) - async def query_documents( + async def query_chunks( self, - bank_id: str, + vector_db_id: str, query: InterleavedContent, params: Optional[Dict[str, Any]] = None, - ) -> QueryDocumentsResponse: - index = await self._get_and_cache_bank_index(bank_id) + ) -> QueryChunksResponse: + index = await self._get_and_cache_vector_db_index(vector_db_id) if not index: - raise ValueError(f"Bank {bank_id} not found") + raise ValueError(f"Vector DB {vector_db_id} not found") - return await index.query_documents(query, params) + return await index.query_chunks(query, params) diff --git a/llama_stack/providers/tests/eval/fixtures.py b/llama_stack/providers/tests/eval/fixtures.py index 37bb0527a..009e65fb3 100644 --- a/llama_stack/providers/tests/eval/fixtures.py +++ b/llama_stack/providers/tests/eval/fixtures.py @@ -53,7 +53,7 @@ async def eval_stack( "inference", "agents", "safety", - "memory", + "vector_io", "tool_runtime", ]: fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}") @@ -69,7 +69,7 @@ async def eval_stack( Api.scoring, Api.agents, Api.safety, - Api.memory, + Api.vector_io, Api.tool_runtime, ], providers, diff --git a/llama_stack/providers/tests/tools/fixtures.py b/llama_stack/providers/tests/tools/fixtures.py index a559dbf8c..03752881a 100644 --- a/llama_stack/providers/tests/tools/fixtures.py +++ b/llama_stack/providers/tests/tools/fixtures.py @@ -83,7 +83,7 @@ async def tools_stack( providers = {} provider_data = {} - for key in ["inference", "memory", "tool_runtime"]: + for key in ["inference", "vector_io", "tool_runtime"]: fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}") providers[key] = fixture.providers if key == "inference": @@ -117,7 +117,12 @@ async def tools_stack( ) test_stack = await construct_stack_for_test( - [Api.tool_groups, Api.inference, Api.memory, Api.tool_runtime], + [ + Api.tool_groups, + Api.inference, + Api.vector_io, + Api.tool_runtime, + ], providers, provider_data, models=models, diff --git a/llama_stack/providers/tests/tools/test_tools.py b/llama_stack/providers/tests/tools/test_tools.py index 16081b939..62b18ea66 100644 --- a/llama_stack/providers/tests/tools/test_tools.py +++ b/llama_stack/providers/tests/tools/test_tools.py @@ -8,10 +8,7 @@ import os import pytest -from llama_stack.apis.inference import UserMessage -from llama_stack.apis.memory import MemoryBankDocument -from llama_stack.apis.memory_banks import VectorMemoryBankParams -from llama_stack.apis.tools import ToolInvocationResult +from llama_stack.apis.tools import RAGDocument, RAGQueryResult, ToolInvocationResult from llama_stack.providers.datatypes import Api @@ -36,7 +33,7 @@ def sample_documents(): "lora_finetune.rst", ] return [ - MemoryBankDocument( + RAGDocument( document_id=f"num-{i}", content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}", mime_type="text/plain", @@ -57,7 +54,7 @@ class TestTools: # Execute the tool response = await tools_impl.invoke_tool( - tool_name="web_search", args={"query": sample_search_query} + tool_name="web_search", kwargs={"query": sample_search_query} ) # Verify the response @@ -75,7 +72,7 @@ class TestTools: tools_impl = tools_stack.impls[Api.tool_runtime] response = await tools_impl.invoke_tool( - tool_name="wolfram_alpha", args={"query": sample_wolfram_alpha_query} + tool_name="wolfram_alpha", kwargs={"query": sample_wolfram_alpha_query} ) # Verify the response @@ -85,43 +82,33 @@ class TestTools: assert isinstance(response.content, str) @pytest.mark.asyncio - async def test_memory_tool(self, tools_stack, sample_documents): + async def test_rag_tool(self, tools_stack, sample_documents): """Test the memory tool functionality.""" - memory_banks_impl = tools_stack.impls[Api.memory_banks] - memory_impl = tools_stack.impls[Api.memory] + vector_dbs_impl = tools_stack.impls[Api.vector_dbs] tools_impl = tools_stack.impls[Api.tool_runtime] # Register memory bank - await memory_banks_impl.register_memory_bank( - memory_bank_id="test_bank", - params=VectorMemoryBankParams( - embedding_model="all-MiniLM-L6-v2", - chunk_size_in_tokens=512, - overlap_size_in_tokens=64, - ), + await vector_dbs_impl.register( + vector_db_id="test_bank", + embedding_model="all-MiniLM-L6-v2", + embedding_dimension=384, provider_id="faiss", ) # Insert documents into memory - await memory_impl.insert_documents( - bank_id="test_bank", + await tools_impl.rag_tool.insert_documents( documents=sample_documents, + vector_db_id="test_bank", + chunk_size_in_tokens=512, ) # Execute the memory tool - response = await tools_impl.invoke_tool( - tool_name="memory", - args={ - "messages": [ - UserMessage( - content="What are the main topics covered in the documentation?", - ) - ], - "memory_bank_ids": ["test_bank"], - }, + response = await tools_impl.rag_tool.query_context( + content="What are the main topics covered in the documentation?", + vector_db_ids=["test_bank"], ) # Verify the response - assert isinstance(response, ToolInvocationResult) + assert isinstance(response, RAGQueryResult) assert response.content is not None assert len(response.content) > 0