remove mixin and test fixes

This commit is contained in:
Dinesh Yeduguru 2024-12-09 15:00:12 -08:00
parent 5bbeb985ca
commit 0e451525e5
9 changed files with 140 additions and 69 deletions

View file

@ -15,12 +15,10 @@ from numpy.typing import NDArray
from pydantic import parse_obj_as
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex,
EmbeddingIndex,
InferenceEmbeddingMixin,
)
log = logging.getLogger(__name__)
@ -72,7 +70,7 @@ class ChromaIndex(EmbeddingIndex):
await self.client.delete_collection(self.collection.name)
class ChromaMemoryAdapter(InferenceEmbeddingMixin, Memory, MemoryBanksProtocolPrivate):
class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
def __init__(self, url: str, inference_api: Api.inference) -> None:
log.info(f"Initializing ChromaMemoryAdapter with url: {url}")
url = url.rstrip("/")
@ -111,8 +109,8 @@ class ChromaMemoryAdapter(InferenceEmbeddingMixin, Memory, MemoryBanksProtocolPr
name=memory_bank.identifier,
metadata={"bank": memory_bank.model_dump_json()},
)
self.cache[memory_bank.identifier] = self._create_bank_with_index(
memory_bank, ChromaIndex(self.client, collection)
self.cache[memory_bank.identifier] = BankWithIndex(
memory_bank, ChromaIndex(self.client, collection), self.inference_api
)
async def list_memory_banks(self) -> List[MemoryBank]:
@ -125,9 +123,10 @@ class ChromaMemoryAdapter(InferenceEmbeddingMixin, Memory, MemoryBanksProtocolPr
log.exception(f"Failed to parse bank: {collection.metadata}")
continue
self.cache[bank.identifier] = self._create_bank_with_index(
self.cache[bank.identifier] = BankWithIndex(
bank,
ChromaIndex(self.client, collection),
self.inference_api,
)
return [i.bank for i in self.cache.values()]
@ -166,6 +165,8 @@ class ChromaMemoryAdapter(InferenceEmbeddingMixin, Memory, MemoryBanksProtocolPr
collection = await self.client.get_collection(bank_id)
if not collection:
raise ValueError(f"Bank {bank_id} not found in Chroma")
index = self._create_bank_with_index(bank, ChromaIndex(self.client, collection))
index = BankWithIndex(
bank, ChromaIndex(self.client, collection), self.inference_api
)
self.cache[bank_id] = index
return index