diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 613844f5e..f2602ddde 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -271,7 +271,7 @@ class Session(BaseModel): turns: List[Turn] started_at: datetime - memory_bank: Optional[MemoryBankDef] = None + memory_bank: Optional[MemoryBank] = None class AgentConfigCommon(BaseModel): diff --git a/llama_stack/apis/memory/client.py b/llama_stack/apis/memory/client.py index a791dfa86..5cfed8518 100644 --- a/llama_stack/apis/memory/client.py +++ b/llama_stack/apis/memory/client.py @@ -75,14 +75,22 @@ class MemoryClient(Memory): async def run_main(host: str, port: int, stream: bool): banks_client = MemoryBanksClient(f"http://{host}:{port}") - bank = VectorMemoryBankDef( + bank = VectorMemoryBank( identifier="test_bank", provider_id="", embedding_model="all-MiniLM-L6-v2", chunk_size_in_tokens=512, overlap_size_in_tokens=64, ) - await banks_client.register_memory_bank(bank) + await banks_client.register_memory_bank( + bank.identifier, + VectorMemoryBankParams( + embedding_model="all-MiniLM-L6-v2", + chunk_size_in_tokens=512, + overlap_size_in_tokens=64, + ), + provider_resource_id=bank.identifier, + ) retrieved_bank = await banks_client.get_memory_bank(bank.identifier) assert retrieved_bank is not None diff --git a/llama_stack/apis/memory/memory.py b/llama_stack/apis/memory/memory.py index 9047820ac..48b6e2241 100644 --- a/llama_stack/apis/memory/memory.py +++ b/llama_stack/apis/memory/memory.py @@ -39,7 +39,7 @@ class QueryDocumentsResponse(BaseModel): class MemoryBankStore(Protocol): - def get_memory_bank(self, bank_id: str) -> Optional[MemoryBankDef]: ... + def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: ... @runtime_checkable diff --git a/llama_stack/apis/memory_banks/client.py b/llama_stack/apis/memory_banks/client.py index 69be35d02..308ee42f4 100644 --- a/llama_stack/apis/memory_banks/client.py +++ b/llama_stack/apis/memory_banks/client.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import asyncio -import json from typing import Any, Dict, List, Optional @@ -26,13 +25,13 @@ def deserialize_memory_bank_def( raise ValueError("Memory bank type not specified") type = j["type"] if type == MemoryBankType.vector.value: - return VectorMemoryBankDef(**j) + return VectorMemoryBank(**j) elif type == MemoryBankType.keyvalue.value: - return KeyValueMemoryBankDef(**j) + return KeyValueMemoryBank(**j) elif type == MemoryBankType.keyword.value: - return KeywordMemoryBankDef(**j) + return KeywordMemoryBank(**j) elif type == MemoryBankType.graph.value: - return GraphMemoryBankDef(**j) + return GraphMemoryBank(**j) else: raise ValueError(f"Unknown memory bank type: {type}") @@ -47,7 +46,7 @@ class MemoryBanksClient(MemoryBanks): async def shutdown(self) -> None: pass - async def list_memory_banks(self) -> List[MemoryBankDefWithProvider]: + async def list_memory_banks(self) -> List[MemoryBank]: async with httpx.AsyncClient() as client: response = await client.get( f"{self.base_url}/memory_banks/list", @@ -57,13 +56,20 @@ class MemoryBanksClient(MemoryBanks): return [deserialize_memory_bank_def(x) for x in response.json()] async def register_memory_bank( - self, memory_bank: MemoryBankDefWithProvider + self, + memory_bank_id: str, + params: BankParams, + provider_resource_id: Optional[str] = None, + provider_id: Optional[str] = None, ) -> None: async with httpx.AsyncClient() as client: response = await client.post( f"{self.base_url}/memory_banks/register", json={ - "memory_bank": json.loads(memory_bank.json()), + "memory_bank_id": memory_bank_id, + "provider_resource_id": provider_resource_id, + "provider_id": provider_id, + "params": params.dict(), }, headers={"Content-Type": "application/json"}, ) @@ -71,13 +77,13 @@ class MemoryBanksClient(MemoryBanks): async def get_memory_bank( self, - identifier: str, - ) -> Optional[MemoryBankDefWithProvider]: + memory_bank_id: str, + ) -> Optional[MemoryBank]: async with httpx.AsyncClient() as client: response = await client.get( f"{self.base_url}/memory_banks/get", params={ - "identifier": identifier, + "memory_bank_id": memory_bank_id, }, headers={"Content-Type": "application/json"}, ) @@ -94,12 +100,12 @@ async def run_main(host: str, port: int, stream: bool): # register memory bank for the first time response = await client.register_memory_bank( - VectorMemoryBankDef( - identifier="test_bank2", + memory_bank_id="test_bank2", + params=VectorMemoryBankParams( embedding_model="all-MiniLM-L6-v2", chunk_size_in_tokens=512, overlap_size_in_tokens=64, - ) + ), ) cprint(f"register_memory_bank response={response}", "blue") diff --git a/llama_stack/apis/memory_banks/memory_banks.py b/llama_stack/apis/memory_banks/memory_banks.py index df116d3c2..303104f25 100644 --- a/llama_stack/apis/memory_banks/memory_banks.py +++ b/llama_stack/apis/memory_banks/memory_banks.py @@ -5,11 +5,21 @@ # the root directory of this source tree. from enum import Enum -from typing import List, Literal, Optional, Protocol, runtime_checkable, Union +from typing import ( + Annotated, + List, + Literal, + Optional, + Protocol, + runtime_checkable, + Union, +) from llama_models.schema_utils import json_schema_type, webmethod + from pydantic import BaseModel, Field -from typing_extensions import Annotated + +from llama_stack.apis.resource import Resource, ResourceType @json_schema_type @@ -20,59 +30,98 @@ class MemoryBankType(Enum): graph = "graph" -class CommonDef(BaseModel): - identifier: str - # Hack: move this out later - provider_id: str = "" - - @json_schema_type -class VectorMemoryBankDef(CommonDef): - type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value +class VectorMemoryBank(Resource): + type: Literal[ResourceType.memory_bank.value] = ResourceType.memory_bank.value + memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value embedding_model: str chunk_size_in_tokens: int overlap_size_in_tokens: Optional[int] = None @json_schema_type -class KeyValueMemoryBankDef(CommonDef): - type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value +class KeyValueMemoryBank(Resource): + type: Literal[ResourceType.memory_bank.value] = ResourceType.memory_bank.value + memory_bank_type: Literal[MemoryBankType.keyvalue.value] = ( + MemoryBankType.keyvalue.value + ) @json_schema_type -class KeywordMemoryBankDef(CommonDef): - type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value +class KeywordMemoryBank(Resource): + type: Literal[ResourceType.memory_bank.value] = ResourceType.memory_bank.value + memory_bank_type: Literal[MemoryBankType.keyword.value] = ( + MemoryBankType.keyword.value + ) @json_schema_type -class GraphMemoryBankDef(CommonDef): - type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value +class GraphMemoryBank(Resource): + type: Literal[ResourceType.memory_bank.value] = ResourceType.memory_bank.value + memory_bank_type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value -MemoryBankDef = Annotated[ +@json_schema_type +class VectorMemoryBankParams(BaseModel): + memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value + embedding_model: str + chunk_size_in_tokens: int + overlap_size_in_tokens: Optional[int] = None + + +@json_schema_type +class KeyValueMemoryBankParams(BaseModel): + memory_bank_type: Literal[MemoryBankType.keyvalue.value] = ( + MemoryBankType.keyvalue.value + ) + + +@json_schema_type +class KeywordMemoryBankParams(BaseModel): + memory_bank_type: Literal[MemoryBankType.keyword.value] = ( + MemoryBankType.keyword.value + ) + + +@json_schema_type +class GraphMemoryBankParams(BaseModel): + memory_bank_type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value + + +MemoryBank = Annotated[ Union[ - VectorMemoryBankDef, - KeyValueMemoryBankDef, - KeywordMemoryBankDef, - GraphMemoryBankDef, + VectorMemoryBank, + KeyValueMemoryBank, + KeywordMemoryBank, + GraphMemoryBank, ], - Field(discriminator="type"), + Field(discriminator="memory_bank_type"), ] -MemoryBankDefWithProvider = MemoryBankDef +BankParams = Annotated[ + Union[ + VectorMemoryBankParams, + KeyValueMemoryBankParams, + KeywordMemoryBankParams, + GraphMemoryBankParams, + ], + Field(discriminator="memory_bank_type"), +] @runtime_checkable class MemoryBanks(Protocol): @webmethod(route="/memory_banks/list", method="GET") - async def list_memory_banks(self) -> List[MemoryBankDefWithProvider]: ... + async def list_memory_banks(self) -> List[MemoryBank]: ... @webmethod(route="/memory_banks/get", method="GET") - async def get_memory_bank( - self, identifier: str - ) -> Optional[MemoryBankDefWithProvider]: ... + async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]: ... @webmethod(route="/memory_banks/register", method="POST") async def register_memory_bank( - self, memory_bank: MemoryBankDefWithProvider - ) -> None: ... + self, + memory_bank_id: str, + params: BankParams, + provider_id: Optional[str] = None, + provider_memorybank_id: Optional[str] = None, + ) -> MemoryBank: ... diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index a2eafe273..ebc511b02 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -33,7 +33,7 @@ RoutingKey = Union[str, List[str]] RoutableObject = Union[ Model, Shield, - MemoryBankDef, + MemoryBank, DatasetDef, ScoringFnDef, ] @@ -43,7 +43,7 @@ RoutableObjectWithProvider = Annotated[ Union[ Model, Shield, - MemoryBankDefWithProvider, + MemoryBank, DatasetDefWithProvider, ScoringFnDefWithProvider, ], diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index c8c906af7..5f6395e0d 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -7,8 +7,8 @@ from typing import Any, AsyncGenerator, Dict, List, Optional from llama_stack.apis.datasetio.datasetio import DatasetIO +from llama_stack.apis.memory_banks.memory_banks import BankParams from llama_stack.distribution.datatypes import RoutingTable - from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403 @@ -32,8 +32,19 @@ class MemoryRouter(Memory): async def shutdown(self) -> None: pass - async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: - await self.routing_table.register_memory_bank(memory_bank) + async def register_memory_bank( + self, + memory_bank_id: str, + params: BankParams, + provider_id: Optional[str] = None, + provider_memorybank_id: Optional[str] = None, + ) -> None: + await self.routing_table.register_memory_bank( + memory_bank_id, + params, + provider_id, + provider_memorybank_id, + ) async def insert_documents( self, diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 721134bd4..aa61580b2 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -6,6 +6,8 @@ from typing import Any, Dict, List, Optional +from pydantic import parse_obj_as + from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.models import * # noqa: F403 @@ -89,8 +91,6 @@ class CommonRoutingTableImpl(RoutingTable): elif api == Api.memory: p.memory_bank_store = self - memory_banks = await p.list_memory_banks() - await add_objects(memory_banks, pid, None) elif api == Api.datasetio: p.dataset_store = self @@ -188,12 +188,6 @@ class CommonRoutingTableImpl(RoutingTable): objs = await self.dist_registry.get_all() return [obj for obj in objs if obj.type == type] - async def get_all_with_types( - self, types: List[str] - ) -> List[RoutableObjectWithProvider]: - objs = await self.dist_registry.get_all() - return [obj for obj in objs if obj.type in types] - class ModelsRoutingTable(CommonRoutingTableImpl, Models): async def list_models(self) -> List[Model]: @@ -233,7 +227,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): async def list_shields(self) -> List[Shield]: - return await self.get_all_with_type("shield") + return await self.get_all_with_type(ResourceType.shield.value) async def get_shield(self, identifier: str) -> Optional[Shield]: return await self.get_object_by_identifier(identifier) @@ -270,25 +264,41 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): - async def list_memory_banks(self) -> List[MemoryBankDefWithProvider]: - return await self.get_all_with_types( - [ - MemoryBankType.vector.value, - MemoryBankType.keyvalue.value, - MemoryBankType.keyword.value, - MemoryBankType.graph.value, - ] - ) + async def list_memory_banks(self) -> List[MemoryBank]: + return await self.get_all_with_type(ResourceType.memory_bank.value) - async def get_memory_bank( - self, identifier: str - ) -> Optional[MemoryBankDefWithProvider]: - return await self.get_object_by_identifier(identifier) + async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]: + return await self.get_object_by_identifier(memory_bank_id) async def register_memory_bank( - self, memory_bank: MemoryBankDefWithProvider - ) -> None: + self, + memory_bank_id: str, + params: BankParams, + provider_id: Optional[str] = None, + provider_memorybank_id: Optional[str] = None, + ) -> MemoryBank: + if provider_memorybank_id is None: + provider_memorybank_id = memory_bank_id + if provider_id is None: + # If provider_id not specified, use the only provider if it supports this shield type + if len(self.impls_by_provider_id) == 1: + provider_id = list(self.impls_by_provider_id.keys())[0] + else: + raise ValueError( + "No provider specified and multiple providers available. Please specify a provider_id." + ) + memory_bank = parse_obj_as( + MemoryBank, + { + "identifier": memory_bank_id, + "type": ResourceType.memory_bank.value, + "provider_id": provider_id, + "provider_resource_id": provider_memorybank_id, + **params.model_dump(), + }, + ) await self.register_object(memory_bank) + return memory_bank class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): diff --git a/llama_stack/distribution/store/tests/test_registry.py b/llama_stack/distribution/store/tests/test_registry.py index b2f7ada86..e5b64bdc6 100644 --- a/llama_stack/distribution/store/tests/test_registry.py +++ b/llama_stack/distribution/store/tests/test_registry.py @@ -10,7 +10,7 @@ import pytest import pytest_asyncio from llama_stack.distribution.store import * # noqa F403 from llama_stack.apis.inference import Model -from llama_stack.apis.memory_banks import VectorMemoryBankDef +from llama_stack.apis.memory_banks import VectorMemoryBank from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig from llama_stack.distribution.datatypes import * # noqa F403 @@ -39,7 +39,7 @@ async def cached_registry(config): @pytest.fixture def sample_bank(): - return VectorMemoryBankDef( + return VectorMemoryBank( identifier="test_bank", embedding_model="all-MiniLM-L6-v2", chunk_size_in_tokens=512, @@ -113,7 +113,7 @@ async def test_cached_registry_updates(config): cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config)) await cached_registry.initialize() - new_bank = VectorMemoryBankDef( + new_bank = VectorMemoryBank( identifier="test_bank_2", embedding_model="all-MiniLM-L6-v2", chunk_size_in_tokens=256, @@ -144,7 +144,7 @@ async def test_duplicate_provider_registration(config): cached_registry = CachedDiskDistributionRegistry(await kvstore_impl(config)) await cached_registry.initialize() - original_bank = VectorMemoryBankDef( + original_bank = VectorMemoryBank( identifier="test_bank_2", embedding_model="all-MiniLM-L6-v2", chunk_size_in_tokens=256, @@ -153,7 +153,7 @@ async def test_duplicate_provider_registration(config): ) await cached_registry.register(original_bank) - duplicate_bank = VectorMemoryBankDef( + duplicate_bank = VectorMemoryBank( identifier="test_bank_2", embedding_model="different-model", chunk_size_in_tokens=128, diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 7aa2b976f..ed2033494 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -13,7 +13,7 @@ from pydantic import BaseModel, Field from llama_stack.apis.datasets import DatasetDef from llama_stack.apis.eval_tasks import EvalTaskDef -from llama_stack.apis.memory_banks import MemoryBankDef +from llama_stack.apis.memory_banks.memory_banks import MemoryBank from llama_stack.apis.models import Model from llama_stack.apis.scoring_functions import ScoringFnDef from llama_stack.apis.shields import Shield @@ -51,9 +51,9 @@ class ShieldsProtocolPrivate(Protocol): class MemoryBanksProtocolPrivate(Protocol): - async def list_memory_banks(self) -> List[MemoryBankDef]: ... + async def list_memory_banks(self) -> List[MemoryBank]: ... - async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None: ... + async def register_memory_bank(self, memory_bank: MemoryBank) -> None: ... class DatasetsProtocolPrivate(Protocol): diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index cbc7490fd..a36a2c24f 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -641,7 +641,7 @@ class ChatAgent(ShieldRunnerMixin): if session_info.memory_bank_id is None: bank_id = f"memory_bank_{session_id}" - memory_bank = VectorMemoryBankDef( + memory_bank = VectorMemoryBank( identifier=bank_id, embedding_model="all-MiniLM-L6-v2", chunk_size_in_tokens=512, diff --git a/llama_stack/providers/inline/memory/faiss/faiss.py b/llama_stack/providers/inline/memory/faiss/faiss.py index c362eeedb..0ab1b1f78 100644 --- a/llama_stack/providers/inline/memory/faiss/faiss.py +++ b/llama_stack/providers/inline/memory/faiss/faiss.py @@ -83,7 +83,7 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate): stored_banks = await self.kvstore.range(start_key, end_key) for bank_data in stored_banks: - bank = VectorMemoryBankDef.model_validate_json(bank_data) + bank = VectorMemoryBank.model_validate_json(bank_data) index = BankWithIndex( bank=bank, index=FaissIndex(ALL_MINILM_L6_V2_DIMENSION) ) @@ -95,10 +95,10 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate): async def register_memory_bank( self, - memory_bank: MemoryBankDef, + memory_bank: MemoryBank, ) -> None: assert ( - memory_bank.type == MemoryBankType.vector.value + memory_bank.memory_bank_type == MemoryBankType.vector.value ), f"Only vector banks are supported {memory_bank.type}" # Store in kvstore @@ -114,7 +114,7 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate): ) self.cache[memory_bank.identifier] = index - async def list_memory_banks(self) -> List[MemoryBankDef]: + async def list_memory_banks(self) -> List[MemoryBank]: return [i.bank for i in self.cache.values()] async def insert_documents( diff --git a/llama_stack/providers/remote/memory/chroma/chroma.py b/llama_stack/providers/remote/memory/chroma/chroma.py index 7c206d531..0611d9aa2 100644 --- a/llama_stack/providers/remote/memory/chroma/chroma.py +++ b/llama_stack/providers/remote/memory/chroma/chroma.py @@ -98,11 +98,11 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): async def register_memory_bank( self, - memory_bank: MemoryBankDef, + memory_bank: MemoryBank, ) -> None: assert ( - memory_bank.type == MemoryBankType.vector.value - ), f"Only vector banks are supported {memory_bank.type}" + memory_bank.memory_bank_type == MemoryBankType.vector.value + ), f"Only vector banks are supported {memory_bank.memory_bank_type}" collection = await self.client.get_or_create_collection( name=memory_bank.identifier, @@ -113,12 +113,12 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): ) self.cache[memory_bank.identifier] = bank_index - async def list_memory_banks(self) -> List[MemoryBankDef]: + async def list_memory_banks(self) -> List[MemoryBank]: collections = await self.client.list_collections() for collection in collections: try: data = json.loads(collection.metadata["bank"]) - bank = parse_obj_as(MemoryBankDef, data) + bank = parse_obj_as(VectorMemoryBank, data) except Exception: import traceback diff --git a/llama_stack/providers/remote/memory/pgvector/pgvector.py b/llama_stack/providers/remote/memory/pgvector/pgvector.py index 0d188d944..9acfef2dc 100644 --- a/llama_stack/providers/remote/memory/pgvector/pgvector.py +++ b/llama_stack/providers/remote/memory/pgvector/pgvector.py @@ -52,7 +52,7 @@ def load_models(cur, cls): class PGVectorIndex(EmbeddingIndex): - def __init__(self, bank: MemoryBankDef, dimension: int, cursor): + def __init__(self, bank: VectorMemoryBank, dimension: int, cursor): self.cursor = cursor self.table_name = f"vector_store_{bank.identifier}" @@ -121,6 +121,7 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): self.cache = {} async def initialize(self) -> None: + print(f"Initializing PGVector memory adapter with config: {self.config}") try: self.conn = psycopg2.connect( host=self.config.host, @@ -157,11 +158,11 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): async def register_memory_bank( self, - memory_bank: MemoryBankDef, + memory_bank: MemoryBank, ) -> None: assert ( - memory_bank.type == MemoryBankType.vector.value - ), f"Only vector banks are supported {memory_bank.type}" + memory_bank.memory_bank_type == MemoryBankType.vector.value + ), f"Only vector banks are supported {memory_bank.memory_bank_type}" upsert_models( self.cursor, @@ -176,8 +177,8 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): ) self.cache[memory_bank.identifier] = index - async def list_memory_banks(self) -> List[MemoryBankDef]: - banks = load_models(self.cursor, MemoryBankDef) + async def list_memory_banks(self) -> List[MemoryBank]: + banks = load_models(self.cursor, VectorMemoryBank) for bank in banks: if bank.identifier not in self.cache: index = BankWithIndex( diff --git a/llama_stack/providers/remote/memory/qdrant/qdrant.py b/llama_stack/providers/remote/memory/qdrant/qdrant.py index 0f0df3dca..27923a7c5 100644 --- a/llama_stack/providers/remote/memory/qdrant/qdrant.py +++ b/llama_stack/providers/remote/memory/qdrant/qdrant.py @@ -12,6 +12,7 @@ from numpy.typing import NDArray from qdrant_client import AsyncQdrantClient, models from qdrant_client.models import PointStruct +from llama_stack.apis.memory_banks import * # noqa: F403 from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate from llama_stack.apis.memory import * # noqa: F403 @@ -112,11 +113,11 @@ class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): async def register_memory_bank( self, - memory_bank: MemoryBankDef, + memory_bank: MemoryBank, ) -> None: assert ( - memory_bank.type == MemoryBankType.vector.value - ), f"Only vector banks are supported {memory_bank.type}" + memory_bank.memory_bank_type == MemoryBankType.vector + ), f"Only vector banks are supported {memory_bank.memory_bank_type}" index = BankWithIndex( bank=memory_bank, @@ -125,7 +126,7 @@ class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): self.cache[memory_bank.identifier] = index - async def list_memory_banks(self) -> List[MemoryBankDef]: + async def list_memory_banks(self) -> List[MemoryBank]: # Qdrant doesn't have collection level metadata to store the bank properties # So we only return from the cache value return [i.bank for i in self.cache.values()] diff --git a/llama_stack/providers/remote/memory/weaviate/weaviate.py b/llama_stack/providers/remote/memory/weaviate/weaviate.py index 16fa03679..2844402b5 100644 --- a/llama_stack/providers/remote/memory/weaviate/weaviate.py +++ b/llama_stack/providers/remote/memory/weaviate/weaviate.py @@ -114,11 +114,11 @@ class WeaviateMemoryAdapter( async def register_memory_bank( self, - memory_bank: MemoryBankDef, + memory_bank: MemoryBank, ) -> None: assert ( - memory_bank.type == MemoryBankType.vector.value - ), f"Only vector banks are supported {memory_bank.type}" + memory_bank.memory_bank_type == MemoryBankType.vector + ), f"Only vector banks are supported {memory_bank.memory_bank_type}" client = self._get_client() @@ -141,7 +141,7 @@ class WeaviateMemoryAdapter( ) self.cache[memory_bank.identifier] = index - async def list_memory_banks(self) -> List[MemoryBankDef]: + async def list_memory_banks(self) -> List[MemoryBank]: # TODO: right now the Llama Stack is the source of truth for these banks. That is # not ideal. It should be Weaviate which is the source of truth. Unfortunately, # list() happens at Stack startup when the Weaviate client (credentials) is not @@ -157,8 +157,8 @@ class WeaviateMemoryAdapter( raise ValueError(f"Bank {bank_id} not found") client = self._get_client() - if not client.collections.exists(bank_id): - raise ValueError(f"Collection with name `{bank_id}` not found") + if not client.collections.exists(bank.identifier): + raise ValueError(f"Collection with name `{bank.identifier}` not found") index = BankWithIndex( bank=bank, diff --git a/llama_stack/providers/tests/memory/fixtures.py b/llama_stack/providers/tests/memory/fixtures.py index c0931b009..482049045 100644 --- a/llama_stack/providers/tests/memory/fixtures.py +++ b/llama_stack/providers/tests/memory/fixtures.py @@ -10,11 +10,10 @@ import tempfile import pytest import pytest_asyncio -from llama_stack.distribution.datatypes import Api, Provider +from llama_stack.distribution.datatypes import Api, Provider, RemoteProviderConfig from llama_stack.providers.inline.memory.faiss import FaissImplConfig from llama_stack.providers.remote.memory.pgvector import PGVectorConfig from llama_stack.providers.remote.memory.weaviate import WeaviateConfig - from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig from ..conftest import ProviderFixture, remote_stack_fixture @@ -78,7 +77,23 @@ def memory_weaviate() -> ProviderFixture: ) -MEMORY_FIXTURES = ["meta_reference", "pgvector", "weaviate", "remote"] +@pytest.fixture(scope="session") +def memory_chroma() -> ProviderFixture: + return ProviderFixture( + providers=[ + Provider( + provider_id="chroma", + provider_type="remote::chromadb", + config=RemoteProviderConfig( + host=get_env_or_fail("CHROMA_HOST"), + port=get_env_or_fail("CHROMA_PORT"), + ).model_dump(), + ) + ] + ) + + +MEMORY_FIXTURES = ["meta_reference", "pgvector", "weaviate", "remote", "chroma"] @pytest_asyncio.fixture(scope="session") diff --git a/llama_stack/providers/tests/memory/test_memory.py b/llama_stack/providers/tests/memory/test_memory.py index ee3110dea..a1befa6b0 100644 --- a/llama_stack/providers/tests/memory/test_memory.py +++ b/llama_stack/providers/tests/memory/test_memory.py @@ -8,6 +8,7 @@ import pytest from llama_stack.apis.memory import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403 +from llama_stack.apis.memory_banks.memory_banks import VectorMemoryBankParams # How to run this test: # @@ -43,14 +44,15 @@ def sample_documents(): async def register_memory_bank(banks_impl: MemoryBanks): - bank = VectorMemoryBankDef( - identifier="test_bank", - embedding_model="all-MiniLM-L6-v2", - chunk_size_in_tokens=512, - overlap_size_in_tokens=64, - ) - await banks_impl.register_memory_bank(bank) + return await banks_impl.register_memory_bank( + memory_bank_id="test_bank", + params=VectorMemoryBankParams( + embedding_model="all-MiniLM-L6-v2", + chunk_size_in_tokens=512, + overlap_size_in_tokens=64, + ), + ) class TestMemory: @@ -68,20 +70,28 @@ class TestMemory: # NOTE: this needs you to ensure that you are starting from a clean state # but so far we don't have an unregister API unfortunately, so be careful _, banks_impl = memory_stack - bank = VectorMemoryBankDef( - identifier="test_bank_no_provider", - embedding_model="all-MiniLM-L6-v2", - chunk_size_in_tokens=512, - overlap_size_in_tokens=64, - ) - await banks_impl.register_memory_bank(bank) + bank = await banks_impl.register_memory_bank( + memory_bank_id="test_bank_no_provider", + params=VectorMemoryBankParams( + embedding_model="all-MiniLM-L6-v2", + chunk_size_in_tokens=512, + overlap_size_in_tokens=64, + ), + ) response = await banks_impl.list_memory_banks() assert isinstance(response, list) assert len(response) == 1 # register same memory bank with same id again will fail - await banks_impl.register_memory_bank(bank) + await banks_impl.register_memory_bank( + memory_bank_id="test_bank_no_provider", + params=VectorMemoryBankParams( + embedding_model="all-MiniLM-L6-v2", + chunk_size_in_tokens=512, + overlap_size_in_tokens=64, + ), + ) response = await banks_impl.list_memory_banks() assert isinstance(response, list) assert len(response) == 1 diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index 8e2a1550d..ba7ed231e 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -148,7 +148,7 @@ class EmbeddingIndex(ABC): @dataclass class BankWithIndex: - bank: MemoryBankDef + bank: VectorMemoryBank index: EmbeddingIndex async def insert_documents(