mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
# What does this PR do? <!-- Provide a short summary of what this PR does and why. Link to relevant issues if applicable. --> This PR is responsible for attaching prompts to storage stores in run configs. It allows to specify prompts as stores in different distributions. The need of this functionality was initiated in #3514 > Note, #3514 is divided on three separate PRs. Current PR is the first of three. <!-- If resolving an issue, uncomment and update the line below --> <!-- Closes #[issue-number] --> ## Test Plan <!-- Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.* --> Manual testing and updated CI unit tests Prerequisites: 1. `uv run --with llama-stack llama stack list-deps starter | xargs -L1 uv pip install` 2. `llama stack run starter ` ``` INFO 2025-10-23 15:36:17,387 llama_stack.cli.stack.run:100 cli: Using run configuration: /Users/ianmiller/llama-stack/llama_stack/distributions/starter/run.yaml INFO 2025-10-23 15:36:17,423 llama_stack.cli.stack.run:157 cli: HTTPS enabled with certificates: Key: None Cert: None INFO 2025-10-23 15:36:17,424 llama_stack.cli.stack.run:159 cli: Listening on ['::', '0.0.0.0']:8321 INFO 2025-10-23 15:36:17,749 llama_stack.core.server.server:521 core::server: Run configuration: INFO 2025-10-23 15:36:17,756 llama_stack.core.server.server:524 core::server: apis: - agents - batches - datasetio - eval - files - inference - post_training - safety - scoring - tool_runtime - vector_io image_name: starter providers: agents: - config: persistence: agent_state: backend: kv_default namespace: agents responses: backend: sql_default max_write_queue_size: 10000 num_writers: 4 table_name: responses provider_id: meta-reference provider_type: inline::meta-reference batches: - config: kvstore: backend: kv_default namespace: batches provider_id: reference provider_type: inline::reference datasetio: - config: kvstore: backend: kv_default namespace: datasetio::huggingface provider_id: huggingface provider_type: remote::huggingface - config: kvstore: backend: kv_default namespace: datasetio::localfs provider_id: localfs provider_type: inline::localfs eval: - config: kvstore: backend: kv_default namespace: eval provider_id: meta-reference provider_type: inline::meta-reference files: - config: metadata_store: backend: sql_default table_name: files_metadata storage_dir: /Users/ianmiller/.llama/distributions/starter/files provider_id: meta-reference-files provider_type: inline::localfs inference: - config: api_key: '********' url: https://api.fireworks.ai/inference/v1 provider_id: fireworks provider_type: remote::fireworks - config: api_key: '********' url: https://api.together.xyz/v1 provider_id: together provider_type: remote::together - config: {} provider_id: bedrock provider_type: remote::bedrock - config: api_key: '********' base_url: https://api.openai.com/v1 provider_id: openai provider_type: remote::openai - config: api_key: '********' provider_id: anthropic provider_type: remote::anthropic - config: api_key: '********' provider_id: gemini provider_type: remote::gemini - config: api_key: '********' url: https://api.groq.com provider_id: groq provider_type: remote::groq - config: api_key: '********' url: https://api.sambanova.ai/v1 provider_id: sambanova provider_type: remote::sambanova - config: {} provider_id: sentence-transformers provider_type: inline::sentence-transformers post_training: - config: checkpoint_format: meta provider_id: torchtune-cpu provider_type: inline::torchtune-cpu safety: - config: excluded_categories: [] provider_id: llama-guard provider_type: inline::llama-guard - config: {} provider_id: code-scanner provider_type: inline::code-scanner scoring: - config: {} provider_id: basic provider_type: inline::basic - config: {} provider_id: llm-as-judge provider_type: inline::llm-as-judge - config: openai_api_key: '********' provider_id: braintrust provider_type: inline::braintrust tool_runtime: - config: api_key: '********' max_results: 3 provider_id: brave-search provider_type: remote::brave-search - config: api_key: '********' max_results: 3 provider_id: tavily-search provider_type: remote::tavily-search - config: {} provider_id: rag-runtime provider_type: inline::rag-runtime - config: {} provider_id: model-context-protocol provider_type: remote::model-context-protocol vector_io: - config: persistence: backend: kv_default namespace: vector_io::faiss provider_id: faiss provider_type: inline::faiss - config: db_path: /Users/ianmiller/.llama/distributions/starter/sqlite_vec.db persistence: backend: kv_default namespace: vector_io::sqlite_vec provider_id: sqlite-vec provider_type: inline::sqlite-vec registered_resources: benchmarks: [] datasets: [] models: [] scoring_fns: [] shields: [] tool_groups: - provider_id: tavily-search toolgroup_id: builtin::websearch - provider_id: rag-runtime toolgroup_id: builtin::rag vector_stores: [] server: port: 8321 storage: backends: kv_default: db_path: /Users/ianmiller/.llama/distributions/starter/kvstore.db type: kv_sqlite sql_default: db_path: /Users/ianmiller/.llama/distributions/starter/sql_store.db type: sql_sqlite stores: conversations: backend: sql_default table_name: openai_conversations inference: backend: sql_default max_write_queue_size: 10000 num_writers: 4 table_name: inference_store metadata: backend: kv_default namespace: registry prompts: backend: kv_default namespace: prompts telemetry: enabled: true vector_stores: default_embedding_model: model_id: nomic-ai/nomic-embed-text-v1.5 provider_id: sentence-transformers default_provider_id: faiss version: 2 INFO 2025-10-23 15:36:20,032 llama_stack.providers.utils.inference.inference_store:74 inference: Write queue disabled for SQLite to avoid concurrency issues WARNING 2025-10-23 15:36:20,422 llama_stack.providers.inline.telemetry.meta_reference.telemetry:84 telemetry: OTEL_EXPORTER_OTLP_ENDPOINT is not set, skipping telemetry INFO 2025-10-23 15:36:22,379 llama_stack.providers.utils.inference.openai_mixin:436 providers::utils: OpenAIInferenceAdapter.list_provider_model_ids() returned 105 models INFO 2025-10-23 15:36:22,703 uvicorn.error:84 uncategorized: Started server process [17328] INFO 2025-10-23 15:36:22,704 uvicorn.error:48 uncategorized: Waiting for application startup. INFO 2025-10-23 15:36:22,706 llama_stack.core.server.server:179 core::server: Starting up Llama Stack server (version: 0.3.0) INFO 2025-10-23 15:36:22,707 llama_stack.core.stack:470 core: starting registry refresh task INFO 2025-10-23 15:36:22,708 uvicorn.error:62 uncategorized: Application startup complete. INFO 2025-10-23 15:36:22,708 uvicorn.error:216 uncategorized: Uvicorn running on http://['::', '0.0.0.0']:8321 (Press CTRL+C to quit) ``` As you can see, prompts are attached to stores in config Testing: 1. Create prompt: ``` curl -X POST http://localhost:8321/v1/prompts \ -H "Content-Type: application/json" \ -d '{ "prompt": "Hello {{name}}! You are working at {{company}}. Your role is {{role}} at {{company}}. Remember, {{name}}, to be {{tone}}.", "variables": ["name", "company", "role", "tone"] }' ``` `{"prompt":"Hello {{name}}! You are working at {{company}}. Your role is {{role}} at {{company}}. Remember, {{name}}, to be {{tone}}.","version":1,"prompt_id":"pmpt_a90e09e67acfe23776f2778c603eb6c17e139dab5f6e163f","variables":["name","company","role","tone"],"is_default":false}% ` 2. Get prompt: `curl -X GET http://localhost:8321/v1/prompts/pmpt_a90e09e67acfe23776f2778c603eb6c17e139dab5f6e163f` `{"prompt":"Hello {{name}}! You are working at {{company}}. Your role is {{role}} at {{company}}. Remember, {{name}}, to be {{tone}}.","version":1,"prompt_id":"pmpt_a90e09e67acfe23776f2778c603eb6c17e139dab5f6e163f","variables":["name","company","role","tone"],"is_default":false}% ` 3. Query sqlite KV storage to check created prompt: ``` sqlite> .mode column sqlite> .headers on sqlite> SELECT * FROM kvstore WHERE key LIKE 'prompts:v1:%'; key value expiration ------------------------------------------------------------ ------------------------------------------------------------ ---------- prompts:v1:pmpt_a90e09e67acfe23776f2778c603eb6c17e139dab5f6e {"prompt_id": "pmpt_a90e09e67acfe23776f2778c603eb6c17e139dab 163f:1 5f6e163f", "prompt": "Hello {{name}}! You are working at {{c ompany}}. Your role is {{role}} at {{company}}. Remember, {{ name}}, to be {{tone}}.", "version": 1, "variables": ["name" , "company", "role", "tone"], "is_default": false} prompts:v1:pmpt_a90e09e67acfe23776f2778c603eb6c17e139dab5f6e 1 163f:default sqlite> ```
880 lines
33 KiB
Python
880 lines
33 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from typing import Any
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
import yaml
|
|
from pydantic import BaseModel, Field, ValidationError
|
|
|
|
from llama_stack.core.datatypes import Api, Provider, StackRunConfig
|
|
from llama_stack.core.distribution import INTERNAL_APIS, get_provider_registry, providable_apis
|
|
from llama_stack.core.storage.datatypes import (
|
|
InferenceStoreReference,
|
|
KVStoreReference,
|
|
ServerStoresConfig,
|
|
SqliteKVStoreConfig,
|
|
SqliteSqlStoreConfig,
|
|
SqlStoreReference,
|
|
StorageConfig,
|
|
)
|
|
from llama_stack.providers.datatypes import ProviderSpec
|
|
|
|
|
|
class SampleConfig(BaseModel):
|
|
foo: str = Field(
|
|
default="bar",
|
|
description="foo",
|
|
)
|
|
|
|
@classmethod
|
|
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
|
|
return {
|
|
"foo": "baz",
|
|
}
|
|
|
|
|
|
def _default_storage() -> StorageConfig:
|
|
return StorageConfig(
|
|
backends={
|
|
"kv_default": SqliteKVStoreConfig(db_path=":memory:"),
|
|
"sql_default": SqliteSqlStoreConfig(db_path=":memory:"),
|
|
},
|
|
stores=ServerStoresConfig(
|
|
metadata=KVStoreReference(backend="kv_default", namespace="registry"),
|
|
inference=InferenceStoreReference(backend="sql_default", table_name="inference_store"),
|
|
conversations=SqlStoreReference(backend="sql_default", table_name="conversations"),
|
|
prompts=KVStoreReference(backend="kv_default", namespace="prompts"),
|
|
),
|
|
)
|
|
|
|
|
|
def make_stack_config(**overrides) -> StackRunConfig:
|
|
storage = overrides.pop("storage", _default_storage())
|
|
defaults = dict(
|
|
image_name="test_image",
|
|
apis=[],
|
|
providers={},
|
|
storage=storage,
|
|
)
|
|
defaults.update(overrides)
|
|
return StackRunConfig(**defaults)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_providers():
|
|
"""Mock the available_providers function to return test providers."""
|
|
with patch("llama_stack.providers.registry.inference.available_providers") as mock:
|
|
mock.return_value = [
|
|
ProviderSpec(
|
|
provider_type="test_provider",
|
|
api=Api.inference,
|
|
adapter_type="test_adapter",
|
|
config_class="test_provider.config.TestProviderConfig",
|
|
)
|
|
]
|
|
yield mock
|
|
|
|
|
|
@pytest.fixture
|
|
def base_config(tmp_path):
|
|
"""Create a base StackRunConfig with common settings."""
|
|
return make_stack_config(
|
|
apis=["inference"],
|
|
providers={
|
|
"inference": [
|
|
Provider(
|
|
provider_id="sample_provider",
|
|
provider_type="sample",
|
|
config=SampleConfig.sample_run_config(),
|
|
)
|
|
]
|
|
},
|
|
external_providers_dir=str(tmp_path),
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def provider_spec_yaml():
|
|
"""Common provider spec YAML for testing."""
|
|
return """
|
|
adapter_type: test_provider
|
|
config_class: test_provider.config.TestProviderConfig
|
|
module: test_provider
|
|
api_dependencies:
|
|
- safety
|
|
"""
|
|
|
|
|
|
@pytest.fixture
|
|
def inline_provider_spec_yaml():
|
|
"""Common inline provider spec YAML for testing."""
|
|
return """
|
|
module: test_provider
|
|
config_class: test_provider.config.TestProviderConfig
|
|
pip_packages:
|
|
- test-package
|
|
api_dependencies:
|
|
- safety
|
|
optional_api_dependencies:
|
|
- vector_io
|
|
provider_data_validator: test_provider.validator.TestValidator
|
|
container_image: test-image:latest
|
|
"""
|
|
|
|
|
|
@pytest.fixture
|
|
def api_directories(tmp_path):
|
|
"""Create the API directory structure for testing."""
|
|
# Create remote provider directory
|
|
remote_inference_dir = tmp_path / "remote" / "inference"
|
|
remote_inference_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Create inline provider directory
|
|
inline_inference_dir = tmp_path / "inline" / "inference"
|
|
inline_inference_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
return remote_inference_dir, inline_inference_dir
|
|
|
|
|
|
def make_import_module_side_effect(
|
|
builtin_provider_spec=None,
|
|
external_module=None,
|
|
raise_for_external=False,
|
|
missing_get_provider_spec=False,
|
|
):
|
|
from types import SimpleNamespace
|
|
|
|
def import_module_side_effect(name):
|
|
if name == "llama_stack.providers.registry.inference":
|
|
mock_builtin = SimpleNamespace(
|
|
available_providers=lambda: [
|
|
builtin_provider_spec
|
|
or ProviderSpec(
|
|
api=Api.inference,
|
|
provider_type="test_provider",
|
|
config_class="test_provider.config.TestProviderConfig",
|
|
module="test_provider",
|
|
)
|
|
]
|
|
)
|
|
return mock_builtin
|
|
elif name == "external_test.provider":
|
|
if raise_for_external:
|
|
raise ModuleNotFoundError(name)
|
|
if missing_get_provider_spec:
|
|
return SimpleNamespace()
|
|
return external_module
|
|
else:
|
|
raise ModuleNotFoundError(name)
|
|
|
|
return import_module_side_effect
|
|
|
|
|
|
class TestProviderRegistry:
|
|
"""Test suite for provider registry functionality."""
|
|
|
|
def test_builtin_providers(self, mock_providers):
|
|
"""Test loading built-in providers."""
|
|
registry = get_provider_registry(None)
|
|
|
|
assert Api.inference in registry
|
|
assert "test_provider" in registry[Api.inference]
|
|
assert registry[Api.inference]["test_provider"].provider_type == "test_provider"
|
|
assert registry[Api.inference]["test_provider"].api == Api.inference
|
|
|
|
def test_internal_apis_excluded(self):
|
|
"""Test that internal APIs are excluded and APIs without provider registries are marked as internal."""
|
|
import importlib
|
|
|
|
apis = providable_apis()
|
|
|
|
for internal_api in INTERNAL_APIS:
|
|
assert internal_api not in apis, f"Internal API {internal_api} should not be in providable_apis"
|
|
|
|
for api in apis:
|
|
module_name = f"llama_stack.providers.registry.{api.name.lower()}"
|
|
try:
|
|
importlib.import_module(module_name)
|
|
except ImportError as err:
|
|
raise AssertionError(
|
|
f"API {api} is in providable_apis but has no provider registry module ({module_name})"
|
|
) from err
|
|
|
|
def test_external_remote_providers(self, api_directories, mock_providers, base_config, provider_spec_yaml):
|
|
"""Test loading external remote providers from YAML files."""
|
|
remote_dir, _ = api_directories
|
|
with open(remote_dir / "test_provider.yaml", "w") as f:
|
|
f.write(provider_spec_yaml)
|
|
|
|
registry = get_provider_registry(base_config)
|
|
assert len(registry[Api.inference]) == 2
|
|
|
|
assert Api.inference in registry
|
|
assert "remote::test_provider" in registry[Api.inference]
|
|
provider = registry[Api.inference]["remote::test_provider"]
|
|
assert provider.adapter_type == "test_provider"
|
|
assert provider.module == "test_provider"
|
|
assert provider.config_class == "test_provider.config.TestProviderConfig"
|
|
assert Api.safety in provider.api_dependencies
|
|
|
|
def test_external_inline_providers(self, api_directories, mock_providers, base_config, inline_provider_spec_yaml):
|
|
"""Test loading external inline providers from YAML files."""
|
|
_, inline_dir = api_directories
|
|
with open(inline_dir / "test_provider.yaml", "w") as f:
|
|
f.write(inline_provider_spec_yaml)
|
|
|
|
registry = get_provider_registry(base_config)
|
|
assert len(registry[Api.inference]) == 2
|
|
|
|
assert Api.inference in registry
|
|
assert "inline::test_provider" in registry[Api.inference]
|
|
provider = registry[Api.inference]["inline::test_provider"]
|
|
assert provider.provider_type == "inline::test_provider"
|
|
assert provider.module == "test_provider"
|
|
assert provider.config_class == "test_provider.config.TestProviderConfig"
|
|
assert provider.pip_packages == ["test-package"]
|
|
assert Api.safety in provider.api_dependencies
|
|
assert Api.vector_io in provider.optional_api_dependencies
|
|
assert provider.provider_data_validator == "test_provider.validator.TestValidator"
|
|
assert provider.container_image == "test-image:latest"
|
|
|
|
def test_invalid_yaml(self, api_directories, mock_providers, base_config):
|
|
"""Test handling of invalid YAML files."""
|
|
remote_dir, inline_dir = api_directories
|
|
with open(remote_dir / "invalid.yaml", "w") as f:
|
|
f.write("invalid: yaml: content: -")
|
|
with open(inline_dir / "invalid.yaml", "w") as f:
|
|
f.write("invalid: yaml: content: -")
|
|
|
|
with pytest.raises(yaml.YAMLError):
|
|
get_provider_registry(base_config)
|
|
|
|
def test_missing_directory(self, mock_providers):
|
|
"""Test handling of missing external providers directory."""
|
|
config = make_stack_config(
|
|
apis=["inference"],
|
|
providers={
|
|
"inference": [
|
|
Provider(
|
|
provider_id="sample_provider",
|
|
provider_type="sample",
|
|
config=SampleConfig.sample_run_config(),
|
|
)
|
|
]
|
|
},
|
|
external_providers_dir="/nonexistent/dir",
|
|
)
|
|
with pytest.raises(FileNotFoundError):
|
|
get_provider_registry(config)
|
|
|
|
def test_empty_api_directory(self, api_directories, mock_providers, base_config):
|
|
"""Test handling of empty API directory."""
|
|
registry = get_provider_registry(base_config)
|
|
assert len(registry[Api.inference]) == 1 # Only built-in provider
|
|
|
|
def test_malformed_remote_provider_spec(self, api_directories, mock_providers, base_config):
|
|
"""Test handling of malformed remote provider spec (missing required fields)."""
|
|
remote_dir, _ = api_directories
|
|
malformed_spec = """
|
|
adapter_type: test_provider
|
|
# Missing required fields
|
|
api_dependencies:
|
|
- safety
|
|
"""
|
|
with open(remote_dir / "malformed.yaml", "w") as f:
|
|
f.write(malformed_spec)
|
|
|
|
with pytest.raises(ValidationError):
|
|
get_provider_registry(base_config)
|
|
|
|
def test_malformed_inline_provider_spec(self, api_directories, mock_providers, base_config):
|
|
"""Test handling of malformed inline provider spec (missing required fields)."""
|
|
_, inline_dir = api_directories
|
|
malformed_spec = """
|
|
module: test_provider
|
|
# Missing required config_class
|
|
pip_packages:
|
|
- test-package
|
|
"""
|
|
with open(inline_dir / "malformed.yaml", "w") as f:
|
|
f.write(malformed_spec)
|
|
|
|
with pytest.raises(ValidationError) as exc_info:
|
|
get_provider_registry(base_config)
|
|
assert "config_class" in str(exc_info.value)
|
|
|
|
def test_external_provider_from_module_success(self, mock_providers):
|
|
"""Test loading an external provider from a module (success path)."""
|
|
from types import SimpleNamespace
|
|
|
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
|
|
|
# Simulate a provider module with get_provider_spec
|
|
fake_spec = ProviderSpec(
|
|
api=Api.inference,
|
|
provider_type="external_test",
|
|
config_class="external_test.config.ExternalTestConfig",
|
|
module="external_test",
|
|
)
|
|
fake_module = SimpleNamespace(get_provider_spec=lambda: fake_spec)
|
|
|
|
import_module_side_effect = make_import_module_side_effect(external_module=fake_module)
|
|
|
|
with patch("importlib.import_module", side_effect=import_module_side_effect) as mock_import:
|
|
config = make_stack_config(
|
|
image_name="test_image",
|
|
providers={
|
|
"inference": [
|
|
Provider(
|
|
provider_id="external_test",
|
|
provider_type="external_test",
|
|
config={},
|
|
module="external_test",
|
|
)
|
|
]
|
|
},
|
|
)
|
|
registry = get_provider_registry(config)
|
|
assert Api.inference in registry
|
|
assert "external_test" in registry[Api.inference]
|
|
provider = registry[Api.inference]["external_test"]
|
|
assert provider.module == "external_test"
|
|
assert provider.config_class == "external_test.config.ExternalTestConfig"
|
|
mock_import.assert_any_call("llama_stack.providers.registry.inference")
|
|
mock_import.assert_any_call("external_test.provider")
|
|
|
|
def test_external_provider_from_module_not_found(self, mock_providers):
|
|
"""Test handling ModuleNotFoundError for missing provider module."""
|
|
|
|
import_module_side_effect = make_import_module_side_effect(raise_for_external=True)
|
|
|
|
with patch("importlib.import_module", side_effect=import_module_side_effect):
|
|
config = make_stack_config(
|
|
image_name="test_image",
|
|
providers={
|
|
"inference": [
|
|
Provider(
|
|
provider_id="external_test",
|
|
provider_type="external_test",
|
|
config={},
|
|
module="external_test",
|
|
)
|
|
]
|
|
},
|
|
)
|
|
with pytest.raises(ValueError) as exc_info:
|
|
get_provider_registry(config)
|
|
assert "get_provider_spec not found" in str(exc_info.value)
|
|
|
|
def test_external_provider_from_module_missing_get_provider_spec(self, mock_providers):
|
|
"""Test handling missing get_provider_spec in provider module (should raise ValueError)."""
|
|
|
|
import_module_side_effect = make_import_module_side_effect(missing_get_provider_spec=True)
|
|
|
|
with patch("importlib.import_module", side_effect=import_module_side_effect):
|
|
config = make_stack_config(
|
|
image_name="test_image",
|
|
providers={
|
|
"inference": [
|
|
Provider(
|
|
provider_id="external_test",
|
|
provider_type="external_test",
|
|
config={},
|
|
module="external_test",
|
|
)
|
|
]
|
|
},
|
|
)
|
|
with pytest.raises(AttributeError):
|
|
get_provider_registry(config)
|
|
|
|
def test_external_provider_from_module_building(self, mock_providers):
|
|
"""Test loading an external provider from a module during build (building=True, partial spec)."""
|
|
from llama_stack.core.datatypes import BuildConfig, BuildProvider, DistributionSpec
|
|
from llama_stack.providers.datatypes import Api
|
|
|
|
# No importlib patch needed, should not import module when type of `config` is BuildConfig or DistributionSpec
|
|
build_config = BuildConfig(
|
|
version=2,
|
|
image_type="container",
|
|
image_name="test_image",
|
|
distribution_spec=DistributionSpec(
|
|
description="test",
|
|
providers={
|
|
"inference": [
|
|
BuildProvider(
|
|
provider_type="external_test",
|
|
module="external_test",
|
|
)
|
|
]
|
|
},
|
|
),
|
|
)
|
|
registry = get_provider_registry(build_config)
|
|
assert Api.inference in registry
|
|
assert "external_test" in registry[Api.inference]
|
|
provider = registry[Api.inference]["external_test"]
|
|
assert provider.module == "external_test"
|
|
assert provider.is_external is True
|
|
# config_class is empty string in partial spec
|
|
assert provider.config_class == ""
|
|
|
|
|
|
class TestGetExternalProvidersFromModule:
|
|
"""Test suite for installing external providers from module."""
|
|
|
|
def test_stackrunconfig_provider_without_module(self, mock_providers):
|
|
"""Test that providers without module attribute are skipped."""
|
|
from llama_stack.core.distribution import get_external_providers_from_module
|
|
|
|
import_module_side_effect = make_import_module_side_effect()
|
|
|
|
with patch("importlib.import_module", side_effect=import_module_side_effect):
|
|
config = make_stack_config(
|
|
image_name="test_image",
|
|
providers={
|
|
"inference": [
|
|
Provider(
|
|
provider_id="no_module",
|
|
provider_type="no_module",
|
|
config={},
|
|
)
|
|
]
|
|
},
|
|
)
|
|
registry = {Api.inference: {}}
|
|
result = get_external_providers_from_module(registry, config, building=False)
|
|
# Should not add anything to registry
|
|
assert len(result[Api.inference]) == 0
|
|
|
|
def test_stackrunconfig_with_version_spec(self, mock_providers):
|
|
"""Test provider with module containing version spec (e.g., package==1.0.0)."""
|
|
from types import SimpleNamespace
|
|
|
|
from llama_stack.core.distribution import get_external_providers_from_module
|
|
from llama_stack.providers.datatypes import ProviderSpec
|
|
|
|
fake_spec = ProviderSpec(
|
|
api=Api.inference,
|
|
provider_type="versioned_test",
|
|
config_class="versioned_test.config.VersionedTestConfig",
|
|
module="versioned_test==1.0.0",
|
|
)
|
|
fake_module = SimpleNamespace(get_provider_spec=lambda: fake_spec)
|
|
|
|
def import_side_effect(name):
|
|
if name == "versioned_test.provider":
|
|
return fake_module
|
|
raise ModuleNotFoundError(name)
|
|
|
|
with patch("importlib.import_module", side_effect=import_side_effect):
|
|
config = make_stack_config(
|
|
image_name="test_image",
|
|
providers={
|
|
"inference": [
|
|
Provider(
|
|
provider_id="versioned",
|
|
provider_type="versioned_test",
|
|
config={},
|
|
module="versioned_test==1.0.0",
|
|
)
|
|
]
|
|
},
|
|
)
|
|
registry = {Api.inference: {}}
|
|
result = get_external_providers_from_module(registry, config, building=False)
|
|
assert "versioned_test" in result[Api.inference]
|
|
assert result[Api.inference]["versioned_test"].module == "versioned_test==1.0.0"
|
|
|
|
def test_buildconfig_does_not_import_module(self, mock_providers):
|
|
"""Test that BuildConfig does not import the module (building=True)."""
|
|
from llama_stack.core.datatypes import BuildConfig, BuildProvider, DistributionSpec
|
|
from llama_stack.core.distribution import get_external_providers_from_module
|
|
|
|
build_config = BuildConfig(
|
|
version=2,
|
|
image_type="container",
|
|
image_name="test_image",
|
|
distribution_spec=DistributionSpec(
|
|
description="test",
|
|
providers={
|
|
"inference": [
|
|
BuildProvider(
|
|
provider_type="build_test",
|
|
module="build_test==1.0.0",
|
|
)
|
|
]
|
|
},
|
|
),
|
|
)
|
|
|
|
# Should not call import_module at all when building
|
|
with patch("importlib.import_module") as mock_import:
|
|
registry = {Api.inference: {}}
|
|
result = get_external_providers_from_module(registry, build_config, building=True)
|
|
|
|
# Verify module was NOT imported
|
|
mock_import.assert_not_called()
|
|
|
|
# Verify partial spec was created
|
|
assert "build_test" in result[Api.inference]
|
|
provider = result[Api.inference]["build_test"]
|
|
assert provider.module == "build_test==1.0.0"
|
|
assert provider.is_external is True
|
|
assert provider.config_class == ""
|
|
assert provider.api == Api.inference
|
|
|
|
def test_buildconfig_multiple_providers(self, mock_providers):
|
|
"""Test BuildConfig with multiple providers for the same API."""
|
|
from llama_stack.core.datatypes import BuildConfig, BuildProvider, DistributionSpec
|
|
from llama_stack.core.distribution import get_external_providers_from_module
|
|
|
|
build_config = BuildConfig(
|
|
version=2,
|
|
image_type="container",
|
|
image_name="test_image",
|
|
distribution_spec=DistributionSpec(
|
|
description="test",
|
|
providers={
|
|
"inference": [
|
|
BuildProvider(provider_type="provider1", module="provider1"),
|
|
BuildProvider(provider_type="provider2", module="provider2"),
|
|
]
|
|
},
|
|
),
|
|
)
|
|
|
|
with patch("importlib.import_module") as mock_import:
|
|
registry = {Api.inference: {}}
|
|
result = get_external_providers_from_module(registry, build_config, building=True)
|
|
|
|
mock_import.assert_not_called()
|
|
assert "provider1" in result[Api.inference]
|
|
assert "provider2" in result[Api.inference]
|
|
|
|
def test_distributionspec_does_not_import_module(self, mock_providers):
|
|
"""Test that DistributionSpec does not import the module (building=True)."""
|
|
from llama_stack.core.datatypes import BuildProvider, DistributionSpec
|
|
from llama_stack.core.distribution import get_external_providers_from_module
|
|
|
|
dist_spec = DistributionSpec(
|
|
description="test distribution",
|
|
providers={
|
|
"inference": [
|
|
BuildProvider(
|
|
provider_type="dist_test",
|
|
module="dist_test==2.0.0",
|
|
)
|
|
]
|
|
},
|
|
)
|
|
|
|
# Should not call import_module at all when building
|
|
with patch("importlib.import_module") as mock_import:
|
|
registry = {Api.inference: {}}
|
|
result = get_external_providers_from_module(registry, dist_spec, building=True)
|
|
|
|
# Verify module was NOT imported
|
|
mock_import.assert_not_called()
|
|
|
|
# Verify partial spec was created
|
|
assert "dist_test" in result[Api.inference]
|
|
provider = result[Api.inference]["dist_test"]
|
|
assert provider.module == "dist_test==2.0.0"
|
|
assert provider.is_external is True
|
|
assert provider.config_class == ""
|
|
|
|
def test_list_return_from_get_provider_spec(self, mock_providers):
|
|
"""Test when get_provider_spec returns a list of specs."""
|
|
from types import SimpleNamespace
|
|
|
|
from llama_stack.core.distribution import get_external_providers_from_module
|
|
from llama_stack.providers.datatypes import ProviderSpec
|
|
|
|
spec1 = ProviderSpec(
|
|
api=Api.inference,
|
|
provider_type="list_test",
|
|
config_class="list_test.config.Config1",
|
|
module="list_test",
|
|
)
|
|
spec2 = ProviderSpec(
|
|
api=Api.inference,
|
|
provider_type="list_test_remote",
|
|
config_class="list_test.config.Config2",
|
|
module="list_test",
|
|
)
|
|
|
|
fake_module = SimpleNamespace(get_provider_spec=lambda: [spec1, spec2])
|
|
|
|
def import_side_effect(name):
|
|
if name == "list_test.provider":
|
|
return fake_module
|
|
raise ModuleNotFoundError(name)
|
|
|
|
with patch("importlib.import_module", side_effect=import_side_effect):
|
|
config = make_stack_config(
|
|
image_name="test_image",
|
|
providers={
|
|
"inference": [
|
|
Provider(
|
|
provider_id="list_test",
|
|
provider_type="list_test",
|
|
config={},
|
|
module="list_test",
|
|
)
|
|
]
|
|
},
|
|
)
|
|
registry = {Api.inference: {}}
|
|
result = get_external_providers_from_module(registry, config, building=False)
|
|
|
|
# Only the matching provider_type should be added
|
|
assert "list_test" in result[Api.inference]
|
|
assert result[Api.inference]["list_test"].config_class == "list_test.config.Config1"
|
|
|
|
def test_list_return_filters_by_provider_type(self, mock_providers):
|
|
"""Test that list return filters specs by provider_type."""
|
|
from types import SimpleNamespace
|
|
|
|
from llama_stack.core.distribution import get_external_providers_from_module
|
|
from llama_stack.providers.datatypes import ProviderSpec
|
|
|
|
spec1 = ProviderSpec(
|
|
api=Api.inference,
|
|
provider_type="wanted",
|
|
config_class="test.Config1",
|
|
module="test",
|
|
)
|
|
spec2 = ProviderSpec(
|
|
api=Api.inference,
|
|
provider_type="unwanted",
|
|
config_class="test.Config2",
|
|
module="test",
|
|
)
|
|
|
|
fake_module = SimpleNamespace(get_provider_spec=lambda: [spec1, spec2])
|
|
|
|
def import_side_effect(name):
|
|
if name == "test.provider":
|
|
return fake_module
|
|
raise ModuleNotFoundError(name)
|
|
|
|
with patch("importlib.import_module", side_effect=import_side_effect):
|
|
config = make_stack_config(
|
|
image_name="test_image",
|
|
providers={
|
|
"inference": [
|
|
Provider(
|
|
provider_id="wanted",
|
|
provider_type="wanted",
|
|
config={},
|
|
module="test",
|
|
)
|
|
]
|
|
},
|
|
)
|
|
registry = {Api.inference: {}}
|
|
result = get_external_providers_from_module(registry, config, building=False)
|
|
|
|
# Only the matching provider_type should be added
|
|
assert "wanted" in result[Api.inference]
|
|
assert "unwanted" not in result[Api.inference]
|
|
|
|
def test_list_return_adds_multiple_provider_types(self, mock_providers):
|
|
"""Test that list return adds multiple different provider_types when config requests them."""
|
|
from types import SimpleNamespace
|
|
|
|
from llama_stack.core.distribution import get_external_providers_from_module
|
|
from llama_stack.providers.datatypes import ProviderSpec
|
|
|
|
# Module returns both inline and remote variants
|
|
spec1 = ProviderSpec(
|
|
api=Api.inference,
|
|
provider_type="remote::ollama",
|
|
config_class="test.RemoteConfig",
|
|
module="test",
|
|
)
|
|
spec2 = ProviderSpec(
|
|
api=Api.inference,
|
|
provider_type="inline::ollama",
|
|
config_class="test.InlineConfig",
|
|
module="test",
|
|
)
|
|
|
|
fake_module = SimpleNamespace(get_provider_spec=lambda: [spec1, spec2])
|
|
|
|
def import_side_effect(name):
|
|
if name == "test.provider":
|
|
return fake_module
|
|
raise ModuleNotFoundError(name)
|
|
|
|
with patch("importlib.import_module", side_effect=import_side_effect):
|
|
config = make_stack_config(
|
|
image_name="test_image",
|
|
providers={
|
|
"inference": [
|
|
Provider(
|
|
provider_id="remote_ollama",
|
|
provider_type="remote::ollama",
|
|
config={},
|
|
module="test",
|
|
),
|
|
Provider(
|
|
provider_id="inline_ollama",
|
|
provider_type="inline::ollama",
|
|
config={},
|
|
module="test",
|
|
),
|
|
]
|
|
},
|
|
)
|
|
registry = {Api.inference: {}}
|
|
result = get_external_providers_from_module(registry, config, building=False)
|
|
|
|
# Both provider types should be added to registry
|
|
assert "remote::ollama" in result[Api.inference]
|
|
assert "inline::ollama" in result[Api.inference]
|
|
assert result[Api.inference]["remote::ollama"].config_class == "test.RemoteConfig"
|
|
assert result[Api.inference]["inline::ollama"].config_class == "test.InlineConfig"
|
|
|
|
def test_module_not_found_raises_value_error(self, mock_providers):
|
|
"""Test that ModuleNotFoundError raises ValueError with helpful message."""
|
|
from llama_stack.core.distribution import get_external_providers_from_module
|
|
|
|
def import_side_effect(name):
|
|
if name == "missing_module.provider":
|
|
raise ModuleNotFoundError(name)
|
|
raise ModuleNotFoundError(name)
|
|
|
|
with patch("importlib.import_module", side_effect=import_side_effect):
|
|
config = make_stack_config(
|
|
image_name="test_image",
|
|
providers={
|
|
"inference": [
|
|
Provider(
|
|
provider_id="missing",
|
|
provider_type="missing",
|
|
config={},
|
|
module="missing_module",
|
|
)
|
|
]
|
|
},
|
|
)
|
|
registry = {Api.inference: {}}
|
|
|
|
with pytest.raises(ValueError) as exc_info:
|
|
get_external_providers_from_module(registry, config, building=False)
|
|
|
|
assert "get_provider_spec not found" in str(exc_info.value)
|
|
|
|
def test_generic_exception_is_raised(self, mock_providers):
|
|
"""Test that generic exceptions are properly raised."""
|
|
from types import SimpleNamespace
|
|
|
|
from llama_stack.core.distribution import get_external_providers_from_module
|
|
|
|
def bad_spec():
|
|
raise RuntimeError("Something went wrong")
|
|
|
|
fake_module = SimpleNamespace(get_provider_spec=bad_spec)
|
|
|
|
def import_side_effect(name):
|
|
if name == "error_module.provider":
|
|
return fake_module
|
|
raise ModuleNotFoundError(name)
|
|
|
|
with patch("importlib.import_module", side_effect=import_side_effect):
|
|
config = make_stack_config(
|
|
image_name="test_image",
|
|
providers={
|
|
"inference": [
|
|
Provider(
|
|
provider_id="error",
|
|
provider_type="error",
|
|
config={},
|
|
module="error_module",
|
|
)
|
|
]
|
|
},
|
|
)
|
|
registry = {Api.inference: {}}
|
|
|
|
with pytest.raises(RuntimeError) as exc_info:
|
|
get_external_providers_from_module(registry, config, building=False)
|
|
|
|
assert "Something went wrong" in str(exc_info.value)
|
|
|
|
def test_empty_provider_list(self, mock_providers):
|
|
"""Test with empty provider list."""
|
|
from llama_stack.core.distribution import get_external_providers_from_module
|
|
|
|
config = make_stack_config(
|
|
image_name="test_image",
|
|
providers={},
|
|
)
|
|
registry = {Api.inference: {}}
|
|
result = get_external_providers_from_module(registry, config, building=False)
|
|
|
|
# Should return registry unchanged
|
|
assert result == registry
|
|
assert len(result[Api.inference]) == 0
|
|
|
|
def test_multiple_apis_with_providers(self, mock_providers):
|
|
"""Test multiple APIs with providers."""
|
|
from types import SimpleNamespace
|
|
|
|
from llama_stack.core.distribution import get_external_providers_from_module
|
|
from llama_stack.providers.datatypes import ProviderSpec
|
|
|
|
inference_spec = ProviderSpec(
|
|
api=Api.inference,
|
|
provider_type="inf_test",
|
|
config_class="inf.Config",
|
|
module="inf_test",
|
|
)
|
|
safety_spec = ProviderSpec(
|
|
api=Api.safety,
|
|
provider_type="safe_test",
|
|
config_class="safe.Config",
|
|
module="safe_test",
|
|
)
|
|
|
|
def import_side_effect(name):
|
|
if name == "inf_test.provider":
|
|
return SimpleNamespace(get_provider_spec=lambda: inference_spec)
|
|
elif name == "safe_test.provider":
|
|
return SimpleNamespace(get_provider_spec=lambda: safety_spec)
|
|
raise ModuleNotFoundError(name)
|
|
|
|
with patch("importlib.import_module", side_effect=import_side_effect):
|
|
config = make_stack_config(
|
|
image_name="test_image",
|
|
providers={
|
|
"inference": [
|
|
Provider(
|
|
provider_id="inf",
|
|
provider_type="inf_test",
|
|
config={},
|
|
module="inf_test",
|
|
)
|
|
],
|
|
"safety": [
|
|
Provider(
|
|
provider_id="safe",
|
|
provider_type="safe_test",
|
|
config={},
|
|
module="safe_test",
|
|
)
|
|
],
|
|
},
|
|
)
|
|
registry = {Api.inference: {}, Api.safety: {}}
|
|
result = get_external_providers_from_module(registry, config, building=False)
|
|
|
|
assert "inf_test" in result[Api.inference]
|
|
assert "safe_test" in result[Api.safety]
|