mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 23:51:00 +00:00
remove auto discovery
This commit is contained in:
parent
f1a2996812
commit
4206f07ada
4 changed files with 25 additions and 12 deletions
|
@ -90,8 +90,6 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
|
||||
elif api == Api.memory:
|
||||
p.memory_bank_store = self
|
||||
memory_banks = await p.list_memory_banks()
|
||||
await add_objects(memory_banks, pid, None)
|
||||
|
||||
elif api == Api.datasetio:
|
||||
p.dataset_store = self
|
||||
|
|
|
@ -98,11 +98,11 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
|
||||
async def register_memory_bank(
|
||||
self,
|
||||
memory_bank: MemoryBankDef,
|
||||
memory_bank: MemoryBank,
|
||||
) -> None:
|
||||
assert (
|
||||
memory_bank.type == MemoryBankType.vector.value
|
||||
), f"Only vector banks are supported {memory_bank.type}"
|
||||
memory_bank.memory_bank_type == MemoryBankType.vector.value
|
||||
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
|
||||
|
||||
collection = await self.client.get_or_create_collection(
|
||||
name=memory_bank.identifier,
|
||||
|
@ -113,12 +113,12 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
)
|
||||
self.cache[memory_bank.identifier] = bank_index
|
||||
|
||||
async def list_memory_banks(self) -> List[MemoryBankDef]:
|
||||
async def list_memory_banks(self) -> List[MemoryBank]:
|
||||
collections = await self.client.list_collections()
|
||||
for collection in collections:
|
||||
try:
|
||||
data = json.loads(collection.metadata["bank"])
|
||||
bank = parse_obj_as(MemoryBankDef, data)
|
||||
bank = parse_obj_as(VectorMemoryBank, data)
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
|
|
|
@ -158,7 +158,7 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
|
||||
async def register_memory_bank(
|
||||
self,
|
||||
memory_bank: VectorMemoryBank,
|
||||
memory_bank: MemoryBank,
|
||||
) -> None:
|
||||
assert (
|
||||
memory_bank.memory_bank_type == MemoryBankType.vector.value
|
||||
|
@ -177,7 +177,7 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
|||
)
|
||||
self.cache[memory_bank.identifier] = index
|
||||
|
||||
async def list_memory_banks(self) -> List[VectorMemoryBank]:
|
||||
async def list_memory_banks(self) -> List[MemoryBank]:
|
||||
banks = load_models(self.cursor, VectorMemoryBank)
|
||||
for bank in banks:
|
||||
if bank.identifier not in self.cache:
|
||||
|
|
|
@ -10,11 +10,10 @@ import tempfile
|
|||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
from llama_stack.distribution.datatypes import Api, Provider, RemoteProviderConfig
|
||||
from llama_stack.providers.inline.memory.faiss import FaissImplConfig
|
||||
from llama_stack.providers.remote.memory.pgvector import PGVectorConfig
|
||||
from llama_stack.providers.remote.memory.weaviate import WeaviateConfig
|
||||
|
||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
||||
from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig
|
||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||
|
@ -78,7 +77,23 @@ def memory_weaviate() -> ProviderFixture:
|
|||
)
|
||||
|
||||
|
||||
MEMORY_FIXTURES = ["meta_reference", "pgvector", "weaviate", "remote"]
|
||||
@pytest.fixture(scope="session")
|
||||
def memory_chroma() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="chroma",
|
||||
provider_type="remote::chromadb",
|
||||
config=RemoteProviderConfig(
|
||||
host=get_env_or_fail("CHROMA_HOST"),
|
||||
port=get_env_or_fail("CHROMA_PORT"),
|
||||
).model_dump(),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
MEMORY_FIXTURES = ["meta_reference", "pgvector", "weaviate", "remote", "chroma"]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue