[memory refactor][1/n] Rename Memory -> VectorIO, MemoryBanks -> VectorDBs (#828)

See https://github.com/meta-llama/llama-stack/issues/827 for the broader
design.

This is the first part:

- delete other kinds of memory banks (keyvalue, keyword, graph) for now;
we will introduce a keyvalue store API as part of this design but not
use it in the RAG tool yet.
- renaming of the APIs
This commit is contained in:
Ashwin Bharambe 2025-01-22 09:59:30 -08:00 committed by GitHub
parent 35a00d004a
commit 3ae8585b65
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
37 changed files with 175 additions and 296 deletions

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import ChromaRemoteImplConfig
async def get_adapter_impl(
config: ChromaRemoteImplConfig, deps: Dict[Api, ProviderSpec]
):
from .chroma import ChromaMemoryAdapter
impl = ChromaMemoryAdapter(config, deps[Api.inference])
await impl.initialize()
return impl

View file

@ -0,0 +1,182 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import json
import logging
from typing import List, Optional, Union
from urllib.parse import urlparse
import chromadb
from numpy.typing import NDArray
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.memory import (
Chunk,
Memory,
MemoryBankDocument,
QueryDocumentsResponse,
)
from llama_stack.apis.memory_banks import MemoryBank, MemoryBankType
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig
from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex,
EmbeddingIndex,
)
from .config import ChromaRemoteImplConfig
log = logging.getLogger(__name__)
ChromaClientType = Union[chromadb.AsyncHttpClient, chromadb.PersistentClient]
# this is a helper to allow us to use async and non-async chroma clients interchangeably
async def maybe_await(result):
if asyncio.iscoroutine(result):
return await result
return result
class ChromaIndex(EmbeddingIndex):
def __init__(self, client: ChromaClientType, collection):
self.client = client
self.collection = collection
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
assert len(chunks) == len(
embeddings
), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
await maybe_await(
self.collection.add(
documents=[chunk.model_dump_json() for chunk in chunks],
embeddings=embeddings,
ids=[f"{c.document_id}:chunk-{i}" for i, c in enumerate(chunks)],
)
)
async def query(
self, embedding: NDArray, k: int, score_threshold: float
) -> QueryDocumentsResponse:
results = await maybe_await(
self.collection.query(
query_embeddings=[embedding.tolist()],
n_results=k,
include=["documents", "distances"],
)
)
distances = results["distances"][0]
documents = results["documents"][0]
chunks = []
scores = []
for dist, doc in zip(distances, documents):
try:
doc = json.loads(doc)
chunk = Chunk(**doc)
except Exception:
log.exception(f"Failed to parse document: {doc}")
continue
chunks.append(chunk)
scores.append(1.0 / float(dist))
return QueryDocumentsResponse(chunks=chunks, scores=scores)
async def delete(self):
await maybe_await(self.client.delete_collection(self.collection.name))
class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
def __init__(
self,
config: Union[ChromaRemoteImplConfig, ChromaInlineImplConfig],
inference_api: Api.inference,
) -> None:
log.info(f"Initializing ChromaMemoryAdapter with url: {config}")
self.config = config
self.inference_api = inference_api
self.client = None
self.cache = {}
async def initialize(self) -> None:
if isinstance(self.config, ChromaRemoteImplConfig):
log.info(f"Connecting to Chroma server at: {self.config.url}")
url = self.config.url.rstrip("/")
parsed = urlparse(url)
if parsed.path and parsed.path != "/":
raise ValueError("URL should not contain a path")
self.client = await chromadb.AsyncHttpClient(
host=parsed.hostname, port=parsed.port
)
else:
log.info(f"Connecting to Chroma local db at: {self.config.db_path}")
self.client = chromadb.PersistentClient(path=self.config.db_path)
async def shutdown(self) -> None:
pass
async def register_memory_bank(
self,
memory_bank: MemoryBank,
) -> None:
assert (
memory_bank.memory_bank_type == MemoryBankType.vector.value
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
collection = await maybe_await(
self.client.get_or_create_collection(
name=memory_bank.identifier,
metadata={"bank": memory_bank.model_dump_json()},
)
)
self.cache[memory_bank.identifier] = BankWithIndex(
memory_bank, ChromaIndex(self.client, collection), self.inference_api
)
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
await self.cache[memory_bank_id].index.delete()
del self.cache[memory_bank_id]
async def insert_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None:
index = await self._get_and_cache_bank_index(bank_id)
await index.insert_documents(documents)
async def query_documents(
self,
bank_id: str,
query: InterleavedContent,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
index = await self._get_and_cache_bank_index(bank_id)
return await index.query_documents(query, params)
async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex:
if bank_id in self.cache:
return self.cache[bank_id]
bank = await self.memory_bank_store.get_memory_bank(bank_id)
if not bank:
raise ValueError(f"Bank {bank_id} not found in Llama Stack")
collection = await maybe_await(self.client.get_collection(bank_id))
if not collection:
raise ValueError(f"Bank {bank_id} not found in Chroma")
index = BankWithIndex(
bank, ChromaIndex(self.client, collection), self.inference_api
)
self.cache[bank_id] = index
return index

View file

@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from pydantic import BaseModel
class ChromaRemoteImplConfig(BaseModel):
url: str
@classmethod
def sample_config(cls) -> Dict[str, Any]:
return {"url": "{env.CHROMADB_URL}"}

View file

@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import PGVectorConfig
async def get_adapter_impl(config: PGVectorConfig, deps: Dict[Api, ProviderSpec]):
from .pgvector import PGVectorMemoryAdapter
impl = PGVectorMemoryAdapter(config, deps[Api.inference])
await impl.initialize()
return impl

View file

@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@json_schema_type
class PGVectorConfig(BaseModel):
host: str = Field(default="localhost")
port: int = Field(default=5432)
db: str = Field(default="postgres")
user: str = Field(default="postgres")
password: str = Field(default="mysecretpassword")

View file

@ -0,0 +1,212 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from typing import Any, Dict, List, Optional, Tuple
import psycopg2
from numpy.typing import NDArray
from psycopg2 import sql
from psycopg2.extras import execute_values, Json
from pydantic import BaseModel, parse_obj_as
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.memory import (
Chunk,
Memory,
MemoryBankDocument,
QueryDocumentsResponse,
)
from llama_stack.apis.memory_banks import MemoryBank, MemoryBankType, VectorMemoryBank
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex,
EmbeddingIndex,
)
from .config import PGVectorConfig
log = logging.getLogger(__name__)
def check_extension_version(cur):
cur.execute("SELECT extversion FROM pg_extension WHERE extname = 'vector'")
result = cur.fetchone()
return result[0] if result else None
def upsert_models(cur, keys_models: List[Tuple[str, BaseModel]]):
query = sql.SQL(
"""
INSERT INTO metadata_store (key, data)
VALUES %s
ON CONFLICT (key) DO UPDATE
SET data = EXCLUDED.data
"""
)
values = [(key, Json(model.dict())) for key, model in keys_models]
execute_values(cur, query, values, template="(%s, %s)")
def load_models(cur, cls):
cur.execute("SELECT key, data FROM metadata_store")
rows = cur.fetchall()
return [parse_obj_as(cls, row["data"]) for row in rows]
class PGVectorIndex(EmbeddingIndex):
def __init__(self, bank: VectorMemoryBank, dimension: int, cursor):
self.cursor = cursor
self.table_name = f"vector_store_{bank.identifier}"
self.cursor.execute(
f"""
CREATE TABLE IF NOT EXISTS {self.table_name} (
id TEXT PRIMARY KEY,
document JSONB,
embedding vector({dimension})
)
"""
)
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
assert len(chunks) == len(
embeddings
), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
values = []
for i, chunk in enumerate(chunks):
values.append(
(
f"{chunk.document_id}:chunk-{i}",
Json(chunk.dict()),
embeddings[i].tolist(),
)
)
query = sql.SQL(
f"""
INSERT INTO {self.table_name} (id, document, embedding)
VALUES %s
ON CONFLICT (id) DO UPDATE SET embedding = EXCLUDED.embedding, document = EXCLUDED.document
"""
)
execute_values(self.cursor, query, values, template="(%s, %s, %s::vector)")
async def query(
self, embedding: NDArray, k: int, score_threshold: float
) -> QueryDocumentsResponse:
self.cursor.execute(
f"""
SELECT document, embedding <-> %s::vector AS distance
FROM {self.table_name}
ORDER BY distance
LIMIT %s
""",
(embedding.tolist(), k),
)
results = self.cursor.fetchall()
chunks = []
scores = []
for doc, dist in results:
chunks.append(Chunk(**doc))
scores.append(1.0 / float(dist))
return QueryDocumentsResponse(chunks=chunks, scores=scores)
async def delete(self):
self.cursor.execute(f"DROP TABLE IF EXISTS {self.table_name}")
class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
def __init__(self, config: PGVectorConfig, inference_api: Api.inference) -> None:
self.config = config
self.inference_api = inference_api
self.cursor = None
self.conn = None
self.cache = {}
async def initialize(self) -> None:
log.info(f"Initializing PGVector memory adapter with config: {self.config}")
try:
self.conn = psycopg2.connect(
host=self.config.host,
port=self.config.port,
database=self.config.db,
user=self.config.user,
password=self.config.password,
)
self.conn.autocommit = True
self.cursor = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
version = check_extension_version(self.cursor)
if version:
log.info(f"Vector extension version: {version}")
else:
raise RuntimeError("Vector extension is not installed.")
self.cursor.execute(
"""
CREATE TABLE IF NOT EXISTS metadata_store (
key TEXT PRIMARY KEY,
data JSONB
)
"""
)
except Exception as e:
log.exception("Could not connect to PGVector database server")
raise RuntimeError("Could not connect to PGVector database server") from e
async def shutdown(self) -> None:
pass
async def register_memory_bank(self, memory_bank: MemoryBank) -> None:
assert (
memory_bank.memory_bank_type == MemoryBankType.vector.value
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
upsert_models(self.cursor, [(memory_bank.identifier, memory_bank)])
index = PGVectorIndex(memory_bank, memory_bank.embedding_dimension, self.cursor)
self.cache[memory_bank.identifier] = BankWithIndex(
memory_bank, index, self.inference_api
)
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
await self.cache[memory_bank_id].index.delete()
del self.cache[memory_bank_id]
async def insert_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None:
index = await self._get_and_cache_bank_index(bank_id)
await index.insert_documents(documents)
async def query_documents(
self,
bank_id: str,
query: InterleavedContent,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
index = await self._get_and_cache_bank_index(bank_id)
return await index.query_documents(query, params)
self.inference_api = inference_api
async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex:
if bank_id in self.cache:
return self.cache[bank_id]
bank = await self.memory_bank_store.get_memory_bank(bank_id)
index = PGVectorIndex(bank, bank.embedding_dimension, self.cursor)
self.cache[bank_id] = BankWithIndex(bank, index, self.inference_api)
return self.cache[bank_id]

View file

@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import QdrantConfig
async def get_adapter_impl(config: QdrantConfig, deps: Dict[Api, ProviderSpec]):
from .qdrant import QdrantVectorMemoryAdapter
impl = QdrantVectorMemoryAdapter(config, deps[Api.inference])
await impl.initialize()
return impl

View file

@ -0,0 +1,25 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel
@json_schema_type
class QdrantConfig(BaseModel):
location: Optional[str] = None
url: Optional[str] = None
port: Optional[int] = 6333
grpc_port: int = 6334
prefer_grpc: bool = False
https: Optional[bool] = None
api_key: Optional[str] = None
prefix: Optional[str] = None
timeout: Optional[int] = None
host: Optional[str] = None
path: Optional[str] = None

View file

@ -0,0 +1,174 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import uuid
from typing import Any, Dict, List, Optional
from numpy.typing import NDArray
from qdrant_client import AsyncQdrantClient, models
from qdrant_client.models import PointStruct
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.memory import (
Chunk,
Memory,
MemoryBankDocument,
QueryDocumentsResponse,
)
from llama_stack.apis.memory_banks import MemoryBank, MemoryBankType
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.remote.memory.qdrant.config import QdrantConfig
from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex,
EmbeddingIndex,
)
log = logging.getLogger(__name__)
CHUNK_ID_KEY = "_chunk_id"
def convert_id(_id: str) -> str:
"""
Converts any string into a UUID string based on a seed.
Qdrant accepts UUID strings and unsigned integers as point ID.
We use a seed to convert each string into a UUID string deterministically.
This allows us to overwrite the same point with the original ID.
"""
return str(uuid.uuid5(uuid.NAMESPACE_DNS, _id))
class QdrantIndex(EmbeddingIndex):
def __init__(self, client: AsyncQdrantClient, collection_name: str):
self.client = client
self.collection_name = collection_name
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
assert len(chunks) == len(
embeddings
), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
if not await self.client.collection_exists(self.collection_name):
await self.client.create_collection(
self.collection_name,
vectors_config=models.VectorParams(
size=len(embeddings[0]), distance=models.Distance.COSINE
),
)
points = []
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
chunk_id = f"{chunk.document_id}:chunk-{i}"
points.append(
PointStruct(
id=convert_id(chunk_id),
vector=embedding,
payload={"chunk_content": chunk.model_dump()}
| {CHUNK_ID_KEY: chunk_id},
)
)
await self.client.upsert(collection_name=self.collection_name, points=points)
async def query(
self, embedding: NDArray, k: int, score_threshold: float
) -> QueryDocumentsResponse:
results = (
await self.client.query_points(
collection_name=self.collection_name,
query=embedding.tolist(),
limit=k,
with_payload=True,
score_threshold=score_threshold,
)
).points
chunks, scores = [], []
for point in results:
assert isinstance(point, models.ScoredPoint)
assert point.payload is not None
try:
chunk = Chunk(**point.payload["chunk_content"])
except Exception:
log.exception("Failed to parse chunk")
continue
chunks.append(chunk)
scores.append(point.score)
return QueryDocumentsResponse(chunks=chunks, scores=scores)
class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
def __init__(self, config: QdrantConfig, inference_api: Api.inference) -> None:
self.config = config
self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True))
self.cache = {}
self.inference_api = inference_api
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
self.client.close()
async def register_memory_bank(
self,
memory_bank: MemoryBank,
) -> None:
assert (
memory_bank.memory_bank_type == MemoryBankType.vector
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
index = BankWithIndex(
bank=memory_bank,
index=QdrantIndex(self.client, memory_bank.identifier),
inference_api=self.inference_api,
)
self.cache[memory_bank.identifier] = index
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
if bank_id in self.cache:
return self.cache[bank_id]
bank = await self.memory_bank_store.get_memory_bank(bank_id)
if not bank:
raise ValueError(f"Bank {bank_id} not found")
index = BankWithIndex(
bank=bank,
index=QdrantIndex(client=self.client, collection_name=bank_id),
inference_api=self.inference_api,
)
self.cache[bank_id] = index
return index
async def insert_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None:
index = await self._get_and_cache_bank_index(bank_id)
if not index:
raise ValueError(f"Bank {bank_id} not found")
await index.insert_documents(documents)
async def query_documents(
self,
bank_id: str,
query: InterleavedContent,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
index = await self._get_and_cache_bank_index(bank_id)
if not index:
raise ValueError(f"Bank {bank_id} not found")
return await index.query_documents(query, params)

View file

@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from .config import SampleConfig
async def get_adapter_impl(config: SampleConfig, _deps) -> Any:
from .sample import SampleMemoryImpl
impl = SampleMemoryImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,12 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
class SampleConfig(BaseModel):
host: str = "localhost"
port: int = 9999

View file

@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBank
from .config import SampleConfig
class SampleMemoryImpl(Memory):
def __init__(self, config: SampleConfig):
self.config = config
async def register_memory_bank(self, memory_bank: MemoryBank) -> None:
# these are the memory banks the Llama Stack will use to route requests to this provider
# perform validation here if necessary
pass
async def initialize(self):
pass

View file

@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import WeaviateConfig, WeaviateRequestProviderData # noqa: F401
async def get_adapter_impl(config: WeaviateConfig, deps: Dict[Api, ProviderSpec]):
from .weaviate import WeaviateMemoryAdapter
impl = WeaviateMemoryAdapter(config, deps[Api.inference])
await impl.initialize()
return impl

View file

@ -0,0 +1,16 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
class WeaviateRequestProviderData(BaseModel):
weaviate_api_key: str
weaviate_cluster_url: str
class WeaviateConfig(BaseModel):
pass

View file

@ -0,0 +1,203 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import logging
from typing import Any, Dict, List, Optional
import weaviate
import weaviate.classes as wvc
from numpy.typing import NDArray
from weaviate.classes.init import Auth
from weaviate.classes.query import Filter
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.memory import (
Chunk,
Memory,
MemoryBankDocument,
QueryDocumentsResponse,
)
from llama_stack.apis.memory_banks import MemoryBank, MemoryBankType
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex,
EmbeddingIndex,
)
from .config import WeaviateConfig, WeaviateRequestProviderData
log = logging.getLogger(__name__)
class WeaviateIndex(EmbeddingIndex):
def __init__(self, client: weaviate.Client, collection_name: str):
self.client = client
self.collection_name = collection_name
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
assert len(chunks) == len(
embeddings
), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
data_objects = []
for i, chunk in enumerate(chunks):
data_objects.append(
wvc.data.DataObject(
properties={
"chunk_content": chunk.json(),
},
vector=embeddings[i].tolist(),
)
)
# Inserting chunks into a prespecified Weaviate collection
collection = self.client.collections.get(self.collection_name)
# TODO: make this async friendly
collection.data.insert_many(data_objects)
async def query(
self, embedding: NDArray, k: int, score_threshold: float
) -> QueryDocumentsResponse:
collection = self.client.collections.get(self.collection_name)
results = collection.query.near_vector(
near_vector=embedding.tolist(),
limit=k,
return_metadata=wvc.query.MetadataQuery(distance=True),
)
chunks = []
scores = []
for doc in results.objects:
chunk_json = doc.properties["chunk_content"]
try:
chunk_dict = json.loads(chunk_json)
chunk = Chunk(**chunk_dict)
except Exception:
log.exception(f"Failed to parse document: {chunk_json}")
continue
chunks.append(chunk)
scores.append(1.0 / doc.metadata.distance)
return QueryDocumentsResponse(chunks=chunks, scores=scores)
async def delete(self, chunk_ids: List[str]) -> None:
collection = self.client.collections.get(self.collection_name)
collection.data.delete_many(
where=Filter.by_property("id").contains_any(chunk_ids)
)
class WeaviateMemoryAdapter(
Memory,
NeedsRequestProviderData,
MemoryBanksProtocolPrivate,
):
def __init__(self, config: WeaviateConfig, inference_api: Api.inference) -> None:
self.config = config
self.inference_api = inference_api
self.client_cache = {}
self.cache = {}
def _get_client(self) -> weaviate.Client:
provider_data = self.get_request_provider_data()
assert provider_data is not None, "Request provider data must be set"
assert isinstance(provider_data, WeaviateRequestProviderData)
key = f"{provider_data.weaviate_cluster_url}::{provider_data.weaviate_api_key}"
if key in self.client_cache:
return self.client_cache[key]
client = weaviate.connect_to_weaviate_cloud(
cluster_url=provider_data.weaviate_cluster_url,
auth_credentials=Auth.api_key(provider_data.weaviate_api_key),
)
self.client_cache[key] = client
return client
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
for client in self.client_cache.values():
client.close()
async def register_memory_bank(
self,
memory_bank: MemoryBank,
) -> None:
assert (
memory_bank.memory_bank_type == MemoryBankType.vector.value
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
client = self._get_client()
# Create collection if it doesn't exist
if not client.collections.exists(memory_bank.identifier):
client.collections.create(
name=memory_bank.identifier,
vectorizer_config=wvc.config.Configure.Vectorizer.none(),
properties=[
wvc.config.Property(
name="chunk_content",
data_type=wvc.config.DataType.TEXT,
),
],
)
self.cache[memory_bank.identifier] = BankWithIndex(
memory_bank,
WeaviateIndex(client=client, collection_name=memory_bank.identifier),
self.inference_api,
)
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
if bank_id in self.cache:
return self.cache[bank_id]
bank = await self.memory_bank_store.get_memory_bank(bank_id)
if not bank:
raise ValueError(f"Bank {bank_id} not found")
client = self._get_client()
if not client.collections.exists(bank.identifier):
raise ValueError(f"Collection with name `{bank.identifier}` not found")
index = BankWithIndex(
bank=bank,
index=WeaviateIndex(client=client, collection_name=bank_id),
inference_api=self.inference_api,
)
self.cache[bank_id] = index
return index
async def insert_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
ttl_seconds: Optional[int] = None,
) -> None:
index = await self._get_and_cache_bank_index(bank_id)
if not index:
raise ValueError(f"Bank {bank_id} not found")
await index.insert_documents(documents)
async def query_documents(
self,
bank_id: str,
query: InterleavedContent,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
index = await self._get_and_cache_bank_index(bank_id)
if not index:
raise ValueError(f"Bank {bank_id} not found")
return await index.query_documents(query, params)