diff --git a/.github/workflows/test-external-provider-module.yml b/.github/workflows/test-external-provider-module.yml new file mode 100644 index 000000000..30fddb981 --- /dev/null +++ b/.github/workflows/test-external-provider-module.yml @@ -0,0 +1,72 @@ +name: Test External Providers Installed via Module + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + paths: + - 'llama_stack/**' + - 'tests/integration/**' + - 'uv.lock' + - 'pyproject.toml' + - 'requirements.txt' + - '.github/workflows/test-external-providers-module.yml' # This workflow + +jobs: + test-external-providers-from-module: + runs-on: ubuntu-latest + strategy: + matrix: + image-type: [venv] + # We don't do container yet, it's tricky to install a package from the host into the + # container and point 'uv pip install' to the correct path... + steps: + - name: Checkout repository + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Install dependencies + uses: ./.github/actions/setup-runner + + - name: Install Ramalama + shell: bash + run: | + uv pip install ramalama + + - name: Run Ramalama + shell: bash + run: | + nohup ramalama serve llama3.2:3b-instruct-fp16 > ramalama_server.log 2>&1 & + - name: Apply image type to config file + run: | + yq -i '.image_type = "${{ matrix.image-type }}"' tests/external/ramalama-stack/run.yaml + cat tests/external/ramalama-stack/run.yaml + + - name: Build distro from config file + run: | + USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. llama stack build --config tests/external/ramalama-stack/build.yaml + + - name: Start Llama Stack server in background + if: ${{ matrix.image-type }} == 'venv' + env: + INFERENCE_MODEL: "llama3.2:3b-instruct-fp16" + run: | + # Use the virtual environment created by the build step (name comes from build config) + source ramalama-stack-test/bin/activate + uv pip list + nohup llama stack run tests/external/ramalama-stack/run.yaml --image-type ${{ matrix.image-type }} > server.log 2>&1 & + + - name: Wait for Llama Stack server to be ready + run: | + for i in {1..30}; do + if ! grep -q "successfully connected to Ramalama" server.log; then + echo "Waiting for Llama Stack server to load the provider..." + sleep 1 + else + echo "Provider loaded" + exit 0 + fi + done + echo "Provider failed to load" + cat server.log + exit 1 diff --git a/docs/source/providers/external.md b/docs/source/providers/external.md index db0bc01e3..092b3a476 100644 --- a/docs/source/providers/external.md +++ b/docs/source/providers/external.md @@ -7,7 +7,17 @@ Llama Stack supports external providers that live outside of the main codebase. ## Configuration -To enable external providers, you need to configure the `external_providers_dir` in your Llama Stack configuration. This directory should contain your external provider specifications: +To enable external providers, you need to add `module` into your build yaml, allowing Llama Stack to install the required package corresponding to the external provider. + +an example entry in your build.yaml should look like: + +``` +- provider_id: ramalama + provider_type: remote::ramalama + module: ramalama_stack +``` + +Additionally you can configure the `external_providers_dir` in your Llama Stack configuration. This method is in the process of being deprecated in favor of the `module` method. If using this method, the external provider directory should contain your external provider specifications: ```yaml external_providers_dir: ~/.llama/providers.d/ @@ -112,6 +122,31 @@ container_image: custom-vector-store:latest # optional ## Required Implementation +## All Providers + +All providers must contain a `get_provider_spec` function in their `provider` module. This is a standardized structure that Llama Stack expects and is necessary for getting things such as the config class. The `get_provider_spec` method returns a structure identical to the `adapter`. An example function may look like: + +```python +from llama_stack.providers.datatypes import ( + ProviderSpec, + Api, + AdapterSpec, + remote_provider_spec, +) + + +def get_provider_spec() -> ProviderSpec: + return remote_provider_spec( + api=Api.inference, + adapter=AdapterSpec( + adapter_type="ramalama", + pip_packages=["ramalama>=0.8.5", "pymilvus"], + config_class="ramalama_stack.config.RamalamaImplConfig", + module="ramalama_stack", + ), + ) +``` + ### Remote Providers Remote providers must expose a `get_adapter_impl()` function in their module that takes two arguments: @@ -155,7 +190,7 @@ Version: 0.1.0 Location: /path/to/venv/lib/python3.10/site-packages ``` -## Example: Custom Ollama Provider +## Example using `external_providers_dir`: Custom Ollama Provider Here's a complete example of creating and using a custom Ollama provider: @@ -206,6 +241,35 @@ external_providers_dir: ~/.llama/providers.d/ The provider will now be available in Llama Stack with the type `remote::custom_ollama`. + +## Example using `module`: ramalama-stack + +[ramalama-stack](https://github.com/containers/ramalama-stack) is a recognized external provider that supports installation via module. + +To install Llama Stack with this external provider a user can provider the following build.yaml: + +```yaml +version: 2 +distribution_spec: + description: Use (an external) Ramalama server for running LLM inference + container_image: null + providers: + inference: + - provider_id: ramalama + provider_type: remote::ramalama + module: ramalama_stack==0.3.0a0 +image_type: venv +image_name: null +external_providers_dir: null +additional_pip_packages: +- aiosqlite +- sqlalchemy[asyncio] +``` + +No other steps are required other than `llama stack build` and `llama stack run`. The build process will use `module` to install all of the provider dependencies, retrieve the spec, etc. + +The provider will now be available in Llama Stack with the type `remote::ramalama`. + ## Best Practices 1. **Package Naming**: Use the prefix `llama-stack-provider-` for your provider packages to make them easily identifiable. @@ -229,9 +293,10 @@ information. Execute the test for the Provider type you are developing. If your external provider isn't being loaded: +1. Check that `module` points to a published pip package with a top level `provider` module including `get_provider_spec`. 1. Check that the `external_providers_dir` path is correct and accessible. 2. Verify that the YAML files are properly formatted. 3. Ensure all required Python packages are installed. 4. Check the Llama Stack server logs for any error messages - turn on debug logging to get more information using `LLAMA_STACK_LOGGING=all=debug`. -5. Verify that the provider package is installed in your Python environment. +5. Verify that the provider package is installed in your Python environment if using `external_providers_dir`. diff --git a/llama_stack/cli/stack/_build.py b/llama_stack/cli/stack/_build.py index 83aefa4a9..af2a46739 100644 --- a/llama_stack/cli/stack/_build.py +++ b/llama_stack/cli/stack/_build.py @@ -94,7 +94,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None: ) sys.exit(1) elif args.providers: - providers_list: dict[str, str | list[str]] = dict() + provider_list: dict[str, list[Provider]] = dict() for api_provider in args.providers.split(","): if "=" not in api_provider: cprint( @@ -103,7 +103,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None: file=sys.stderr, ) sys.exit(1) - api, provider = api_provider.split("=") + api, provider_type = api_provider.split("=") providers_for_api = get_provider_registry().get(Api(api), None) if providers_for_api is None: cprint( @@ -112,16 +112,14 @@ def run_stack_build_command(args: argparse.Namespace) -> None: file=sys.stderr, ) sys.exit(1) - if provider in providers_for_api: - if api not in providers_list: - providers_list[api] = [] - # Use type guarding to ensure we have a list - provider_value = providers_list[api] - if isinstance(provider_value, list): - provider_value.append(provider) - else: - # Convert string to list and append - providers_list[api] = [provider_value, provider] + if provider_type in providers_for_api: + provider = Provider( + provider_type=provider_type, + provider_id=provider_type.split("::")[1], + config={}, + module=None, + ) + provider_list.setdefault(api, []).append(provider) else: cprint( f"{provider} is not a valid provider for the {api} API.", @@ -130,7 +128,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None: ) sys.exit(1) distribution_spec = DistributionSpec( - providers=providers_list, + providers=provider_list, description=",".join(args.providers), ) if not args.image_type: @@ -191,7 +189,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None: cprint("Tip: use to see options for the providers.\n", color="green", file=sys.stderr) - providers: dict[str, str | list[str]] = dict() + providers: dict[str, list[Provider]] = dict() for api, providers_for_api in get_provider_registry().items(): available_providers = [x for x in providers_for_api.keys() if x not in ("remote", "remote::sample")] if not available_providers: @@ -237,11 +235,13 @@ def run_stack_build_command(args: argparse.Namespace) -> None: if args.print_deps_only: print(f"# Dependencies for {args.template or args.config or image_name}") - normal_deps, special_deps = get_provider_dependencies(build_config) + normal_deps, special_deps, external_provider_dependencies = get_provider_dependencies(build_config) normal_deps += SERVER_DEPENDENCIES print(f"uv pip install {' '.join(normal_deps)}") for special_dep in special_deps: print(f"uv pip install {special_dep}") + for external_dep in external_provider_dependencies: + print(f"uv pip install {external_dep}") return try: @@ -304,27 +304,25 @@ def _generate_run_config( provider_registry = get_provider_registry(build_config) for api in apis: run_config.providers[api] = [] - provider_types = build_config.distribution_spec.providers[api] - if isinstance(provider_types, str): - provider_types = [provider_types] + providers = build_config.distribution_spec.providers[api] - for i, provider_type in enumerate(provider_types): - pid = provider_type.split("::")[-1] + for provider in providers: + pid = provider.provider_id - p = provider_registry[Api(api)][provider_type] + p = provider_registry[Api(api)][provider.provider_type] if p.deprecation_error: raise InvalidProviderError(p.deprecation_error) try: - config_type = instantiate_class_type(provider_registry[Api(api)][provider_type].config_class) - except ModuleNotFoundError: + config_type = instantiate_class_type(provider_registry[Api(api)][provider.provider_type].config_class) + except (ModuleNotFoundError, ValueError) as exc: # HACK ALERT: # This code executes after building is done, the import cannot work since the # package is either available in the venv or container - not available on the host. # TODO: use a "is_external" flag in ProviderSpec to check if the provider is # external cprint( - f"Failed to import provider {provider_type} for API {api} - assuming it's external, skipping", + f"Failed to import provider {provider.provider_type} for API {api} - assuming it's external, skipping: {exc}", color="yellow", file=sys.stderr, ) @@ -337,9 +335,10 @@ def _generate_run_config( config = {} p_spec = Provider( - provider_id=f"{pid}-{i}" if len(provider_types) > 1 else pid, - provider_type=provider_type, + provider_id=pid, + provider_type=provider.provider_type, config=config, + module=provider.module, ) run_config.providers[api].append(p_spec) @@ -402,7 +401,7 @@ def _run_stack_build_command_from_build_config( run_config_file = _generate_run_config(build_config, build_dir, image_name) with open(build_file_path, "w") as f: - to_write = json.loads(build_config.model_dump_json()) + to_write = json.loads(build_config.model_dump_json(exclude_none=True)) f.write(yaml.dump(to_write, sort_keys=False)) # We first install the external APIs so that the build process can use them and discover the diff --git a/llama_stack/distribution/build.py b/llama_stack/distribution/build.py index 819bf4e94..b4eaac1c7 100644 --- a/llama_stack/distribution/build.py +++ b/llama_stack/distribution/build.py @@ -42,7 +42,7 @@ class ApiInput(BaseModel): def get_provider_dependencies( config: BuildConfig | DistributionTemplate, -) -> tuple[list[str], list[str]]: +) -> tuple[list[str], list[str], list[str]]: """Get normal and special dependencies from provider configuration.""" if isinstance(config, DistributionTemplate): config = config.build_config() @@ -51,6 +51,7 @@ def get_provider_dependencies( additional_pip_packages = config.additional_pip_packages deps = [] + external_provider_deps = [] registry = get_provider_registry(config) for api_str, provider_or_providers in providers.items(): providers_for_api = registry[Api(api_str)] @@ -65,8 +66,16 @@ def get_provider_dependencies( raise ValueError(f"Provider `{provider}` is not available for API `{api_str}`") provider_spec = providers_for_api[provider_type] - deps.extend(provider_spec.pip_packages) - if provider_spec.container_image: + if hasattr(provider_spec, "is_external") and provider_spec.is_external: + # this ensures we install the top level module for our external providers + if provider_spec.module: + if isinstance(provider_spec.module, str): + external_provider_deps.append(provider_spec.module) + else: + external_provider_deps.extend(provider_spec.module) + if hasattr(provider_spec, "pip_packages"): + deps.extend(provider_spec.pip_packages) + if hasattr(provider_spec, "container_image") and provider_spec.container_image: raise ValueError("A stack's dependencies cannot have a container image") normal_deps = [] @@ -79,7 +88,7 @@ def get_provider_dependencies( normal_deps.extend(additional_pip_packages or []) - return list(set(normal_deps)), list(set(special_deps)) + return list(set(normal_deps)), list(set(special_deps)), list(set(external_provider_deps)) def print_pip_install_help(config: BuildConfig): @@ -104,7 +113,7 @@ def build_image( ): container_base = build_config.distribution_spec.container_image or "python:3.12-slim" - normal_deps, special_deps = get_provider_dependencies(build_config) + normal_deps, special_deps, external_provider_deps = get_provider_dependencies(build_config) normal_deps += SERVER_DEPENDENCIES if build_config.external_apis_dir: external_apis = load_external_apis(build_config) @@ -116,34 +125,47 @@ def build_image( script = str(importlib.resources.files("llama_stack") / "distribution/build_container.sh") args = [ script, + "--template-or-config", template_or_config, + "--image-name", image_name, + "--container-base", container_base, + "--normal-deps", " ".join(normal_deps), ] - # When building from a config file (not a template), include the run config path in the # build arguments if run_config is not None: - args.append(run_config) + args.extend(["--run-config", run_config]) elif build_config.image_type == LlamaStackImageType.CONDA.value: script = str(importlib.resources.files("llama_stack") / "distribution/build_conda_env.sh") args = [ script, + "--env-name", str(image_name), + "--build-file-path", str(build_file_path), + "--normal-deps", " ".join(normal_deps), ] elif build_config.image_type == LlamaStackImageType.VENV.value: script = str(importlib.resources.files("llama_stack") / "distribution/build_venv.sh") args = [ script, + "--env-name", str(image_name), + "--normal-deps", " ".join(normal_deps), ] + # Always pass both arguments, even if empty, to maintain consistent positional arguments if special_deps: - args.append("#".join(special_deps)) + args.extend(["--optional-deps", "#".join(special_deps)]) + if external_provider_deps: + args.extend( + ["--external-provider-deps", "#".join(external_provider_deps)] + ) # the script will install external provider module, get its deps, and install those too. return_code = run_command(args) diff --git a/llama_stack/distribution/build_conda_env.sh b/llama_stack/distribution/build_conda_env.sh index 61a2d5973..48ac3a1ab 100755 --- a/llama_stack/distribution/build_conda_env.sh +++ b/llama_stack/distribution/build_conda_env.sh @@ -9,10 +9,91 @@ LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-} LLAMA_STACK_CLIENT_DIR=${LLAMA_STACK_CLIENT_DIR:-} TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-} +PYPI_VERSION=${PYPI_VERSION:-} # This timeout (in seconds) is necessary when installing PyTorch via uv since it's likely to time out # Reference: https://github.com/astral-sh/uv/pull/1694 UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500} +set -euo pipefail + +# Define color codes +RED='\033[0;31m' +GREEN='\033[0;32m' +NC='\033[0m' # No Color + +SCRIPT_DIR=$(dirname "$(readlink -f "$0")") +source "$SCRIPT_DIR/common.sh" + +# Usage function +usage() { + echo "Usage: $0 --env-name --build-file-path --normal-deps [--external-provider-deps ] [--optional-deps ]" + echo "Example: $0 --env-name my-conda-env --build-file-path ./my-stack-build.yaml --normal-deps 'numpy pandas scipy' --external-provider-deps 'foo' --optional-deps 'bar'" + exit 1 +} + +# Parse arguments +env_name="" +build_file_path="" +normal_deps="" +external_provider_deps="" +optional_deps="" + +while [[ $# -gt 0 ]]; do + key="$1" + case "$key" in + --env-name) + if [[ -z "$2" || "$2" == --* ]]; then + echo "Error: --env-name requires a string value" >&2 + usage + fi + env_name="$2" + shift 2 + ;; + --build-file-path) + if [[ -z "$2" || "$2" == --* ]]; then + echo "Error: --build-file-path requires a string value" >&2 + usage + fi + build_file_path="$2" + shift 2 + ;; + --normal-deps) + if [[ -z "$2" || "$2" == --* ]]; then + echo "Error: --normal-deps requires a string value" >&2 + usage + fi + normal_deps="$2" + shift 2 + ;; + --external-provider-deps) + if [[ -z "$2" || "$2" == --* ]]; then + echo "Error: --external-provider-deps requires a string value" >&2 + usage + fi + external_provider_deps="$2" + shift 2 + ;; + --optional-deps) + if [[ -z "$2" || "$2" == --* ]]; then + echo "Error: --optional-deps requires a string value" >&2 + usage + fi + optional_deps="$2" + shift 2 + ;; + *) + echo "Unknown option: $1" >&2 + usage + ;; + esac +done + +# Check required arguments +if [[ -z "$env_name" || -z "$build_file_path" || -z "$normal_deps" ]]; then + echo "Error: --env-name, --build-file-path, and --normal-deps are required." >&2 + usage +fi + if [ -n "$LLAMA_STACK_DIR" ]; then echo "Using llama-stack-dir=$LLAMA_STACK_DIR" fi @@ -20,50 +101,18 @@ if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then echo "Using llama-stack-client-dir=$LLAMA_STACK_CLIENT_DIR" fi -if [ "$#" -lt 3 ]; then - echo "Usage: $0 []" >&2 - echo "Example: $0 my-conda-env ./my-stack-build.yaml 'numpy pandas scipy'" >&2 - exit 1 -fi - -special_pip_deps="$4" - -set -euo pipefail - -env_name="$1" -build_file_path="$2" -pip_dependencies="$3" - -# Define color codes -RED='\033[0;31m' -GREEN='\033[0;32m' -NC='\033[0m' # No Color - -# this is set if we actually create a new conda in which case we need to clean up -ENVNAME="" - -SCRIPT_DIR=$(dirname "$(readlink -f "$0")") -source "$SCRIPT_DIR/common.sh" - ensure_conda_env_python310() { - local env_name="$1" - local pip_dependencies="$2" - local special_pip_deps="$3" + # Use only global variables set by flag parser local python_version="3.12" - # Check if conda command is available if ! is_command_available conda; then printf "${RED}Error: conda command not found. Is Conda installed and in your PATH?${NC}" >&2 exit 1 fi - # Check if the environment exists if conda env list | grep -q "^${env_name} "; then printf "Conda environment '${env_name}' exists. Checking Python version...\n" - - # Check Python version in the environment current_version=$(conda run -n "${env_name}" python --version 2>&1 | cut -d' ' -f2 | cut -d'.' -f1,2) - if [ "$current_version" = "$python_version" ]; then printf "Environment '${env_name}' already has Python ${python_version}. No action needed.\n" else @@ -73,37 +122,37 @@ ensure_conda_env_python310() { else printf "Conda environment '${env_name}' does not exist. Creating with Python ${python_version}...\n" conda create -n "${env_name}" python="${python_version}" -y - - ENVNAME="${env_name}" - # setup_cleanup_handlers fi eval "$(conda shell.bash hook)" conda deactivate && conda activate "${env_name}" - "$CONDA_PREFIX"/bin/pip install uv if [ -n "$TEST_PYPI_VERSION" ]; then - # these packages are damaged in test-pypi, so install them first uv pip install fastapi libcst uv pip install --extra-index-url https://test.pypi.org/simple/ \ llama-stack=="$TEST_PYPI_VERSION" \ - "$pip_dependencies" - if [ -n "$special_pip_deps" ]; then - IFS='#' read -ra parts <<<"$special_pip_deps" + "$normal_deps" + if [ -n "$optional_deps" ]; then + IFS='#' read -ra parts <<<"$optional_deps" + for part in "${parts[@]}"; do + echo "$part" + uv pip install $part + done + fi + if [ -n "$external_provider_deps" ]; then + IFS='#' read -ra parts <<<"$external_provider_deps" for part in "${parts[@]}"; do echo "$part" uv pip install "$part" done fi else - # Re-installing llama-stack in the new conda environment if [ -n "$LLAMA_STACK_DIR" ]; then if [ ! -d "$LLAMA_STACK_DIR" ]; then printf "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: $LLAMA_STACK_DIR${NC}\n" >&2 exit 1 fi - printf "Installing from LLAMA_STACK_DIR: $LLAMA_STACK_DIR\n" uv pip install --no-cache-dir -e "$LLAMA_STACK_DIR" else @@ -115,31 +164,44 @@ ensure_conda_env_python310() { fi uv pip install --no-cache-dir "$SPEC_VERSION" fi - if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then if [ ! -d "$LLAMA_STACK_CLIENT_DIR" ]; then printf "${RED}Warning: LLAMA_STACK_CLIENT_DIR is set but directory does not exist: $LLAMA_STACK_CLIENT_DIR${NC}\n" >&2 exit 1 fi - printf "Installing from LLAMA_STACK_CLIENT_DIR: $LLAMA_STACK_CLIENT_DIR\n" uv pip install --no-cache-dir -e "$LLAMA_STACK_CLIENT_DIR" fi - - # Install pip dependencies printf "Installing pip dependencies\n" - uv pip install $pip_dependencies - if [ -n "$special_pip_deps" ]; then - IFS='#' read -ra parts <<<"$special_pip_deps" + uv pip install $normal_deps + if [ -n "$optional_deps" ]; then + IFS='#' read -ra parts <<<"$optional_deps" for part in "${parts[@]}"; do echo "$part" uv pip install $part done fi + if [ -n "$external_provider_deps" ]; then + IFS='#' read -ra parts <<<"$external_provider_deps" + for part in "${parts[@]}"; do + echo "Getting provider spec for module: $part and installing dependencies" + package_name=$(echo "$part" | sed 's/[<>=!].*//') + python3 -c " +import importlib +import sys +try: + module = importlib.import_module(f'$package_name.provider') + spec = module.get_provider_spec() + if hasattr(spec, 'pip_packages') and spec.pip_packages: + print('\\n'.join(spec.pip_packages)) +except Exception as e: + print(f'Error getting provider spec for $package_name: {e}', file=sys.stderr) +" | uv pip install -r - + done + fi fi - mv "$build_file_path" "$CONDA_PREFIX"/llamastack-build.yaml echo "Build spec configuration saved at $CONDA_PREFIX/llamastack-build.yaml" } -ensure_conda_env_python310 "$env_name" "$pip_dependencies" "$special_pip_deps" +ensure_conda_env_python310 "$env_name" "$build_file_path" "$normal_deps" "$optional_deps" "$external_provider_deps" diff --git a/llama_stack/distribution/build_container.sh b/llama_stack/distribution/build_container.sh index 6985c1cd0..7c406d3e7 100755 --- a/llama_stack/distribution/build_container.sh +++ b/llama_stack/distribution/build_container.sh @@ -27,52 +27,103 @@ RUN_CONFIG_PATH=/app/run.yaml BUILD_CONTEXT_DIR=$(pwd) -if [ "$#" -lt 4 ]; then - # This only works for templates - echo "Usage: $0 [] []" >&2 - exit 1 -fi set -euo pipefail -template_or_config="$1" -shift -image_name="$1" -shift -container_base="$1" -shift -pip_dependencies="$1" -shift - -# Handle optional arguments -run_config="" -special_pip_deps="" - -# Check if there are more arguments -# The logics is becoming cumbersom, we should refactor it if we can do better -if [ $# -gt 0 ]; then - # Check if the argument ends with .yaml - if [[ "$1" == *.yaml ]]; then - run_config="$1" - shift - # If there's another argument after .yaml, it must be special_pip_deps - if [ $# -gt 0 ]; then - special_pip_deps="$1" - fi - else - # If it's not .yaml, it must be special_pip_deps - special_pip_deps="$1" - fi -fi - # Define color codes RED='\033[0;31m' NC='\033[0m' # No Color +# Usage function +usage() { + echo "Usage: $0 --image-name --container-base --normal-deps [--run-config ] [--external-provider-deps ] [--optional-deps ]" + echo "Example: $0 --image-name llama-stack-img --container-base python:3.12-slim --normal-deps 'numpy pandas' --run-config ./run.yaml --external-provider-deps 'foo' --optional-deps 'bar'" + exit 1 +} + +# Parse arguments +image_name="" +container_base="" +normal_deps="" +external_provider_deps="" +optional_deps="" +run_config="" +template_or_config="" + +while [[ $# -gt 0 ]]; do + key="$1" + case "$key" in + --image-name) + if [[ -z "$2" || "$2" == --* ]]; then + echo "Error: --image-name requires a string value" >&2 + usage + fi + image_name="$2" + shift 2 + ;; + --container-base) + if [[ -z "$2" || "$2" == --* ]]; then + echo "Error: --container-base requires a string value" >&2 + usage + fi + container_base="$2" + shift 2 + ;; + --normal-deps) + if [[ -z "$2" || "$2" == --* ]]; then + echo "Error: --normal-deps requires a string value" >&2 + usage + fi + normal_deps="$2" + shift 2 + ;; + --external-provider-deps) + if [[ -z "$2" || "$2" == --* ]]; then + echo "Error: --external-provider-deps requires a string value" >&2 + usage + fi + external_provider_deps="$2" + shift 2 + ;; + --optional-deps) + if [[ -z "$2" || "$2" == --* ]]; then + echo "Error: --optional-deps requires a string value" >&2 + usage + fi + optional_deps="$2" + shift 2 + ;; + --run-config) + if [[ -z "$2" || "$2" == --* ]]; then + echo "Error: --run-config requires a string value" >&2 + usage + fi + run_config="$2" + shift 2 + ;; + --template-or-config) + if [[ -z "$2" || "$2" == --* ]]; then + echo "Error: --template-or-config requires a string value" >&2 + usage + fi + template_or_config="$2" + shift 2 + ;; + *) + echo "Unknown option: $1" >&2 + usage + ;; + esac +done + +# Check required arguments +if [[ -z "$image_name" || -z "$container_base" || -z "$normal_deps" ]]; then + echo "Error: --image-name, --container-base, and --normal-deps are required." >&2 + usage +fi + CONTAINER_BINARY=${CONTAINER_BINARY:-docker} CONTAINER_OPTS=${CONTAINER_OPTS:---progress=plain} - TEMP_DIR=$(mktemp -d) - SCRIPT_DIR=$(dirname "$(readlink -f "$0")") source "$SCRIPT_DIR/common.sh" @@ -81,18 +132,15 @@ add_to_container() { if [ -t 0 ]; then printf '%s\n' "$1" >>"$output_file" else - # If stdin is not a terminal, read from it (heredoc) cat >>"$output_file" fi } -# Check if container command is available if ! is_command_available "$CONTAINER_BINARY"; then printf "${RED}Error: ${CONTAINER_BINARY} command not found. Is ${CONTAINER_BINARY} installed and in your PATH?${NC}" >&2 exit 1 fi -# Update and install UBI9 components if UBI9 base image is used if [[ $container_base == *"registry.access.redhat.com/ubi9"* ]]; then add_to_container << EOF FROM $container_base @@ -135,16 +183,16 @@ EOF # Add pip dependencies first since llama-stack is what will change most often # so we can reuse layers. -if [ -n "$pip_dependencies" ]; then - read -ra pip_args <<< "$pip_dependencies" +if [ -n "$normal_deps" ]; then + read -ra pip_args <<< "$normal_deps" quoted_deps=$(printf " %q" "${pip_args[@]}") add_to_container << EOF RUN $MOUNT_CACHE uv pip install $quoted_deps EOF fi -if [ -n "$special_pip_deps" ]; then - IFS='#' read -ra parts <<<"$special_pip_deps" +if [ -n "$optional_deps" ]; then + IFS='#' read -ra parts <<<"$optional_deps" for part in "${parts[@]}"; do read -ra pip_args <<< "$part" quoted_deps=$(printf " %q" "${pip_args[@]}") @@ -154,7 +202,33 @@ EOF done fi -# Function to get Python command +if [ -n "$external_provider_deps" ]; then + IFS='#' read -ra parts <<<"$external_provider_deps" + for part in "${parts[@]}"; do + read -ra pip_args <<< "$part" + quoted_deps=$(printf " %q" "${pip_args[@]}") + add_to_container <=')[0].split('<=')[0].split('!=')[0].split('<')[0].split('>')[0] + module = importlib.import_module(f'{package_name}.provider') + spec = module.get_provider_spec() + if hasattr(spec, 'pip_packages') and spec.pip_packages: + if isinstance(spec.pip_packages, (list, tuple)): + print('\n'.join(spec.pip_packages)) +except Exception as e: + print(f'Error getting provider spec for {package_name}: {e}', file=sys.stderr) +PYTHON +EOF + done +fi + get_python_cmd() { if is_command_available python; then echo "python" diff --git a/llama_stack/distribution/build_venv.sh b/llama_stack/distribution/build_venv.sh index 264cedf9c..93db9ab28 100755 --- a/llama_stack/distribution/build_venv.sh +++ b/llama_stack/distribution/build_venv.sh @@ -18,6 +18,76 @@ UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500} UV_SYSTEM_PYTHON=${UV_SYSTEM_PYTHON:-} VIRTUAL_ENV=${VIRTUAL_ENV:-} +set -euo pipefail + +# Define color codes +RED='\033[0;31m' +NC='\033[0m' # No Color + +SCRIPT_DIR=$(dirname "$(readlink -f "$0")") +source "$SCRIPT_DIR/common.sh" + +# Usage function +usage() { + echo "Usage: $0 --env-name --normal-deps [--external-provider-deps ] [--optional-deps ]" + echo "Example: $0 --env-name mybuild --normal-deps 'numpy pandas scipy' --external-provider-deps 'foo' --optional-deps 'bar'" + exit 1 +} + +# Parse arguments +env_name="" +normal_deps="" +external_provider_deps="" +optional_deps="" + +while [[ $# -gt 0 ]]; do + key="$1" + case "$key" in + --env-name) + if [[ -z "$2" || "$2" == --* ]]; then + echo "Error: --env-name requires a string value" >&2 + usage + fi + env_name="$2" + shift 2 + ;; + --normal-deps) + if [[ -z "$2" || "$2" == --* ]]; then + echo "Error: --normal-deps requires a string value" >&2 + usage + fi + normal_deps="$2" + shift 2 + ;; + --external-provider-deps) + if [[ -z "$2" || "$2" == --* ]]; then + echo "Error: --external-provider-deps requires a string value" >&2 + usage + fi + external_provider_deps="$2" + shift 2 + ;; + --optional-deps) + if [[ -z "$2" || "$2" == --* ]]; then + echo "Error: --optional-deps requires a string value" >&2 + usage + fi + optional_deps="$2" + shift 2 + ;; + *) + echo "Unknown option: $1" >&2 + usage + ;; + esac +done + +# Check required arguments +if [[ -z "$env_name" || -z "$normal_deps" ]]; then + echo "Error: --env-name and --normal-deps are required." >&2 + usage +fi + if [ -n "$LLAMA_STACK_DIR" ]; then echo "Using llama-stack-dir=$LLAMA_STACK_DIR" fi @@ -25,29 +95,6 @@ if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then echo "Using llama-stack-client-dir=$LLAMA_STACK_CLIENT_DIR" fi -if [ "$#" -lt 2 ]; then - echo "Usage: $0 []" >&2 - echo "Example: $0 mybuild ./my-stack-build.yaml 'numpy pandas scipy'" >&2 - exit 1 -fi - -special_pip_deps="$3" - -set -euo pipefail - -env_name="$1" -pip_dependencies="$2" - -# Define color codes -RED='\033[0;31m' -NC='\033[0m' # No Color - -# this is set if we actually create a new conda in which case we need to clean up -ENVNAME="" - -SCRIPT_DIR=$(dirname "$(readlink -f "$0")") -source "$SCRIPT_DIR/common.sh" - # pre-run checks to make sure we can proceed with the installation pre_run_checks() { local env_name="$1" @@ -71,49 +118,44 @@ pre_run_checks() { } run() { - local env_name="$1" - local pip_dependencies="$2" - local special_pip_deps="$3" - + # Use only global variables set by flag parser if [ -n "$UV_SYSTEM_PYTHON" ] || [ "$env_name" == "__system__" ]; then echo "Installing dependencies in system Python environment" - # if env == __system__, ensure we set UV_SYSTEM_PYTHON export UV_SYSTEM_PYTHON=1 elif [ "$VIRTUAL_ENV" == "$env_name" ]; then echo "Virtual environment $env_name is already active" else echo "Using virtual environment $env_name" uv venv "$env_name" - # shellcheck source=/dev/null source "$env_name/bin/activate" fi if [ -n "$TEST_PYPI_VERSION" ]; then - # these packages are damaged in test-pypi, so install them first uv pip install fastapi libcst - # shellcheck disable=SC2086 - # we are building a command line so word splitting is expected uv pip install --extra-index-url https://test.pypi.org/simple/ \ --index-strategy unsafe-best-match \ llama-stack=="$TEST_PYPI_VERSION" \ - $pip_dependencies - if [ -n "$special_pip_deps" ]; then - IFS='#' read -ra parts <<<"$special_pip_deps" + $normal_deps + if [ -n "$optional_deps" ]; then + IFS='#' read -ra parts <<<"$optional_deps" for part in "${parts[@]}"; do echo "$part" - # shellcheck disable=SC2086 - # we are building a command line so word splitting is expected uv pip install $part done fi + if [ -n "$external_provider_deps" ]; then + IFS='#' read -ra parts <<<"$external_provider_deps" + for part in "${parts[@]}"; do + echo "$part" + uv pip install "$part" + done + fi else - # Re-installing llama-stack in the new virtual environment if [ -n "$LLAMA_STACK_DIR" ]; then if [ ! -d "$LLAMA_STACK_DIR" ]; then printf "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: %s${NC}\n" "$LLAMA_STACK_DIR" >&2 exit 1 fi - printf "Installing from LLAMA_STACK_DIR: %s\n" "$LLAMA_STACK_DIR" uv pip install --no-cache-dir -e "$LLAMA_STACK_DIR" else @@ -125,27 +167,41 @@ run() { printf "${RED}Warning: LLAMA_STACK_CLIENT_DIR is set but directory does not exist: %s${NC}\n" "$LLAMA_STACK_CLIENT_DIR" >&2 exit 1 fi - printf "Installing from LLAMA_STACK_CLIENT_DIR: %s\n" "$LLAMA_STACK_CLIENT_DIR" uv pip install --no-cache-dir -e "$LLAMA_STACK_CLIENT_DIR" fi - # Install pip dependencies printf "Installing pip dependencies\n" - # shellcheck disable=SC2086 - # we are building a command line so word splitting is expected - uv pip install $pip_dependencies - if [ -n "$special_pip_deps" ]; then - IFS='#' read -ra parts <<<"$special_pip_deps" + uv pip install $normal_deps + if [ -n "$optional_deps" ]; then + IFS='#' read -ra parts <<<"$optional_deps" for part in "${parts[@]}"; do - echo "$part" - # shellcheck disable=SC2086 - # we are building a command line so word splitting is expected + echo "Installing special provider module: $part" uv pip install $part done fi + if [ -n "$external_provider_deps" ]; then + IFS='#' read -ra parts <<<"$external_provider_deps" + for part in "${parts[@]}"; do + echo "Installing external provider module: $part" + uv pip install "$part" + echo "Getting provider spec for module: $part and installing dependencies" + package_name=$(echo "$part" | sed 's/[<>=!].*//') + python3 -c " +import importlib +import sys +try: + module = importlib.import_module(f'$package_name.provider') + spec = module.get_provider_spec() + if hasattr(spec, 'pip_packages') and spec.pip_packages: + print('\\n'.join(spec.pip_packages)) +except Exception as e: + print(f'Error getting provider spec for $package_name: {e}', file=sys.stderr) +" | uv pip install -r - + done + fi fi } pre_run_checks "$env_name" -run "$env_name" "$pip_dependencies" "$special_pip_deps" +run diff --git a/llama_stack/distribution/configure.py b/llama_stack/distribution/configure.py index 2238eef93..355233d53 100644 --- a/llama_stack/distribution/configure.py +++ b/llama_stack/distribution/configure.py @@ -91,21 +91,21 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec logger.info(f"Configuring API `{api_str}`...") updated_providers = [] - for i, provider_type in enumerate(plist): + for i, provider in enumerate(plist): if i >= 1: - others = ", ".join(plist[i:]) + others = ", ".join(p.provider_type for p in plist[i:]) logger.info( f"Not configuring other providers ({others}) interactively. Please edit the resulting YAML directly.\n" ) break - logger.info(f"> Configuring provider `({provider_type})`") + logger.info(f"> Configuring provider `({provider.provider_type})`") updated_providers.append( configure_single_provider( provider_registry[api], Provider( - provider_id=(f"{provider_type}-{i:02d}" if len(plist) > 1 else provider_type), - provider_type=provider_type, + provider_id=(f"{provider.provider_id}-{i:02d}" if len(plist) > 1 else provider.provider_id), + provider_type=provider.provider_type, config={}, ), ) diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index f0b18606a..c17aadcc1 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -136,29 +136,40 @@ class RoutingTableProviderSpec(ProviderSpec): pip_packages: list[str] = Field(default_factory=list) +class Provider(BaseModel): + # provider_id of None means that the provider is not enabled - this happens + # when the provider is enabled via a conditional environment variable + provider_id: str | None + provider_type: str + config: dict[str, Any] = {} + module: str | None = Field( + default=None, + description=""" + Fully-qualified name of the external provider module to import. The module is expected to have: + + - `get_adapter_impl(config, deps)`: returns the adapter implementation + + Example: `module: ramalama_stack` + """, + ) + + class DistributionSpec(BaseModel): description: str | None = Field( default="", description="Description of the distribution", ) container_image: str | None = None - providers: dict[str, str | list[str]] = Field( + providers: dict[str, list[Provider]] = Field( default_factory=dict, description=""" -Provider Types for each of the APIs provided by this distribution. If you -select multiple providers, you should provide an appropriate 'routing_map' -in the runtime configuration to help route to the correct provider.""", + Provider Types for each of the APIs provided by this distribution. If you + select multiple providers, you should provide an appropriate 'routing_map' + in the runtime configuration to help route to the correct provider. + """, ) -class Provider(BaseModel): - # provider_id of None means that the provider is not enabled - this happens - # when the provider is enabled via a conditional environment variable - provider_id: str | None - provider_type: str - config: dict[str, Any] - - class LoggingConfig(BaseModel): category_levels: dict[str, str] = Field( default_factory=dict, diff --git a/llama_stack/distribution/distribution.py b/llama_stack/distribution/distribution.py index 929e11286..6e7297e32 100644 --- a/llama_stack/distribution/distribution.py +++ b/llama_stack/distribution/distribution.py @@ -12,6 +12,7 @@ from typing import Any import yaml from pydantic import BaseModel +from llama_stack.distribution.datatypes import BuildConfig, DistributionSpec from llama_stack.distribution.external import load_external_apis from llama_stack.log import get_logger from llama_stack.providers.datatypes import ( @@ -97,12 +98,10 @@ def _load_inline_provider_spec(spec_data: dict[str, Any], api: Api, provider_nam return spec -def get_provider_registry( - config=None, -) -> dict[Api, dict[str, ProviderSpec]]: +def get_provider_registry(config=None) -> dict[Api, dict[str, ProviderSpec]]: """Get the provider registry, optionally including external providers. - This function loads both built-in providers and external providers from YAML files. + This function loads both built-in providers and external providers from YAML files or from their provided modules. External providers are loaded from a directory structure like: providers.d/ @@ -123,8 +122,13 @@ def get_provider_registry( safety/ llama-guard.yaml + This method is overloaded in that it can be called from a variety of places: during build, during run, during stack construction. + So when building external providers from a module, there are scenarios where the pip package required to import the module might not be available yet. + There is special handling for all of the potential cases this method can be called from. + Args: config: Optional object containing the external providers directory path + building: Optional bool delineating whether or not this is being called from a build process Returns: A dictionary mapping APIs to their available providers @@ -162,46 +166,112 @@ def get_provider_registry( "Install the API package to load any in-tree providers for this API." ) - # Check if config has the external_providers_dir attribute - if config and hasattr(config, "external_providers_dir") and config.external_providers_dir: - external_providers_dir = os.path.abspath(os.path.expanduser(config.external_providers_dir)) - if not os.path.exists(external_providers_dir): - raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}") - logger.info(f"Loading external providers from {external_providers_dir}") + # Check if config has external providers + if config: + if hasattr(config, "external_providers_dir") and config.external_providers_dir: + registry = get_external_providers_from_dir(registry, config) + # else lets check for modules in each provider + registry = get_external_providers_from_module( + registry=registry, + config=config, + building=(isinstance(config, BuildConfig) or isinstance(config, DistributionSpec)), + ) - for api in providable_apis(): - api_name = api.name.lower() - - # Process both remote and inline providers - for provider_type in ["remote", "inline"]: - api_dir = os.path.join(external_providers_dir, provider_type, api_name) - if not os.path.exists(api_dir): - logger.debug(f"No {provider_type} provider directory found for {api_name}") - continue - - # Look for provider spec files in the API directory - for spec_path in glob.glob(os.path.join(api_dir, "*.yaml")): - provider_name = os.path.splitext(os.path.basename(spec_path))[0] - logger.info(f"Loading {provider_type} provider spec from {spec_path}") - - try: - with open(spec_path) as f: - spec_data = yaml.safe_load(f) - - if provider_type == "remote": - spec = _load_remote_provider_spec(spec_data, api) - provider_type_key = f"remote::{provider_name}" - else: - spec = _load_inline_provider_spec(spec_data, api, provider_name) - provider_type_key = f"inline::{provider_name}" - if provider_type_key in registry[api]: - logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}") - registry[api][provider_type_key] = spec - logger.info(f"Successfully loaded external provider {provider_type_key}") - except yaml.YAMLError as yaml_err: - logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}") - raise yaml_err - except Exception as e: - logger.error(f"Failed to load provider spec from {spec_path}: {e}") - raise e + return registry + + +def get_external_providers_from_dir( + registry: dict[Api, dict[str, ProviderSpec]], config +) -> dict[Api, dict[str, ProviderSpec]]: + logger.warning( + "Specifying external providers via `external_providers_dir` is being deprecated. Please specify `module:` in the provider instead." + ) + external_providers_dir = os.path.abspath(os.path.expanduser(config.external_providers_dir)) + if not os.path.exists(external_providers_dir): + raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}") + logger.info(f"Loading external providers from {external_providers_dir}") + + for api in providable_apis(): + api_name = api.name.lower() + + # Process both remote and inline providers + for provider_type in ["remote", "inline"]: + api_dir = os.path.join(external_providers_dir, provider_type, api_name) + if not os.path.exists(api_dir): + logger.debug(f"No {provider_type} provider directory found for {api_name}") + continue + + # Look for provider spec files in the API directory + for spec_path in glob.glob(os.path.join(api_dir, "*.yaml")): + provider_name = os.path.splitext(os.path.basename(spec_path))[0] + logger.info(f"Loading {provider_type} provider spec from {spec_path}") + + try: + with open(spec_path) as f: + spec_data = yaml.safe_load(f) + + if provider_type == "remote": + spec = _load_remote_provider_spec(spec_data, api) + provider_type_key = f"remote::{provider_name}" + else: + spec = _load_inline_provider_spec(spec_data, api, provider_name) + provider_type_key = f"inline::{provider_name}" + + logger.info(f"Loaded {provider_type} provider spec for {provider_type_key} from {spec_path}") + if provider_type_key in registry[api]: + logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}") + registry[api][provider_type_key] = spec + logger.info(f"Successfully loaded external provider {provider_type_key}") + except yaml.YAMLError as yaml_err: + logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}") + raise yaml_err + except Exception as e: + logger.error(f"Failed to load provider spec from {spec_path}: {e}") + raise e + + return registry + + +def get_external_providers_from_module( + registry: dict[Api, dict[str, ProviderSpec]], config, building: bool +) -> dict[Api, dict[str, ProviderSpec]]: + provider_list = None + if isinstance(config, BuildConfig): + provider_list = config.distribution_spec.providers.items() + else: + provider_list = config.providers.items() + if provider_list is None: + logger.warning("Could not get list of providers from config") + return registry + for provider_api, providers in provider_list: + for provider in providers: + if not hasattr(provider, "module") or provider.module is None: + continue + # get provider using module + try: + if not building: + package_name = provider.module.split("==")[0] + module = importlib.import_module(f"{package_name}.provider") + # if config class is wrong you will get an error saying module could not be imported + spec = module.get_provider_spec() + else: + # pass in a partially filled out provider spec to satisfy the registry -- knowing we will be overwriting it later upon build and run + spec = ProviderSpec( + api=Api(provider_api), + provider_type=provider.provider_type, + is_external=True, + module=provider.module, + config_class="", + ) + provider_type = provider.provider_type + # in the case we are building we CANNOT import this module of course because it has not been installed. + # return a partially filled out spec that the build script will populate. + registry[Api(provider_api)][provider_type] = spec + except ModuleNotFoundError as exc: + raise ValueError( + "get_provider_spec not found. If specifying an external provider via `module` in the Provider spec, the Provider must have the `provider.get_provider_spec` module available" + ) from exc + except Exception as e: + logger.error(f"Failed to load provider spec from module {provider.module}: {e}") + raise e return registry diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index 07949aea7..bcb0b9167 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -249,15 +249,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): file=sys.stderr, ) if self.config_path_or_template_name.endswith(".yaml"): - # Convert Provider objects to their types - provider_types: dict[str, str | list[str]] = {} - for api, providers in self.config.providers.items(): - types = [p.provider_type for p in providers] - # Convert single-item lists to strings - provider_types[api] = types[0] if len(types) == 1 else types build_config = BuildConfig( distribution_spec=DistributionSpec( - providers=provider_types, + providers=self.config.providers, ), external_providers_dir=self.config.external_providers_dir, ) diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 95017debb..db6856ed2 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -345,7 +345,7 @@ async def instantiate_provider( policy: list[AccessRule], ): provider_spec = provider.spec - if not hasattr(provider_spec, "module"): + if not hasattr(provider_spec, "module") or provider_spec.module is None: raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute") logger.debug(f"Instantiating provider {provider.provider_id} from {provider_spec.module}") diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 005bfbab8..055bf5232 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -115,6 +115,19 @@ class ProviderSpec(BaseModel): description="If this provider is deprecated and does NOT work, specify the error message here", ) + module: str | None = Field( + default=None, + description=""" + Fully-qualified name of the module to import. The module is expected to have: + + - `get_adapter_impl(config, deps)`: returns the adapter implementation + + Example: `module: ramalama_stack` + """, + ) + + is_external: bool = Field(default=False, description="Notes whether this provider is an external provider.") + # used internally by the resolver; this is a hack for now deps__: list[str] = Field(default_factory=list) @@ -135,7 +148,7 @@ class AdapterSpec(BaseModel): description="Unique identifier for this adapter", ) module: str = Field( - ..., + default_factory=str, description=""" Fully-qualified name of the module to import. The module is expected to have: @@ -173,14 +186,7 @@ The container image to use for this implementation. If one is provided, pip_pack If a provider depends on other providers, the dependencies MUST NOT specify a container image. """, ) - module: str = Field( - ..., - description=""" -Fully-qualified name of the module to import. The module is expected to have: - - - `get_provider_impl(config, deps)`: returns the local implementation -""", - ) + # module field is inherited from ProviderSpec provider_data_validator: str | None = Field( default=None, ) @@ -223,9 +229,7 @@ API responses, specify the adapter here. def container_image(self) -> str | None: return None - @property - def module(self) -> str: - return self.adapter.module + # module field is inherited from ProviderSpec @property def pip_packages(self) -> list[str]: @@ -243,6 +247,7 @@ def remote_provider_spec( api=api, provider_type=f"remote::{adapter.adapter_type}", config_class=adapter.config_class, + module=adapter.module, adapter=adapter, api_dependencies=api_dependencies or [], ) diff --git a/llama_stack/templates/ci-tests/build.yaml b/llama_stack/templates/ci-tests/build.yaml index 625e36e4f..2f18e5d26 100644 --- a/llama_stack/templates/ci-tests/build.yaml +++ b/llama_stack/templates/ci-tests/build.yaml @@ -3,57 +3,98 @@ distribution_spec: description: CI tests for Llama Stack providers: inference: - - remote::cerebras - - remote::ollama - - remote::vllm - - remote::tgi - - remote::hf::serverless - - remote::hf::endpoint - - remote::fireworks - - remote::together - - remote::bedrock - - remote::databricks - - remote::nvidia - - remote::runpod - - remote::openai - - remote::anthropic - - remote::gemini - - remote::groq - - remote::llama-openai-compat - - remote::sambanova - - remote::passthrough - - inline::sentence-transformers + - provider_id: ${env.ENABLE_CEREBRAS:=__disabled__} + provider_type: remote::cerebras + - provider_id: ${env.ENABLE_OLLAMA:=__disabled__} + provider_type: remote::ollama + - provider_id: ${env.ENABLE_VLLM:=__disabled__} + provider_type: remote::vllm + - provider_id: ${env.ENABLE_TGI:=__disabled__} + provider_type: remote::tgi + - provider_id: ${env.ENABLE_HF_SERVERLESS:=__disabled__} + provider_type: remote::hf::serverless + - provider_id: ${env.ENABLE_HF_ENDPOINT:=__disabled__} + provider_type: remote::hf::endpoint + - provider_id: ${env.ENABLE_FIREWORKS:=__disabled__} + provider_type: remote::fireworks + - provider_id: ${env.ENABLE_TOGETHER:=__disabled__} + provider_type: remote::together + - provider_id: ${env.ENABLE_BEDROCK:=__disabled__} + provider_type: remote::bedrock + - provider_id: ${env.ENABLE_DATABRICKS:=__disabled__} + provider_type: remote::databricks + - provider_id: ${env.ENABLE_NVIDIA:=__disabled__} + provider_type: remote::nvidia + - provider_id: ${env.ENABLE_RUNPOD:=__disabled__} + provider_type: remote::runpod + - provider_id: ${env.ENABLE_OPENAI:=__disabled__} + provider_type: remote::openai + - provider_id: ${env.ENABLE_ANTHROPIC:=__disabled__} + provider_type: remote::anthropic + - provider_id: ${env.ENABLE_GEMINI:=__disabled__} + provider_type: remote::gemini + - provider_id: ${env.ENABLE_GROQ:=__disabled__} + provider_type: remote::groq + - provider_id: ${env.ENABLE_LLAMA_OPENAI_COMPAT:=__disabled__} + provider_type: remote::llama-openai-compat + - provider_id: ${env.ENABLE_SAMBANOVA:=__disabled__} + provider_type: remote::sambanova + - provider_id: ${env.ENABLE_PASSTHROUGH:=__disabled__} + provider_type: remote::passthrough + - provider_id: sentence-transformers + provider_type: inline::sentence-transformers vector_io: - - inline::faiss - - inline::sqlite-vec - - inline::milvus - - remote::chromadb - - remote::pgvector + - provider_id: ${env.ENABLE_FAISS:=faiss} + provider_type: inline::faiss + - provider_id: ${env.ENABLE_SQLITE_VEC:=__disabled__} + provider_type: inline::sqlite-vec + - provider_id: ${env.ENABLE_MILVUS:=__disabled__} + provider_type: inline::milvus + - provider_id: ${env.ENABLE_CHROMADB:=__disabled__} + provider_type: remote::chromadb + - provider_id: ${env.ENABLE_PGVECTOR:=__disabled__} + provider_type: remote::pgvector files: - - inline::localfs + - provider_id: localfs + provider_type: inline::localfs safety: - - inline::llama-guard + - provider_id: llama-guard + provider_type: inline::llama-guard agents: - - inline::meta-reference + - provider_id: meta-reference + provider_type: inline::meta-reference telemetry: - - inline::meta-reference + - provider_id: meta-reference + provider_type: inline::meta-reference post_training: - - inline::huggingface + - provider_id: huggingface + provider_type: inline::huggingface eval: - - inline::meta-reference + - provider_id: meta-reference + provider_type: inline::meta-reference datasetio: - - remote::huggingface - - inline::localfs + - provider_id: huggingface + provider_type: remote::huggingface + - provider_id: localfs + provider_type: inline::localfs scoring: - - inline::basic - - inline::llm-as-judge - - inline::braintrust + - provider_id: basic + provider_type: inline::basic + - provider_id: llm-as-judge + provider_type: inline::llm-as-judge + - provider_id: braintrust + provider_type: inline::braintrust tool_runtime: - - remote::brave-search - - remote::tavily-search - - inline::rag-runtime - - remote::model-context-protocol + - provider_id: brave-search + provider_type: remote::brave-search + - provider_id: tavily-search + provider_type: remote::tavily-search + - provider_id: rag-runtime + provider_type: inline::rag-runtime + - provider_id: model-context-protocol + provider_type: remote::model-context-protocol image_type: conda +image_name: ci-tests additional_pip_packages: - aiosqlite - asyncpg diff --git a/llama_stack/templates/ci-tests/run.yaml b/llama_stack/templates/ci-tests/run.yaml index 1396d54a8..6f8a192ee 100644 --- a/llama_stack/templates/ci-tests/run.yaml +++ b/llama_stack/templates/ci-tests/run.yaml @@ -56,7 +56,6 @@ providers: api_key: ${env.TOGETHER_API_KEY} - provider_id: ${env.ENABLE_BEDROCK:=__disabled__} provider_type: remote::bedrock - config: {} - provider_id: ${env.ENABLE_DATABRICKS:=__disabled__} provider_type: remote::databricks config: @@ -107,7 +106,6 @@ providers: api_key: ${env.PASSTHROUGH_API_KEY} - provider_id: ${env.ENABLE_SENTENCE_TRANSFORMERS:=sentence-transformers} provider_type: inline::sentence-transformers - config: {} vector_io: - provider_id: ${env.ENABLE_FAISS:=faiss} provider_type: inline::faiss @@ -208,10 +206,8 @@ providers: scoring: - provider_id: basic provider_type: inline::basic - config: {} - provider_id: llm-as-judge provider_type: inline::llm-as-judge - config: {} - provider_id: braintrust provider_type: inline::braintrust config: @@ -229,10 +225,8 @@ providers: max_results: 3 - provider_id: rag-runtime provider_type: inline::rag-runtime - config: {} - provider_id: model-context-protocol provider_type: remote::model-context-protocol - config: {} metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/registry.db diff --git a/llama_stack/templates/dell/build.yaml b/llama_stack/templates/dell/build.yaml index ff8d58a08..d19934ee5 100644 --- a/llama_stack/templates/dell/build.yaml +++ b/llama_stack/templates/dell/build.yaml @@ -4,32 +4,50 @@ distribution_spec: container providers: inference: - - remote::tgi - - inline::sentence-transformers + - provider_id: tgi + provider_type: remote::tgi + - provider_id: sentence-transformers + provider_type: inline::sentence-transformers vector_io: - - inline::faiss - - remote::chromadb - - remote::pgvector + - provider_id: faiss + provider_type: inline::faiss + - provider_id: chromadb + provider_type: remote::chromadb + - provider_id: pgvector + provider_type: remote::pgvector safety: - - inline::llama-guard + - provider_id: llama-guard + provider_type: inline::llama-guard agents: - - inline::meta-reference + - provider_id: meta-reference + provider_type: inline::meta-reference telemetry: - - inline::meta-reference + - provider_id: meta-reference + provider_type: inline::meta-reference eval: - - inline::meta-reference + - provider_id: meta-reference + provider_type: inline::meta-reference datasetio: - - remote::huggingface - - inline::localfs + - provider_id: huggingface + provider_type: remote::huggingface + - provider_id: localfs + provider_type: inline::localfs scoring: - - inline::basic - - inline::llm-as-judge - - inline::braintrust + - provider_id: basic + provider_type: inline::basic + - provider_id: llm-as-judge + provider_type: inline::llm-as-judge + - provider_id: braintrust + provider_type: inline::braintrust tool_runtime: - - remote::brave-search - - remote::tavily-search - - inline::rag-runtime + - provider_id: brave-search + provider_type: remote::brave-search + - provider_id: tavily-search + provider_type: remote::tavily-search + - provider_id: rag-runtime + provider_type: inline::rag-runtime image_type: conda +image_name: dell additional_pip_packages: - aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/dell/dell.py b/llama_stack/templates/dell/dell.py index 5a6f52a89..b2210e7dc 100644 --- a/llama_stack/templates/dell/dell.py +++ b/llama_stack/templates/dell/dell.py @@ -19,18 +19,32 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin def get_distribution_template() -> DistributionTemplate: providers = { - "inference": ["remote::tgi", "inline::sentence-transformers"], - "vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"], - "safety": ["inline::llama-guard"], - "agents": ["inline::meta-reference"], - "telemetry": ["inline::meta-reference"], - "eval": ["inline::meta-reference"], - "datasetio": ["remote::huggingface", "inline::localfs"], - "scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"], + "inference": [ + Provider(provider_id="tgi", provider_type="remote::tgi"), + Provider(provider_id="sentence-transformers", provider_type="inline::sentence-transformers"), + ], + "vector_io": [ + Provider(provider_id="faiss", provider_type="inline::faiss"), + Provider(provider_id="chromadb", provider_type="remote::chromadb"), + Provider(provider_id="pgvector", provider_type="remote::pgvector"), + ], + "safety": [Provider(provider_id="llama-guard", provider_type="inline::llama-guard")], + "agents": [Provider(provider_id="meta-reference", provider_type="inline::meta-reference")], + "telemetry": [Provider(provider_id="meta-reference", provider_type="inline::meta-reference")], + "eval": [Provider(provider_id="meta-reference", provider_type="inline::meta-reference")], + "datasetio": [ + Provider(provider_id="huggingface", provider_type="remote::huggingface"), + Provider(provider_id="localfs", provider_type="inline::localfs"), + ], + "scoring": [ + Provider(provider_id="basic", provider_type="inline::basic"), + Provider(provider_id="llm-as-judge", provider_type="inline::llm-as-judge"), + Provider(provider_id="braintrust", provider_type="inline::braintrust"), + ], "tool_runtime": [ - "remote::brave-search", - "remote::tavily-search", - "inline::rag-runtime", + Provider(provider_id="brave-search", provider_type="remote::brave-search"), + Provider(provider_id="tavily-search", provider_type="remote::tavily-search"), + Provider(provider_id="rag-runtime", provider_type="inline::rag-runtime"), ], } name = "dell" diff --git a/llama_stack/templates/dell/run-with-safety.yaml b/llama_stack/templates/dell/run-with-safety.yaml index 768fad4fa..ecc6729eb 100644 --- a/llama_stack/templates/dell/run-with-safety.yaml +++ b/llama_stack/templates/dell/run-with-safety.yaml @@ -22,7 +22,6 @@ providers: url: ${env.DEH_SAFETY_URL} - provider_id: sentence-transformers provider_type: inline::sentence-transformers - config: {} vector_io: - provider_id: chromadb provider_type: remote::chromadb @@ -74,10 +73,8 @@ providers: scoring: - provider_id: basic provider_type: inline::basic - config: {} - provider_id: llm-as-judge provider_type: inline::llm-as-judge - config: {} - provider_id: braintrust provider_type: inline::braintrust config: @@ -95,7 +92,6 @@ providers: max_results: 3 - provider_id: rag-runtime provider_type: inline::rag-runtime - config: {} metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/registry.db diff --git a/llama_stack/templates/dell/run.yaml b/llama_stack/templates/dell/run.yaml index de2ada009..fc2553526 100644 --- a/llama_stack/templates/dell/run.yaml +++ b/llama_stack/templates/dell/run.yaml @@ -18,7 +18,6 @@ providers: url: ${env.DEH_URL} - provider_id: sentence-transformers provider_type: inline::sentence-transformers - config: {} vector_io: - provider_id: chromadb provider_type: remote::chromadb @@ -70,10 +69,8 @@ providers: scoring: - provider_id: basic provider_type: inline::basic - config: {} - provider_id: llm-as-judge provider_type: inline::llm-as-judge - config: {} - provider_id: braintrust provider_type: inline::braintrust config: @@ -91,7 +88,6 @@ providers: max_results: 3 - provider_id: rag-runtime provider_type: inline::rag-runtime - config: {} metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/registry.db diff --git a/llama_stack/templates/meta-reference-gpu/build.yaml b/llama_stack/templates/meta-reference-gpu/build.yaml index 2119eeddd..0a0bc0aea 100644 --- a/llama_stack/templates/meta-reference-gpu/build.yaml +++ b/llama_stack/templates/meta-reference-gpu/build.yaml @@ -3,32 +3,50 @@ distribution_spec: description: Use Meta Reference for running LLM inference providers: inference: - - inline::meta-reference + - provider_id: meta-reference + provider_type: inline::meta-reference vector_io: - - inline::faiss - - remote::chromadb - - remote::pgvector + - provider_id: faiss + provider_type: inline::faiss + - provider_id: chromadb + provider_type: remote::chromadb + - provider_id: pgvector + provider_type: remote::pgvector safety: - - inline::llama-guard + - provider_id: llama-guard + provider_type: inline::llama-guard agents: - - inline::meta-reference + - provider_id: meta-reference + provider_type: inline::meta-reference telemetry: - - inline::meta-reference + - provider_id: meta-reference + provider_type: inline::meta-reference eval: - - inline::meta-reference + - provider_id: meta-reference + provider_type: inline::meta-reference datasetio: - - remote::huggingface - - inline::localfs + - provider_id: huggingface + provider_type: remote::huggingface + - provider_id: localfs + provider_type: inline::localfs scoring: - - inline::basic - - inline::llm-as-judge - - inline::braintrust + - provider_id: basic + provider_type: inline::basic + - provider_id: llm-as-judge + provider_type: inline::llm-as-judge + - provider_id: braintrust + provider_type: inline::braintrust tool_runtime: - - remote::brave-search - - remote::tavily-search - - inline::rag-runtime - - remote::model-context-protocol + - provider_id: brave-search + provider_type: remote::brave-search + - provider_id: tavily-search + provider_type: remote::tavily-search + - provider_id: rag-runtime + provider_type: inline::rag-runtime + - provider_id: model-context-protocol + provider_type: remote::model-context-protocol image_type: conda +image_name: meta-reference-gpu additional_pip_packages: - aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/meta-reference-gpu/meta_reference.py b/llama_stack/templates/meta-reference-gpu/meta_reference.py index 4bfb4e9d8..6ca500eff 100644 --- a/llama_stack/templates/meta-reference-gpu/meta_reference.py +++ b/llama_stack/templates/meta-reference-gpu/meta_reference.py @@ -25,19 +25,91 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin def get_distribution_template() -> DistributionTemplate: providers = { - "inference": ["inline::meta-reference"], - "vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"], - "safety": ["inline::llama-guard"], - "agents": ["inline::meta-reference"], - "telemetry": ["inline::meta-reference"], - "eval": ["inline::meta-reference"], - "datasetio": ["remote::huggingface", "inline::localfs"], - "scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"], + "inference": [ + Provider( + provider_id="meta-reference", + provider_type="inline::meta-reference", + ) + ], + "vector_io": [ + Provider( + provider_id="faiss", + provider_type="inline::faiss", + ), + Provider( + provider_id="chromadb", + provider_type="remote::chromadb", + ), + Provider( + provider_id="pgvector", + provider_type="remote::pgvector", + ), + ], + "safety": [ + Provider( + provider_id="llama-guard", + provider_type="inline::llama-guard", + ) + ], + "agents": [ + Provider( + provider_id="meta-reference", + provider_type="inline::meta-reference", + ) + ], + "telemetry": [ + Provider( + provider_id="meta-reference", + provider_type="inline::meta-reference", + ) + ], + "eval": [ + Provider( + provider_id="meta-reference", + provider_type="inline::meta-reference", + ) + ], + "datasetio": [ + Provider( + provider_id="huggingface", + provider_type="remote::huggingface", + ), + Provider( + provider_id="localfs", + provider_type="inline::localfs", + ), + ], + "scoring": [ + Provider( + provider_id="basic", + provider_type="inline::basic", + ), + Provider( + provider_id="llm-as-judge", + provider_type="inline::llm-as-judge", + ), + Provider( + provider_id="braintrust", + provider_type="inline::braintrust", + ), + ], "tool_runtime": [ - "remote::brave-search", - "remote::tavily-search", - "inline::rag-runtime", - "remote::model-context-protocol", + Provider( + provider_id="brave-search", + provider_type="remote::brave-search", + ), + Provider( + provider_id="tavily-search", + provider_type="remote::tavily-search", + ), + Provider( + provider_id="rag-runtime", + provider_type="inline::rag-runtime", + ), + Provider( + provider_id="model-context-protocol", + provider_type="remote::model-context-protocol", + ), ], } name = "meta-reference-gpu" diff --git a/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml b/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml index 49657a680..910f9ec46 100644 --- a/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml +++ b/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml @@ -24,7 +24,6 @@ providers: max_seq_len: ${env.MAX_SEQ_LEN:=4096} - provider_id: sentence-transformers provider_type: inline::sentence-transformers - config: {} - provider_id: meta-reference-safety provider_type: inline::meta-reference config: @@ -88,10 +87,8 @@ providers: scoring: - provider_id: basic provider_type: inline::basic - config: {} - provider_id: llm-as-judge provider_type: inline::llm-as-judge - config: {} - provider_id: braintrust provider_type: inline::braintrust config: @@ -109,10 +106,8 @@ providers: max_results: 3 - provider_id: rag-runtime provider_type: inline::rag-runtime - config: {} - provider_id: model-context-protocol provider_type: remote::model-context-protocol - config: {} metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/registry.db diff --git a/llama_stack/templates/meta-reference-gpu/run.yaml b/llama_stack/templates/meta-reference-gpu/run.yaml index 2923b5faf..5266f3c84 100644 --- a/llama_stack/templates/meta-reference-gpu/run.yaml +++ b/llama_stack/templates/meta-reference-gpu/run.yaml @@ -24,7 +24,6 @@ providers: max_seq_len: ${env.MAX_SEQ_LEN:=4096} - provider_id: sentence-transformers provider_type: inline::sentence-transformers - config: {} vector_io: - provider_id: faiss provider_type: inline::faiss @@ -78,10 +77,8 @@ providers: scoring: - provider_id: basic provider_type: inline::basic - config: {} - provider_id: llm-as-judge provider_type: inline::llm-as-judge - config: {} - provider_id: braintrust provider_type: inline::braintrust config: @@ -99,10 +96,8 @@ providers: max_results: 3 - provider_id: rag-runtime provider_type: inline::rag-runtime - config: {} - provider_id: model-context-protocol provider_type: remote::model-context-protocol - config: {} metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/registry.db diff --git a/llama_stack/templates/nvidia/build.yaml b/llama_stack/templates/nvidia/build.yaml index 51685b2e3..572a70408 100644 --- a/llama_stack/templates/nvidia/build.yaml +++ b/llama_stack/templates/nvidia/build.yaml @@ -3,27 +3,39 @@ distribution_spec: description: Use NVIDIA NIM for running LLM inference, evaluation and safety providers: inference: - - remote::nvidia + - provider_id: nvidia + provider_type: remote::nvidia vector_io: - - inline::faiss + - provider_id: faiss + provider_type: inline::faiss safety: - - remote::nvidia + - provider_id: nvidia + provider_type: remote::nvidia agents: - - inline::meta-reference + - provider_id: meta-reference + provider_type: inline::meta-reference telemetry: - - inline::meta-reference + - provider_id: meta-reference + provider_type: inline::meta-reference eval: - - remote::nvidia + - provider_id: nvidia + provider_type: remote::nvidia post_training: - - remote::nvidia + - provider_id: nvidia + provider_type: remote::nvidia datasetio: - - inline::localfs - - remote::nvidia + - provider_id: localfs + provider_type: inline::localfs + - provider_id: nvidia + provider_type: remote::nvidia scoring: - - inline::basic + - provider_id: basic + provider_type: inline::basic tool_runtime: - - inline::rag-runtime + - provider_id: rag-runtime + provider_type: inline::rag-runtime image_type: conda +image_name: nvidia additional_pip_packages: - aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/nvidia/nvidia.py b/llama_stack/templates/nvidia/nvidia.py index e5c13aa74..25beeae75 100644 --- a/llama_stack/templates/nvidia/nvidia.py +++ b/llama_stack/templates/nvidia/nvidia.py @@ -17,16 +17,65 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin def get_distribution_template() -> DistributionTemplate: providers = { - "inference": ["remote::nvidia"], - "vector_io": ["inline::faiss"], - "safety": ["remote::nvidia"], - "agents": ["inline::meta-reference"], - "telemetry": ["inline::meta-reference"], - "eval": ["remote::nvidia"], - "post_training": ["remote::nvidia"], - "datasetio": ["inline::localfs", "remote::nvidia"], - "scoring": ["inline::basic"], - "tool_runtime": ["inline::rag-runtime"], + "inference": [ + Provider( + provider_id="nvidia", + provider_type="remote::nvidia", + ) + ], + "vector_io": [ + Provider( + provider_id="faiss", + provider_type="inline::faiss", + ) + ], + "safety": [ + Provider( + provider_id="nvidia", + provider_type="remote::nvidia", + ) + ], + "agents": [ + Provider( + provider_id="meta-reference", + provider_type="inline::meta-reference", + ) + ], + "telemetry": [ + Provider( + provider_id="meta-reference", + provider_type="inline::meta-reference", + ) + ], + "eval": [ + Provider( + provider_id="nvidia", + provider_type="remote::nvidia", + ) + ], + "post_training": [Provider(provider_id="nvidia", provider_type="remote::nvidia", config={})], + "datasetio": [ + Provider( + provider_id="localfs", + provider_type="inline::localfs", + ), + Provider( + provider_id="nvidia", + provider_type="remote::nvidia", + ), + ], + "scoring": [ + Provider( + provider_id="basic", + provider_type="inline::basic", + ) + ], + "tool_runtime": [ + Provider( + provider_id="rag-runtime", + provider_type="inline::rag-runtime", + ) + ], } inference_provider = Provider( diff --git a/llama_stack/templates/nvidia/run-with-safety.yaml b/llama_stack/templates/nvidia/run-with-safety.yaml index 7017a5955..015724050 100644 --- a/llama_stack/templates/nvidia/run-with-safety.yaml +++ b/llama_stack/templates/nvidia/run-with-safety.yaml @@ -85,11 +85,9 @@ providers: scoring: - provider_id: basic provider_type: inline::basic - config: {} tool_runtime: - provider_id: rag-runtime provider_type: inline::rag-runtime - config: {} metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/registry.db diff --git a/llama_stack/templates/nvidia/run.yaml b/llama_stack/templates/nvidia/run.yaml index ccddf11a2..f087e89ee 100644 --- a/llama_stack/templates/nvidia/run.yaml +++ b/llama_stack/templates/nvidia/run.yaml @@ -74,11 +74,9 @@ providers: scoring: - provider_id: basic provider_type: inline::basic - config: {} tool_runtime: - provider_id: rag-runtime provider_type: inline::rag-runtime - config: {} metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/registry.db diff --git a/llama_stack/templates/open-benchmark/build.yaml b/llama_stack/templates/open-benchmark/build.yaml index 5f82c5243..6647b471c 100644 --- a/llama_stack/templates/open-benchmark/build.yaml +++ b/llama_stack/templates/open-benchmark/build.yaml @@ -3,36 +3,58 @@ distribution_spec: description: Distribution for running open benchmarks providers: inference: - - remote::openai - - remote::anthropic - - remote::gemini - - remote::groq - - remote::together + - provider_id: openai + provider_type: remote::openai + - provider_id: anthropic + provider_type: remote::anthropic + - provider_id: gemini + provider_type: remote::gemini + - provider_id: groq + provider_type: remote::groq + - provider_id: together + provider_type: remote::together vector_io: - - inline::sqlite-vec - - remote::chromadb - - remote::pgvector + - provider_id: sqlite-vec + provider_type: inline::sqlite-vec + - provider_id: chromadb + provider_type: remote::chromadb + - provider_id: pgvector + provider_type: remote::pgvector safety: - - inline::llama-guard + - provider_id: llama-guard + provider_type: inline::llama-guard agents: - - inline::meta-reference + - provider_id: meta-reference + provider_type: inline::meta-reference telemetry: - - inline::meta-reference + - provider_id: meta-reference + provider_type: inline::meta-reference eval: - - inline::meta-reference + - provider_id: meta-reference + provider_type: inline::meta-reference datasetio: - - remote::huggingface - - inline::localfs + - provider_id: huggingface + provider_type: remote::huggingface + - provider_id: localfs + provider_type: inline::localfs scoring: - - inline::basic - - inline::llm-as-judge - - inline::braintrust + - provider_id: basic + provider_type: inline::basic + - provider_id: llm-as-judge + provider_type: inline::llm-as-judge + - provider_id: braintrust + provider_type: inline::braintrust tool_runtime: - - remote::brave-search - - remote::tavily-search - - inline::rag-runtime - - remote::model-context-protocol + - provider_id: brave-search + provider_type: remote::brave-search + - provider_id: tavily-search + provider_type: remote::tavily-search + - provider_id: rag-runtime + provider_type: inline::rag-runtime + - provider_id: model-context-protocol + provider_type: remote::model-context-protocol image_type: conda +image_name: open-benchmark additional_pip_packages: - aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/open-benchmark/open_benchmark.py b/llama_stack/templates/open-benchmark/open_benchmark.py index ae25c9fc9..3a17e7525 100644 --- a/llama_stack/templates/open-benchmark/open_benchmark.py +++ b/llama_stack/templates/open-benchmark/open_benchmark.py @@ -96,19 +96,33 @@ def get_inference_providers() -> tuple[list[Provider], dict[str, list[ProviderMo def get_distribution_template() -> DistributionTemplate: inference_providers, available_models = get_inference_providers() providers = { - "inference": [p.provider_type for p in inference_providers], - "vector_io": ["inline::sqlite-vec", "remote::chromadb", "remote::pgvector"], - "safety": ["inline::llama-guard"], - "agents": ["inline::meta-reference"], - "telemetry": ["inline::meta-reference"], - "eval": ["inline::meta-reference"], - "datasetio": ["remote::huggingface", "inline::localfs"], - "scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"], + "inference": inference_providers, + "vector_io": [ + Provider(provider_id="sqlite-vec", provider_type="inline::sqlite-vec"), + Provider(provider_id="chromadb", provider_type="remote::chromadb"), + Provider(provider_id="pgvector", provider_type="remote::pgvector"), + ], + "safety": [Provider(provider_id="llama-guard", provider_type="inline::llama-guard")], + "agents": [Provider(provider_id="meta-reference", provider_type="inline::meta-reference")], + "telemetry": [Provider(provider_id="meta-reference", provider_type="inline::meta-reference")], + "eval": [Provider(provider_id="meta-reference", provider_type="inline::meta-reference")], + "datasetio": [ + Provider(provider_id="huggingface", provider_type="remote::huggingface"), + Provider(provider_id="localfs", provider_type="inline::localfs"), + ], + "scoring": [ + Provider(provider_id="basic", provider_type="inline::basic"), + Provider(provider_id="llm-as-judge", provider_type="inline::llm-as-judge"), + Provider(provider_id="braintrust", provider_type="inline::braintrust"), + ], "tool_runtime": [ - "remote::brave-search", - "remote::tavily-search", - "inline::rag-runtime", - "remote::model-context-protocol", + Provider(provider_id="brave-search", provider_type="remote::brave-search"), + Provider(provider_id="tavily-search", provider_type="remote::tavily-search"), + Provider(provider_id="rag-runtime", provider_type="inline::rag-runtime"), + Provider( + provider_id="model-context-protocol", + provider_type="remote::model-context-protocol", + ), ], } name = "open-benchmark" diff --git a/llama_stack/templates/open-benchmark/run.yaml b/llama_stack/templates/open-benchmark/run.yaml index 828b960a2..ba6a5e9d6 100644 --- a/llama_stack/templates/open-benchmark/run.yaml +++ b/llama_stack/templates/open-benchmark/run.yaml @@ -106,10 +106,8 @@ providers: scoring: - provider_id: basic provider_type: inline::basic - config: {} - provider_id: llm-as-judge provider_type: inline::llm-as-judge - config: {} - provider_id: braintrust provider_type: inline::braintrust config: @@ -127,10 +125,8 @@ providers: max_results: 3 - provider_id: rag-runtime provider_type: inline::rag-runtime - config: {} - provider_id: model-context-protocol provider_type: remote::model-context-protocol - config: {} metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/registry.db diff --git a/llama_stack/templates/postgres-demo/build.yaml b/llama_stack/templates/postgres-demo/build.yaml index 645b59613..d5e816a54 100644 --- a/llama_stack/templates/postgres-demo/build.yaml +++ b/llama_stack/templates/postgres-demo/build.yaml @@ -3,22 +3,33 @@ distribution_spec: description: Quick start template for running Llama Stack with several popular providers providers: inference: - - remote::vllm - - inline::sentence-transformers + - provider_id: vllm-inference + provider_type: remote::vllm + - provider_id: sentence-transformers + provider_type: inline::sentence-transformers vector_io: - - remote::chromadb + - provider_id: chromadb + provider_type: remote::chromadb safety: - - inline::llama-guard + - provider_id: llama-guard + provider_type: inline::llama-guard agents: - - inline::meta-reference + - provider_id: meta-reference + provider_type: inline::meta-reference telemetry: - - inline::meta-reference + - provider_id: meta-reference + provider_type: inline::meta-reference tool_runtime: - - remote::brave-search - - remote::tavily-search - - inline::rag-runtime - - remote::model-context-protocol + - provider_id: brave-search + provider_type: remote::brave-search + - provider_id: tavily-search + provider_type: remote::tavily-search + - provider_id: rag-runtime + provider_type: inline::rag-runtime + - provider_id: model-context-protocol + provider_type: remote::model-context-protocol image_type: conda +image_name: postgres-demo additional_pip_packages: - asyncpg - psycopg2-binary diff --git a/llama_stack/templates/postgres-demo/postgres_demo.py b/llama_stack/templates/postgres-demo/postgres_demo.py index c7ab222ec..24e3f6f27 100644 --- a/llama_stack/templates/postgres-demo/postgres_demo.py +++ b/llama_stack/templates/postgres-demo/postgres_demo.py @@ -34,16 +34,24 @@ def get_distribution_template() -> DistributionTemplate: ), ] providers = { - "inference": ([p.provider_type for p in inference_providers] + ["inline::sentence-transformers"]), - "vector_io": ["remote::chromadb"], - "safety": ["inline::llama-guard"], - "agents": ["inline::meta-reference"], - "telemetry": ["inline::meta-reference"], + "inference": inference_providers + + [ + Provider(provider_id="sentence-transformers", provider_type="inline::sentence-transformers"), + ], + "vector_io": [ + Provider(provider_id="chromadb", provider_type="remote::chromadb"), + ], + "safety": [Provider(provider_id="llama-guard", provider_type="inline::llama-guard")], + "agents": [Provider(provider_id="meta-reference", provider_type="inline::meta-reference")], + "telemetry": [Provider(provider_id="meta-reference", provider_type="inline::meta-reference")], "tool_runtime": [ - "remote::brave-search", - "remote::tavily-search", - "inline::rag-runtime", - "remote::model-context-protocol", + Provider(provider_id="brave-search", provider_type="remote::brave-search"), + Provider(provider_id="tavily-search", provider_type="remote::tavily-search"), + Provider(provider_id="rag-runtime", provider_type="inline::rag-runtime"), + Provider( + provider_id="model-context-protocol", + provider_type="remote::model-context-protocol", + ), ], } name = "postgres-demo" diff --git a/llama_stack/templates/postgres-demo/run.yaml b/llama_stack/templates/postgres-demo/run.yaml index feb85e316..747b7dc53 100644 --- a/llama_stack/templates/postgres-demo/run.yaml +++ b/llama_stack/templates/postgres-demo/run.yaml @@ -18,7 +18,6 @@ providers: tls_verify: ${env.VLLM_TLS_VERIFY:=true} - provider_id: sentence-transformers provider_type: inline::sentence-transformers - config: {} vector_io: - provider_id: ${env.ENABLE_CHROMADB:+chromadb} provider_type: remote::chromadb @@ -70,10 +69,8 @@ providers: max_results: 3 - provider_id: rag-runtime provider_type: inline::rag-runtime - config: {} - provider_id: model-context-protocol provider_type: remote::model-context-protocol - config: {} metadata_store: type: postgres host: ${env.POSTGRES_HOST:=localhost} diff --git a/llama_stack/templates/starter/build.yaml b/llama_stack/templates/starter/build.yaml index 8180124f6..9b540ab62 100644 --- a/llama_stack/templates/starter/build.yaml +++ b/llama_stack/templates/starter/build.yaml @@ -3,57 +3,98 @@ distribution_spec: description: Quick start template for running Llama Stack with several popular providers providers: inference: - - remote::cerebras - - remote::ollama - - remote::vllm - - remote::tgi - - remote::hf::serverless - - remote::hf::endpoint - - remote::fireworks - - remote::together - - remote::bedrock - - remote::databricks - - remote::nvidia - - remote::runpod - - remote::openai - - remote::anthropic - - remote::gemini - - remote::groq - - remote::llama-openai-compat - - remote::sambanova - - remote::passthrough - - inline::sentence-transformers + - provider_id: ${env.ENABLE_CEREBRAS:=__disabled__} + provider_type: remote::cerebras + - provider_id: ${env.ENABLE_OLLAMA:=__disabled__} + provider_type: remote::ollama + - provider_id: ${env.ENABLE_VLLM:=__disabled__} + provider_type: remote::vllm + - provider_id: ${env.ENABLE_TGI:=__disabled__} + provider_type: remote::tgi + - provider_id: ${env.ENABLE_HF_SERVERLESS:=__disabled__} + provider_type: remote::hf::serverless + - provider_id: ${env.ENABLE_HF_ENDPOINT:=__disabled__} + provider_type: remote::hf::endpoint + - provider_id: ${env.ENABLE_FIREWORKS:=__disabled__} + provider_type: remote::fireworks + - provider_id: ${env.ENABLE_TOGETHER:=__disabled__} + provider_type: remote::together + - provider_id: ${env.ENABLE_BEDROCK:=__disabled__} + provider_type: remote::bedrock + - provider_id: ${env.ENABLE_DATABRICKS:=__disabled__} + provider_type: remote::databricks + - provider_id: ${env.ENABLE_NVIDIA:=__disabled__} + provider_type: remote::nvidia + - provider_id: ${env.ENABLE_RUNPOD:=__disabled__} + provider_type: remote::runpod + - provider_id: ${env.ENABLE_OPENAI:=__disabled__} + provider_type: remote::openai + - provider_id: ${env.ENABLE_ANTHROPIC:=__disabled__} + provider_type: remote::anthropic + - provider_id: ${env.ENABLE_GEMINI:=__disabled__} + provider_type: remote::gemini + - provider_id: ${env.ENABLE_GROQ:=__disabled__} + provider_type: remote::groq + - provider_id: ${env.ENABLE_LLAMA_OPENAI_COMPAT:=__disabled__} + provider_type: remote::llama-openai-compat + - provider_id: ${env.ENABLE_SAMBANOVA:=__disabled__} + provider_type: remote::sambanova + - provider_id: ${env.ENABLE_PASSTHROUGH:=__disabled__} + provider_type: remote::passthrough + - provider_id: sentence-transformers + provider_type: inline::sentence-transformers vector_io: - - inline::faiss - - inline::sqlite-vec - - inline::milvus - - remote::chromadb - - remote::pgvector + - provider_id: ${env.ENABLE_FAISS:=faiss} + provider_type: inline::faiss + - provider_id: ${env.ENABLE_SQLITE_VEC:=__disabled__} + provider_type: inline::sqlite-vec + - provider_id: ${env.ENABLE_MILVUS:=__disabled__} + provider_type: inline::milvus + - provider_id: ${env.ENABLE_CHROMADB:=__disabled__} + provider_type: remote::chromadb + - provider_id: ${env.ENABLE_PGVECTOR:=__disabled__} + provider_type: remote::pgvector files: - - inline::localfs + - provider_id: localfs + provider_type: inline::localfs safety: - - inline::llama-guard + - provider_id: llama-guard + provider_type: inline::llama-guard agents: - - inline::meta-reference + - provider_id: meta-reference + provider_type: inline::meta-reference telemetry: - - inline::meta-reference + - provider_id: meta-reference + provider_type: inline::meta-reference post_training: - - inline::huggingface + - provider_id: huggingface + provider_type: inline::huggingface eval: - - inline::meta-reference + - provider_id: meta-reference + provider_type: inline::meta-reference datasetio: - - remote::huggingface - - inline::localfs + - provider_id: huggingface + provider_type: remote::huggingface + - provider_id: localfs + provider_type: inline::localfs scoring: - - inline::basic - - inline::llm-as-judge - - inline::braintrust + - provider_id: basic + provider_type: inline::basic + - provider_id: llm-as-judge + provider_type: inline::llm-as-judge + - provider_id: braintrust + provider_type: inline::braintrust tool_runtime: - - remote::brave-search - - remote::tavily-search - - inline::rag-runtime - - remote::model-context-protocol + - provider_id: brave-search + provider_type: remote::brave-search + - provider_id: tavily-search + provider_type: remote::tavily-search + - provider_id: rag-runtime + provider_type: inline::rag-runtime + - provider_id: model-context-protocol + provider_type: remote::model-context-protocol image_type: conda +image_name: starter additional_pip_packages: - aiosqlite - asyncpg diff --git a/llama_stack/templates/starter/run.yaml b/llama_stack/templates/starter/run.yaml index c38933f98..d60800ebb 100644 --- a/llama_stack/templates/starter/run.yaml +++ b/llama_stack/templates/starter/run.yaml @@ -56,7 +56,6 @@ providers: api_key: ${env.TOGETHER_API_KEY} - provider_id: ${env.ENABLE_BEDROCK:=__disabled__} provider_type: remote::bedrock - config: {} - provider_id: ${env.ENABLE_DATABRICKS:=__disabled__} provider_type: remote::databricks config: @@ -107,7 +106,6 @@ providers: api_key: ${env.PASSTHROUGH_API_KEY} - provider_id: ${env.ENABLE_SENTENCE_TRANSFORMERS:=sentence-transformers} provider_type: inline::sentence-transformers - config: {} vector_io: - provider_id: ${env.ENABLE_FAISS:=faiss} provider_type: inline::faiss @@ -208,10 +206,8 @@ providers: scoring: - provider_id: basic provider_type: inline::basic - config: {} - provider_id: llm-as-judge provider_type: inline::llm-as-judge - config: {} - provider_id: braintrust provider_type: inline::braintrust config: @@ -229,10 +225,8 @@ providers: max_results: 3 - provider_id: rag-runtime provider_type: inline::rag-runtime - config: {} - provider_id: model-context-protocol provider_type: remote::model-context-protocol - config: {} metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/registry.db diff --git a/llama_stack/templates/starter/starter.py b/llama_stack/templates/starter/starter.py index cee1094db..489117702 100644 --- a/llama_stack/templates/starter/starter.py +++ b/llama_stack/templates/starter/starter.py @@ -253,21 +253,91 @@ def get_distribution_template() -> DistributionTemplate: ] providers = { - "inference": ([p.provider_type for p in remote_inference_providers] + ["inline::sentence-transformers"]), - "vector_io": ([p.provider_type for p in vector_io_providers]), - "files": ["inline::localfs"], - "safety": ["inline::llama-guard"], - "agents": ["inline::meta-reference"], - "telemetry": ["inline::meta-reference"], - "post_training": ["inline::huggingface"], - "eval": ["inline::meta-reference"], - "datasetio": ["remote::huggingface", "inline::localfs"], - "scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"], + "inference": remote_inference_providers + + [ + Provider( + provider_id="sentence-transformers", + provider_type="inline::sentence-transformers", + ) + ], + "vector_io": vector_io_providers, + "files": [ + Provider( + provider_id="localfs", + provider_type="inline::localfs", + ) + ], + "safety": [ + Provider( + provider_id="llama-guard", + provider_type="inline::llama-guard", + ) + ], + "agents": [ + Provider( + provider_id="meta-reference", + provider_type="inline::meta-reference", + ) + ], + "telemetry": [ + Provider( + provider_id="meta-reference", + provider_type="inline::meta-reference", + ) + ], + "post_training": [ + Provider( + provider_id="huggingface", + provider_type="inline::huggingface", + ) + ], + "eval": [ + Provider( + provider_id="meta-reference", + provider_type="inline::meta-reference", + ) + ], + "datasetio": [ + Provider( + provider_id="huggingface", + provider_type="remote::huggingface", + ), + Provider( + provider_id="localfs", + provider_type="inline::localfs", + ), + ], + "scoring": [ + Provider( + provider_id="basic", + provider_type="inline::basic", + ), + Provider( + provider_id="llm-as-judge", + provider_type="inline::llm-as-judge", + ), + Provider( + provider_id="braintrust", + provider_type="inline::braintrust", + ), + ], "tool_runtime": [ - "remote::brave-search", - "remote::tavily-search", - "inline::rag-runtime", - "remote::model-context-protocol", + Provider( + provider_id="brave-search", + provider_type="remote::brave-search", + ), + Provider( + provider_id="tavily-search", + provider_type="remote::tavily-search", + ), + Provider( + provider_id="rag-runtime", + provider_type="inline::rag-runtime", + ), + Provider( + provider_id="model-context-protocol", + provider_type="remote::model-context-protocol", + ), ], } files_provider = Provider( diff --git a/llama_stack/templates/template.py b/llama_stack/templates/template.py index fb2528873..e9054f95d 100644 --- a/llama_stack/templates/template.py +++ b/llama_stack/templates/template.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from pathlib import Path -from typing import Literal +from typing import Any, Literal import jinja2 import rich @@ -35,6 +35,51 @@ from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig from llama_stack.providers.utils.sqlstore.sqlstore import get_pip_packages as get_sql_pip_packages +def filter_empty_values(obj: Any) -> Any: + """Recursively filter out specific empty values from a dictionary or list. + + This function removes: + - Empty strings ('') only when they are the 'module' field + - Empty dictionaries ({}) only when they are the 'config' field + - None values (always excluded) + """ + if obj is None: + return None + + if isinstance(obj, dict): + filtered = {} + for key, value in obj.items(): + # Special handling for specific fields + if key == "module" and isinstance(value, str) and value == "": + # Skip empty module strings + continue + elif key == "config" and isinstance(value, dict) and not value: + # Skip empty config dictionaries + continue + elif key == "container_image" and not value: + # Skip empty container_image names + continue + else: + # For all other fields, recursively filter but preserve empty values + filtered_value = filter_empty_values(value) + # if filtered_value is not None: + filtered[key] = filtered_value + return filtered + + elif isinstance(obj, list): + filtered = [] + for item in obj: + filtered_item = filter_empty_values(item) + if filtered_item is not None: + filtered.append(filtered_item) + return filtered + + else: + # For all other types (including empty strings and dicts that aren't module/config), + # preserve them as-is + return obj + + def get_model_registry( available_models: dict[str, list[ProviderModelEntry]], ) -> tuple[list[ModelInput], bool]: @@ -138,31 +183,26 @@ class RunConfigSettings(BaseModel): def run_config( self, name: str, - providers: dict[str, list[str]], + providers: dict[str, list[Provider]], container_image: str | None = None, ) -> dict: provider_registry = get_provider_registry() - provider_configs = {} - for api_str, provider_types in providers.items(): + for api_str, provider_objs in providers.items(): if api_providers := self.provider_overrides.get(api_str): # Convert Provider objects to dicts for YAML serialization - provider_configs[api_str] = [ - p.model_dump(exclude_none=True) if isinstance(p, Provider) else p for p in api_providers - ] + provider_configs[api_str] = [p.model_dump(exclude_none=True) for p in api_providers] continue provider_configs[api_str] = [] - for provider_type in provider_types: - provider_id = provider_type.split("::")[-1] - + for provider in provider_objs: api = Api(api_str) - if provider_type not in provider_registry[api]: - raise ValueError(f"Unknown provider type: {provider_type} for API: {api_str}") + if provider.provider_type not in provider_registry[api]: + raise ValueError(f"Unknown provider type: {provider.provider_type} for API: {api_str}") - config_class = provider_registry[api][provider_type].config_class + config_class = provider_registry[api][provider.provider_type].config_class assert config_class is not None, ( - f"No config class for provider type: {provider_type} for API: {api_str}" + f"No config class for provider type: {provider.provider_type} for API: {api_str}" ) config_class = instantiate_class_type(config_class) @@ -171,14 +211,9 @@ class RunConfigSettings(BaseModel): else: config = {} - provider_configs[api_str].append( - Provider( - provider_id=provider_id, - provider_type=provider_type, - config=config, - ).model_dump(exclude_none=True) - ) - + provider.config = config + # Convert Provider object to dict for YAML serialization + provider_configs[api_str].append(provider.model_dump(exclude_none=True)) # Get unique set of APIs from providers apis = sorted(providers.keys()) @@ -222,7 +257,7 @@ class DistributionTemplate(BaseModel): description: str distro_type: Literal["self_hosted", "remote_hosted", "ondevice"] - providers: dict[str, list[str]] + providers: dict[str, list[Provider]] run_configs: dict[str, RunConfigSettings] template_path: Path | None = None @@ -255,13 +290,28 @@ class DistributionTemplate(BaseModel): if self.additional_pip_packages: additional_pip_packages.extend(self.additional_pip_packages) + # Create minimal providers for build config (without runtime configs) + build_providers = {} + for api, providers in self.providers.items(): + build_providers[api] = [] + for provider in providers: + # Create a minimal provider object with only essential build information + build_provider = Provider( + provider_id=provider.provider_id, + provider_type=provider.provider_type, + config={}, # Empty config for build + module=provider.module, + ) + build_providers[api].append(build_provider) + return BuildConfig( distribution_spec=DistributionSpec( description=self.description, container_image=self.container_image, - providers=self.providers, + providers=build_providers, ), - image_type="conda", # default to conda, can be overridden + image_type="conda", + image_name=self.name, additional_pip_packages=sorted(set(additional_pip_packages)), ) @@ -270,7 +320,7 @@ class DistributionTemplate(BaseModel): providers_table += "|-----|-------------|\n" for api, providers in sorted(self.providers.items()): - providers_str = ", ".join(f"`{p}`" for p in providers) + providers_str = ", ".join(f"`{p.provider_type}`" for p in providers) providers_table += f"| {api} | {providers_str} |\n" template = self.template_path.read_text() @@ -334,7 +384,7 @@ class DistributionTemplate(BaseModel): build_config = self.build_config() with open(yaml_output_dir / "build.yaml", "w") as f: yaml.safe_dump( - build_config.model_dump(exclude_none=True), + filter_empty_values(build_config.model_dump(exclude_none=True)), f, sort_keys=False, ) @@ -343,7 +393,7 @@ class DistributionTemplate(BaseModel): run_config = settings.run_config(self.name, self.providers, self.container_image) with open(yaml_output_dir / yaml_pth, "w") as f: yaml.safe_dump( - {k: v for k, v in run_config.items() if v is not None}, + filter_empty_values(run_config), f, sort_keys=False, ) diff --git a/llama_stack/templates/watsonx/build.yaml b/llama_stack/templates/watsonx/build.yaml index 08ee2c5ce..bc992f0c7 100644 --- a/llama_stack/templates/watsonx/build.yaml +++ b/llama_stack/templates/watsonx/build.yaml @@ -3,31 +3,49 @@ distribution_spec: description: Use watsonx for running LLM inference providers: inference: - - remote::watsonx - - inline::sentence-transformers + - provider_id: watsonx + provider_type: remote::watsonx + - provider_id: sentence-transformers + provider_type: inline::sentence-transformers vector_io: - - inline::faiss + - provider_id: faiss + provider_type: inline::faiss safety: - - inline::llama-guard + - provider_id: llama-guard + provider_type: inline::llama-guard agents: - - inline::meta-reference + - provider_id: meta-reference + provider_type: inline::meta-reference telemetry: - - inline::meta-reference + - provider_id: meta-reference + provider_type: inline::meta-reference eval: - - inline::meta-reference + - provider_id: meta-reference + provider_type: inline::meta-reference datasetio: - - remote::huggingface - - inline::localfs + - provider_id: huggingface + provider_type: remote::huggingface + - provider_id: localfs + provider_type: inline::localfs scoring: - - inline::basic - - inline::llm-as-judge - - inline::braintrust + - provider_id: basic + provider_type: inline::basic + - provider_id: llm-as-judge + provider_type: inline::llm-as-judge + - provider_id: braintrust + provider_type: inline::braintrust tool_runtime: - - remote::brave-search - - remote::tavily-search - - inline::rag-runtime - - remote::model-context-protocol + - provider_id: brave-search + provider_type: remote::brave-search + - provider_id: tavily-search + provider_type: remote::tavily-search + - provider_id: rag-runtime + provider_type: inline::rag-runtime + - provider_id: model-context-protocol + provider_type: remote::model-context-protocol image_type: conda +image_name: watsonx additional_pip_packages: -- aiosqlite - sqlalchemy[asyncio] +- aiosqlite +- aiosqlite diff --git a/llama_stack/templates/watsonx/run.yaml b/llama_stack/templates/watsonx/run.yaml index afbbdb917..f5fe31bef 100644 --- a/llama_stack/templates/watsonx/run.yaml +++ b/llama_stack/templates/watsonx/run.yaml @@ -20,7 +20,6 @@ providers: project_id: ${env.WATSONX_PROJECT_ID:=} - provider_id: sentence-transformers provider_type: inline::sentence-transformers - config: {} vector_io: - provider_id: faiss provider_type: inline::faiss @@ -74,10 +73,8 @@ providers: scoring: - provider_id: basic provider_type: inline::basic - config: {} - provider_id: llm-as-judge provider_type: inline::llm-as-judge - config: {} - provider_id: braintrust provider_type: inline::braintrust config: @@ -95,10 +92,8 @@ providers: max_results: 3 - provider_id: rag-runtime provider_type: inline::rag-runtime - config: {} - provider_id: model-context-protocol provider_type: remote::model-context-protocol - config: {} metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/registry.db diff --git a/llama_stack/templates/watsonx/watsonx.py b/llama_stack/templates/watsonx/watsonx.py index ea185f05d..c13bbea36 100644 --- a/llama_stack/templates/watsonx/watsonx.py +++ b/llama_stack/templates/watsonx/watsonx.py @@ -18,19 +18,87 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin def get_distribution_template() -> DistributionTemplate: providers = { - "inference": ["remote::watsonx", "inline::sentence-transformers"], - "vector_io": ["inline::faiss"], - "safety": ["inline::llama-guard"], - "agents": ["inline::meta-reference"], - "telemetry": ["inline::meta-reference"], - "eval": ["inline::meta-reference"], - "datasetio": ["remote::huggingface", "inline::localfs"], - "scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"], + "inference": [ + Provider( + provider_id="watsonx", + provider_type="remote::watsonx", + ), + Provider( + provider_id="sentence-transformers", + provider_type="inline::sentence-transformers", + ), + ], + "vector_io": [ + Provider( + provider_id="faiss", + provider_type="inline::faiss", + ) + ], + "safety": [ + Provider( + provider_id="llama-guard", + provider_type="inline::llama-guard", + ) + ], + "agents": [ + Provider( + provider_id="meta-reference", + provider_type="inline::meta-reference", + ) + ], + "telemetry": [ + Provider( + provider_id="meta-reference", + provider_type="inline::meta-reference", + ) + ], + "eval": [ + Provider( + provider_id="meta-reference", + provider_type="inline::meta-reference", + ) + ], + "datasetio": [ + Provider( + provider_id="huggingface", + provider_type="remote::huggingface", + ), + Provider( + provider_id="localfs", + provider_type="inline::localfs", + ), + ], + "scoring": [ + Provider( + provider_id="basic", + provider_type="inline::basic", + ), + Provider( + provider_id="llm-as-judge", + provider_type="inline::llm-as-judge", + ), + Provider( + provider_id="braintrust", + provider_type="inline::braintrust", + ), + ], "tool_runtime": [ - "remote::brave-search", - "remote::tavily-search", - "inline::rag-runtime", - "remote::model-context-protocol", + Provider( + provider_id="brave-search", + provider_type="remote::brave-search", + ), + Provider( + provider_id="tavily-search", + provider_type="remote::tavily-search", + ), + Provider( + provider_id="rag-runtime", + provider_type="inline::rag-runtime", + ), + Provider( + provider_id="model-context-protocol", + provider_type="remote::model-context-protocol", + ), ], } diff --git a/tests/external/build.yaml b/tests/external/build.yaml index 90dcc97aa..c928febdb 100644 --- a/tests/external/build.yaml +++ b/tests/external/build.yaml @@ -3,7 +3,8 @@ distribution_spec: description: Custom distro for CI tests providers: weather: - - remote::kaze + - provider_id: kaze + provider_type: remote::kaze image_type: venv image_name: ci-test external_providers_dir: ~/.llama/providers.d diff --git a/tests/external/ramalama-stack/build.yaml b/tests/external/ramalama-stack/build.yaml new file mode 100644 index 000000000..c781e6537 --- /dev/null +++ b/tests/external/ramalama-stack/build.yaml @@ -0,0 +1,14 @@ +version: 2 +distribution_spec: + description: Use (an external) Ramalama server for running LLM inference + container_image: null + providers: + inference: + - provider_id: ramalama + provider_type: remote::ramalama + module: ramalama_stack==0.3.0a0 +image_type: venv +image_name: ramalama-stack-test +additional_pip_packages: +- aiosqlite +- sqlalchemy[asyncio] diff --git a/tests/external/ramalama-stack/run.yaml b/tests/external/ramalama-stack/run.yaml new file mode 100644 index 000000000..9d1d34df3 --- /dev/null +++ b/tests/external/ramalama-stack/run.yaml @@ -0,0 +1,12 @@ +version: 2 +image_name: ramalama +apis: +- inference +providers: + inference: + - provider_id: ramalama + provider_type: remote::ramalama + module: ramalama_stack==0.3.0a0 + config: {} +server: + port: 8321 diff --git a/tests/unit/distribution/test_distribution.py b/tests/unit/distribution/test_distribution.py index ae24602d7..5aac113eb 100644 --- a/tests/unit/distribution/test_distribution.py +++ b/tests/unit/distribution/test_distribution.py @@ -106,6 +106,40 @@ def api_directories(tmp_path): return remote_inference_dir, inline_inference_dir +def make_import_module_side_effect( + builtin_provider_spec=None, + external_module=None, + raise_for_external=False, + missing_get_provider_spec=False, +): + from types import SimpleNamespace + + def import_module_side_effect(name): + if name == "llama_stack.providers.registry.inference": + mock_builtin = SimpleNamespace( + available_providers=lambda: [ + builtin_provider_spec + or ProviderSpec( + api=Api.inference, + provider_type="test_provider", + config_class="test_provider.config.TestProviderConfig", + module="test_provider", + ) + ] + ) + return mock_builtin + elif name == "external_test.provider": + if raise_for_external: + raise ModuleNotFoundError(name) + if missing_get_provider_spec: + return SimpleNamespace() + return external_module + else: + raise ModuleNotFoundError(name) + + return import_module_side_effect + + class TestProviderRegistry: """Test suite for provider registry functionality.""" @@ -221,3 +255,124 @@ pip_packages: with pytest.raises(KeyError) as exc_info: get_provider_registry(base_config) assert "config_class" in str(exc_info.value) + + def test_external_provider_from_module_success(self, mock_providers): + """Test loading an external provider from a module (success path).""" + from types import SimpleNamespace + + from llama_stack.distribution.datatypes import Provider, StackRunConfig + from llama_stack.providers.datatypes import Api, ProviderSpec + + # Simulate a provider module with get_provider_spec + fake_spec = ProviderSpec( + api=Api.inference, + provider_type="external_test", + config_class="external_test.config.ExternalTestConfig", + module="external_test", + ) + fake_module = SimpleNamespace(get_provider_spec=lambda: fake_spec) + + import_module_side_effect = make_import_module_side_effect(external_module=fake_module) + + with patch("importlib.import_module", side_effect=import_module_side_effect) as mock_import: + config = StackRunConfig( + image_name="test_image", + providers={ + "inference": [ + Provider( + provider_id="external_test", + provider_type="external_test", + config={}, + module="external_test", + ) + ] + }, + ) + registry = get_provider_registry(config) + assert Api.inference in registry + assert "external_test" in registry[Api.inference] + provider = registry[Api.inference]["external_test"] + assert provider.module == "external_test" + assert provider.config_class == "external_test.config.ExternalTestConfig" + mock_import.assert_any_call("llama_stack.providers.registry.inference") + mock_import.assert_any_call("external_test.provider") + + def test_external_provider_from_module_not_found(self, mock_providers): + """Test handling ModuleNotFoundError for missing provider module.""" + from llama_stack.distribution.datatypes import Provider, StackRunConfig + + import_module_side_effect = make_import_module_side_effect(raise_for_external=True) + + with patch("importlib.import_module", side_effect=import_module_side_effect): + config = StackRunConfig( + image_name="test_image", + providers={ + "inference": [ + Provider( + provider_id="external_test", + provider_type="external_test", + config={}, + module="external_test", + ) + ] + }, + ) + with pytest.raises(ValueError) as exc_info: + get_provider_registry(config) + assert "get_provider_spec not found" in str(exc_info.value) + + def test_external_provider_from_module_missing_get_provider_spec(self, mock_providers): + """Test handling missing get_provider_spec in provider module (should raise ValueError).""" + from llama_stack.distribution.datatypes import Provider, StackRunConfig + + import_module_side_effect = make_import_module_side_effect(missing_get_provider_spec=True) + + with patch("importlib.import_module", side_effect=import_module_side_effect): + config = StackRunConfig( + image_name="test_image", + providers={ + "inference": [ + Provider( + provider_id="external_test", + provider_type="external_test", + config={}, + module="external_test", + ) + ] + }, + ) + with pytest.raises(AttributeError): + get_provider_registry(config) + + def test_external_provider_from_module_building(self, mock_providers): + """Test loading an external provider from a module during build (building=True, partial spec).""" + from llama_stack.distribution.datatypes import BuildConfig, DistributionSpec, Provider + from llama_stack.providers.datatypes import Api + + # No importlib patch needed, should not import module when type of `config` is BuildConfig or DistributionSpec + build_config = BuildConfig( + version=2, + image_type="container", + image_name="test_image", + distribution_spec=DistributionSpec( + description="test", + providers={ + "inference": [ + Provider( + provider_id="external_test", + provider_type="external_test", + config={}, + module="external_test", + ) + ] + }, + ), + ) + registry = get_provider_registry(build_config) + assert Api.inference in registry + assert "external_test" in registry[Api.inference] + provider = registry[Api.inference]["external_test"] + assert provider.module == "external_test" + assert provider.is_external is True + # config_class is empty string in partial spec + assert provider.config_class == ""