feat(test): allow specifying simple ad-hoc distributions in LLAMA_STACK_CONFIG

This commit is contained in:
Ashwin Bharambe 2025-03-04 09:42:00 -08:00
parent cb085d56c6
commit 1c63ec981a

View file

@ -6,14 +6,22 @@
import copy import copy
import logging import logging
import os import os
import tempfile
from pathlib import Path from pathlib import Path
from typing import List
import pytest import pytest
import yaml
from llama_stack_client import LlamaStackClient from llama_stack_client import LlamaStackClient
from llama_stack import LlamaStackAsLibraryClient from llama_stack import LlamaStackAsLibraryClient
from llama_stack.apis.datatypes import Api from llama_stack.apis.datatypes import Api
from llama_stack.distribution.datatypes import Provider, StackRunConfig
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.stack import replace_env_vars
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.providers.tests.env import get_env_or_fail from llama_stack.providers.tests.env import get_env_or_fail
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from .fixtures.recordable_mock import RecordableMock from .fixtures.recordable_mock import RecordableMock
from .report import Report from .report import Report
@ -99,11 +107,67 @@ def provider_data():
return provider_data if len(provider_data) > 0 else None return provider_data if len(provider_data) > 0 else None
def distro_from_adhoc_config_spec(adhoc_config_spec: str) -> str:
"""
Create an adhoc distribution from a list of API providers.
The list should be of the form "api=provider", e.g. "inference=fireworks". If you have
multiple pairs, separate them with commas or semicolons, e.g. "inference=fireworks,safety=llama-guard,agents=meta-reference"
"""
api_providers = adhoc_config_spec.replace(";", ",").split(",")
provider_registry = get_provider_registry()
provider_configs_by_api = {}
for api_provider in api_providers:
api_str, provider = api_provider.split("=")
api = Api(api_str)
providers_by_type = provider_registry[api]
provider_spec = providers_by_type.get(provider)
if not provider_spec:
provider_spec = providers_by_type.get(f"inline::{provider}")
if not provider_spec:
provider_spec = providers_by_type.get(f"remote::{provider}")
if not provider_spec:
raise ValueError(
f"Provider {provider} (or remote::{provider} or inline::{provider}) not found for API {api}"
)
# call method "sample_run_config" on the provider spec config class
provider_config_type = instantiate_class_type(provider_spec.config_class)
provider_config = replace_env_vars(provider_config_type.sample_run_config())
provider_configs_by_api[api_str] = [
Provider(
provider_id=provider,
provider_type=provider_spec.provider_type,
config=provider_config,
)
]
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
run_config_file = tempfile.NamedTemporaryFile(delete=False, suffix=".yaml")
with open(run_config_file.name, "w") as f:
config = StackRunConfig(
image_name="distro-test",
apis=list(provider_configs_by_api.keys()),
metadata_store=SqliteKVStoreConfig(db_path=sqlite_file.name),
providers=provider_configs_by_api,
)
yaml.dump(config.model_dump(), f)
return run_config_file.name
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def llama_stack_client(provider_data, text_model_id): def llama_stack_client(request, provider_data, text_model_id):
if os.environ.get("LLAMA_STACK_CONFIG"): if os.environ.get("LLAMA_STACK_CONFIG"):
config = get_env_or_fail("LLAMA_STACK_CONFIG")
if "=" in config:
config = distro_from_adhoc_config_spec(config)
client = LlamaStackAsLibraryClient( client = LlamaStackAsLibraryClient(
get_env_or_fail("LLAMA_STACK_CONFIG"), config,
provider_data=provider_data, provider_data=provider_data,
skip_logger_removal=True, skip_logger_removal=True,
) )