From 4206f07adab5811e348099c1e5e899845f3102ea Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Mon, 11 Nov 2024 11:13:33 -0800 Subject: [PATCH] remove auto discovery --- .../distribution/routers/routing_tables.py | 2 -- .../providers/remote/memory/chroma/chroma.py | 10 ++++----- .../remote/memory/pgvector/pgvector.py | 4 ++-- .../providers/tests/memory/fixtures.py | 21 ++++++++++++++++--- 4 files changed, 25 insertions(+), 12 deletions(-) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 41e84a3ac..a23051c6d 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -90,8 +90,6 @@ class CommonRoutingTableImpl(RoutingTable): elif api == Api.memory: p.memory_bank_store = self - memory_banks = await p.list_memory_banks() - await add_objects(memory_banks, pid, None) elif api == Api.datasetio: p.dataset_store = self diff --git a/llama_stack/providers/remote/memory/chroma/chroma.py b/llama_stack/providers/remote/memory/chroma/chroma.py index 7c206d531..0611d9aa2 100644 --- a/llama_stack/providers/remote/memory/chroma/chroma.py +++ b/llama_stack/providers/remote/memory/chroma/chroma.py @@ -98,11 +98,11 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): async def register_memory_bank( self, - memory_bank: MemoryBankDef, + memory_bank: MemoryBank, ) -> None: assert ( - memory_bank.type == MemoryBankType.vector.value - ), f"Only vector banks are supported {memory_bank.type}" + memory_bank.memory_bank_type == MemoryBankType.vector.value + ), f"Only vector banks are supported {memory_bank.memory_bank_type}" collection = await self.client.get_or_create_collection( name=memory_bank.identifier, @@ -113,12 +113,12 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): ) self.cache[memory_bank.identifier] = bank_index - async def list_memory_banks(self) -> List[MemoryBankDef]: + async def list_memory_banks(self) -> List[MemoryBank]: collections = await self.client.list_collections() for collection in collections: try: data = json.loads(collection.metadata["bank"]) - bank = parse_obj_as(MemoryBankDef, data) + bank = parse_obj_as(VectorMemoryBank, data) except Exception: import traceback diff --git a/llama_stack/providers/remote/memory/pgvector/pgvector.py b/llama_stack/providers/remote/memory/pgvector/pgvector.py index fea336f62..9acfef2dc 100644 --- a/llama_stack/providers/remote/memory/pgvector/pgvector.py +++ b/llama_stack/providers/remote/memory/pgvector/pgvector.py @@ -158,7 +158,7 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): async def register_memory_bank( self, - memory_bank: VectorMemoryBank, + memory_bank: MemoryBank, ) -> None: assert ( memory_bank.memory_bank_type == MemoryBankType.vector.value @@ -177,7 +177,7 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): ) self.cache[memory_bank.identifier] = index - async def list_memory_banks(self) -> List[VectorMemoryBank]: + async def list_memory_banks(self) -> List[MemoryBank]: banks = load_models(self.cursor, VectorMemoryBank) for bank in banks: if bank.identifier not in self.cache: diff --git a/llama_stack/providers/tests/memory/fixtures.py b/llama_stack/providers/tests/memory/fixtures.py index c0931b009..482049045 100644 --- a/llama_stack/providers/tests/memory/fixtures.py +++ b/llama_stack/providers/tests/memory/fixtures.py @@ -10,11 +10,10 @@ import tempfile import pytest import pytest_asyncio -from llama_stack.distribution.datatypes import Api, Provider +from llama_stack.distribution.datatypes import Api, Provider, RemoteProviderConfig from llama_stack.providers.inline.memory.faiss import FaissImplConfig from llama_stack.providers.remote.memory.pgvector import PGVectorConfig from llama_stack.providers.remote.memory.weaviate import WeaviateConfig - from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig from ..conftest import ProviderFixture, remote_stack_fixture @@ -78,7 +77,23 @@ def memory_weaviate() -> ProviderFixture: ) -MEMORY_FIXTURES = ["meta_reference", "pgvector", "weaviate", "remote"] +@pytest.fixture(scope="session") +def memory_chroma() -> ProviderFixture: + return ProviderFixture( + providers=[ + Provider( + provider_id="chroma", + provider_type="remote::chromadb", + config=RemoteProviderConfig( + host=get_env_or_fail("CHROMA_HOST"), + port=get_env_or_fail("CHROMA_PORT"), + ).model_dump(), + ) + ] + ) + + +MEMORY_FIXTURES = ["meta_reference", "pgvector", "weaviate", "remote", "chroma"] @pytest_asyncio.fixture(scope="session")