remove auto discovery

This commit is contained in:
Dinesh Yeduguru 2024-11-11 11:13:33 -08:00
parent f1a2996812
commit 4206f07ada
4 changed files with 25 additions and 12 deletions

View file

@ -90,8 +90,6 @@ class CommonRoutingTableImpl(RoutingTable):
elif api == Api.memory: elif api == Api.memory:
p.memory_bank_store = self p.memory_bank_store = self
memory_banks = await p.list_memory_banks()
await add_objects(memory_banks, pid, None)
elif api == Api.datasetio: elif api == Api.datasetio:
p.dataset_store = self p.dataset_store = self

View file

@ -98,11 +98,11 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
async def register_memory_bank( async def register_memory_bank(
self, self,
memory_bank: MemoryBankDef, memory_bank: MemoryBank,
) -> None: ) -> None:
assert ( assert (
memory_bank.type == MemoryBankType.vector.value memory_bank.memory_bank_type == MemoryBankType.vector.value
), f"Only vector banks are supported {memory_bank.type}" ), f"Only vector banks are supported {memory_bank.memory_bank_type}"
collection = await self.client.get_or_create_collection( collection = await self.client.get_or_create_collection(
name=memory_bank.identifier, name=memory_bank.identifier,
@ -113,12 +113,12 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
) )
self.cache[memory_bank.identifier] = bank_index self.cache[memory_bank.identifier] = bank_index
async def list_memory_banks(self) -> List[MemoryBankDef]: async def list_memory_banks(self) -> List[MemoryBank]:
collections = await self.client.list_collections() collections = await self.client.list_collections()
for collection in collections: for collection in collections:
try: try:
data = json.loads(collection.metadata["bank"]) data = json.loads(collection.metadata["bank"])
bank = parse_obj_as(MemoryBankDef, data) bank = parse_obj_as(VectorMemoryBank, data)
except Exception: except Exception:
import traceback import traceback

View file

@ -158,7 +158,7 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
async def register_memory_bank( async def register_memory_bank(
self, self,
memory_bank: VectorMemoryBank, memory_bank: MemoryBank,
) -> None: ) -> None:
assert ( assert (
memory_bank.memory_bank_type == MemoryBankType.vector.value memory_bank.memory_bank_type == MemoryBankType.vector.value
@ -177,7 +177,7 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
) )
self.cache[memory_bank.identifier] = index self.cache[memory_bank.identifier] = index
async def list_memory_banks(self) -> List[VectorMemoryBank]: async def list_memory_banks(self) -> List[MemoryBank]:
banks = load_models(self.cursor, VectorMemoryBank) banks = load_models(self.cursor, VectorMemoryBank)
for bank in banks: for bank in banks:
if bank.identifier not in self.cache: if bank.identifier not in self.cache:

View file

@ -10,11 +10,10 @@ import tempfile
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from llama_stack.distribution.datatypes import Api, Provider from llama_stack.distribution.datatypes import Api, Provider, RemoteProviderConfig
from llama_stack.providers.inline.memory.faiss import FaissImplConfig from llama_stack.providers.inline.memory.faiss import FaissImplConfig
from llama_stack.providers.remote.memory.pgvector import PGVectorConfig from llama_stack.providers.remote.memory.pgvector import PGVectorConfig
from llama_stack.providers.remote.memory.weaviate import WeaviateConfig from llama_stack.providers.remote.memory.weaviate import WeaviateConfig
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig
from ..conftest import ProviderFixture, remote_stack_fixture from ..conftest import ProviderFixture, remote_stack_fixture
@ -78,7 +77,23 @@ def memory_weaviate() -> ProviderFixture:
) )
MEMORY_FIXTURES = ["meta_reference", "pgvector", "weaviate", "remote"] @pytest.fixture(scope="session")
def memory_chroma() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="chroma",
provider_type="remote::chromadb",
config=RemoteProviderConfig(
host=get_env_or_fail("CHROMA_HOST"),
port=get_env_or_fail("CHROMA_PORT"),
).model_dump(),
)
]
)
MEMORY_FIXTURES = ["meta_reference", "pgvector", "weaviate", "remote", "chroma"]
@pytest_asyncio.fixture(scope="session") @pytest_asyncio.fixture(scope="session")