forked from phoenix-oss/llama-stack-mirror
feat(test): allow specifying simple ad-hoc distributions in LLAMA_STACK_CONFIG
This commit is contained in:
parent
cb085d56c6
commit
1c63ec981a
1 changed files with 66 additions and 2 deletions
|
@ -6,14 +6,22 @@
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import yaml
|
||||||
from llama_stack_client import LlamaStackClient
|
from llama_stack_client import LlamaStackClient
|
||||||
|
|
||||||
from llama_stack import LlamaStackAsLibraryClient
|
from llama_stack import LlamaStackAsLibraryClient
|
||||||
from llama_stack.apis.datatypes import Api
|
from llama_stack.apis.datatypes import Api
|
||||||
|
from llama_stack.distribution.datatypes import Provider, StackRunConfig
|
||||||
|
from llama_stack.distribution.distribution import get_provider_registry
|
||||||
|
from llama_stack.distribution.stack import replace_env_vars
|
||||||
|
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||||
from llama_stack.providers.tests.env import get_env_or_fail
|
from llama_stack.providers.tests.env import get_env_or_fail
|
||||||
|
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||||
|
|
||||||
from .fixtures.recordable_mock import RecordableMock
|
from .fixtures.recordable_mock import RecordableMock
|
||||||
from .report import Report
|
from .report import Report
|
||||||
|
@ -99,11 +107,67 @@ def provider_data():
|
||||||
return provider_data if len(provider_data) > 0 else None
|
return provider_data if len(provider_data) > 0 else None
|
||||||
|
|
||||||
|
|
||||||
|
def distro_from_adhoc_config_spec(adhoc_config_spec: str) -> str:
|
||||||
|
"""
|
||||||
|
Create an adhoc distribution from a list of API providers.
|
||||||
|
|
||||||
|
The list should be of the form "api=provider", e.g. "inference=fireworks". If you have
|
||||||
|
multiple pairs, separate them with commas or semicolons, e.g. "inference=fireworks,safety=llama-guard,agents=meta-reference"
|
||||||
|
"""
|
||||||
|
|
||||||
|
api_providers = adhoc_config_spec.replace(";", ",").split(",")
|
||||||
|
provider_registry = get_provider_registry()
|
||||||
|
|
||||||
|
provider_configs_by_api = {}
|
||||||
|
for api_provider in api_providers:
|
||||||
|
api_str, provider = api_provider.split("=")
|
||||||
|
api = Api(api_str)
|
||||||
|
|
||||||
|
providers_by_type = provider_registry[api]
|
||||||
|
provider_spec = providers_by_type.get(provider)
|
||||||
|
if not provider_spec:
|
||||||
|
provider_spec = providers_by_type.get(f"inline::{provider}")
|
||||||
|
if not provider_spec:
|
||||||
|
provider_spec = providers_by_type.get(f"remote::{provider}")
|
||||||
|
|
||||||
|
if not provider_spec:
|
||||||
|
raise ValueError(
|
||||||
|
f"Provider {provider} (or remote::{provider} or inline::{provider}) not found for API {api}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# call method "sample_run_config" on the provider spec config class
|
||||||
|
provider_config_type = instantiate_class_type(provider_spec.config_class)
|
||||||
|
provider_config = replace_env_vars(provider_config_type.sample_run_config())
|
||||||
|
|
||||||
|
provider_configs_by_api[api_str] = [
|
||||||
|
Provider(
|
||||||
|
provider_id=provider,
|
||||||
|
provider_type=provider_spec.provider_type,
|
||||||
|
config=provider_config,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
||||||
|
run_config_file = tempfile.NamedTemporaryFile(delete=False, suffix=".yaml")
|
||||||
|
with open(run_config_file.name, "w") as f:
|
||||||
|
config = StackRunConfig(
|
||||||
|
image_name="distro-test",
|
||||||
|
apis=list(provider_configs_by_api.keys()),
|
||||||
|
metadata_store=SqliteKVStoreConfig(db_path=sqlite_file.name),
|
||||||
|
providers=provider_configs_by_api,
|
||||||
|
)
|
||||||
|
yaml.dump(config.model_dump(), f)
|
||||||
|
|
||||||
|
return run_config_file.name
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def llama_stack_client(provider_data, text_model_id):
|
def llama_stack_client(request, provider_data, text_model_id):
|
||||||
if os.environ.get("LLAMA_STACK_CONFIG"):
|
if os.environ.get("LLAMA_STACK_CONFIG"):
|
||||||
|
config = get_env_or_fail("LLAMA_STACK_CONFIG")
|
||||||
|
if "=" in config:
|
||||||
|
config = distro_from_adhoc_config_spec(config)
|
||||||
client = LlamaStackAsLibraryClient(
|
client = LlamaStackAsLibraryClient(
|
||||||
get_env_or_fail("LLAMA_STACK_CONFIG"),
|
config,
|
||||||
provider_data=provider_data,
|
provider_data=provider_data,
|
||||||
skip_logger_removal=True,
|
skip_logger_removal=True,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue