forked from phoenix-oss/llama-stack-mirror
migrate memory banks to Resource and new registration (#411)
* migrate memory banks to Resource and new registration * address feedback * address feedback * fix tests * pgvector fix * pgvector fix v2 * remove auto discovery * change register signature to make params required * update client * client fix * use annotated union to parse * remove base MemoryBank inheritence --------- Co-authored-by: Dinesh Yeduguru <dineshyv@fb.com>
This commit is contained in:
parent
6b9850e11b
commit
38cce97597
19 changed files with 240 additions and 129 deletions
|
@ -98,11 +98,11 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
|
||||
async def register_memory_bank(
|
||||
self,
|
||||
memory_bank: MemoryBankDef,
|
||||
memory_bank: MemoryBank,
|
||||
) -> None:
|
||||
assert (
|
||||
memory_bank.type == MemoryBankType.vector.value
|
||||
), f"Only vector banks are supported {memory_bank.type}"
|
||||
memory_bank.memory_bank_type == MemoryBankType.vector.value
|
||||
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
|
||||
|
||||
collection = await self.client.get_or_create_collection(
|
||||
name=memory_bank.identifier,
|
||||
|
@ -113,12 +113,12 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
)
|
||||
self.cache[memory_bank.identifier] = bank_index
|
||||
|
||||
async def list_memory_banks(self) -> List[MemoryBankDef]:
|
||||
async def list_memory_banks(self) -> List[MemoryBank]:
|
||||
collections = await self.client.list_collections()
|
||||
for collection in collections:
|
||||
try:
|
||||
data = json.loads(collection.metadata["bank"])
|
||||
bank = parse_obj_as(MemoryBankDef, data)
|
||||
bank = parse_obj_as(VectorMemoryBank, data)
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
|
|
|
@ -52,7 +52,7 @@ def load_models(cur, cls):
|
|||
|
||||
|
||||
class PGVectorIndex(EmbeddingIndex):
|
||||
def __init__(self, bank: MemoryBankDef, dimension: int, cursor):
|
||||
def __init__(self, bank: VectorMemoryBank, dimension: int, cursor):
|
||||
self.cursor = cursor
|
||||
self.table_name = f"vector_store_{bank.identifier}"
|
||||
|
||||
|
@ -121,6 +121,7 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
self.cache = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
print(f"Initializing PGVector memory adapter with config: {self.config}")
|
||||
try:
|
||||
self.conn = psycopg2.connect(
|
||||
host=self.config.host,
|
||||
|
@ -157,11 +158,11 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
|
||||
async def register_memory_bank(
|
||||
self,
|
||||
memory_bank: MemoryBankDef,
|
||||
memory_bank: MemoryBank,
|
||||
) -> None:
|
||||
assert (
|
||||
memory_bank.type == MemoryBankType.vector.value
|
||||
), f"Only vector banks are supported {memory_bank.type}"
|
||||
memory_bank.memory_bank_type == MemoryBankType.vector.value
|
||||
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
|
||||
|
||||
upsert_models(
|
||||
self.cursor,
|
||||
|
@ -176,8 +177,8 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
)
|
||||
self.cache[memory_bank.identifier] = index
|
||||
|
||||
async def list_memory_banks(self) -> List[MemoryBankDef]:
|
||||
banks = load_models(self.cursor, MemoryBankDef)
|
||||
async def list_memory_banks(self) -> List[MemoryBank]:
|
||||
banks = load_models(self.cursor, VectorMemoryBank)
|
||||
for bank in banks:
|
||||
if bank.identifier not in self.cache:
|
||||
index = BankWithIndex(
|
||||
|
|
|
@ -12,6 +12,7 @@ from numpy.typing import NDArray
|
|||
from qdrant_client import AsyncQdrantClient, models
|
||||
from qdrant_client.models import PointStruct
|
||||
|
||||
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
|
@ -112,11 +113,11 @@ class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
|
||||
async def register_memory_bank(
|
||||
self,
|
||||
memory_bank: MemoryBankDef,
|
||||
memory_bank: MemoryBank,
|
||||
) -> None:
|
||||
assert (
|
||||
memory_bank.type == MemoryBankType.vector.value
|
||||
), f"Only vector banks are supported {memory_bank.type}"
|
||||
memory_bank.memory_bank_type == MemoryBankType.vector
|
||||
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
|
||||
|
||||
index = BankWithIndex(
|
||||
bank=memory_bank,
|
||||
|
@ -125,7 +126,7 @@ class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
|
||||
self.cache[memory_bank.identifier] = index
|
||||
|
||||
async def list_memory_banks(self) -> List[MemoryBankDef]:
|
||||
async def list_memory_banks(self) -> List[MemoryBank]:
|
||||
# Qdrant doesn't have collection level metadata to store the bank properties
|
||||
# So we only return from the cache value
|
||||
return [i.bank for i in self.cache.values()]
|
||||
|
|
|
@ -114,11 +114,11 @@ class WeaviateMemoryAdapter(
|
|||
|
||||
async def register_memory_bank(
|
||||
self,
|
||||
memory_bank: MemoryBankDef,
|
||||
memory_bank: MemoryBank,
|
||||
) -> None:
|
||||
assert (
|
||||
memory_bank.type == MemoryBankType.vector.value
|
||||
), f"Only vector banks are supported {memory_bank.type}"
|
||||
memory_bank.memory_bank_type == MemoryBankType.vector
|
||||
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
|
||||
|
||||
client = self._get_client()
|
||||
|
||||
|
@ -141,7 +141,7 @@ class WeaviateMemoryAdapter(
|
|||
)
|
||||
self.cache[memory_bank.identifier] = index
|
||||
|
||||
async def list_memory_banks(self) -> List[MemoryBankDef]:
|
||||
async def list_memory_banks(self) -> List[MemoryBank]:
|
||||
# TODO: right now the Llama Stack is the source of truth for these banks. That is
|
||||
# not ideal. It should be Weaviate which is the source of truth. Unfortunately,
|
||||
# list() happens at Stack startup when the Weaviate client (credentials) is not
|
||||
|
@ -157,8 +157,8 @@ class WeaviateMemoryAdapter(
|
|||
raise ValueError(f"Bank {bank_id} not found")
|
||||
|
||||
client = self._get_client()
|
||||
if not client.collections.exists(bank_id):
|
||||
raise ValueError(f"Collection with name `{bank_id}` not found")
|
||||
if not client.collections.exists(bank.identifier):
|
||||
raise ValueError(f"Collection with name `{bank.identifier}` not found")
|
||||
|
||||
index = BankWithIndex(
|
||||
bank=bank,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue