diff --git a/.github/workflows/test-external-providers.yml b/.github/workflows/test-external.yml similarity index 50% rename from .github/workflows/test-external-providers.yml rename to .github/workflows/test-external.yml index cdf18fab7..d4b222e70 100644 --- a/.github/workflows/test-external-providers.yml +++ b/.github/workflows/test-external.yml @@ -1,4 +1,4 @@ -name: Test External Providers +name: Test External API and Providers on: push: @@ -11,10 +11,10 @@ on: - 'uv.lock' - 'pyproject.toml' - 'requirements.txt' - - '.github/workflows/test-external-providers.yml' # This workflow + - '.github/workflows/test-external.yml' # This workflow jobs: - test-external-providers: + test-external: runs-on: ubuntu-latest strategy: matrix: @@ -28,24 +28,23 @@ jobs: - name: Install dependencies uses: ./.github/actions/setup-runner - - name: Apply image type to config file + - name: Create API configuration run: | - yq -i '.image_type = "${{ matrix.image-type }}"' tests/external-provider/llama-stack-provider-ollama/custom-distro.yaml - cat tests/external-provider/llama-stack-provider-ollama/custom-distro.yaml - - - name: Setup directory for Ollama custom provider - run: | - mkdir -p tests/external-provider/llama-stack-provider-ollama/src/ - cp -a llama_stack/providers/remote/inference/ollama/ tests/external-provider/llama-stack-provider-ollama/src/llama_stack_provider_ollama + mkdir -p /home/runner/.llama/apis.d + cp tests/external/weather.yaml /home/runner/.llama/apis.d/weather.yaml - name: Create provider configuration run: | - mkdir -p /home/runner/.llama/providers.d/remote/inference - cp tests/external-provider/llama-stack-provider-ollama/custom_ollama.yaml /home/runner/.llama/providers.d/remote/inference/custom_ollama.yaml + mkdir -p /home/runner/.llama/providers.d/remote/weather + cp tests/external/kaze.yaml /home/runner/.llama/providers.d/remote/weather/kaze.yaml + + - name: Print distro dependencies + run: | + USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. llama stack build --config tests/external/build.yaml --print-deps-only - name: Build distro from config file run: | - USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. llama stack build --config tests/external-provider/llama-stack-provider-ollama/custom-distro.yaml + USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. llama stack build --config tests/external/build.yaml - name: Start Llama Stack server in background if: ${{ matrix.image-type }} == 'venv' @@ -55,19 +54,22 @@ jobs: # Use the virtual environment created by the build step (name comes from build config) source ci-test/bin/activate uv pip list - nohup llama stack run tests/external-provider/llama-stack-provider-ollama/run.yaml --image-type ${{ matrix.image-type }} > server.log 2>&1 & + nohup llama stack run tests/external/run-byoa.yaml --image-type ${{ matrix.image-type }} > server.log 2>&1 & - name: Wait for Llama Stack server to be ready run: | + echo "Waiting for Llama Stack server..." for i in {1..30}; do - if ! grep -q "Successfully loaded external provider remote::custom_ollama" server.log; then - echo "Waiting for Llama Stack server to load the provider..." - sleep 1 - else - echo "Provider loaded" + if curl -sSf http://localhost:8321/v1/health | grep -q "OK"; then + echo "Llama Stack server is up!" exit 0 fi + sleep 1 done - echo "Provider failed to load" + echo "Llama Stack server failed to start" cat server.log exit 1 + + - name: Test external API + run: | + curl -sSf http://localhost:8321/v1/weather/locations diff --git a/docs/source/apis/external.md b/docs/source/apis/external.md new file mode 100644 index 000000000..025267c33 --- /dev/null +++ b/docs/source/apis/external.md @@ -0,0 +1,392 @@ +# External APIs + +Llama Stack supports external APIs that live outside of the main codebase. This allows you to: +- Create and maintain your own APIs independently +- Share APIs with others without contributing to the main codebase +- Keep API-specific code separate from the core Llama Stack code + +## Configuration + +To enable external APIs, you need to configure the `external_apis_dir` in your Llama Stack configuration. This directory should contain your external API specifications: + +```yaml +external_apis_dir: ~/.llama/apis.d/ +``` + +## Directory Structure + +The external APIs directory should follow this structure: + +``` +apis.d/ + custom_api1.yaml + custom_api2.yaml +``` + +Each YAML file in these directories defines an API specification. + +## API Specification + +Here's an example of an external API specification for a weather API: + +```yaml +module: weather +api_dependencies: + - inference +protocol: WeatherAPI +name: weather +pip_packages: + - llama-stack-api-weather +``` + +### API Specification Fields + +- `module`: Python module containing the API implementation +- `protocol`: Name of the protocol class for the API +- `name`: Name of the API +- `pip_packages`: List of pip packages to install the API, typically a single package + +## Required Implementation + +External APIs must expose a `available_providers()` function in their module that returns a list of provider names: + +```python +# llama_stack_api_weather/api.py +from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec + + +def available_providers() -> list[ProviderSpec]: + return [ + InlineProviderSpec( + api=Api.weather, + provider_type="inline::darksky", + pip_packages=[], + module="llama_stack_provider_darksky", + config_class="llama_stack_provider_darksky.DarkSkyWeatherImplConfig", + ), + ] +``` + +A Protocol class like so: + +```python +# llama_stack_api_weather/api.py +from typing import Protocol + +from llama_stack.schema_utils import webmethod + + +class WeatherAPI(Protocol): + """ + A protocol for the Weather API. + """ + + @webmethod(route="/locations", method="GET") + async def get_available_locations() -> dict[str, list[str]]: + """ + Get the available locations. + """ + ... +``` + +## Example: Custom API + +Here's a complete example of creating and using a custom API: + +1. First, create the API package: + +```bash +mkdir -p llama-stack-api-weather +cd llama-stack-api-weather +mkdir src/llama_stack_api_weather +git init +uv init +``` + +2. Edit `pyproject.toml`: + +```toml +[project] +name = "llama-stack-api-weather" +version = "0.1.0" +description = "Weather API for Llama Stack" +readme = "README.md" +requires-python = ">=3.10" +dependencies = ["llama-stack", "pydantic"] + +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +where = ["src"] +include = ["llama_stack_api_weather", "llama_stack_api_weather.*"] +``` + +3. Create the initial files: + +```bash +touch src/llama_stack_api_weather/__init__.py +touch src/llama_stack_api_weather/api.py +``` + +```python +# llama-stack-api-weather/src/llama_stack_api_weather/__init__.py +"""Weather API for Llama Stack.""" + +from .api import WeatherAPI, available_providers + +__all__ = ["WeatherAPI", "available_providers"] +``` + +4. Create the API implementation: + +```python +# llama-stack-api-weather/src/llama_stack_api_weather/weather.py +from typing import Protocol + +from llama_stack.providers.datatypes import ( + AdapterSpec, + Api, + ProviderSpec, + RemoteProviderSpec, +) +from llama_stack.schema_utils import webmethod + + +def available_providers() -> list[ProviderSpec]: + return [ + RemoteProviderSpec( + api=Api.weather, + provider_type="remote::kaze", + config_class="llama_stack_provider_kaze.KazeProviderConfig", + adapter=AdapterSpec( + adapter_type="kaze", + module="llama_stack_provider_kaze", + pip_packages=["llama_stack_provider_kaze"], + config_class="llama_stack_provider_kaze.KazeProviderConfig", + ), + ), + ] + + +class WeatherProvider(Protocol): + """ + A protocol for the Weather API. + """ + + @webmethod(route="/weather/locations", method="GET") + async def get_available_locations() -> dict[str, list[str]]: + """ + Get the available locations. + """ + ... +``` + +5. Create the API specification: + +```yaml +# ~/.llama/apis.d/weather.yaml +module: llama_stack_api_weather +name: weather +pip_packages: ["llama-stack-api-weather"] +protocol: WeatherProvider + +``` + +6. Install the API package: + +```bash +uv pip install -e . +``` + +7. Configure Llama Stack to use external APIs: + +```yaml +version: "2" +image_name: "llama-stack-api-weather" +apis: + - weather +providers: {} +external_apis_dir: ~/.llama/apis.d +``` + +The API will now be available at `/v1/weather/locations`. + +## Example: custom provider for the weather API + +1. Create the provider package: + +```bash +mkdir -p llama-stack-provider-kaze +cd llama-stack-provider-kaze +uv init +``` + +2. Edit `pyproject.toml`: + +```toml +[project] +name = "llama-stack-provider-kaze" +version = "0.1.0" +description = "Kaze weather provider for Llama Stack" +readme = "README.md" +requires-python = ">=3.10" +dependencies = ["llama-stack", "pydantic", "aiohttp"] + +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +where = ["src"] +include = ["llama_stack_provider_kaze", "llama_stack_provider_kaze.*"] +``` + +3. Create the initial files: + +```bash +touch src/llama_stack_provider_kaze/__init__.py +touch src/llama_stack_provider_kaze/kaze.py +``` + +4. Create the provider implementation: + + +Initialization function: + +```python +# llama-stack-provider-kaze/src/llama_stack_provider_kaze/__init__.py +"""Kaze weather provider for Llama Stack.""" + +from .config import KazeProviderConfig +from .kaze import WeatherKazeAdapter + +__all__ = ["KazeProviderConfig", "WeatherKazeAdapter"] + + +async def get_adapter_impl(config: KazeProviderConfig, _deps): + from .kaze import WeatherKazeAdapter + + impl = WeatherKazeAdapter(config) + await impl.initialize() + return impl +``` + +Configuration: + +```python +# llama-stack-provider-kaze/src/llama_stack_provider_kaze/config.py +from pydantic import BaseModel, Field + + +class KazeProviderConfig(BaseModel): + """Configuration for the Kaze weather provider.""" + + base_url: str = Field( + "https://api.kaze.io/v1", + description="Base URL for the Kaze weather API", + ) +``` + +Main implementation: + +```python +# llama-stack-provider-kaze/src/llama_stack_provider_kaze/kaze.py +from llama_stack_api_weather.api import WeatherProvider + +from .config import KazeProviderConfig + + +class WeatherKazeAdapter(WeatherProvider): + """Kaze weather provider implementation.""" + + def __init__( + self, + config: KazeProviderConfig, + ) -> None: + self.config = config + + async def initialize(self) -> None: + pass + + async def get_available_locations(self) -> dict[str, list[str]]: + """Get available weather locations.""" + return {"locations": ["Paris", "Tokyo"]} +``` + +5. Create the provider specification: + +```yaml +# ~/.llama/providers.d/remote/weather/kaze.yaml +adapter: + adapter_type: kaze + pip_packages: ["llama_stack_provider_kaze"] + config_class: llama_stack_provider_kaze.config.KazeProviderConfig + module: llama_stack_provider_kaze +optional_api_dependencies: [] +``` + +6. Install the provider package: + +```bash +uv pip install -e . +``` + +7. Configure Llama Stack to use the provider: + +```yaml +# ~/.llama/run-byoa.yaml +version: "2" +image_name: "llama-stack-api-weather" +apis: + - weather +providers: + weather: + - provider_id: kaze + provider_type: remote::kaze + config: {} +external_apis_dir: ~/.llama/apis.d +external_providers_dir: ~/.llama/providers.d +server: + port: 8321 +``` + +8. Run the server: + +```bash +python -m llama_stack.distribution.server.server --yaml-config ~/.llama/run-byoa.yaml +``` + +9. Test the API: + +```bash +curl -sSf http://127.0.0.1:8321/v1/weather/locations +{"locations":["Paris","Tokyo"]}% +``` + +## Best Practices + +1. **Package Naming**: Use a clear and descriptive name for your API package. + +2. **Version Management**: Keep your API package versioned and compatible with the Llama Stack version you're using. + +3. **Dependencies**: Only include the minimum required dependencies in your API package. + +4. **Documentation**: Include clear documentation in your API package about: + - Installation requirements + - Configuration options + - API endpoints and usage + - Any limitations or known issues + +5. **Testing**: Include tests in your API package to ensure it works correctly with Llama Stack. + +## Troubleshooting + +If your external API isn't being loaded: + +1. Check that the `external_apis_dir` path is correct and accessible. +2. Verify that the YAML files are properly formatted. +3. Ensure all required Python packages are installed. +4. Check the Llama Stack server logs for any error messages - turn on debug logging to get more information using `LLAMA_STACK_LOGGING=all=debug`. +5. Verify that the API package is installed in your Python environment. diff --git a/llama_stack/apis/datatypes.py b/llama_stack/apis/datatypes.py index 63a764725..e6628f5d7 100644 --- a/llama_stack/apis/datatypes.py +++ b/llama_stack/apis/datatypes.py @@ -4,15 +4,83 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from enum import Enum +from enum import Enum, EnumMeta -from pydantic import BaseModel +from pydantic import BaseModel, Field from llama_stack.schema_utils import json_schema_type +class DynamicApiMeta(EnumMeta): + def __new__(cls, name, bases, namespace): + # Store the original enum values + original_values = {k: v for k, v in namespace.items() if not k.startswith("_")} + + # Create the enum class + cls = super().__new__(cls, name, bases, namespace) + + # Store the original values for reference + cls._original_values = original_values + # Initialize _dynamic_values + cls._dynamic_values = {} + + return cls + + def __call__(cls, value): + try: + return super().__call__(value) + except ValueError as e: + # If this value was already dynamically added, return it + if value in cls._dynamic_values: + return cls._dynamic_values[value] + + # If the value doesn't exist, create a new enum member + # Create a new member name from the value + member_name = value.lower().replace("-", "_") + + # If this member name already exists in the enum, return the existing member + if member_name in cls._member_map_: + return cls._member_map_[member_name] + + # Instead of creating a new member, raise ValueError to force users to use Api.add() to + # register new APIs explicitly + raise ValueError(f"API '{value}' does not exist. Use Api.add() to register new APIs.") from e + + def __iter__(cls): + # Allow iteration over both static and dynamic members + yield from super().__iter__() + if hasattr(cls, "_dynamic_values"): + yield from cls._dynamic_values.values() + + def add(cls, value): + """ + Add a new API to the enum. + Used to register external APIs. + """ + member_name = value.lower().replace("-", "_") + + # If this member name already exists in the enum, return it + if member_name in cls._member_map_: + return cls._member_map_[member_name] + + # Create a new enum member + member = object.__new__(cls) + member._name_ = member_name + member._value_ = value + + # Add it to the enum class + cls._member_map_[member_name] = member + cls._member_names_.append(member_name) + cls._member_type_ = str + + # Store it in our dynamic values + cls._dynamic_values[value] = member + + return member + + @json_schema_type -class Api(Enum): +class Api(Enum, metaclass=DynamicApiMeta): providers = "providers" inference = "inference" safety = "safety" @@ -54,3 +122,12 @@ class Error(BaseModel): title: str detail: str instance: str | None = None + + +class ExternalApiSpec(BaseModel): + """Specification for an external API implementation.""" + + module: str = Field(..., description="Python module containing the API implementation") + name: str = Field(..., description="Name of the API") + pip_packages: list[str] = Field(default=[], description="List of pip packages to install the API") + protocol: str = Field(..., description="Name of the protocol class for the API") diff --git a/llama_stack/cli/stack/_build.py b/llama_stack/cli/stack/_build.py index 7ade6f17a..464accc0c 100644 --- a/llama_stack/cli/stack/_build.py +++ b/llama_stack/cli/stack/_build.py @@ -36,6 +36,7 @@ from llama_stack.distribution.datatypes import ( StackRunConfig, ) from llama_stack.distribution.distribution import get_provider_registry +from llama_stack.distribution.external import load_external_apis from llama_stack.distribution.resolver import InvalidProviderError from llama_stack.distribution.stack import replace_env_vars from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR, EXTERNAL_PROVIDERS_DIR @@ -390,6 +391,29 @@ def _run_stack_build_command_from_build_config( to_write = json.loads(build_config.model_dump_json()) f.write(yaml.dump(to_write, sort_keys=False)) + # We first install the external APIs so that the build process can use them and discover the + # providers dependencies + if build_config.external_apis_dir: + cprint("Installing external APIs", color="yellow", file=sys.stderr) + external_apis = load_external_apis(build_config) + if external_apis: + # install the external APIs + packages = [] + for _, api_spec in external_apis.items(): + if api_spec.pip_packages: + packages.extend(api_spec.pip_packages) + cprint( + f"Installing {api_spec.name} with pip packages {api_spec.pip_packages}", + color="yellow", + file=sys.stderr, + ) + return_code = run_command(["uv", "pip", "install", *packages]) + if return_code != 0: + packages_str = ", ".join(packages) + raise RuntimeError( + f"Failed to install external APIs packages: {packages_str} (return code: {return_code})" + ) + return_code = build_image( build_config, build_file_path, diff --git a/llama_stack/distribution/build.py b/llama_stack/distribution/build.py index 699ed72da..819bf4e94 100644 --- a/llama_stack/distribution/build.py +++ b/llama_stack/distribution/build.py @@ -14,6 +14,7 @@ from termcolor import cprint from llama_stack.distribution.datatypes import BuildConfig from llama_stack.distribution.distribution import get_provider_registry +from llama_stack.distribution.external import load_external_apis from llama_stack.distribution.utils.exec import run_command from llama_stack.distribution.utils.image_types import LlamaStackImageType from llama_stack.providers.datatypes import Api @@ -105,6 +106,11 @@ def build_image( normal_deps, special_deps = get_provider_dependencies(build_config) normal_deps += SERVER_DEPENDENCIES + if build_config.external_apis_dir: + external_apis = load_external_apis(build_config) + if external_apis: + for _, api_spec in external_apis.items(): + normal_deps.extend(api_spec.pip_packages) if build_config.image_type == LlamaStackImageType.CONTAINER.value: script = str(importlib.resources.files("llama_stack") / "distribution/build_container.sh") diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index abc3f0065..99539084a 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -289,6 +289,11 @@ a default SQLite store will be used.""", description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.", ) + external_apis_dir: Path | None = Field( + default=None, + description="Path to directory containing external API implementations. The APIs code and dependencies must be installed on the system.", + ) + @field_validator("external_providers_dir") @classmethod def validate_external_providers_dir(cls, v): @@ -320,6 +325,10 @@ class BuildConfig(BaseModel): default_factory=list, description="Additional pip packages to install in the distribution. These packages will be installed in the distribution environment.", ) + external_apis_dir: Path | None = Field( + default=None, + description="Path to directory containing external API implementations. The APIs code and dependencies must be installed on the system.", + ) @field_validator("external_providers_dir") @classmethod diff --git a/llama_stack/distribution/distribution.py b/llama_stack/distribution/distribution.py index e37b2c443..929e11286 100644 --- a/llama_stack/distribution/distribution.py +++ b/llama_stack/distribution/distribution.py @@ -12,6 +12,7 @@ from typing import Any import yaml from pydantic import BaseModel +from llama_stack.distribution.external import load_external_apis from llama_stack.log import get_logger from llama_stack.providers.datatypes import ( AdapterSpec, @@ -133,16 +134,34 @@ def get_provider_registry( ValueError: If any provider spec is invalid """ - ret: dict[Api, dict[str, ProviderSpec]] = {} + registry: dict[Api, dict[str, ProviderSpec]] = {} for api in providable_apis(): name = api.name.lower() logger.debug(f"Importing module {name}") try: module = importlib.import_module(f"llama_stack.providers.registry.{name}") - ret[api] = {a.provider_type: a for a in module.available_providers()} + registry[api] = {a.provider_type: a for a in module.available_providers()} except ImportError as e: logger.warning(f"Failed to import module {name}: {e}") + # Refresh providable APIs with external APIs if any + external_apis = load_external_apis(config) + for api, api_spec in external_apis.items(): + name = api_spec.name.lower() + logger.info(f"Importing external API {name} module {api_spec.module}") + try: + module = importlib.import_module(api_spec.module) + registry[api] = {a.provider_type: a for a in module.available_providers()} + except (ImportError, AttributeError) as e: + # Populate the registry with an empty dict to avoid breaking the provider registry + # This assume that the in-tree provider(s) are not available for this API which means + # that users will need to use external providers for this API. + registry[api] = {} + logger.error( + f"Failed to import external API {name}: {e}. Could not populate the in-tree provider(s) registry for {api.name}. \n" + "Install the API package to load any in-tree providers for this API." + ) + # Check if config has the external_providers_dir attribute if config and hasattr(config, "external_providers_dir") and config.external_providers_dir: external_providers_dir = os.path.abspath(os.path.expanduser(config.external_providers_dir)) @@ -175,11 +194,9 @@ def get_provider_registry( else: spec = _load_inline_provider_spec(spec_data, api, provider_name) provider_type_key = f"inline::{provider_name}" - - logger.info(f"Loaded {provider_type} provider spec for {provider_type_key} from {spec_path}") - if provider_type_key in ret[api]: + if provider_type_key in registry[api]: logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}") - ret[api][provider_type_key] = spec + registry[api][provider_type_key] = spec logger.info(f"Successfully loaded external provider {provider_type_key}") except yaml.YAMLError as yaml_err: logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}") @@ -187,4 +204,4 @@ def get_provider_registry( except Exception as e: logger.error(f"Failed to load provider spec from {spec_path}: {e}") raise e - return ret + return registry diff --git a/llama_stack/distribution/external.py b/llama_stack/distribution/external.py new file mode 100644 index 000000000..d59a01d33 --- /dev/null +++ b/llama_stack/distribution/external.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +import yaml + +from llama_stack.apis.datatypes import Api, ExternalApiSpec +from llama_stack.log import get_logger + +logger = get_logger(name=__name__, category="core") + + +def load_external_apis(config=None) -> dict[Api, ExternalApiSpec]: + """Load external API specifications from the configured directory. + + Args: + config: StackRunConfig containing the external APIs directory path + + Returns: + A dictionary mapping API names to their specifications + """ + if not config: + return {} + + if not hasattr(config, "external_apis_dir"): + return {} + + if not config.external_apis_dir: + return {} + + external_apis_dir = config.external_apis_dir.expanduser().resolve() + if not external_apis_dir.is_dir(): + logger.error(f"External APIs directory is not a directory: {external_apis_dir}") + return {} + + logger.info(f"Loading external APIs from {external_apis_dir}") + external_apis: dict[Api, ExternalApiSpec] = {} + + # Look for YAML files in the external APIs directory + for yaml_path in external_apis_dir.glob("*.yaml"): + try: + with open(yaml_path) as f: + spec_data = yaml.safe_load(f) + + spec = ExternalApiSpec(**spec_data) + api = Api.add(spec.name) + logger.info(f"Loaded external API spec for {spec.name} from {yaml_path}") + external_apis[api] = spec + except yaml.YAMLError as yaml_err: + logger.error(f"Failed to parse YAML file {yaml_path}: {yaml_err}") + raise yaml_err + except Exception as e: + logger.error(f"Failed to load external API spec from {yaml_path}: {e}") + raise e + + return external_apis diff --git a/llama_stack/distribution/inspect.py b/llama_stack/distribution/inspect.py index 5822070ad..7f7ab06ab 100644 --- a/llama_stack/distribution/inspect.py +++ b/llama_stack/distribution/inspect.py @@ -16,6 +16,7 @@ from llama_stack.apis.inspect import ( VersionInfo, ) from llama_stack.distribution.datatypes import StackRunConfig +from llama_stack.distribution.external import load_external_apis from llama_stack.distribution.server.routes import get_all_api_routes from llama_stack.providers.datatypes import HealthStatus @@ -42,7 +43,8 @@ class DistributionInspectImpl(Inspect): run_config: StackRunConfig = self.config.run_config ret = [] - all_endpoints = get_all_api_routes() + external_apis = load_external_apis(run_config) + all_endpoints = get_all_api_routes(external_apis) for api, endpoints in all_endpoints.items(): # Always include provider and inspect APIs, filter others based on run config if api.value in ["providers", "inspect"]: diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 3726bb3a5..c2a0b9fae 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -11,6 +11,7 @@ from llama_stack.apis.agents import Agents from llama_stack.apis.benchmarks import Benchmarks from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets +from llama_stack.apis.datatypes import ExternalApiSpec from llama_stack.apis.eval import Eval from llama_stack.apis.files import Files from llama_stack.apis.inference import Inference, InferenceProvider @@ -35,6 +36,7 @@ from llama_stack.distribution.datatypes import ( StackRunConfig, ) from llama_stack.distribution.distribution import builtin_automatically_routed_apis +from llama_stack.distribution.external import load_external_apis from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.log import get_logger @@ -59,8 +61,16 @@ class InvalidProviderError(Exception): pass -def api_protocol_map() -> dict[Api, Any]: - return { +def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) -> dict[Api, Any]: + """Get a mapping of API types to their protocol classes. + + Args: + external_apis: Optional dictionary of external API specifications + + Returns: + Dictionary mapping API types to their protocol classes + """ + protocols = { Api.providers: ProvidersAPI, Api.agents: Agents, Api.inference: Inference, @@ -83,10 +93,23 @@ def api_protocol_map() -> dict[Api, Any]: Api.files: Files, } + if external_apis: + for api, api_spec in external_apis.items(): + try: + module = importlib.import_module(api_spec.module) + api_class = getattr(module, api_spec.protocol) -def api_protocol_map_for_compliance_check() -> dict[Api, Any]: + protocols[api] = api_class + except (ImportError, AttributeError) as e: + logger.warning(f"Failed to load external API {api_spec.name}: {e}") + + return protocols + + +def api_protocol_map_for_compliance_check(config: Any) -> dict[Api, Any]: + external_apis = load_external_apis(config) return { - **api_protocol_map(), + **api_protocol_map(external_apis), Api.inference: InferenceProvider, } @@ -250,7 +273,7 @@ async def instantiate_providers( dist_registry: DistributionRegistry, run_config: StackRunConfig, policy: list[AccessRule], -) -> dict: +) -> dict[Api, Any]: """Instantiates providers asynchronously while managing dependencies.""" impls: dict[Api, Any] = {} inner_impls_by_provider_id: dict[str, dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis} @@ -356,7 +379,7 @@ async def instantiate_provider( impl.__provider_spec__ = provider_spec impl.__provider_config__ = config - protocols = api_protocol_map_for_compliance_check() + protocols = api_protocol_map_for_compliance_check(run_config) additional_protocols = additional_protocols_map() # TODO: check compliance for special tool groups # the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol diff --git a/llama_stack/distribution/server/routes.py b/llama_stack/distribution/server/routes.py index ea66fec5a..682ef56c6 100644 --- a/llama_stack/distribution/server/routes.py +++ b/llama_stack/distribution/server/routes.py @@ -12,10 +12,9 @@ from typing import Any from aiohttp import hdrs from starlette.routing import Route +from llama_stack.apis.datatypes import Api, ExternalApiSpec from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup from llama_stack.apis.version import LLAMA_STACK_API_VERSION -from llama_stack.distribution.resolver import api_protocol_map -from llama_stack.providers.datatypes import Api EndpointFunc = Callable[..., Any] PathParams = dict[str, str] @@ -31,10 +30,13 @@ def toolgroup_protocol_map(): } -def get_all_api_routes() -> dict[Api, list[Route]]: +def get_all_api_routes(external_apis: dict[Api, ExternalApiSpec] | None = None) -> dict[Api, list[Route]]: apis = {} - protocols = api_protocol_map() + # Lazy import to avoid circular dependency + from llama_stack.distribution.resolver import api_protocol_map + + protocols = api_protocol_map(external_apis) toolgroup_protocols = toolgroup_protocol_map() for api, protocol in protocols.items(): routes = [] @@ -73,8 +75,8 @@ def get_all_api_routes() -> dict[Api, list[Route]]: return apis -def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls: - routes = get_all_api_routes() +def initialize_route_impls(impls, external_apis: dict[Api, ExternalApiSpec] | None = None) -> RouteImpls: + routes = get_all_api_routes(external_apis) route_impls: RouteImpls = {} def _convert_path_to_regex(path: str) -> str: diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 83407a25f..9c4eb0e65 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -33,6 +33,7 @@ from pydantic import BaseModel, ValidationError from llama_stack.apis.common.responses import PaginatedResponse from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig from llama_stack.distribution.distribution import builtin_automatically_routed_apis +from llama_stack.distribution.external import ExternalApiSpec, load_external_apis from llama_stack.distribution.request_headers import PROVIDER_DATA_VAR, User, request_provider_data_context from llama_stack.distribution.resolver import InvalidProviderError from llama_stack.distribution.server.routes import ( @@ -270,9 +271,10 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable: class TracingMiddleware: - def __init__(self, app, impls): + def __init__(self, app, impls, external_apis: dict[str, ExternalApiSpec]): self.app = app self.impls = impls + self.external_apis = external_apis # FastAPI built-in paths that should bypass custom routing self.fastapi_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static") @@ -289,7 +291,7 @@ class TracingMiddleware: return await self.app(scope, receive, send) if not hasattr(self, "route_impls"): - self.route_impls = initialize_route_impls(self.impls) + self.route_impls = initialize_route_impls(self.impls, self.external_apis) try: _, _, trace_path = find_matching_route(scope.get("method", hdrs.METH_GET), path, self.route_impls) @@ -493,7 +495,9 @@ def main(args: argparse.Namespace | None = None): else: setup_logger(TelemetryAdapter(TelemetryConfig(), {})) - all_routes = get_all_api_routes() + # Load external APIs if configured + external_apis = load_external_apis(config) + all_routes = get_all_api_routes(external_apis) if config.apis: apis_to_serve = set(config.apis) @@ -512,7 +516,10 @@ def main(args: argparse.Namespace | None = None): api = Api(api_str) routes = all_routes[api] - impl = impls[api] + try: + impl = impls[api] + except KeyError as e: + raise ValueError(f"Could not find provider implementation for {api} API") from e for route in routes: if not hasattr(impl, route.name): @@ -543,7 +550,7 @@ def main(args: argparse.Namespace | None = None): app.exception_handler(Exception)(global_exception_handler) app.__llama_stack_impls__ = impls - app.add_middleware(TracingMiddleware, impls=impls) + app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis) import uvicorn diff --git a/tests/external-provider/llama-stack-provider-ollama/README.md b/tests/external-provider/llama-stack-provider-ollama/README.md deleted file mode 100644 index 8bd2b6a87..000000000 --- a/tests/external-provider/llama-stack-provider-ollama/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# Ollama external provider for Llama Stack - -Template code to create a new external provider for Llama Stack. diff --git a/tests/external-provider/llama-stack-provider-ollama/custom_ollama.yaml b/tests/external-provider/llama-stack-provider-ollama/custom_ollama.yaml deleted file mode 100644 index 2ae1e2cf3..000000000 --- a/tests/external-provider/llama-stack-provider-ollama/custom_ollama.yaml +++ /dev/null @@ -1,7 +0,0 @@ -adapter: - adapter_type: custom_ollama - pip_packages: ["ollama", "aiohttp", "tests/external-provider/llama-stack-provider-ollama"] - config_class: llama_stack_provider_ollama.config.OllamaImplConfig - module: llama_stack_provider_ollama -api_dependencies: [] -optional_api_dependencies: [] diff --git a/tests/external-provider/llama-stack-provider-ollama/pyproject.toml b/tests/external-provider/llama-stack-provider-ollama/pyproject.toml deleted file mode 100644 index ca1fecc42..000000000 --- a/tests/external-provider/llama-stack-provider-ollama/pyproject.toml +++ /dev/null @@ -1,43 +0,0 @@ -[project] -dependencies = [ - "llama-stack", - "pydantic", - "ollama", - "aiohttp", - "aiosqlite", - "autoevals", - "chardet", - "chromadb-client", - "datasets", - "faiss-cpu", - "fastapi", - "fire", - "httpx", - "matplotlib", - "mcp", - "nltk", - "numpy", - "openai", - "opentelemetry-exporter-otlp-proto-http", - "opentelemetry-sdk", - "pandas", - "pillow", - "psycopg2-binary", - "pymongo", - "pypdf", - "redis", - "requests", - "scikit-learn", - "scipy", - "sentencepiece", - "tqdm", - "transformers", - "tree_sitter", - "uvicorn", -] - -name = "llama-stack-provider-ollama" -version = "0.1.0" -description = "External provider for Ollama using the Llama Stack API" -readme = "README.md" -requires-python = ">=3.12" diff --git a/tests/external-provider/llama-stack-provider-ollama/run.yaml b/tests/external-provider/llama-stack-provider-ollama/run.yaml deleted file mode 100644 index 158f6800f..000000000 --- a/tests/external-provider/llama-stack-provider-ollama/run.yaml +++ /dev/null @@ -1,94 +0,0 @@ -version: '2' -image_name: ollama -apis: -- inference -- telemetry -- tool_runtime -- datasetio -- vector_io -providers: - inference: - - provider_id: custom_ollama - provider_type: remote::custom_ollama - config: - url: ${env.OLLAMA_URL:http://localhost:11434} - vector_io: - - provider_id: faiss - provider_type: inline::faiss - config: - kvstore: - type: sqlite - namespace: null - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db - telemetry: - - provider_id: meta-reference - provider_type: inline::meta-reference - config: - service_name: "${env.OTEL_SERVICE_NAME:\u200B}" - sinks: ${env.TELEMETRY_SINKS:console,sqlite} - sqlite_db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/trace_store.db - datasetio: - - provider_id: huggingface - provider_type: remote::huggingface - config: - kvstore: - type: sqlite - namespace: null - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/huggingface_datasetio.db - - provider_id: localfs - provider_type: inline::localfs - config: - kvstore: - type: sqlite - namespace: null - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/localfs_datasetio.db - tool_runtime: - - provider_id: brave-search - provider_type: remote::brave-search - config: - api_key: ${env.BRAVE_SEARCH_API_KEY:} - max_results: 3 - - provider_id: tavily-search - provider_type: remote::tavily-search - config: - api_key: ${env.TAVILY_SEARCH_API_KEY:} - max_results: 3 - - provider_id: rag-runtime - provider_type: inline::rag-runtime - config: {} - - provider_id: model-context-protocol - provider_type: remote::model-context-protocol - config: {} - - provider_id: wolfram-alpha - provider_type: remote::wolfram-alpha - config: - api_key: ${env.WOLFRAM_ALPHA_API_KEY:} -metadata_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/registry.db -models: -- metadata: {} - model_id: ${env.INFERENCE_MODEL} - provider_id: custom_ollama - model_type: llm -- metadata: - embedding_dimension: 384 - model_id: all-MiniLM-L6-v2 - provider_id: custom_ollama - provider_model_id: all-minilm:latest - model_type: embedding -shields: [] -vector_dbs: [] -datasets: [] -scoring_fns: [] -benchmarks: [] -tool_groups: -- toolgroup_id: builtin::websearch - provider_id: tavily-search -- toolgroup_id: builtin::rag - provider_id: rag-runtime -- toolgroup_id: builtin::wolfram_alpha - provider_id: wolfram-alpha -server: - port: 8321 -external_providers_dir: ~/.llama/providers.d diff --git a/tests/external-provider/llama-stack-provider-ollama/custom-distro.yaml b/tests/external/build.yaml similarity index 64% rename from tests/external-provider/llama-stack-provider-ollama/custom-distro.yaml rename to tests/external/build.yaml index 1f3ab3817..90dcc97aa 100644 --- a/tests/external-provider/llama-stack-provider-ollama/custom-distro.yaml +++ b/tests/external/build.yaml @@ -2,8 +2,9 @@ version: '2' distribution_spec: description: Custom distro for CI tests providers: - inference: - - remote::custom_ollama -image_type: container + weather: + - remote::kaze +image_type: venv image_name: ci-test external_providers_dir: ~/.llama/providers.d +external_apis_dir: ~/.llama/apis.d diff --git a/tests/external/kaze.yaml b/tests/external/kaze.yaml new file mode 100644 index 000000000..c61ac0e31 --- /dev/null +++ b/tests/external/kaze.yaml @@ -0,0 +1,6 @@ +adapter: + adapter_type: kaze + pip_packages: ["tests/external/llama-stack-provider-kaze"] + config_class: llama_stack_provider_kaze.config.KazeProviderConfig + module: llama_stack_provider_kaze +optional_api_dependencies: [] diff --git a/tests/external/llama-stack-api-weather/pyproject.toml b/tests/external/llama-stack-api-weather/pyproject.toml new file mode 100644 index 000000000..566e1e9aa --- /dev/null +++ b/tests/external/llama-stack-api-weather/pyproject.toml @@ -0,0 +1,15 @@ +[project] +name = "llama-stack-api-weather" +version = "0.1.0" +description = "Weather API for Llama Stack" +readme = "README.md" +requires-python = ">=3.10" +dependencies = ["llama-stack", "pydantic"] + +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +where = ["src"] +include = ["llama_stack_api_weather", "llama_stack_api_weather.*"] diff --git a/tests/external/llama-stack-api-weather/src/llama_stack_api_weather/__init__.py b/tests/external/llama-stack-api-weather/src/llama_stack_api_weather/__init__.py new file mode 100644 index 000000000..d0227615d --- /dev/null +++ b/tests/external/llama-stack-api-weather/src/llama_stack_api_weather/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +"""Weather API for Llama Stack.""" + +from .weather import WeatherProvider, available_providers + +__all__ = ["WeatherProvider", "available_providers"] diff --git a/tests/external/llama-stack-api-weather/src/llama_stack_api_weather/weather.py b/tests/external/llama-stack-api-weather/src/llama_stack_api_weather/weather.py new file mode 100644 index 000000000..4b3bfb641 --- /dev/null +++ b/tests/external/llama-stack-api-weather/src/llama_stack_api_weather/weather.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Protocol + +from llama_stack.providers.datatypes import AdapterSpec, Api, ProviderSpec, RemoteProviderSpec +from llama_stack.schema_utils import webmethod + + +def available_providers() -> list[ProviderSpec]: + return [ + RemoteProviderSpec( + api=Api.weather, + provider_type="remote::kaze", + config_class="llama_stack_provider_kaze.KazeProviderConfig", + adapter=AdapterSpec( + adapter_type="kaze", + module="llama_stack_provider_kaze", + pip_packages=["llama_stack_provider_kaze"], + config_class="llama_stack_provider_kaze.KazeProviderConfig", + ), + ), + ] + + +class WeatherProvider(Protocol): + """ + A protocol for the Weather API. + """ + + @webmethod(route="/weather/locations", method="GET") + async def get_available_locations() -> dict[str, list[str]]: + """ + Get the available locations. + """ + ... diff --git a/tests/external/llama-stack-provider-kaze/pyproject.toml b/tests/external/llama-stack-provider-kaze/pyproject.toml new file mode 100644 index 000000000..7bbf1f843 --- /dev/null +++ b/tests/external/llama-stack-provider-kaze/pyproject.toml @@ -0,0 +1,15 @@ +[project] +name = "llama-stack-provider-kaze" +version = "0.1.0" +description = "Kaze weather provider for Llama Stack" +readme = "README.md" +requires-python = ">=3.10" +dependencies = ["llama-stack", "pydantic", "aiohttp"] + +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +where = ["src"] +include = ["llama_stack_provider_kaze", "llama_stack_provider_kaze.*"] diff --git a/tests/external/llama-stack-provider-kaze/src/llama_stack_provider_kaze/__init__.py b/tests/external/llama-stack-provider-kaze/src/llama_stack_provider_kaze/__init__.py new file mode 100644 index 000000000..581ff38c7 --- /dev/null +++ b/tests/external/llama-stack-provider-kaze/src/llama_stack_provider_kaze/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +"""Kaze weather provider for Llama Stack.""" + +from .config import KazeProviderConfig +from .kaze import WeatherKazeAdapter + +__all__ = ["KazeProviderConfig", "WeatherKazeAdapter"] + + +async def get_adapter_impl(config: KazeProviderConfig, _deps): + from .kaze import WeatherKazeAdapter + + impl = WeatherKazeAdapter(config) + await impl.initialize() + return impl diff --git a/tests/external/llama-stack-provider-kaze/src/llama_stack_provider_kaze/config.py b/tests/external/llama-stack-provider-kaze/src/llama_stack_provider_kaze/config.py new file mode 100644 index 000000000..4b82698ed --- /dev/null +++ b/tests/external/llama-stack-provider-kaze/src/llama_stack_provider_kaze/config.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel + + +class KazeProviderConfig(BaseModel): + """Configuration for the Kaze weather provider.""" diff --git a/tests/external/llama-stack-provider-kaze/src/llama_stack_provider_kaze/kaze.py b/tests/external/llama-stack-provider-kaze/src/llama_stack_provider_kaze/kaze.py new file mode 100644 index 000000000..120b5438d --- /dev/null +++ b/tests/external/llama-stack-provider-kaze/src/llama_stack_provider_kaze/kaze.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack_api_weather.weather import WeatherProvider + +from .config import KazeProviderConfig + + +class WeatherKazeAdapter(WeatherProvider): + """Kaze weather provider implementation.""" + + def __init__( + self, + config: KazeProviderConfig, + ) -> None: + self.config = config + + async def initialize(self) -> None: + pass + + async def get_available_locations(self) -> dict[str, list[str]]: + """Get available weather locations.""" + return {"locations": ["Paris", "Tokyo"]} diff --git a/tests/external/run-byoa.yaml b/tests/external/run-byoa.yaml new file mode 100644 index 000000000..5774ae9da --- /dev/null +++ b/tests/external/run-byoa.yaml @@ -0,0 +1,13 @@ +version: "2" +image_name: "llama-stack-api-weather" +apis: + - weather +providers: + weather: + - provider_id: kaze + provider_type: remote::kaze + config: {} +external_apis_dir: ~/.llama/apis.d +external_providers_dir: ~/.llama/providers.d +server: + port: 8321 diff --git a/tests/external/weather.yaml b/tests/external/weather.yaml new file mode 100644 index 000000000..a84fcc921 --- /dev/null +++ b/tests/external/weather.yaml @@ -0,0 +1,4 @@ +module: llama_stack_api_weather +name: weather +pip_packages: ["tests/external/llama-stack-api-weather"] +protocol: WeatherProvider