forked from phoenix-oss/llama-stack-mirror
feat(test): allow specifying simple ad-hoc distributions in LLAMA_STACK_CONFIG
This commit is contained in:
parent
cb085d56c6
commit
1c63ec981a
1 changed files with 66 additions and 2 deletions
|
@ -6,14 +6,22 @@
|
|||
import copy
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from llama_stack_client import LlamaStackClient
|
||||
|
||||
from llama_stack import LlamaStackAsLibraryClient
|
||||
from llama_stack.apis.datatypes import Api
|
||||
from llama_stack.distribution.datatypes import Provider, StackRunConfig
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.stack import replace_env_vars
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.providers.tests.env import get_env_or_fail
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
from .fixtures.recordable_mock import RecordableMock
|
||||
from .report import Report
|
||||
|
@ -99,11 +107,67 @@ def provider_data():
|
|||
return provider_data if len(provider_data) > 0 else None
|
||||
|
||||
|
||||
def distro_from_adhoc_config_spec(adhoc_config_spec: str) -> str:
|
||||
"""
|
||||
Create an adhoc distribution from a list of API providers.
|
||||
|
||||
The list should be of the form "api=provider", e.g. "inference=fireworks". If you have
|
||||
multiple pairs, separate them with commas or semicolons, e.g. "inference=fireworks,safety=llama-guard,agents=meta-reference"
|
||||
"""
|
||||
|
||||
api_providers = adhoc_config_spec.replace(";", ",").split(",")
|
||||
provider_registry = get_provider_registry()
|
||||
|
||||
provider_configs_by_api = {}
|
||||
for api_provider in api_providers:
|
||||
api_str, provider = api_provider.split("=")
|
||||
api = Api(api_str)
|
||||
|
||||
providers_by_type = provider_registry[api]
|
||||
provider_spec = providers_by_type.get(provider)
|
||||
if not provider_spec:
|
||||
provider_spec = providers_by_type.get(f"inline::{provider}")
|
||||
if not provider_spec:
|
||||
provider_spec = providers_by_type.get(f"remote::{provider}")
|
||||
|
||||
if not provider_spec:
|
||||
raise ValueError(
|
||||
f"Provider {provider} (or remote::{provider} or inline::{provider}) not found for API {api}"
|
||||
)
|
||||
|
||||
# call method "sample_run_config" on the provider spec config class
|
||||
provider_config_type = instantiate_class_type(provider_spec.config_class)
|
||||
provider_config = replace_env_vars(provider_config_type.sample_run_config())
|
||||
|
||||
provider_configs_by_api[api_str] = [
|
||||
Provider(
|
||||
provider_id=provider,
|
||||
provider_type=provider_spec.provider_type,
|
||||
config=provider_config,
|
||||
)
|
||||
]
|
||||
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
||||
run_config_file = tempfile.NamedTemporaryFile(delete=False, suffix=".yaml")
|
||||
with open(run_config_file.name, "w") as f:
|
||||
config = StackRunConfig(
|
||||
image_name="distro-test",
|
||||
apis=list(provider_configs_by_api.keys()),
|
||||
metadata_store=SqliteKVStoreConfig(db_path=sqlite_file.name),
|
||||
providers=provider_configs_by_api,
|
||||
)
|
||||
yaml.dump(config.model_dump(), f)
|
||||
|
||||
return run_config_file.name
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def llama_stack_client(provider_data, text_model_id):
|
||||
def llama_stack_client(request, provider_data, text_model_id):
|
||||
if os.environ.get("LLAMA_STACK_CONFIG"):
|
||||
config = get_env_or_fail("LLAMA_STACK_CONFIG")
|
||||
if "=" in config:
|
||||
config = distro_from_adhoc_config_spec(config)
|
||||
client = LlamaStackAsLibraryClient(
|
||||
get_env_or_fail("LLAMA_STACK_CONFIG"),
|
||||
config,
|
||||
provider_data=provider_data,
|
||||
skip_logger_removal=True,
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue