remove auto discovery

This commit is contained in:
Dinesh Yeduguru 2024-11-11 11:13:33 -08:00
parent f1a2996812
commit 4206f07ada
4 changed files with 25 additions and 12 deletions

View file

@ -90,8 +90,6 @@ class CommonRoutingTableImpl(RoutingTable):
elif api == Api.memory:
p.memory_bank_store = self
memory_banks = await p.list_memory_banks()
await add_objects(memory_banks, pid, None)
elif api == Api.datasetio:
p.dataset_store = self

View file

@ -98,11 +98,11 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
async def register_memory_bank(
self,
memory_bank: MemoryBankDef,
memory_bank: MemoryBank,
) -> None:
assert (
memory_bank.type == MemoryBankType.vector.value
), f"Only vector banks are supported {memory_bank.type}"
memory_bank.memory_bank_type == MemoryBankType.vector.value
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
collection = await self.client.get_or_create_collection(
name=memory_bank.identifier,
@ -113,12 +113,12 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
)
self.cache[memory_bank.identifier] = bank_index
async def list_memory_banks(self) -> List[MemoryBankDef]:
async def list_memory_banks(self) -> List[MemoryBank]:
collections = await self.client.list_collections()
for collection in collections:
try:
data = json.loads(collection.metadata["bank"])
bank = parse_obj_as(MemoryBankDef, data)
bank = parse_obj_as(VectorMemoryBank, data)
except Exception:
import traceback

View file

@ -158,7 +158,7 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
async def register_memory_bank(
self,
memory_bank: VectorMemoryBank,
memory_bank: MemoryBank,
) -> None:
assert (
memory_bank.memory_bank_type == MemoryBankType.vector.value
@ -177,7 +177,7 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
)
self.cache[memory_bank.identifier] = index
async def list_memory_banks(self) -> List[VectorMemoryBank]:
async def list_memory_banks(self) -> List[MemoryBank]:
banks = load_models(self.cursor, VectorMemoryBank)
for bank in banks:
if bank.identifier not in self.cache:

View file

@ -10,11 +10,10 @@ import tempfile
import pytest
import pytest_asyncio
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.distribution.datatypes import Api, Provider, RemoteProviderConfig
from llama_stack.providers.inline.memory.faiss import FaissImplConfig
from llama_stack.providers.remote.memory.pgvector import PGVectorConfig
from llama_stack.providers.remote.memory.weaviate import WeaviateConfig
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig
from ..conftest import ProviderFixture, remote_stack_fixture
@ -78,7 +77,23 @@ def memory_weaviate() -> ProviderFixture:
)
MEMORY_FIXTURES = ["meta_reference", "pgvector", "weaviate", "remote"]
@pytest.fixture(scope="session")
def memory_chroma() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="chroma",
provider_type="remote::chromadb",
config=RemoteProviderConfig(
host=get_env_or_fail("CHROMA_HOST"),
port=get_env_or_fail("CHROMA_PORT"),
).model_dump(),
)
]
)
MEMORY_FIXTURES = ["meta_reference", "pgvector", "weaviate", "remote", "chroma"]
@pytest_asyncio.fixture(scope="session")