forked from phoenix-oss/llama-stack-mirror
migrate memory banks to Resource and new registration (#411)
* migrate memory banks to Resource and new registration * address feedback * address feedback * fix tests * pgvector fix * pgvector fix v2 * remove auto discovery * change register signature to make params required * update client * client fix * use annotated union to parse * remove base MemoryBank inheritence --------- Co-authored-by: Dinesh Yeduguru <dineshyv@fb.com>
This commit is contained in:
parent
6b9850e11b
commit
38cce97597
19 changed files with 240 additions and 129 deletions
|
@ -13,7 +13,7 @@ from pydantic import BaseModel, Field
|
|||
|
||||
from llama_stack.apis.datasets import DatasetDef
|
||||
from llama_stack.apis.eval_tasks import EvalTaskDef
|
||||
from llama_stack.apis.memory_banks import MemoryBankDef
|
||||
from llama_stack.apis.memory_banks.memory_banks import MemoryBank
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.apis.scoring_functions import ScoringFnDef
|
||||
from llama_stack.apis.shields import Shield
|
||||
|
@ -51,9 +51,9 @@ class ShieldsProtocolPrivate(Protocol):
|
|||
|
||||
|
||||
class MemoryBanksProtocolPrivate(Protocol):
|
||||
async def list_memory_banks(self) -> List[MemoryBankDef]: ...
|
||||
async def list_memory_banks(self) -> List[MemoryBank]: ...
|
||||
|
||||
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: ...
|
||||
async def register_memory_bank(self, memory_bank: MemoryBank) -> None: ...
|
||||
|
||||
|
||||
class DatasetsProtocolPrivate(Protocol):
|
||||
|
|
|
@ -641,7 +641,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
if session_info.memory_bank_id is None:
|
||||
bank_id = f"memory_bank_{session_id}"
|
||||
memory_bank = VectorMemoryBankDef(
|
||||
memory_bank = VectorMemoryBank(
|
||||
identifier=bank_id,
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
|
|
|
@ -83,7 +83,7 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
|||
stored_banks = await self.kvstore.range(start_key, end_key)
|
||||
|
||||
for bank_data in stored_banks:
|
||||
bank = VectorMemoryBankDef.model_validate_json(bank_data)
|
||||
bank = VectorMemoryBank.model_validate_json(bank_data)
|
||||
index = BankWithIndex(
|
||||
bank=bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION)
|
||||
)
|
||||
|
@ -95,10 +95,10 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
|||
|
||||
async def register_memory_bank(
|
||||
self,
|
||||
memory_bank: MemoryBankDef,
|
||||
memory_bank: MemoryBank,
|
||||
) -> None:
|
||||
assert (
|
||||
memory_bank.type == MemoryBankType.vector.value
|
||||
memory_bank.memory_bank_type == MemoryBankType.vector.value
|
||||
), f"Only vector banks are supported {memory_bank.type}"
|
||||
|
||||
# Store in kvstore
|
||||
|
@ -114,7 +114,7 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
|||
)
|
||||
self.cache[memory_bank.identifier] = index
|
||||
|
||||
async def list_memory_banks(self) -> List[MemoryBankDef]:
|
||||
async def list_memory_banks(self) -> List[MemoryBank]:
|
||||
return [i.bank for i in self.cache.values()]
|
||||
|
||||
async def insert_documents(
|
||||
|
|
|
@ -98,11 +98,11 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
|
||||
async def register_memory_bank(
|
||||
self,
|
||||
memory_bank: MemoryBankDef,
|
||||
memory_bank: MemoryBank,
|
||||
) -> None:
|
||||
assert (
|
||||
memory_bank.type == MemoryBankType.vector.value
|
||||
), f"Only vector banks are supported {memory_bank.type}"
|
||||
memory_bank.memory_bank_type == MemoryBankType.vector.value
|
||||
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
|
||||
|
||||
collection = await self.client.get_or_create_collection(
|
||||
name=memory_bank.identifier,
|
||||
|
@ -113,12 +113,12 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
)
|
||||
self.cache[memory_bank.identifier] = bank_index
|
||||
|
||||
async def list_memory_banks(self) -> List[MemoryBankDef]:
|
||||
async def list_memory_banks(self) -> List[MemoryBank]:
|
||||
collections = await self.client.list_collections()
|
||||
for collection in collections:
|
||||
try:
|
||||
data = json.loads(collection.metadata["bank"])
|
||||
bank = parse_obj_as(MemoryBankDef, data)
|
||||
bank = parse_obj_as(VectorMemoryBank, data)
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
|
|
|
@ -52,7 +52,7 @@ def load_models(cur, cls):
|
|||
|
||||
|
||||
class PGVectorIndex(EmbeddingIndex):
|
||||
def __init__(self, bank: MemoryBankDef, dimension: int, cursor):
|
||||
def __init__(self, bank: VectorMemoryBank, dimension: int, cursor):
|
||||
self.cursor = cursor
|
||||
self.table_name = f"vector_store_{bank.identifier}"
|
||||
|
||||
|
@ -121,6 +121,7 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
self.cache = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
print(f"Initializing PGVector memory adapter with config: {self.config}")
|
||||
try:
|
||||
self.conn = psycopg2.connect(
|
||||
host=self.config.host,
|
||||
|
@ -157,11 +158,11 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
|
||||
async def register_memory_bank(
|
||||
self,
|
||||
memory_bank: MemoryBankDef,
|
||||
memory_bank: MemoryBank,
|
||||
) -> None:
|
||||
assert (
|
||||
memory_bank.type == MemoryBankType.vector.value
|
||||
), f"Only vector banks are supported {memory_bank.type}"
|
||||
memory_bank.memory_bank_type == MemoryBankType.vector.value
|
||||
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
|
||||
|
||||
upsert_models(
|
||||
self.cursor,
|
||||
|
@ -176,8 +177,8 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
)
|
||||
self.cache[memory_bank.identifier] = index
|
||||
|
||||
async def list_memory_banks(self) -> List[MemoryBankDef]:
|
||||
banks = load_models(self.cursor, MemoryBankDef)
|
||||
async def list_memory_banks(self) -> List[MemoryBank]:
|
||||
banks = load_models(self.cursor, VectorMemoryBank)
|
||||
for bank in banks:
|
||||
if bank.identifier not in self.cache:
|
||||
index = BankWithIndex(
|
||||
|
|
|
@ -12,6 +12,7 @@ from numpy.typing import NDArray
|
|||
from qdrant_client import AsyncQdrantClient, models
|
||||
from qdrant_client.models import PointStruct
|
||||
|
||||
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
|
@ -112,11 +113,11 @@ class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
|
||||
async def register_memory_bank(
|
||||
self,
|
||||
memory_bank: MemoryBankDef,
|
||||
memory_bank: MemoryBank,
|
||||
) -> None:
|
||||
assert (
|
||||
memory_bank.type == MemoryBankType.vector.value
|
||||
), f"Only vector banks are supported {memory_bank.type}"
|
||||
memory_bank.memory_bank_type == MemoryBankType.vector
|
||||
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
|
||||
|
||||
index = BankWithIndex(
|
||||
bank=memory_bank,
|
||||
|
@ -125,7 +126,7 @@ class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
|
||||
self.cache[memory_bank.identifier] = index
|
||||
|
||||
async def list_memory_banks(self) -> List[MemoryBankDef]:
|
||||
async def list_memory_banks(self) -> List[MemoryBank]:
|
||||
# Qdrant doesn't have collection level metadata to store the bank properties
|
||||
# So we only return from the cache value
|
||||
return [i.bank for i in self.cache.values()]
|
||||
|
|
|
@ -114,11 +114,11 @@ class WeaviateMemoryAdapter(
|
|||
|
||||
async def register_memory_bank(
|
||||
self,
|
||||
memory_bank: MemoryBankDef,
|
||||
memory_bank: MemoryBank,
|
||||
) -> None:
|
||||
assert (
|
||||
memory_bank.type == MemoryBankType.vector.value
|
||||
), f"Only vector banks are supported {memory_bank.type}"
|
||||
memory_bank.memory_bank_type == MemoryBankType.vector
|
||||
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
|
||||
|
||||
client = self._get_client()
|
||||
|
||||
|
@ -141,7 +141,7 @@ class WeaviateMemoryAdapter(
|
|||
)
|
||||
self.cache[memory_bank.identifier] = index
|
||||
|
||||
async def list_memory_banks(self) -> List[MemoryBankDef]:
|
||||
async def list_memory_banks(self) -> List[MemoryBank]:
|
||||
# TODO: right now the Llama Stack is the source of truth for these banks. That is
|
||||
# not ideal. It should be Weaviate which is the source of truth. Unfortunately,
|
||||
# list() happens at Stack startup when the Weaviate client (credentials) is not
|
||||
|
@ -157,8 +157,8 @@ class WeaviateMemoryAdapter(
|
|||
raise ValueError(f"Bank {bank_id} not found")
|
||||
|
||||
client = self._get_client()
|
||||
if not client.collections.exists(bank_id):
|
||||
raise ValueError(f"Collection with name `{bank_id}` not found")
|
||||
if not client.collections.exists(bank.identifier):
|
||||
raise ValueError(f"Collection with name `{bank.identifier}` not found")
|
||||
|
||||
index = BankWithIndex(
|
||||
bank=bank,
|
||||
|
|
|
@ -10,11 +10,10 @@ import tempfile
|
|||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
from llama_stack.distribution.datatypes import Api, Provider, RemoteProviderConfig
|
||||
from llama_stack.providers.inline.memory.faiss import FaissImplConfig
|
||||
from llama_stack.providers.remote.memory.pgvector import PGVectorConfig
|
||||
from llama_stack.providers.remote.memory.weaviate import WeaviateConfig
|
||||
|
||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
||||
from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig
|
||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||
|
@ -78,7 +77,23 @@ def memory_weaviate() -> ProviderFixture:
|
|||
)
|
||||
|
||||
|
||||
MEMORY_FIXTURES = ["meta_reference", "pgvector", "weaviate", "remote"]
|
||||
@pytest.fixture(scope="session")
|
||||
def memory_chroma() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="chroma",
|
||||
provider_type="remote::chromadb",
|
||||
config=RemoteProviderConfig(
|
||||
host=get_env_or_fail("CHROMA_HOST"),
|
||||
port=get_env_or_fail("CHROMA_PORT"),
|
||||
).model_dump(),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
MEMORY_FIXTURES = ["meta_reference", "pgvector", "weaviate", "remote", "chroma"]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
|
|
|
@ -8,6 +8,7 @@ import pytest
|
|||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.memory_banks.memory_banks import VectorMemoryBankParams
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
|
@ -43,14 +44,15 @@ def sample_documents():
|
|||
|
||||
|
||||
async def register_memory_bank(banks_impl: MemoryBanks):
|
||||
bank = VectorMemoryBankDef(
|
||||
identifier="test_bank",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
)
|
||||
|
||||
await banks_impl.register_memory_bank(bank)
|
||||
return await banks_impl.register_memory_bank(
|
||||
memory_bank_id="test_bank",
|
||||
params=VectorMemoryBankParams(
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class TestMemory:
|
||||
|
@ -68,20 +70,28 @@ class TestMemory:
|
|||
# NOTE: this needs you to ensure that you are starting from a clean state
|
||||
# but so far we don't have an unregister API unfortunately, so be careful
|
||||
_, banks_impl = memory_stack
|
||||
bank = VectorMemoryBankDef(
|
||||
identifier="test_bank_no_provider",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
)
|
||||
|
||||
await banks_impl.register_memory_bank(bank)
|
||||
bank = await banks_impl.register_memory_bank(
|
||||
memory_bank_id="test_bank_no_provider",
|
||||
params=VectorMemoryBankParams(
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
),
|
||||
)
|
||||
response = await banks_impl.list_memory_banks()
|
||||
assert isinstance(response, list)
|
||||
assert len(response) == 1
|
||||
|
||||
# register same memory bank with same id again will fail
|
||||
await banks_impl.register_memory_bank(bank)
|
||||
await banks_impl.register_memory_bank(
|
||||
memory_bank_id="test_bank_no_provider",
|
||||
params=VectorMemoryBankParams(
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
chunk_size_in_tokens=512,
|
||||
overlap_size_in_tokens=64,
|
||||
),
|
||||
)
|
||||
response = await banks_impl.list_memory_banks()
|
||||
assert isinstance(response, list)
|
||||
assert len(response) == 1
|
||||
|
|
|
@ -148,7 +148,7 @@ class EmbeddingIndex(ABC):
|
|||
|
||||
@dataclass
|
||||
class BankWithIndex:
|
||||
bank: MemoryBankDef
|
||||
bank: VectorMemoryBank
|
||||
index: EmbeddingIndex
|
||||
|
||||
async def insert_documents(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue