mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
This PR does the following: 1) adds the ability to generate embeddings in all supported inference providers. 2) Moves all the memory providers to use the inference API and improved the memory tests to setup the inference stack correctly and use the embedding models This is a merge from #589 and #598
175 lines
5.8 KiB
Python
175 lines
5.8 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
from typing import List
|
|
from urllib.parse import urlparse
|
|
|
|
import chromadb
|
|
from numpy.typing import NDArray
|
|
|
|
from llama_stack.apis.memory import * # noqa: F403
|
|
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
|
|
from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig
|
|
from llama_stack.providers.utils.memory.vector_store import (
|
|
BankWithIndex,
|
|
EmbeddingIndex,
|
|
)
|
|
from .config import ChromaRemoteImplConfig
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
ChromaClientType = Union[chromadb.AsyncHttpClient, chromadb.PersistentClient]
|
|
|
|
|
|
# this is a helper to allow us to use async and non-async chroma clients interchangeably
|
|
async def maybe_await(result):
|
|
if asyncio.iscoroutine(result):
|
|
return await result
|
|
return result
|
|
|
|
|
|
class ChromaIndex(EmbeddingIndex):
|
|
def __init__(self, client: ChromaClientType, collection):
|
|
self.client = client
|
|
self.collection = collection
|
|
|
|
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
|
assert len(chunks) == len(
|
|
embeddings
|
|
), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
|
|
|
await maybe_await(
|
|
self.collection.add(
|
|
documents=[chunk.model_dump_json() for chunk in chunks],
|
|
embeddings=embeddings,
|
|
ids=[f"{c.document_id}:chunk-{i}" for i, c in enumerate(chunks)],
|
|
)
|
|
)
|
|
|
|
async def query(
|
|
self, embedding: NDArray, k: int, score_threshold: float
|
|
) -> QueryDocumentsResponse:
|
|
results = await maybe_await(
|
|
self.collection.query(
|
|
query_embeddings=[embedding.tolist()],
|
|
n_results=k,
|
|
include=["documents", "distances"],
|
|
)
|
|
)
|
|
distances = results["distances"][0]
|
|
documents = results["documents"][0]
|
|
|
|
chunks = []
|
|
scores = []
|
|
for dist, doc in zip(distances, documents):
|
|
try:
|
|
doc = json.loads(doc)
|
|
chunk = Chunk(**doc)
|
|
except Exception:
|
|
log.exception(f"Failed to parse document: {doc}")
|
|
continue
|
|
|
|
chunks.append(chunk)
|
|
scores.append(1.0 / float(dist))
|
|
|
|
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
|
|
|
async def delete(self):
|
|
await maybe_await(self.client.delete_collection(self.collection.name))
|
|
|
|
|
|
class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|
def __init__(
|
|
self,
|
|
config: Union[ChromaRemoteImplConfig, ChromaInlineImplConfig],
|
|
inference_api: Api.inference,
|
|
) -> None:
|
|
log.info(f"Initializing ChromaMemoryAdapter with url: {config}")
|
|
self.config = config
|
|
self.inference_api = inference_api
|
|
|
|
self.client = None
|
|
self.cache = {}
|
|
|
|
async def initialize(self) -> None:
|
|
if isinstance(self.config, ChromaRemoteImplConfig):
|
|
log.info(f"Connecting to Chroma server at: {self.config.url}")
|
|
url = self.config.url.rstrip("/")
|
|
parsed = urlparse(url)
|
|
|
|
if parsed.path and parsed.path != "/":
|
|
raise ValueError("URL should not contain a path")
|
|
|
|
self.client = await chromadb.AsyncHttpClient(
|
|
host=parsed.hostname, port=parsed.port
|
|
)
|
|
else:
|
|
log.info(f"Connecting to Chroma local db at: {self.config.db_path}")
|
|
self.client = chromadb.PersistentClient(path=self.config.db_path)
|
|
|
|
async def shutdown(self) -> None:
|
|
pass
|
|
|
|
async def register_memory_bank(
|
|
self,
|
|
memory_bank: MemoryBank,
|
|
) -> None:
|
|
assert (
|
|
memory_bank.memory_bank_type == MemoryBankType.vector.value
|
|
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
|
|
|
|
collection = await maybe_await(
|
|
self.client.get_or_create_collection(
|
|
name=memory_bank.identifier,
|
|
metadata={"bank": memory_bank.model_dump_json()},
|
|
)
|
|
)
|
|
self.cache[memory_bank.identifier] = BankWithIndex(
|
|
memory_bank, ChromaIndex(self.client, collection), self.inference_api
|
|
)
|
|
|
|
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
|
await self.cache[memory_bank_id].index.delete()
|
|
del self.cache[memory_bank_id]
|
|
|
|
async def insert_documents(
|
|
self,
|
|
bank_id: str,
|
|
documents: List[MemoryBankDocument],
|
|
ttl_seconds: Optional[int] = None,
|
|
) -> None:
|
|
index = await self._get_and_cache_bank_index(bank_id)
|
|
|
|
await index.insert_documents(documents)
|
|
|
|
async def query_documents(
|
|
self,
|
|
bank_id: str,
|
|
query: InterleavedTextMedia,
|
|
params: Optional[Dict[str, Any]] = None,
|
|
) -> QueryDocumentsResponse:
|
|
index = await self._get_and_cache_bank_index(bank_id)
|
|
|
|
return await index.query_documents(query, params)
|
|
|
|
async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex:
|
|
if bank_id in self.cache:
|
|
return self.cache[bank_id]
|
|
|
|
bank = await self.memory_bank_store.get_memory_bank(bank_id)
|
|
if not bank:
|
|
raise ValueError(f"Bank {bank_id} not found in Llama Stack")
|
|
collection = await maybe_await(self.client.get_collection(bank_id))
|
|
if not collection:
|
|
raise ValueError(f"Bank {bank_id} not found in Chroma")
|
|
index = BankWithIndex(
|
|
bank, ChromaIndex(self.client, collection), self.inference_api
|
|
)
|
|
self.cache[bank_id] = index
|
|
return index
|