mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-16 18:32:40 +00:00
Summary: Part of * https://github.com/meta-llama/llama-stack/issues/436 This change adds a minimalistic support to creating memory provider test fake. For more details about the approach, follow the issue link from above. Test Plan: Run tests using the "test_fake" mark: ``` pytest llama_stack/providers/tests/memory/test_memory.py -m "test_fake" /opt/homebrew/Caskroom/miniconda/base/envs/llama-stack/lib/python3.11/site-packages/pytest_asyncio/plugin.py:208: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset. The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session" warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET)) ====================================================================================================== test session starts ====================================================================================================== platform darwin -- Python 3.11.10, pytest-8.3.3, pluggy-1.5.0 rootdir: /llama-stack configfile: pyproject.toml plugins: asyncio-0.24.0, anyio-4.6.2.post1 asyncio: mode=Mode.STRICT, default_loop_scope=None collected 18 items / 15 deselected / 3 selected llama_stack/providers/tests/memory/test_memory.py ... [100%] ========================================================================================= 3 passed, 15 deselected, 6 warnings in 0.03s ========================================================================================== ```
128 lines
4.8 KiB
Python
128 lines
4.8 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
import json
|
|
import tempfile
|
|
from datetime import datetime
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
|
from llama_stack.distribution.build import print_pip_install_help
|
|
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
|
from llama_stack.distribution.distribution import get_provider_registry
|
|
from llama_stack.distribution.request_headers import set_request_provider_data
|
|
from llama_stack.distribution.resolver import (
|
|
resolve_remote_stack_impls,
|
|
resolve_test_fake_stack_impls,
|
|
)
|
|
from llama_stack.distribution.stack import construct_stack
|
|
from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig
|
|
|
|
|
|
class TestStack(BaseModel):
|
|
impls: Dict[Api, Any]
|
|
run_config: StackRunConfig
|
|
|
|
|
|
async def construct_stack_for_test(
|
|
apis: List[Api],
|
|
providers: Dict[str, List[Provider]],
|
|
provider_data: Optional[Dict[str, Any]] = None,
|
|
models: Optional[List[ModelInput]] = None,
|
|
shields: Optional[List[ShieldInput]] = None,
|
|
memory_banks: Optional[List[MemoryBankInput]] = None,
|
|
datasets: Optional[List[DatasetInput]] = None,
|
|
scoring_fns: Optional[List[ScoringFnInput]] = None,
|
|
eval_tasks: Optional[List[EvalTaskInput]] = None,
|
|
) -> TestStack:
|
|
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
|
run_config = dict(
|
|
built_at=datetime.now(),
|
|
image_name="test-fixture",
|
|
apis=apis,
|
|
providers=providers,
|
|
metadata_store=SqliteKVStoreConfig(db_path=sqlite_file.name),
|
|
models=models or [],
|
|
shields=shields or [],
|
|
memory_banks=memory_banks or [],
|
|
datasets=datasets or [],
|
|
scoring_fns=scoring_fns or [],
|
|
eval_tasks=eval_tasks or [],
|
|
)
|
|
run_config = parse_and_maybe_upgrade_config(run_config)
|
|
try:
|
|
impls = None
|
|
|
|
# "Resolve" implementations when using test::test-fake providers.
|
|
# This is actually injecting test fakes as resolved API implementations.
|
|
test_fake_config = test_fake_provider_config(run_config)
|
|
if test_fake_config:
|
|
impls = await resolve_test_fake_stack_impls(test_fake_config)
|
|
|
|
# Resolve implementations when using test::remove providers
|
|
if not impls:
|
|
remote_config = remote_provider_config(run_config)
|
|
if remote_config:
|
|
# we don't register resources for a remote stack as part of the fixture setup
|
|
# because the stack is already "up". if a test needs to register resources, it
|
|
# can do so manually always.
|
|
impls = await resolve_remote_stack_impls(remote_config, run_config.apis)
|
|
|
|
# In case none of the above happened, resolve implementations as normal providers
|
|
if not impls:
|
|
impls = await construct_stack(run_config, get_provider_registry())
|
|
|
|
test_stack = TestStack(impls=impls, run_config=run_config)
|
|
except ModuleNotFoundError as e:
|
|
print_pip_install_help(providers)
|
|
raise e
|
|
|
|
if provider_data:
|
|
set_request_provider_data(
|
|
{"X-LlamaStack-ProviderData": json.dumps(provider_data)}
|
|
)
|
|
|
|
return test_stack
|
|
|
|
|
|
# In case when we want to run multiple test fakes, then we need to
|
|
# make sure the stack contains only test fake providers as after this
|
|
# we will be calling resolve_test_fake_stack_impls() which cannot resolve
|
|
# any other provider type. This could be refactored in case we need to change this.
|
|
def test_fake_provider_config(
|
|
run_config: StackRunConfig,
|
|
) -> Optional[RemoteProviderConfig]:
|
|
test_fake_config = None
|
|
has_non_test_fake = False
|
|
for api_providers in run_config.providers.values():
|
|
for provider in api_providers:
|
|
if provider.provider_type == "test::test-fake":
|
|
test_fake_config = TestFakeProviderConfig(**provider.config)
|
|
else:
|
|
has_non_test_fake = True
|
|
|
|
if test_fake_config:
|
|
assert not has_non_test_fake, "Test fake stack cannot have non-fake providers"
|
|
|
|
return test_fake_config
|
|
|
|
|
|
def remote_provider_config(
|
|
run_config: StackRunConfig,
|
|
) -> Optional[RemoteProviderConfig]:
|
|
remote_config = None
|
|
has_non_remote = False
|
|
for api_providers in run_config.providers.values():
|
|
for provider in api_providers:
|
|
if provider.provider_type == "test::remote":
|
|
remote_config = RemoteProviderConfig(**provider.config)
|
|
else:
|
|
has_non_remote = True
|
|
|
|
if remote_config:
|
|
assert not has_non_remote, "Remote stack cannot have non-remote providers"
|
|
|
|
return remote_config
|