mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 14:02:36 +00:00
add support for provider update and unregister for memory banks
This commit is contained in:
parent
9b75e92852
commit
e8b699797c
11 changed files with 240 additions and 65 deletions
|
|
@ -67,6 +67,9 @@ class ChromaIndex(EmbeddingIndex):
|
|||
|
||||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def delete(self):
|
||||
await self.client.delete_collection(self.collection.name)
|
||||
|
||||
|
||||
class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||
def __init__(self, url: str) -> None:
|
||||
|
|
@ -134,6 +137,14 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
|
||||
return [i.bank for i in self.cache.values()]
|
||||
|
||||
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
||||
await self.cache[memory_bank_id].index.delete()
|
||||
del self.cache[memory_bank_id]
|
||||
|
||||
async def update_memory_bank(self, memory_bank: MemoryBank) -> None:
|
||||
await self.unregister_memory_bank(memory_bank.identifier)
|
||||
await self.register_memory_bank(memory_bank)
|
||||
|
||||
async def insert_documents(
|
||||
self,
|
||||
bank_id: str,
|
||||
|
|
|
|||
|
|
@ -112,6 +112,9 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
|
||||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def delete(self):
|
||||
self.cursor.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
||||
|
||||
|
||||
class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||
def __init__(self, config: PGVectorConfig) -> None:
|
||||
|
|
@ -177,6 +180,14 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
)
|
||||
self.cache[memory_bank.identifier] = index
|
||||
|
||||
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
||||
await self.cache[memory_bank_id].index.delete()
|
||||
del self.cache[memory_bank_id]
|
||||
|
||||
async def update_memory_bank(self, memory_bank: MemoryBank) -> None:
|
||||
await self.unregister_memory_bank(memory_bank.identifier)
|
||||
await self.register_memory_bank(memory_bank)
|
||||
|
||||
async def list_memory_banks(self) -> List[MemoryBank]:
|
||||
banks = load_models(self.cursor, VectorMemoryBank)
|
||||
for bank in banks:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue