llama-stack-mirror/tests/unit/distribution/test_distribution.py
IAN MILLER 98a5047f9d
feat(prompts): attach prompts to storage stores in run configs (#3893)
# What does this PR do?
<!-- Provide a short summary of what this PR does and why. Link to
relevant issues if applicable. -->
This PR is responsible for attaching prompts to storage stores in run
configs. It allows to specify prompts as stores in different
distributions. The need of this functionality was initiated in #3514

> Note, #3514 is divided on three separate PRs. Current PR is the first
of three.

<!-- If resolving an issue, uncomment and update the line below -->
<!-- Closes #[issue-number] -->

## Test Plan
<!-- Describe the tests you ran to verify your changes with result
summaries. *Provide clear instructions so the plan can be easily
re-executed.* -->
Manual testing and updated CI unit tests

Prerequisites:

1. `uv run --with llama-stack llama stack list-deps starter | xargs -L1
uv pip install`

2. `llama stack run starter `

```
INFO     2025-10-23 15:36:17,387 llama_stack.cli.stack.run:100 cli: Using run configuration:                            
         /Users/ianmiller/llama-stack/llama_stack/distributions/starter/run.yaml                                        
INFO     2025-10-23 15:36:17,423 llama_stack.cli.stack.run:157 cli: HTTPS enabled with certificates:                    
           Key: None                                                                                                    
           Cert: None                                                                                                   
INFO     2025-10-23 15:36:17,424 llama_stack.cli.stack.run:159 cli: Listening on ['::', '0.0.0.0']:8321                 
INFO     2025-10-23 15:36:17,749 llama_stack.core.server.server:521 core::server: Run configuration:                    
INFO     2025-10-23 15:36:17,756 llama_stack.core.server.server:524 core::server: apis:                                 
         - agents                                                                                                       
         - batches                                                                                                      
         - datasetio                                                                                                    
         - eval                                                                                                         
         - files                                                                                                        
         - inference                                                                                                    
         - post_training                                                                                                
         - safety                                                                                                       
         - scoring                                                                                                      
         - tool_runtime                                                                                                 
         - vector_io                                                                                                    
         image_name: starter                                                                                            
         providers:                                                                                                     
           agents:                                                                                                      
           - config:                                                                                                    
               persistence:                                                                                             
                 agent_state:                                                                                           
                   backend: kv_default                                                                                  
                   namespace: agents                                                                                    
                 responses:                                                                                             
                   backend: sql_default                                                                                 
                   max_write_queue_size: 10000                                                                          
                   num_writers: 4                                                                                       
                   table_name: responses                                                                                
             provider_id: meta-reference                                                                                
             provider_type: inline::meta-reference                                                                      
           batches:                                                                                                     
           - config:                                                                                                    
               kvstore:                                                                                                 
                 backend: kv_default                                                                                    
                 namespace: batches                                                                                     
             provider_id: reference                                                                                     
             provider_type: inline::reference                                                                           
           datasetio:                                                                                                   
           - config:                                                                                                    
               kvstore:                                                                                                 
                 backend: kv_default                                                                                    
                 namespace: datasetio::huggingface                                                                      
             provider_id: huggingface                                                                                   
             provider_type: remote::huggingface                                                                         
           - config:                                                                                                    
               kvstore:                                                                                                 
                 backend: kv_default                                                                                    
                 namespace: datasetio::localfs                                                                          
             provider_id: localfs                                                                                       
             provider_type: inline::localfs                                                                             
           eval:                                                                                                        
           - config:                                                                                                    
               kvstore:                                                                                                 
                 backend: kv_default                                                                                    
                 namespace: eval                                                                                        
             provider_id: meta-reference                                                                                
             provider_type: inline::meta-reference                                                                      
           files:                                                                                                       
           - config:                                                                                                    
               metadata_store:                                                                                          
                 backend: sql_default                                                                                   
                 table_name: files_metadata                                                                             
               storage_dir: /Users/ianmiller/.llama/distributions/starter/files                                         
             provider_id: meta-reference-files                                                                          
             provider_type: inline::localfs                                                                             
           inference:                                                                                                   
           - config:                                                                                                    
               api_key: '********'                                                                                      
               url: https://api.fireworks.ai/inference/v1                                                               
             provider_id: fireworks                                                                                     
             provider_type: remote::fireworks                                                                           
           - config:                                                                                                    
               api_key: '********'                                                                                      
               url: https://api.together.xyz/v1                                                                         
             provider_id: together                                                                                      
             provider_type: remote::together                                                                            
           - config: {}                                                                                                 
             provider_id: bedrock                                                                                       
             provider_type: remote::bedrock                                                                             
           - config:                                                                                                    
               api_key: '********'                                                                                      
               base_url: https://api.openai.com/v1                                                                      
             provider_id: openai                                                                                        
             provider_type: remote::openai                                                                              
           - config:                                                                                                    
               api_key: '********'                                                                                      
             provider_id: anthropic                                                                                     
             provider_type: remote::anthropic                                                                           
           - config:                                                                                                    
               api_key: '********'                                                                                      
             provider_id: gemini                                                                                        
             provider_type: remote::gemini                                                                              
           - config:                                                                                                    
               api_key: '********'                                                                                      
               url: https://api.groq.com                                                                                
             provider_id: groq                                                                                          
             provider_type: remote::groq                                                                                
           - config:                                                                                                    
               api_key: '********'                                                                                      
               url: https://api.sambanova.ai/v1                                                                         
             provider_id: sambanova                                                                                     
             provider_type: remote::sambanova                                                                           
           - config: {}                                                                                                 
             provider_id: sentence-transformers                                                                         
             provider_type: inline::sentence-transformers                                                               
           post_training:                                                                                               
           - config:                                                                                                    
               checkpoint_format: meta                                                                                  
             provider_id: torchtune-cpu                                                                                 
             provider_type: inline::torchtune-cpu                                                                       
           safety:                                                                                                      
           - config:                                                                                                    
               excluded_categories: []                                                                                  
             provider_id: llama-guard                                                                                   
             provider_type: inline::llama-guard                                                                         
           - config: {}                                                                                                 
             provider_id: code-scanner                                                                                  
             provider_type: inline::code-scanner                                                                        
           scoring:                                                                                                     
           - config: {}                                                                                                 
             provider_id: basic                                                                                         
             provider_type: inline::basic                                                                               
           - config: {}                                                                                                 
             provider_id: llm-as-judge                                                                                  
             provider_type: inline::llm-as-judge                                                                        
           - config:                                                                                                    
               openai_api_key: '********'                                                                               
             provider_id: braintrust                                                                                    
             provider_type: inline::braintrust                                                                          
           tool_runtime:                                                                                                
           - config:                                                                                                    
               api_key: '********'                                                                                      
               max_results: 3                                                                                           
             provider_id: brave-search                                                                                  
             provider_type: remote::brave-search                                                                        
           - config:                                                                                                    
               api_key: '********'                                                                                      
               max_results: 3                                                                                           
             provider_id: tavily-search                                                                                 
             provider_type: remote::tavily-search                                                                       
           - config: {}                                                                                                 
             provider_id: rag-runtime                                                                                   
             provider_type: inline::rag-runtime                                                                         
           - config: {}                                                                                                 
             provider_id: model-context-protocol                                                                        
             provider_type: remote::model-context-protocol                                                              
           vector_io:                                                                                                   
           - config:                                                                                                    
               persistence:                                                                                             
                 backend: kv_default                                                                                    
                 namespace: vector_io::faiss                                                                            
             provider_id: faiss                                                                                         
             provider_type: inline::faiss                                                                               
           - config:                                                                                                    
               db_path: /Users/ianmiller/.llama/distributions/starter/sqlite_vec.db                                     
               persistence:                                                                                             
                 backend: kv_default                                                                                    
                 namespace: vector_io::sqlite_vec                                                                       
             provider_id: sqlite-vec                                                                                    
             provider_type: inline::sqlite-vec                                                                          
         registered_resources:                                                                                          
           benchmarks: []                                                                                               
           datasets: []                                                                                                 
           models: []                                                                                                   
           scoring_fns: []                                                                                              
           shields: []                                                                                                  
           tool_groups:                                                                                                 
           - provider_id: tavily-search                                                                                 
             toolgroup_id: builtin::websearch                                                                           
           - provider_id: rag-runtime                                                                                   
             toolgroup_id: builtin::rag                                                                                 
           vector_stores: []                                                                                            
         server:                                                                                                        
           port: 8321                                                                                                   
         storage:                                                                                                       
           backends:                                                                                                    
             kv_default:                                                                                                
               db_path: /Users/ianmiller/.llama/distributions/starter/kvstore.db                                        
               type: kv_sqlite                                                                                          
             sql_default:                                                                                               
               db_path: /Users/ianmiller/.llama/distributions/starter/sql_store.db                                      
               type: sql_sqlite                                                                                         
           stores:                                                                                                      
             conversations:                                                                                             
               backend: sql_default                                                                                     
               table_name: openai_conversations                                                                         
             inference:                                                                                                 
               backend: sql_default                                                                                     
               max_write_queue_size: 10000                                                                              
               num_writers: 4                                                                                           
               table_name: inference_store                                                                              
             metadata:                                                                                                  
               backend: kv_default                                                                                      
               namespace: registry                                                                                      
             prompts:                                                                                                   
               backend: kv_default                                                                                      
               namespace: prompts                                                                                       
         telemetry:                                                                                                     
           enabled: true                                                                                                
         vector_stores:                                                                                                 
           default_embedding_model:                                                                                     
             model_id: nomic-ai/nomic-embed-text-v1.5                                                                   
             provider_id: sentence-transformers                                                                         
           default_provider_id: faiss                                                                                   
         version: 2                                                                                                     
                                                                                                                        
INFO     2025-10-23 15:36:20,032 llama_stack.providers.utils.inference.inference_store:74 inference: Write queue        
         disabled for SQLite to avoid concurrency issues                                                                
WARNING  2025-10-23 15:36:20,422 llama_stack.providers.inline.telemetry.meta_reference.telemetry:84 telemetry:          
         OTEL_EXPORTER_OTLP_ENDPOINT is not set, skipping telemetry                                                     
INFO     2025-10-23 15:36:22,379 llama_stack.providers.utils.inference.openai_mixin:436 providers::utils:               
         OpenAIInferenceAdapter.list_provider_model_ids() returned 105 models                                           
INFO     2025-10-23 15:36:22,703 uvicorn.error:84 uncategorized: Started server process [17328]                         
INFO     2025-10-23 15:36:22,704 uvicorn.error:48 uncategorized: Waiting for application startup.                       
INFO     2025-10-23 15:36:22,706 llama_stack.core.server.server:179 core::server: Starting up Llama Stack server        
         (version: 0.3.0)                                                                                               
INFO     2025-10-23 15:36:22,707 llama_stack.core.stack:470 core: starting registry refresh task                        
INFO     2025-10-23 15:36:22,708 uvicorn.error:62 uncategorized: Application startup complete.                          
INFO     2025-10-23 15:36:22,708 uvicorn.error:216 uncategorized: Uvicorn running on http://['::', '0.0.0.0']:8321      
         (Press CTRL+C to quit)   
```
As you can see, prompts are attached to stores in config

Testing:

1. Create prompt:

```
curl -X POST http://localhost:8321/v1/prompts \                 
  -H "Content-Type: application/json" \
  -d '{
    "prompt": "Hello {{name}}! You are working at {{company}}. Your role is {{role}} at {{company}}. Remember, {{name}}, to be {{tone}}.",
    "variables": ["name", "company", "role", "tone"]
  }'
```

`{"prompt":"Hello {{name}}! You are working at {{company}}. Your role is
{{role}} at {{company}}. Remember, {{name}}, to be
{{tone}}.","version":1,"prompt_id":"pmpt_a90e09e67acfe23776f2778c603eb6c17e139dab5f6e163f","variables":["name","company","role","tone"],"is_default":false}%
`

2. Get prompt:

`curl -X GET
http://localhost:8321/v1/prompts/pmpt_a90e09e67acfe23776f2778c603eb6c17e139dab5f6e163f`

`{"prompt":"Hello {{name}}! You are working at {{company}}. Your role is
{{role}} at {{company}}. Remember, {{name}}, to be
{{tone}}.","version":1,"prompt_id":"pmpt_a90e09e67acfe23776f2778c603eb6c17e139dab5f6e163f","variables":["name","company","role","tone"],"is_default":false}%
`

3. Query sqlite KV storage to check created prompt:

```
sqlite> .mode column
sqlite> .headers on
sqlite> SELECT * FROM kvstore WHERE key LIKE 'prompts:v1:%';
key                                                           value                                                         expiration
------------------------------------------------------------  ------------------------------------------------------------  ----------
prompts:v1:pmpt_a90e09e67acfe23776f2778c603eb6c17e139dab5f6e  {"prompt_id": "pmpt_a90e09e67acfe23776f2778c603eb6c17e139dab            
163f:1                                                        5f6e163f", "prompt": "Hello {{name}}! You are working at {{c            
                                                              ompany}}. Your role is {{role}} at {{company}}. Remember, {{            
                                                              name}}, to be {{tone}}.", "version": 1, "variables": ["name"            
                                                              , "company", "role", "tone"], "is_default": false}                      

prompts:v1:pmpt_a90e09e67acfe23776f2778c603eb6c17e139dab5f6e  1                                                                       
163f:default                                                                                                                          
sqlite> 
```
2025-10-27 11:12:12 -07:00

880 lines
33 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from unittest.mock import patch
import pytest
import yaml
from pydantic import BaseModel, Field, ValidationError
from llama_stack.core.datatypes import Api, Provider, StackRunConfig
from llama_stack.core.distribution import INTERNAL_APIS, get_provider_registry, providable_apis
from llama_stack.core.storage.datatypes import (
InferenceStoreReference,
KVStoreReference,
ServerStoresConfig,
SqliteKVStoreConfig,
SqliteSqlStoreConfig,
SqlStoreReference,
StorageConfig,
)
from llama_stack.providers.datatypes import ProviderSpec
class SampleConfig(BaseModel):
foo: str = Field(
default="bar",
description="foo",
)
@classmethod
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
return {
"foo": "baz",
}
def _default_storage() -> StorageConfig:
return StorageConfig(
backends={
"kv_default": SqliteKVStoreConfig(db_path=":memory:"),
"sql_default": SqliteSqlStoreConfig(db_path=":memory:"),
},
stores=ServerStoresConfig(
metadata=KVStoreReference(backend="kv_default", namespace="registry"),
inference=InferenceStoreReference(backend="sql_default", table_name="inference_store"),
conversations=SqlStoreReference(backend="sql_default", table_name="conversations"),
prompts=KVStoreReference(backend="kv_default", namespace="prompts"),
),
)
def make_stack_config(**overrides) -> StackRunConfig:
storage = overrides.pop("storage", _default_storage())
defaults = dict(
image_name="test_image",
apis=[],
providers={},
storage=storage,
)
defaults.update(overrides)
return StackRunConfig(**defaults)
@pytest.fixture
def mock_providers():
"""Mock the available_providers function to return test providers."""
with patch("llama_stack.providers.registry.inference.available_providers") as mock:
mock.return_value = [
ProviderSpec(
provider_type="test_provider",
api=Api.inference,
adapter_type="test_adapter",
config_class="test_provider.config.TestProviderConfig",
)
]
yield mock
@pytest.fixture
def base_config(tmp_path):
"""Create a base StackRunConfig with common settings."""
return make_stack_config(
apis=["inference"],
providers={
"inference": [
Provider(
provider_id="sample_provider",
provider_type="sample",
config=SampleConfig.sample_run_config(),
)
]
},
external_providers_dir=str(tmp_path),
)
@pytest.fixture
def provider_spec_yaml():
"""Common provider spec YAML for testing."""
return """
adapter_type: test_provider
config_class: test_provider.config.TestProviderConfig
module: test_provider
api_dependencies:
- safety
"""
@pytest.fixture
def inline_provider_spec_yaml():
"""Common inline provider spec YAML for testing."""
return """
module: test_provider
config_class: test_provider.config.TestProviderConfig
pip_packages:
- test-package
api_dependencies:
- safety
optional_api_dependencies:
- vector_io
provider_data_validator: test_provider.validator.TestValidator
container_image: test-image:latest
"""
@pytest.fixture
def api_directories(tmp_path):
"""Create the API directory structure for testing."""
# Create remote provider directory
remote_inference_dir = tmp_path / "remote" / "inference"
remote_inference_dir.mkdir(parents=True, exist_ok=True)
# Create inline provider directory
inline_inference_dir = tmp_path / "inline" / "inference"
inline_inference_dir.mkdir(parents=True, exist_ok=True)
return remote_inference_dir, inline_inference_dir
def make_import_module_side_effect(
builtin_provider_spec=None,
external_module=None,
raise_for_external=False,
missing_get_provider_spec=False,
):
from types import SimpleNamespace
def import_module_side_effect(name):
if name == "llama_stack.providers.registry.inference":
mock_builtin = SimpleNamespace(
available_providers=lambda: [
builtin_provider_spec
or ProviderSpec(
api=Api.inference,
provider_type="test_provider",
config_class="test_provider.config.TestProviderConfig",
module="test_provider",
)
]
)
return mock_builtin
elif name == "external_test.provider":
if raise_for_external:
raise ModuleNotFoundError(name)
if missing_get_provider_spec:
return SimpleNamespace()
return external_module
else:
raise ModuleNotFoundError(name)
return import_module_side_effect
class TestProviderRegistry:
"""Test suite for provider registry functionality."""
def test_builtin_providers(self, mock_providers):
"""Test loading built-in providers."""
registry = get_provider_registry(None)
assert Api.inference in registry
assert "test_provider" in registry[Api.inference]
assert registry[Api.inference]["test_provider"].provider_type == "test_provider"
assert registry[Api.inference]["test_provider"].api == Api.inference
def test_internal_apis_excluded(self):
"""Test that internal APIs are excluded and APIs without provider registries are marked as internal."""
import importlib
apis = providable_apis()
for internal_api in INTERNAL_APIS:
assert internal_api not in apis, f"Internal API {internal_api} should not be in providable_apis"
for api in apis:
module_name = f"llama_stack.providers.registry.{api.name.lower()}"
try:
importlib.import_module(module_name)
except ImportError as err:
raise AssertionError(
f"API {api} is in providable_apis but has no provider registry module ({module_name})"
) from err
def test_external_remote_providers(self, api_directories, mock_providers, base_config, provider_spec_yaml):
"""Test loading external remote providers from YAML files."""
remote_dir, _ = api_directories
with open(remote_dir / "test_provider.yaml", "w") as f:
f.write(provider_spec_yaml)
registry = get_provider_registry(base_config)
assert len(registry[Api.inference]) == 2
assert Api.inference in registry
assert "remote::test_provider" in registry[Api.inference]
provider = registry[Api.inference]["remote::test_provider"]
assert provider.adapter_type == "test_provider"
assert provider.module == "test_provider"
assert provider.config_class == "test_provider.config.TestProviderConfig"
assert Api.safety in provider.api_dependencies
def test_external_inline_providers(self, api_directories, mock_providers, base_config, inline_provider_spec_yaml):
"""Test loading external inline providers from YAML files."""
_, inline_dir = api_directories
with open(inline_dir / "test_provider.yaml", "w") as f:
f.write(inline_provider_spec_yaml)
registry = get_provider_registry(base_config)
assert len(registry[Api.inference]) == 2
assert Api.inference in registry
assert "inline::test_provider" in registry[Api.inference]
provider = registry[Api.inference]["inline::test_provider"]
assert provider.provider_type == "inline::test_provider"
assert provider.module == "test_provider"
assert provider.config_class == "test_provider.config.TestProviderConfig"
assert provider.pip_packages == ["test-package"]
assert Api.safety in provider.api_dependencies
assert Api.vector_io in provider.optional_api_dependencies
assert provider.provider_data_validator == "test_provider.validator.TestValidator"
assert provider.container_image == "test-image:latest"
def test_invalid_yaml(self, api_directories, mock_providers, base_config):
"""Test handling of invalid YAML files."""
remote_dir, inline_dir = api_directories
with open(remote_dir / "invalid.yaml", "w") as f:
f.write("invalid: yaml: content: -")
with open(inline_dir / "invalid.yaml", "w") as f:
f.write("invalid: yaml: content: -")
with pytest.raises(yaml.YAMLError):
get_provider_registry(base_config)
def test_missing_directory(self, mock_providers):
"""Test handling of missing external providers directory."""
config = make_stack_config(
apis=["inference"],
providers={
"inference": [
Provider(
provider_id="sample_provider",
provider_type="sample",
config=SampleConfig.sample_run_config(),
)
]
},
external_providers_dir="/nonexistent/dir",
)
with pytest.raises(FileNotFoundError):
get_provider_registry(config)
def test_empty_api_directory(self, api_directories, mock_providers, base_config):
"""Test handling of empty API directory."""
registry = get_provider_registry(base_config)
assert len(registry[Api.inference]) == 1 # Only built-in provider
def test_malformed_remote_provider_spec(self, api_directories, mock_providers, base_config):
"""Test handling of malformed remote provider spec (missing required fields)."""
remote_dir, _ = api_directories
malformed_spec = """
adapter_type: test_provider
# Missing required fields
api_dependencies:
- safety
"""
with open(remote_dir / "malformed.yaml", "w") as f:
f.write(malformed_spec)
with pytest.raises(ValidationError):
get_provider_registry(base_config)
def test_malformed_inline_provider_spec(self, api_directories, mock_providers, base_config):
"""Test handling of malformed inline provider spec (missing required fields)."""
_, inline_dir = api_directories
malformed_spec = """
module: test_provider
# Missing required config_class
pip_packages:
- test-package
"""
with open(inline_dir / "malformed.yaml", "w") as f:
f.write(malformed_spec)
with pytest.raises(ValidationError) as exc_info:
get_provider_registry(base_config)
assert "config_class" in str(exc_info.value)
def test_external_provider_from_module_success(self, mock_providers):
"""Test loading an external provider from a module (success path)."""
from types import SimpleNamespace
from llama_stack.providers.datatypes import Api, ProviderSpec
# Simulate a provider module with get_provider_spec
fake_spec = ProviderSpec(
api=Api.inference,
provider_type="external_test",
config_class="external_test.config.ExternalTestConfig",
module="external_test",
)
fake_module = SimpleNamespace(get_provider_spec=lambda: fake_spec)
import_module_side_effect = make_import_module_side_effect(external_module=fake_module)
with patch("importlib.import_module", side_effect=import_module_side_effect) as mock_import:
config = make_stack_config(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="external_test",
provider_type="external_test",
config={},
module="external_test",
)
]
},
)
registry = get_provider_registry(config)
assert Api.inference in registry
assert "external_test" in registry[Api.inference]
provider = registry[Api.inference]["external_test"]
assert provider.module == "external_test"
assert provider.config_class == "external_test.config.ExternalTestConfig"
mock_import.assert_any_call("llama_stack.providers.registry.inference")
mock_import.assert_any_call("external_test.provider")
def test_external_provider_from_module_not_found(self, mock_providers):
"""Test handling ModuleNotFoundError for missing provider module."""
import_module_side_effect = make_import_module_side_effect(raise_for_external=True)
with patch("importlib.import_module", side_effect=import_module_side_effect):
config = make_stack_config(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="external_test",
provider_type="external_test",
config={},
module="external_test",
)
]
},
)
with pytest.raises(ValueError) as exc_info:
get_provider_registry(config)
assert "get_provider_spec not found" in str(exc_info.value)
def test_external_provider_from_module_missing_get_provider_spec(self, mock_providers):
"""Test handling missing get_provider_spec in provider module (should raise ValueError)."""
import_module_side_effect = make_import_module_side_effect(missing_get_provider_spec=True)
with patch("importlib.import_module", side_effect=import_module_side_effect):
config = make_stack_config(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="external_test",
provider_type="external_test",
config={},
module="external_test",
)
]
},
)
with pytest.raises(AttributeError):
get_provider_registry(config)
def test_external_provider_from_module_building(self, mock_providers):
"""Test loading an external provider from a module during build (building=True, partial spec)."""
from llama_stack.core.datatypes import BuildConfig, BuildProvider, DistributionSpec
from llama_stack.providers.datatypes import Api
# No importlib patch needed, should not import module when type of `config` is BuildConfig or DistributionSpec
build_config = BuildConfig(
version=2,
image_type="container",
image_name="test_image",
distribution_spec=DistributionSpec(
description="test",
providers={
"inference": [
BuildProvider(
provider_type="external_test",
module="external_test",
)
]
},
),
)
registry = get_provider_registry(build_config)
assert Api.inference in registry
assert "external_test" in registry[Api.inference]
provider = registry[Api.inference]["external_test"]
assert provider.module == "external_test"
assert provider.is_external is True
# config_class is empty string in partial spec
assert provider.config_class == ""
class TestGetExternalProvidersFromModule:
"""Test suite for installing external providers from module."""
def test_stackrunconfig_provider_without_module(self, mock_providers):
"""Test that providers without module attribute are skipped."""
from llama_stack.core.distribution import get_external_providers_from_module
import_module_side_effect = make_import_module_side_effect()
with patch("importlib.import_module", side_effect=import_module_side_effect):
config = make_stack_config(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="no_module",
provider_type="no_module",
config={},
)
]
},
)
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, config, building=False)
# Should not add anything to registry
assert len(result[Api.inference]) == 0
def test_stackrunconfig_with_version_spec(self, mock_providers):
"""Test provider with module containing version spec (e.g., package==1.0.0)."""
from types import SimpleNamespace
from llama_stack.core.distribution import get_external_providers_from_module
from llama_stack.providers.datatypes import ProviderSpec
fake_spec = ProviderSpec(
api=Api.inference,
provider_type="versioned_test",
config_class="versioned_test.config.VersionedTestConfig",
module="versioned_test==1.0.0",
)
fake_module = SimpleNamespace(get_provider_spec=lambda: fake_spec)
def import_side_effect(name):
if name == "versioned_test.provider":
return fake_module
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_side_effect):
config = make_stack_config(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="versioned",
provider_type="versioned_test",
config={},
module="versioned_test==1.0.0",
)
]
},
)
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, config, building=False)
assert "versioned_test" in result[Api.inference]
assert result[Api.inference]["versioned_test"].module == "versioned_test==1.0.0"
def test_buildconfig_does_not_import_module(self, mock_providers):
"""Test that BuildConfig does not import the module (building=True)."""
from llama_stack.core.datatypes import BuildConfig, BuildProvider, DistributionSpec
from llama_stack.core.distribution import get_external_providers_from_module
build_config = BuildConfig(
version=2,
image_type="container",
image_name="test_image",
distribution_spec=DistributionSpec(
description="test",
providers={
"inference": [
BuildProvider(
provider_type="build_test",
module="build_test==1.0.0",
)
]
},
),
)
# Should not call import_module at all when building
with patch("importlib.import_module") as mock_import:
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, build_config, building=True)
# Verify module was NOT imported
mock_import.assert_not_called()
# Verify partial spec was created
assert "build_test" in result[Api.inference]
provider = result[Api.inference]["build_test"]
assert provider.module == "build_test==1.0.0"
assert provider.is_external is True
assert provider.config_class == ""
assert provider.api == Api.inference
def test_buildconfig_multiple_providers(self, mock_providers):
"""Test BuildConfig with multiple providers for the same API."""
from llama_stack.core.datatypes import BuildConfig, BuildProvider, DistributionSpec
from llama_stack.core.distribution import get_external_providers_from_module
build_config = BuildConfig(
version=2,
image_type="container",
image_name="test_image",
distribution_spec=DistributionSpec(
description="test",
providers={
"inference": [
BuildProvider(provider_type="provider1", module="provider1"),
BuildProvider(provider_type="provider2", module="provider2"),
]
},
),
)
with patch("importlib.import_module") as mock_import:
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, build_config, building=True)
mock_import.assert_not_called()
assert "provider1" in result[Api.inference]
assert "provider2" in result[Api.inference]
def test_distributionspec_does_not_import_module(self, mock_providers):
"""Test that DistributionSpec does not import the module (building=True)."""
from llama_stack.core.datatypes import BuildProvider, DistributionSpec
from llama_stack.core.distribution import get_external_providers_from_module
dist_spec = DistributionSpec(
description="test distribution",
providers={
"inference": [
BuildProvider(
provider_type="dist_test",
module="dist_test==2.0.0",
)
]
},
)
# Should not call import_module at all when building
with patch("importlib.import_module") as mock_import:
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, dist_spec, building=True)
# Verify module was NOT imported
mock_import.assert_not_called()
# Verify partial spec was created
assert "dist_test" in result[Api.inference]
provider = result[Api.inference]["dist_test"]
assert provider.module == "dist_test==2.0.0"
assert provider.is_external is True
assert provider.config_class == ""
def test_list_return_from_get_provider_spec(self, mock_providers):
"""Test when get_provider_spec returns a list of specs."""
from types import SimpleNamespace
from llama_stack.core.distribution import get_external_providers_from_module
from llama_stack.providers.datatypes import ProviderSpec
spec1 = ProviderSpec(
api=Api.inference,
provider_type="list_test",
config_class="list_test.config.Config1",
module="list_test",
)
spec2 = ProviderSpec(
api=Api.inference,
provider_type="list_test_remote",
config_class="list_test.config.Config2",
module="list_test",
)
fake_module = SimpleNamespace(get_provider_spec=lambda: [spec1, spec2])
def import_side_effect(name):
if name == "list_test.provider":
return fake_module
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_side_effect):
config = make_stack_config(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="list_test",
provider_type="list_test",
config={},
module="list_test",
)
]
},
)
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, config, building=False)
# Only the matching provider_type should be added
assert "list_test" in result[Api.inference]
assert result[Api.inference]["list_test"].config_class == "list_test.config.Config1"
def test_list_return_filters_by_provider_type(self, mock_providers):
"""Test that list return filters specs by provider_type."""
from types import SimpleNamespace
from llama_stack.core.distribution import get_external_providers_from_module
from llama_stack.providers.datatypes import ProviderSpec
spec1 = ProviderSpec(
api=Api.inference,
provider_type="wanted",
config_class="test.Config1",
module="test",
)
spec2 = ProviderSpec(
api=Api.inference,
provider_type="unwanted",
config_class="test.Config2",
module="test",
)
fake_module = SimpleNamespace(get_provider_spec=lambda: [spec1, spec2])
def import_side_effect(name):
if name == "test.provider":
return fake_module
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_side_effect):
config = make_stack_config(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="wanted",
provider_type="wanted",
config={},
module="test",
)
]
},
)
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, config, building=False)
# Only the matching provider_type should be added
assert "wanted" in result[Api.inference]
assert "unwanted" not in result[Api.inference]
def test_list_return_adds_multiple_provider_types(self, mock_providers):
"""Test that list return adds multiple different provider_types when config requests them."""
from types import SimpleNamespace
from llama_stack.core.distribution import get_external_providers_from_module
from llama_stack.providers.datatypes import ProviderSpec
# Module returns both inline and remote variants
spec1 = ProviderSpec(
api=Api.inference,
provider_type="remote::ollama",
config_class="test.RemoteConfig",
module="test",
)
spec2 = ProviderSpec(
api=Api.inference,
provider_type="inline::ollama",
config_class="test.InlineConfig",
module="test",
)
fake_module = SimpleNamespace(get_provider_spec=lambda: [spec1, spec2])
def import_side_effect(name):
if name == "test.provider":
return fake_module
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_side_effect):
config = make_stack_config(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="remote_ollama",
provider_type="remote::ollama",
config={},
module="test",
),
Provider(
provider_id="inline_ollama",
provider_type="inline::ollama",
config={},
module="test",
),
]
},
)
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, config, building=False)
# Both provider types should be added to registry
assert "remote::ollama" in result[Api.inference]
assert "inline::ollama" in result[Api.inference]
assert result[Api.inference]["remote::ollama"].config_class == "test.RemoteConfig"
assert result[Api.inference]["inline::ollama"].config_class == "test.InlineConfig"
def test_module_not_found_raises_value_error(self, mock_providers):
"""Test that ModuleNotFoundError raises ValueError with helpful message."""
from llama_stack.core.distribution import get_external_providers_from_module
def import_side_effect(name):
if name == "missing_module.provider":
raise ModuleNotFoundError(name)
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_side_effect):
config = make_stack_config(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="missing",
provider_type="missing",
config={},
module="missing_module",
)
]
},
)
registry = {Api.inference: {}}
with pytest.raises(ValueError) as exc_info:
get_external_providers_from_module(registry, config, building=False)
assert "get_provider_spec not found" in str(exc_info.value)
def test_generic_exception_is_raised(self, mock_providers):
"""Test that generic exceptions are properly raised."""
from types import SimpleNamespace
from llama_stack.core.distribution import get_external_providers_from_module
def bad_spec():
raise RuntimeError("Something went wrong")
fake_module = SimpleNamespace(get_provider_spec=bad_spec)
def import_side_effect(name):
if name == "error_module.provider":
return fake_module
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_side_effect):
config = make_stack_config(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="error",
provider_type="error",
config={},
module="error_module",
)
]
},
)
registry = {Api.inference: {}}
with pytest.raises(RuntimeError) as exc_info:
get_external_providers_from_module(registry, config, building=False)
assert "Something went wrong" in str(exc_info.value)
def test_empty_provider_list(self, mock_providers):
"""Test with empty provider list."""
from llama_stack.core.distribution import get_external_providers_from_module
config = make_stack_config(
image_name="test_image",
providers={},
)
registry = {Api.inference: {}}
result = get_external_providers_from_module(registry, config, building=False)
# Should return registry unchanged
assert result == registry
assert len(result[Api.inference]) == 0
def test_multiple_apis_with_providers(self, mock_providers):
"""Test multiple APIs with providers."""
from types import SimpleNamespace
from llama_stack.core.distribution import get_external_providers_from_module
from llama_stack.providers.datatypes import ProviderSpec
inference_spec = ProviderSpec(
api=Api.inference,
provider_type="inf_test",
config_class="inf.Config",
module="inf_test",
)
safety_spec = ProviderSpec(
api=Api.safety,
provider_type="safe_test",
config_class="safe.Config",
module="safe_test",
)
def import_side_effect(name):
if name == "inf_test.provider":
return SimpleNamespace(get_provider_spec=lambda: inference_spec)
elif name == "safe_test.provider":
return SimpleNamespace(get_provider_spec=lambda: safety_spec)
raise ModuleNotFoundError(name)
with patch("importlib.import_module", side_effect=import_side_effect):
config = make_stack_config(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="inf",
provider_type="inf_test",
config={},
module="inf_test",
)
],
"safety": [
Provider(
provider_id="safe",
provider_type="safe_test",
config={},
module="safe_test",
)
],
},
)
registry = {Api.inference: {}, Api.safety: {}}
result = get_external_providers_from_module(registry, config, building=False)
assert "inf_test" in result[Api.inference]
assert "safe_test" in result[Api.safety]