mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 07:39:38 +00:00
Fix memory to use the newer fixture organization
This commit is contained in:
parent
dd049d5727
commit
62dd3b376c
8 changed files with 102 additions and 137 deletions
|
@ -129,4 +129,5 @@ def pytest_itemcollected(item):
|
|||
pytest_plugins = [
|
||||
"llama_stack.providers.tests.inference.fixtures",
|
||||
"llama_stack.providers.tests.safety.fixtures",
|
||||
"llama_stack.providers.tests.memory.fixtures",
|
||||
]
|
||||
|
|
|
@ -4,87 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
from llama_stack.providers.adapters.memory.pgvector import PGVectorConfig
|
||||
from llama_stack.providers.adapters.memory.weaviate import WeaviateConfig
|
||||
from llama_stack.providers.impls.meta_reference.memory import FaissImplConfig
|
||||
|
||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
||||
from ..conftest import ProviderFixture
|
||||
from ..env import get_env_or_fail
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def meta_reference() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
provider=Provider(
|
||||
provider_id="meta-reference",
|
||||
provider_type="meta-reference",
|
||||
config=FaissImplConfig().model_dump(),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def pgvector() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
provider=Provider(
|
||||
provider_id="pgvector",
|
||||
provider_type="remote::pgvector",
|
||||
config=PGVectorConfig(
|
||||
host=os.getenv("PGVECTOR_HOST", "localhost"),
|
||||
port=os.getenv("PGVECTOR_PORT", 5432),
|
||||
db=get_env_or_fail("PGVECTOR_DB"),
|
||||
user=get_env_or_fail("PGVECTOR_USER"),
|
||||
password=get_env_or_fail("PGVECTOR_PASSWORD"),
|
||||
).model_dump(),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def weaviate() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
provider=Provider(
|
||||
provider_id="weaviate",
|
||||
provider_type="remote::weaviate",
|
||||
config=WeaviateConfig().model_dump(),
|
||||
),
|
||||
provider_data=dict(
|
||||
weaviate_api_key=get_env_or_fail("WEAVIATE_API_KEY"),
|
||||
weaviate_cluster_url=get_env_or_fail("WEAVIATE_CLUSTER_URL"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
MEMORY_FIXTURES = ["meta_reference", "pgvector", "weaviate"]
|
||||
|
||||
PROVIDER_PARAMS = [
|
||||
pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name))
|
||||
for fixture_name in MEMORY_FIXTURES
|
||||
]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(
|
||||
scope="session",
|
||||
params=PROVIDER_PARAMS,
|
||||
)
|
||||
async def stack_impls(request):
|
||||
fixture_name = request.param
|
||||
fixture = request.getfixturevalue(fixture_name)
|
||||
|
||||
impls = await resolve_impls_for_test_v2(
|
||||
[Api.memory],
|
||||
{"memory": [fixture.provider.model_dump()]},
|
||||
fixture.provider_data,
|
||||
)
|
||||
|
||||
return impls[Api.memory], impls[Api.memory_banks]
|
||||
from .fixtures import MEMORY_FIXTURES
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
|
|
87
llama_stack/providers/tests/memory/fixtures.py
Normal file
87
llama_stack/providers/tests/memory/fixtures.py
Normal file
|
@ -0,0 +1,87 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, Provider
|
||||
from llama_stack.providers.adapters.memory.pgvector import PGVectorConfig
|
||||
from llama_stack.providers.adapters.memory.weaviate import WeaviateConfig
|
||||
from llama_stack.providers.impls.meta_reference.memory import FaissImplConfig
|
||||
|
||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
||||
from ..conftest import ProviderFixture
|
||||
from ..env import get_env_or_fail
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def meta_reference() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
provider=Provider(
|
||||
provider_id="meta-reference",
|
||||
provider_type="meta-reference",
|
||||
config=FaissImplConfig().model_dump(),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def pgvector() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
provider=Provider(
|
||||
provider_id="pgvector",
|
||||
provider_type="remote::pgvector",
|
||||
config=PGVectorConfig(
|
||||
host=os.getenv("PGVECTOR_HOST", "localhost"),
|
||||
port=os.getenv("PGVECTOR_PORT", 5432),
|
||||
db=get_env_or_fail("PGVECTOR_DB"),
|
||||
user=get_env_or_fail("PGVECTOR_USER"),
|
||||
password=get_env_or_fail("PGVECTOR_PASSWORD"),
|
||||
).model_dump(),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def weaviate() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
provider=Provider(
|
||||
provider_id="weaviate",
|
||||
provider_type="remote::weaviate",
|
||||
config=WeaviateConfig().model_dump(),
|
||||
),
|
||||
provider_data=dict(
|
||||
weaviate_api_key=get_env_or_fail("WEAVIATE_API_KEY"),
|
||||
weaviate_cluster_url=get_env_or_fail("WEAVIATE_CLUSTER_URL"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
MEMORY_FIXTURES = ["meta_reference", "pgvector", "weaviate"]
|
||||
|
||||
PROVIDER_PARAMS = [
|
||||
pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name))
|
||||
for fixture_name in MEMORY_FIXTURES
|
||||
]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(
|
||||
scope="session",
|
||||
params=PROVIDER_PARAMS,
|
||||
)
|
||||
async def memory_stack(request):
|
||||
fixture_name = request.param
|
||||
fixture = request.getfixturevalue(fixture_name)
|
||||
|
||||
impls = await resolve_impls_for_test_v2(
|
||||
[Api.memory],
|
||||
{"memory": [fixture.provider.model_dump()]},
|
||||
fixture.provider_data,
|
||||
)
|
||||
|
||||
return impls[Api.memory], impls[Api.memory_banks]
|
|
@ -1,29 +0,0 @@
|
|||
providers:
|
||||
- provider_id: test-faiss
|
||||
provider_type: meta-reference
|
||||
config: {}
|
||||
- provider_id: test-chromadb
|
||||
provider_type: remote::chromadb
|
||||
config:
|
||||
host: localhost
|
||||
port: 6001
|
||||
- provider_id: test-remote
|
||||
provider_type: remote
|
||||
config:
|
||||
host: localhost
|
||||
port: 7002
|
||||
- provider_id: test-weaviate
|
||||
provider_type: remote::weaviate
|
||||
config: {}
|
||||
- provider_id: test-qdrant
|
||||
provider_type: remote::qdrant
|
||||
config:
|
||||
host: localhost
|
||||
port: 6333
|
||||
# if a provider needs private keys from the client, they use the
|
||||
# "get_request_provider_data" function (see distribution/request_headers.py)
|
||||
# this is a place to provide such data.
|
||||
provider_data:
|
||||
"test-weaviate":
|
||||
weaviate_api_key: 0xdeadbeefputrealapikeyhere
|
||||
weaviate_cluster_url: http://foobarbaz
|
|
@ -8,7 +8,7 @@ import pytest
|
|||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
from .conftest import PROVIDER_PARAMS
|
||||
from .fixtures import PROVIDER_PARAMS
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
|
@ -55,25 +55,25 @@ async def register_memory_bank(banks_impl: MemoryBanks):
|
|||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"stack_impls",
|
||||
"memory_stack",
|
||||
PROVIDER_PARAMS,
|
||||
indirect=True,
|
||||
)
|
||||
class TestMemory:
|
||||
@pytest.mark.asyncio
|
||||
async def test_banks_list(self, stack_impls):
|
||||
async def test_banks_list(self, memory_stack):
|
||||
# NOTE: this needs you to ensure that you are starting from a clean state
|
||||
# but so far we don't have an unregister API unfortunately, so be careful
|
||||
_, banks_impl = stack_impls
|
||||
_, banks_impl = memory_stack
|
||||
response = await banks_impl.list_memory_banks()
|
||||
assert isinstance(response, list)
|
||||
assert len(response) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_banks_register(self, stack_impls):
|
||||
async def test_banks_register(self, memory_stack):
|
||||
# NOTE: this needs you to ensure that you are starting from a clean state
|
||||
# but so far we don't have an unregister API unfortunately, so be careful
|
||||
_, banks_impl = stack_impls
|
||||
_, banks_impl = memory_stack
|
||||
bank = VectorMemoryBankDef(
|
||||
identifier="test_bank_no_provider",
|
||||
embedding_model="all-MiniLM-L6-v2",
|
||||
|
@ -93,8 +93,8 @@ class TestMemory:
|
|||
assert len(response) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_documents(self, stack_impls, sample_documents):
|
||||
memory_impl, banks_impl = stack_impls
|
||||
async def test_query_documents(self, memory_stack, sample_documents):
|
||||
memory_impl, banks_impl = memory_stack
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await memory_impl.insert_documents("test_bank", sample_documents)
|
||||
|
|
|
@ -49,6 +49,10 @@ def pytest_configure(config):
|
|||
|
||||
|
||||
def pytest_generate_tests(metafunc):
|
||||
# We use this method to make sure we have built-in simple combos for safety tests
|
||||
# But a user can also pass in a custom combination via the CLI by doing
|
||||
# `--providers inference=together,safety=meta_reference`
|
||||
|
||||
if "safety_stack" in metafunc.fixturenames:
|
||||
# print(f"metafunc.fixturenames: {metafunc.fixturenames}, {metafunc}")
|
||||
available_fixtures = {
|
||||
|
|
|
@ -64,6 +64,7 @@ SAFETY_FIXTURES = ["meta_reference", "together"]
|
|||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def safety_stack(inference_model, safety_model, request):
|
||||
# We need an inference + safety fixture to test safety
|
||||
fixture_dict = request.param
|
||||
inference_fixture = request.getfixturevalue(
|
||||
f"inference_{fixture_dict['inference']}"
|
||||
|
|
|
@ -1,19 +0,0 @@
|
|||
providers:
|
||||
inference:
|
||||
- provider_id: together
|
||||
provider_type: remote::together
|
||||
config: {}
|
||||
- provider_id: tgi
|
||||
provider_type: remote::tgi
|
||||
config:
|
||||
url: http://127.0.0.1:7002
|
||||
- provider_id: meta-reference
|
||||
provider_type: meta-reference
|
||||
config:
|
||||
model: Llama-Guard-3-1B
|
||||
safety:
|
||||
- provider_id: meta-reference
|
||||
provider_type: meta-reference
|
||||
config:
|
||||
llama_guard_shield:
|
||||
model: Llama-Guard-3-1B
|
Loading…
Add table
Add a link
Reference in a new issue