mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 23:51:00 +00:00
Fix memory to use the newer fixture organization
This commit is contained in:
parent
dd049d5727
commit
62dd3b376c
8 changed files with 102 additions and 137 deletions
|
@ -129,4 +129,5 @@ def pytest_itemcollected(item):
|
||||||
pytest_plugins = [
|
pytest_plugins = [
|
||||||
"llama_stack.providers.tests.inference.fixtures",
|
"llama_stack.providers.tests.inference.fixtures",
|
||||||
"llama_stack.providers.tests.safety.fixtures",
|
"llama_stack.providers.tests.safety.fixtures",
|
||||||
|
"llama_stack.providers.tests.memory.fixtures",
|
||||||
]
|
]
|
||||||
|
|
|
@ -4,87 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import os
|
from .fixtures import MEMORY_FIXTURES
|
||||||
|
|
||||||
import pytest
|
|
||||||
import pytest_asyncio
|
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api, Provider
|
|
||||||
from llama_stack.providers.adapters.memory.pgvector import PGVectorConfig
|
|
||||||
from llama_stack.providers.adapters.memory.weaviate import WeaviateConfig
|
|
||||||
from llama_stack.providers.impls.meta_reference.memory import FaissImplConfig
|
|
||||||
|
|
||||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
|
||||||
from ..conftest import ProviderFixture
|
|
||||||
from ..env import get_env_or_fail
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def meta_reference() -> ProviderFixture:
|
|
||||||
return ProviderFixture(
|
|
||||||
provider=Provider(
|
|
||||||
provider_id="meta-reference",
|
|
||||||
provider_type="meta-reference",
|
|
||||||
config=FaissImplConfig().model_dump(),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def pgvector() -> ProviderFixture:
|
|
||||||
return ProviderFixture(
|
|
||||||
provider=Provider(
|
|
||||||
provider_id="pgvector",
|
|
||||||
provider_type="remote::pgvector",
|
|
||||||
config=PGVectorConfig(
|
|
||||||
host=os.getenv("PGVECTOR_HOST", "localhost"),
|
|
||||||
port=os.getenv("PGVECTOR_PORT", 5432),
|
|
||||||
db=get_env_or_fail("PGVECTOR_DB"),
|
|
||||||
user=get_env_or_fail("PGVECTOR_USER"),
|
|
||||||
password=get_env_or_fail("PGVECTOR_PASSWORD"),
|
|
||||||
).model_dump(),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def weaviate() -> ProviderFixture:
|
|
||||||
return ProviderFixture(
|
|
||||||
provider=Provider(
|
|
||||||
provider_id="weaviate",
|
|
||||||
provider_type="remote::weaviate",
|
|
||||||
config=WeaviateConfig().model_dump(),
|
|
||||||
),
|
|
||||||
provider_data=dict(
|
|
||||||
weaviate_api_key=get_env_or_fail("WEAVIATE_API_KEY"),
|
|
||||||
weaviate_cluster_url=get_env_or_fail("WEAVIATE_CLUSTER_URL"),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
MEMORY_FIXTURES = ["meta_reference", "pgvector", "weaviate"]
|
|
||||||
|
|
||||||
PROVIDER_PARAMS = [
|
|
||||||
pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name))
|
|
||||||
for fixture_name in MEMORY_FIXTURES
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(
|
|
||||||
scope="session",
|
|
||||||
params=PROVIDER_PARAMS,
|
|
||||||
)
|
|
||||||
async def stack_impls(request):
|
|
||||||
fixture_name = request.param
|
|
||||||
fixture = request.getfixturevalue(fixture_name)
|
|
||||||
|
|
||||||
impls = await resolve_impls_for_test_v2(
|
|
||||||
[Api.memory],
|
|
||||||
{"memory": [fixture.provider.model_dump()]},
|
|
||||||
fixture.provider_data,
|
|
||||||
)
|
|
||||||
|
|
||||||
return impls[Api.memory], impls[Api.memory_banks]
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_configure(config):
|
def pytest_configure(config):
|
||||||
|
|
87
llama_stack/providers/tests/memory/fixtures.py
Normal file
87
llama_stack/providers/tests/memory/fixtures.py
Normal file
|
@ -0,0 +1,87 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import Api, Provider
|
||||||
|
from llama_stack.providers.adapters.memory.pgvector import PGVectorConfig
|
||||||
|
from llama_stack.providers.adapters.memory.weaviate import WeaviateConfig
|
||||||
|
from llama_stack.providers.impls.meta_reference.memory import FaissImplConfig
|
||||||
|
|
||||||
|
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
||||||
|
from ..conftest import ProviderFixture
|
||||||
|
from ..env import get_env_or_fail
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def meta_reference() -> ProviderFixture:
|
||||||
|
return ProviderFixture(
|
||||||
|
provider=Provider(
|
||||||
|
provider_id="meta-reference",
|
||||||
|
provider_type="meta-reference",
|
||||||
|
config=FaissImplConfig().model_dump(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def pgvector() -> ProviderFixture:
|
||||||
|
return ProviderFixture(
|
||||||
|
provider=Provider(
|
||||||
|
provider_id="pgvector",
|
||||||
|
provider_type="remote::pgvector",
|
||||||
|
config=PGVectorConfig(
|
||||||
|
host=os.getenv("PGVECTOR_HOST", "localhost"),
|
||||||
|
port=os.getenv("PGVECTOR_PORT", 5432),
|
||||||
|
db=get_env_or_fail("PGVECTOR_DB"),
|
||||||
|
user=get_env_or_fail("PGVECTOR_USER"),
|
||||||
|
password=get_env_or_fail("PGVECTOR_PASSWORD"),
|
||||||
|
).model_dump(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def weaviate() -> ProviderFixture:
|
||||||
|
return ProviderFixture(
|
||||||
|
provider=Provider(
|
||||||
|
provider_id="weaviate",
|
||||||
|
provider_type="remote::weaviate",
|
||||||
|
config=WeaviateConfig().model_dump(),
|
||||||
|
),
|
||||||
|
provider_data=dict(
|
||||||
|
weaviate_api_key=get_env_or_fail("WEAVIATE_API_KEY"),
|
||||||
|
weaviate_cluster_url=get_env_or_fail("WEAVIATE_CLUSTER_URL"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
MEMORY_FIXTURES = ["meta_reference", "pgvector", "weaviate"]
|
||||||
|
|
||||||
|
PROVIDER_PARAMS = [
|
||||||
|
pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name))
|
||||||
|
for fixture_name in MEMORY_FIXTURES
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(
|
||||||
|
scope="session",
|
||||||
|
params=PROVIDER_PARAMS,
|
||||||
|
)
|
||||||
|
async def memory_stack(request):
|
||||||
|
fixture_name = request.param
|
||||||
|
fixture = request.getfixturevalue(fixture_name)
|
||||||
|
|
||||||
|
impls = await resolve_impls_for_test_v2(
|
||||||
|
[Api.memory],
|
||||||
|
{"memory": [fixture.provider.model_dump()]},
|
||||||
|
fixture.provider_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
return impls[Api.memory], impls[Api.memory_banks]
|
|
@ -1,29 +0,0 @@
|
||||||
providers:
|
|
||||||
- provider_id: test-faiss
|
|
||||||
provider_type: meta-reference
|
|
||||||
config: {}
|
|
||||||
- provider_id: test-chromadb
|
|
||||||
provider_type: remote::chromadb
|
|
||||||
config:
|
|
||||||
host: localhost
|
|
||||||
port: 6001
|
|
||||||
- provider_id: test-remote
|
|
||||||
provider_type: remote
|
|
||||||
config:
|
|
||||||
host: localhost
|
|
||||||
port: 7002
|
|
||||||
- provider_id: test-weaviate
|
|
||||||
provider_type: remote::weaviate
|
|
||||||
config: {}
|
|
||||||
- provider_id: test-qdrant
|
|
||||||
provider_type: remote::qdrant
|
|
||||||
config:
|
|
||||||
host: localhost
|
|
||||||
port: 6333
|
|
||||||
# if a provider needs private keys from the client, they use the
|
|
||||||
# "get_request_provider_data" function (see distribution/request_headers.py)
|
|
||||||
# this is a place to provide such data.
|
|
||||||
provider_data:
|
|
||||||
"test-weaviate":
|
|
||||||
weaviate_api_key: 0xdeadbeefputrealapikeyhere
|
|
||||||
weaviate_cluster_url: http://foobarbaz
|
|
|
@ -8,7 +8,7 @@ import pytest
|
||||||
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
from .conftest import PROVIDER_PARAMS
|
from .fixtures import PROVIDER_PARAMS
|
||||||
|
|
||||||
# How to run this test:
|
# How to run this test:
|
||||||
#
|
#
|
||||||
|
@ -55,25 +55,25 @@ async def register_memory_bank(banks_impl: MemoryBanks):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"stack_impls",
|
"memory_stack",
|
||||||
PROVIDER_PARAMS,
|
PROVIDER_PARAMS,
|
||||||
indirect=True,
|
indirect=True,
|
||||||
)
|
)
|
||||||
class TestMemory:
|
class TestMemory:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_banks_list(self, stack_impls):
|
async def test_banks_list(self, memory_stack):
|
||||||
# NOTE: this needs you to ensure that you are starting from a clean state
|
# NOTE: this needs you to ensure that you are starting from a clean state
|
||||||
# but so far we don't have an unregister API unfortunately, so be careful
|
# but so far we don't have an unregister API unfortunately, so be careful
|
||||||
_, banks_impl = stack_impls
|
_, banks_impl = memory_stack
|
||||||
response = await banks_impl.list_memory_banks()
|
response = await banks_impl.list_memory_banks()
|
||||||
assert isinstance(response, list)
|
assert isinstance(response, list)
|
||||||
assert len(response) == 0
|
assert len(response) == 0
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_banks_register(self, stack_impls):
|
async def test_banks_register(self, memory_stack):
|
||||||
# NOTE: this needs you to ensure that you are starting from a clean state
|
# NOTE: this needs you to ensure that you are starting from a clean state
|
||||||
# but so far we don't have an unregister API unfortunately, so be careful
|
# but so far we don't have an unregister API unfortunately, so be careful
|
||||||
_, banks_impl = stack_impls
|
_, banks_impl = memory_stack
|
||||||
bank = VectorMemoryBankDef(
|
bank = VectorMemoryBankDef(
|
||||||
identifier="test_bank_no_provider",
|
identifier="test_bank_no_provider",
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
embedding_model="all-MiniLM-L6-v2",
|
||||||
|
@ -93,8 +93,8 @@ class TestMemory:
|
||||||
assert len(response) == 1
|
assert len(response) == 1
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_query_documents(self, stack_impls, sample_documents):
|
async def test_query_documents(self, memory_stack, sample_documents):
|
||||||
memory_impl, banks_impl = stack_impls
|
memory_impl, banks_impl = memory_stack
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
await memory_impl.insert_documents("test_bank", sample_documents)
|
await memory_impl.insert_documents("test_bank", sample_documents)
|
||||||
|
|
|
@ -49,6 +49,10 @@ def pytest_configure(config):
|
||||||
|
|
||||||
|
|
||||||
def pytest_generate_tests(metafunc):
|
def pytest_generate_tests(metafunc):
|
||||||
|
# We use this method to make sure we have built-in simple combos for safety tests
|
||||||
|
# But a user can also pass in a custom combination via the CLI by doing
|
||||||
|
# `--providers inference=together,safety=meta_reference`
|
||||||
|
|
||||||
if "safety_stack" in metafunc.fixturenames:
|
if "safety_stack" in metafunc.fixturenames:
|
||||||
# print(f"metafunc.fixturenames: {metafunc.fixturenames}, {metafunc}")
|
# print(f"metafunc.fixturenames: {metafunc.fixturenames}, {metafunc}")
|
||||||
available_fixtures = {
|
available_fixtures = {
|
||||||
|
|
|
@ -64,6 +64,7 @@ SAFETY_FIXTURES = ["meta_reference", "together"]
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session")
|
@pytest_asyncio.fixture(scope="session")
|
||||||
async def safety_stack(inference_model, safety_model, request):
|
async def safety_stack(inference_model, safety_model, request):
|
||||||
|
# We need an inference + safety fixture to test safety
|
||||||
fixture_dict = request.param
|
fixture_dict = request.param
|
||||||
inference_fixture = request.getfixturevalue(
|
inference_fixture = request.getfixturevalue(
|
||||||
f"inference_{fixture_dict['inference']}"
|
f"inference_{fixture_dict['inference']}"
|
||||||
|
|
|
@ -1,19 +0,0 @@
|
||||||
providers:
|
|
||||||
inference:
|
|
||||||
- provider_id: together
|
|
||||||
provider_type: remote::together
|
|
||||||
config: {}
|
|
||||||
- provider_id: tgi
|
|
||||||
provider_type: remote::tgi
|
|
||||||
config:
|
|
||||||
url: http://127.0.0.1:7002
|
|
||||||
- provider_id: meta-reference
|
|
||||||
provider_type: meta-reference
|
|
||||||
config:
|
|
||||||
model: Llama-Guard-3-1B
|
|
||||||
safety:
|
|
||||||
- provider_id: meta-reference
|
|
||||||
provider_type: meta-reference
|
|
||||||
config:
|
|
||||||
llama_guard_shield:
|
|
||||||
model: Llama-Guard-3-1B
|
|
Loading…
Add table
Add a link
Reference in a new issue