[memory refactor][5/n] Migrate all vector_io providers

This commit is contained in:
Ashwin Bharambe 2025-01-21 21:28:57 -08:00
parent 63f37f9b7c
commit 5605917361
11 changed files with 233 additions and 343 deletions

View file

@ -413,8 +413,8 @@ class ChatAgent(ShieldRunnerMixin):
session_info = await self.storage.get_session_info(session_id) session_info = await self.storage.get_session_info(session_id)
# if the session has a memory bank id, let the memory tool use it # if the session has a memory bank id, let the memory tool use it
if session_info.memory_bank_id: if session_info.vector_db_id:
vector_db_ids.append(session_info.memory_bank_id) vector_db_ids.append(session_info.vector_db_id)
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
@ -829,7 +829,7 @@ class ChatAgent(ShieldRunnerMixin):
msg = await attachment_message(self.tempdir, url_items) msg = await attachment_message(self.tempdir, url_items)
input_messages.append(msg) input_messages.append(msg)
# Since memory is present, add all the data to the memory bank # Since memory is present, add all the data to the memory bank
await self.add_to_session_memory_bank(session_id, documents) await self.add_to_session_vector_db(session_id, documents)
elif code_interpreter_tool: elif code_interpreter_tool:
# if only code_interpreter is available, we download the URLs to a tempdir # if only code_interpreter is available, we download the URLs to a tempdir
# and attach the path to them as a message to inference with the # and attach the path to them as a message to inference with the
@ -838,7 +838,7 @@ class ChatAgent(ShieldRunnerMixin):
input_messages.append(msg) input_messages.append(msg)
elif memory_tool: elif memory_tool:
# if only memory is available, we load the data from the URLs and content items to the memory bank # if only memory is available, we load the data from the URLs and content items to the memory bank
await self.add_to_session_memory_bank(session_id, documents) await self.add_to_session_vector_db(session_id, documents)
else: else:
# if no memory or code_interpreter tool is available, # if no memory or code_interpreter tool is available,
# we try to load the data from the URLs and content items as a message to inference # we try to load the data from the URLs and content items as a message to inference
@ -848,31 +848,31 @@ class ChatAgent(ShieldRunnerMixin):
+ await load_data_from_urls(url_items) + await load_data_from_urls(url_items)
) )
async def _ensure_memory_bank(self, session_id: str) -> str: async def _ensure_vector_db(self, session_id: str) -> str:
session_info = await self.storage.get_session_info(session_id) session_info = await self.storage.get_session_info(session_id)
if session_info is None: if session_info is None:
raise ValueError(f"Session {session_id} not found") raise ValueError(f"Session {session_id} not found")
if session_info.memory_bank_id is None: if session_info.vector_db_id is None:
bank_id = f"memory_bank_{session_id}" vector_db_id = f"vector_db_{session_id}"
# TODO: the semantic for registration is definitely not "creation" # TODO: the semantic for registration is definitely not "creation"
# so we need to fix it if we expect the agent to create a new vector db # so we need to fix it if we expect the agent to create a new vector db
# for each session # for each session
await self.vector_io_api.register_vector_db( await self.vector_io_api.register_vector_db(
vector_db_id=bank_id, vector_db_id=vector_db_id,
embedding_model="all-MiniLM-L6-v2", embedding_model="all-MiniLM-L6-v2",
) )
await self.storage.add_memory_bank_to_session(session_id, bank_id) await self.storage.add_vector_db_to_session(session_id, vector_db_id)
else: else:
bank_id = session_info.memory_bank_id vector_db_id = session_info.vector_db_id
return bank_id return vector_db_id
async def add_to_session_memory_bank( async def add_to_session_vector_db(
self, session_id: str, data: List[Document] self, session_id: str, data: List[Document]
) -> None: ) -> None:
vector_db_id = await self._ensure_memory_bank(session_id) vector_db_id = await self._ensure_vector_db(session_id)
documents = [ documents = [
RAGDocument( RAGDocument(
document_id=str(uuid.uuid4()), document_id=str(uuid.uuid4()),

View file

@ -21,7 +21,7 @@ log = logging.getLogger(__name__)
class AgentSessionInfo(BaseModel): class AgentSessionInfo(BaseModel):
session_id: str session_id: str
session_name: str session_name: str
memory_bank_id: Optional[str] = None vector_db_id: Optional[str] = None
started_at: datetime started_at: datetime
@ -52,12 +52,12 @@ class AgentPersistence:
return AgentSessionInfo(**json.loads(value)) return AgentSessionInfo(**json.loads(value))
async def add_memory_bank_to_session(self, session_id: str, bank_id: str): async def add_vector_db_to_session(self, session_id: str, vector_db_id: str):
session_info = await self.get_session_info(session_id) session_info = await self.get_session_info(session_id)
if session_info is None: if session_info is None:
raise ValueError(f"Session {session_id} not found") raise ValueError(f"Session {session_id} not found")
session_info.memory_bank_id = bank_id session_info.vector_db_id = vector_db_id
await self.kvstore.set( await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}", key=f"session:{self.agent_id}:{session_id}",
value=session_info.model_dump_json(), value=session_info.model_dump_json(),

View file

@ -29,10 +29,9 @@ from llama_stack.apis.inference import (
SamplingParams, SamplingParams,
ToolChoice, ToolChoice,
ToolDefinition, ToolDefinition,
ToolPromptFormat,
UserMessage, UserMessage,
) )
from llama_stack.apis.memory import MemoryBank
from llama_stack.apis.memory_banks import BankParams, VectorMemoryBank
from llama_stack.apis.safety import RunShieldResponse from llama_stack.apis.safety import RunShieldResponse
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
Tool, Tool,
@ -40,8 +39,9 @@ from llama_stack.apis.tools import (
ToolGroup, ToolGroup,
ToolHost, ToolHost,
ToolInvocationResult, ToolInvocationResult,
ToolPromptFormat,
) )
from llama_stack.apis.vector_io import QueryChunksResponse
from llama_stack.providers.inline.agents.meta_reference.agent_instance import ( from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
MEMORY_QUERY_TOOL, MEMORY_QUERY_TOOL,
) )
@ -110,68 +110,22 @@ class MockSafetyAPI:
return RunShieldResponse(violation=None) return RunShieldResponse(violation=None)
class MockMemoryAPI: class MockVectorIOAPI:
def __init__(self): def __init__(self):
self.memory_banks = {} self.chunks = {}
self.documents = {}
async def create_memory_bank(self, name, config, url=None): async def insert_chunks(self, vector_db_id, chunks, ttl_seconds=None):
bank_id = f"bank_{len(self.memory_banks)}" for chunk in chunks:
bank = MemoryBank(bank_id, name, config, url) metadata = chunk.metadata
self.memory_banks[bank_id] = bank self.chunks[vector_db_id][metadata["document_id"]] = chunk
self.documents[bank_id] = {}
return bank
async def list_memory_banks(self): async def query_chunks(self, vector_db_id, query, params=None):
return list(self.memory_banks.values()) if vector_db_id not in self.chunks:
raise ValueError(f"Bank {vector_db_id} not found")
async def get_memory_bank(self, bank_id): chunks = list(self.chunks[vector_db_id].values())
return self.memory_banks.get(bank_id)
async def drop_memory_bank(self, bank_id):
if bank_id in self.memory_banks:
del self.memory_banks[bank_id]
del self.documents[bank_id]
return bank_id
async def insert_documents(self, bank_id, documents, ttl_seconds=None):
if bank_id not in self.documents:
raise ValueError(f"Bank {bank_id} not found")
for doc in documents:
self.documents[bank_id][doc.document_id] = doc
async def update_documents(self, bank_id, documents):
if bank_id not in self.documents:
raise ValueError(f"Bank {bank_id} not found")
for doc in documents:
if doc.document_id in self.documents[bank_id]:
self.documents[bank_id][doc.document_id] = doc
async def query_documents(self, bank_id, query, params=None):
if bank_id not in self.documents:
raise ValueError(f"Bank {bank_id} not found")
# Simple mock implementation: return all documents
chunks = [
{"content": doc.content, "token_count": 10, "document_id": doc.document_id}
for doc in self.documents[bank_id].values()
]
scores = [1.0] * len(chunks) scores = [1.0] * len(chunks)
return {"chunks": chunks, "scores": scores} return QueryChunksResponse(chunks=chunks, scores=scores)
async def get_documents(self, bank_id, document_ids):
if bank_id not in self.documents:
raise ValueError(f"Bank {bank_id} not found")
return [
self.documents[bank_id][doc_id]
for doc_id in document_ids
if doc_id in self.documents[bank_id]
]
async def delete_documents(self, bank_id, document_ids):
if bank_id not in self.documents:
raise ValueError(f"Bank {bank_id} not found")
for doc_id in document_ids:
self.documents[bank_id].pop(doc_id, None)
class MockToolGroupsAPI: class MockToolGroupsAPI:
@ -241,31 +195,6 @@ class MockToolRuntimeAPI:
return ToolInvocationResult(content={"result": "Mock tool result"}) return ToolInvocationResult(content={"result": "Mock tool result"})
class MockMemoryBanksAPI:
async def list_memory_banks(self) -> List[MemoryBank]:
return []
async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]:
return None
async def register_memory_bank(
self,
memory_bank_id: str,
params: BankParams,
provider_id: Optional[str] = None,
provider_memory_bank_id: Optional[str] = None,
) -> MemoryBank:
return VectorMemoryBank(
identifier=memory_bank_id,
provider_resource_id=provider_memory_bank_id or memory_bank_id,
embedding_model="mock_model",
chunk_size_in_tokens=512,
)
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
pass
@pytest.fixture @pytest.fixture
def mock_inference_api(): def mock_inference_api():
return MockInferenceAPI() return MockInferenceAPI()
@ -277,8 +206,8 @@ def mock_safety_api():
@pytest.fixture @pytest.fixture
def mock_memory_api(): def mock_vector_io_api():
return MockMemoryAPI() return MockVectorIOAPI()
@pytest.fixture @pytest.fixture
@ -291,17 +220,11 @@ def mock_tool_runtime_api():
return MockToolRuntimeAPI() return MockToolRuntimeAPI()
@pytest.fixture
def mock_memory_banks_api():
return MockMemoryBanksAPI()
@pytest.fixture @pytest.fixture
async def get_agents_impl( async def get_agents_impl(
mock_inference_api, mock_inference_api,
mock_safety_api, mock_safety_api,
mock_memory_api, mock_vector_io_api,
mock_memory_banks_api,
mock_tool_runtime_api, mock_tool_runtime_api,
mock_tool_groups_api, mock_tool_groups_api,
): ):
@ -314,8 +237,7 @@ async def get_agents_impl(
), ),
inference_api=mock_inference_api, inference_api=mock_inference_api,
safety_api=mock_safety_api, safety_api=mock_safety_api,
memory_api=mock_memory_api, vector_io_api=mock_vector_io_api,
memory_banks_api=mock_memory_banks_api,
tool_runtime_api=mock_tool_runtime_api, tool_runtime_api=mock_tool_runtime_api,
tool_groups_api=mock_tool_groups_api, tool_groups_api=mock_tool_groups_api,
) )
@ -484,7 +406,7 @@ async def test_chat_agent_tools(
toolgroups_for_turn=[ toolgroups_for_turn=[
AgentToolGroupWithArgs( AgentToolGroupWithArgs(
name=MEMORY_TOOLGROUP, name=MEMORY_TOOLGROUP,
args={"memory_banks": ["test_memory_bank"]}, args={"vector_dbs": ["test_vector_db"]},
) )
] ]
) )

