From 22aedd0277a2ac8576859b6cdc2e0bb36616432e Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Tue, 12 Nov 2024 21:31:02 -0800 Subject: [PATCH] fixed agent persistence test, more cleanup --- llama_stack/distribution/resolver.py | 5 +- llama_stack/distribution/stack.py | 10 +- .../providers/tests/agents/fixtures.py | 7 +- .../tests/agents/test_agent_persistence.py | 148 ------------------ .../providers/tests/agents/test_agents.py | 21 +-- .../tests/agents/test_persistence.py | 122 +++++++++++++++ llama_stack/providers/tests/agents/utils.py | 17 ++ .../providers/tests/datasetio/fixtures.py | 6 +- llama_stack/providers/tests/eval/fixtures.py | 6 +- .../providers/tests/inference/fixtures.py | 6 +- .../providers/tests/memory/fixtures.py | 6 +- llama_stack/providers/tests/resolver.py | 126 +++------------ .../providers/tests/safety/fixtures.py | 26 ++- .../providers/tests/scoring/fixtures.py | 6 +- 14 files changed, 202 insertions(+), 310 deletions(-) delete mode 100644 llama_stack/providers/tests/agents/test_agent_persistence.py create mode 100644 llama_stack/providers/tests/agents/test_persistence.py create mode 100644 llama_stack/providers/tests/agents/utils.py diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index b95cc5418..4c74b0d1f 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -78,10 +78,13 @@ class ProviderWithSpec(Provider): spec: ProviderSpec +ProviderRegistry = Dict[Api, Dict[str, ProviderSpec]] + + # TODO: this code is not very straightforward to follow and needs one more round of refactoring async def resolve_impls( run_config: StackRunConfig, - provider_registry: Dict[Api, Dict[str, ProviderSpec]], + provider_registry: ProviderRegistry, dist_registry: DistributionRegistry, ) -> Dict[Api, Any]: """ diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py index 6a80d7a48..1cffd7749 100644 --- a/llama_stack/distribution/stack.py +++ b/llama_stack/distribution/stack.py @@ -30,7 +30,7 @@ from llama_stack.apis.eval_tasks import * # noqa: F403 from llama_stack.distribution.datatypes import StackRunConfig from llama_stack.distribution.distribution import get_provider_registry -from llama_stack.distribution.resolver import resolve_impls +from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls from llama_stack.distribution.store.registry import create_dist_registry from llama_stack.providers.datatypes import Api @@ -94,10 +94,14 @@ async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]): # Produces a stack of providers for the given run config. Not all APIs may be # asked for in the run config. -async def construct_stack(run_config: StackRunConfig) -> Dict[Api, Any]: +async def construct_stack( + run_config: StackRunConfig, provider_registry: Optional[ProviderRegistry] = None +) -> Dict[Api, Any]: dist_registry, _ = await create_dist_registry( run_config.metadata_store, run_config.image_name ) - impls = await resolve_impls(run_config, get_provider_registry(), dist_registry) + impls = await resolve_impls( + run_config, provider_registry or get_provider_registry(), dist_registry + ) await register_resources(run_config, impls) return impls diff --git a/llama_stack/providers/tests/agents/fixtures.py b/llama_stack/providers/tests/agents/fixtures.py index 322cae798..1f89b909a 100644 --- a/llama_stack/providers/tests/agents/fixtures.py +++ b/llama_stack/providers/tests/agents/fixtures.py @@ -16,7 +16,7 @@ from llama_stack.providers.inline.agents.meta_reference import ( MetaReferenceAgentsImplConfig, ) -from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 +from llama_stack.providers.tests.resolver import construct_stack_for_test from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig from ..conftest import ProviderFixture, remote_stack_fixture @@ -73,7 +73,7 @@ async def agents_stack(request, inference_model, safety_shield): inference_models = ( inference_model if isinstance(inference_model, list) else [inference_model] ) - impls = await resolve_impls_for_test_v2( + test_stack = await construct_stack_for_test( [Api.agents, Api.inference, Api.safety, Api.memory], providers, provider_data, @@ -85,5 +85,4 @@ async def agents_stack(request, inference_model, safety_shield): ], shields=[safety_shield], ) - - return impls[Api.agents], impls[Api.memory] + return test_stack diff --git a/llama_stack/providers/tests/agents/test_agent_persistence.py b/llama_stack/providers/tests/agents/test_agent_persistence.py deleted file mode 100644 index a15887b33..000000000 --- a/llama_stack/providers/tests/agents/test_agent_persistence.py +++ /dev/null @@ -1,148 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import pytest -import pytest_asyncio - -from llama_stack.apis.agents import * # noqa: F403 -from llama_stack.providers.tests.resolver import resolve_impls_for_test -from llama_stack.providers.datatypes import * # noqa: F403 - -from dotenv import load_dotenv - -from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig - -# How to run this test: -# -# 1. Ensure you have a conda environment with the right dependencies installed. -# This includes `pytest` and `pytest-asyncio`. -# -# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing. -# -# 3. Run: -# -# ```bash -# PROVIDER_ID= \ -# PROVIDER_CONFIG=provider_config.yaml \ -# pytest -s llama_stack/providers/tests/agents/test_agent_persistence.py \ -# --tb=short --disable-warnings -# ``` - -load_dotenv() - - -@pytest_asyncio.fixture(scope="session") -async def agents_settings(): - impls = await resolve_impls_for_test( - Api.agents, deps=[Api.inference, Api.memory, Api.safety] - ) - - return { - "impl": impls[Api.agents], - "memory_impl": impls[Api.memory], - "common_params": { - "model": "Llama3.1-8B-Instruct", - "instructions": "You are a helpful assistant.", - }, - } - - -@pytest.fixture -def sample_messages(): - return [ - UserMessage(content="What's the weather like today?"), - ] - - -@pytest.mark.asyncio -async def test_delete_agents_and_sessions(agents_settings, sample_messages): - agents_impl = agents_settings["impl"] - # First, create an agent - agent_config = AgentConfig( - model=agents_settings["common_params"]["model"], - instructions=agents_settings["common_params"]["instructions"], - enable_session_persistence=True, - sampling_params=SamplingParams(temperature=0.7, top_p=0.95), - input_shields=[], - output_shields=[], - tools=[], - max_infer_iters=5, - ) - - create_response = await agents_impl.create_agent(agent_config) - agent_id = create_response.agent_id - - # Create a session - session_create_response = await agents_impl.create_agent_session( - agent_id, "Test Session" - ) - session_id = session_create_response.session_id - persistence_store = await kvstore_impl(agents_settings["persistence"]) - - await agents_impl.delete_agents_session(agent_id, session_id) - session_response = await persistence_store.get(f"session:{agent_id}:{session_id}") - - await agents_impl.delete_agents(agent_id) - agent_response = await persistence_store.get(f"agent:{agent_id}") - - assert session_response is None - assert agent_response is None - - -async def test_get_agent_turns_and_steps(agents_settings, sample_messages): - agents_impl = agents_settings["impl"] - - # First, create an agent - agent_config = AgentConfig( - model=agents_settings["common_params"]["model"], - instructions=agents_settings["common_params"]["instructions"], - enable_session_persistence=True, - sampling_params=SamplingParams(temperature=0.7, top_p=0.95), - input_shields=[], - output_shields=[], - tools=[], - max_infer_iters=5, - ) - - create_response = await agents_impl.create_agent(agent_config) - agent_id = create_response.agent_id - - # Create a session - session_create_response = await agents_impl.create_agent_session( - agent_id, "Test Session" - ) - session_id = session_create_response.session_id - - # Create and execute a turn - turn_request = dict( - agent_id=agent_id, - session_id=session_id, - messages=sample_messages, - stream=True, - ) - - turn_response = [ - chunk async for chunk in await agents_impl.create_agent_turn(**turn_request) - ] - - final_event = turn_response[-1].event.payload - turn_id = final_event.turn.turn_id - persistence_store = await kvstore_impl(SqliteKVStoreConfig()) - turn = await persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}") - response = await agents_impl.get_agents_turn(agent_id, session_id, turn_id) - - assert isinstance(response, Turn) - assert response == final_event.turn - assert turn == final_event.turn - - steps = final_event.turn.steps - step_id = steps[0].step_id - step_response = await agents_impl.get_agents_step( - agent_id, session_id, turn_id, step_id - ) - - assert isinstance(step_response.step, Step) - assert step_response.step == steps[0] diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index bdfa38bc4..60c047058 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -17,6 +17,7 @@ from llama_stack.providers.datatypes import * # noqa: F403 # -m "meta_reference" from .fixtures import pick_inference_model +from .utils import create_agent_session @pytest.fixture @@ -67,24 +68,12 @@ def query_attachment_messages(): ] -async def create_agent_session(agents_impl, agent_config): - create_response = await agents_impl.create_agent(agent_config) - agent_id = create_response.agent_id - - # Create a session - session_create_response = await agents_impl.create_agent_session( - agent_id, "Test Session" - ) - session_id = session_create_response.session_id - return agent_id, session_id - - class TestAgents: @pytest.mark.asyncio async def test_agent_turns_with_safety( self, safety_shield, agents_stack, common_params ): - agents_impl, _ = agents_stack + agents_impl = agents_stack.impls[Api.agents] agent_id, session_id = await create_agent_session( agents_impl, AgentConfig( @@ -127,7 +116,7 @@ class TestAgents: async def test_create_agent_turn( self, agents_stack, sample_messages, common_params ): - agents_impl, _ = agents_stack + agents_impl = agents_stack.impls[Api.agents] agent_id, session_id = await create_agent_session( agents_impl, AgentConfig(**common_params) @@ -158,7 +147,7 @@ class TestAgents: query_attachment_messages, common_params, ): - agents_impl, _ = agents_stack + agents_impl = agents_stack.impls[Api.agents] urls = [ "memory_optimizations.rst", "chat.rst", @@ -226,7 +215,7 @@ class TestAgents: async def test_create_agent_turn_with_brave_search( self, agents_stack, search_query_messages, common_params ): - agents_impl, _ = agents_stack + agents_impl = agents_stack.impls[Api.agents] if "BRAVE_SEARCH_API_KEY" not in os.environ: pytest.skip("BRAVE_SEARCH_API_KEY not set, skipping test") diff --git a/llama_stack/providers/tests/agents/test_persistence.py b/llama_stack/providers/tests/agents/test_persistence.py new file mode 100644 index 000000000..97094cd7a --- /dev/null +++ b/llama_stack/providers/tests/agents/test_persistence.py @@ -0,0 +1,122 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest + +from llama_stack.apis.agents import * # noqa: F403 +from llama_stack.providers.datatypes import * # noqa: F403 + +from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig +from .fixtures import pick_inference_model + +from .utils import create_agent_session + + +@pytest.fixture +def sample_messages(): + return [ + UserMessage(content="What's the weather like today?"), + ] + + +@pytest.fixture +def common_params(inference_model): + inference_model = pick_inference_model(inference_model) + + return dict( + model=inference_model, + instructions="You are a helpful assistant.", + enable_session_persistence=True, + sampling_params=SamplingParams(temperature=0.7, top_p=0.95), + input_shields=[], + output_shields=[], + tools=[], + max_infer_iters=5, + ) + + +class TestAgentPersistence: + @pytest.mark.asyncio + async def test_delete_agents_and_sessions(self, agents_stack, common_params): + agents_impl = agents_stack.impls[Api.agents] + agent_id, session_id = await create_agent_session( + agents_impl, + AgentConfig( + **{ + **common_params, + "input_shields": [], + "output_shields": [], + } + ), + ) + + run_config = agents_stack.run_config + provider_config = run_config.providers["agents"][0].config + persistence_store = await kvstore_impl( + SqliteKVStoreConfig(**provider_config["persistence_store"]) + ) + + await agents_impl.delete_agents_session(agent_id, session_id) + session_response = await persistence_store.get( + f"session:{agent_id}:{session_id}" + ) + + await agents_impl.delete_agents(agent_id) + agent_response = await persistence_store.get(f"agent:{agent_id}") + + assert session_response is None + assert agent_response is None + + @pytest.mark.asyncio + async def test_get_agent_turns_and_steps( + self, agents_stack, sample_messages, common_params + ): + agents_impl = agents_stack.impls[Api.agents] + + agent_id, session_id = await create_agent_session( + agents_impl, + AgentConfig( + **{ + **common_params, + "input_shields": [], + "output_shields": [], + } + ), + ) + + # Create and execute a turn + turn_request = dict( + agent_id=agent_id, + session_id=session_id, + messages=sample_messages, + stream=True, + ) + + turn_response = [ + chunk async for chunk in await agents_impl.create_agent_turn(**turn_request) + ] + + final_event = turn_response[-1].event.payload + turn_id = final_event.turn.turn_id + + provider_config = agents_stack.run_config.providers["agents"][0].config + persistence_store = await kvstore_impl( + SqliteKVStoreConfig(**provider_config["persistence_store"]) + ) + turn = await persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}") + response = await agents_impl.get_agents_turn(agent_id, session_id, turn_id) + + assert isinstance(response, Turn) + assert response == final_event.turn + assert turn == final_event.turn.model_dump_json() + + steps = final_event.turn.steps + step_id = steps[0].step_id + step_response = await agents_impl.get_agents_step( + agent_id, session_id, turn_id, step_id + ) + + assert step_response.step == steps[0] diff --git a/llama_stack/providers/tests/agents/utils.py b/llama_stack/providers/tests/agents/utils.py new file mode 100644 index 000000000..048877991 --- /dev/null +++ b/llama_stack/providers/tests/agents/utils.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +async def create_agent_session(agents_impl, agent_config): + create_response = await agents_impl.create_agent(agent_config) + agent_id = create_response.agent_id + + # Create a session + session_create_response = await agents_impl.create_agent_session( + agent_id, "Test Session" + ) + session_id = session_create_response.session_id + return agent_id, session_id diff --git a/llama_stack/providers/tests/datasetio/fixtures.py b/llama_stack/providers/tests/datasetio/fixtures.py index 6f20bf96a..60f89de46 100644 --- a/llama_stack/providers/tests/datasetio/fixtures.py +++ b/llama_stack/providers/tests/datasetio/fixtures.py @@ -9,7 +9,7 @@ import pytest_asyncio from llama_stack.distribution.datatypes import Api, Provider -from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 +from llama_stack.providers.tests.resolver import construct_stack_for_test from ..conftest import ProviderFixture, remote_stack_fixture @@ -52,10 +52,10 @@ async def datasetio_stack(request): fixture_name = request.param fixture = request.getfixturevalue(f"datasetio_{fixture_name}") - impls = await resolve_impls_for_test_v2( + test_stack = await construct_stack_for_test( [Api.datasetio], {"datasetio": fixture.providers}, fixture.provider_data, ) - return impls[Api.datasetio], impls[Api.datasets] + return test_stack.impls[Api.datasetio], test_stack.impls[Api.datasets] diff --git a/llama_stack/providers/tests/eval/fixtures.py b/llama_stack/providers/tests/eval/fixtures.py index 4a359213b..a6b404d0c 100644 --- a/llama_stack/providers/tests/eval/fixtures.py +++ b/llama_stack/providers/tests/eval/fixtures.py @@ -9,7 +9,7 @@ import pytest_asyncio from llama_stack.distribution.datatypes import Api, Provider -from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 +from llama_stack.providers.tests.resolver import construct_stack_for_test from ..conftest import ProviderFixture, remote_stack_fixture @@ -46,10 +46,10 @@ async def eval_stack(request): if fixture.provider_data: provider_data.update(fixture.provider_data) - impls = await resolve_impls_for_test_v2( + test_stack = await construct_stack_for_test( [Api.eval, Api.datasetio, Api.inference, Api.scoring], providers, provider_data, ) - return impls + return test_stack.impls diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index 01d7e4892..a53ddf639 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -21,7 +21,7 @@ from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig from llama_stack.providers.remote.inference.ollama import OllamaImplConfig from llama_stack.providers.remote.inference.together import TogetherImplConfig from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig -from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 +from llama_stack.providers.tests.resolver import construct_stack_for_test from ..conftest import ProviderFixture, remote_stack_fixture from ..env import get_env_or_fail @@ -182,11 +182,11 @@ INFERENCE_FIXTURES = [ async def inference_stack(request, inference_model): fixture_name = request.param inference_fixture = request.getfixturevalue(f"inference_{fixture_name}") - impls = await resolve_impls_for_test_v2( + test_stack = await construct_stack_for_test( [Api.inference], {"inference": inference_fixture.providers}, inference_fixture.provider_data, models=[ModelInput(model_id=inference_model)], ) - return (impls[Api.inference], impls[Api.models]) + return test_stack.impls[Api.inference], test_stack.impls[Api.models] diff --git a/llama_stack/providers/tests/memory/fixtures.py b/llama_stack/providers/tests/memory/fixtures.py index 456e354b2..c9559b61c 100644 --- a/llama_stack/providers/tests/memory/fixtures.py +++ b/llama_stack/providers/tests/memory/fixtures.py @@ -14,7 +14,7 @@ from llama_stack.distribution.datatypes import Api, Provider, RemoteProviderConf from llama_stack.providers.inline.memory.faiss import FaissImplConfig from llama_stack.providers.remote.memory.pgvector import PGVectorConfig from llama_stack.providers.remote.memory.weaviate import WeaviateConfig -from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 +from llama_stack.providers.tests.resolver import construct_stack_for_test from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig from ..conftest import ProviderFixture, remote_stack_fixture from ..env import get_env_or_fail @@ -101,10 +101,10 @@ async def memory_stack(request): fixture_name = request.param fixture = request.getfixturevalue(f"memory_{fixture_name}") - impls = await resolve_impls_for_test_v2( + test_stack = await construct_stack_for_test( [Api.memory], {"memory": fixture.providers}, fixture.provider_data, ) - return impls[Api.memory], impls[Api.memory_banks] + return test_stack.impls[Api.memory], test_stack.impls[Api.memory_banks] diff --git a/llama_stack/providers/tests/resolver.py b/llama_stack/providers/tests/resolver.py index 46e0435af..df927926e 100644 --- a/llama_stack/providers/tests/resolver.py +++ b/llama_stack/providers/tests/resolver.py @@ -5,38 +5,26 @@ # the root directory of this source tree. import json -import os import tempfile from datetime import datetime from typing import Any, Dict, List, Optional -import yaml - from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.build import print_pip_install_help from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.request_headers import set_request_provider_data -from llama_stack.distribution.resolver import resolve_impls, resolve_remote_stack_impls +from llama_stack.distribution.resolver import resolve_remote_stack_impls from llama_stack.distribution.stack import construct_stack from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig -async def construct_stack_for_test(run_config: StackRunConfig): - remote_config = remote_provider_config(run_config) - if not remote_config: - return await construct_stack(run_config) - - impls = await resolve_remote_stack_impls(remote_config, run_config.apis) - - # we don't register resources for a remote stack as part of the fixture setup - # because the stack is already "up". if a test needs to register resources, it - # can do so manually always. - - return impls +class TestStack(BaseModel): + impls: Dict[Api, Any] + run_config: StackRunConfig -async def resolve_impls_for_test_v2( +async def construct_stack_for_test( apis: List[Api], providers: Dict[str, List[Provider]], provider_data: Optional[Dict[str, Any]] = None, @@ -46,7 +34,7 @@ async def resolve_impls_for_test_v2( datasets: Optional[List[DatasetInput]] = None, scoring_fns: Optional[List[ScoringFnInput]] = None, eval_tasks: Optional[List[EvalTaskInput]] = None, -): +) -> TestStack: sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db") run_config = dict( built_at=datetime.now(), @@ -63,7 +51,18 @@ async def resolve_impls_for_test_v2( ) run_config = parse_and_maybe_upgrade_config(run_config) try: - impls = await construct_stack_for_test(run_config) + remote_config = remote_provider_config(run_config) + if not remote_config: + # TODO: add to provider registry by creating interesting mocks or fakes + impls = await construct_stack(run_config, get_provider_registry()) + else: + # we don't register resources for a remote stack as part of the fixture setup + # because the stack is already "up". if a test needs to register resources, it + # can do so manually always. + + impls = await resolve_remote_stack_impls(remote_config, run_config.apis) + + test_stack = TestStack(impls=impls, run_config=run_config) except ModuleNotFoundError as e: print_pip_install_help(providers) raise e @@ -73,7 +72,7 @@ async def resolve_impls_for_test_v2( {"X-LlamaStack-ProviderData": json.dumps(provider_data)} ) - return impls + return test_stack def remote_provider_config( @@ -92,90 +91,3 @@ def remote_provider_config( assert not has_non_remote, "Remote stack cannot have non-remote providers" return remote_config - - -async def resolve_impls_for_test(api: Api, deps: List[Api] = None): - if "PROVIDER_CONFIG" not in os.environ: - raise ValueError( - "You must set PROVIDER_CONFIG to a YAML file containing provider config" - ) - - with open(os.environ["PROVIDER_CONFIG"], "r") as f: - config_dict = yaml.safe_load(f) - - providers = read_providers(api, config_dict) - - chosen = choose_providers(providers, api, deps) - run_config = dict( - built_at=datetime.now(), - image_name="test-fixture", - apis=[api] + (deps or []), - providers=chosen, - ) - run_config = parse_and_maybe_upgrade_config(run_config) - try: - impls = await resolve_impls(run_config, get_provider_registry()) - except ModuleNotFoundError as e: - print_pip_install_help(providers) - raise e - - if "provider_data" in config_dict: - provider_id = chosen[api.value][0].provider_id - provider_data = config_dict["provider_data"].get(provider_id, {}) - if provider_data: - set_request_provider_data( - {"X-LlamaStack-ProviderData": json.dumps(provider_data)} - ) - - return impls - - -def read_providers(api: Api, config_dict: Dict[str, Any]) -> Dict[str, Any]: - if "providers" not in config_dict: - raise ValueError("Config file should contain a `providers` key") - - providers = config_dict["providers"] - if isinstance(providers, dict): - return providers - elif isinstance(providers, list): - return { - api.value: providers, - } - else: - raise ValueError( - "Config file should contain a list of providers or dict(api to providers)" - ) - - -def choose_providers( - providers: Dict[str, Any], api: Api, deps: List[Api] = None -) -> Dict[str, Provider]: - chosen = {} - if api.value not in providers: - raise ValueError(f"No providers found for `{api}`?") - chosen[api.value] = [pick_provider(api, providers[api.value], "PROVIDER_ID")] - - for dep in deps or []: - if dep.value not in providers: - raise ValueError(f"No providers specified for `{dep}` in config?") - chosen[dep.value] = [Provider(**x) for x in providers[dep.value]] - - return chosen - - -def pick_provider(api: Api, providers: List[Any], key: str) -> Provider: - providers_by_id = {x["provider_id"]: x for x in providers} - if len(providers_by_id) == 0: - raise ValueError(f"No providers found for `{api}` in config file") - - if key in os.environ: - provider_id = os.environ[key] - if provider_id not in providers_by_id: - raise ValueError(f"Provider ID {provider_id} not found in config file") - provider = providers_by_id[provider_id] - else: - provider = list(providers_by_id.values())[0] - provider_id = provider["provider_id"] - print(f"No provider ID specified, picking first `{provider_id}`") - - return Provider(**provider) diff --git a/llama_stack/providers/tests/safety/fixtures.py b/llama_stack/providers/tests/safety/fixtures.py index ade201b11..a706316dd 100644 --- a/llama_stack/providers/tests/safety/fixtures.py +++ b/llama_stack/providers/tests/safety/fixtures.py @@ -16,7 +16,7 @@ from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig from llama_stack.providers.inline.safety.prompt_guard import PromptGuardConfig from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig -from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 +from llama_stack.providers.tests.resolver import construct_stack_for_test from ..conftest import ProviderFixture, remote_stack_fixture from ..env import get_env_or_fail @@ -102,22 +102,16 @@ SAFETY_FIXTURES = ["llama_guard", "bedrock", "remote"] async def safety_stack(inference_model, safety_shield, request): # We need an inference + safety fixture to test safety fixture_dict = request.param - inference_fixture = request.getfixturevalue( - f"inference_{fixture_dict['inference']}" - ) - safety_fixture = request.getfixturevalue(f"safety_{fixture_dict['safety']}") - providers = { - "inference": inference_fixture.providers, - "safety": safety_fixture.providers, - } + providers = {} provider_data = {} - if inference_fixture.provider_data: - provider_data.update(inference_fixture.provider_data) - if safety_fixture.provider_data: - provider_data.update(safety_fixture.provider_data) + for key in ["inference", "safety"]: + fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}") + providers[key] = fixture.providers + if fixture.provider_data: + provider_data.update(fixture.provider_data) - impls = await resolve_impls_for_test_v2( + test_stack = await construct_stack_for_test( [Api.safety, Api.shields, Api.inference], providers, provider_data, @@ -125,5 +119,5 @@ async def safety_stack(inference_model, safety_shield, request): shields=[safety_shield], ) - shield = await impls[Api.shields].get_shield(safety_shield.shield_id) - return impls[Api.safety], impls[Api.shields], shield + shield = await test_stack.impls[Api.shields].get_shield(safety_shield.shield_id) + return test_stack.impls[Api.safety], test_stack.impls[Api.shields], shield diff --git a/llama_stack/providers/tests/scoring/fixtures.py b/llama_stack/providers/tests/scoring/fixtures.py index ee6999043..d89b211ef 100644 --- a/llama_stack/providers/tests/scoring/fixtures.py +++ b/llama_stack/providers/tests/scoring/fixtures.py @@ -11,7 +11,7 @@ from llama_stack.apis.models import ModelInput from llama_stack.distribution.datatypes import Api, Provider -from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 +from llama_stack.providers.tests.resolver import construct_stack_for_test from ..conftest import ProviderFixture, remote_stack_fixture @@ -74,7 +74,7 @@ async def scoring_stack(request, inference_model): if fixture.provider_data: provider_data.update(fixture.provider_data) - impls = await resolve_impls_for_test_v2( + test_stack = await construct_stack_for_test( [Api.scoring, Api.datasetio, Api.inference], providers, provider_data, @@ -88,4 +88,4 @@ async def scoring_stack(request, inference_model): ], ) - return impls + return test_stack.impls