fixed agent persistence test, more cleanup

This commit is contained in:
Ashwin Bharambe 2024-11-12 21:31:02 -08:00
parent 4f3b009980
commit 22aedd0277
14 changed files with 202 additions and 310 deletions

View file

@ -78,10 +78,13 @@ class ProviderWithSpec(Provider):
spec: ProviderSpec spec: ProviderSpec
ProviderRegistry = Dict[Api, Dict[str, ProviderSpec]]
# TODO: this code is not very straightforward to follow and needs one more round of refactoring # TODO: this code is not very straightforward to follow and needs one more round of refactoring
async def resolve_impls( async def resolve_impls(
run_config: StackRunConfig, run_config: StackRunConfig,
provider_registry: Dict[Api, Dict[str, ProviderSpec]], provider_registry: ProviderRegistry,
dist_registry: DistributionRegistry, dist_registry: DistributionRegistry,
) -> Dict[Api, Any]: ) -> Dict[Api, Any]:
""" """

View file

@ -30,7 +30,7 @@ from llama_stack.apis.eval_tasks import * # noqa: F403
from llama_stack.distribution.datatypes import StackRunConfig from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.resolver import resolve_impls from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
from llama_stack.distribution.store.registry import create_dist_registry from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
@ -94,10 +94,14 @@ async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]):
# Produces a stack of providers for the given run config. Not all APIs may be # Produces a stack of providers for the given run config. Not all APIs may be
# asked for in the run config. # asked for in the run config.
async def construct_stack(run_config: StackRunConfig) -> Dict[Api, Any]: async def construct_stack(
run_config: StackRunConfig, provider_registry: Optional[ProviderRegistry] = None
) -> Dict[Api, Any]:
dist_registry, _ = await create_dist_registry( dist_registry, _ = await create_dist_registry(
run_config.metadata_store, run_config.image_name run_config.metadata_store, run_config.image_name
) )
impls = await resolve_impls(run_config, get_provider_registry(), dist_registry) impls = await resolve_impls(
run_config, provider_registry or get_provider_registry(), dist_registry
)
await register_resources(run_config, impls) await register_resources(run_config, impls)
return impls return impls

View file

@ -16,7 +16,7 @@ from llama_stack.providers.inline.agents.meta_reference import (
MetaReferenceAgentsImplConfig, MetaReferenceAgentsImplConfig,
) )
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 from llama_stack.providers.tests.resolver import construct_stack_for_test
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from ..conftest import ProviderFixture, remote_stack_fixture from ..conftest import ProviderFixture, remote_stack_fixture
@ -73,7 +73,7 @@ async def agents_stack(request, inference_model, safety_shield):
inference_models = ( inference_models = (
inference_model if isinstance(inference_model, list) else [inference_model] inference_model if isinstance(inference_model, list) else [inference_model]
) )
impls = await resolve_impls_for_test_v2( test_stack = await construct_stack_for_test(
[Api.agents, Api.inference, Api.safety, Api.memory], [Api.agents, Api.inference, Api.safety, Api.memory],
providers, providers,
provider_data, provider_data,
@ -85,5 +85,4 @@ async def agents_stack(request, inference_model, safety_shield):
], ],
shields=[safety_shield], shields=[safety_shield],
) )
return test_stack
return impls[Api.agents], impls[Api.memory]

View file

