From 1c63ec981a48b8abb5e48a8de38d4e7bc67440c9 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Tue, 4 Mar 2025 09:42:00 -0800 Subject: [PATCH] feat(test): allow specifying simple ad-hoc distributions in LLAMA_STACK_CONFIG --- tests/api/conftest.py | 68 +++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 66 insertions(+), 2 deletions(-) diff --git a/tests/api/conftest.py b/tests/api/conftest.py index 52064fed4..dfe22dcc8 100644 --- a/tests/api/conftest.py +++ b/tests/api/conftest.py @@ -6,14 +6,22 @@ import copy import logging import os +import tempfile from pathlib import Path +from typing import List import pytest +import yaml from llama_stack_client import LlamaStackClient from llama_stack import LlamaStackAsLibraryClient from llama_stack.apis.datatypes import Api +from llama_stack.distribution.datatypes import Provider, StackRunConfig +from llama_stack.distribution.distribution import get_provider_registry +from llama_stack.distribution.stack import replace_env_vars +from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.providers.tests.env import get_env_or_fail +from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig from .fixtures.recordable_mock import RecordableMock from .report import Report @@ -99,11 +107,67 @@ def provider_data(): return provider_data if len(provider_data) > 0 else None +def distro_from_adhoc_config_spec(adhoc_config_spec: str) -> str: + """ + Create an adhoc distribution from a list of API providers. + + The list should be of the form "api=provider", e.g. "inference=fireworks". If you have + multiple pairs, separate them with commas or semicolons, e.g. "inference=fireworks,safety=llama-guard,agents=meta-reference" + """ + + api_providers = adhoc_config_spec.replace(";", ",").split(",") + provider_registry = get_provider_registry() + + provider_configs_by_api = {} + for api_provider in api_providers: + api_str, provider = api_provider.split("=") + api = Api(api_str) + + providers_by_type = provider_registry[api] + provider_spec = providers_by_type.get(provider) + if not provider_spec: + provider_spec = providers_by_type.get(f"inline::{provider}") + if not provider_spec: + provider_spec = providers_by_type.get(f"remote::{provider}") + + if not provider_spec: + raise ValueError( + f"Provider {provider} (or remote::{provider} or inline::{provider}) not found for API {api}" + ) + + # call method "sample_run_config" on the provider spec config class + provider_config_type = instantiate_class_type(provider_spec.config_class) + provider_config = replace_env_vars(provider_config_type.sample_run_config()) + + provider_configs_by_api[api_str] = [ + Provider( + provider_id=provider, + provider_type=provider_spec.provider_type, + config=provider_config, + ) + ] + sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db") + run_config_file = tempfile.NamedTemporaryFile(delete=False, suffix=".yaml") + with open(run_config_file.name, "w") as f: + config = StackRunConfig( + image_name="distro-test", + apis=list(provider_configs_by_api.keys()), + metadata_store=SqliteKVStoreConfig(db_path=sqlite_file.name), + providers=provider_configs_by_api, + ) + yaml.dump(config.model_dump(), f) + + return run_config_file.name + + @pytest.fixture(scope="session") -def llama_stack_client(provider_data, text_model_id): +def llama_stack_client(request, provider_data, text_model_id): if os.environ.get("LLAMA_STACK_CONFIG"): + config = get_env_or_fail("LLAMA_STACK_CONFIG") + if "=" in config: + config = distro_from_adhoc_config_spec(config) client = LlamaStackAsLibraryClient( - get_env_or_fail("LLAMA_STACK_CONFIG"), + config, provider_data=provider_data, skip_logger_removal=True, )