View file

@ -6,25 +6,20 @@
import asyncio import asyncio
import json import json
import logging import logging
from typing import List, Optional, Union from typing import Any, Dict, List, Optional, Union
from urllib.parse import urlparse from urllib.parse import urlparse
import chromadb import chromadb
from numpy.typing import NDArray from numpy.typing import NDArray
from llama_stack.apis.inference import InterleavedContent from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.memory import ( from llama_stack.apis.vector_dbs import VectorDB
Chunk, from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
Memory, from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
MemoryBankDocument, from llama_stack.providers.inline.vector_io.chroma import ChromaInlineImplConfig
QueryDocumentsResponse,
)
from llama_stack.apis.memory_banks import MemoryBank, MemoryBankType
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex,
EmbeddingIndex, EmbeddingIndex,
VectorDBWithIndex,
) )
from .config import ChromaRemoteImplConfig from .config import ChromaRemoteImplConfig
@ -61,7 +56,7 @@ class ChromaIndex(EmbeddingIndex):
async def query( async def query(
self, embedding: NDArray, k: int, score_threshold: float self, embedding: NDArray, k: int, score_threshold: float
) -> QueryDocumentsResponse: ) -> QueryChunksResponse:
results = await maybe_await( results = await maybe_await(
self.collection.query( self.collection.query(
query_embeddings=[embedding.tolist()], query_embeddings=[embedding.tolist()],
@ -85,13 +80,13 @@ class ChromaIndex(EmbeddingIndex):
chunks.append(chunk) chunks.append(chunk)
scores.append(1.0 / float(dist)) scores.append(1.0 / float(dist))
return QueryDocumentsResponse(chunks=chunks, scores=scores) return QueryChunksResponse(chunks=chunks, scores=scores)
async def delete(self): async def delete(self):
await maybe_await(self.client.delete_collection(self.collection.name)) await maybe_await(self.client.delete_collection(self.collection.name))
class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): class ChromaMemoryAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__( def __init__(
self, self,
config: Union[ChromaRemoteImplConfig, ChromaInlineImplConfig], config: Union[ChromaRemoteImplConfig, ChromaInlineImplConfig],
@ -123,60 +118,58 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def register_memory_bank( async def register_vector_db(
self, self,
memory_bank: MemoryBank, vector_db: VectorDB,
) -> None: ) -> None:
assert (
memory_bank.memory_bank_type == MemoryBankType.vector.value
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
collection = await maybe_await( collection = await maybe_await(
self.client.get_or_create_collection( self.client.get_or_create_collection(
name=memory_bank.identifier, name=vector_db.identifier,
metadata={"bank": memory_bank.model_dump_json()}, metadata={"vector_db": vector_db.model_dump_json()},
) )
) )
self.cache[memory_bank.identifier] = BankWithIndex( self.cache[vector_db.identifier] = VectorDBWithIndex(
memory_bank, ChromaIndex(self.client, collection), self.inference_api vector_db, ChromaIndex(self.client, collection), self.inference_api
) )
async def unregister_memory_bank(self, memory_bank_id: str) -> None: async def unregister_vector_db(self, vector_db_id: str) -> None:
await self.cache[memory_bank_id].index.delete() await self.cache[vector_db_id].index.delete()
del self.cache[memory_bank_id] del self.cache[vector_db_id]
async def insert_documents( async def insert_chunks(
self, self,
bank_id: str, vector_db_id: str,
documents: List[MemoryBankDocument], chunks: List[Chunk],
ttl_seconds: Optional[int] = None, embeddings: NDArray,
) -> None: ) -> None:
index = await self._get_and_cache_bank_index(bank_id) index = await self._get_and_cache_vector_db_index(vector_db_id)
await index.insert_documents(documents) await index.insert_chunks(chunks, embeddings)
async def query_documents( async def query_chunks(
self, self,
bank_id: str, vector_db_id: str,
query: InterleavedContent, query: InterleavedContent,
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse: ) -> QueryChunksResponse:
index = await self._get_and_cache_bank_index(bank_id) index = await self._get_and_cache_vector_db_index(vector_db_id)
return await index.query_documents(query, params) return await index.query_chunks(query, params)
async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex: async def _get_and_cache_vector_db_index(
if bank_id in self.cache: self, vector_db_id: str
return self.cache[bank_id] ) -> VectorDBWithIndex:
if vector_db_id in self.cache:
return self.cache[vector_db_id]
bank = await self.memory_bank_store.get_memory_bank(bank_id) vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
if not bank: if not vector_db:
raise ValueError(f"Bank {bank_id} not found in Llama Stack") raise ValueError(f"Vector DB {vector_db_id} not found in Llama Stack")
collection = await maybe_await(self.client.get_collection(bank_id)) collection = await maybe_await(self.client.get_collection(vector_db_id))
if not collection: if not collection:
raise ValueError(f"Bank {bank_id} not found in Chroma") raise ValueError(f"Vector DB {vector_db_id} not found in Chroma")
index = BankWithIndex( index = VectorDBWithIndex(
bank, ChromaIndex(self.client, collection), self.inference_api vector_db, ChromaIndex(self.client, collection), self.inference_api
) )
self.cache[bank_id] = index self.cache[vector_db_id] = index
return index return index

View file

@ -12,21 +12,16 @@ from numpy.typing import NDArray
from psycopg2 import sql from psycopg2 import sql
from psycopg2.extras import execute_values, Json from psycopg2.extras import execute_values, Json
from pydantic import BaseModel, parse_obj_as from pydantic import BaseModel, TypeAdapter
from llama_stack.apis.inference import InterleavedContent from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.memory import ( from llama_stack.apis.vector_dbs import VectorDB
Chunk, from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
Memory, from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
MemoryBankDocument,
QueryDocumentsResponse,
)
from llama_stack.apis.memory_banks import MemoryBank, MemoryBankType, VectorMemoryBank
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex,
EmbeddingIndex, EmbeddingIndex,
VectorDBWithIndex,
) )
from .config import PGVectorConfig from .config import PGVectorConfig
@ -50,20 +45,20 @@ def upsert_models(cur, keys_models: List[Tuple[str, BaseModel]]):
""" """
) )
values = [(key, Json(model.dict())) for key, model in keys_models] values = [(key, Json(model.model_dump())) for key, model in keys_models]
execute_values(cur, query, values, template="(%s, %s)") execute_values(cur, query, values, template="(%s, %s)")
def load_models(cur, cls): def load_models(cur, cls):
cur.execute("SELECT key, data FROM metadata_store") cur.execute("SELECT key, data FROM metadata_store")
rows = cur.fetchall() rows = cur.fetchall()
return [parse_obj_as(cls, row["data"]) for row in rows] return [TypeAdapter(cls).validate_python(row["data"]) for row in rows]
class PGVectorIndex(EmbeddingIndex): class PGVectorIndex(EmbeddingIndex):
def __init__(self, bank: VectorMemoryBank, dimension: int, cursor): def __init__(self, vector_db: VectorDB, dimension: int, cursor):
self.cursor = cursor self.cursor = cursor
self.table_name = f"vector_store_{bank.identifier}" self.table_name = f"vector_store_{vector_db.identifier}"
self.cursor.execute( self.cursor.execute(
f""" f"""
@ -85,7 +80,7 @@ class PGVectorIndex(EmbeddingIndex):
values.append( values.append(
( (
f"{chunk.document_id}:chunk-{i}", f"{chunk.document_id}:chunk-{i}",
Json(chunk.dict()), Json(chunk.model_dump()),
embeddings[i].tolist(), embeddings[i].tolist(),
) )
) )
@ -101,7 +96,7 @@ class PGVectorIndex(EmbeddingIndex):
async def query( async def query(
self, embedding: NDArray, k: int, score_threshold: float self, embedding: NDArray, k: int, score_threshold: float
) -> QueryDocumentsResponse: ) -> QueryChunksResponse:
self.cursor.execute( self.cursor.execute(
f""" f"""
SELECT document, embedding <-> %s::vector AS distance SELECT document, embedding <-> %s::vector AS distance
@ -119,13 +114,13 @@ class PGVectorIndex(EmbeddingIndex):
chunks.append(Chunk(**doc)) chunks.append(Chunk(**doc))
scores.append(1.0 / float(dist)) scores.append(1.0 / float(dist))
return QueryDocumentsResponse(chunks=chunks, scores=scores) return QueryChunksResponse(chunks=chunks, scores=scores)
async def delete(self): async def delete(self):
self.cursor.execute(f"DROP TABLE IF EXISTS {self.table_name}") self.cursor.execute(f"DROP TABLE IF EXISTS {self.table_name}")
class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): class PGVectorVectorDBAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: PGVectorConfig, inference_api: Api.inference) -> None: def __init__(self, config: PGVectorConfig, inference_api: Api.inference) -> None:
self.config = config self.config = config
self.inference_api = inference_api self.inference_api = inference_api
@ -167,46 +162,45 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def register_memory_bank(self, memory_bank: MemoryBank) -> None: async def register_vector_db(self, vector_db: VectorDB) -> None:
assert ( upsert_models(self.cursor, [(vector_db.identifier, vector_db)])
memory_bank.memory_bank_type == MemoryBankType.vector.value
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
upsert_models(self.cursor, [(memory_bank.identifier, memory_bank)]) index = PGVectorIndex(vector_db, vector_db.embedding_dimension, self.cursor)
index = PGVectorIndex(memory_bank, memory_bank.embedding_dimension, self.cursor) self.cache[vector_db.identifier] = VectorDBWithIndex(
self.cache[memory_bank.identifier] = BankWithIndex( vector_db, index, self.inference_api
memory_bank, index, self.inference_api
) )
async def unregister_memory_bank(self, memory_bank_id: str) -> None: async def unregister_vector_db(self, vector_db_id: str) -> None:
await self.cache[memory_bank_id].index.delete() await self.cache[vector_db_id].index.delete()
del self.cache[memory_bank_id] del self.cache[vector_db_id]
async def insert_documents( async def insert_chunks(
self, self,
bank_id: str, vector_db_id: str,
documents: List[MemoryBankDocument], chunks: List[Chunk],
ttl_seconds: Optional[int] = None, ttl_seconds: Optional[int] = None,
) -> None: ) -> None:
index = await self._get_and_cache_bank_index(bank_id) index = await self._get_and_cache_vector_db_index(vector_db_id)
await index.insert_documents(documents) await index.insert_chunks(chunks)
async def query_documents( async def query_chunks(
self, self,
bank_id: str, vector_db_id: str,
query: InterleavedContent, query: InterleavedContent,
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse: ) -> QueryChunksResponse:
index = await self._get_and_cache_bank_index(bank_id) index = await self._get_and_cache_vector_db_index(vector_db_id)
return await index.query_documents(query, params) return await index.query_chunks(query, params)
self.inference_api = inference_api async def _get_and_cache_vector_db_index(
self, vector_db_id: str
) -> VectorDBWithIndex:
if vector_db_id in self.cache:
return self.cache[vector_db_id]
async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex: vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
if bank_id in self.cache: index = PGVectorIndex(vector_db, vector_db.embedding_dimension, self.cursor)
return self.cache[bank_id] self.cache[vector_db_id] = VectorDBWithIndex(
vector_db, index, self.inference_api
bank = await self.memory_bank_store.get_memory_bank(bank_id) )
index = PGVectorIndex(bank, bank.embedding_dimension, self.cursor) return self.cache[vector_db_id]
self.cache[bank_id] = BankWithIndex(bank, index, self.inference_api)
return self.cache[bank_id]

View file

@ -13,19 +13,14 @@ from qdrant_client import AsyncQdrantClient, models
from qdrant_client.models import PointStruct from qdrant_client.models import PointStruct
from llama_stack.apis.inference import InterleavedContent from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.memory import ( from llama_stack.apis.vector_dbs import VectorDB
Chunk, from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
Memory, from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
MemoryBankDocument,
QueryDocumentsResponse,
)
from llama_stack.apis.memory_banks import MemoryBank, MemoryBankType
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.remote.memory.qdrant.config import QdrantConfig
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex,
EmbeddingIndex, EmbeddingIndex,
VectorDBWithIndex,
) )
from .config import QdrantConfig
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
CHUNK_ID_KEY = "_chunk_id" CHUNK_ID_KEY = "_chunk_id"
@ -76,7 +71,7 @@ class QdrantIndex(EmbeddingIndex):
async def query( async def query(
self, embedding: NDArray, k: int, score_threshold: float self, embedding: NDArray, k: int, score_threshold: float
) -> QueryDocumentsResponse: ) -> QueryChunksResponse:
results = ( results = (
await self.client.query_points( await self.client.query_points(
collection_name=self.collection_name, collection_name=self.collection_name,
@ -101,10 +96,10 @@ class QdrantIndex(EmbeddingIndex):
chunks.append(chunk) chunks.append(chunk)
scores.append(point.score) scores.append(point.score)
return QueryDocumentsResponse(chunks=chunks, scores=scores) return QueryChunksResponse(chunks=chunks, scores=scores)
class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): class QdrantVectorDBAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: QdrantConfig, inference_api: Api.inference) -> None: def __init__(self, config: QdrantConfig, inference_api: Api.inference) -> None:
self.config = config self.config = config
self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True)) self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True))
@ -117,58 +112,56 @@ class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
async def shutdown(self) -> None: async def shutdown(self) -> None:
self.client.close() self.client.close()
async def register_memory_bank( async def register_vector_db(
self, self,
memory_bank: MemoryBank, vector_db: VectorDB,
) -> None: ) -> None:
assert ( index = VectorDBWithIndex(
memory_bank.memory_bank_type == MemoryBankType.vector vector_db=vector_db,
), f"Only vector banks are supported {memory_bank.memory_bank_type}" index=QdrantIndex(self.client, vector_db.identifier),
index = BankWithIndex(
bank=memory_bank,
index=QdrantIndex(self.client, memory_bank.identifier),
inference_api=self.inference_api, inference_api=self.inference_api,
) )
self.cache[memory_bank.identifier] = index self.cache[vector_db.identifier] = index
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]: async def _get_and_cache_vector_db_index(
if bank_id in self.cache: self, vector_db_id: str
return self.cache[bank_id] ) -> Optional[VectorDBWithIndex]:
if vector_db_id in self.cache:
return self.cache[vector_db_id]
bank = await self.memory_bank_store.get_memory_bank(bank_id) vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
if not bank: if not vector_db:
raise ValueError(f"Bank {bank_id} not found") raise ValueError(f"Vector DB {vector_db_id} not found")
index = BankWithIndex( index = VectorDBWithIndex(
bank=bank, vector_db=vector_db,
index=QdrantIndex(client=self.client, collection_name=bank_id), index=QdrantIndex(client=self.client, collection_name=vector_db.identifier),
inference_api=self.inference_api, inference_api=self.inference_api,
) )
self.cache[bank_id] = index self.cache[vector_db_id] = index
return index return index
async def insert_documents( async def insert_chunks(
self, self,
bank_id: str, vector_db_id: str,
documents: List[MemoryBankDocument], chunks: List[Chunk],
ttl_seconds: Optional[int] = None, ttl_seconds: Optional[int] = None,
) -> None: ) -> None:
index = await self._get_and_cache_bank_index(bank_id) index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index: if not index:
raise ValueError(f"Bank {bank_id} not found") raise ValueError(f"Vector DB {vector_db_id} not found")
await index.insert_documents(documents) await index.insert_chunks(chunks)
async def query_documents( async def query_chunks(
self, self,
bank_id: str, vector_db_id: str,
query: InterleavedContent, query: InterleavedContent,
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse: ) -> QueryChunksResponse:
index = await self._get_and_cache_bank_index(bank_id) index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index: if not index:
raise ValueError(f"Bank {bank_id} not found") raise ValueError(f"Vector DB {vector_db_id} not found")
return await index.query_documents(query, params) return await index.query_chunks(query, params)

View file

@ -4,19 +4,22 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.apis.memory import Memory from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.memory_banks import MemoryBank from llama_stack.apis.vector_io import VectorIO
from .config import SampleConfig from .config import SampleConfig
class SampleMemoryImpl(Memory): class SampleMemoryImpl(VectorIO):
def __init__(self, config: SampleConfig): def __init__(self, config: SampleConfig):
self.config = config self.config = config
async def register_memory_bank(self, memory_bank: MemoryBank) -> None: async def register_vector_db(self, vector_db: VectorDB) -> None:
# these are the memory banks the Llama Stack will use to route requests to this provider # these are the vector dbs the Llama Stack will use to route requests to this provider
# perform validation here if necessary # perform validation here if necessary
pass pass
async def initialize(self): async def initialize(self):
pass pass
async def shutdown(self):
pass

View file

@ -15,18 +15,13 @@ from weaviate.classes.init import Auth
from weaviate.classes.query import Filter from weaviate.classes.query import Filter
from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.memory import ( from llama_stack.apis.vector_dbs import VectorDB
Chunk, from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
Memory,
MemoryBankDocument,
QueryDocumentsResponse,
)
from llama_stack.apis.memory_banks import MemoryBank, MemoryBankType
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex,
EmbeddingIndex, EmbeddingIndex,
VectorDBWithIndex,
) )
from .config import WeaviateConfig, WeaviateRequestProviderData from .config import WeaviateConfig, WeaviateRequestProviderData
@ -49,7 +44,7 @@ class WeaviateIndex(EmbeddingIndex):
data_objects.append( data_objects.append(
wvc.data.DataObject( wvc.data.DataObject(
properties={ properties={
"chunk_content": chunk.json(), "chunk_content": chunk.model_dump_json(),
}, },
vector=embeddings[i].tolist(), vector=embeddings[i].tolist(),
) )
@ -63,7 +58,7 @@ class WeaviateIndex(EmbeddingIndex):
async def query( async def query(
self, embedding: NDArray, k: int, score_threshold: float self, embedding: NDArray, k: int, score_threshold: float
) -> QueryDocumentsResponse: ) -> QueryChunksResponse:
collection = self.client.collections.get(self.collection_name) collection = self.client.collections.get(self.collection_name)
results = collection.query.near_vector( results = collection.query.near_vector(
@ -86,7 +81,7 @@ class WeaviateIndex(EmbeddingIndex):
chunks.append(chunk) chunks.append(chunk)
scores.append(1.0 / doc.metadata.distance) scores.append(1.0 / doc.metadata.distance)
return QueryDocumentsResponse(chunks=chunks, scores=scores) return QueryChunksResponse(chunks=chunks, scores=scores)
async def delete(self, chunk_ids: List[str]) -> None: async def delete(self, chunk_ids: List[str]) -> None:
collection = self.client.collections.get(self.collection_name) collection = self.client.collections.get(self.collection_name)
@ -96,9 +91,9 @@ class WeaviateIndex(EmbeddingIndex):
class WeaviateMemoryAdapter( class WeaviateMemoryAdapter(
Memory, VectorIO,
NeedsRequestProviderData, NeedsRequestProviderData,
MemoryBanksProtocolPrivate, VectorDBsProtocolPrivate,
): ):
def __init__(self, config: WeaviateConfig, inference_api: Api.inference) -> None: def __init__(self, config: WeaviateConfig, inference_api: Api.inference) -> None:
self.config = config self.config = config
@ -129,20 +124,16 @@ class WeaviateMemoryAdapter(
for client in self.client_cache.values(): for client in self.client_cache.values():
client.close() client.close()
async def register_memory_bank( async def register_vector_db(
self, self,
memory_bank: MemoryBank, vector_db: VectorDB,
) -> None: ) -> None:
assert (
memory_bank.memory_bank_type == MemoryBankType.vector.value
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
client = self._get_client() client = self._get_client()
# Create collection if it doesn't exist # Create collection if it doesn't exist
if not client.collections.exists(memory_bank.identifier): if not client.collections.exists(vector_db.identifier):
client.collections.create( client.collections.create(
name=memory_bank.identifier, name=vector_db.identifier,
vectorizer_config=wvc.config.Configure.Vectorizer.none(), vectorizer_config=wvc.config.Configure.Vectorizer.none(),
properties=[ properties=[
wvc.config.Property( wvc.config.Property(
@ -152,52 +143,54 @@ class WeaviateMemoryAdapter(
], ],
) )
self.cache[memory_bank.identifier] = BankWithIndex( self.cache[vector_db.identifier] = VectorDBWithIndex(
memory_bank, vector_db,
WeaviateIndex(client=client, collection_name=memory_bank.identifier), WeaviateIndex(client=client, collection_name=vector_db.identifier),
self.inference_api, self.inference_api,
) )
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]: async def _get_and_cache_vector_db_index(
if bank_id in self.cache: self, vector_db_id: str
return self.cache[bank_id] ) -> Optional[VectorDBWithIndex]:
if vector_db_id in self.cache:
return self.cache[vector_db_id]
bank = await self.memory_bank_store.get_memory_bank(bank_id) vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
if not bank: if not vector_db:
raise ValueError(f"Bank {bank_id} not found") raise ValueError(f"Vector DB {vector_db_id} not found")
client = self._get_client() client = self._get_client()
if not client.collections.exists(bank.identifier): if not client.collections.exists(vector_db.identifier):
raise ValueError(f"Collection with name `{bank.identifier}` not found") raise ValueError(f"Collection with name `{vector_db.identifier}` not found")
index = BankWithIndex( index = VectorDBWithIndex(
bank=bank, vector_db=vector_db,
index=WeaviateIndex(client=client, collection_name=bank_id), index=WeaviateIndex(client=client, collection_name=vector_db.identifier),
inference_api=self.inference_api, inference_api=self.inference_api,
) )
self.cache[bank_id] = index self.cache[vector_db_id] = index
return index return index
async def insert_documents( async def insert_chunks(
self, self,
bank_id: str, vector_db_id: str,
documents: List[MemoryBankDocument], chunks: List[Chunk],
ttl_seconds: Optional[int] = None, ttl_seconds: Optional[int] = None,
) -> None: ) -> None:
index = await self._get_and_cache_bank_index(bank_id) index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index: if not index:
raise ValueError(f"Bank {bank_id} not found") raise ValueError(f"Vector DB {vector_db_id} not found")
await index.insert_documents(documents) await index.insert_chunks(chunks)
async def query_documents( async def query_chunks(
self, self,
bank_id: str, vector_db_id: str,
query: InterleavedContent, query: InterleavedContent,
params: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse: ) -> QueryChunksResponse:
index = await self._get_and_cache_bank_index(bank_id) index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index: if not index:
raise ValueError(f"Bank {bank_id} not found") raise ValueError(f"Vector DB {vector_db_id} not found")
return await index.query_documents(query, params) return await index.query_chunks(query, params)

View file

@ -53,7 +53,7 @@ async def eval_stack(
"inference", "inference",
"agents", "agents",
"safety", "safety",
"memory", "vector_io",
"tool_runtime", "tool_runtime",
]: ]:
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}") fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
@ -69,7 +69,7 @@ async def eval_stack(
Api.scoring, Api.scoring,
Api.agents, Api.agents,
Api.safety, Api.safety,
Api.memory, Api.vector_io,
Api.tool_runtime, Api.tool_runtime,
], ],
providers, providers,

View file

@ -83,7 +83,7 @@ async def tools_stack(
providers = {} providers = {}
provider_data = {} provider_data = {}
for key in ["inference", "memory", "tool_runtime"]: for key in ["inference", "vector_io", "tool_runtime"]:
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}") fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
providers[key] = fixture.providers providers[key] = fixture.providers
if key == "inference": if key == "inference":
@ -117,7 +117,12 @@ async def tools_stack(
) )
test_stack = await construct_stack_for_test( test_stack = await construct_stack_for_test(
[Api.tool_groups, Api.inference, Api.memory, Api.tool_runtime], [
Api.tool_groups,
Api.inference,
Api.vector_io,
Api.tool_runtime,
],
providers, providers,
provider_data, provider_data,
models=models, models=models,

View file

@ -8,10 +8,7 @@ import os
import pytest import pytest
from llama_stack.apis.inference import UserMessage from llama_stack.apis.tools import RAGDocument, RAGQueryResult, ToolInvocationResult
from llama_stack.apis.memory import MemoryBankDocument
from llama_stack.apis.memory_banks import VectorMemoryBankParams
from llama_stack.apis.tools import ToolInvocationResult
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
@ -36,7 +33,7 @@ def sample_documents():
"lora_finetune.rst", "lora_finetune.rst",
] ]
return [ return [
MemoryBankDocument( RAGDocument(
document_id=f"num-{i}", document_id=f"num-{i}",
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}", content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
mime_type="text/plain", mime_type="text/plain",
@ -57,7 +54,7 @@ class TestTools:
# Execute the tool # Execute the tool
response = await tools_impl.invoke_tool( response = await tools_impl.invoke_tool(
tool_name="web_search", args={"query": sample_search_query} tool_name="web_search", kwargs={"query": sample_search_query}
) )
# Verify the response # Verify the response
@ -75,7 +72,7 @@ class TestTools:
tools_impl = tools_stack.impls[Api.tool_runtime] tools_impl = tools_stack.impls[Api.tool_runtime]
response = await tools_impl.invoke_tool( response = await tools_impl.invoke_tool(
tool_name="wolfram_alpha", args={"query": sample_wolfram_alpha_query} tool_name="wolfram_alpha", kwargs={"query": sample_wolfram_alpha_query}
) )
# Verify the response # Verify the response
@ -85,43 +82,33 @@ class TestTools:
assert isinstance(response.content, str) assert isinstance(response.content, str)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_memory_tool(self, tools_stack, sample_documents): async def test_rag_tool(self, tools_stack, sample_documents):
"""Test the memory tool functionality.""" """Test the memory tool functionality."""
memory_banks_impl = tools_stack.impls[Api.memory_banks] vector_dbs_impl = tools_stack.impls[Api.vector_dbs]
memory_impl = tools_stack.impls[Api.memory]
tools_impl = tools_stack.impls[Api.tool_runtime] tools_impl = tools_stack.impls[Api.tool_runtime]
# Register memory bank # Register memory bank
await memory_banks_impl.register_memory_bank( await vector_dbs_impl.register(
memory_bank_id="test_bank", vector_db_id="test_bank",
params=VectorMemoryBankParams( embedding_model="all-MiniLM-L6-v2",
embedding_model="all-MiniLM-L6-v2", embedding_dimension=384,
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
),
provider_id="faiss", provider_id="faiss",
) )
# Insert documents into memory # Insert documents into memory
await memory_impl.insert_documents( await tools_impl.rag_tool.insert_documents(
bank_id="test_bank",
documents=sample_documents, documents=sample_documents,
vector_db_id="test_bank",
chunk_size_in_tokens=512,
) )
# Execute the memory tool # Execute the memory tool
response = await tools_impl.invoke_tool( response = await tools_impl.rag_tool.query_context(
tool_name="memory", content="What are the main topics covered in the documentation?",
args={ vector_db_ids=["test_bank"],
"messages": [
UserMessage(
content="What are the main topics covered in the documentation?",
)
],
"memory_bank_ids": ["test_bank"],
},
) )
# Verify the response # Verify the response
assert isinstance(response, ToolInvocationResult) assert isinstance(response, RAGQueryResult)
assert response.content is not None assert response.content is not None
assert len(response.content) > 0 assert len(response.content) > 0