mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 23:19:48 +00:00
user inference api to generate embeddings in vector store
This commit is contained in:
parent
96accc1216
commit
5bbeb985ca
12 changed files with 134 additions and 96 deletions
|
|
@ -20,6 +20,7 @@ from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
|||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
BankWithIndex,
|
||||
EmbeddingIndex,
|
||||
InferenceEmbeddingMixin,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
@ -71,8 +72,8 @@ class ChromaIndex(EmbeddingIndex):
|
|||
await self.client.delete_collection(self.collection.name)
|
||||
|
||||
|
||||
class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||
def __init__(self, url: str) -> None:
|
||||
class ChromaMemoryAdapter(InferenceEmbeddingMixin, Memory, MemoryBanksProtocolPrivate):
|
||||
def __init__(self, url: str, inference_api: Api.inference) -> None:
|
||||
log.info(f"Initializing ChromaMemoryAdapter with url: {url}")
|
||||
url = url.rstrip("/")
|
||||
parsed = urlparse(url)
|
||||
|
|
@ -82,6 +83,7 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
|
||||
self.host = parsed.hostname
|
||||
self.port = parsed.port
|
||||
self.inference_api = inference_api
|
||||
|
||||
self.client = None
|
||||
self.cache = {}
|
||||
|
|
@ -109,10 +111,9 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
name=memory_bank.identifier,
|
||||
metadata={"bank": memory_bank.model_dump_json()},
|
||||
)
|
||||
bank_index = BankWithIndex(
|
||||
bank=memory_bank, index=ChromaIndex(self.client, collection)
|
||||
self.cache[memory_bank.identifier] = self._create_bank_with_index(
|
||||
memory_bank, ChromaIndex(self.client, collection)
|
||||
)
|
||||
self.cache[memory_bank.identifier] = bank_index
|
||||
|
||||
async def list_memory_banks(self) -> List[MemoryBank]:
|
||||
collections = await self.client.list_collections()
|
||||
|
|
@ -124,11 +125,10 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
log.exception(f"Failed to parse bank: {collection.metadata}")
|
||||
continue
|
||||
|
||||
index = BankWithIndex(
|
||||
bank=bank,
|
||||
index=ChromaIndex(self.client, collection),
|
||||
self.cache[bank.identifier] = self._create_bank_with_index(
|
||||
bank,
|
||||
ChromaIndex(self.client, collection),
|
||||
)
|
||||
self.cache[bank.identifier] = index
|
||||
|
||||
return [i.bank for i in self.cache.values()]
|
||||
|
||||
|
|
@ -166,6 +166,6 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
collection = await self.client.get_collection(bank_id)
|
||||
if not collection:
|
||||
raise ValueError(f"Bank {bank_id} not found in Chroma")
|
||||
index = BankWithIndex(bank=bank, index=ChromaIndex(self.client, collection))
|
||||
index = self._create_bank_with_index(bank, ChromaIndex(self.client, collection))
|
||||
self.cache[bank_id] = index
|
||||
return index
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue