[memory refactor][5/n] Migrate all vector_io providers (#835)

See https://github.com/meta-llama/llama-stack/issues/827 for the broader
design.

This PR finishes off all the stragglers and migrates everything to the
new naming.
This commit is contained in:
Ashwin Bharambe 2025-01-22 10:17:59 -08:00 committed by GitHub
parent 63f37f9b7c
commit c9e5578151
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
78 changed files with 504 additions and 623 deletions

View file

@ -14,8 +14,8 @@ from .config import ChromaRemoteImplConfig
async def get_adapter_impl(
config: ChromaRemoteImplConfig, deps: Dict[Api, ProviderSpec]
):
from .chroma import ChromaMemoryAdapter
from .chroma import ChromaVectorIOAdapter
impl = ChromaMemoryAdapter(config, deps[Api.inference])
impl = ChromaVectorIOAdapter(config, deps[Api.inference])
await impl.initialize()
return impl

View file

@ -6,25 +6,20 @@
import asyncio
import json
import logging
from typing import List, Optional, Union
from typing import Any, Dict, List, Optional, Union
from urllib.parse import urlparse
import chromadb
from numpy.typing import NDArray
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.memory import (
Chunk,
Memory,
MemoryBankDocument,
QueryDocumentsResponse,
)
from llama_stack.apis.memory_banks import MemoryBank, MemoryBankType
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.chroma import ChromaInlineImplConfig
from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex,
EmbeddingIndex,
VectorDBWithIndex,
)
from .config import ChromaRemoteImplConfig
@ -61,7 +56,7 @@ class ChromaIndex(EmbeddingIndex):
async def query(
self, embedding: NDArray, k: int, score_threshold: float
) -> QueryDocumentsResponse:
) -> QueryChunksResponse:
results = await maybe_await(
self.collection.query(
query_embeddings=[embedding.tolist()],
@ -85,19 +80,19 @@ class ChromaIndex(EmbeddingIndex):
chunks.append(chunk)
scores.append(1.0 / float(dist))
return QueryDocumentsResponse(chunks=chunks, scores=scores)
return QueryChunksResponse(chunks=chunks, scores=scores)
async def delete(self):
await maybe_await(self.client.delete_collection(self.collection.name))
class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(
self,
config: Union[ChromaRemoteImplConfig, ChromaInlineImplConfig],
inference_api: Api.inference,
) -> None:
log.info(f"Initializing ChromaMemoryAdapter with url: {config}")
log.info(f"Initializing ChromaVectorIOAdapter with url: {config}")
self.config = config
self.inference_api = inference_api
@ -123,60 +118,58 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
async def shutdown(self) -> None:
pass
async def register_memory_bank(
async def register_vector_db(
self,
memory_bank: MemoryBank,
vector_db: VectorDB,
) -> None:
assert (
memory_bank.memory_bank_type == MemoryBankType.vector.value
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
collection = await maybe_await(
self.client.get_or_create_collection(
name=memory_bank.identifier,
metadata={"bank": memory_bank.model_dump_json()},
name=vector_db.identifier,
metadata={"vector_db": vector_db.model_dump_json()},
)
)
self.cache[memory_bank.identifier] = BankWithIndex(
memory_bank, ChromaIndex(self.client, collection), self.inference_api
self.cache[vector_db.identifier] = VectorDBWithIndex(
vector_db, ChromaIndex(self.client, collection), self.inference_api
)
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
await self.cache[memory_bank_id].index.delete()
del self.cache[memory_bank_id]
async def unregister_vector_db(self, vector_db_id: str) -> None:
await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id]
async def insert_documents(
async def insert_chunks(
self,
bank_id: str,
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
vector_db_id: str,
chunks: List[Chunk],
embeddings: NDArray,
) -> None:
index = await self._get_and_cache_bank_index(bank_id)
index = await self._get_and_cache_vector_db_index(vector_db_id)
await index.insert_documents(documents)
await index.insert_chunks(chunks, embeddings)
async def query_documents(
async def query_chunks(
self,
bank_id: str,
vector_db_id: str,
query: InterleavedContent,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
index = await self._get_and_cache_bank_index(bank_id)
) -> QueryChunksResponse:
index = await self._get_and_cache_vector_db_index(vector_db_id)
return await index.query_documents(query, params)
return await index.query_chunks(query, params)
async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex:
if bank_id in self.cache:
return self.cache[bank_id]
async def _get_and_cache_vector_db_index(
self, vector_db_id: str
) -> VectorDBWithIndex:
if vector_db_id in self.cache:
return self.cache[vector_db_id]
bank = await self.memory_bank_store.get_memory_bank(bank_id)
if not bank:
raise ValueError(f"Bank {bank_id} not found in Llama Stack")
collection = await maybe_await(self.client.get_collection(bank_id))
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
if not vector_db:
raise ValueError(f"Vector DB {vector_db_id} not found in Llama Stack")
collection = await maybe_await(self.client.get_collection(vector_db_id))
if not collection:
raise ValueError(f"Bank {bank_id} not found in Chroma")
index = BankWithIndex(
bank, ChromaIndex(self.client, collection), self.inference_api
raise ValueError(f"Vector DB {vector_db_id} not found in Chroma")
index = VectorDBWithIndex(
vector_db, ChromaIndex(self.client, collection), self.inference_api
)
self.cache[bank_id] = index
self.cache[vector_db_id] = index
return index

View file

@ -12,21 +12,16 @@ from numpy.typing import NDArray
from psycopg2 import sql
from psycopg2.extras import execute_values, Json
from pydantic import BaseModel, parse_obj_as
from pydantic import BaseModel, TypeAdapter
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.memory import (
Chunk,
Memory,
MemoryBankDocument,
QueryDocumentsResponse,
)
from llama_stack.apis.memory_banks import MemoryBank, MemoryBankType, VectorMemoryBank
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex,
EmbeddingIndex,
VectorDBWithIndex,
)
from .config import PGVectorConfig
@ -50,20 +45,20 @@ def upsert_models(cur, keys_models: List[Tuple[str, BaseModel]]):
"""
)
values = [(key, Json(model.dict())) for key, model in keys_models]
values = [(key, Json(model.model_dump())) for key, model in keys_models]
execute_values(cur, query, values, template="(%s, %s)")
def load_models(cur, cls):
cur.execute("SELECT key, data FROM metadata_store")
rows = cur.fetchall()
return [parse_obj_as(cls, row["data"]) for row in rows]
return [TypeAdapter(cls).validate_python(row["data"]) for row in rows]
class PGVectorIndex(EmbeddingIndex):
def __init__(self, bank: VectorMemoryBank, dimension: int, cursor):
def __init__(self, vector_db: VectorDB, dimension: int, cursor):
self.cursor = cursor
self.table_name = f"vector_store_{bank.identifier}"
self.table_name = f"vector_store_{vector_db.identifier}"
self.cursor.execute(
f"""
@ -85,7 +80,7 @@ class PGVectorIndex(EmbeddingIndex):
values.append(
(
f"{chunk.document_id}:chunk-{i}",
Json(chunk.dict()),
Json(chunk.model_dump()),
embeddings[i].tolist(),
)
)
@ -101,7 +96,7 @@ class PGVectorIndex(EmbeddingIndex):
async def query(
self, embedding: NDArray, k: int, score_threshold: float
) -> QueryDocumentsResponse:
) -> QueryChunksResponse:
self.cursor.execute(
f"""
SELECT document, embedding <-> %s::vector AS distance
@ -119,13 +114,13 @@ class PGVectorIndex(EmbeddingIndex):
chunks.append(Chunk(**doc))
scores.append(1.0 / float(dist))
return QueryDocumentsResponse(chunks=chunks, scores=scores)
return QueryChunksResponse(chunks=chunks, scores=scores)
async def delete(self):
self.cursor.execute(f"DROP TABLE IF EXISTS {self.table_name}")
class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
class PGVectorVectorDBAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: PGVectorConfig, inference_api: Api.inference) -> None:
self.config = config
self.inference_api = inference_api
@ -167,46 +162,45 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
async def shutdown(self) -> None:
pass
async def register_memory_bank(self, memory_bank: MemoryBank) -> None:
assert (
memory_bank.memory_bank_type == MemoryBankType.vector.value
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
async def register_vector_db(self, vector_db: VectorDB) -> None:
upsert_models(self.cursor, [(vector_db.identifier, vector_db)])
upsert_models(self.cursor, [(memory_bank.identifier, memory_bank)])
index = PGVectorIndex(memory_bank, memory_bank.embedding_dimension, self.cursor)
self.cache[memory_bank.identifier] = BankWithIndex(
memory_bank, index, self.inference_api
index = PGVectorIndex(vector_db, vector_db.embedding_dimension, self.cursor)
self.cache[vector_db.identifier] = VectorDBWithIndex(
vector_db, index, self.inference_api
)
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
await self.cache[memory_bank_id].index.delete()
del self.cache[memory_bank_id]
async def unregister_vector_db(self, vector_db_id: str) -> None:
await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id]
async def insert_documents(
async def insert_chunks(
self,
bank_id: str,
documents: List[MemoryBankDocument],
vector_db_id: str,
chunks: List[Chunk],
ttl_seconds: Optional[int] = None,
) -> None:
index = await self._get_and_cache_bank_index(bank_id)
await index.insert_documents(documents)
index = await self._get_and_cache_vector_db_index(vector_db_id)
await index.insert_chunks(chunks)
async def query_documents(
async def query_chunks(
self,
bank_id: str,
vector_db_id: str,
query: InterleavedContent,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
index = await self._get_and_cache_bank_index(bank_id)
return await index.query_documents(query, params)
) -> QueryChunksResponse:
index = await self._get_and_cache_vector_db_index(vector_db_id)
return await index.query_chunks(query, params)
self.inference_api = inference_api
async def _get_and_cache_vector_db_index(
self, vector_db_id: str
) -> VectorDBWithIndex:
if vector_db_id in self.cache:
return self.cache[vector_db_id]
async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex:
if bank_id in self.cache:
return self.cache[bank_id]
bank = await self.memory_bank_store.get_memory_bank(bank_id)
index = PGVectorIndex(bank, bank.embedding_dimension, self.cursor)
self.cache[bank_id] = BankWithIndex(bank, index, self.inference_api)
return self.cache[bank_id]
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
index = PGVectorIndex(vector_db, vector_db.embedding_dimension, self.cursor)
self.cache[vector_db_id] = VectorDBWithIndex(
vector_db, index, self.inference_api
)
return self.cache[vector_db_id]

View file

@ -13,19 +13,14 @@ from qdrant_client import AsyncQdrantClient, models
from qdrant_client.models import PointStruct
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.memory import (
Chunk,
Memory,
MemoryBankDocument,
QueryDocumentsResponse,
)
from llama_stack.apis.memory_banks import MemoryBank, MemoryBankType
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.remote.memory.qdrant.config import QdrantConfig
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex,
EmbeddingIndex,
VectorDBWithIndex,
)
from .config import QdrantConfig
log = logging.getLogger(__name__)
CHUNK_ID_KEY = "_chunk_id"
@ -76,7 +71,7 @@ class QdrantIndex(EmbeddingIndex):
async def query(
self, embedding: NDArray, k: int, score_threshold: float
) -> QueryDocumentsResponse:
) -> QueryChunksResponse:
results = (
await self.client.query_points(
collection_name=self.collection_name,
@ -101,10 +96,10 @@ class QdrantIndex(EmbeddingIndex):
chunks.append(chunk)
scores.append(point.score)
return QueryDocumentsResponse(chunks=chunks, scores=scores)
return QueryChunksResponse(chunks=chunks, scores=scores)
class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
class QdrantVectorDBAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: QdrantConfig, inference_api: Api.inference) -> None:
self.config = config
self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True))
@ -117,58 +112,56 @@ class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
async def shutdown(self) -> None:
self.client.close()
async def register_memory_bank(
async def register_vector_db(
self,
memory_bank: MemoryBank,
vector_db: VectorDB,
) -> None:
assert (
memory_bank.memory_bank_type == MemoryBankType.vector
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
index = BankWithIndex(
bank=memory_bank,
index=QdrantIndex(self.client, memory_bank.identifier),
index = VectorDBWithIndex(
vector_db=vector_db,
index=QdrantIndex(self.client, vector_db.identifier),
inference_api=self.inference_api,
)
self.cache[memory_bank.identifier] = index
self.cache[vector_db.identifier] = index
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
if bank_id in self.cache:
return self.cache[bank_id]
async def _get_and_cache_vector_db_index(
self, vector_db_id: str
) -> Optional[VectorDBWithIndex]:
if vector_db_id in self.cache:
return self.cache[vector_db_id]
bank = await self.memory_bank_store.get_memory_bank(bank_id)
if not bank:
raise ValueError(f"Bank {bank_id} not found")
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
if not vector_db:
raise ValueError(f"Vector DB {vector_db_id} not found")
index = BankWithIndex(
bank=bank,
index=QdrantIndex(client=self.client, collection_name=bank_id),
index = VectorDBWithIndex(
vector_db=vector_db,
index=QdrantIndex(client=self.client, collection_name=vector_db.identifier),
inference_api=self.inference_api,
)
self.cache[bank_id] = index
self.cache[vector_db_id] = index
return index
async def insert_documents(
async def insert_chunks(
self,
bank_id: str,
documents: List[MemoryBankDocument],
vector_db_id: str,
chunks: List[Chunk],
ttl_seconds: Optional[int] = None,
) -> None:
index = await self._get_and_cache_bank_index(bank_id)
index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index:
raise ValueError(f"Bank {bank_id} not found")
raise ValueError(f"Vector DB {vector_db_id} not found")
await index.insert_documents(documents)
await index.insert_chunks(chunks)
async def query_documents(
async def query_chunks(
self,
bank_id: str,
vector_db_id: str,
query: InterleavedContent,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
index = await self._get_and_cache_bank_index(bank_id)
) -> QueryChunksResponse:
index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index:
raise ValueError(f"Bank {bank_id} not found")
raise ValueError(f"Vector DB {vector_db_id} not found")
return await index.query_documents(query, params)
return await index.query_chunks(query, params)

View file

@ -4,19 +4,22 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBank
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import VectorIO
from .config import SampleConfig
class SampleMemoryImpl(Memory):
class SampleMemoryImpl(VectorIO):
def __init__(self, config: SampleConfig):
self.config = config
async def register_memory_bank(self, memory_bank: MemoryBank) -> None:
# these are the memory banks the Llama Stack will use to route requests to this provider
async def register_vector_db(self, vector_db: VectorDB) -> None:
# these are the vector dbs the Llama Stack will use to route requests to this provider
# perform validation here if necessary
pass
async def initialize(self):
pass
async def shutdown(self):
pass

View file

@ -15,18 +15,13 @@ from weaviate.classes.init import Auth
from weaviate.classes.query import Filter
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.memory import (
Chunk,
Memory,
MemoryBankDocument,
QueryDocumentsResponse,
)
from llama_stack.apis.memory_banks import MemoryBank, MemoryBankType
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex,
EmbeddingIndex,
VectorDBWithIndex,
)
from .config import WeaviateConfig, WeaviateRequestProviderData
@ -49,7 +44,7 @@ class WeaviateIndex(EmbeddingIndex):
data_objects.append(
wvc.data.DataObject(
properties={
"chunk_content": chunk.json(),
"chunk_content": chunk.model_dump_json(),
},
vector=embeddings[i].tolist(),
)
@ -63,7 +58,7 @@ class WeaviateIndex(EmbeddingIndex):
async def query(
self, embedding: NDArray, k: int, score_threshold: float
) -> QueryDocumentsResponse:
) -> QueryChunksResponse:
collection = self.client.collections.get(self.collection_name)
results = collection.query.near_vector(
@ -86,7 +81,7 @@ class WeaviateIndex(EmbeddingIndex):
chunks.append(chunk)
scores.append(1.0 / doc.metadata.distance)
return QueryDocumentsResponse(chunks=chunks, scores=scores)
return QueryChunksResponse(chunks=chunks, scores=scores)
async def delete(self, chunk_ids: List[str]) -> None:
collection = self.client.collections.get(self.collection_name)
@ -96,9 +91,9 @@ class WeaviateIndex(EmbeddingIndex):
class WeaviateMemoryAdapter(
Memory,
VectorIO,
NeedsRequestProviderData,
MemoryBanksProtocolPrivate,
VectorDBsProtocolPrivate,
):
def __init__(self, config: WeaviateConfig, inference_api: Api.inference) -> None:
self.config = config
@ -129,20 +124,16 @@ class WeaviateMemoryAdapter(
for client in self.client_cache.values():
client.close()
async def register_memory_bank(
async def register_vector_db(
self,
memory_bank: MemoryBank,
vector_db: VectorDB,
) -> None:
assert (
memory_bank.memory_bank_type == MemoryBankType.vector.value
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
client = self._get_client()
# Create collection if it doesn't exist
if not client.collections.exists(memory_bank.identifier):
if not client.collections.exists(vector_db.identifier):
client.collections.create(
name=memory_bank.identifier,
name=vector_db.identifier,
vectorizer_config=wvc.config.Configure.Vectorizer.none(),
properties=[
wvc.config.Property(
@ -152,52 +143,54 @@ class WeaviateMemoryAdapter(
],
)
self.cache[memory_bank.identifier] = BankWithIndex(
memory_bank,
WeaviateIndex(client=client, collection_name=memory_bank.identifier),
self.cache[vector_db.identifier] = VectorDBWithIndex(
vector_db,
WeaviateIndex(client=client, collection_name=vector_db.identifier),
self.inference_api,
)
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
if bank_id in self.cache:
return self.cache[bank_id]
async def _get_and_cache_vector_db_index(
self, vector_db_id: str
) -> Optional[VectorDBWithIndex]:
if vector_db_id in self.cache:
return self.cache[vector_db_id]
bank = await self.memory_bank_store.get_memory_bank(bank_id)
if not bank:
raise ValueError(f"Bank {bank_id} not found")
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
if not vector_db:
raise ValueError(f"Vector DB {vector_db_id} not found")
client = self._get_client()
if not client.collections.exists(bank.identifier):
raise ValueError(f"Collection with name `{bank.identifier}` not found")
if not client.collections.exists(vector_db.identifier):
raise ValueError(f"Collection with name `{vector_db.identifier}` not found")
index = BankWithIndex(
bank=bank,
index=WeaviateIndex(client=client, collection_name=bank_id),
index = VectorDBWithIndex(
vector_db=vector_db,
index=WeaviateIndex(client=client, collection_name=vector_db.identifier),
inference_api=self.inference_api,
)
self.cache[bank_id] = index
self.cache[vector_db_id] = index
return index
async def insert_documents(
async def insert_chunks(
self,
bank_id: str,
documents: List[MemoryBankDocument],
vector_db_id: str,
chunks: List[Chunk],
ttl_seconds: Optional[int] = None,
) -> None:
index = await self._get_and_cache_bank_index(bank_id)
index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index:
raise ValueError(f"Bank {bank_id} not found")
raise ValueError(f"Vector DB {vector_db_id} not found")
await index.insert_documents(documents)
await index.insert_chunks(chunks)
async def query_documents(
async def query_chunks(
self,
bank_id: str,
vector_db_id: str,
query: InterleavedContent,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
index = await self._get_and_cache_bank_index(bank_id)
) -> QueryChunksResponse:
index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index:
raise ValueError(f"Bank {bank_id} not found")
raise ValueError(f"Vector DB {vector_db_id} not found")
return await index.query_documents(query, params)
return await index.query_chunks(query, params)