llama-stack-mirror/llama_stack/providers/tests/memory/fixtures.py
Ashwin Bharambe ffedb81c11
Significantly simpler and malleable test setup (#360)
* Significantly simpler and malleable test setup

* convert memory tests

* refactor fixtures and add support for composable fixtures

* Fix memory to use the newer fixture organization

* Get agents tests working

* Safety tests work

* yet another refactor to make this more general

now it accepts --inference-model, --safety-model options also

* get multiple providers working for meta-reference (for inference + safety)

* Add README.md

---------

Co-authored-by: Ashwin Bharambe <ashwin@meta.com>
2024-11-04 17:36:43 -08:00

85 lines
2.5 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import pytest
import pytest_asyncio
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.adapters.memory.pgvector import PGVectorConfig
from llama_stack.providers.adapters.memory.weaviate import WeaviateConfig
from llama_stack.providers.impls.meta_reference.memory import FaissImplConfig
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
from ..conftest import ProviderFixture
from ..env import get_env_or_fail
@pytest.fixture(scope="session")
def memory_meta_reference() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="meta-reference",
provider_type="meta-reference",
config=FaissImplConfig().model_dump(),
)
],
)
@pytest.fixture(scope="session")
def memory_pgvector() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="pgvector",
provider_type="remote::pgvector",
config=PGVectorConfig(
host=os.getenv("PGVECTOR_HOST", "localhost"),
port=os.getenv("PGVECTOR_PORT", 5432),
db=get_env_or_fail("PGVECTOR_DB"),
user=get_env_or_fail("PGVECTOR_USER"),
password=get_env_or_fail("PGVECTOR_PASSWORD"),
).model_dump(),
)
],
)
@pytest.fixture(scope="session")
def memory_weaviate() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="weaviate",
provider_type="remote::weaviate",
config=WeaviateConfig().model_dump(),
)
],
provider_data=dict(
weaviate_api_key=get_env_or_fail("WEAVIATE_API_KEY"),
weaviate_cluster_url=get_env_or_fail("WEAVIATE_CLUSTER_URL"),
),
)
MEMORY_FIXTURES = ["meta_reference", "pgvector", "weaviate"]
@pytest_asyncio.fixture(scope="session")
async def memory_stack(request):
fixture_name = request.param
fixture = request.getfixturevalue(f"memory_{fixture_name}")
impls = await resolve_impls_for_test_v2(
[Api.memory],
{"memory": fixture.providers},
fixture.provider_data,
)
return impls[Api.memory], impls[Api.memory_banks]