[memory refactor][5/n] Migrate all vector_io providers (#835)

See https://github.com/meta-llama/llama-stack/issues/827 for the broader
design.

This PR finishes off all the stragglers and migrates everything to the
new naming.
This commit is contained in:
Ashwin Bharambe 2025-01-22 10:17:59 -08:00 committed by GitHub
parent 63f37f9b7c
commit c9e5578151
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
78 changed files with 504 additions and 623 deletions

View file

@ -15,18 +15,13 @@ from weaviate.classes.init import Auth
from weaviate.classes.query import Filter
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.memory import (
Chunk,
Memory,
MemoryBankDocument,
QueryDocumentsResponse,
)
from llama_stack.apis.memory_banks import MemoryBank, MemoryBankType
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex,
EmbeddingIndex,
VectorDBWithIndex,
)
from .config import WeaviateConfig, WeaviateRequestProviderData
@ -49,7 +44,7 @@ class WeaviateIndex(EmbeddingIndex):
data_objects.append(
wvc.data.DataObject(
properties={
"chunk_content": chunk.json(),
"chunk_content": chunk.model_dump_json(),
},
vector=embeddings[i].tolist(),
)
@ -63,7 +58,7 @@ class WeaviateIndex(EmbeddingIndex):
async def query(
self, embedding: NDArray, k: int, score_threshold: float
) -> QueryDocumentsResponse:
) -> QueryChunksResponse:
collection = self.client.collections.get(self.collection_name)
results = collection.query.near_vector(
@ -86,7 +81,7 @@ class WeaviateIndex(EmbeddingIndex):
chunks.append(chunk)
scores.append(1.0 / doc.metadata.distance)
return QueryDocumentsResponse(chunks=chunks, scores=scores)
return QueryChunksResponse(chunks=chunks, scores=scores)
async def delete(self, chunk_ids: List[str]) -> None:
collection = self.client.collections.get(self.collection_name)
@ -96,9 +91,9 @@ class WeaviateIndex(EmbeddingIndex):
class WeaviateMemoryAdapter(
Memory,
VectorIO,
NeedsRequestProviderData,
MemoryBanksProtocolPrivate,
VectorDBsProtocolPrivate,
):
def __init__(self, config: WeaviateConfig, inference_api: Api.inference) -> None:
self.config = config
@ -129,20 +124,16 @@ class WeaviateMemoryAdapter(
for client in self.client_cache.values():
client.close()
async def register_memory_bank(
async def register_vector_db(
self,
memory_bank: MemoryBank,
vector_db: VectorDB,
) -> None:
assert (
memory_bank.memory_bank_type == MemoryBankType.vector.value
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
client = self._get_client()
# Create collection if it doesn't exist
if not client.collections.exists(memory_bank.identifier):
if not client.collections.exists(vector_db.identifier):
client.collections.create(
name=memory_bank.identifier,
name=vector_db.identifier,
vectorizer_config=wvc.config.Configure.Vectorizer.none(),
properties=[
wvc.config.Property(
@ -152,52 +143,54 @@ class WeaviateMemoryAdapter(
],
)
self.cache[memory_bank.identifier] = BankWithIndex(
memory_bank,
WeaviateIndex(client=client, collection_name=memory_bank.identifier),
self.cache[vector_db.identifier] = VectorDBWithIndex(
vector_db,
WeaviateIndex(client=client, collection_name=vector_db.identifier),
self.inference_api,
)
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
if bank_id in self.cache:
return self.cache[bank_id]
async def _get_and_cache_vector_db_index(
self, vector_db_id: str
) -> Optional[VectorDBWithIndex]:
if vector_db_id in self.cache:
return self.cache[vector_db_id]
bank = await self.memory_bank_store.get_memory_bank(bank_id)
if not bank:
raise ValueError(f"Bank {bank_id} not found")
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
if not vector_db:
raise ValueError(f"Vector DB {vector_db_id} not found")
client = self._get_client()
if not client.collections.exists(bank.identifier):
raise ValueError(f"Collection with name `{bank.identifier}` not found")
if not client.collections.exists(vector_db.identifier):
raise ValueError(f"Collection with name `{vector_db.identifier}` not found")
index = BankWithIndex(
bank=bank,
index=WeaviateIndex(client=client, collection_name=bank_id),
index = VectorDBWithIndex(
vector_db=vector_db,
index=WeaviateIndex(client=client, collection_name=vector_db.identifier),
inference_api=self.inference_api,
)
self.cache[bank_id] = index
self.cache[vector_db_id] = index
return index
async def insert_documents(
async def insert_chunks(
self,
bank_id: str,
documents: List[MemoryBankDocument],
vector_db_id: str,
chunks: List[Chunk],
ttl_seconds: Optional[int] = None,
) -> None:
index = await self._get_and_cache_bank_index(bank_id)
index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index:
raise ValueError(f"Bank {bank_id} not found")
raise ValueError(f"Vector DB {vector_db_id} not found")
await index.insert_documents(documents)
await index.insert_chunks(chunks)
async def query_documents(
async def query_chunks(
self,
bank_id: str,
vector_db_id: str,
query: InterleavedContent,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
index = await self._get_and_cache_bank_index(bank_id)
) -> QueryChunksResponse:
index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index:
raise ValueError(f"Bank {bank_id} not found")
raise ValueError(f"Vector DB {vector_db_id} not found")
return await index.query_documents(query, params)
return await index.query_chunks(query, params)