@ -1,148 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
import pytest_asyncio
from llama_stack.apis.agents import * # noqa: F403
from llama_stack.providers.tests.resolver import resolve_impls_for_test
from llama_stack.providers.datatypes import * # noqa: F403
from dotenv import load_dotenv
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
# How to run this test:
#
# 1. Ensure you have a conda environment with the right dependencies installed.
# This includes `pytest` and `pytest-asyncio`.
#
# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing.
#
# 3. Run:
#
# ```bash
# PROVIDER_ID=<your_provider> \
# PROVIDER_CONFIG=provider_config.yaml \
# pytest -s llama_stack/providers/tests/agents/test_agent_persistence.py \
# --tb=short --disable-warnings
# ```
load_dotenv()
@pytest_asyncio.fixture(scope="session")
async def agents_settings():
impls = await resolve_impls_for_test(
Api.agents, deps=[Api.inference, Api.memory, Api.safety]
)
return {
"impl": impls[Api.agents],
"memory_impl": impls[Api.memory],
"common_params": {
"model": "Llama3.1-8B-Instruct",
"instructions": "You are a helpful assistant.",
},
}
@pytest.fixture
def sample_messages():
return [
UserMessage(content="What's the weather like today?"),
]
@pytest.mark.asyncio
async def test_delete_agents_and_sessions(agents_settings, sample_messages):
agents_impl = agents_settings["impl"]
# First, create an agent
agent_config = AgentConfig(
model=agents_settings["common_params"]["model"],
instructions=agents_settings["common_params"]["instructions"],
enable_session_persistence=True,
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
input_shields=[],
output_shields=[],
tools=[],
max_infer_iters=5,
)
create_response = await agents_impl.create_agent(agent_config)
agent_id = create_response.agent_id
# Create a session
session_create_response = await agents_impl.create_agent_session(
agent_id, "Test Session"
)
session_id = session_create_response.session_id
persistence_store = await kvstore_impl(agents_settings["persistence"])
await agents_impl.delete_agents_session(agent_id, session_id)
session_response = await persistence_store.get(f"session:{agent_id}:{session_id}")
await agents_impl.delete_agents(agent_id)
agent_response = await persistence_store.get(f"agent:{agent_id}")
assert session_response is None
assert agent_response is None
async def test_get_agent_turns_and_steps(agents_settings, sample_messages):
agents_impl = agents_settings["impl"]
# First, create an agent
agent_config = AgentConfig(
model=agents_settings["common_params"]["model"],
instructions=agents_settings["common_params"]["instructions"],
enable_session_persistence=True,
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
input_shields=[],
output_shields=[],
tools=[],
max_infer_iters=5,
)
create_response = await agents_impl.create_agent(agent_config)
agent_id = create_response.agent_id
# Create a session
session_create_response = await agents_impl.create_agent_session(
agent_id, "Test Session"
)
session_id = session_create_response.session_id
# Create and execute a turn
turn_request = dict(
agent_id=agent_id,
session_id=session_id,
messages=sample_messages,
stream=True,
)
turn_response = [
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
]
final_event = turn_response[-1].event.payload
turn_id = final_event.turn.turn_id
persistence_store = await kvstore_impl(SqliteKVStoreConfig())
turn = await persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
response = await agents_impl.get_agents_turn(agent_id, session_id, turn_id)
assert isinstance(response, Turn)
assert response == final_event.turn
assert turn == final_event.turn
steps = final_event.turn.steps
step_id = steps[0].step_id
step_response = await agents_impl.get_agents_step(
agent_id, session_id, turn_id, step_id
)
assert isinstance(step_response.step, Step)
assert step_response.step == steps[0]

View file

