forked from phoenix-oss/llama-stack-mirror
Allow using an "inline" version of Chroma using PersistentClient (#567)
The same code is used (inside providers/remote/memory/chroma/chroma.py) but it is driven by separate configurations and changes which Chroma client to use. Note that the dependencies are separate (`chromadb-client` vs `chromadb` -- the latter is a _much_ heavier package.) ``` pytest -s -v -m chroma memory/test_memory.py --env CHROMA_DB_PATH=/tmp/chroma_test pytest -s -v -m chroma memory/test_memory.py --env CHROMA_URL=http://localhost:6001 ```
This commit is contained in:
parent
41487e6ed1
commit
b7cb06f004
12 changed files with 127 additions and 88 deletions
|
@ -3,7 +3,7 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import List
|
||||
|
@ -12,21 +12,31 @@ from urllib.parse import urlparse
|
|||
import chromadb
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from pydantic import parse_obj_as
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
||||
from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
BankWithIndex,
|
||||
EmbeddingIndex,
|
||||
)
|
||||
from .config import ChromaRemoteImplConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
ChromaClientType = Union[chromadb.AsyncHttpClient, chromadb.PersistentClient]
|
||||
|
||||
|
||||
# this is a helper to allow us to use async and non-async chroma clients interchangeably
|
||||
async def maybe_await(result):
|
||||
if asyncio.iscoroutine(result):
|
||||
return await result
|
||||
return result
|
||||
|
||||
|
||||
class ChromaIndex(EmbeddingIndex):
|
||||
def __init__(self, client: chromadb.AsyncHttpClient, collection):
|
||||
def __init__(self, client: ChromaClientType, collection):
|
||||
self.client = client
|
||||
self.collection = collection
|
||||
|
||||
|
@ -35,19 +45,23 @@ class ChromaIndex(EmbeddingIndex):
|
|||
embeddings
|
||||
), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||
|
||||
await self.collection.add(
|
||||
documents=[chunk.json() for chunk in chunks],
|
||||
embeddings=embeddings,
|
||||
ids=[f"{c.document_id}:chunk-{i}" for i, c in enumerate(chunks)],
|
||||
await maybe_await(
|
||||
self.collection.add(
|
||||
documents=[chunk.model_dump_json() for chunk in chunks],
|
||||
embeddings=embeddings,
|
||||
ids=[f"{c.document_id}:chunk-{i}" for i, c in enumerate(chunks)],
|
||||
)
|
||||
)
|
||||
|
||||
async def query(
|
||||
self, embedding: NDArray, k: int, score_threshold: float
|
||||
) -> QueryDocumentsResponse:
|
||||
results = await self.collection.query(
|
||||
query_embeddings=[embedding.tolist()],
|
||||
n_results=k,
|
||||
include=["documents", "distances"],
|
||||
results = await maybe_await(
|
||||
self.collection.query(
|
||||
query_embeddings=[embedding.tolist()],
|
||||
n_results=k,
|
||||
include=["documents", "distances"],
|
||||
)
|
||||
)
|
||||
distances = results["distances"][0]
|
||||
documents = results["documents"][0]
|
||||
|
@ -68,31 +82,33 @@ class ChromaIndex(EmbeddingIndex):
|
|||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def delete(self):
|
||||
await self.client.delete_collection(self.collection.name)
|
||||
await maybe_await(self.client.delete_collection(self.collection.name))
|
||||
|
||||
|
||||
class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||
def __init__(self, url: str) -> None:
|
||||
log.info(f"Initializing ChromaMemoryAdapter with url: {url}")
|
||||
url = url.rstrip("/")
|
||||
parsed = urlparse(url)
|
||||
|
||||
if parsed.path and parsed.path != "/":
|
||||
raise ValueError("URL should not contain a path")
|
||||
|
||||
self.host = parsed.hostname
|
||||
self.port = parsed.port
|
||||
|
||||
def __init__(
|
||||
self, config: Union[ChromaRemoteImplConfig, ChromaInlineImplConfig]
|
||||
) -> None:
|
||||
log.info(f"Initializing ChromaMemoryAdapter with url: {config}")
|
||||
self.config = config
|
||||
self.client = None
|
||||
self.cache = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
try:
|
||||
log.info(f"Connecting to Chroma server at: {self.host}:{self.port}")
|
||||
self.client = await chromadb.AsyncHttpClient(host=self.host, port=self.port)
|
||||
except Exception as e:
|
||||
log.exception("Could not connect to Chroma server")
|
||||
raise RuntimeError("Could not connect to Chroma server") from e
|
||||
if isinstance(self.config, ChromaRemoteImplConfig):
|
||||
log.info(f"Connecting to Chroma server at: {self.config.url}")
|
||||
url = self.config.url.rstrip("/")
|
||||
parsed = urlparse(url)
|
||||
|
||||
if parsed.path and parsed.path != "/":
|
||||
raise ValueError("URL should not contain a path")
|
||||
|
||||
self.client = await chromadb.AsyncHttpClient(
|
||||
host=parsed.hostname, port=parsed.port
|
||||
)
|
||||
else:
|
||||
log.info(f"Connecting to Chroma local db at: {self.config.db_path}")
|
||||
self.client = chromadb.PersistentClient(path=self.config.db_path)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
@ -105,33 +121,17 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
memory_bank.memory_bank_type == MemoryBankType.vector.value
|
||||
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
|
||||
|
||||
collection = await self.client.get_or_create_collection(
|
||||
name=memory_bank.identifier,
|
||||
metadata={"bank": memory_bank.model_dump_json()},
|
||||
collection = await maybe_await(
|
||||
self.client.get_or_create_collection(
|
||||
name=memory_bank.identifier,
|
||||
metadata={"bank": memory_bank.model_dump_json()},
|
||||
)
|
||||
)
|
||||
bank_index = BankWithIndex(
|
||||
bank=memory_bank, index=ChromaIndex(self.client, collection)
|
||||
)
|
||||
self.cache[memory_bank.identifier] = bank_index
|
||||
|
||||
async def list_memory_banks(self) -> List[MemoryBank]:
|
||||
collections = await self.client.list_collections()
|
||||
for collection in collections:
|
||||
try:
|
||||
data = json.loads(collection.metadata["bank"])
|
||||
bank = parse_obj_as(VectorMemoryBank, data)
|
||||
except Exception:
|
||||
log.exception(f"Failed to parse bank: {collection.metadata}")
|
||||
continue
|
||||
|
||||
index = BankWithIndex(
|
||||
bank=bank,
|
||||
index=ChromaIndex(self.client, collection),
|
||||
)
|
||||
self.cache[bank.identifier] = index
|
||||
|
||||
return [i.bank for i in self.cache.values()]
|
||||
|
||||
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
||||
await self.cache[memory_bank_id].index.delete()
|
||||
del self.cache[memory_bank_id]
|
||||
|
@ -163,7 +163,7 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
bank = await self.memory_bank_store.get_memory_bank(bank_id)
|
||||
if not bank:
|
||||
raise ValueError(f"Bank {bank_id} not found in Llama Stack")
|
||||
collection = await self.client.get_collection(bank_id)
|
||||
collection = await maybe_await(self.client.get_collection(bank_id))
|
||||
if not collection:
|
||||
raise ValueError(f"Bank {bank_id} not found in Chroma")
|
||||
index = BankWithIndex(bank=bank, index=ChromaIndex(self.client, collection))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue