mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 00:05:18 +00:00
remove auto discovery
This commit is contained in:
parent
f1a2996812
commit
4206f07ada
4 changed files with 25 additions and 12 deletions
|
@ -90,8 +90,6 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
|
|
||||||
elif api == Api.memory:
|
elif api == Api.memory:
|
||||||
p.memory_bank_store = self
|
p.memory_bank_store = self
|
||||||
memory_banks = await p.list_memory_banks()
|
|
||||||
await add_objects(memory_banks, pid, None)
|
|
||||||
|
|
||||||
elif api == Api.datasetio:
|
elif api == Api.datasetio:
|
||||||
p.dataset_store = self
|
p.dataset_store = self
|
||||||
|
|
|
@ -98,11 +98,11 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
|
|
||||||
async def register_memory_bank(
|
async def register_memory_bank(
|
||||||
self,
|
self,
|
||||||
memory_bank: MemoryBankDef,
|
memory_bank: MemoryBank,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert (
|
assert (
|
||||||
memory_bank.type == MemoryBankType.vector.value
|
memory_bank.memory_bank_type == MemoryBankType.vector.value
|
||||||
), f"Only vector banks are supported {memory_bank.type}"
|
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
|
||||||
|
|
||||||
collection = await self.client.get_or_create_collection(
|
collection = await self.client.get_or_create_collection(
|
||||||
name=memory_bank.identifier,
|
name=memory_bank.identifier,
|
||||||
|
@ -113,12 +113,12 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
)
|
)
|
||||||
self.cache[memory_bank.identifier] = bank_index
|
self.cache[memory_bank.identifier] = bank_index
|
||||||
|
|
||||||
async def list_memory_banks(self) -> List[MemoryBankDef]:
|
async def list_memory_banks(self) -> List[MemoryBank]:
|
||||||
collections = await self.client.list_collections()
|
collections = await self.client.list_collections()
|
||||||
for collection in collections:
|
for collection in collections:
|
||||||
try:
|
try:
|
||||||
data = json.loads(collection.metadata["bank"])
|
data = json.loads(collection.metadata["bank"])
|
||||||
bank = parse_obj_as(MemoryBankDef, data)
|
bank = parse_obj_as(VectorMemoryBank, data)
|
||||||
except Exception:
|
except Exception:
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
|
|
@ -158,7 +158,7 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
|
|
||||||
async def register_memory_bank(
|
async def register_memory_bank(
|
||||||
self,
|
self,
|
||||||
memory_bank: VectorMemoryBank,
|
memory_bank: MemoryBank,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert (
|
assert (
|
||||||
memory_bank.memory_bank_type == MemoryBankType.vector.value
|
memory_bank.memory_bank_type == MemoryBankType.vector.value
|
||||||
|
@ -177,7 +177,7 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
)
|
)
|
||||||
self.cache[memory_bank.identifier] = index
|
self.cache[memory_bank.identifier] = index
|
||||||
|
|
||||||
async def list_memory_banks(self) -> List[VectorMemoryBank]:
|
async def list_memory_banks(self) -> List[MemoryBank]:
|
||||||
banks = load_models(self.cursor, VectorMemoryBank)
|
banks = load_models(self.cursor, VectorMemoryBank)
|
||||||
for bank in banks:
|
for bank in banks:
|
||||||
if bank.identifier not in self.cache:
|
if bank.identifier not in self.cache:
|
||||||
|
|
|
@ -10,11 +10,10 @@ import tempfile
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api, Provider
|
from llama_stack.distribution.datatypes import Api, Provider, RemoteProviderConfig
|
||||||
from llama_stack.providers.inline.memory.faiss import FaissImplConfig
|
from llama_stack.providers.inline.memory.faiss import FaissImplConfig
|
||||||
from llama_stack.providers.remote.memory.pgvector import PGVectorConfig
|
from llama_stack.providers.remote.memory.pgvector import PGVectorConfig
|
||||||
from llama_stack.providers.remote.memory.weaviate import WeaviateConfig
|
from llama_stack.providers.remote.memory.weaviate import WeaviateConfig
|
||||||
|
|
||||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
||||||
from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig
|
from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig
|
||||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||||
|
@ -78,7 +77,23 @@ def memory_weaviate() -> ProviderFixture:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
MEMORY_FIXTURES = ["meta_reference", "pgvector", "weaviate", "remote"]
|
@pytest.fixture(scope="session")
|
||||||
|
def memory_chroma() -> ProviderFixture:
|
||||||
|
return ProviderFixture(
|
||||||
|
providers=[
|
||||||
|
Provider(
|
||||||
|
provider_id="chroma",
|
||||||
|
provider_type="remote::chromadb",
|
||||||
|
config=RemoteProviderConfig(
|
||||||
|
host=get_env_or_fail("CHROMA_HOST"),
|
||||||
|
port=get_env_or_fail("CHROMA_PORT"),
|
||||||
|
).model_dump(),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
MEMORY_FIXTURES = ["meta_reference", "pgvector", "weaviate", "remote", "chroma"]
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session")
|
@pytest_asyncio.fixture(scope="session")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue