feat(test): allow specifying simple ad-hoc distributions in LLAMA_STACK_CONFIG

This commit is contained in:
Ashwin Bharambe 2025-03-04 09:42:00 -08:00
parent cb085d56c6
commit 1c63ec981a

View file

@ -6,14 +6,22 @@
import copy
import logging
import os
import tempfile
from pathlib import Path
from typing import List
import pytest
import yaml
from llama_stack_client import LlamaStackClient
from llama_stack import LlamaStackAsLibraryClient
from llama_stack.apis.datatypes import Api
from llama_stack.distribution.datatypes import Provider, StackRunConfig
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.stack import replace_env_vars
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.providers.tests.env import get_env_or_fail
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from .fixtures.recordable_mock import RecordableMock
from .report import Report
@ -99,11 +107,67 @@ def provider_data():
return provider_data if len(provider_data) > 0 else None
def distro_from_adhoc_config_spec(adhoc_config_spec: str) -> str:
"""
Create an adhoc distribution from a list of API providers.
The list should be of the form "api=provider", e.g. "inference=fireworks". If you have
multiple pairs, separate them with commas or semicolons, e.g. "inference=fireworks,safety=llama-guard,agents=meta-reference"
"""
api_providers = adhoc_config_spec.replace(";", ",").split(",")
provider_registry = get_provider_registry()
provider_configs_by_api = {}
for api_provider in api_providers:
api_str, provider = api_provider.split("=")
api = Api(api_str)
providers_by_type = provider_registry[api]
provider_spec = providers_by_type.get(provider)
if not provider_spec:
provider_spec = providers_by_type.get(f"inline::{provider}")
if not provider_spec:
provider_spec = providers_by_type.get(f"remote::{provider}")
if not provider_spec:
raise ValueError(
f"Provider {provider} (or remote::{provider} or inline::{provider}) not found for API {api}"
)
# call method "sample_run_config" on the provider spec config class
provider_config_type = instantiate_class_type(provider_spec.config_class)
provider_config = replace_env_vars(provider_config_type.sample_run_config())
provider_configs_by_api[api_str] = [
Provider(
provider_id=provider,
provider_type=provider_spec.provider_type,
config=provider_config,
)
]
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
run_config_file = tempfile.NamedTemporaryFile(delete=False, suffix=".yaml")
with open(run_config_file.name, "w") as f:
config = StackRunConfig(
image_name="distro-test",
apis=list(provider_configs_by_api.keys()),
metadata_store=SqliteKVStoreConfig(db_path=sqlite_file.name),
providers=provider_configs_by_api,
)
yaml.dump(config.model_dump(), f)
return run_config_file.name
@pytest.fixture(scope="session")
def llama_stack_client(provider_data, text_model_id):
def llama_stack_client(request, provider_data, text_model_id):
if os.environ.get("LLAMA_STACK_CONFIG"):
config = get_env_or_fail("LLAMA_STACK_CONFIG")
if "=" in config:
config = distro_from_adhoc_config_spec(config)
client = LlamaStackAsLibraryClient(
get_env_or_fail("LLAMA_STACK_CONFIG"),
config,
provider_data=provider_data,
skip_logger_removal=True,
)