feat: ability to execute external providers (#1672)

# What does this PR do?

Providers that live outside of the llama-stack codebase are now
supported.
A new property `external_providers_dir` has been added to the main
config and can be configured as follow:

```
external_providers_dir: /etc/llama-stack/providers.d/
```

Where the expected structure is:

```
providers.d/
  inference/
    custom_ollama.yaml
    vllm.yaml
  vector_io/
    qdrant.yaml
```

Where `custom_ollama.yaml` is:

```
adapter:
  adapter_type: custom_ollama
  pip_packages: ["ollama", "aiohttp"]
  config_class: llama_stack_ollama_provider.config.OllamaImplConfig
  module: llama_stack_ollama_provider
api_dependencies: []
optional_api_dependencies: []
```

Obviously the package must be installed on the system, here is the
`llama_stack_ollama_provider` example:

```
$ uv pip show llama-stack-ollama-provider
Using Python 3.10.16 environment at: /Users/leseb/Documents/AI/llama-stack/.venv
Name: llama-stack-ollama-provider
Version: 0.1.0
Location: /Users/leseb/Documents/AI/llama-stack/.venv/lib/python3.10/site-packages
Editable project location: /private/var/folders/mq/rnm5w_7s2d3fxmtkx02knvhm0000gn/T/tmp.ZBHU5Ezxg4/ollama/llama-stack-ollama-provider
Requires:
Required-by:
```

Closes: https://github.com/meta-llama/llama-stack/issues/658

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-04-09 10:30:41 +02:00 committed by GitHub
parent 45e210fd0c
commit 389767010b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 875 additions and 7 deletions

View file

@ -0,0 +1,93 @@
name: Test External Providers
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
jobs:
test-external-providers:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Install uv
uses: astral-sh/setup-uv@v5
with:
python-version: "3.10"
- name: Install Ollama
run: |
curl -fsSL https://ollama.com/install.sh | sh
- name: Pull Ollama image
run: |
ollama pull llama3.2:3b-instruct-fp16
- name: Start Ollama in background
run: |
nohup ollama run llama3.2:3b-instruct-fp16 --keepalive=30m > ollama.log 2>&1 &
- name: Set Up Environment and Install Dependencies
run: |
uv sync --extra dev --extra test
uv pip install -e .
- name: Install Ollama custom provider
run: |
mkdir -p tests/external-provider/llama-stack-provider-ollama/src/
cp -a llama_stack/providers/remote/inference/ollama/ tests/external-provider/llama-stack-provider-ollama/src/llama_stack_provider_ollama
uv pip install tests/external-provider/llama-stack-provider-ollama
- name: Create provider configuration
run: |
mkdir -p /tmp/providers.d/remote/inference
cp tests/external-provider/llama-stack-provider-ollama/custom_ollama.yaml /tmp/providers.d/remote/inference/custom_ollama.yaml
- name: Wait for Ollama to start
run: |
echo "Waiting for Ollama..."
for i in {1..30}; do
if curl -s http://localhost:11434 | grep -q "Ollama is running"; then
echo "Ollama is running!"
exit 0
fi
sleep 1
done
echo "Ollama failed to start"
ollama ps
ollama.log
exit 1
- name: Start Llama Stack server in background
env:
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
run: |
source .venv/bin/activate
nohup uv run llama stack run tests/external-provider/llama-stack-provider-ollama/run.yaml --image-type venv > server.log 2>&1 &
- name: Wait for Llama Stack server to be ready
run: |
echo "Waiting for Llama Stack server..."
for i in {1..30}; do
if curl -s http://localhost:8321/v1/health | grep -q "OK"; then
echo "Llama Stack server is up!"
if grep -q "remote::custom_ollama from /tmp/providers.d/remote/inference/custom_ollama.yaml" server.log; then
echo "Llama Stack server is using custom Ollama provider"
exit 0
else
echo "Llama Stack server is not using custom Ollama provider"
exit 1
fi
fi
sleep 1
done
echo "Llama Stack server failed to start"
cat server.log
exit 1
- name: run inference tests
run: |
uv run pytest -v tests/integration/inference/test_text_inference.py --stack-config="http://localhost:8321" --text-model="meta-llama/Llama-3.2-3B-Instruct" --embedding-model=all-MiniLM-L6-v2

View file

@ -0,0 +1,234 @@
# External Providers
Llama Stack supports external providers that live outside of the main codebase. This allows you to:
- Create and maintain your own providers independently
- Share providers with others without contributing to the main codebase
- Keep provider-specific code separate from the core Llama Stack code
## Configuration
To enable external providers, you need to configure the `external_providers_dir` in your Llama Stack configuration. This directory should contain your external provider specifications:
```yaml
external_providers_dir: /etc/llama-stack/providers.d/
```
## Directory Structure
The external providers directory should follow this structure:
```
providers.d/
remote/
inference/
custom_ollama.yaml
vllm.yaml
vector_io/
qdrant.yaml
safety/
llama-guard.yaml
inline/
inference/
custom_ollama.yaml
vllm.yaml
vector_io/
qdrant.yaml
safety/
llama-guard.yaml
```
Each YAML file in these directories defines a provider specification for that particular API.
## Provider Types
Llama Stack supports two types of external providers:
1. **Remote Providers**: Providers that communicate with external services (e.g., cloud APIs)
2. **Inline Providers**: Providers that run locally within the Llama Stack process
## Known External Providers
Here's a list of known external providers that you can use with Llama Stack:
| Type | Name | Description | Repository |
|------|------|-------------|------------|
| Remote | KubeFlow Training | Train models with KubeFlow | [llama-stack-provider-kft](https://github.com/opendatahub-io/llama-stack-provider-kft) |
### Remote Provider Specification
Remote providers are used when you need to communicate with external services. Here's an example for a custom Ollama provider:
```yaml
adapter:
adapter_type: custom_ollama
pip_packages:
- ollama
- aiohttp
config_class: llama_stack_ollama_provider.config.OllamaImplConfig
module: llama_stack_ollama_provider
api_dependencies: []
optional_api_dependencies: []
```
#### Adapter Configuration
The `adapter` section defines how to load and configure the provider:
- `adapter_type`: A unique identifier for this adapter
- `pip_packages`: List of Python packages required by the provider
- `config_class`: The full path to the configuration class
- `module`: The Python module containing the provider implementation
### Inline Provider Specification
Inline providers run locally within the Llama Stack process. Here's an example for a custom vector store provider:
```yaml
module: llama_stack_vector_provider
config_class: llama_stack_vector_provider.config.VectorStoreConfig
pip_packages:
- faiss-cpu
- numpy
api_dependencies:
- inference
optional_api_dependencies:
- vector_io
provider_data_validator: llama_stack_vector_provider.validator.VectorStoreValidator
container_image: custom-vector-store:latest # optional
```
#### Inline Provider Fields
- `module`: The Python module containing the provider implementation
- `config_class`: The full path to the configuration class
- `pip_packages`: List of Python packages required by the provider
- `api_dependencies`: List of Llama Stack APIs that this provider depends on
- `optional_api_dependencies`: List of optional Llama Stack APIs that this provider can use
- `provider_data_validator`: Optional validator for provider data
- `container_image`: Optional container image to use instead of pip packages
## Required Implementation
### Remote Providers
Remote providers must expose a `get_adapter_impl()` function in their module that takes two arguments:
1. `config`: An instance of the provider's config class
2. `deps`: A dictionary of API dependencies
This function must return an instance of the provider's adapter class that implements the required protocol for the API.
Example:
```python
async def get_adapter_impl(
config: OllamaImplConfig, deps: Dict[Api, Any]
) -> OllamaInferenceAdapter:
return OllamaInferenceAdapter(config)
```
### Inline Providers
Inline providers must expose a `get_provider_impl()` function in their module that takes two arguments:
1. `config`: An instance of the provider's config class
2. `deps`: A dictionary of API dependencies
Example:
```python
async def get_provider_impl(
config: VectorStoreConfig, deps: Dict[Api, Any]
) -> VectorStoreImpl:
impl = VectorStoreImpl(config, deps[Api.inference])
await impl.initialize()
return impl
```
## Dependencies
The provider package must be installed on the system. For example:
```bash
$ uv pip show llama-stack-ollama-provider
Name: llama-stack-ollama-provider
Version: 0.1.0
Location: /path/to/venv/lib/python3.10/site-packages
```
## Example: Custom Ollama Provider
Here's a complete example of creating and using a custom Ollama provider:
1. First, create the provider package:
```bash
mkdir -p llama-stack-provider-ollama
cd llama-stack-provider-ollama
git init
uv init
```
2. Edit `pyproject.toml`:
```toml
[project]
name = "llama-stack-provider-ollama"
version = "0.1.0"
description = "Ollama provider for Llama Stack"
requires-python = ">=3.10"
dependencies = ["llama-stack", "pydantic", "ollama", "aiohttp"]
```
3. Create the provider specification:
```yaml
# /etc/llama-stack/providers.d/remote/inference/custom_ollama.yaml
adapter:
adapter_type: custom_ollama
pip_packages: ["ollama", "aiohttp"]
config_class: llama_stack_provider_ollama.config.OllamaImplConfig
module: llama_stack_provider_ollama
api_dependencies: []
optional_api_dependencies: []
```
4. Install the provider:
```bash
uv pip install -e .
```
5. Configure Llama Stack to use external providers:
```yaml
external_providers_dir: /etc/llama-stack/providers.d/
```
The provider will now be available in Llama Stack with the type `remote::custom_ollama`.
## Best Practices
1. **Package Naming**: Use the prefix `llama-stack-provider-` for your provider packages to make them easily identifiable.
2. **Version Management**: Keep your provider package versioned and compatible with the Llama Stack version you're using.
3. **Dependencies**: Only include the minimum required dependencies in your provider package.
4. **Documentation**: Include clear documentation in your provider package about:
- Installation requirements
- Configuration options
- Usage examples
- Any limitations or known issues
5. **Testing**: Include tests in your provider package to ensure it works correctly with Llama Stack.
You can refer to the [integration tests
guide](https://github.com/meta-llama/llama-stack/blob/main/tests/integration/README.md) for more
information. Execute the test for the Provider type you are developing.
## Troubleshooting
If your external provider isn't being loaded:
1. Check that the `external_providers_dir` path is correct and accessible.
2. Verify that the YAML files are properly formatted.
3. Ensure all required Python packages are installed.
4. Check the Llama Stack server logs for any error messages - turn on debug logging to get more
information using `LLAMA_STACK_LOGGING=all=debug`.
5. Verify that the provider package is installed in your Python environment.

View file

@ -11,6 +11,10 @@ Providers come in two flavors:
Importantly, Llama Stack always strives to provide at least one fully inline provider for each API so you can iterate on a fully featured environment locally.
## External Providers
Llama Stack supports external providers that live outside of the main codebase. This allows you to create and maintain your own providers independently. See the [External Providers Guide](external) for details.
## Agents
Run multi-step agentic workflows with LLMs with tool usage, memory (RAG), etc.
@ -50,6 +54,7 @@ The following providers (i.e., databases) are available for Vector IO:
```{toctree}
:maxdepth: 1
external
vector_io/faiss
vector_io/sqlite-vec
vector_io/chromadb

View file

@ -312,6 +312,11 @@ a default SQLite store will be used.""",
description="Configuration for the HTTP(S) server",
)
external_providers_dir: Optional[str] = Field(
default=None,
description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.",
)
class BuildConfig(BaseModel):
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION

View file

@ -4,12 +4,25 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import glob
import importlib
from typing import Dict, List
import os
from typing import Any, Dict, List
import yaml
from pydantic import BaseModel
from llama_stack.providers.datatypes import Api, ProviderSpec
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import (
AdapterSpec,
Api,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
)
logger = get_logger(name=__name__, category="core")
def stack_apis() -> List[Api]:
@ -59,11 +72,116 @@ def providable_apis() -> List[Api]:
return [api for api in Api if api not in routing_table_apis and api != Api.inspect and api != Api.providers]
def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]:
ret = {}
def _load_remote_provider_spec(spec_data: Dict[str, Any], api: Api) -> ProviderSpec:
adapter = AdapterSpec(**spec_data["adapter"])
spec = remote_provider_spec(
api=api,
adapter=adapter,
api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])],
)
return spec
def _load_inline_provider_spec(spec_data: Dict[str, Any], api: Api, provider_name: str) -> ProviderSpec:
spec = InlineProviderSpec(
api=api,
provider_type=f"inline::{provider_name}",
pip_packages=spec_data.get("pip_packages", []),
module=spec_data["module"],
config_class=spec_data["config_class"],
api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])],
optional_api_dependencies=[Api(dep) for dep in spec_data.get("optional_api_dependencies", [])],
provider_data_validator=spec_data.get("provider_data_validator"),
container_image=spec_data.get("container_image"),
)
return spec
def get_provider_registry(config: StackRunConfig | None = None) -> Dict[Api, Dict[str, ProviderSpec]]:
"""Get the provider registry, optionally including external providers.
This function loads both built-in providers and external providers from YAML files.
External providers are loaded from a directory structure like:
providers.d/
remote/
inference/
custom_ollama.yaml
vllm.yaml
vector_io/
qdrant.yaml
safety/
llama-guard.yaml
inline/
inference/
custom_ollama.yaml
vllm.yaml
vector_io/
qdrant.yaml
safety/
llama-guard.yaml
Args:
config: Optional StackRunConfig containing the external providers directory path
Returns:
A dictionary mapping APIs to their available providers
Raises:
FileNotFoundError: If the external providers directory doesn't exist
ValueError: If any provider spec is invalid
"""
ret: Dict[Api, Dict[str, ProviderSpec]] = {}
for api in providable_apis():
name = api.name.lower()
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
ret[api] = {a.provider_type: a for a in module.available_providers()}
logger.debug(f"Importing module {name}")
try:
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
ret[api] = {a.provider_type: a for a in module.available_providers()}
except ImportError as e:
logger.warning(f"Failed to import module {name}: {e}")
if config and config.external_providers_dir:
external_providers_dir = os.path.abspath(config.external_providers_dir)
if not os.path.exists(external_providers_dir):
raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}")
logger.info(f"Loading external providers from {external_providers_dir}")
for api in providable_apis():
api_name = api.name.lower()
# Process both remote and inline providers
for provider_type in ["remote", "inline"]:
api_dir = os.path.join(external_providers_dir, provider_type, api_name)
if not os.path.exists(api_dir):
logger.debug(f"No {provider_type} provider directory found for {api_name}")
continue
# Look for provider spec files in the API directory
for spec_path in glob.glob(os.path.join(api_dir, "*.yaml")):
provider_name = os.path.splitext(os.path.basename(spec_path))[0]
logger.info(f"Loading {provider_type} provider spec from {spec_path}")
try:
with open(spec_path) as f:
spec_data = yaml.safe_load(f)
if provider_type == "remote":
spec = _load_remote_provider_spec(spec_data, api)
provider_type_key = f"remote::{provider_name}"
else:
spec = _load_inline_provider_spec(spec_data, api, provider_name)
provider_type_key = f"inline::{provider_name}"
logger.info(f"Loaded {provider_type} provider spec for {provider_type_key} from {spec_path}")
if provider_type_key in ret[api]:
logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}")
ret[api][provider_type_key] = spec
except yaml.YAMLError as yaml_err:
logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}")
raise yaml_err
except Exception as e:
logger.error(f"Failed to load provider spec from {spec_path}: {e}")
raise e
return ret

View file

@ -351,6 +351,7 @@ async def instantiate_provider(
if not hasattr(provider_spec, "module"):
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
logger.debug(f"Instantiating provider {provider.provider_id} from {provider_spec.module}")
module = importlib.import_module(provider_spec.module)
args = []
if isinstance(provider_spec, RemoteProviderSpec):

View file

@ -218,7 +218,7 @@ async def construct_stack(
run_config: StackRunConfig, provider_registry: Optional[ProviderRegistry] = None
) -> Dict[Api, Any]:
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(), dist_registry)
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(run_config), dist_registry)
await register_resources(run_config, impls)
return impls

View file

@ -0,0 +1,3 @@
# Ollama external provider for Llama Stack
Template code to create a new external provider for Llama Stack.

View file

@ -0,0 +1,7 @@
adapter:
adapter_type: custom_ollama
pip_packages: ["ollama", "aiohttp"]
config_class: llama_stack_provider_ollama.config.OllamaImplConfig
module: llama_stack_provider_ollama
api_dependencies: []
optional_api_dependencies: []

View file

@ -0,0 +1,44 @@
[project]
dependencies = [
"llama-stack",
"pydantic",
"ollama",
"aiohttp",
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
]
name = "llama-stack-provider-ollama"
version = "0.1.0"
description = "External provider for Ollama using the Llama Stack API"
readme = "README.md"
requires-python = ">=3.10"

View file

@ -0,0 +1,135 @@
version: '2'
image_name: ollama
apis:
- agents
- datasetio
- eval
- inference
- safety
- scoring
- telemetry
- tool_runtime
- vector_io
providers:
inference:
- provider_id: custom_ollama
provider_type: remote::custom_ollama
config:
url: ${env.OLLAMA_URL:http://localhost:11434}
vector_io:
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config:
excluded_categories: []
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/agents_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/ollama/trace_store.db}
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/meta_reference_eval.db
datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/huggingface_datasetio.db
- provider_id: localfs
provider_type: inline::localfs
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/localfs_datasetio.db
scoring:
- provider_id: basic
provider_type: inline::basic
config: {}
- provider_id: llm-as-judge
provider_type: inline::llm-as-judge
config: {}
- provider_id: braintrust
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:}
max_results: 3
- provider_id: code-interpreter
provider_type: inline::code-interpreter
config: {}
- provider_id: rag-runtime
provider_type: inline::rag-runtime
config: {}
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
config: {}
- provider_id: wolfram-alpha
provider_type: remote::wolfram-alpha
config:
api_key: ${env.WOLFRAM_ALPHA_API_KEY:}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/registry.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}
provider_id: custom_ollama
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: custom_ollama
provider_model_id: all-minilm:latest
model_type: embedding
shields: []
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
- toolgroup_id: builtin::code_interpreter
provider_id: code-interpreter
- toolgroup_id: builtin::wolfram_alpha
provider_id: wolfram-alpha
server:
port: 8321
external_providers_dir: /tmp/providers.d

View file

@ -0,0 +1,223 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from unittest.mock import patch
import pytest
import yaml
from pydantic import BaseModel, Field, ValidationError
from llama_stack.distribution.datatypes import Api, Provider, StackRunConfig
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.providers.datatypes import ProviderSpec
class SampleConfig(BaseModel):
foo: str = Field(
default="bar",
description="foo",
)
@classmethod
def sample_run_config(cls, **kwargs: Any) -> Dict[str, Any]:
return {
"foo": "baz",
}
@pytest.fixture
def mock_providers():
"""Mock the available_providers function to return test providers."""
with patch("llama_stack.providers.registry.inference.available_providers") as mock:
mock.return_value = [
ProviderSpec(
provider_type="test_provider",
api=Api.inference,
adapter_type="test_adapter",
config_class="test_provider.config.TestProviderConfig",
)
]
yield mock
@pytest.fixture
def base_config(tmp_path):
"""Create a base StackRunConfig with common settings."""
return StackRunConfig(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="sample_provider",
provider_type="sample",
config=SampleConfig.sample_run_config(),
)
]
},
external_providers_dir=str(tmp_path),
)
@pytest.fixture
def provider_spec_yaml():
"""Common provider spec YAML for testing."""
return """
adapter:
adapter_type: test_provider
config_class: test_provider.config.TestProviderConfig
module: test_provider
api_dependencies:
- safety
"""
@pytest.fixture
def inline_provider_spec_yaml():
"""Common inline provider spec YAML for testing."""
return """
module: test_provider
config_class: test_provider.config.TestProviderConfig
pip_packages:
- test-package
api_dependencies:
- safety
optional_api_dependencies:
- vector_io
provider_data_validator: test_provider.validator.TestValidator
container_image: test-image:latest
"""
@pytest.fixture
def api_directories(tmp_path):
"""Create the API directory structure for testing."""
# Create remote provider directory
remote_inference_dir = tmp_path / "remote" / "inference"
remote_inference_dir.mkdir(parents=True, exist_ok=True)
# Create inline provider directory
inline_inference_dir = tmp_path / "inline" / "inference"
inline_inference_dir.mkdir(parents=True, exist_ok=True)
return remote_inference_dir, inline_inference_dir
class TestProviderRegistry:
"""Test suite for provider registry functionality."""
def test_builtin_providers(self, mock_providers):
"""Test loading built-in providers."""
registry = get_provider_registry(None)
assert Api.inference in registry
assert "test_provider" in registry[Api.inference]
assert registry[Api.inference]["test_provider"].provider_type == "test_provider"
assert registry[Api.inference]["test_provider"].api == Api.inference
def test_external_remote_providers(self, api_directories, mock_providers, base_config, provider_spec_yaml):
"""Test loading external remote providers from YAML files."""
remote_dir, _ = api_directories
with open(remote_dir / "test_provider.yaml", "w") as f:
f.write(provider_spec_yaml)
registry = get_provider_registry(base_config)
assert len(registry[Api.inference]) == 2
assert Api.inference in registry
assert "remote::test_provider" in registry[Api.inference]
provider = registry[Api.inference]["remote::test_provider"]
assert provider.adapter.adapter_type == "test_provider"
assert provider.adapter.module == "test_provider"
assert provider.adapter.config_class == "test_provider.config.TestProviderConfig"
assert Api.safety in provider.api_dependencies
def test_external_inline_providers(self, api_directories, mock_providers, base_config, inline_provider_spec_yaml):
"""Test loading external inline providers from YAML files."""
_, inline_dir = api_directories
with open(inline_dir / "test_provider.yaml", "w") as f:
f.write(inline_provider_spec_yaml)
registry = get_provider_registry(base_config)
assert len(registry[Api.inference]) == 2
assert Api.inference in registry
assert "inline::test_provider" in registry[Api.inference]
provider = registry[Api.inference]["inline::test_provider"]
assert provider.provider_type == "inline::test_provider"
assert provider.module == "test_provider"
assert provider.config_class == "test_provider.config.TestProviderConfig"
assert provider.pip_packages == ["test-package"]
assert Api.safety in provider.api_dependencies
assert Api.vector_io in provider.optional_api_dependencies
assert provider.provider_data_validator == "test_provider.validator.TestValidator"
assert provider.container_image == "test-image:latest"
def test_invalid_yaml(self, api_directories, mock_providers, base_config):
"""Test handling of invalid YAML files."""
remote_dir, inline_dir = api_directories
with open(remote_dir / "invalid.yaml", "w") as f:
f.write("invalid: yaml: content: -")
with open(inline_dir / "invalid.yaml", "w") as f:
f.write("invalid: yaml: content: -")
with pytest.raises(yaml.YAMLError):
get_provider_registry(base_config)
def test_missing_directory(self, mock_providers):
"""Test handling of missing external providers directory."""
config = StackRunConfig(
image_name="test_image",
providers={
"inference": [
Provider(
provider_id="sample_provider",
provider_type="sample",
config=SampleConfig.sample_run_config(),
)
]
},
external_providers_dir="/nonexistent/dir",
)
with pytest.raises(FileNotFoundError):
get_provider_registry(config)
def test_empty_api_directory(self, api_directories, mock_providers, base_config):
"""Test handling of empty API directory."""
registry = get_provider_registry(base_config)
assert len(registry[Api.inference]) == 1 # Only built-in provider
def test_malformed_remote_provider_spec(self, api_directories, mock_providers, base_config):
"""Test handling of malformed remote provider spec (missing required fields)."""
remote_dir, _ = api_directories
malformed_spec = """
adapter:
adapter_type: test_provider
# Missing required fields
api_dependencies:
- safety
"""
with open(remote_dir / "malformed.yaml", "w") as f:
f.write(malformed_spec)
with pytest.raises(ValidationError):
get_provider_registry(base_config)
def test_malformed_inline_provider_spec(self, api_directories, mock_providers, base_config):
"""Test handling of malformed inline provider spec (missing required fields)."""
_, inline_dir = api_directories
malformed_spec = """
module: test_provider
# Missing required config_class
pip_packages:
- test-package
"""
with open(inline_dir / "malformed.yaml", "w") as f:
f.write(malformed_spec)
with pytest.raises(KeyError) as exc_info:
get_provider_registry(base_config)
assert "config_class" in str(exc_info.value)