From 389767010b0333c49cf6cb86122308a5ec474621 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Wed, 9 Apr 2025 10:30:41 +0200 Subject: [PATCH] feat: ability to execute external providers (#1672) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? Providers that live outside of the llama-stack codebase are now supported. A new property `external_providers_dir` has been added to the main config and can be configured as follow: ``` external_providers_dir: /etc/llama-stack/providers.d/ ``` Where the expected structure is: ``` providers.d/ inference/ custom_ollama.yaml vllm.yaml vector_io/ qdrant.yaml ``` Where `custom_ollama.yaml` is: ``` adapter: adapter_type: custom_ollama pip_packages: ["ollama", "aiohttp"] config_class: llama_stack_ollama_provider.config.OllamaImplConfig module: llama_stack_ollama_provider api_dependencies: [] optional_api_dependencies: [] ``` Obviously the package must be installed on the system, here is the `llama_stack_ollama_provider` example: ``` $ uv pip show llama-stack-ollama-provider Using Python 3.10.16 environment at: /Users/leseb/Documents/AI/llama-stack/.venv Name: llama-stack-ollama-provider Version: 0.1.0 Location: /Users/leseb/Documents/AI/llama-stack/.venv/lib/python3.10/site-packages Editable project location: /private/var/folders/mq/rnm5w_7s2d3fxmtkx02knvhm0000gn/T/tmp.ZBHU5Ezxg4/ollama/llama-stack-ollama-provider Requires: Required-by: ``` Closes: https://github.com/meta-llama/llama-stack/issues/658 Signed-off-by: Sébastien Han --- .github/workflows/test-external-providers.yml | 93 +++++++ docs/source/providers/external.md | 234 ++++++++++++++++++ docs/source/providers/index.md | 5 + llama_stack/distribution/datatypes.py | 5 + llama_stack/distribution/distribution.py | 130 +++++++++- llama_stack/distribution/resolver.py | 1 + llama_stack/distribution/stack.py | 2 +- .../llama-stack-provider-ollama/README.md | 3 + .../custom_ollama.yaml | 7 + .../pyproject.toml | 44 ++++ .../llama-stack-provider-ollama/run.yaml | 135 ++++++++++ tests/unit/distribution/test_distribution.py | 223 +++++++++++++++++ 12 files changed, 875 insertions(+), 7 deletions(-) create mode 100644 .github/workflows/test-external-providers.yml create mode 100644 docs/source/providers/external.md create mode 100644 tests/external-provider/llama-stack-provider-ollama/README.md create mode 100644 tests/external-provider/llama-stack-provider-ollama/custom_ollama.yaml create mode 100644 tests/external-provider/llama-stack-provider-ollama/pyproject.toml create mode 100644 tests/external-provider/llama-stack-provider-ollama/run.yaml create mode 100644 tests/unit/distribution/test_distribution.py diff --git a/.github/workflows/test-external-providers.yml b/.github/workflows/test-external-providers.yml new file mode 100644 index 000000000..2ead8f845 --- /dev/null +++ b/.github/workflows/test-external-providers.yml @@ -0,0 +1,93 @@ +name: Test External Providers + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +jobs: + test-external-providers: + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Install uv + uses: astral-sh/setup-uv@v5 + with: + python-version: "3.10" + + - name: Install Ollama + run: | + curl -fsSL https://ollama.com/install.sh | sh + + - name: Pull Ollama image + run: | + ollama pull llama3.2:3b-instruct-fp16 + + - name: Start Ollama in background + run: | + nohup ollama run llama3.2:3b-instruct-fp16 --keepalive=30m > ollama.log 2>&1 & + + - name: Set Up Environment and Install Dependencies + run: | + uv sync --extra dev --extra test + uv pip install -e . + + - name: Install Ollama custom provider + run: | + mkdir -p tests/external-provider/llama-stack-provider-ollama/src/ + cp -a llama_stack/providers/remote/inference/ollama/ tests/external-provider/llama-stack-provider-ollama/src/llama_stack_provider_ollama + uv pip install tests/external-provider/llama-stack-provider-ollama + + - name: Create provider configuration + run: | + mkdir -p /tmp/providers.d/remote/inference + cp tests/external-provider/llama-stack-provider-ollama/custom_ollama.yaml /tmp/providers.d/remote/inference/custom_ollama.yaml + + - name: Wait for Ollama to start + run: | + echo "Waiting for Ollama..." + for i in {1..30}; do + if curl -s http://localhost:11434 | grep -q "Ollama is running"; then + echo "Ollama is running!" + exit 0 + fi + sleep 1 + done + echo "Ollama failed to start" + ollama ps + ollama.log + exit 1 + + - name: Start Llama Stack server in background + env: + INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct" + run: | + source .venv/bin/activate + nohup uv run llama stack run tests/external-provider/llama-stack-provider-ollama/run.yaml --image-type venv > server.log 2>&1 & + + - name: Wait for Llama Stack server to be ready + run: | + echo "Waiting for Llama Stack server..." + for i in {1..30}; do + if curl -s http://localhost:8321/v1/health | grep -q "OK"; then + echo "Llama Stack server is up!" + if grep -q "remote::custom_ollama from /tmp/providers.d/remote/inference/custom_ollama.yaml" server.log; then + echo "Llama Stack server is using custom Ollama provider" + exit 0 + else + echo "Llama Stack server is not using custom Ollama provider" + exit 1 + fi + fi + sleep 1 + done + echo "Llama Stack server failed to start" + cat server.log + exit 1 + + - name: run inference tests + run: | + uv run pytest -v tests/integration/inference/test_text_inference.py --stack-config="http://localhost:8321" --text-model="meta-llama/Llama-3.2-3B-Instruct" --embedding-model=all-MiniLM-L6-v2 diff --git a/docs/source/providers/external.md b/docs/source/providers/external.md new file mode 100644 index 000000000..90fc77979 --- /dev/null +++ b/docs/source/providers/external.md @@ -0,0 +1,234 @@ +# External Providers + +Llama Stack supports external providers that live outside of the main codebase. This allows you to: +- Create and maintain your own providers independently +- Share providers with others without contributing to the main codebase +- Keep provider-specific code separate from the core Llama Stack code + +## Configuration + +To enable external providers, you need to configure the `external_providers_dir` in your Llama Stack configuration. This directory should contain your external provider specifications: + +```yaml +external_providers_dir: /etc/llama-stack/providers.d/ +``` + +## Directory Structure + +The external providers directory should follow this structure: + +``` +providers.d/ + remote/ + inference/ + custom_ollama.yaml + vllm.yaml + vector_io/ + qdrant.yaml + safety/ + llama-guard.yaml + inline/ + inference/ + custom_ollama.yaml + vllm.yaml + vector_io/ + qdrant.yaml + safety/ + llama-guard.yaml +``` + +Each YAML file in these directories defines a provider specification for that particular API. + +## Provider Types + +Llama Stack supports two types of external providers: + +1. **Remote Providers**: Providers that communicate with external services (e.g., cloud APIs) +2. **Inline Providers**: Providers that run locally within the Llama Stack process + +## Known External Providers + +Here's a list of known external providers that you can use with Llama Stack: + +| Type | Name | Description | Repository | +|------|------|-------------|------------| +| Remote | KubeFlow Training | Train models with KubeFlow | [llama-stack-provider-kft](https://github.com/opendatahub-io/llama-stack-provider-kft) | + +### Remote Provider Specification + +Remote providers are used when you need to communicate with external services. Here's an example for a custom Ollama provider: + +```yaml +adapter: + adapter_type: custom_ollama + pip_packages: + - ollama + - aiohttp + config_class: llama_stack_ollama_provider.config.OllamaImplConfig + module: llama_stack_ollama_provider +api_dependencies: [] +optional_api_dependencies: [] +``` + +#### Adapter Configuration + +The `adapter` section defines how to load and configure the provider: + +- `adapter_type`: A unique identifier for this adapter +- `pip_packages`: List of Python packages required by the provider +- `config_class`: The full path to the configuration class +- `module`: The Python module containing the provider implementation + +### Inline Provider Specification + +Inline providers run locally within the Llama Stack process. Here's an example for a custom vector store provider: + +```yaml +module: llama_stack_vector_provider +config_class: llama_stack_vector_provider.config.VectorStoreConfig +pip_packages: + - faiss-cpu + - numpy +api_dependencies: + - inference +optional_api_dependencies: + - vector_io +provider_data_validator: llama_stack_vector_provider.validator.VectorStoreValidator +container_image: custom-vector-store:latest # optional +``` + +#### Inline Provider Fields + +- `module`: The Python module containing the provider implementation +- `config_class`: The full path to the configuration class +- `pip_packages`: List of Python packages required by the provider +- `api_dependencies`: List of Llama Stack APIs that this provider depends on +- `optional_api_dependencies`: List of optional Llama Stack APIs that this provider can use +- `provider_data_validator`: Optional validator for provider data +- `container_image`: Optional container image to use instead of pip packages + +## Required Implementation + +### Remote Providers + +Remote providers must expose a `get_adapter_impl()` function in their module that takes two arguments: +1. `config`: An instance of the provider's config class +2. `deps`: A dictionary of API dependencies + +This function must return an instance of the provider's adapter class that implements the required protocol for the API. + +Example: +```python +async def get_adapter_impl( + config: OllamaImplConfig, deps: Dict[Api, Any] +) -> OllamaInferenceAdapter: + return OllamaInferenceAdapter(config) +``` + +### Inline Providers + +Inline providers must expose a `get_provider_impl()` function in their module that takes two arguments: +1. `config`: An instance of the provider's config class +2. `deps`: A dictionary of API dependencies + +Example: +```python +async def get_provider_impl( + config: VectorStoreConfig, deps: Dict[Api, Any] +) -> VectorStoreImpl: + impl = VectorStoreImpl(config, deps[Api.inference]) + await impl.initialize() + return impl +``` + +## Dependencies + +The provider package must be installed on the system. For example: + +```bash +$ uv pip show llama-stack-ollama-provider +Name: llama-stack-ollama-provider +Version: 0.1.0 +Location: /path/to/venv/lib/python3.10/site-packages +``` + +## Example: Custom Ollama Provider + +Here's a complete example of creating and using a custom Ollama provider: + +1. First, create the provider package: + +```bash +mkdir -p llama-stack-provider-ollama +cd llama-stack-provider-ollama +git init +uv init +``` + +2. Edit `pyproject.toml`: + +```toml +[project] +name = "llama-stack-provider-ollama" +version = "0.1.0" +description = "Ollama provider for Llama Stack" +requires-python = ">=3.10" +dependencies = ["llama-stack", "pydantic", "ollama", "aiohttp"] +``` + +3. Create the provider specification: + +```yaml +# /etc/llama-stack/providers.d/remote/inference/custom_ollama.yaml +adapter: + adapter_type: custom_ollama + pip_packages: ["ollama", "aiohttp"] + config_class: llama_stack_provider_ollama.config.OllamaImplConfig + module: llama_stack_provider_ollama +api_dependencies: [] +optional_api_dependencies: [] +``` + +4. Install the provider: + +```bash +uv pip install -e . +``` + +5. Configure Llama Stack to use external providers: + +```yaml +external_providers_dir: /etc/llama-stack/providers.d/ +``` + +The provider will now be available in Llama Stack with the type `remote::custom_ollama`. + +## Best Practices + +1. **Package Naming**: Use the prefix `llama-stack-provider-` for your provider packages to make them easily identifiable. + +2. **Version Management**: Keep your provider package versioned and compatible with the Llama Stack version you're using. + +3. **Dependencies**: Only include the minimum required dependencies in your provider package. + +4. **Documentation**: Include clear documentation in your provider package about: + - Installation requirements + - Configuration options + - Usage examples + - Any limitations or known issues + +5. **Testing**: Include tests in your provider package to ensure it works correctly with Llama Stack. +You can refer to the [integration tests +guide](https://github.com/meta-llama/llama-stack/blob/main/tests/integration/README.md) for more +information. Execute the test for the Provider type you are developing. + +## Troubleshooting + +If your external provider isn't being loaded: + +1. Check that the `external_providers_dir` path is correct and accessible. +2. Verify that the YAML files are properly formatted. +3. Ensure all required Python packages are installed. +4. Check the Llama Stack server logs for any error messages - turn on debug logging to get more + information using `LLAMA_STACK_LOGGING=all=debug`. +5. Verify that the provider package is installed in your Python environment. diff --git a/docs/source/providers/index.md b/docs/source/providers/index.md index f8997a281..75faf7c00 100644 --- a/docs/source/providers/index.md +++ b/docs/source/providers/index.md @@ -11,6 +11,10 @@ Providers come in two flavors: Importantly, Llama Stack always strives to provide at least one fully inline provider for each API so you can iterate on a fully featured environment locally. +## External Providers + +Llama Stack supports external providers that live outside of the main codebase. This allows you to create and maintain your own providers independently. See the [External Providers Guide](external) for details. + ## Agents Run multi-step agentic workflows with LLMs with tool usage, memory (RAG), etc. @@ -50,6 +54,7 @@ The following providers (i.e., databases) are available for Vector IO: ```{toctree} :maxdepth: 1 +external vector_io/faiss vector_io/sqlite-vec vector_io/chromadb diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index 48f1925dd..b24b0ec50 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -312,6 +312,11 @@ a default SQLite store will be used.""", description="Configuration for the HTTP(S) server", ) + external_providers_dir: Optional[str] = Field( + default=None, + description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.", + ) + class BuildConfig(BaseModel): version: str = LLAMA_STACK_BUILD_CONFIG_VERSION diff --git a/llama_stack/distribution/distribution.py b/llama_stack/distribution/distribution.py index ddb727663..d4447139c 100644 --- a/llama_stack/distribution/distribution.py +++ b/llama_stack/distribution/distribution.py @@ -4,12 +4,25 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import glob import importlib -from typing import Dict, List +import os +from typing import Any, Dict, List +import yaml from pydantic import BaseModel -from llama_stack.providers.datatypes import Api, ProviderSpec +from llama_stack.distribution.datatypes import StackRunConfig +from llama_stack.log import get_logger +from llama_stack.providers.datatypes import ( + AdapterSpec, + Api, + InlineProviderSpec, + ProviderSpec, + remote_provider_spec, +) + +logger = get_logger(name=__name__, category="core") def stack_apis() -> List[Api]: @@ -59,11 +72,116 @@ def providable_apis() -> List[Api]: return [api for api in Api if api not in routing_table_apis and api != Api.inspect and api != Api.providers] -def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]: - ret = {} +def _load_remote_provider_spec(spec_data: Dict[str, Any], api: Api) -> ProviderSpec: + adapter = AdapterSpec(**spec_data["adapter"]) + spec = remote_provider_spec( + api=api, + adapter=adapter, + api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])], + ) + return spec + + +def _load_inline_provider_spec(spec_data: Dict[str, Any], api: Api, provider_name: str) -> ProviderSpec: + spec = InlineProviderSpec( + api=api, + provider_type=f"inline::{provider_name}", + pip_packages=spec_data.get("pip_packages", []), + module=spec_data["module"], + config_class=spec_data["config_class"], + api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])], + optional_api_dependencies=[Api(dep) for dep in spec_data.get("optional_api_dependencies", [])], + provider_data_validator=spec_data.get("provider_data_validator"), + container_image=spec_data.get("container_image"), + ) + return spec + + +def get_provider_registry(config: StackRunConfig | None = None) -> Dict[Api, Dict[str, ProviderSpec]]: + """Get the provider registry, optionally including external providers. + + This function loads both built-in providers and external providers from YAML files. + External providers are loaded from a directory structure like: + + providers.d/ + remote/ + inference/ + custom_ollama.yaml + vllm.yaml + vector_io/ + qdrant.yaml + safety/ + llama-guard.yaml + inline/ + inference/ + custom_ollama.yaml + vllm.yaml + vector_io/ + qdrant.yaml + safety/ + llama-guard.yaml + + Args: + config: Optional StackRunConfig containing the external providers directory path + + Returns: + A dictionary mapping APIs to their available providers + + Raises: + FileNotFoundError: If the external providers directory doesn't exist + ValueError: If any provider spec is invalid + """ + + ret: Dict[Api, Dict[str, ProviderSpec]] = {} for api in providable_apis(): name = api.name.lower() - module = importlib.import_module(f"llama_stack.providers.registry.{name}") - ret[api] = {a.provider_type: a for a in module.available_providers()} + logger.debug(f"Importing module {name}") + try: + module = importlib.import_module(f"llama_stack.providers.registry.{name}") + ret[api] = {a.provider_type: a for a in module.available_providers()} + except ImportError as e: + logger.warning(f"Failed to import module {name}: {e}") + if config and config.external_providers_dir: + external_providers_dir = os.path.abspath(config.external_providers_dir) + if not os.path.exists(external_providers_dir): + raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}") + logger.info(f"Loading external providers from {external_providers_dir}") + + for api in providable_apis(): + api_name = api.name.lower() + + # Process both remote and inline providers + for provider_type in ["remote", "inline"]: + api_dir = os.path.join(external_providers_dir, provider_type, api_name) + if not os.path.exists(api_dir): + logger.debug(f"No {provider_type} provider directory found for {api_name}") + continue + + # Look for provider spec files in the API directory + for spec_path in glob.glob(os.path.join(api_dir, "*.yaml")): + provider_name = os.path.splitext(os.path.basename(spec_path))[0] + logger.info(f"Loading {provider_type} provider spec from {spec_path}") + + try: + with open(spec_path) as f: + spec_data = yaml.safe_load(f) + + if provider_type == "remote": + spec = _load_remote_provider_spec(spec_data, api) + provider_type_key = f"remote::{provider_name}" + else: + spec = _load_inline_provider_spec(spec_data, api, provider_name) + provider_type_key = f"inline::{provider_name}" + + logger.info(f"Loaded {provider_type} provider spec for {provider_type_key} from {spec_path}") + if provider_type_key in ret[api]: + logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}") + ret[api][provider_type_key] = spec + except yaml.YAMLError as yaml_err: + logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}") + raise yaml_err + except Exception as e: + logger.error(f"Failed to load provider spec from {spec_path}: {e}") + raise e return ret diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 25fe3f184..33ad343ec 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -351,6 +351,7 @@ async def instantiate_provider( if not hasattr(provider_spec, "module"): raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute") + logger.debug(f"Instantiating provider {provider.provider_id} from {provider_spec.module}") module = importlib.import_module(provider_spec.module) args = [] if isinstance(provider_spec, RemoteProviderSpec): diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py index 9c9289a77..d70878db4 100644 --- a/llama_stack/distribution/stack.py +++ b/llama_stack/distribution/stack.py @@ -218,7 +218,7 @@ async def construct_stack( run_config: StackRunConfig, provider_registry: Optional[ProviderRegistry] = None ) -> Dict[Api, Any]: dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name) - impls = await resolve_impls(run_config, provider_registry or get_provider_registry(), dist_registry) + impls = await resolve_impls(run_config, provider_registry or get_provider_registry(run_config), dist_registry) await register_resources(run_config, impls) return impls diff --git a/tests/external-provider/llama-stack-provider-ollama/README.md b/tests/external-provider/llama-stack-provider-ollama/README.md new file mode 100644 index 000000000..8bd2b6a87 --- /dev/null +++ b/tests/external-provider/llama-stack-provider-ollama/README.md @@ -0,0 +1,3 @@ +# Ollama external provider for Llama Stack + +Template code to create a new external provider for Llama Stack. diff --git a/tests/external-provider/llama-stack-provider-ollama/custom_ollama.yaml b/tests/external-provider/llama-stack-provider-ollama/custom_ollama.yaml new file mode 100644 index 000000000..f0960b4d8 --- /dev/null +++ b/tests/external-provider/llama-stack-provider-ollama/custom_ollama.yaml @@ -0,0 +1,7 @@ +adapter: + adapter_type: custom_ollama + pip_packages: ["ollama", "aiohttp"] + config_class: llama_stack_provider_ollama.config.OllamaImplConfig + module: llama_stack_provider_ollama +api_dependencies: [] +optional_api_dependencies: [] diff --git a/tests/external-provider/llama-stack-provider-ollama/pyproject.toml b/tests/external-provider/llama-stack-provider-ollama/pyproject.toml new file mode 100644 index 000000000..ddebc54b0 --- /dev/null +++ b/tests/external-provider/llama-stack-provider-ollama/pyproject.toml @@ -0,0 +1,44 @@ +[project] +dependencies = [ + "llama-stack", + "pydantic", + "ollama", + "aiohttp", + "aiosqlite", + "autoevals", + "blobfile", + "chardet", + "chromadb-client", + "datasets", + "faiss-cpu", + "fastapi", + "fire", + "httpx", + "matplotlib", + "mcp", + "nltk", + "numpy", + "openai", + "opentelemetry-exporter-otlp-proto-http", + "opentelemetry-sdk", + "pandas", + "pillow", + "psycopg2-binary", + "pymongo", + "pypdf", + "redis", + "requests", + "scikit-learn", + "scipy", + "sentencepiece", + "tqdm", + "transformers", + "tree_sitter", + "uvicorn", +] + +name = "llama-stack-provider-ollama" +version = "0.1.0" +description = "External provider for Ollama using the Llama Stack API" +readme = "README.md" +requires-python = ">=3.10" diff --git a/tests/external-provider/llama-stack-provider-ollama/run.yaml b/tests/external-provider/llama-stack-provider-ollama/run.yaml new file mode 100644 index 000000000..7a3636c4d --- /dev/null +++ b/tests/external-provider/llama-stack-provider-ollama/run.yaml @@ -0,0 +1,135 @@ +version: '2' +image_name: ollama +apis: +- agents +- datasetio +- eval +- inference +- safety +- scoring +- telemetry +- tool_runtime +- vector_io +providers: + inference: + - provider_id: custom_ollama + provider_type: remote::custom_ollama + config: + url: ${env.OLLAMA_URL:http://localhost:11434} + vector_io: + - provider_id: faiss + provider_type: inline::faiss + config: + kvstore: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db + safety: + - provider_id: llama-guard + provider_type: inline::llama-guard + config: + excluded_categories: [] + agents: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + persistence_store: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/agents_store.db + telemetry: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + service_name: ${env.OTEL_SERVICE_NAME:llama-stack} + sinks: ${env.TELEMETRY_SINKS:console,sqlite} + sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/ollama/trace_store.db} + eval: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + kvstore: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/meta_reference_eval.db + datasetio: + - provider_id: huggingface + provider_type: remote::huggingface + config: + kvstore: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/huggingface_datasetio.db + - provider_id: localfs + provider_type: inline::localfs + config: + kvstore: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/localfs_datasetio.db + scoring: + - provider_id: basic + provider_type: inline::basic + config: {} + - provider_id: llm-as-judge + provider_type: inline::llm-as-judge + config: {} + - provider_id: braintrust + provider_type: inline::braintrust + config: + openai_api_key: ${env.OPENAI_API_KEY:} + tool_runtime: + - provider_id: brave-search + provider_type: remote::brave-search + config: + api_key: ${env.BRAVE_SEARCH_API_KEY:} + max_results: 3 + - provider_id: tavily-search + provider_type: remote::tavily-search + config: + api_key: ${env.TAVILY_SEARCH_API_KEY:} + max_results: 3 + - provider_id: code-interpreter + provider_type: inline::code-interpreter + config: {} + - provider_id: rag-runtime + provider_type: inline::rag-runtime + config: {} + - provider_id: model-context-protocol + provider_type: remote::model-context-protocol + config: {} + - provider_id: wolfram-alpha + provider_type: remote::wolfram-alpha + config: + api_key: ${env.WOLFRAM_ALPHA_API_KEY:} +metadata_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/registry.db +models: +- metadata: {} + model_id: ${env.INFERENCE_MODEL} + provider_id: custom_ollama + model_type: llm +- metadata: + embedding_dimension: 384 + model_id: all-MiniLM-L6-v2 + provider_id: custom_ollama + provider_model_id: all-minilm:latest + model_type: embedding +shields: [] +vector_dbs: [] +datasets: [] +scoring_fns: [] +benchmarks: [] +tool_groups: +- toolgroup_id: builtin::websearch + provider_id: tavily-search +- toolgroup_id: builtin::rag + provider_id: rag-runtime +- toolgroup_id: builtin::code_interpreter + provider_id: code-interpreter +- toolgroup_id: builtin::wolfram_alpha + provider_id: wolfram-alpha +server: + port: 8321 +external_providers_dir: /tmp/providers.d diff --git a/tests/unit/distribution/test_distribution.py b/tests/unit/distribution/test_distribution.py new file mode 100644 index 000000000..a4daffb82 --- /dev/null +++ b/tests/unit/distribution/test_distribution.py @@ -0,0 +1,223 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Dict +from unittest.mock import patch + +import pytest +import yaml +from pydantic import BaseModel, Field, ValidationError + +from llama_stack.distribution.datatypes import Api, Provider, StackRunConfig +from llama_stack.distribution.distribution import get_provider_registry +from llama_stack.providers.datatypes import ProviderSpec + + +class SampleConfig(BaseModel): + foo: str = Field( + default="bar", + description="foo", + ) + + @classmethod + def sample_run_config(cls, **kwargs: Any) -> Dict[str, Any]: + return { + "foo": "baz", + } + + +@pytest.fixture +def mock_providers(): + """Mock the available_providers function to return test providers.""" + with patch("llama_stack.providers.registry.inference.available_providers") as mock: + mock.return_value = [ + ProviderSpec( + provider_type="test_provider", + api=Api.inference, + adapter_type="test_adapter", + config_class="test_provider.config.TestProviderConfig", + ) + ] + yield mock + + +@pytest.fixture +def base_config(tmp_path): + """Create a base StackRunConfig with common settings.""" + return StackRunConfig( + image_name="test_image", + providers={ + "inference": [ + Provider( + provider_id="sample_provider", + provider_type="sample", + config=SampleConfig.sample_run_config(), + ) + ] + }, + external_providers_dir=str(tmp_path), + ) + + +@pytest.fixture +def provider_spec_yaml(): + """Common provider spec YAML for testing.""" + return """ +adapter: + adapter_type: test_provider + config_class: test_provider.config.TestProviderConfig + module: test_provider +api_dependencies: + - safety +""" + + +@pytest.fixture +def inline_provider_spec_yaml(): + """Common inline provider spec YAML for testing.""" + return """ +module: test_provider +config_class: test_provider.config.TestProviderConfig +pip_packages: + - test-package +api_dependencies: + - safety +optional_api_dependencies: + - vector_io +provider_data_validator: test_provider.validator.TestValidator +container_image: test-image:latest +""" + + +@pytest.fixture +def api_directories(tmp_path): + """Create the API directory structure for testing.""" + # Create remote provider directory + remote_inference_dir = tmp_path / "remote" / "inference" + remote_inference_dir.mkdir(parents=True, exist_ok=True) + + # Create inline provider directory + inline_inference_dir = tmp_path / "inline" / "inference" + inline_inference_dir.mkdir(parents=True, exist_ok=True) + + return remote_inference_dir, inline_inference_dir + + +class TestProviderRegistry: + """Test suite for provider registry functionality.""" + + def test_builtin_providers(self, mock_providers): + """Test loading built-in providers.""" + registry = get_provider_registry(None) + + assert Api.inference in registry + assert "test_provider" in registry[Api.inference] + assert registry[Api.inference]["test_provider"].provider_type == "test_provider" + assert registry[Api.inference]["test_provider"].api == Api.inference + + def test_external_remote_providers(self, api_directories, mock_providers, base_config, provider_spec_yaml): + """Test loading external remote providers from YAML files.""" + remote_dir, _ = api_directories + with open(remote_dir / "test_provider.yaml", "w") as f: + f.write(provider_spec_yaml) + + registry = get_provider_registry(base_config) + assert len(registry[Api.inference]) == 2 + + assert Api.inference in registry + assert "remote::test_provider" in registry[Api.inference] + provider = registry[Api.inference]["remote::test_provider"] + assert provider.adapter.adapter_type == "test_provider" + assert provider.adapter.module == "test_provider" + assert provider.adapter.config_class == "test_provider.config.TestProviderConfig" + assert Api.safety in provider.api_dependencies + + def test_external_inline_providers(self, api_directories, mock_providers, base_config, inline_provider_spec_yaml): + """Test loading external inline providers from YAML files.""" + _, inline_dir = api_directories + with open(inline_dir / "test_provider.yaml", "w") as f: + f.write(inline_provider_spec_yaml) + + registry = get_provider_registry(base_config) + assert len(registry[Api.inference]) == 2 + + assert Api.inference in registry + assert "inline::test_provider" in registry[Api.inference] + provider = registry[Api.inference]["inline::test_provider"] + assert provider.provider_type == "inline::test_provider" + assert provider.module == "test_provider" + assert provider.config_class == "test_provider.config.TestProviderConfig" + assert provider.pip_packages == ["test-package"] + assert Api.safety in provider.api_dependencies + assert Api.vector_io in provider.optional_api_dependencies + assert provider.provider_data_validator == "test_provider.validator.TestValidator" + assert provider.container_image == "test-image:latest" + + def test_invalid_yaml(self, api_directories, mock_providers, base_config): + """Test handling of invalid YAML files.""" + remote_dir, inline_dir = api_directories + with open(remote_dir / "invalid.yaml", "w") as f: + f.write("invalid: yaml: content: -") + with open(inline_dir / "invalid.yaml", "w") as f: + f.write("invalid: yaml: content: -") + + with pytest.raises(yaml.YAMLError): + get_provider_registry(base_config) + + def test_missing_directory(self, mock_providers): + """Test handling of missing external providers directory.""" + config = StackRunConfig( + image_name="test_image", + providers={ + "inference": [ + Provider( + provider_id="sample_provider", + provider_type="sample", + config=SampleConfig.sample_run_config(), + ) + ] + }, + external_providers_dir="/nonexistent/dir", + ) + with pytest.raises(FileNotFoundError): + get_provider_registry(config) + + def test_empty_api_directory(self, api_directories, mock_providers, base_config): + """Test handling of empty API directory.""" + registry = get_provider_registry(base_config) + assert len(registry[Api.inference]) == 1 # Only built-in provider + + def test_malformed_remote_provider_spec(self, api_directories, mock_providers, base_config): + """Test handling of malformed remote provider spec (missing required fields).""" + remote_dir, _ = api_directories + malformed_spec = """ +adapter: + adapter_type: test_provider + # Missing required fields +api_dependencies: + - safety +""" + with open(remote_dir / "malformed.yaml", "w") as f: + f.write(malformed_spec) + + with pytest.raises(ValidationError): + get_provider_registry(base_config) + + def test_malformed_inline_provider_spec(self, api_directories, mock_providers, base_config): + """Test handling of malformed inline provider spec (missing required fields).""" + _, inline_dir = api_directories + malformed_spec = """ +module: test_provider +# Missing required config_class +pip_packages: + - test-package +""" + with open(inline_dir / "malformed.yaml", "w") as f: + f.write(malformed_spec) + + with pytest.raises(KeyError) as exc_info: + get_provider_registry(base_config) + assert "config_class" in str(exc_info.value)