@ -17,6 +17,7 @@ from llama_stack.providers.datatypes import * # noqa: F403
# -m "meta_reference" # -m "meta_reference"
from .fixtures import pick_inference_model from .fixtures import pick_inference_model
from .utils import create_agent_session
@pytest.fixture @pytest.fixture
@ -67,24 +68,12 @@ def query_attachment_messages():
] ]
async def create_agent_session(agents_impl, agent_config):
create_response = await agents_impl.create_agent(agent_config)
agent_id = create_response.agent_id
# Create a session
session_create_response = await agents_impl.create_agent_session(
agent_id, "Test Session"
)
session_id = session_create_response.session_id
return agent_id, session_id
class TestAgents: class TestAgents:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_agent_turns_with_safety( async def test_agent_turns_with_safety(
self, safety_shield, agents_stack, common_params self, safety_shield, agents_stack, common_params
): ):
agents_impl, _ = agents_stack agents_impl = agents_stack.impls[Api.agents]
agent_id, session_id = await create_agent_session( agent_id, session_id = await create_agent_session(
agents_impl, agents_impl,
AgentConfig( AgentConfig(
@ -127,7 +116,7 @@ class TestAgents:
async def test_create_agent_turn( async def test_create_agent_turn(
self, agents_stack, sample_messages, common_params self, agents_stack, sample_messages, common_params
): ):
agents_impl, _ = agents_stack agents_impl = agents_stack.impls[Api.agents]
agent_id, session_id = await create_agent_session( agent_id, session_id = await create_agent_session(
agents_impl, AgentConfig(**common_params) agents_impl, AgentConfig(**common_params)
@ -158,7 +147,7 @@ class TestAgents:
query_attachment_messages, query_attachment_messages,
common_params, common_params,
): ):
agents_impl, _ = agents_stack agents_impl = agents_stack.impls[Api.agents]
urls = [ urls = [
"memory_optimizations.rst", "memory_optimizations.rst",
"chat.rst", "chat.rst",
@ -226,7 +215,7 @@ class TestAgents:
async def test_create_agent_turn_with_brave_search( async def test_create_agent_turn_with_brave_search(
self, agents_stack, search_query_messages, common_params self, agents_stack, search_query_messages, common_params
): ):
agents_impl, _ = agents_stack agents_impl = agents_stack.impls[Api.agents]
if "BRAVE_SEARCH_API_KEY" not in os.environ: if "BRAVE_SEARCH_API_KEY" not in os.environ:
pytest.skip("BRAVE_SEARCH_API_KEY not set, skipping test") pytest.skip("BRAVE_SEARCH_API_KEY not set, skipping test")

View file

@ -0,0 +1,122 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack.apis.agents import * # noqa: F403
from llama_stack.providers.datatypes import * # noqa: F403
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
from .fixtures import pick_inference_model
from .utils import create_agent_session
@pytest.fixture
def sample_messages():
return [
UserMessage(content="What's the weather like today?"),
]
@pytest.fixture
def common_params(inference_model):
inference_model = pick_inference_model(inference_model)
return dict(
model=inference_model,
instructions="You are a helpful assistant.",
enable_session_persistence=True,
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
input_shields=[],
output_shields=[],
tools=[],
max_infer_iters=5,
)
class TestAgentPersistence:
@pytest.mark.asyncio
async def test_delete_agents_and_sessions(self, agents_stack, common_params):
agents_impl = agents_stack.impls[Api.agents]
agent_id, session_id = await create_agent_session(
agents_impl,
AgentConfig(
**{
**common_params,
"input_shields": [],
"output_shields": [],
}
),
)
run_config = agents_stack.run_config
provider_config = run_config.providers["agents"][0].config
persistence_store = await kvstore_impl(
SqliteKVStoreConfig(**provider_config["persistence_store"])
)
await agents_impl.delete_agents_session(agent_id, session_id)
session_response = await persistence_store.get(
f"session:{agent_id}:{session_id}"
)
await agents_impl.delete_agents(agent_id)
agent_response = await persistence_store.get(f"agent:{agent_id}")
assert session_response is None
assert agent_response is None
@pytest.mark.asyncio
async def test_get_agent_turns_and_steps(
self, agents_stack, sample_messages, common_params
):
agents_impl = agents_stack.impls[Api.agents]
agent_id, session_id = await create_agent_session(
agents_impl,
AgentConfig(
**{
**common_params,
"input_shields": [],
"output_shields": [],
}
),
)
# Create and execute a turn
turn_request = dict(
agent_id=agent_id,
session_id=session_id,
messages=sample_messages,
stream=True,
)
turn_response = [
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
]
final_event = turn_response[-1].event.payload
turn_id = final_event.turn.turn_id
provider_config = agents_stack.run_config.providers["agents"][0].config
persistence_store = await kvstore_impl(
SqliteKVStoreConfig(**provider_config["persistence_store"])
)
turn = await persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}")
response = await agents_impl.get_agents_turn(agent_id, session_id, turn_id)
assert isinstance(response, Turn)
assert response == final_event.turn
assert turn == final_event.turn.model_dump_json()
steps = final_event.turn.steps
step_id = steps[0].step_id
step_response = await agents_impl.get_agents_step(
agent_id, session_id, turn_id, step_id
)
assert step_response.step == steps[0]

View file

@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
async def create_agent_session(agents_impl, agent_config):
create_response = await agents_impl.create_agent(agent_config)
agent_id = create_response.agent_id
# Create a session
session_create_response = await agents_impl.create_agent_session(
agent_id, "Test Session"
)
session_id = session_create_response.session_id
return agent_id, session_id

View file

@ -9,7 +9,7 @@ import pytest_asyncio
from llama_stack.distribution.datatypes import Api, Provider from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 from llama_stack.providers.tests.resolver import construct_stack_for_test
from ..conftest import ProviderFixture, remote_stack_fixture from ..conftest import ProviderFixture, remote_stack_fixture
@ -52,10 +52,10 @@ async def datasetio_stack(request):
fixture_name = request.param fixture_name = request.param
fixture = request.getfixturevalue(f"datasetio_{fixture_name}") fixture = request.getfixturevalue(f"datasetio_{fixture_name}")
impls = await resolve_impls_for_test_v2( test_stack = await construct_stack_for_test(
[Api.datasetio], [Api.datasetio],
{"datasetio": fixture.providers}, {"datasetio": fixture.providers},
fixture.provider_data, fixture.provider_data,
) )
return impls[Api.datasetio], impls[Api.datasets] return test_stack.impls[Api.datasetio], test_stack.impls[Api.datasets]

View file

@ -9,7 +9,7 @@ import pytest_asyncio
from llama_stack.distribution.datatypes import Api, Provider from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 from llama_stack.providers.tests.resolver import construct_stack_for_test
from ..conftest import ProviderFixture, remote_stack_fixture from ..conftest import ProviderFixture, remote_stack_fixture
@ -46,10 +46,10 @@ async def eval_stack(request):
if fixture.provider_data: if fixture.provider_data:
provider_data.update(fixture.provider_data) provider_data.update(fixture.provider_data)
impls = await resolve_impls_for_test_v2( test_stack = await construct_stack_for_test(
[Api.eval, Api.datasetio, Api.inference, Api.scoring], [Api.eval, Api.datasetio, Api.inference, Api.scoring],
providers, providers,
provider_data, provider_data,
) )
return impls return test_stack.impls

View file

@ -21,7 +21,7 @@ from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
from llama_stack.providers.remote.inference.together import TogetherImplConfig from llama_stack.providers.remote.inference.together import TogetherImplConfig
from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 from llama_stack.providers.tests.resolver import construct_stack_for_test
from ..conftest import ProviderFixture, remote_stack_fixture from ..conftest import ProviderFixture, remote_stack_fixture
from ..env import get_env_or_fail from ..env import get_env_or_fail
@ -182,11 +182,11 @@ INFERENCE_FIXTURES = [
async def inference_stack(request, inference_model): async def inference_stack(request, inference_model):
fixture_name = request.param fixture_name = request.param
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}") inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
impls = await resolve_impls_for_test_v2( test_stack = await construct_stack_for_test(
[Api.inference], [Api.inference],
{"inference": inference_fixture.providers}, {"inference": inference_fixture.providers},
inference_fixture.provider_data, inference_fixture.provider_data,
models=[ModelInput(model_id=inference_model)], models=[ModelInput(model_id=inference_model)],
) )
return (impls[Api.inference], impls[Api.models]) return test_stack.impls[Api.inference], test_stack.impls[Api.models]

View file

@ -14,7 +14,7 @@ from llama_stack.distribution.datatypes import Api, Provider, RemoteProviderConf
from llama_stack.providers.inline.memory.faiss import FaissImplConfig from llama_stack.providers.inline.memory.faiss import FaissImplConfig
from llama_stack.providers.remote.memory.pgvector import PGVectorConfig from llama_stack.providers.remote.memory.pgvector import PGVectorConfig
from llama_stack.providers.remote.memory.weaviate import WeaviateConfig from llama_stack.providers.remote.memory.weaviate import WeaviateConfig
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 from llama_stack.providers.tests.resolver import construct_stack_for_test
from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig
from ..conftest import ProviderFixture, remote_stack_fixture from ..conftest import ProviderFixture, remote_stack_fixture
from ..env import get_env_or_fail from ..env import get_env_or_fail
@ -101,10 +101,10 @@ async def memory_stack(request):
fixture_name = request.param fixture_name = request.param
fixture = request.getfixturevalue(f"memory_{fixture_name}") fixture = request.getfixturevalue(f"memory_{fixture_name}")
impls = await resolve_impls_for_test_v2( test_stack = await construct_stack_for_test(
[Api.memory], [Api.memory],
{"memory": fixture.providers}, {"memory": fixture.providers},
fixture.provider_data, fixture.provider_data,
) )
return impls[Api.memory], impls[Api.memory_banks] return test_stack.impls[Api.memory], test_stack.impls[Api.memory_banks]

View file

@ -5,38 +5,26 @@
# the root directory of this source tree. # the root directory of this source tree.
import json import json
import os
import tempfile import tempfile
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import yaml
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.build import print_pip_install_help from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.request_headers import set_request_provider_data from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_impls, resolve_remote_stack_impls from llama_stack.distribution.resolver import resolve_remote_stack_impls
from llama_stack.distribution.stack import construct_stack from llama_stack.distribution.stack import construct_stack
from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig from llama_stack.providers.utils.kvstore import SqliteKVStoreConfig
async def construct_stack_for_test(run_config: StackRunConfig): class TestStack(BaseModel):
remote_config = remote_provider_config(run_config) impls: Dict[Api, Any]
if not remote_config: run_config: StackRunConfig
return await construct_stack(run_config)
impls = await resolve_remote_stack_impls(remote_config, run_config.apis)
# we don't register resources for a remote stack as part of the fixture setup
# because the stack is already "up". if a test needs to register resources, it
# can do so manually always.
return impls
async def resolve_impls_for_test_v2( async def construct_stack_for_test(
apis: List[Api], apis: List[Api],
providers: Dict[str, List[Provider]], providers: Dict[str, List[Provider]],
provider_data: Optional[Dict[str, Any]] = None, provider_data: Optional[Dict[str, Any]] = None,
@ -46,7 +34,7 @@ async def resolve_impls_for_test_v2(
datasets: Optional[List[DatasetInput]] = None, datasets: Optional[List[DatasetInput]] = None,
scoring_fns: Optional[List[ScoringFnInput]] = None, scoring_fns: Optional[List[ScoringFnInput]] = None,
eval_tasks: Optional[List[EvalTaskInput]] = None, eval_tasks: Optional[List[EvalTaskInput]] = None,
): ) -> TestStack:
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db") sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
run_config = dict( run_config = dict(
built_at=datetime.now(), built_at=datetime.now(),
@ -63,7 +51,18 @@ async def resolve_impls_for_test_v2(
) )
run_config = parse_and_maybe_upgrade_config(run_config) run_config = parse_and_maybe_upgrade_config(run_config)
try: try:
impls = await construct_stack_for_test(run_config) remote_config = remote_provider_config(run_config)
if not remote_config:
# TODO: add to provider registry by creating interesting mocks or fakes
impls = await construct_stack(run_config, get_provider_registry())
else:
# we don't register resources for a remote stack as part of the fixture setup
# because the stack is already "up". if a test needs to register resources, it
# can do so manually always.
impls = await resolve_remote_stack_impls(remote_config, run_config.apis)
test_stack = TestStack(impls=impls, run_config=run_config)
except ModuleNotFoundError as e: except ModuleNotFoundError as e:
print_pip_install_help(providers) print_pip_install_help(providers)
raise e raise e
@ -73,7 +72,7 @@ async def resolve_impls_for_test_v2(
{"X-LlamaStack-ProviderData": json.dumps(provider_data)} {"X-LlamaStack-ProviderData": json.dumps(provider_data)}
) )
return impls return test_stack
def remote_provider_config( def remote_provider_config(
@ -92,90 +91,3 @@ def remote_provider_config(
assert not has_non_remote, "Remote stack cannot have non-remote providers" assert not has_non_remote, "Remote stack cannot have non-remote providers"
return remote_config return remote_config
async def resolve_impls_for_test(api: Api, deps: List[Api] = None):
if "PROVIDER_CONFIG" not in os.environ:
raise ValueError(
"You must set PROVIDER_CONFIG to a YAML file containing provider config"
)
with open(os.environ["PROVIDER_CONFIG"], "r") as f:
config_dict = yaml.safe_load(f)
providers = read_providers(api, config_dict)
chosen = choose_providers(providers, api, deps)
run_config = dict(
built_at=datetime.now(),
image_name="test-fixture",
apis=[api] + (deps or []),
providers=chosen,
)
run_config = parse_and_maybe_upgrade_config(run_config)
try:
impls = await resolve_impls(run_config, get_provider_registry())
except ModuleNotFoundError as e:
print_pip_install_help(providers)
raise e
if "provider_data" in config_dict:
provider_id = chosen[api.value][0].provider_id
provider_data = config_dict["provider_data"].get(provider_id, {})
if provider_data:
set_request_provider_data(
{"X-LlamaStack-ProviderData": json.dumps(provider_data)}
)
return impls
def read_providers(api: Api, config_dict: Dict[str, Any]) -> Dict[str, Any]:
if "providers" not in config_dict:
raise ValueError("Config file should contain a `providers` key")
providers = config_dict["providers"]
if isinstance(providers, dict):
return providers
elif isinstance(providers, list):
return {
api.value: providers,
}
else:
raise ValueError(
"Config file should contain a list of providers or dict(api to providers)"
)
def choose_providers(
providers: Dict[str, Any], api: Api, deps: List[Api] = None
) -> Dict[str, Provider]:
chosen = {}
if api.value not in providers:
raise ValueError(f"No providers found for `{api}`?")
chosen[api.value] = [pick_provider(api, providers[api.value], "PROVIDER_ID")]
for dep in deps or []:
if dep.value not in providers:
raise ValueError(f"No providers specified for `{dep}` in config?")
chosen[dep.value] = [Provider(**x) for x in providers[dep.value]]
return chosen
def pick_provider(api: Api, providers: List[Any], key: str) -> Provider:
providers_by_id = {x["provider_id"]: x for x in providers}
if len(providers_by_id) == 0:
raise ValueError(f"No providers found for `{api}` in config file")
if key in os.environ:
provider_id = os.environ[key]
if provider_id not in providers_by_id:
raise ValueError(f"Provider ID {provider_id} not found in config file")
provider = providers_by_id[provider_id]
else:
provider = list(providers_by_id.values())[0]
provider_id = provider["provider_id"]
print(f"No provider ID specified, picking first `{provider_id}`")
return Provider(**provider)

View file

@ -16,7 +16,7 @@ from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig
from llama_stack.providers.inline.safety.prompt_guard import PromptGuardConfig from llama_stack.providers.inline.safety.prompt_guard import PromptGuardConfig
from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 from llama_stack.providers.tests.resolver import construct_stack_for_test
from ..conftest import ProviderFixture, remote_stack_fixture from ..conftest import ProviderFixture, remote_stack_fixture
from ..env import get_env_or_fail from ..env import get_env_or_fail
@ -102,22 +102,16 @@ SAFETY_FIXTURES = ["llama_guard", "bedrock", "remote"]
async def safety_stack(inference_model, safety_shield, request): async def safety_stack(inference_model, safety_shield, request):
# We need an inference + safety fixture to test safety # We need an inference + safety fixture to test safety
fixture_dict = request.param fixture_dict = request.param
inference_fixture = request.getfixturevalue(
f"inference_{fixture_dict['inference']}"
)
safety_fixture = request.getfixturevalue(f"safety_{fixture_dict['safety']}")
providers = { providers = {}
"inference": inference_fixture.providers,
"safety": safety_fixture.providers,
}
provider_data = {} provider_data = {}
if inference_fixture.provider_data: for key in ["inference", "safety"]:
provider_data.update(inference_fixture.provider_data) fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
if safety_fixture.provider_data: providers[key] = fixture.providers
provider_data.update(safety_fixture.provider_data) if fixture.provider_data:
provider_data.update(fixture.provider_data)
impls = await resolve_impls_for_test_v2( test_stack = await construct_stack_for_test(
[Api.safety, Api.shields, Api.inference], [Api.safety, Api.shields, Api.inference],
providers, providers,
provider_data, provider_data,
@ -125,5 +119,5 @@ async def safety_stack(inference_model, safety_shield, request):
shields=[safety_shield], shields=[safety_shield],
) )
shield = await impls[Api.shields].get_shield(safety_shield.shield_id) shield = await test_stack.impls[Api.shields].get_shield(safety_shield.shield_id)
return impls[Api.safety], impls[Api.shields], shield return test_stack.impls[Api.safety], test_stack.impls[Api.shields], shield

View file

@ -11,7 +11,7 @@ from llama_stack.apis.models import ModelInput
from llama_stack.distribution.datatypes import Api, Provider from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 from llama_stack.providers.tests.resolver import construct_stack_for_test
from ..conftest import ProviderFixture, remote_stack_fixture from ..conftest import ProviderFixture, remote_stack_fixture
@ -74,7 +74,7 @@ async def scoring_stack(request, inference_model):
if fixture.provider_data: if fixture.provider_data:
provider_data.update(fixture.provider_data) provider_data.update(fixture.provider_data)
impls = await resolve_impls_for_test_v2( test_stack = await construct_stack_for_test(
[Api.scoring, Api.datasetio, Api.inference], [Api.scoring, Api.datasetio, Api.inference],
providers, providers,
provider_data, provider_data,
@ -88,4 +88,4 @@ async def scoring_stack(request, inference_model):
], ],
) )
return impls return test_stack.impls