diff --git a/.github/actions/setup-vllm/action.yml b/.github/actions/setup-vllm/action.yml new file mode 100644 index 000000000..17ebd42f2 --- /dev/null +++ b/.github/actions/setup-vllm/action.yml @@ -0,0 +1,27 @@ +name: Setup VLLM +description: Start VLLM +runs: + using: "composite" + steps: + - name: Start VLLM + shell: bash + run: | + # Start vllm container + docker run -d \ + --name vllm \ + -p 8000:8000 \ + --privileged=true \ + quay.io/higginsd/vllm-cpu:65393ee064 \ + --host 0.0.0.0 \ + --port 8000 \ + --enable-auto-tool-choice \ + --tool-call-parser llama3_json \ + --model /root/.cache/Llama-3.2-1B-Instruct \ + --served-model-name meta-llama/Llama-3.2-1B-Instruct + + # Wait for vllm to be ready + echo "Waiting for vllm to be ready..." + timeout 900 bash -c 'until curl -f http://localhost:8000/health; do + echo "Waiting for vllm..." + sleep 5 + done' diff --git a/.github/dependabot.yml b/.github/dependabot.yml index d68af5615..134efd93b 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -14,8 +14,6 @@ updates: schedule: interval: "weekly" day: "saturday" - # ignore all non-security updates: https://docs.github.com/en/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file#open-pull-requests-limit - open-pull-requests-limit: 0 labels: - type/dependencies - python diff --git a/.github/workflows/README.md b/.github/workflows/README.md new file mode 100644 index 000000000..3347b05f8 --- /dev/null +++ b/.github/workflows/README.md @@ -0,0 +1,22 @@ +# Llama Stack CI + +Llama Stack uses GitHub Actions for Continous Integration (CI). Below is a table detailing what CI the project includes and the purpose. + +| Name | File | Purpose | +| ---- | ---- | ------- | +| Update Changelog | [changelog.yml](changelog.yml) | Creates PR for updating the CHANGELOG.md | +| Coverage Badge | [coverage-badge.yml](coverage-badge.yml) | Creates PR for updating the code coverage badge | +| Installer CI | [install-script-ci.yml](install-script-ci.yml) | Test the installation script | +| Integration Auth Tests | [integration-auth-tests.yml](integration-auth-tests.yml) | Run the integration test suite with Kubernetes authentication | +| SqlStore Integration Tests | [integration-sql-store-tests.yml](integration-sql-store-tests.yml) | Run the integration test suite with SqlStore | +| Integration Tests | [integration-tests.yml](integration-tests.yml) | Run the integration test suite with Ollama | +| Vector IO Integration Tests | [integration-vector-io-tests.yml](integration-vector-io-tests.yml) | Run the integration test suite with various VectorIO providers | +| Pre-commit | [pre-commit.yml](pre-commit.yml) | Run pre-commit checks | +| Test Llama Stack Build | [providers-build.yml](providers-build.yml) | Test llama stack build | +| Python Package Build Test | [python-build-test.yml](python-build-test.yml) | Test building the llama-stack PyPI project | +| Check semantic PR titles | [semantic-pr.yml](semantic-pr.yml) | Ensure that PR titles follow the conventional commit spec | +| Close stale issues and PRs | [stale_bot.yml](stale_bot.yml) | Run the Stale Bot action | +| Test External Providers Installed via Module | [test-external-provider-module.yml](test-external-provider-module.yml) | Test External Provider installation via Python module | +| Test External API and Providers | [test-external.yml](test-external.yml) | Test the External API and Provider mechanisms | +| Unit Tests | [unit-tests.yml](unit-tests.yml) | Run the unit test suite | +| Update ReadTheDocs | [update-readthedocs.yml](update-readthedocs.yml) | Update the Llama Stack ReadTheDocs site | diff --git a/.github/workflows/changelog.yml b/.github/workflows/changelog.yml index c497348b0..e406d99ee 100644 --- a/.github/workflows/changelog.yml +++ b/.github/workflows/changelog.yml @@ -1,5 +1,7 @@ name: Update Changelog +run-name: Creates PR for updating the CHANGELOG.md + on: release: types: [published, unpublished, created, edited, deleted, released] diff --git a/.github/workflows/coverage-badge.yml b/.github/workflows/coverage-badge.yml index 54bde1749..75428539e 100644 --- a/.github/workflows/coverage-badge.yml +++ b/.github/workflows/coverage-badge.yml @@ -1,5 +1,7 @@ name: Coverage Badge +run-name: Creates PR for updating the code coverage badge + on: push: branches: [ main ] diff --git a/.github/workflows/install-script-ci.yml b/.github/workflows/install-script-ci.yml index d711444e8..5dc2b4412 100644 --- a/.github/workflows/install-script-ci.yml +++ b/.github/workflows/install-script-ci.yml @@ -1,5 +1,7 @@ name: Installer CI +run-name: Test the installation script + on: pull_request: paths: @@ -17,10 +19,20 @@ jobs: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # 4.2.2 - name: Run ShellCheck on install.sh run: shellcheck scripts/install.sh - smoke-test: - needs: lint + smoke-test-on-dev: runs-on: ubuntu-latest steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # 4.2.2 + - name: Checkout repository + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Install dependencies + uses: ./.github/actions/setup-runner + + - name: Build a single provider + run: | + USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --template starter --image-type container --image-name test + - name: Run installer end-to-end - run: ./scripts/install.sh + run: | + IMAGE_ID=$(docker images --format "{{.Repository}}:{{.Tag}}" | head -n 1) + ./scripts/install.sh --image $IMAGE_ID diff --git a/.github/workflows/integration-auth-tests.yml b/.github/workflows/integration-auth-tests.yml index cf10e005c..ef2066497 100644 --- a/.github/workflows/integration-auth-tests.yml +++ b/.github/workflows/integration-auth-tests.yml @@ -1,5 +1,7 @@ name: Integration Auth Tests +run-name: Run the integration test suite with Kubernetes authentication + on: push: branches: [ main ] diff --git a/.github/workflows/integration-sql-store-tests.yml b/.github/workflows/integration-sql-store-tests.yml index aeeecf395..4e5b64963 100644 --- a/.github/workflows/integration-sql-store-tests.yml +++ b/.github/workflows/integration-sql-store-tests.yml @@ -1,5 +1,7 @@ name: SqlStore Integration Tests +run-name: Run the integration test suite with SqlStore + on: push: branches: [ main ] diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 082f1e204..c7b7fc55b 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -1,5 +1,7 @@ name: Integration Tests +run-name: Run the integration test suite with Ollama + on: push: branches: [ main ] @@ -14,13 +16,19 @@ on: - '.github/workflows/integration-tests.yml' # This workflow - '.github/actions/setup-ollama/action.yml' schedule: - - cron: '0 0 * * *' # Daily at 12 AM UTC + # If changing the cron schedule, update the provider in the test-matrix job + - cron: '0 0 * * *' # (test latest client) Daily at 12 AM UTC + - cron: '1 0 * * 0' # (test vllm) Weekly on Sunday at 1 AM UTC workflow_dispatch: inputs: test-all-client-versions: description: 'Test against both the latest and published versions' type: boolean default: false + test-provider: + description: 'Test against a specific provider' + type: string + default: 'ollama' concurrency: group: ${{ github.workflow }}-${{ github.ref }} @@ -53,8 +61,17 @@ jobs: matrix: test-type: ${{ fromJson(needs.discover-tests.outputs.test-type) }} client-type: [library, server] + # Use vllm on weekly schedule, otherwise use test-provider input (defaults to ollama) + provider: ${{ (github.event.schedule == '1 0 * * 0') && fromJSON('["vllm"]') || fromJSON(format('["{0}"]', github.event.inputs.test-provider || 'ollama')) }} python-version: ["3.12", "3.13"] - client-version: ${{ (github.event_name == 'schedule' || github.event.inputs.test-all-client-versions == 'true') && fromJSON('["published", "latest"]') || fromJSON('["latest"]') }} + client-version: ${{ (github.event.schedule == '0 0 * * 0' || github.event.inputs.test-all-client-versions == 'true') && fromJSON('["published", "latest"]') || fromJSON('["latest"]') }} + exclude: # TODO: look into why these tests are failing and fix them + - provider: vllm + test-type: safety + - provider: vllm + test-type: post_training + - provider: vllm + test-type: tool_runtime steps: - name: Checkout repository @@ -67,8 +84,13 @@ jobs: client-version: ${{ matrix.client-version }} - name: Setup ollama + if: ${{ matrix.provider == 'ollama' }} uses: ./.github/actions/setup-ollama + - name: Setup vllm + if: ${{ matrix.provider == 'vllm' }} + uses: ./.github/actions/setup-vllm + - name: Build Llama Stack run: | uv run llama stack build --template ci-tests --image-type venv @@ -81,10 +103,6 @@ jobs: - name: Run Integration Tests env: - OLLAMA_INFERENCE_MODEL: "llama3.2:3b-instruct-fp16" # for server tests - ENABLE_OLLAMA: "ollama" # for server tests - OLLAMA_URL: "http://0.0.0.0:11434" - SAFETY_MODEL: "llama-guard3:1b" LLAMA_STACK_CLIENT_TIMEOUT: "300" # Increased timeout for eval operations # Use 'shell' to get pipefail behavior # https://docs.github.com/en/actions/reference/workflow-syntax-for-github-actions#exit-codes-and-error-action-preference @@ -96,12 +114,31 @@ jobs: else stack_config="server:ci-tests" fi + + EXCLUDE_TESTS="builtin_tool or safety_with_image or code_interpreter or test_rag" + if [ "${{ matrix.provider }}" == "ollama" ]; then + export ENABLE_OLLAMA="ollama" + export OLLAMA_URL="http://0.0.0.0:11434" + export OLLAMA_INFERENCE_MODEL="llama3.2:3b-instruct-fp16" + export TEXT_MODEL=ollama/$OLLAMA_INFERENCE_MODEL + export SAFETY_MODEL="llama-guard3:1b" + EXTRA_PARAMS="--safety-shield=$SAFETY_MODEL" + else + export ENABLE_VLLM="vllm" + export VLLM_URL="http://localhost:8000/v1" + export VLLM_INFERENCE_MODEL="meta-llama/Llama-3.2-1B-Instruct" + export TEXT_MODEL=vllm/$VLLM_INFERENCE_MODEL + # TODO: remove the not(test_inference_store_tool_calls) once we can get the tool called consistently + EXTRA_PARAMS= + EXCLUDE_TESTS="${EXCLUDE_TESTS} or test_inference_store_tool_calls" + fi + + uv run pytest -s -v tests/integration/${{ matrix.test-type }} --stack-config=${stack_config} \ - -k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \ - --text-model="ollama/llama3.2:3b-instruct-fp16" \ + -k "not( ${EXCLUDE_TESTS} )" \ + --text-model=$TEXT_MODEL \ --embedding-model=sentence-transformers/all-MiniLM-L6-v2 \ - --safety-shield=$SAFETY_MODEL \ - --color=yes \ + --color=yes ${EXTRA_PARAMS} \ --capture=tee-sys | tee pytest-${{ matrix.test-type }}.log - name: Check Storage and Memory Available After Tests @@ -110,16 +147,17 @@ jobs: free -h df -h - - name: Write ollama logs to file + - name: Write inference logs to file if: ${{ always() }} run: | - sudo docker logs ollama > ollama.log + sudo docker logs ollama > ollama.log || true + sudo docker logs vllm > vllm.log || true - name: Upload all logs to artifacts if: ${{ always() }} uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 with: - name: logs-${{ github.run_id }}-${{ github.run_attempt }}-${{ matrix.client-type }}-${{ matrix.test-type }}-${{ matrix.python-version }}-${{ matrix.client-version }} + name: logs-${{ github.run_id }}-${{ github.run_attempt }}-${{ matrix.provider }}-${{ matrix.client-type }}-${{ matrix.test-type }}-${{ matrix.python-version }}-${{ matrix.client-version }} path: | *.log retention-days: 1 diff --git a/.github/workflows/integration-vector-io-tests.yml b/.github/workflows/integration-vector-io-tests.yml index 525c17d46..9a02bbcf8 100644 --- a/.github/workflows/integration-vector-io-tests.yml +++ b/.github/workflows/integration-vector-io-tests.yml @@ -1,5 +1,7 @@ name: Vector IO Integration Tests +run-name: Run the integration test suite with various VectorIO providers + on: push: branches: [ main ] diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 326abb37b..2c1c8febb 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -1,5 +1,7 @@ name: Pre-commit +run-name: Run pre-commit checks + on: pull_request: push: diff --git a/.github/workflows/providers-build.yml b/.github/workflows/providers-build.yml index 392fddda6..284076d50 100644 --- a/.github/workflows/providers-build.yml +++ b/.github/workflows/providers-build.yml @@ -1,5 +1,7 @@ name: Test Llama Stack Build +run-name: Test llama stack build + on: push: branches: diff --git a/.github/workflows/python-build-test.yml b/.github/workflows/python-build-test.yml index efd1f2cc9..67dc49cce 100644 --- a/.github/workflows/python-build-test.yml +++ b/.github/workflows/python-build-test.yml @@ -1,5 +1,7 @@ name: Python Package Build Test +run-name: Test building the llama-stack PyPI project + on: push: branches: @@ -20,7 +22,7 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Install uv - uses: astral-sh/setup-uv@7edac99f961f18b581bbd960d59d049f04c0002f # v6.4.1 + uses: astral-sh/setup-uv@e92bafb6253dcd438e0484186d7669ea7a8ca1cc # v6.4.3 with: python-version: ${{ matrix.python-version }} activate-environment: true diff --git a/.github/workflows/semantic-pr.yml b/.github/workflows/semantic-pr.yml index 2dc1ed473..4df7324c4 100644 --- a/.github/workflows/semantic-pr.yml +++ b/.github/workflows/semantic-pr.yml @@ -1,5 +1,7 @@ name: Check semantic PR titles +run-name: Ensure that PR titles follow the conventional commit spec + on: pull_request_target: types: diff --git a/.github/workflows/stale_bot.yml b/.github/workflows/stale_bot.yml index 06318b5f7..087df72d7 100644 --- a/.github/workflows/stale_bot.yml +++ b/.github/workflows/stale_bot.yml @@ -1,5 +1,7 @@ name: Close stale issues and PRs +run-name: Run the Stale Bot action + on: schedule: - cron: '0 0 * * *' # every day at midnight diff --git a/.github/workflows/test-external-providers.yml b/.github/workflows/test-external-provider-module.yml similarity index 55% rename from .github/workflows/test-external-providers.yml rename to .github/workflows/test-external-provider-module.yml index cdf18fab7..284f9baa2 100644 --- a/.github/workflows/test-external-providers.yml +++ b/.github/workflows/test-external-provider-module.yml @@ -1,4 +1,6 @@ -name: Test External Providers +name: Test External Providers Installed via Module + +run-name: Test External Provider installation via Python module on: push: @@ -11,10 +13,10 @@ on: - 'uv.lock' - 'pyproject.toml' - 'requirements.txt' - - '.github/workflows/test-external-providers.yml' # This workflow + - '.github/workflows/test-external-providers-module.yml' # This workflow jobs: - test-external-providers: + test-external-providers-from-module: runs-on: ubuntu-latest strategy: matrix: @@ -28,39 +30,38 @@ jobs: - name: Install dependencies uses: ./.github/actions/setup-runner + - name: Install Ramalama + shell: bash + run: | + uv pip install ramalama + + - name: Run Ramalama + shell: bash + run: | + nohup ramalama serve llama3.2:3b-instruct-fp16 > ramalama_server.log 2>&1 & - name: Apply image type to config file run: | - yq -i '.image_type = "${{ matrix.image-type }}"' tests/external-provider/llama-stack-provider-ollama/custom-distro.yaml - cat tests/external-provider/llama-stack-provider-ollama/custom-distro.yaml - - - name: Setup directory for Ollama custom provider - run: | - mkdir -p tests/external-provider/llama-stack-provider-ollama/src/ - cp -a llama_stack/providers/remote/inference/ollama/ tests/external-provider/llama-stack-provider-ollama/src/llama_stack_provider_ollama - - - name: Create provider configuration - run: | - mkdir -p /home/runner/.llama/providers.d/remote/inference - cp tests/external-provider/llama-stack-provider-ollama/custom_ollama.yaml /home/runner/.llama/providers.d/remote/inference/custom_ollama.yaml + yq -i '.image_type = "${{ matrix.image-type }}"' tests/external/ramalama-stack/run.yaml + cat tests/external/ramalama-stack/run.yaml - name: Build distro from config file run: | - USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. llama stack build --config tests/external-provider/llama-stack-provider-ollama/custom-distro.yaml + USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. llama stack build --config tests/external/ramalama-stack/build.yaml - name: Start Llama Stack server in background if: ${{ matrix.image-type }} == 'venv' env: - INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct" + INFERENCE_MODEL: "llama3.2:3b-instruct-fp16" run: | # Use the virtual environment created by the build step (name comes from build config) - source ci-test/bin/activate + source ramalama-stack-test/bin/activate uv pip list - nohup llama stack run tests/external-provider/llama-stack-provider-ollama/run.yaml --image-type ${{ matrix.image-type }} > server.log 2>&1 & + nohup llama stack run tests/external/ramalama-stack/run.yaml --image-type ${{ matrix.image-type }} > server.log 2>&1 & - name: Wait for Llama Stack server to be ready run: | for i in {1..30}; do - if ! grep -q "Successfully loaded external provider remote::custom_ollama" server.log; then + if ! grep -q "successfully connected to Ramalama" server.log; then echo "Waiting for Llama Stack server to load the provider..." sleep 1 else diff --git a/.github/workflows/test-external.yml b/.github/workflows/test-external.yml new file mode 100644 index 000000000..9dd72ad61 --- /dev/null +++ b/.github/workflows/test-external.yml @@ -0,0 +1,77 @@ +name: Test External API and Providers + +run-name: Test the External API and Provider mechanisms + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + paths: + - 'llama_stack/**' + - 'tests/integration/**' + - 'uv.lock' + - 'pyproject.toml' + - 'requirements.txt' + - '.github/workflows/test-external.yml' # This workflow + +jobs: + test-external: + runs-on: ubuntu-latest + strategy: + matrix: + image-type: [venv] + # We don't do container yet, it's tricky to install a package from the host into the + # container and point 'uv pip install' to the correct path... + steps: + - name: Checkout repository + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Install dependencies + uses: ./.github/actions/setup-runner + + - name: Create API configuration + run: | + mkdir -p /home/runner/.llama/apis.d + cp tests/external/weather.yaml /home/runner/.llama/apis.d/weather.yaml + + - name: Create provider configuration + run: | + mkdir -p /home/runner/.llama/providers.d/remote/weather + cp tests/external/kaze.yaml /home/runner/.llama/providers.d/remote/weather/kaze.yaml + + - name: Print distro dependencies + run: | + USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. llama stack build --config tests/external/build.yaml --print-deps-only + + - name: Build distro from config file + run: | + USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. llama stack build --config tests/external/build.yaml + + - name: Start Llama Stack server in background + if: ${{ matrix.image-type }} == 'venv' + env: + INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct" + run: | + # Use the virtual environment created by the build step (name comes from build config) + source ci-test/bin/activate + uv pip list + nohup llama stack run tests/external/run-byoa.yaml --image-type ${{ matrix.image-type }} > server.log 2>&1 & + + - name: Wait for Llama Stack server to be ready + run: | + echo "Waiting for Llama Stack server..." + for i in {1..30}; do + if curl -sSf http://localhost:8321/v1/health | grep -q "OK"; then + echo "Llama Stack server is up!" + exit 0 + fi + sleep 1 + done + echo "Llama Stack server failed to start" + cat server.log + exit 1 + + - name: Test external API + run: | + curl -sSf http://localhost:8321/v1/weather/locations diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 41034b45f..f0c63f83d 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -1,5 +1,7 @@ name: Unit Tests +run-name: Run the unit test suite + on: push: branches: [ main ] diff --git a/.github/workflows/update-readthedocs.yml b/.github/workflows/update-readthedocs.yml index 981332a77..1dcfdeca5 100644 --- a/.github/workflows/update-readthedocs.yml +++ b/.github/workflows/update-readthedocs.yml @@ -1,5 +1,7 @@ name: Update ReadTheDocs +run-name: Update the Llama Stack ReadTheDocs site + on: workflow_dispatch: inputs: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cf72ecd0e..a1acdbe84 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -145,6 +145,15 @@ repos: echo; exit 1; } || true + - id: generate-ci-docs + name: Generate CI documentation + additional_dependencies: + - uv==0.7.8 + entry: uv run ./scripts/gen-ci-docs.py + language: python + pass_filenames: false + require_serial: true + files: ^.github/workflows/.*$ ci: autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks diff --git a/docs/readme.md b/docs/README.md similarity index 100% rename from docs/readme.md rename to docs/README.md diff --git a/docs/source/apis/external.md b/docs/source/apis/external.md new file mode 100644 index 000000000..025267c33 --- /dev/null +++ b/docs/source/apis/external.md @@ -0,0 +1,392 @@ +# External APIs + +Llama Stack supports external APIs that live outside of the main codebase. This allows you to: +- Create and maintain your own APIs independently +- Share APIs with others without contributing to the main codebase +- Keep API-specific code separate from the core Llama Stack code + +## Configuration + +To enable external APIs, you need to configure the `external_apis_dir` in your Llama Stack configuration. This directory should contain your external API specifications: + +```yaml +external_apis_dir: ~/.llama/apis.d/ +``` + +## Directory Structure + +The external APIs directory should follow this structure: + +``` +apis.d/ + custom_api1.yaml + custom_api2.yaml +``` + +Each YAML file in these directories defines an API specification. + +## API Specification + +Here's an example of an external API specification for a weather API: + +```yaml +module: weather +api_dependencies: + - inference +protocol: WeatherAPI +name: weather +pip_packages: + - llama-stack-api-weather +``` + +### API Specification Fields + +- `module`: Python module containing the API implementation +- `protocol`: Name of the protocol class for the API +- `name`: Name of the API +- `pip_packages`: List of pip packages to install the API, typically a single package + +## Required Implementation + +External APIs must expose a `available_providers()` function in their module that returns a list of provider names: + +```python +# llama_stack_api_weather/api.py +from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec + + +def available_providers() -> list[ProviderSpec]: + return [ + InlineProviderSpec( + api=Api.weather, + provider_type="inline::darksky", + pip_packages=[], + module="llama_stack_provider_darksky", + config_class="llama_stack_provider_darksky.DarkSkyWeatherImplConfig", + ), + ] +``` + +A Protocol class like so: + +```python +# llama_stack_api_weather/api.py +from typing import Protocol + +from llama_stack.schema_utils import webmethod + + +class WeatherAPI(Protocol): + """ + A protocol for the Weather API. + """ + + @webmethod(route="/locations", method="GET") + async def get_available_locations() -> dict[str, list[str]]: + """ + Get the available locations. + """ + ... +``` + +## Example: Custom API + +Here's a complete example of creating and using a custom API: + +1. First, create the API package: + +```bash +mkdir -p llama-stack-api-weather +cd llama-stack-api-weather +mkdir src/llama_stack_api_weather +git init +uv init +``` + +2. Edit `pyproject.toml`: + +```toml +[project] +name = "llama-stack-api-weather" +version = "0.1.0" +description = "Weather API for Llama Stack" +readme = "README.md" +requires-python = ">=3.10" +dependencies = ["llama-stack", "pydantic"] + +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +where = ["src"] +include = ["llama_stack_api_weather", "llama_stack_api_weather.*"] +``` + +3. Create the initial files: + +```bash +touch src/llama_stack_api_weather/__init__.py +touch src/llama_stack_api_weather/api.py +``` + +```python +# llama-stack-api-weather/src/llama_stack_api_weather/__init__.py +"""Weather API for Llama Stack.""" + +from .api import WeatherAPI, available_providers + +__all__ = ["WeatherAPI", "available_providers"] +``` + +4. Create the API implementation: + +```python +# llama-stack-api-weather/src/llama_stack_api_weather/weather.py +from typing import Protocol + +from llama_stack.providers.datatypes import ( + AdapterSpec, + Api, + ProviderSpec, + RemoteProviderSpec, +) +from llama_stack.schema_utils import webmethod + + +def available_providers() -> list[ProviderSpec]: + return [ + RemoteProviderSpec( + api=Api.weather, + provider_type="remote::kaze", + config_class="llama_stack_provider_kaze.KazeProviderConfig", + adapter=AdapterSpec( + adapter_type="kaze", + module="llama_stack_provider_kaze", + pip_packages=["llama_stack_provider_kaze"], + config_class="llama_stack_provider_kaze.KazeProviderConfig", + ), + ), + ] + + +class WeatherProvider(Protocol): + """ + A protocol for the Weather API. + """ + + @webmethod(route="/weather/locations", method="GET") + async def get_available_locations() -> dict[str, list[str]]: + """ + Get the available locations. + """ + ... +``` + +5. Create the API specification: + +```yaml +# ~/.llama/apis.d/weather.yaml +module: llama_stack_api_weather +name: weather +pip_packages: ["llama-stack-api-weather"] +protocol: WeatherProvider + +``` + +6. Install the API package: + +```bash +uv pip install -e . +``` + +7. Configure Llama Stack to use external APIs: + +```yaml +version: "2" +image_name: "llama-stack-api-weather" +apis: + - weather +providers: {} +external_apis_dir: ~/.llama/apis.d +``` + +The API will now be available at `/v1/weather/locations`. + +## Example: custom provider for the weather API + +1. Create the provider package: + +```bash +mkdir -p llama-stack-provider-kaze +cd llama-stack-provider-kaze +uv init +``` + +2. Edit `pyproject.toml`: + +```toml +[project] +name = "llama-stack-provider-kaze" +version = "0.1.0" +description = "Kaze weather provider for Llama Stack" +readme = "README.md" +requires-python = ">=3.10" +dependencies = ["llama-stack", "pydantic", "aiohttp"] + +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +where = ["src"] +include = ["llama_stack_provider_kaze", "llama_stack_provider_kaze.*"] +``` + +3. Create the initial files: + +```bash +touch src/llama_stack_provider_kaze/__init__.py +touch src/llama_stack_provider_kaze/kaze.py +``` + +4. Create the provider implementation: + + +Initialization function: + +```python +# llama-stack-provider-kaze/src/llama_stack_provider_kaze/__init__.py +"""Kaze weather provider for Llama Stack.""" + +from .config import KazeProviderConfig +from .kaze import WeatherKazeAdapter + +__all__ = ["KazeProviderConfig", "WeatherKazeAdapter"] + + +async def get_adapter_impl(config: KazeProviderConfig, _deps): + from .kaze import WeatherKazeAdapter + + impl = WeatherKazeAdapter(config) + await impl.initialize() + return impl +``` + +Configuration: + +```python +# llama-stack-provider-kaze/src/llama_stack_provider_kaze/config.py +from pydantic import BaseModel, Field + + +class KazeProviderConfig(BaseModel): + """Configuration for the Kaze weather provider.""" + + base_url: str = Field( + "https://api.kaze.io/v1", + description="Base URL for the Kaze weather API", + ) +``` + +Main implementation: + +```python +# llama-stack-provider-kaze/src/llama_stack_provider_kaze/kaze.py +from llama_stack_api_weather.api import WeatherProvider + +from .config import KazeProviderConfig + + +class WeatherKazeAdapter(WeatherProvider): + """Kaze weather provider implementation.""" + + def __init__( + self, + config: KazeProviderConfig, + ) -> None: + self.config = config + + async def initialize(self) -> None: + pass + + async def get_available_locations(self) -> dict[str, list[str]]: + """Get available weather locations.""" + return {"locations": ["Paris", "Tokyo"]} +``` + +5. Create the provider specification: + +```yaml +# ~/.llama/providers.d/remote/weather/kaze.yaml +adapter: + adapter_type: kaze + pip_packages: ["llama_stack_provider_kaze"] + config_class: llama_stack_provider_kaze.config.KazeProviderConfig + module: llama_stack_provider_kaze +optional_api_dependencies: [] +``` + +6. Install the provider package: + +```bash +uv pip install -e . +``` + +7. Configure Llama Stack to use the provider: + +```yaml +# ~/.llama/run-byoa.yaml +version: "2" +image_name: "llama-stack-api-weather" +apis: + - weather +providers: + weather: + - provider_id: kaze + provider_type: remote::kaze + config: {} +external_apis_dir: ~/.llama/apis.d +external_providers_dir: ~/.llama/providers.d +server: + port: 8321 +``` + +8. Run the server: + +```bash +python -m llama_stack.distribution.server.server --yaml-config ~/.llama/run-byoa.yaml +``` + +9. Test the API: + +```bash +curl -sSf http://127.0.0.1:8321/v1/weather/locations +{"locations":["Paris","Tokyo"]}% +``` + +## Best Practices + +1. **Package Naming**: Use a clear and descriptive name for your API package. + +2. **Version Management**: Keep your API package versioned and compatible with the Llama Stack version you're using. + +3. **Dependencies**: Only include the minimum required dependencies in your API package. + +4. **Documentation**: Include clear documentation in your API package about: + - Installation requirements + - Configuration options + - API endpoints and usage + - Any limitations or known issues + +5. **Testing**: Include tests in your API package to ensure it works correctly with Llama Stack. + +## Troubleshooting + +If your external API isn't being loaded: + +1. Check that the `external_apis_dir` path is correct and accessible. +2. Verify that the YAML files are properly formatted. +3. Ensure all required Python packages are installed. +4. Check the Llama Stack server logs for any error messages - turn on debug logging to get more information using `LLAMA_STACK_LOGGING=all=debug`. +5. Verify that the API package is installed in your Python environment. diff --git a/docs/source/concepts/apis.md b/docs/source/concepts/apis.md index 6da77a9e6..5a10d6498 100644 --- a/docs/source/concepts/apis.md +++ b/docs/source/concepts/apis.md @@ -10,9 +10,11 @@ A Llama Stack API is described as a collection of REST endpoints. We currently s - **Eval**: generate outputs (via Inference or Agents) and perform scoring - **VectorIO**: perform operations on vector stores, such as adding documents, searching, and deleting documents - **Telemetry**: collect telemetry data from the system +- **Post Training**: fine-tune a model +- **Tool Runtime**: interact with various tools and protocols +- **Responses**: generate responses from an LLM using this OpenAI compatible API. We are working on adding a few more APIs to complete the application lifecycle. These will include: - **Batch Inference**: run inference on a dataset of inputs - **Batch Agents**: run agents on a dataset of inputs -- **Post Training**: fine-tune a model - **Synthetic Data Generation**: generate synthetic data for model development diff --git a/docs/source/distributions/configuration.md b/docs/source/distributions/configuration.md index 6362effe8..775749dd6 100644 --- a/docs/source/distributions/configuration.md +++ b/docs/source/distributions/configuration.md @@ -504,6 +504,47 @@ created by users sharing a team with them: description: any user has read access to any resource created by a user with the same team ``` +#### API Endpoint Authorization with Scopes + +In addition to resource-based access control, Llama Stack supports endpoint-level authorization using OAuth 2.0 style scopes. When authentication is enabled, specific API endpoints require users to have particular scopes in their authentication token. + +**Scope-Gated APIs:** +The following APIs are currently gated by scopes: + +- **Telemetry API** (scope: `telemetry.read`): + - `POST /telemetry/traces` - Query traces + - `GET /telemetry/traces/{trace_id}` - Get trace by ID + - `GET /telemetry/traces/{trace_id}/spans/{span_id}` - Get span by ID + - `POST /telemetry/spans/{span_id}/tree` - Get span tree + - `POST /telemetry/spans` - Query spans + - `POST /telemetry/metrics/{metric_name}` - Query metrics + +**Authentication Configuration:** + +For **JWT/OAuth2 providers**, scopes should be included in the JWT's claims: +```json +{ + "sub": "user123", + "scope": "telemetry.read", + "aud": "llama-stack" +} +``` + +For **custom authentication providers**, the endpoint must return user attributes including the `scopes` array: +```json +{ + "principal": "user123", + "attributes": { + "scopes": ["telemetry.read"] + } +} +``` + +**Behavior:** +- Users without the required scope receive a 403 Forbidden response +- When authentication is disabled, scope checks are bypassed +- Endpoints without `required_scope` work normally for all authenticated users + ### Quota Configuration The `quota` section allows you to enable server-side request throttling for both diff --git a/docs/source/distributions/self_hosted_distro/nvidia.md b/docs/source/distributions/self_hosted_distro/nvidia.md index 47e38f73d..928be15ad 100644 --- a/docs/source/distributions/self_hosted_distro/nvidia.md +++ b/docs/source/distributions/self_hosted_distro/nvidia.md @@ -1,3 +1,6 @@ +--- +orphan: true +--- # NVIDIA Distribution diff --git a/docs/source/providers/external.md b/docs/source/providers/external.md index db0bc01e3..092b3a476 100644 --- a/docs/source/providers/external.md +++ b/docs/source/providers/external.md @@ -7,7 +7,17 @@ Llama Stack supports external providers that live outside of the main codebase. ## Configuration -To enable external providers, you need to configure the `external_providers_dir` in your Llama Stack configuration. This directory should contain your external provider specifications: +To enable external providers, you need to add `module` into your build yaml, allowing Llama Stack to install the required package corresponding to the external provider. + +an example entry in your build.yaml should look like: + +``` +- provider_id: ramalama + provider_type: remote::ramalama + module: ramalama_stack +``` + +Additionally you can configure the `external_providers_dir` in your Llama Stack configuration. This method is in the process of being deprecated in favor of the `module` method. If using this method, the external provider directory should contain your external provider specifications: ```yaml external_providers_dir: ~/.llama/providers.d/ @@ -112,6 +122,31 @@ container_image: custom-vector-store:latest # optional ## Required Implementation +## All Providers + +All providers must contain a `get_provider_spec` function in their `provider` module. This is a standardized structure that Llama Stack expects and is necessary for getting things such as the config class. The `get_provider_spec` method returns a structure identical to the `adapter`. An example function may look like: + +```python +from llama_stack.providers.datatypes import ( + ProviderSpec, + Api, + AdapterSpec, + remote_provider_spec, +) + + +def get_provider_spec() -> ProviderSpec: + return remote_provider_spec( + api=Api.inference, + adapter=AdapterSpec( + adapter_type="ramalama", + pip_packages=["ramalama>=0.8.5", "pymilvus"], + config_class="ramalama_stack.config.RamalamaImplConfig", + module="ramalama_stack", + ), + ) +``` + ### Remote Providers Remote providers must expose a `get_adapter_impl()` function in their module that takes two arguments: @@ -155,7 +190,7 @@ Version: 0.1.0 Location: /path/to/venv/lib/python3.10/site-packages ``` -## Example: Custom Ollama Provider +## Example using `external_providers_dir`: Custom Ollama Provider Here's a complete example of creating and using a custom Ollama provider: @@ -206,6 +241,35 @@ external_providers_dir: ~/.llama/providers.d/ The provider will now be available in Llama Stack with the type `remote::custom_ollama`. + +## Example using `module`: ramalama-stack + +[ramalama-stack](https://github.com/containers/ramalama-stack) is a recognized external provider that supports installation via module. + +To install Llama Stack with this external provider a user can provider the following build.yaml: + +```yaml +version: 2 +distribution_spec: + description: Use (an external) Ramalama server for running LLM inference + container_image: null + providers: + inference: + - provider_id: ramalama + provider_type: remote::ramalama + module: ramalama_stack==0.3.0a0 +image_type: venv +image_name: null +external_providers_dir: null +additional_pip_packages: +- aiosqlite +- sqlalchemy[asyncio] +``` + +No other steps are required other than `llama stack build` and `llama stack run`. The build process will use `module` to install all of the provider dependencies, retrieve the spec, etc. + +The provider will now be available in Llama Stack with the type `remote::ramalama`. + ## Best Practices 1. **Package Naming**: Use the prefix `llama-stack-provider-` for your provider packages to make them easily identifiable. @@ -229,9 +293,10 @@ information. Execute the test for the Provider type you are developing. If your external provider isn't being loaded: +1. Check that `module` points to a published pip package with a top level `provider` module including `get_provider_spec`. 1. Check that the `external_providers_dir` path is correct and accessible. 2. Verify that the YAML files are properly formatted. 3. Ensure all required Python packages are installed. 4. Check the Llama Stack server logs for any error messages - turn on debug logging to get more information using `LLAMA_STACK_LOGGING=all=debug`. -5. Verify that the provider package is installed in your Python environment. +5. Verify that the provider package is installed in your Python environment if using `external_providers_dir`. diff --git a/docs/source/providers/inference/remote_fireworks.md b/docs/source/providers/inference/remote_fireworks.md index 351586c34..862860c29 100644 --- a/docs/source/providers/inference/remote_fireworks.md +++ b/docs/source/providers/inference/remote_fireworks.md @@ -8,6 +8,7 @@ Fireworks AI inference provider for Llama models and other AI models on the Fire | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| +| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `url` | `` | No | https://api.fireworks.ai/inference/v1 | The URL for the Fireworks server | | `api_key` | `pydantic.types.SecretStr \| None` | No | | The Fireworks.ai API Key | diff --git a/docs/source/providers/inference/remote_ollama.md b/docs/source/providers/inference/remote_ollama.md index 23b8f87a2..f9f0a7622 100644 --- a/docs/source/providers/inference/remote_ollama.md +++ b/docs/source/providers/inference/remote_ollama.md @@ -9,8 +9,7 @@ Ollama inference provider for running local models through the Ollama runtime. | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `url` | `` | No | http://localhost:11434 | | -| `refresh_models` | `` | No | False | refresh and re-register models periodically | -| `refresh_models_interval` | `` | No | 300 | interval in seconds to refresh models | +| `refresh_models` | `` | No | False | Whether to refresh models periodically | ## Sample Configuration diff --git a/docs/source/providers/inference/remote_together.md b/docs/source/providers/inference/remote_together.md index f33ff42f2..d1fe3e82b 100644 --- a/docs/source/providers/inference/remote_together.md +++ b/docs/source/providers/inference/remote_together.md @@ -8,6 +8,7 @@ Together AI inference provider for open-source models and collaborative AI devel | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| +| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `url` | `` | No | https://api.together.xyz/v1 | The URL for the Together AI server | | `api_key` | `pydantic.types.SecretStr \| None` | No | | The Together AI API Key | diff --git a/docs/source/providers/inference/remote_vllm.md b/docs/source/providers/inference/remote_vllm.md index 5291199a4..172d35873 100644 --- a/docs/source/providers/inference/remote_vllm.md +++ b/docs/source/providers/inference/remote_vllm.md @@ -13,7 +13,6 @@ Remote vLLM inference provider for connecting to vLLM servers. | `api_token` | `str \| None` | No | fake | The API token | | `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. | | `refresh_models` | `` | No | False | Whether to refresh models periodically | -| `refresh_models_interval` | `` | No | 300 | Interval in seconds to refresh models | ## Sample Configuration diff --git a/llama_stack/apis/datatypes.py b/llama_stack/apis/datatypes.py index cebd1e580..eb972e829 100644 --- a/llama_stack/apis/datatypes.py +++ b/llama_stack/apis/datatypes.py @@ -4,15 +4,83 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from enum import Enum +from enum import Enum, EnumMeta -from pydantic import BaseModel +from pydantic import BaseModel, Field from llama_stack.schema_utils import json_schema_type +class DynamicApiMeta(EnumMeta): + def __new__(cls, name, bases, namespace): + # Store the original enum values + original_values = {k: v for k, v in namespace.items() if not k.startswith("_")} + + # Create the enum class + cls = super().__new__(cls, name, bases, namespace) + + # Store the original values for reference + cls._original_values = original_values + # Initialize _dynamic_values + cls._dynamic_values = {} + + return cls + + def __call__(cls, value): + try: + return super().__call__(value) + except ValueError as e: + # If this value was already dynamically added, return it + if value in cls._dynamic_values: + return cls._dynamic_values[value] + + # If the value doesn't exist, create a new enum member + # Create a new member name from the value + member_name = value.lower().replace("-", "_") + + # If this member name already exists in the enum, return the existing member + if member_name in cls._member_map_: + return cls._member_map_[member_name] + + # Instead of creating a new member, raise ValueError to force users to use Api.add() to + # register new APIs explicitly + raise ValueError(f"API '{value}' does not exist. Use Api.add() to register new APIs.") from e + + def __iter__(cls): + # Allow iteration over both static and dynamic members + yield from super().__iter__() + if hasattr(cls, "_dynamic_values"): + yield from cls._dynamic_values.values() + + def add(cls, value): + """ + Add a new API to the enum. + Used to register external APIs. + """ + member_name = value.lower().replace("-", "_") + + # If this member name already exists in the enum, return it + if member_name in cls._member_map_: + return cls._member_map_[member_name] + + # Create a new enum member + member = object.__new__(cls) + member._name_ = member_name + member._value_ = value + + # Add it to the enum class + cls._member_map_[member_name] = member + cls._member_names_.append(member_name) + cls._member_type_ = str + + # Store it in our dynamic values + cls._dynamic_values[value] = member + + return member + + @json_schema_type -class Api(Enum): +class Api(Enum, metaclass=DynamicApiMeta): """Enumeration of all available APIs in the Llama Stack system. :cvar providers: Provider management and configuration :cvar inference: Text generation, chat completions, and embeddings @@ -35,7 +103,6 @@ class Api(Enum): :cvar files: File storage and management :cvar inspect: Built-in system inspection and introspection """ - providers = "providers" inference = "inference" safety = "safety" @@ -77,3 +144,12 @@ class Error(BaseModel): title: str detail: str instance: str | None = None + + +class ExternalApiSpec(BaseModel): + """Specification for an external API implementation.""" + + module: str = Field(..., description="Python module containing the API implementation") + name: str = Field(..., description="Name of the API") + pip_packages: list[str] = Field(default=[], description="List of pip packages to install the API") + protocol: str = Field(..., description="Name of the protocol class for the API") diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index a3ecff289..dbd0defb1 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -911,12 +911,6 @@ class OpenAIEmbeddingsResponse(BaseModel): class ModelStore(Protocol): async def get_model(self, identifier: str) -> Model: ... - async def update_registered_llm_models( - self, - provider_id: str, - models: list[Model], - ) -> None: ... - class TextTruncation(Enum): """Config for how to truncate text for embedding when text is longer than the model's max sequence length. Start and End semantics depend on whether the language is left-to-right or right-to-left. diff --git a/llama_stack/apis/telemetry/telemetry.py b/llama_stack/apis/telemetry/telemetry.py index fbf6d6880..92422ac1b 100644 --- a/llama_stack/apis/telemetry/telemetry.py +++ b/llama_stack/apis/telemetry/telemetry.py @@ -22,6 +22,8 @@ from llama_stack.schema_utils import json_schema_type, register_schema, webmetho # Add this constant near the top of the file, after the imports DEFAULT_TTL_DAYS = 7 +REQUIRED_SCOPE = "telemetry.read" + @json_schema_type class SpanStatus(Enum): @@ -422,7 +424,7 @@ class Telemetry(Protocol): """ ... - @webmethod(route="/telemetry/traces", method="POST") + @webmethod(route="/telemetry/traces", method="POST", required_scope=REQUIRED_SCOPE) async def query_traces( self, attribute_filters: list[QueryCondition] | None = None, @@ -440,7 +442,7 @@ class Telemetry(Protocol): """ ... - @webmethod(route="/telemetry/traces/{trace_id:path}", method="GET") + @webmethod(route="/telemetry/traces/{trace_id:path}", method="GET", required_scope=REQUIRED_SCOPE) async def get_trace(self, trace_id: str) -> Trace: """Get a trace by its ID. @@ -449,7 +451,9 @@ class Telemetry(Protocol): """ ... - @webmethod(route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}", method="GET") + @webmethod( + route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}", method="GET", required_scope=REQUIRED_SCOPE + ) async def get_span(self, trace_id: str, span_id: str) -> Span: """Get a span by its ID. @@ -459,7 +463,7 @@ class Telemetry(Protocol): """ ... - @webmethod(route="/telemetry/spans/{span_id:path}/tree", method="POST") + @webmethod(route="/telemetry/spans/{span_id:path}/tree", method="POST", required_scope=REQUIRED_SCOPE) async def get_span_tree( self, span_id: str, @@ -475,7 +479,7 @@ class Telemetry(Protocol): """ ... - @webmethod(route="/telemetry/spans", method="POST") + @webmethod(route="/telemetry/spans", method="POST", required_scope=REQUIRED_SCOPE) async def query_spans( self, attribute_filters: list[QueryCondition], @@ -508,7 +512,7 @@ class Telemetry(Protocol): """ ... - @webmethod(route="/telemetry/metrics/{metric_name}", method="POST") + @webmethod(route="/telemetry/metrics/{metric_name}", method="POST", required_scope=REQUIRED_SCOPE) async def query_metrics( self, metric_name: str, diff --git a/llama_stack/cli/stack/_build.py b/llama_stack/cli/stack/_build.py index 3f94b1e2c..af2a46739 100644 --- a/llama_stack/cli/stack/_build.py +++ b/llama_stack/cli/stack/_build.py @@ -36,6 +36,7 @@ from llama_stack.distribution.datatypes import ( StackRunConfig, ) from llama_stack.distribution.distribution import get_provider_registry +from llama_stack.distribution.external import load_external_apis from llama_stack.distribution.resolver import InvalidProviderError from llama_stack.distribution.stack import replace_env_vars from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR, EXTERNAL_PROVIDERS_DIR @@ -93,7 +94,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None: ) sys.exit(1) elif args.providers: - providers_list: dict[str, str | list[str]] = dict() + provider_list: dict[str, list[Provider]] = dict() for api_provider in args.providers.split(","): if "=" not in api_provider: cprint( @@ -102,7 +103,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None: file=sys.stderr, ) sys.exit(1) - api, provider = api_provider.split("=") + api, provider_type = api_provider.split("=") providers_for_api = get_provider_registry().get(Api(api), None) if providers_for_api is None: cprint( @@ -111,16 +112,14 @@ def run_stack_build_command(args: argparse.Namespace) -> None: file=sys.stderr, ) sys.exit(1) - if provider in providers_for_api: - if api not in providers_list: - providers_list[api] = [] - # Use type guarding to ensure we have a list - provider_value = providers_list[api] - if isinstance(provider_value, list): - provider_value.append(provider) - else: - # Convert string to list and append - providers_list[api] = [provider_value, provider] + if provider_type in providers_for_api: + provider = Provider( + provider_type=provider_type, + provider_id=provider_type.split("::")[1], + config={}, + module=None, + ) + provider_list.setdefault(api, []).append(provider) else: cprint( f"{provider} is not a valid provider for the {api} API.", @@ -129,7 +128,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None: ) sys.exit(1) distribution_spec = DistributionSpec( - providers=providers_list, + providers=provider_list, description=",".join(args.providers), ) if not args.image_type: @@ -190,7 +189,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None: cprint("Tip: use to see options for the providers.\n", color="green", file=sys.stderr) - providers: dict[str, str | list[str]] = dict() + providers: dict[str, list[Provider]] = dict() for api, providers_for_api in get_provider_registry().items(): available_providers = [x for x in providers_for_api.keys() if x not in ("remote", "remote::sample")] if not available_providers: @@ -236,11 +235,13 @@ def run_stack_build_command(args: argparse.Namespace) -> None: if args.print_deps_only: print(f"# Dependencies for {args.template or args.config or image_name}") - normal_deps, special_deps = get_provider_dependencies(build_config) + normal_deps, special_deps, external_provider_dependencies = get_provider_dependencies(build_config) normal_deps += SERVER_DEPENDENCIES print(f"uv pip install {' '.join(normal_deps)}") for special_dep in special_deps: print(f"uv pip install {special_dep}") + for external_dep in external_provider_dependencies: + print(f"uv pip install {external_dep}") return try: @@ -303,27 +304,25 @@ def _generate_run_config( provider_registry = get_provider_registry(build_config) for api in apis: run_config.providers[api] = [] - provider_types = build_config.distribution_spec.providers[api] - if isinstance(provider_types, str): - provider_types = [provider_types] + providers = build_config.distribution_spec.providers[api] - for i, provider_type in enumerate(provider_types): - pid = provider_type.split("::")[-1] + for provider in providers: + pid = provider.provider_id - p = provider_registry[Api(api)][provider_type] + p = provider_registry[Api(api)][provider.provider_type] if p.deprecation_error: raise InvalidProviderError(p.deprecation_error) try: - config_type = instantiate_class_type(provider_registry[Api(api)][provider_type].config_class) - except ModuleNotFoundError: + config_type = instantiate_class_type(provider_registry[Api(api)][provider.provider_type].config_class) + except (ModuleNotFoundError, ValueError) as exc: # HACK ALERT: # This code executes after building is done, the import cannot work since the # package is either available in the venv or container - not available on the host. # TODO: use a "is_external" flag in ProviderSpec to check if the provider is # external cprint( - f"Failed to import provider {provider_type} for API {api} - assuming it's external, skipping", + f"Failed to import provider {provider.provider_type} for API {api} - assuming it's external, skipping: {exc}", color="yellow", file=sys.stderr, ) @@ -336,9 +335,10 @@ def _generate_run_config( config = {} p_spec = Provider( - provider_id=f"{pid}-{i}" if len(provider_types) > 1 else pid, - provider_type=provider_type, + provider_id=pid, + provider_type=provider.provider_type, config=config, + module=provider.module, ) run_config.providers[api].append(p_spec) @@ -401,9 +401,32 @@ def _run_stack_build_command_from_build_config( run_config_file = _generate_run_config(build_config, build_dir, image_name) with open(build_file_path, "w") as f: - to_write = json.loads(build_config.model_dump_json()) + to_write = json.loads(build_config.model_dump_json(exclude_none=True)) f.write(yaml.dump(to_write, sort_keys=False)) + # We first install the external APIs so that the build process can use them and discover the + # providers dependencies + if build_config.external_apis_dir: + cprint("Installing external APIs", color="yellow", file=sys.stderr) + external_apis = load_external_apis(build_config) + if external_apis: + # install the external APIs + packages = [] + for _, api_spec in external_apis.items(): + if api_spec.pip_packages: + packages.extend(api_spec.pip_packages) + cprint( + f"Installing {api_spec.name} with pip packages {api_spec.pip_packages}", + color="yellow", + file=sys.stderr, + ) + return_code = run_command(["uv", "pip", "install", *packages]) + if return_code != 0: + packages_str = ", ".join(packages) + raise RuntimeError( + f"Failed to install external APIs packages: {packages_str} (return code: {return_code})" + ) + return_code = build_image( build_config, build_file_path, diff --git a/llama_stack/distribution/build.py b/llama_stack/distribution/build.py index 699ed72da..b4eaac1c7 100644 --- a/llama_stack/distribution/build.py +++ b/llama_stack/distribution/build.py @@ -14,6 +14,7 @@ from termcolor import cprint from llama_stack.distribution.datatypes import BuildConfig from llama_stack.distribution.distribution import get_provider_registry +from llama_stack.distribution.external import load_external_apis from llama_stack.distribution.utils.exec import run_command from llama_stack.distribution.utils.image_types import LlamaStackImageType from llama_stack.providers.datatypes import Api @@ -41,7 +42,7 @@ class ApiInput(BaseModel): def get_provider_dependencies( config: BuildConfig | DistributionTemplate, -) -> tuple[list[str], list[str]]: +) -> tuple[list[str], list[str], list[str]]: """Get normal and special dependencies from provider configuration.""" if isinstance(config, DistributionTemplate): config = config.build_config() @@ -50,6 +51,7 @@ def get_provider_dependencies( additional_pip_packages = config.additional_pip_packages deps = [] + external_provider_deps = [] registry = get_provider_registry(config) for api_str, provider_or_providers in providers.items(): providers_for_api = registry[Api(api_str)] @@ -64,8 +66,16 @@ def get_provider_dependencies( raise ValueError(f"Provider `{provider}` is not available for API `{api_str}`") provider_spec = providers_for_api[provider_type] - deps.extend(provider_spec.pip_packages) - if provider_spec.container_image: + if hasattr(provider_spec, "is_external") and provider_spec.is_external: + # this ensures we install the top level module for our external providers + if provider_spec.module: + if isinstance(provider_spec.module, str): + external_provider_deps.append(provider_spec.module) + else: + external_provider_deps.extend(provider_spec.module) + if hasattr(provider_spec, "pip_packages"): + deps.extend(provider_spec.pip_packages) + if hasattr(provider_spec, "container_image") and provider_spec.container_image: raise ValueError("A stack's dependencies cannot have a container image") normal_deps = [] @@ -78,7 +88,7 @@ def get_provider_dependencies( normal_deps.extend(additional_pip_packages or []) - return list(set(normal_deps)), list(set(special_deps)) + return list(set(normal_deps)), list(set(special_deps)), list(set(external_provider_deps)) def print_pip_install_help(config: BuildConfig): @@ -103,41 +113,59 @@ def build_image( ): container_base = build_config.distribution_spec.container_image or "python:3.12-slim" - normal_deps, special_deps = get_provider_dependencies(build_config) + normal_deps, special_deps, external_provider_deps = get_provider_dependencies(build_config) normal_deps += SERVER_DEPENDENCIES + if build_config.external_apis_dir: + external_apis = load_external_apis(build_config) + if external_apis: + for _, api_spec in external_apis.items(): + normal_deps.extend(api_spec.pip_packages) if build_config.image_type == LlamaStackImageType.CONTAINER.value: script = str(importlib.resources.files("llama_stack") / "distribution/build_container.sh") args = [ script, + "--template-or-config", template_or_config, + "--image-name", image_name, + "--container-base", container_base, + "--normal-deps", " ".join(normal_deps), ] - # When building from a config file (not a template), include the run config path in the # build arguments if run_config is not None: - args.append(run_config) + args.extend(["--run-config", run_config]) elif build_config.image_type == LlamaStackImageType.CONDA.value: script = str(importlib.resources.files("llama_stack") / "distribution/build_conda_env.sh") args = [ script, + "--env-name", str(image_name), + "--build-file-path", str(build_file_path), + "--normal-deps", " ".join(normal_deps), ] elif build_config.image_type == LlamaStackImageType.VENV.value: script = str(importlib.resources.files("llama_stack") / "distribution/build_venv.sh") args = [ script, + "--env-name", str(image_name), + "--normal-deps", " ".join(normal_deps), ] + # Always pass both arguments, even if empty, to maintain consistent positional arguments if special_deps: - args.append("#".join(special_deps)) + args.extend(["--optional-deps", "#".join(special_deps)]) + if external_provider_deps: + args.extend( + ["--external-provider-deps", "#".join(external_provider_deps)] + ) # the script will install external provider module, get its deps, and install those too. return_code = run_command(args) diff --git a/llama_stack/distribution/build_conda_env.sh b/llama_stack/distribution/build_conda_env.sh index 61a2d5973..48ac3a1ab 100755 --- a/llama_stack/distribution/build_conda_env.sh +++ b/llama_stack/distribution/build_conda_env.sh @@ -9,10 +9,91 @@ LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-} LLAMA_STACK_CLIENT_DIR=${LLAMA_STACK_CLIENT_DIR:-} TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-} +PYPI_VERSION=${PYPI_VERSION:-} # This timeout (in seconds) is necessary when installing PyTorch via uv since it's likely to time out # Reference: https://github.com/astral-sh/uv/pull/1694 UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500} +set -euo pipefail + +# Define color codes +RED='\033[0;31m' +GREEN='\033[0;32m' +NC='\033[0m' # No Color + +SCRIPT_DIR=$(dirname "$(readlink -f "$0")") +source "$SCRIPT_DIR/common.sh" + +# Usage function +usage() { + echo "Usage: $0 --env-name --build-file-path --normal-deps [--external-provider-deps ] [--optional-deps ]" + echo "Example: $0 --env-name my-conda-env --build-file-path ./my-stack-build.yaml --normal-deps 'numpy pandas scipy' --external-provider-deps 'foo' --optional-deps 'bar'" + exit 1 +} + +# Parse arguments +env_name="" +build_file_path="" +normal_deps="" +external_provider_deps="" +optional_deps="" + +while [[ $# -gt 0 ]]; do + key="$1" + case "$key" in + --env-name) + if [[ -z "$2" || "$2" == --* ]]; then + echo "Error: --env-name requires a string value" >&2 + usage + fi + env_name="$2" + shift 2 + ;; + --build-file-path) + if [[ -z "$2" || "$2" == --* ]]; then + echo "Error: --build-file-path requires a string value" >&2 + usage + fi + build_file_path="$2" + shift 2 + ;; + --normal-deps) + if [[ -z "$2" || "$2" == --* ]]; then + echo "Error: --normal-deps requires a string value" >&2 + usage + fi + normal_deps="$2" + shift 2 + ;; + --external-provider-deps) + if [[ -z "$2" || "$2" == --* ]]; then + echo "Error: --external-provider-deps requires a string value" >&2 + usage + fi + external_provider_deps="$2" + shift 2 + ;; + --optional-deps) + if [[ -z "$2" || "$2" == --* ]]; then + echo "Error: --optional-deps requires a string value" >&2 + usage + fi + optional_deps="$2" + shift 2 + ;; + *) + echo "Unknown option: $1" >&2 + usage + ;; + esac +done + +# Check required arguments +if [[ -z "$env_name" || -z "$build_file_path" || -z "$normal_deps" ]]; then + echo "Error: --env-name, --build-file-path, and --normal-deps are required." >&2 + usage +fi + if [ -n "$LLAMA_STACK_DIR" ]; then echo "Using llama-stack-dir=$LLAMA_STACK_DIR" fi @@ -20,50 +101,18 @@ if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then echo "Using llama-stack-client-dir=$LLAMA_STACK_CLIENT_DIR" fi -if [ "$#" -lt 3 ]; then - echo "Usage: $0 []" >&2 - echo "Example: $0 my-conda-env ./my-stack-build.yaml 'numpy pandas scipy'" >&2 - exit 1 -fi - -special_pip_deps="$4" - -set -euo pipefail - -env_name="$1" -build_file_path="$2" -pip_dependencies="$3" - -# Define color codes -RED='\033[0;31m' -GREEN='\033[0;32m' -NC='\033[0m' # No Color - -# this is set if we actually create a new conda in which case we need to clean up -ENVNAME="" - -SCRIPT_DIR=$(dirname "$(readlink -f "$0")") -source "$SCRIPT_DIR/common.sh" - ensure_conda_env_python310() { - local env_name="$1" - local pip_dependencies="$2" - local special_pip_deps="$3" + # Use only global variables set by flag parser local python_version="3.12" - # Check if conda command is available if ! is_command_available conda; then printf "${RED}Error: conda command not found. Is Conda installed and in your PATH?${NC}" >&2 exit 1 fi - # Check if the environment exists if conda env list | grep -q "^${env_name} "; then printf "Conda environment '${env_name}' exists. Checking Python version...\n" - - # Check Python version in the environment current_version=$(conda run -n "${env_name}" python --version 2>&1 | cut -d' ' -f2 | cut -d'.' -f1,2) - if [ "$current_version" = "$python_version" ]; then printf "Environment '${env_name}' already has Python ${python_version}. No action needed.\n" else @@ -73,37 +122,37 @@ ensure_conda_env_python310() { else printf "Conda environment '${env_name}' does not exist. Creating with Python ${python_version}...\n" conda create -n "${env_name}" python="${python_version}" -y - - ENVNAME="${env_name}" - # setup_cleanup_handlers fi eval "$(conda shell.bash hook)" conda deactivate && conda activate "${env_name}" - "$CONDA_PREFIX"/bin/pip install uv if [ -n "$TEST_PYPI_VERSION" ]; then - # these packages are damaged in test-pypi, so install them first uv pip install fastapi libcst uv pip install --extra-index-url https://test.pypi.org/simple/ \ llama-stack=="$TEST_PYPI_VERSION" \ - "$pip_dependencies" - if [ -n "$special_pip_deps" ]; then - IFS='#' read -ra parts <<<"$special_pip_deps" + "$normal_deps" + if [ -n "$optional_deps" ]; then + IFS='#' read -ra parts <<<"$optional_deps" + for part in "${parts[@]}"; do + echo "$part" + uv pip install $part + done + fi + if [ -n "$external_provider_deps" ]; then + IFS='#' read -ra parts <<<"$external_provider_deps" for part in "${parts[@]}"; do echo "$part" uv pip install "$part" done fi else - # Re-installing llama-stack in the new conda environment if [ -n "$LLAMA_STACK_DIR" ]; then if [ ! -d "$LLAMA_STACK_DIR" ]; then printf "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: $LLAMA_STACK_DIR${NC}\n" >&2 exit 1 fi - printf "Installing from LLAMA_STACK_DIR: $LLAMA_STACK_DIR\n" uv pip install --no-cache-dir -e "$LLAMA_STACK_DIR" else @@ -115,31 +164,44 @@ ensure_conda_env_python310() { fi uv pip install --no-cache-dir "$SPEC_VERSION" fi - if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then if [ ! -d "$LLAMA_STACK_CLIENT_DIR" ]; then printf "${RED}Warning: LLAMA_STACK_CLIENT_DIR is set but directory does not exist: $LLAMA_STACK_CLIENT_DIR${NC}\n" >&2 exit 1 fi - printf "Installing from LLAMA_STACK_CLIENT_DIR: $LLAMA_STACK_CLIENT_DIR\n" uv pip install --no-cache-dir -e "$LLAMA_STACK_CLIENT_DIR" fi - - # Install pip dependencies printf "Installing pip dependencies\n" - uv pip install $pip_dependencies - if [ -n "$special_pip_deps" ]; then - IFS='#' read -ra parts <<<"$special_pip_deps" + uv pip install $normal_deps + if [ -n "$optional_deps" ]; then + IFS='#' read -ra parts <<<"$optional_deps" for part in "${parts[@]}"; do echo "$part" uv pip install $part done fi + if [ -n "$external_provider_deps" ]; then + IFS='#' read -ra parts <<<"$external_provider_deps" + for part in "${parts[@]}"; do + echo "Getting provider spec for module: $part and installing dependencies" + package_name=$(echo "$part" | sed 's/[<>=!].*//') + python3 -c " +import importlib +import sys +try: + module = importlib.import_module(f'$package_name.provider') + spec = module.get_provider_spec() + if hasattr(spec, 'pip_packages') and spec.pip_packages: + print('\\n'.join(spec.pip_packages)) +except Exception as e: + print(f'Error getting provider spec for $package_name: {e}', file=sys.stderr) +" | uv pip install -r - + done + fi fi - mv "$build_file_path" "$CONDA_PREFIX"/llamastack-build.yaml echo "Build spec configuration saved at $CONDA_PREFIX/llamastack-build.yaml" } -ensure_conda_env_python310 "$env_name" "$pip_dependencies" "$special_pip_deps" +ensure_conda_env_python310 "$env_name" "$build_file_path" "$normal_deps" "$optional_deps" "$external_provider_deps" diff --git a/llama_stack/distribution/build_container.sh b/llama_stack/distribution/build_container.sh index 6985c1cd0..7c406d3e7 100755 --- a/llama_stack/distribution/build_container.sh +++ b/llama_stack/distribution/build_container.sh @@ -27,52 +27,103 @@ RUN_CONFIG_PATH=/app/run.yaml BUILD_CONTEXT_DIR=$(pwd) -if [ "$#" -lt 4 ]; then - # This only works for templates - echo "Usage: $0 [] []" >&2 - exit 1 -fi set -euo pipefail -template_or_config="$1" -shift -image_name="$1" -shift -container_base="$1" -shift -pip_dependencies="$1" -shift - -# Handle optional arguments -run_config="" -special_pip_deps="" - -# Check if there are more arguments -# The logics is becoming cumbersom, we should refactor it if we can do better -if [ $# -gt 0 ]; then - # Check if the argument ends with .yaml - if [[ "$1" == *.yaml ]]; then - run_config="$1" - shift - # If there's another argument after .yaml, it must be special_pip_deps - if [ $# -gt 0 ]; then - special_pip_deps="$1" - fi - else - # If it's not .yaml, it must be special_pip_deps - special_pip_deps="$1" - fi -fi - # Define color codes RED='\033[0;31m' NC='\033[0m' # No Color +# Usage function +usage() { + echo "Usage: $0 --image-name --container-base --normal-deps [--run-config ] [--external-provider-deps ] [--optional-deps ]" + echo "Example: $0 --image-name llama-stack-img --container-base python:3.12-slim --normal-deps 'numpy pandas' --run-config ./run.yaml --external-provider-deps 'foo' --optional-deps 'bar'" + exit 1 +} + +# Parse arguments +image_name="" +container_base="" +normal_deps="" +external_provider_deps="" +optional_deps="" +run_config="" +template_or_config="" + +while [[ $# -gt 0 ]]; do + key="$1" + case "$key" in + --image-name) + if [[ -z "$2" || "$2" == --* ]]; then + echo "Error: --image-name requires a string value" >&2 + usage + fi + image_name="$2" + shift 2 + ;; + --container-base) + if [[ -z "$2" || "$2" == --* ]]; then + echo "Error: --container-base requires a string value" >&2 + usage + fi + container_base="$2" + shift 2 + ;; + --normal-deps) + if [[ -z "$2" || "$2" == --* ]]; then + echo "Error: --normal-deps requires a string value" >&2 + usage + fi + normal_deps="$2" + shift 2 + ;; + --external-provider-deps) + if [[ -z "$2" || "$2" == --* ]]; then + echo "Error: --external-provider-deps requires a string value" >&2 + usage + fi + external_provider_deps="$2" + shift 2 + ;; + --optional-deps) + if [[ -z "$2" || "$2" == --* ]]; then + echo "Error: --optional-deps requires a string value" >&2 + usage + fi + optional_deps="$2" + shift 2 + ;; + --run-config) + if [[ -z "$2" || "$2" == --* ]]; then + echo "Error: --run-config requires a string value" >&2 + usage + fi + run_config="$2" + shift 2 + ;; + --template-or-config) + if [[ -z "$2" || "$2" == --* ]]; then + echo "Error: --template-or-config requires a string value" >&2 + usage + fi + template_or_config="$2" + shift 2 + ;; + *) + echo "Unknown option: $1" >&2 + usage + ;; + esac +done + +# Check required arguments +if [[ -z "$image_name" || -z "$container_base" || -z "$normal_deps" ]]; then + echo "Error: --image-name, --container-base, and --normal-deps are required." >&2 + usage +fi + CONTAINER_BINARY=${CONTAINER_BINARY:-docker} CONTAINER_OPTS=${CONTAINER_OPTS:---progress=plain} - TEMP_DIR=$(mktemp -d) - SCRIPT_DIR=$(dirname "$(readlink -f "$0")") source "$SCRIPT_DIR/common.sh" @@ -81,18 +132,15 @@ add_to_container() { if [ -t 0 ]; then printf '%s\n' "$1" >>"$output_file" else - # If stdin is not a terminal, read from it (heredoc) cat >>"$output_file" fi } -# Check if container command is available if ! is_command_available "$CONTAINER_BINARY"; then printf "${RED}Error: ${CONTAINER_BINARY} command not found. Is ${CONTAINER_BINARY} installed and in your PATH?${NC}" >&2 exit 1 fi -# Update and install UBI9 components if UBI9 base image is used if [[ $container_base == *"registry.access.redhat.com/ubi9"* ]]; then add_to_container << EOF FROM $container_base @@ -135,16 +183,16 @@ EOF # Add pip dependencies first since llama-stack is what will change most often # so we can reuse layers. -if [ -n "$pip_dependencies" ]; then - read -ra pip_args <<< "$pip_dependencies" +if [ -n "$normal_deps" ]; then + read -ra pip_args <<< "$normal_deps" quoted_deps=$(printf " %q" "${pip_args[@]}") add_to_container << EOF RUN $MOUNT_CACHE uv pip install $quoted_deps EOF fi -if [ -n "$special_pip_deps" ]; then - IFS='#' read -ra parts <<<"$special_pip_deps" +if [ -n "$optional_deps" ]; then + IFS='#' read -ra parts <<<"$optional_deps" for part in "${parts[@]}"; do read -ra pip_args <<< "$part" quoted_deps=$(printf " %q" "${pip_args[@]}") @@ -154,7 +202,33 @@ EOF done fi -# Function to get Python command +if [ -n "$external_provider_deps" ]; then + IFS='#' read -ra parts <<<"$external_provider_deps" + for part in "${parts[@]}"; do + read -ra pip_args <<< "$part" + quoted_deps=$(printf " %q" "${pip_args[@]}") + add_to_container <=')[0].split('<=')[0].split('!=')[0].split('<')[0].split('>')[0] + module = importlib.import_module(f'{package_name}.provider') + spec = module.get_provider_spec() + if hasattr(spec, 'pip_packages') and spec.pip_packages: + if isinstance(spec.pip_packages, (list, tuple)): + print('\n'.join(spec.pip_packages)) +except Exception as e: + print(f'Error getting provider spec for {package_name}: {e}', file=sys.stderr) +PYTHON +EOF + done +fi + get_python_cmd() { if is_command_available python; then echo "python" diff --git a/llama_stack/distribution/build_venv.sh b/llama_stack/distribution/build_venv.sh index 264cedf9c..93db9ab28 100755 --- a/llama_stack/distribution/build_venv.sh +++ b/llama_stack/distribution/build_venv.sh @@ -18,6 +18,76 @@ UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500} UV_SYSTEM_PYTHON=${UV_SYSTEM_PYTHON:-} VIRTUAL_ENV=${VIRTUAL_ENV:-} +set -euo pipefail + +# Define color codes +RED='\033[0;31m' +NC='\033[0m' # No Color + +SCRIPT_DIR=$(dirname "$(readlink -f "$0")") +source "$SCRIPT_DIR/common.sh" + +# Usage function +usage() { + echo "Usage: $0 --env-name --normal-deps [--external-provider-deps ] [--optional-deps ]" + echo "Example: $0 --env-name mybuild --normal-deps 'numpy pandas scipy' --external-provider-deps 'foo' --optional-deps 'bar'" + exit 1 +} + +# Parse arguments +env_name="" +normal_deps="" +external_provider_deps="" +optional_deps="" + +while [[ $# -gt 0 ]]; do + key="$1" + case "$key" in + --env-name) + if [[ -z "$2" || "$2" == --* ]]; then + echo "Error: --env-name requires a string value" >&2 + usage + fi + env_name="$2" + shift 2 + ;; + --normal-deps) + if [[ -z "$2" || "$2" == --* ]]; then + echo "Error: --normal-deps requires a string value" >&2 + usage + fi + normal_deps="$2" + shift 2 + ;; + --external-provider-deps) + if [[ -z "$2" || "$2" == --* ]]; then + echo "Error: --external-provider-deps requires a string value" >&2 + usage + fi + external_provider_deps="$2" + shift 2 + ;; + --optional-deps) + if [[ -z "$2" || "$2" == --* ]]; then + echo "Error: --optional-deps requires a string value" >&2 + usage + fi + optional_deps="$2" + shift 2 + ;; + *) + echo "Unknown option: $1" >&2 + usage + ;; + esac +done + +# Check required arguments +if [[ -z "$env_name" || -z "$normal_deps" ]]; then + echo "Error: --env-name and --normal-deps are required." >&2 + usage +fi + if [ -n "$LLAMA_STACK_DIR" ]; then echo "Using llama-stack-dir=$LLAMA_STACK_DIR" fi @@ -25,29 +95,6 @@ if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then echo "Using llama-stack-client-dir=$LLAMA_STACK_CLIENT_DIR" fi -if [ "$#" -lt 2 ]; then - echo "Usage: $0 []" >&2 - echo "Example: $0 mybuild ./my-stack-build.yaml 'numpy pandas scipy'" >&2 - exit 1 -fi - -special_pip_deps="$3" - -set -euo pipefail - -env_name="$1" -pip_dependencies="$2" - -# Define color codes -RED='\033[0;31m' -NC='\033[0m' # No Color - -# this is set if we actually create a new conda in which case we need to clean up -ENVNAME="" - -SCRIPT_DIR=$(dirname "$(readlink -f "$0")") -source "$SCRIPT_DIR/common.sh" - # pre-run checks to make sure we can proceed with the installation pre_run_checks() { local env_name="$1" @@ -71,49 +118,44 @@ pre_run_checks() { } run() { - local env_name="$1" - local pip_dependencies="$2" - local special_pip_deps="$3" - + # Use only global variables set by flag parser if [ -n "$UV_SYSTEM_PYTHON" ] || [ "$env_name" == "__system__" ]; then echo "Installing dependencies in system Python environment" - # if env == __system__, ensure we set UV_SYSTEM_PYTHON export UV_SYSTEM_PYTHON=1 elif [ "$VIRTUAL_ENV" == "$env_name" ]; then echo "Virtual environment $env_name is already active" else echo "Using virtual environment $env_name" uv venv "$env_name" - # shellcheck source=/dev/null source "$env_name/bin/activate" fi if [ -n "$TEST_PYPI_VERSION" ]; then - # these packages are damaged in test-pypi, so install them first uv pip install fastapi libcst - # shellcheck disable=SC2086 - # we are building a command line so word splitting is expected uv pip install --extra-index-url https://test.pypi.org/simple/ \ --index-strategy unsafe-best-match \ llama-stack=="$TEST_PYPI_VERSION" \ - $pip_dependencies - if [ -n "$special_pip_deps" ]; then - IFS='#' read -ra parts <<<"$special_pip_deps" + $normal_deps + if [ -n "$optional_deps" ]; then + IFS='#' read -ra parts <<<"$optional_deps" for part in "${parts[@]}"; do echo "$part" - # shellcheck disable=SC2086 - # we are building a command line so word splitting is expected uv pip install $part done fi + if [ -n "$external_provider_deps" ]; then + IFS='#' read -ra parts <<<"$external_provider_deps" + for part in "${parts[@]}"; do + echo "$part" + uv pip install "$part" + done + fi else - # Re-installing llama-stack in the new virtual environment if [ -n "$LLAMA_STACK_DIR" ]; then if [ ! -d "$LLAMA_STACK_DIR" ]; then printf "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: %s${NC}\n" "$LLAMA_STACK_DIR" >&2 exit 1 fi - printf "Installing from LLAMA_STACK_DIR: %s\n" "$LLAMA_STACK_DIR" uv pip install --no-cache-dir -e "$LLAMA_STACK_DIR" else @@ -125,27 +167,41 @@ run() { printf "${RED}Warning: LLAMA_STACK_CLIENT_DIR is set but directory does not exist: %s${NC}\n" "$LLAMA_STACK_CLIENT_DIR" >&2 exit 1 fi - printf "Installing from LLAMA_STACK_CLIENT_DIR: %s\n" "$LLAMA_STACK_CLIENT_DIR" uv pip install --no-cache-dir -e "$LLAMA_STACK_CLIENT_DIR" fi - # Install pip dependencies printf "Installing pip dependencies\n" - # shellcheck disable=SC2086 - # we are building a command line so word splitting is expected - uv pip install $pip_dependencies - if [ -n "$special_pip_deps" ]; then - IFS='#' read -ra parts <<<"$special_pip_deps" + uv pip install $normal_deps + if [ -n "$optional_deps" ]; then + IFS='#' read -ra parts <<<"$optional_deps" for part in "${parts[@]}"; do - echo "$part" - # shellcheck disable=SC2086 - # we are building a command line so word splitting is expected + echo "Installing special provider module: $part" uv pip install $part done fi + if [ -n "$external_provider_deps" ]; then + IFS='#' read -ra parts <<<"$external_provider_deps" + for part in "${parts[@]}"; do + echo "Installing external provider module: $part" + uv pip install "$part" + echo "Getting provider spec for module: $part and installing dependencies" + package_name=$(echo "$part" | sed 's/[<>=!].*//') + python3 -c " +import importlib +import sys +try: + module = importlib.import_module(f'$package_name.provider') + spec = module.get_provider_spec() + if hasattr(spec, 'pip_packages') and spec.pip_packages: + print('\\n'.join(spec.pip_packages)) +except Exception as e: + print(f'Error getting provider spec for $package_name: {e}', file=sys.stderr) +" | uv pip install -r - + done + fi fi } pre_run_checks "$env_name" -run "$env_name" "$pip_dependencies" "$special_pip_deps" +run diff --git a/llama_stack/distribution/configure.py b/llama_stack/distribution/configure.py index 2238eef93..355233d53 100644 --- a/llama_stack/distribution/configure.py +++ b/llama_stack/distribution/configure.py @@ -91,21 +91,21 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec logger.info(f"Configuring API `{api_str}`...") updated_providers = [] - for i, provider_type in enumerate(plist): + for i, provider in enumerate(plist): if i >= 1: - others = ", ".join(plist[i:]) + others = ", ".join(p.provider_type for p in plist[i:]) logger.info( f"Not configuring other providers ({others}) interactively. Please edit the resulting YAML directly.\n" ) break - logger.info(f"> Configuring provider `({provider_type})`") + logger.info(f"> Configuring provider `({provider.provider_type})`") updated_providers.append( configure_single_provider( provider_registry[api], Provider( - provider_id=(f"{provider_type}-{i:02d}" if len(plist) > 1 else provider_type), - provider_type=provider_type, + provider_id=(f"{provider.provider_id}-{i:02d}" if len(plist) > 1 else provider.provider_id), + provider_type=provider.provider_type, config={}, ), ) diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index ead1331f3..c17aadcc1 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -36,6 +36,11 @@ LLAMA_STACK_RUN_CONFIG_VERSION = 2 RoutingKey = str | list[str] +class RegistryEntrySource(StrEnum): + via_register_api = "via_register_api" + listed_from_provider = "listed_from_provider" + + class User(BaseModel): principal: str # further attributes that may be used for access control decisions @@ -50,6 +55,7 @@ class ResourceWithOwner(Resource): resource. This can be used to constrain access to the resource.""" owner: User | None = None + source: RegistryEntrySource = RegistryEntrySource.via_register_api # Use the extended Resource for all routable objects @@ -130,29 +136,40 @@ class RoutingTableProviderSpec(ProviderSpec): pip_packages: list[str] = Field(default_factory=list) +class Provider(BaseModel): + # provider_id of None means that the provider is not enabled - this happens + # when the provider is enabled via a conditional environment variable + provider_id: str | None + provider_type: str + config: dict[str, Any] = {} + module: str | None = Field( + default=None, + description=""" + Fully-qualified name of the external provider module to import. The module is expected to have: + + - `get_adapter_impl(config, deps)`: returns the adapter implementation + + Example: `module: ramalama_stack` + """, + ) + + class DistributionSpec(BaseModel): description: str | None = Field( default="", description="Description of the distribution", ) container_image: str | None = None - providers: dict[str, str | list[str]] = Field( + providers: dict[str, list[Provider]] = Field( default_factory=dict, description=""" -Provider Types for each of the APIs provided by this distribution. If you -select multiple providers, you should provide an appropriate 'routing_map' -in the runtime configuration to help route to the correct provider.""", + Provider Types for each of the APIs provided by this distribution. If you + select multiple providers, you should provide an appropriate 'routing_map' + in the runtime configuration to help route to the correct provider. + """, ) -class Provider(BaseModel): - # provider_id of None means that the provider is not enabled - this happens - # when the provider is enabled via a conditional environment variable - provider_id: str | None - provider_type: str - config: dict[str, Any] - - class LoggingConfig(BaseModel): category_levels: dict[str, str] = Field( default_factory=dict, @@ -381,6 +398,11 @@ a default SQLite store will be used.""", description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.", ) + external_apis_dir: Path | None = Field( + default=None, + description="Path to directory containing external API implementations. The APIs code and dependencies must be installed on the system.", + ) + @field_validator("external_providers_dir") @classmethod def validate_external_providers_dir(cls, v): @@ -412,6 +434,10 @@ class BuildConfig(BaseModel): default_factory=list, description="Additional pip packages to install in the distribution. These packages will be installed in the distribution environment.", ) + external_apis_dir: Path | None = Field( + default=None, + description="Path to directory containing external API implementations. The APIs code and dependencies must be installed on the system.", + ) @field_validator("external_providers_dir") @classmethod diff --git a/llama_stack/distribution/distribution.py b/llama_stack/distribution/distribution.py index e37b2c443..6e7297e32 100644 --- a/llama_stack/distribution/distribution.py +++ b/llama_stack/distribution/distribution.py @@ -12,6 +12,8 @@ from typing import Any import yaml from pydantic import BaseModel +from llama_stack.distribution.datatypes import BuildConfig, DistributionSpec +from llama_stack.distribution.external import load_external_apis from llama_stack.log import get_logger from llama_stack.providers.datatypes import ( AdapterSpec, @@ -96,12 +98,10 @@ def _load_inline_provider_spec(spec_data: dict[str, Any], api: Api, provider_nam return spec -def get_provider_registry( - config=None, -) -> dict[Api, dict[str, ProviderSpec]]: +def get_provider_registry(config=None) -> dict[Api, dict[str, ProviderSpec]]: """Get the provider registry, optionally including external providers. - This function loads both built-in providers and external providers from YAML files. + This function loads both built-in providers and external providers from YAML files or from their provided modules. External providers are loaded from a directory structure like: providers.d/ @@ -122,8 +122,13 @@ def get_provider_registry( safety/ llama-guard.yaml + This method is overloaded in that it can be called from a variety of places: during build, during run, during stack construction. + So when building external providers from a module, there are scenarios where the pip package required to import the module might not be available yet. + There is special handling for all of the potential cases this method can be called from. + Args: config: Optional object containing the external providers directory path + building: Optional bool delineating whether or not this is being called from a build process Returns: A dictionary mapping APIs to their available providers @@ -133,58 +138,140 @@ def get_provider_registry( ValueError: If any provider spec is invalid """ - ret: dict[Api, dict[str, ProviderSpec]] = {} + registry: dict[Api, dict[str, ProviderSpec]] = {} for api in providable_apis(): name = api.name.lower() logger.debug(f"Importing module {name}") try: module = importlib.import_module(f"llama_stack.providers.registry.{name}") - ret[api] = {a.provider_type: a for a in module.available_providers()} + registry[api] = {a.provider_type: a for a in module.available_providers()} except ImportError as e: logger.warning(f"Failed to import module {name}: {e}") - # Check if config has the external_providers_dir attribute - if config and hasattr(config, "external_providers_dir") and config.external_providers_dir: - external_providers_dir = os.path.abspath(os.path.expanduser(config.external_providers_dir)) - if not os.path.exists(external_providers_dir): - raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}") - logger.info(f"Loading external providers from {external_providers_dir}") + # Refresh providable APIs with external APIs if any + external_apis = load_external_apis(config) + for api, api_spec in external_apis.items(): + name = api_spec.name.lower() + logger.info(f"Importing external API {name} module {api_spec.module}") + try: + module = importlib.import_module(api_spec.module) + registry[api] = {a.provider_type: a for a in module.available_providers()} + except (ImportError, AttributeError) as e: + # Populate the registry with an empty dict to avoid breaking the provider registry + # This assume that the in-tree provider(s) are not available for this API which means + # that users will need to use external providers for this API. + registry[api] = {} + logger.error( + f"Failed to import external API {name}: {e}. Could not populate the in-tree provider(s) registry for {api.name}. \n" + "Install the API package to load any in-tree providers for this API." + ) - for api in providable_apis(): - api_name = api.name.lower() + # Check if config has external providers + if config: + if hasattr(config, "external_providers_dir") and config.external_providers_dir: + registry = get_external_providers_from_dir(registry, config) + # else lets check for modules in each provider + registry = get_external_providers_from_module( + registry=registry, + config=config, + building=(isinstance(config, BuildConfig) or isinstance(config, DistributionSpec)), + ) - # Process both remote and inline providers - for provider_type in ["remote", "inline"]: - api_dir = os.path.join(external_providers_dir, provider_type, api_name) - if not os.path.exists(api_dir): - logger.debug(f"No {provider_type} provider directory found for {api_name}") - continue + return registry - # Look for provider spec files in the API directory - for spec_path in glob.glob(os.path.join(api_dir, "*.yaml")): - provider_name = os.path.splitext(os.path.basename(spec_path))[0] - logger.info(f"Loading {provider_type} provider spec from {spec_path}") - try: - with open(spec_path) as f: - spec_data = yaml.safe_load(f) +def get_external_providers_from_dir( + registry: dict[Api, dict[str, ProviderSpec]], config +) -> dict[Api, dict[str, ProviderSpec]]: + logger.warning( + "Specifying external providers via `external_providers_dir` is being deprecated. Please specify `module:` in the provider instead." + ) + external_providers_dir = os.path.abspath(os.path.expanduser(config.external_providers_dir)) + if not os.path.exists(external_providers_dir): + raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}") + logger.info(f"Loading external providers from {external_providers_dir}") - if provider_type == "remote": - spec = _load_remote_provider_spec(spec_data, api) - provider_type_key = f"remote::{provider_name}" - else: - spec = _load_inline_provider_spec(spec_data, api, provider_name) - provider_type_key = f"inline::{provider_name}" + for api in providable_apis(): + api_name = api.name.lower() - logger.info(f"Loaded {provider_type} provider spec for {provider_type_key} from {spec_path}") - if provider_type_key in ret[api]: - logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}") - ret[api][provider_type_key] = spec - logger.info(f"Successfully loaded external provider {provider_type_key}") - except yaml.YAMLError as yaml_err: - logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}") - raise yaml_err - except Exception as e: - logger.error(f"Failed to load provider spec from {spec_path}: {e}") - raise e - return ret + # Process both remote and inline providers + for provider_type in ["remote", "inline"]: + api_dir = os.path.join(external_providers_dir, provider_type, api_name) + if not os.path.exists(api_dir): + logger.debug(f"No {provider_type} provider directory found for {api_name}") + continue + + # Look for provider spec files in the API directory + for spec_path in glob.glob(os.path.join(api_dir, "*.yaml")): + provider_name = os.path.splitext(os.path.basename(spec_path))[0] + logger.info(f"Loading {provider_type} provider spec from {spec_path}") + + try: + with open(spec_path) as f: + spec_data = yaml.safe_load(f) + + if provider_type == "remote": + spec = _load_remote_provider_spec(spec_data, api) + provider_type_key = f"remote::{provider_name}" + else: + spec = _load_inline_provider_spec(spec_data, api, provider_name) + provider_type_key = f"inline::{provider_name}" + + logger.info(f"Loaded {provider_type} provider spec for {provider_type_key} from {spec_path}") + if provider_type_key in registry[api]: + logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}") + registry[api][provider_type_key] = spec + logger.info(f"Successfully loaded external provider {provider_type_key}") + except yaml.YAMLError as yaml_err: + logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}") + raise yaml_err + except Exception as e: + logger.error(f"Failed to load provider spec from {spec_path}: {e}") + raise e + + return registry + + +def get_external_providers_from_module( + registry: dict[Api, dict[str, ProviderSpec]], config, building: bool +) -> dict[Api, dict[str, ProviderSpec]]: + provider_list = None + if isinstance(config, BuildConfig): + provider_list = config.distribution_spec.providers.items() + else: + provider_list = config.providers.items() + if provider_list is None: + logger.warning("Could not get list of providers from config") + return registry + for provider_api, providers in provider_list: + for provider in providers: + if not hasattr(provider, "module") or provider.module is None: + continue + # get provider using module + try: + if not building: + package_name = provider.module.split("==")[0] + module = importlib.import_module(f"{package_name}.provider") + # if config class is wrong you will get an error saying module could not be imported + spec = module.get_provider_spec() + else: + # pass in a partially filled out provider spec to satisfy the registry -- knowing we will be overwriting it later upon build and run + spec = ProviderSpec( + api=Api(provider_api), + provider_type=provider.provider_type, + is_external=True, + module=provider.module, + config_class="", + ) + provider_type = provider.provider_type + # in the case we are building we CANNOT import this module of course because it has not been installed. + # return a partially filled out spec that the build script will populate. + registry[Api(provider_api)][provider_type] = spec + except ModuleNotFoundError as exc: + raise ValueError( + "get_provider_spec not found. If specifying an external provider via `module` in the Provider spec, the Provider must have the `provider.get_provider_spec` module available" + ) from exc + except Exception as e: + logger.error(f"Failed to load provider spec from module {provider.module}: {e}") + raise e + return registry diff --git a/llama_stack/distribution/external.py b/llama_stack/distribution/external.py new file mode 100644 index 000000000..0a7da16b1 --- /dev/null +++ b/llama_stack/distribution/external.py @@ -0,0 +1,54 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +import yaml + +from llama_stack.apis.datatypes import Api, ExternalApiSpec +from llama_stack.distribution.datatypes import BuildConfig, StackRunConfig +from llama_stack.log import get_logger + +logger = get_logger(name=__name__, category="core") + + +def load_external_apis(config: StackRunConfig | BuildConfig | None) -> dict[Api, ExternalApiSpec]: + """Load external API specifications from the configured directory. + + Args: + config: StackRunConfig or BuildConfig containing the external APIs directory path + + Returns: + A dictionary mapping API names to their specifications + """ + if not config or not config.external_apis_dir: + return {} + + external_apis_dir = config.external_apis_dir.expanduser().resolve() + if not external_apis_dir.is_dir(): + logger.error(f"External APIs directory is not a directory: {external_apis_dir}") + return {} + + logger.info(f"Loading external APIs from {external_apis_dir}") + external_apis: dict[Api, ExternalApiSpec] = {} + + # Look for YAML files in the external APIs directory + for yaml_path in external_apis_dir.glob("*.yaml"): + try: + with open(yaml_path) as f: + spec_data = yaml.safe_load(f) + + spec = ExternalApiSpec(**spec_data) + api = Api.add(spec.name) + logger.info(f"Loaded external API spec for {spec.name} from {yaml_path}") + external_apis[api] = spec + except yaml.YAMLError as yaml_err: + logger.error(f"Failed to parse YAML file {yaml_path}: {yaml_err}") + raise + except Exception: + logger.exception(f"Failed to load external API spec from {yaml_path}") + raise + + return external_apis diff --git a/llama_stack/distribution/inspect.py b/llama_stack/distribution/inspect.py index 5822070ad..f62de4f6b 100644 --- a/llama_stack/distribution/inspect.py +++ b/llama_stack/distribution/inspect.py @@ -16,6 +16,7 @@ from llama_stack.apis.inspect import ( VersionInfo, ) from llama_stack.distribution.datatypes import StackRunConfig +from llama_stack.distribution.external import load_external_apis from llama_stack.distribution.server.routes import get_all_api_routes from llama_stack.providers.datatypes import HealthStatus @@ -42,7 +43,8 @@ class DistributionInspectImpl(Inspect): run_config: StackRunConfig = self.config.run_config ret = [] - all_endpoints = get_all_api_routes() + external_apis = load_external_apis(run_config) + all_endpoints = get_all_api_routes(external_apis) for api, endpoints in all_endpoints.items(): # Always include provider and inspect APIs, filter others based on run config if api.value in ["providers", "inspect"]: @@ -53,7 +55,8 @@ class DistributionInspectImpl(Inspect): method=next(iter([m for m in e.methods if m != "HEAD"])), provider_types=[], # These APIs don't have "real" providers - they're internal to the stack ) - for e in endpoints + for e, _ in endpoints + if e.methods is not None ] ) else: @@ -66,7 +69,8 @@ class DistributionInspectImpl(Inspect): method=next(iter([m for m in e.methods if m != "HEAD"])), provider_types=[p.provider_type for p in providers], ) - for e in endpoints + for e, _ in endpoints + if e.methods is not None ] ) diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index 5dc0078d4..bcb0b9167 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -161,7 +161,13 @@ class LlamaStackAsLibraryClient(LlamaStackClient): if not self.skip_logger_removal: self._remove_root_logger_handlers() - return self.loop.run_until_complete(self.async_client.initialize()) + # use a new event loop to avoid interfering with the main event loop + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + return loop.run_until_complete(self.async_client.initialize()) + finally: + asyncio.set_event_loop(None) def _remove_root_logger_handlers(self): """ @@ -243,15 +249,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): file=sys.stderr, ) if self.config_path_or_template_name.endswith(".yaml"): - # Convert Provider objects to their types - provider_types: dict[str, str | list[str]] = {} - for api, providers in self.config.providers.items(): - types = [p.provider_type for p in providers] - # Convert single-item lists to strings - provider_types[api] = types[0] if len(types) == 1 else types build_config = BuildConfig( distribution_spec=DistributionSpec( - providers=provider_types, + providers=self.config.providers, ), external_providers_dir=self.config.external_providers_dir, ) @@ -353,13 +353,15 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): body = options.params or {} body |= options.json_data or {} - matched_func, path_params, route = find_matching_route(options.method, path, self.route_impls) + matched_func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls) body |= path_params body, field_names = self._handle_file_uploads(options, body) body = self._convert_body(path, options.method, body, exclude_params=set(field_names)) - await start_trace(route, {"__location__": "library_client"}) + + trace_path = webmethod.descriptive_name or route_path + await start_trace(trace_path, {"__location__": "library_client"}) try: result = await matched_func(**body) finally: @@ -409,12 +411,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): path = options.url body = options.params or {} body |= options.json_data or {} - func, path_params, route = find_matching_route(options.method, path, self.route_impls) + func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls) body |= path_params body = self._convert_body(path, options.method, body) - await start_trace(route, {"__location__": "library_client"}) + trace_path = webmethod.descriptive_name or route_path + await start_trace(trace_path, {"__location__": "library_client"}) async def gen(): try: @@ -445,8 +448,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): # we use asynchronous impl always internally and channel all requests to AsyncLlamaStackClient # however, the top-level caller may be a SyncAPIClient -- so its stream_cls might be a Stream (SyncStream) # so we need to convert it to AsyncStream + # mypy can't track runtime variables inside the [...] of a generic, so ignore that check args = get_args(stream_cls) - stream_cls = AsyncStream[args[0]] + stream_cls = AsyncStream[args[0]] # type: ignore[valid-type] response = AsyncAPIResponse( raw=mock_response, client=self, @@ -468,7 +472,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): exclude_params = exclude_params or set() - func, _, _ = find_matching_route(method, path, self.route_impls) + func, _, _, _ = find_matching_route(method, path, self.route_impls) sig = inspect.signature(func) # Strip NOT_GIVENs to use the defaults in signature diff --git a/llama_stack/distribution/request_headers.py b/llama_stack/distribution/request_headers.py index 81d494e04..509c2be44 100644 --- a/llama_stack/distribution/request_headers.py +++ b/llama_stack/distribution/request_headers.py @@ -101,3 +101,15 @@ def get_authenticated_user() -> User | None: if not provider_data: return None return provider_data.get("__authenticated_user") + + +def user_from_scope(scope: dict) -> User | None: + """Create a User object from ASGI scope data (set by authentication middleware)""" + user_attributes = scope.get("user_attributes", {}) + principal = scope.get("principal", "") + + # auth not enabled + if not principal and not user_attributes: + return None + + return User(principal=principal, attributes=user_attributes) diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index c83218276..db6856ed2 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -11,6 +11,7 @@ from llama_stack.apis.agents import Agents from llama_stack.apis.benchmarks import Benchmarks from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets +from llama_stack.apis.datatypes import ExternalApiSpec from llama_stack.apis.eval import Eval from llama_stack.apis.files import Files from llama_stack.apis.inference import Inference, InferenceProvider @@ -35,6 +36,7 @@ from llama_stack.distribution.datatypes import ( StackRunConfig, ) from llama_stack.distribution.distribution import builtin_automatically_routed_apis +from llama_stack.distribution.external import load_external_apis from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.log import get_logger @@ -59,8 +61,16 @@ class InvalidProviderError(Exception): pass -def api_protocol_map() -> dict[Api, Any]: - return { +def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) -> dict[Api, Any]: + """Get a mapping of API types to their protocol classes. + + Args: + external_apis: Optional dictionary of external API specifications + + Returns: + Dictionary mapping API types to their protocol classes + """ + protocols = { Api.providers: ProvidersAPI, Api.agents: Agents, Api.inference: Inference, @@ -83,10 +93,23 @@ def api_protocol_map() -> dict[Api, Any]: Api.files: Files, } + if external_apis: + for api, api_spec in external_apis.items(): + try: + module = importlib.import_module(api_spec.module) + api_class = getattr(module, api_spec.protocol) -def api_protocol_map_for_compliance_check() -> dict[Api, Any]: + protocols[api] = api_class + except (ImportError, AttributeError): + logger.exception(f"Failed to load external API {api_spec.name}") + + return protocols + + +def api_protocol_map_for_compliance_check(config: Any) -> dict[Api, Any]: + external_apis = load_external_apis(config) return { - **api_protocol_map(), + **api_protocol_map(external_apis), Api.inference: InferenceProvider, } @@ -250,7 +273,7 @@ async def instantiate_providers( dist_registry: DistributionRegistry, run_config: StackRunConfig, policy: list[AccessRule], -) -> dict: +) -> dict[Api, Any]: """Instantiates providers asynchronously while managing dependencies.""" impls: dict[Api, Any] = {} inner_impls_by_provider_id: dict[str, dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis} @@ -322,7 +345,7 @@ async def instantiate_provider( policy: list[AccessRule], ): provider_spec = provider.spec - if not hasattr(provider_spec, "module"): + if not hasattr(provider_spec, "module") or provider_spec.module is None: raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute") logger.debug(f"Instantiating provider {provider.provider_id} from {provider_spec.module}") @@ -360,7 +383,7 @@ async def instantiate_provider( impl.__provider_spec__ = provider_spec impl.__provider_config__ = config - protocols = api_protocol_map_for_compliance_check() + protocols = api_protocol_map_for_compliance_check(run_config) additional_protocols = additional_protocols_map() # TODO: check compliance for special tool groups # the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol diff --git a/llama_stack/distribution/routing_tables/common.py b/llama_stack/distribution/routing_tables/common.py index 2f6ac90bb..caf0780fd 100644 --- a/llama_stack/distribution/routing_tables/common.py +++ b/llama_stack/distribution/routing_tables/common.py @@ -117,6 +117,9 @@ class CommonRoutingTableImpl(RoutingTable): for p in self.impls_by_provider_id.values(): await p.shutdown() + async def refresh(self) -> None: + pass + async def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any: from .benchmarks import BenchmarksRoutingTable from .datasets import DatasetsRoutingTable @@ -206,7 +209,6 @@ class CommonRoutingTableImpl(RoutingTable): if obj.type == ResourceType.model.value: await self.dist_registry.register(registered_obj) return registered_obj - else: await self.dist_registry.register(obj) return obj diff --git a/llama_stack/distribution/routing_tables/models.py b/llama_stack/distribution/routing_tables/models.py index f2787b308..022c3dd40 100644 --- a/llama_stack/distribution/routing_tables/models.py +++ b/llama_stack/distribution/routing_tables/models.py @@ -10,6 +10,7 @@ from typing import Any from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel from llama_stack.distribution.datatypes import ( ModelWithOwner, + RegistryEntrySource, ) from llama_stack.log import get_logger @@ -19,6 +20,26 @@ logger = get_logger(name=__name__, category="core") class ModelsRoutingTable(CommonRoutingTableImpl, Models): + listed_providers: set[str] = set() + + async def refresh(self) -> None: + for provider_id, provider in self.impls_by_provider_id.items(): + refresh = await provider.should_refresh_models() + if not (refresh or provider_id in self.listed_providers): + continue + + try: + models = await provider.list_models() + except Exception as e: + logger.exception(f"Model refresh failed for provider {provider_id}: {e}") + continue + + self.listed_providers.add(provider_id) + if models is None: + continue + + await self.update_registered_models(provider_id, models) + async def list_models(self) -> ListModelsResponse: return ListModelsResponse(data=await self.get_all_with_type("model")) @@ -81,6 +102,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): provider_id=provider_id, metadata=metadata, model_type=model_type, + source=RegistryEntrySource.via_register_api, ) registered_model = await self.register_object(model) return registered_model @@ -91,7 +113,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): raise ValueError(f"Model {model_id} not found") await self.unregister_object(existing_model) - async def update_registered_llm_models( + async def update_registered_models( self, provider_id: str, models: list[Model], @@ -102,18 +124,19 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): # from run.yaml) that we need to keep track of model_ids = {} for model in existing_models: - # we leave embeddings models alone because often we don't get metadata - # (embedding dimension, etc.) from the provider - if model.provider_id == provider_id and model.model_type == ModelType.llm: + if model.provider_id != provider_id: + continue + if model.source == RegistryEntrySource.via_register_api: model_ids[model.provider_resource_id] = model.identifier - logger.debug(f"unregistering model {model.identifier}") - await self.unregister_object(model) + continue + + logger.debug(f"unregistering model {model.identifier}") + await self.unregister_object(model) for model in models: - if model.model_type != ModelType.llm: - continue if model.provider_resource_id in model_ids: - model.identifier = model_ids[model.provider_resource_id] + # avoid overwriting a non-provider-registered model entry + continue logger.debug(f"registering model {model.identifier} ({model.provider_resource_id})") await self.register_object( @@ -123,5 +146,6 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): provider_id=provider_id, metadata=model.metadata, model_type=model.model_type, + source=RegistryEntrySource.listed_from_provider, ) ) diff --git a/llama_stack/distribution/server/auth.py b/llama_stack/distribution/server/auth.py index fadbf7b49..87c1a2ab6 100644 --- a/llama_stack/distribution/server/auth.py +++ b/llama_stack/distribution/server/auth.py @@ -7,9 +7,12 @@ import json import httpx +from aiohttp import hdrs -from llama_stack.distribution.datatypes import AuthenticationConfig +from llama_stack.distribution.datatypes import AuthenticationConfig, User +from llama_stack.distribution.request_headers import user_from_scope from llama_stack.distribution.server.auth_providers import create_auth_provider +from llama_stack.distribution.server.routes import find_matching_route, initialize_route_impls from llama_stack.log import get_logger logger = get_logger(name=__name__, category="auth") @@ -78,12 +81,14 @@ class AuthenticationMiddleware: access resources that don't have access_attributes defined. """ - def __init__(self, app, auth_config: AuthenticationConfig): + def __init__(self, app, auth_config: AuthenticationConfig, impls): self.app = app + self.impls = impls self.auth_provider = create_auth_provider(auth_config) async def __call__(self, scope, receive, send): if scope["type"] == "http": + # First, handle authentication headers = dict(scope.get("headers", [])) auth_header = headers.get(b"authorization", b"").decode() @@ -121,15 +126,50 @@ class AuthenticationMiddleware: f"Authentication successful: {validation_result.principal} with {len(validation_result.attributes)} attributes" ) + # Scope-based API access control + path = scope.get("path", "") + method = scope.get("method", hdrs.METH_GET) + + if not hasattr(self, "route_impls"): + self.route_impls = initialize_route_impls(self.impls) + + try: + _, _, _, webmethod = find_matching_route(method, path, self.route_impls) + except ValueError: + # If no matching endpoint is found, pass through to FastAPI + return await self.app(scope, receive, send) + + if webmethod.required_scope: + user = user_from_scope(scope) + if not _has_required_scope(webmethod.required_scope, user): + return await self._send_auth_error( + send, + f"Access denied: user does not have required scope: {webmethod.required_scope}", + status=403, + ) + return await self.app(scope, receive, send) - async def _send_auth_error(self, send, message): + async def _send_auth_error(self, send, message, status=401): await send( { "type": "http.response.start", - "status": 401, + "status": status, "headers": [[b"content-type", b"application/json"]], } ) - error_msg = json.dumps({"error": {"message": message}}).encode() + error_key = "message" if status == 401 else "detail" + error_msg = json.dumps({"error": {error_key: message}}).encode() await send({"type": "http.response.body", "body": error_msg}) + + +def _has_required_scope(required_scope: str, user: User | None) -> bool: + # if no user, assume auth is not enabled + if not user: + return True + + if not user.attributes: + return False + + user_scopes = user.attributes.get("scopes", []) + return required_scope in user_scopes diff --git a/llama_stack/distribution/server/routes.py b/llama_stack/distribution/server/routes.py index ea66fec5a..ca6f629af 100644 --- a/llama_stack/distribution/server/routes.py +++ b/llama_stack/distribution/server/routes.py @@ -12,17 +12,18 @@ from typing import Any from aiohttp import hdrs from starlette.routing import Route +from llama_stack.apis.datatypes import Api, ExternalApiSpec from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup from llama_stack.apis.version import LLAMA_STACK_API_VERSION from llama_stack.distribution.resolver import api_protocol_map -from llama_stack.providers.datatypes import Api +from llama_stack.schema_utils import WebMethod EndpointFunc = Callable[..., Any] PathParams = dict[str, str] -RouteInfo = tuple[EndpointFunc, str] +RouteInfo = tuple[EndpointFunc, str, WebMethod] PathImpl = dict[str, RouteInfo] RouteImpls = dict[str, PathImpl] -RouteMatch = tuple[EndpointFunc, PathParams, str] +RouteMatch = tuple[EndpointFunc, PathParams, str, WebMethod] def toolgroup_protocol_map(): @@ -31,10 +32,12 @@ def toolgroup_protocol_map(): } -def get_all_api_routes() -> dict[Api, list[Route]]: +def get_all_api_routes( + external_apis: dict[Api, ExternalApiSpec] | None = None, +) -> dict[Api, list[tuple[Route, WebMethod]]]: apis = {} - protocols = api_protocol_map() + protocols = api_protocol_map(external_apis) toolgroup_protocols = toolgroup_protocol_map() for api, protocol in protocols.items(): routes = [] @@ -65,7 +68,7 @@ def get_all_api_routes() -> dict[Api, list[Route]]: else: http_method = hdrs.METH_POST routes.append( - Route(path=path, methods=[http_method], name=name, endpoint=None) + (Route(path=path, methods=[http_method], name=name, endpoint=None), webmethod) ) # setting endpoint to None since don't use a Router object apis[api] = routes @@ -73,8 +76,8 @@ def get_all_api_routes() -> dict[Api, list[Route]]: return apis -def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls: - routes = get_all_api_routes() +def initialize_route_impls(impls, external_apis: dict[Api, ExternalApiSpec] | None = None) -> RouteImpls: + api_to_routes = get_all_api_routes(external_apis) route_impls: RouteImpls = {} def _convert_path_to_regex(path: str) -> str: @@ -88,10 +91,10 @@ def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls: return f"^{pattern}$" - for api, api_routes in routes.items(): + for api, api_routes in api_to_routes.items(): if api not in impls: continue - for route in api_routes: + for route, webmethod in api_routes: impl = impls[api] func = getattr(impl, route.name) # Get the first (and typically only) method from the set, filtering out HEAD @@ -104,6 +107,7 @@ def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls: route_impls[method][_convert_path_to_regex(route.path)] = ( func, route.path, + webmethod, ) return route_impls @@ -118,7 +122,7 @@ def find_matching_route(method: str, path: str, route_impls: RouteImpls) -> Rout route_impls: A dictionary of endpoint implementations Returns: - A tuple of (endpoint_function, path_params, descriptive_name) + A tuple of (endpoint_function, path_params, route_path, webmethod_metadata) Raises: ValueError: If no matching endpoint is found @@ -127,11 +131,11 @@ def find_matching_route(method: str, path: str, route_impls: RouteImpls) -> Rout if not impls: raise ValueError(f"No endpoint found for {path}") - for regex, (func, descriptive_name) in impls.items(): + for regex, (func, route_path, webmethod) in impls.items(): match = re.match(regex, path) if match: # Extract named groups from the regex match path_params = match.groupdict() - return func, path_params, descriptive_name + return func, path_params, route_path, webmethod raise ValueError(f"No endpoint found for {path}") diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index f05c4ad83..9259fc243 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -40,7 +40,12 @@ from llama_stack.distribution.datatypes import ( StackRunConfig, ) from llama_stack.distribution.distribution import builtin_automatically_routed_apis -from llama_stack.distribution.request_headers import PROVIDER_DATA_VAR, User, request_provider_data_context +from llama_stack.distribution.external import ExternalApiSpec, load_external_apis +from llama_stack.distribution.request_headers import ( + PROVIDER_DATA_VAR, + request_provider_data_context, + user_from_scope, +) from llama_stack.distribution.resolver import InvalidProviderError from llama_stack.distribution.server.routes import ( find_matching_route, @@ -51,6 +56,7 @@ from llama_stack.distribution.stack import ( cast_image_name_to_string, construct_stack, replace_env_vars, + shutdown_stack, validate_env_pair, ) from llama_stack.distribution.utils.config import redact_sensitive_fields @@ -146,18 +152,7 @@ async def shutdown(app): Handled by the lifespan context manager. The shutdown process involves shutting down all implementations registered in the application. """ - for impl in app.__llama_stack_impls__.values(): - impl_name = impl.__class__.__name__ - logger.info("Shutting down %s", impl_name) - try: - if hasattr(impl, "shutdown"): - await asyncio.wait_for(impl.shutdown(), timeout=5) - else: - logger.warning("No shutdown method for %s", impl_name) - except TimeoutError: - logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True) - except (Exception, asyncio.CancelledError) as e: - logger.exception("Failed to shutdown %s: %s", impl_name, {e}) + await shutdown_stack(app.__llama_stack_impls__) @asynccontextmanager @@ -222,9 +217,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable: @functools.wraps(func) async def route_handler(request: Request, **kwargs): # Get auth attributes from the request scope - user_attributes = request.scope.get("user_attributes", {}) - principal = request.scope.get("principal", "") - user = User(principal=principal, attributes=user_attributes) + user = user_from_scope(request.scope) await log_request_pre_validation(request) @@ -282,9 +275,10 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable: class TracingMiddleware: - def __init__(self, app, impls): + def __init__(self, app, impls, external_apis: dict[str, ExternalApiSpec]): self.app = app self.impls = impls + self.external_apis = external_apis # FastAPI built-in paths that should bypass custom routing self.fastapi_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static") @@ -301,10 +295,12 @@ class TracingMiddleware: return await self.app(scope, receive, send) if not hasattr(self, "route_impls"): - self.route_impls = initialize_route_impls(self.impls) + self.route_impls = initialize_route_impls(self.impls, self.external_apis) try: - _, _, trace_path = find_matching_route(scope.get("method", hdrs.METH_GET), path, self.route_impls) + _, _, route_path, webmethod = find_matching_route( + scope.get("method", hdrs.METH_GET), path, self.route_impls + ) except ValueError: # If no matching endpoint is found, pass through to FastAPI logger.debug(f"No matching route found for path: {path}, falling back to FastAPI") @@ -321,6 +317,7 @@ class TracingMiddleware: if tracestate: trace_attributes["tracestate"] = tracestate + trace_path = webmethod.descriptive_name or route_path trace_context = await start_trace(trace_path, trace_attributes) async def send_with_trace_id(message): @@ -432,10 +429,21 @@ def main(args: argparse.Namespace | None = None): if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"): app.add_middleware(ClientVersionMiddleware) - # Add authentication middleware if configured + try: + # Create and set the event loop that will be used for both construction and server runtime + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + # Construct the stack in the persistent event loop + impls = loop.run_until_complete(construct_stack(config)) + + except InvalidProviderError as e: + logger.error(f"Error: {str(e)}") + sys.exit(1) + if config.server.auth: logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}") - app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth) + app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth, impls=impls) else: if config.server.quota: quota = config.server.quota @@ -466,24 +474,14 @@ def main(args: argparse.Namespace | None = None): window_seconds=window_seconds, ) - try: - # Create and set the event loop that will be used for both construction and server runtime - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - # Construct the stack in the persistent event loop - impls = loop.run_until_complete(construct_stack(config)) - - except InvalidProviderError as e: - logger.error(f"Error: {str(e)}") - sys.exit(1) - if Api.telemetry in impls: setup_logger(impls[Api.telemetry]) else: setup_logger(TelemetryAdapter(TelemetryConfig(), {})) - all_routes = get_all_api_routes() + # Load external APIs if configured + external_apis = load_external_apis(config) + all_routes = get_all_api_routes(external_apis) if config.apis: apis_to_serve = set(config.apis) @@ -502,9 +500,12 @@ def main(args: argparse.Namespace | None = None): api = Api(api_str) routes = all_routes[api] - impl = impls[api] + try: + impl = impls[api] + except KeyError as e: + raise ValueError(f"Could not find provider implementation for {api} API") from e - for route in routes: + for route, _ in routes: if not hasattr(impl, route.name): # ideally this should be a typing violation already raise ValueError(f"Could not find method {route.name} on {impl}!") @@ -533,7 +534,7 @@ def main(args: argparse.Namespace | None = None): app.exception_handler(Exception)(global_exception_handler) app.__llama_stack_impls__ = impls - app.add_middleware(TracingMiddleware, impls=impls) + app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis) import uvicorn diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py index d7270156a..0dfd12828 100644 --- a/llama_stack/distribution/stack.py +++ b/llama_stack/distribution/stack.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import asyncio import importlib.resources import os import re @@ -38,6 +39,7 @@ from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.inspect import DistributionInspectConfig, DistributionInspectImpl from llama_stack.distribution.providers import ProviderImpl, ProviderImplConfig from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls +from llama_stack.distribution.routing_tables.common import CommonRoutingTableImpl from llama_stack.distribution.store.registry import create_dist_registry from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.log import get_logger @@ -90,6 +92,10 @@ RESOURCES = [ ] +REGISTRY_REFRESH_INTERVAL_SECONDS = 300 +REGISTRY_REFRESH_TASK = None + + async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]): for rsrc, api, register_method, list_method in RESOURCES: objects = getattr(run_config, rsrc) @@ -324,9 +330,53 @@ async def construct_stack( add_internal_implementations(impls, run_config) await register_resources(run_config, impls) + + global REGISTRY_REFRESH_TASK + REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry(impls)) + + def cb(task): + import traceback + + if task.cancelled(): + logger.error("Model refresh task cancelled") + elif task.exception(): + logger.error(f"Model refresh task failed: {task.exception()}") + traceback.print_exception(task.exception()) + else: + logger.debug("Model refresh task completed") + + REGISTRY_REFRESH_TASK.add_done_callback(cb) return impls +async def shutdown_stack(impls: dict[Api, Any]): + for impl in impls.values(): + impl_name = impl.__class__.__name__ + logger.info(f"Shutting down {impl_name}") + try: + if hasattr(impl, "shutdown"): + await asyncio.wait_for(impl.shutdown(), timeout=5) + else: + logger.warning(f"No shutdown method for {impl_name}") + except TimeoutError: + logger.exception(f"Shutdown timeout for {impl_name}") + except (Exception, asyncio.CancelledError) as e: + logger.exception(f"Failed to shutdown {impl_name}: {e}") + + global REGISTRY_REFRESH_TASK + if REGISTRY_REFRESH_TASK: + REGISTRY_REFRESH_TASK.cancel() + + +async def refresh_registry(impls: dict[Api, Any]): + routing_tables = [v for v in impls.values() if isinstance(v, CommonRoutingTableImpl)] + while True: + for routing_table in routing_tables: + await routing_table.refresh() + + await asyncio.sleep(REGISTRY_REFRESH_INTERVAL_SECONDS) + + def get_stack_run_config_from_template(template: str) -> StackRunConfig: template_path = importlib.resources.files("llama_stack") / f"templates/{template}/run.yaml" diff --git a/llama_stack/log.py b/llama_stack/log.py index fcbb79a5d..fb6fa85f9 100644 --- a/llama_stack/log.py +++ b/llama_stack/log.py @@ -6,6 +6,7 @@ import logging import os +import re import sys from logging.config import dictConfig @@ -30,6 +31,7 @@ CATEGORIES = [ "eval", "tools", "client", + "telemetry", ] # Initialize category levels with default level @@ -113,6 +115,11 @@ def parse_environment_config(env_config: str) -> dict[str, int]: return category_levels +def strip_rich_markup(text): + """Remove Rich markup tags like [dim], [bold magenta], etc.""" + return re.sub(r"\[/?[a-zA-Z0-9 _#=,]+\]", "", text) + + class CustomRichHandler(RichHandler): def __init__(self, *args, **kwargs): kwargs["console"] = Console(width=150) @@ -131,6 +138,19 @@ class CustomRichHandler(RichHandler): self.markup = original_markup +class CustomFileHandler(logging.FileHandler): + def __init__(self, filename, mode="a", encoding=None, delay=False): + super().__init__(filename, mode, encoding, delay) + # Default formatter to match console output + self.default_formatter = logging.Formatter("%(asctime)s %(name)s:%(lineno)d %(category)s: %(message)s") + self.setFormatter(self.default_formatter) + + def emit(self, record): + if hasattr(record, "msg"): + record.msg = strip_rich_markup(str(record.msg)) + super().emit(record) + + def setup_logging(category_levels: dict[str, int], log_file: str | None) -> None: """ Configure logging based on the provided category log levels and an optional log file. @@ -167,8 +187,7 @@ def setup_logging(category_levels: dict[str, int], log_file: str | None) -> None # Add a file handler if log_file is set if log_file: handlers["file"] = { - "class": "logging.FileHandler", - "formatter": "rich", + "()": CustomFileHandler, "filename": log_file, "mode": "a", "encoding": "utf-8", diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 424380324..faf7ff18c 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -47,6 +47,17 @@ class ModelsProtocolPrivate(Protocol): async def unregister_model(self, model_id: str) -> None: ... + # the Stack router will query each provider for their list of models + # if a `refresh_interval_seconds` is provided, this method will be called + # periodically to refresh the list of models + # + # NOTE: each model returned will be registered with the model registry. this means + # a callback to the `register_model()` method will be made. this is duplicative and + # may be removed in the future. + async def list_models(self) -> list[Model] | None: ... + + async def should_refresh_models(self) -> bool: ... + class ShieldsProtocolPrivate(Protocol): async def register_shield(self, shield: Shield) -> None: ... @@ -104,6 +115,19 @@ class ProviderSpec(BaseModel): description="If this provider is deprecated and does NOT work, specify the error message here", ) + module: str | None = Field( + default=None, + description=""" + Fully-qualified name of the module to import. The module is expected to have: + + - `get_adapter_impl(config, deps)`: returns the adapter implementation + + Example: `module: ramalama_stack` + """, + ) + + is_external: bool = Field(default=False, description="Notes whether this provider is an external provider.") + # used internally by the resolver; this is a hack for now deps__: list[str] = Field(default_factory=list) @@ -124,7 +148,7 @@ class AdapterSpec(BaseModel): description="Unique identifier for this adapter", ) module: str = Field( - ..., + default_factory=str, description=""" Fully-qualified name of the module to import. The module is expected to have: @@ -162,14 +186,7 @@ The container image to use for this implementation. If one is provided, pip_pack If a provider depends on other providers, the dependencies MUST NOT specify a container image. """, ) - module: str = Field( - ..., - description=""" -Fully-qualified name of the module to import. The module is expected to have: - - - `get_provider_impl(config, deps)`: returns the local implementation -""", - ) + # module field is inherited from ProviderSpec provider_data_validator: str | None = Field( default=None, ) @@ -212,9 +229,7 @@ API responses, specify the adapter here. def container_image(self) -> str | None: return None - @property - def module(self) -> str: - return self.adapter.module + # module field is inherited from ProviderSpec @property def pip_packages(self) -> list[str]: @@ -226,14 +241,19 @@ API responses, specify the adapter here. def remote_provider_spec( - api: Api, adapter: AdapterSpec, api_dependencies: list[Api] | None = None + api: Api, + adapter: AdapterSpec, + api_dependencies: list[Api] | None = None, + optional_api_dependencies: list[Api] | None = None, ) -> RemoteProviderSpec: return RemoteProviderSpec( api=api, provider_type=f"remote::{adapter.adapter_type}", config_class=adapter.config_class, + module=adapter.module, adapter=adapter, api_dependencies=api_dependencies or [], + optional_api_dependencies=optional_api_dependencies or [], ) diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index e238e1b78..88d7a98ec 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -102,6 +102,12 @@ class MetaReferenceInferenceImpl( if self.config.create_distributed_process_group: self.generator.stop() + async def should_refresh_models(self) -> bool: + return False + + async def list_models(self) -> list[Model] | None: + return None + async def unregister_model(self, model_id: str) -> None: pass diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py index 890c526f5..fea8a8189 100644 --- a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -20,6 +20,7 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) +from llama_stack.apis.models import ModelType from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate from llama_stack.providers.utils.inference.embedding_mixin import ( SentenceTransformerEmbeddingMixin, @@ -41,6 +42,8 @@ class SentenceTransformersInferenceImpl( InferenceProvider, ModelsProtocolPrivate, ): + __provider_id__: str + def __init__(self, config: SentenceTransformersInferenceConfig) -> None: self.config = config @@ -50,6 +53,22 @@ class SentenceTransformersInferenceImpl( async def shutdown(self) -> None: pass + async def should_refresh_models(self) -> bool: + return False + + async def list_models(self) -> list[Model] | None: + return [ + Model( + identifier="all-MiniLM-L6-v2", + provider_resource_id="all-MiniLM-L6-v2", + provider_id=self.__provider_id__, + metadata={ + "embedding_dimension": 384, + }, + model_type=ModelType.embedding, + ), + ] + async def register_model(self, model: Model) -> Model: return model diff --git a/llama_stack/providers/inline/telemetry/meta_reference/console_span_processor.py b/llama_stack/providers/inline/telemetry/meta_reference/console_span_processor.py index e187bdb3b..b4c77437d 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/console_span_processor.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/console_span_processor.py @@ -11,19 +11,9 @@ from opentelemetry.sdk.trace import ReadableSpan from opentelemetry.sdk.trace.export import SpanProcessor from opentelemetry.trace.status import StatusCode -# Colors for console output -COLORS = { - "reset": "\033[0m", - "bold": "\033[1m", - "dim": "\033[2m", - "red": "\033[31m", - "green": "\033[32m", - "yellow": "\033[33m", - "blue": "\033[34m", - "magenta": "\033[35m", - "cyan": "\033[36m", - "white": "\033[37m", -} +from llama_stack.log import get_logger + +logger = get_logger(name="console_span_processor", category="telemetry") class ConsoleSpanProcessor(SpanProcessor): @@ -35,34 +25,21 @@ class ConsoleSpanProcessor(SpanProcessor): return timestamp = datetime.fromtimestamp(span.start_time / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3] - - print( - f"{COLORS['dim']}{timestamp}{COLORS['reset']} " - f"{COLORS['magenta']}[START]{COLORS['reset']} " - f"{COLORS['dim']}{span.name}{COLORS['reset']}" - ) + logger.info(f"[dim]{timestamp}[/dim] [bold magenta][START][/bold magenta] [dim]{span.name}[/dim]") def on_end(self, span: ReadableSpan) -> None: if span.attributes and span.attributes.get("__autotraced__"): return timestamp = datetime.fromtimestamp(span.end_time / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3] - - span_context = ( - f"{COLORS['dim']}{timestamp}{COLORS['reset']} " - f"{COLORS['magenta']}[END]{COLORS['reset']} " - f"{COLORS['dim']}{span.name}{COLORS['reset']}" - ) - + span_context = f"[dim]{timestamp}[/dim] [bold magenta][END][/bold magenta] [dim]{span.name}[/dim]" if span.status.status_code == StatusCode.ERROR: - span_context += f"{COLORS['reset']} {COLORS['red']}[ERROR]{COLORS['reset']}" + span_context += " [bold red][ERROR][/bold red]" elif span.status.status_code != StatusCode.UNSET: - span_context += f"{COLORS['reset']} [{span.status.status_code}]" - + span_context += f" [{span.status.status_code}]" duration_ms = (span.end_time - span.start_time) / 1e6 - span_context += f"{COLORS['reset']} ({duration_ms:.2f}ms)" - - print(span_context) + span_context += f" ({duration_ms:.2f}ms)" + logger.info(span_context) if self.print_attributes and span.attributes: for key, value in span.attributes.items(): @@ -71,31 +48,26 @@ class ConsoleSpanProcessor(SpanProcessor): str_value = str(value) if len(str_value) > 1000: str_value = str_value[:997] + "..." - print(f" {COLORS['dim']}{key}: {str_value}{COLORS['reset']}") + logger.info(f" [dim]{key}[/dim]: {str_value}") for event in span.events: event_time = datetime.fromtimestamp(event.timestamp / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3] - severity = event.attributes.get("severity", "info") message = event.attributes.get("message", event.name) - if isinstance(message, dict | list): + if isinstance(message, dict) or isinstance(message, list): message = json.dumps(message, indent=2) - - severity_colors = { - "error": f"{COLORS['bold']}{COLORS['red']}", - "warn": f"{COLORS['bold']}{COLORS['yellow']}", - "info": COLORS["white"], - "debug": COLORS["dim"], - } - msg_color = severity_colors.get(severity, COLORS["white"]) - - print(f" {event_time} {msg_color}[{severity.upper()}] {message}{COLORS['reset']}") - + severity_color = { + "error": "red", + "warn": "yellow", + "info": "white", + "debug": "dim", + }.get(severity, "white") + logger.info(f" {event_time} [bold {severity_color}][{severity.upper()}][/bold {severity_color}] {message}") if event.attributes: for key, value in event.attributes.items(): if key.startswith("__") or key in ["message", "severity"]: continue - print(f" {COLORS['dim']}{key}: {value}{COLORS['reset']}") + logger.info(f"/r[dim]{key}[/dim]: {value}") def shutdown(self) -> None: """Shutdown the processor.""" diff --git a/llama_stack/providers/inline/vector_io/faiss/faiss.py b/llama_stack/providers/inline/vector_io/faiss/faiss.py index a57b4a4ee..edee4649d 100644 --- a/llama_stack/providers/inline/vector_io/faiss/faiss.py +++ b/llama_stack/providers/inline/vector_io/faiss/faiss.py @@ -55,6 +55,11 @@ class FaissIndex(EmbeddingIndex): self.kvstore = kvstore self.bank_id = bank_id + # A list of chunk id's in the same order as they are in the index, + # must be updated when chunks are added or removed + self.chunk_id_lock = asyncio.Lock() + self.chunk_ids: list[Any] = [] + @classmethod async def create(cls, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None): instance = cls(dimension, kvstore, bank_id) @@ -75,6 +80,7 @@ class FaissIndex(EmbeddingIndex): buffer = io.BytesIO(base64.b64decode(data["faiss_index"])) try: self.index = faiss.deserialize_index(np.load(buffer, allow_pickle=False)) + self.chunk_ids = [chunk.chunk_id for chunk in self.chunk_by_index.values()] except Exception as e: logger.debug(e, exc_info=True) raise ValueError( @@ -114,11 +120,33 @@ class FaissIndex(EmbeddingIndex): for i, chunk in enumerate(chunks): self.chunk_by_index[indexlen + i] = chunk - self.index.add(np.array(embeddings).astype(np.float32)) + async with self.chunk_id_lock: + self.index.add(np.array(embeddings).astype(np.float32)) + self.chunk_ids.extend([chunk.chunk_id for chunk in chunks]) # Save updated index await self._save_index() + async def delete_chunk(self, chunk_id: str) -> None: + if chunk_id not in self.chunk_ids: + return + + async with self.chunk_id_lock: + index = self.chunk_ids.index(chunk_id) + self.index.remove_ids(np.array([index])) + + new_chunk_by_index = {} + for idx, chunk in self.chunk_by_index.items(): + # Shift all chunks after the removed chunk to the left + if idx > index: + new_chunk_by_index[idx - 1] = chunk + else: + new_chunk_by_index[idx] = chunk + self.chunk_by_index = new_chunk_by_index + self.chunk_ids.pop(index) + + await self._save_index() + async def query_vector( self, embedding: NDArray, @@ -260,3 +288,9 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr raise ValueError(f"Vector DB {vector_db_id} not found") return await index.query_chunks(query, params) + + async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None: + """Delete a chunk from a faiss index""" + faiss_index = self.cache[store_id].index + for chunk_id in chunk_ids: + await faiss_index.delete_chunk(chunk_id) diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py index f2598cc7c..cfa4e2263 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py @@ -425,6 +425,35 @@ class SQLiteVecIndex(EmbeddingIndex): return QueryChunksResponse(chunks=chunks, scores=scores) + async def delete_chunk(self, chunk_id: str) -> None: + """Remove a chunk from the SQLite vector store.""" + + def _delete_chunk(): + connection = _create_sqlite_connection(self.db_path) + cur = connection.cursor() + try: + cur.execute("BEGIN TRANSACTION") + + # Delete from metadata table + cur.execute(f"DELETE FROM {self.metadata_table} WHERE id = ?", (chunk_id,)) + + # Delete from vector table + cur.execute(f"DELETE FROM {self.vector_table} WHERE id = ?", (chunk_id,)) + + # Delete from FTS table + cur.execute(f"DELETE FROM {self.fts_table} WHERE id = ?", (chunk_id,)) + + connection.commit() + except Exception as e: + connection.rollback() + logger.error(f"Error deleting chunk {chunk_id}: {e}") + raise + finally: + cur.close() + connection.close() + + await asyncio.to_thread(_delete_chunk) + class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate): """ @@ -520,3 +549,13 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc if not index: raise ValueError(f"Vector DB {vector_db_id} not found") return await index.query_chunks(query, params) + + async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None: + """Delete a chunk from a sqlite_vec index.""" + index = await self._get_and_cache_vector_db_index(store_id) + if not index: + raise ValueError(f"Vector DB {store_id} not found") + + for chunk_id in chunk_ids: + # Use the index's delete_chunk method + await index.index.delete_chunk(chunk_id) diff --git a/llama_stack/providers/registry/vector_io.py b/llama_stack/providers/registry/vector_io.py index e391341b4..063b382df 100644 --- a/llama_stack/providers/registry/vector_io.py +++ b/llama_stack/providers/registry/vector_io.py @@ -410,6 +410,7 @@ See [PGVector's documentation](https://github.com/pgvector/pgvector) for more de """, ), api_dependencies=[Api.inference], + optional_api_dependencies=[Api.files], ), remote_provider_spec( Api.vector_io, diff --git a/llama_stack/providers/remote/inference/fireworks/config.py b/llama_stack/providers/remote/inference/fireworks/config.py index 072d558f4..b23f2d31b 100644 --- a/llama_stack/providers/remote/inference/fireworks/config.py +++ b/llama_stack/providers/remote/inference/fireworks/config.py @@ -6,13 +6,14 @@ from typing import Any -from pydantic import BaseModel, Field, SecretStr +from pydantic import Field, SecretStr +from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type @json_schema_type -class FireworksImplConfig(BaseModel): +class FireworksImplConfig(RemoteInferenceProviderConfig): url: str = Field( default="https://api.fireworks.ai/inference/v1", description="The URL for the Fireworks server", diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index 1c82ff3a8..c76aa39f3 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -70,7 +70,7 @@ logger = get_logger(name=__name__, category="inference") class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData): def __init__(self, config: FireworksImplConfig) -> None: - ModelRegistryHelper.__init__(self, MODEL_ENTRIES) + ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models) self.config = config async def initialize(self) -> None: diff --git a/llama_stack/providers/remote/inference/ollama/config.py b/llama_stack/providers/remote/inference/ollama/config.py index ae261f47c..ce13f0d83 100644 --- a/llama_stack/providers/remote/inference/ollama/config.py +++ b/llama_stack/providers/remote/inference/ollama/config.py @@ -13,8 +13,10 @@ DEFAULT_OLLAMA_URL = "http://localhost:11434" class OllamaImplConfig(BaseModel): url: str = DEFAULT_OLLAMA_URL - refresh_models: bool = Field(default=False, description="refresh and re-register models periodically") - refresh_models_interval: int = Field(default=300, description="interval in seconds to refresh models") + refresh_models: bool = Field( + default=False, + description="Whether to refresh models periodically", + ) @classmethod def sample_run_config(cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", **kwargs) -> dict[str, Any]: diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 76d789d07..ba20185d3 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -98,14 +98,16 @@ class OllamaInferenceAdapter( def __init__(self, config: OllamaImplConfig) -> None: self.register_helper = ModelRegistryHelper(MODEL_ENTRIES) self.config = config - self._client = None + self._clients: dict[asyncio.AbstractEventLoop, AsyncClient] = {} self._openai_client = None @property def client(self) -> AsyncClient: - if self._client is None: - self._client = AsyncClient(host=self.config.url) - return self._client + # ollama client attaches itself to the current event loop (sadly?) + loop = asyncio.get_running_loop() + if loop not in self._clients: + self._clients[loop] = AsyncClient(host=self.config.url) + return self._clients[loop] @property def openai_client(self) -> AsyncOpenAI: @@ -121,59 +123,61 @@ class OllamaInferenceAdapter( "Ollama Server is not running, make sure to start it using `ollama serve` in a separate terminal" ) - if self.config.refresh_models: - logger.debug("ollama starting background model refresh task") - self._refresh_task = asyncio.create_task(self._refresh_models()) - - def cb(task): - if task.cancelled(): - import traceback - - logger.error(f"ollama background refresh task canceled:\n{''.join(traceback.format_stack())}") - elif task.exception(): - logger.error(f"ollama background refresh task died: {task.exception()}") - else: - logger.error("ollama background refresh task completed unexpectedly") - - self._refresh_task.add_done_callback(cb) - - async def _refresh_models(self) -> None: - # Wait for model store to be available (with timeout) - waited_time = 0 - while not self.model_store and waited_time < 60: - await asyncio.sleep(1) - waited_time += 1 - - if not self.model_store: - raise ValueError("Model store not set after waiting 60 seconds") + async def should_refresh_models(self) -> bool: + return self.config.refresh_models + async def list_models(self) -> list[Model] | None: provider_id = self.__provider_id__ - while True: - try: - response = await self.client.list() - except Exception as e: - logger.warning(f"Failed to list models: {str(e)}") - await asyncio.sleep(self.config.refresh_models_interval) + response = await self.client.list() + + # always add the two embedding models which can be pulled on demand + models = [ + Model( + identifier="all-minilm:l6-v2", + provider_resource_id="all-minilm:l6-v2", + provider_id=provider_id, + metadata={ + "embedding_dimension": 384, + "context_length": 512, + }, + model_type=ModelType.embedding, + ), + # add all-minilm alias + Model( + identifier="all-minilm", + provider_resource_id="all-minilm:l6-v2", + provider_id=provider_id, + metadata={ + "embedding_dimension": 384, + "context_length": 512, + }, + model_type=ModelType.embedding, + ), + Model( + identifier="nomic-embed-text", + provider_resource_id="nomic-embed-text", + provider_id=provider_id, + metadata={ + "embedding_dimension": 768, + "context_length": 8192, + }, + model_type=ModelType.embedding, + ), + ] + for m in response.models: + # kill embedding models since we don't know dimensions for them + if m.details.family in ["bert"]: continue - - models = [] - for m in response.models: - model_type = ModelType.embedding if m.details.family in ["bert"] else ModelType.llm - if model_type == ModelType.embedding: - continue - models.append( - Model( - identifier=m.model, - provider_resource_id=m.model, - provider_id=provider_id, - metadata={}, - model_type=model_type, - ) + models.append( + Model( + identifier=m.model, + provider_resource_id=m.model, + provider_id=provider_id, + metadata={}, + model_type=ModelType.llm, ) - await self.model_store.update_registered_llm_models(provider_id, models) - logger.debug(f"ollama refreshed model list ({len(models)} models)") - - await asyncio.sleep(self.config.refresh_models_interval) + ) + return models async def health(self) -> HealthResponse: """ @@ -190,12 +194,7 @@ class OllamaInferenceAdapter( return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}") async def shutdown(self) -> None: - if hasattr(self, "_refresh_task") and not self._refresh_task.done(): - logger.debug("ollama cancelling background refresh task") - self._refresh_task.cancel() - - self._client = None - self._openai_client = None + self._clients.clear() async def unregister_model(self, model_id: str) -> None: pass diff --git a/llama_stack/providers/remote/inference/together/config.py b/llama_stack/providers/remote/inference/together/config.py index f166e4277..211be7efe 100644 --- a/llama_stack/providers/remote/inference/together/config.py +++ b/llama_stack/providers/remote/inference/together/config.py @@ -6,13 +6,14 @@ from typing import Any -from pydantic import BaseModel, Field, SecretStr +from pydantic import Field, SecretStr +from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.schema_utils import json_schema_type @json_schema_type -class TogetherImplConfig(BaseModel): +class TogetherImplConfig(RemoteInferenceProviderConfig): url: str = Field( default="https://api.together.xyz/v1", description="The URL for the Together AI server", diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index e1eb934c5..46094c146 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -66,7 +66,7 @@ logger = get_logger(name=__name__, category="inference") class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData): def __init__(self, config: TogetherImplConfig) -> None: - ModelRegistryHelper.__init__(self, MODEL_ENTRIES) + ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models) self.config = config async def initialize(self) -> None: diff --git a/llama_stack/providers/remote/inference/vllm/config.py b/llama_stack/providers/remote/inference/vllm/config.py index ee72f974a..a5bf0e4bc 100644 --- a/llama_stack/providers/remote/inference/vllm/config.py +++ b/llama_stack/providers/remote/inference/vllm/config.py @@ -33,10 +33,6 @@ class VLLMInferenceAdapterConfig(BaseModel): default=False, description="Whether to refresh models periodically", ) - refresh_models_interval: int = Field( - default=300, - description="Interval in seconds to refresh models", - ) @field_validator("tls_verify") @classmethod diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 8bdba1e88..621658a48 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -3,7 +3,6 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import asyncio import json from collections.abc import AsyncGenerator, AsyncIterator from typing import Any @@ -293,7 +292,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): # automatically set by the resolver when instantiating the provider __provider_id__: str model_store: ModelStore | None = None - _refresh_task: asyncio.Task | None = None def __init__(self, config: VLLMInferenceAdapterConfig) -> None: self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries()) @@ -301,65 +299,30 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): self.client = None async def initialize(self) -> None: - if not self.config.url: - # intentionally don't raise an error here, we want to allow the provider to be "dormant" - # or available in distributions like "starter" without causing a ruckus - return + pass - if self.config.refresh_models: - self._refresh_task = asyncio.create_task(self._refresh_models()) - - def cb(task): - import traceback - - if task.cancelled(): - log.error(f"vLLM background refresh task canceled:\n{''.join(traceback.format_stack())}") - elif task.exception(): - # print the stack trace for the exception - exc = task.exception() - log.error(f"vLLM background refresh task died: {exc}") - traceback.print_exception(exc) - else: - log.error("vLLM background refresh task completed unexpectedly") - - self._refresh_task.add_done_callback(cb) - - async def _refresh_models(self) -> None: - provider_id = self.__provider_id__ - waited_time = 0 - while not self.model_store and waited_time < 60: - await asyncio.sleep(1) - waited_time += 1 - - if not self.model_store: - raise ValueError("Model store not set after waiting 60 seconds") + async def should_refresh_models(self) -> bool: + return self.config.refresh_models + async def list_models(self) -> list[Model] | None: self._lazy_initialize_client() assert self.client is not None # mypy - while True: - try: - models = [] - async for m in self.client.models.list(): - model_type = ModelType.llm # unclear how to determine embedding vs. llm models - models.append( - Model( - identifier=m.id, - provider_resource_id=m.id, - provider_id=provider_id, - metadata={}, - model_type=model_type, - ) - ) - await self.model_store.update_registered_llm_models(provider_id, models) - log.debug(f"vLLM refreshed model list ({len(models)} models)") - except Exception as e: - log.error(f"vLLM background refresh task failed: {e}") - await asyncio.sleep(self.config.refresh_models_interval) + models = [] + async for m in self.client.models.list(): + model_type = ModelType.llm # unclear how to determine embedding vs. llm models + models.append( + Model( + identifier=m.id, + provider_resource_id=m.id, + provider_id=self.__provider_id__, + metadata={}, + model_type=model_type, + ) + ) + return models async def shutdown(self) -> None: - if self._refresh_task: - self._refresh_task.cancel() - self._refresh_task = None + pass async def unregister_model(self, model_id: str) -> None: pass diff --git a/llama_stack/providers/remote/vector_io/chroma/chroma.py b/llama_stack/providers/remote/vector_io/chroma/chroma.py index bd968d96d..26aeaedfb 100644 --- a/llama_stack/providers/remote/vector_io/chroma/chroma.py +++ b/llama_stack/providers/remote/vector_io/chroma/chroma.py @@ -57,12 +57,15 @@ class ChromaIndex(EmbeddingIndex): self.collection = collection self.kvstore = kvstore + async def initialize(self): + pass + async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): assert len(chunks) == len(embeddings), ( f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" ) - ids = [f"{c.metadata['document_id']}:chunk-{i}" for i, c in enumerate(chunks)] + ids = [f"{c.metadata.get('document_id', '')}:{c.chunk_id}" for c in chunks] await maybe_await( self.collection.add( documents=[chunk.model_dump_json() for chunk in chunks], @@ -112,6 +115,9 @@ class ChromaIndex(EmbeddingIndex): ) -> QueryChunksResponse: raise NotImplementedError("Keyword search is not supported in Chroma") + async def delete_chunk(self, chunk_id: str) -> None: + raise NotImplementedError("delete_chunk is not supported in Chroma") + async def query_hybrid( self, embedding: NDArray, @@ -137,9 +143,12 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP self.client = None self.cache = {} self.kvstore: KVStore | None = None + self.vector_db_store = None async def initialize(self) -> None: self.kvstore = await kvstore_impl(self.config.kvstore) + self.vector_db_store = self.kvstore + if isinstance(self.config, RemoteChromaVectorIOConfig): log.info(f"Connecting to Chroma server at: {self.config.url}") url = self.config.url.rstrip("/") @@ -172,6 +181,10 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP ) async def unregister_vector_db(self, vector_db_id: str) -> None: + if vector_db_id not in self.cache: + log.warning(f"Vector DB {vector_db_id} not found") + return + await self.cache[vector_db_id].index.delete() del self.cache[vector_db_id] @@ -182,6 +195,8 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP ttl_seconds: int | None = None, ) -> None: index = await self._get_and_cache_vector_db_index(vector_db_id) + if index is None: + raise ValueError(f"Vector DB {vector_db_id} not found in Chroma") await index.insert_chunks(chunks) @@ -193,6 +208,9 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP ) -> QueryChunksResponse: index = await self._get_and_cache_vector_db_index(vector_db_id) + if index is None: + raise ValueError(f"Vector DB {vector_db_id} not found in Chroma") + return await index.query_chunks(query, params) async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex: @@ -208,3 +226,6 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP index = VectorDBWithIndex(vector_db, ChromaIndex(self.client, collection), self.inference_api) self.cache[vector_db_id] = index return index + + async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None: + raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma") diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py index dc4852821..f1652a80e 100644 --- a/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -247,6 +247,16 @@ class MilvusIndex(EmbeddingIndex): ) -> QueryChunksResponse: raise NotImplementedError("Hybrid search is not supported in Milvus") + async def delete_chunk(self, chunk_id: str) -> None: + """Remove a chunk from the Milvus collection.""" + try: + await asyncio.to_thread( + self.client.delete, collection_name=self.collection_name, filter=f'chunk_id == "{chunk_id}"' + ) + except Exception as e: + logger.error(f"Error deleting chunk {chunk_id} from Milvus collection {self.collection_name}: {e}") + raise + class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate): def __init__( @@ -369,3 +379,13 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP ) return await index.query_chunks(query, params) + + async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None: + """Delete a chunk from a milvus vector store.""" + index = await self._get_and_cache_vector_db_index(store_id) + if not index: + raise ValueError(f"Vector DB {store_id} not found") + + for chunk_id in chunk_ids: + # Use the index's delete_chunk method + await index.index.delete_chunk(chunk_id) diff --git a/llama_stack/providers/remote/vector_io/pgvector/__init__.py b/llama_stack/providers/remote/vector_io/pgvector/__init__.py index 9f528db74..59eef4c81 100644 --- a/llama_stack/providers/remote/vector_io/pgvector/__init__.py +++ b/llama_stack/providers/remote/vector_io/pgvector/__init__.py @@ -12,6 +12,6 @@ from .config import PGVectorVectorIOConfig async def get_adapter_impl(config: PGVectorVectorIOConfig, deps: dict[Api, ProviderSpec]): from .pgvector import PGVectorVectorIOAdapter - impl = PGVectorVectorIOAdapter(config, deps[Api.inference]) + impl = PGVectorVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files, None)) await impl.initialize() return impl diff --git a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py index 3aeb3f30d..643c27328 100644 --- a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py +++ b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py @@ -99,7 +99,7 @@ class PGVectorIndex(EmbeddingIndex): for i, chunk in enumerate(chunks): values.append( ( - f"{chunk.metadata['document_id']}:chunk-{i}", + f"{chunk.chunk_id}", Json(chunk.model_dump()), embeddings[i].tolist(), ) @@ -159,6 +159,11 @@ class PGVectorIndex(EmbeddingIndex): with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: cur.execute(f"DROP TABLE IF EXISTS {self.table_name}") + async def delete_chunk(self, chunk_id: str) -> None: + """Remove a chunk from the PostgreSQL table.""" + with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: + cur.execute(f"DELETE FROM {self.table_name} WHERE id = %s", (chunk_id,)) + class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate): def __init__( @@ -265,3 +270,13 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco index = PGVectorIndex(vector_db, vector_db.embedding_dimension, self.conn) self.cache[vector_db_id] = VectorDBWithIndex(vector_db, index, self.inference_api) return self.cache[vector_db_id] + + async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None: + """Delete a chunk from a PostgreSQL vector store.""" + index = await self._get_and_cache_vector_db_index(store_id) + if not index: + raise ValueError(f"Vector DB {store_id} not found") + + for chunk_id in chunk_ids: + # Use the index's delete_chunk method + await index.index.delete_chunk(chunk_id) diff --git a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py index 5bdea0ce8..3df3da27f 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py +++ b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py @@ -82,6 +82,9 @@ class QdrantIndex(EmbeddingIndex): await self.client.upsert(collection_name=self.collection_name, points=points) + async def delete_chunk(self, chunk_id: str) -> None: + raise NotImplementedError("delete_chunk is not supported in qdrant") + async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: results = ( await self.client.query_points( @@ -307,3 +310,6 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): file_id: str, ) -> VectorStoreFileObject: raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant") + + async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None: + raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant") diff --git a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py index 35bb40454..543835e20 100644 --- a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py +++ b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py @@ -66,6 +66,9 @@ class WeaviateIndex(EmbeddingIndex): # TODO: make this async friendly collection.data.insert_many(data_objects) + async def delete_chunk(self, chunk_id: str) -> None: + raise NotImplementedError("delete_chunk is not supported in Chroma") + async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: collection = self.client.collections.get(self.collection_name) @@ -264,3 +267,6 @@ class WeaviateVectorIOAdapter( async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None: raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate") + + async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None: + raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate") diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 651d58e2a..bceeaf198 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -20,6 +20,13 @@ from llama_stack.providers.utils.inference import ( logger = get_logger(name=__name__, category="core") +class RemoteInferenceProviderConfig(BaseModel): + allowed_models: list[str] | None = Field( + default=None, + description="List of models that should be registered with the model registry. If None, all models are allowed.", + ) + + # TODO: this class is more confusing than useful right now. We need to make it # more closer to the Model class. class ProviderModelEntry(BaseModel): @@ -65,7 +72,10 @@ def build_model_entry(provider_model_id: str, model_descriptor: str) -> Provider class ModelRegistryHelper(ModelsProtocolPrivate): - def __init__(self, model_entries: list[ProviderModelEntry]): + __provider_id__: str + + def __init__(self, model_entries: list[ProviderModelEntry], allowed_models: list[str] | None = None): + self.allowed_models = allowed_models self.alias_to_provider_id_map = {} self.provider_id_to_llama_model_map = {} for entry in model_entries: @@ -79,6 +89,27 @@ class ModelRegistryHelper(ModelsProtocolPrivate): self.alias_to_provider_id_map[entry.llama_model] = entry.provider_model_id self.provider_id_to_llama_model_map[entry.provider_model_id] = entry.llama_model + async def list_models(self) -> list[Model] | None: + models = [] + for entry in self.model_entries: + ids = [entry.provider_model_id] + entry.aliases + for id in ids: + if self.allowed_models and id not in self.allowed_models: + continue + models.append( + Model( + model_id=id, + provider_resource_id=entry.provider_model_id, + model_type=ModelType.llm, + metadata=entry.metadata, + provider_id=self.__provider_id__, + ) + ) + return models + + async def should_refresh_models(self) -> bool: + return False + def get_provider_model_id(self, identifier: str) -> str | None: return self.alias_to_provider_id_map.get(identifier, None) diff --git a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py index f178e9299..ee69d7c52 100644 --- a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py +++ b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py @@ -152,6 +152,11 @@ class OpenAIVectorStoreMixin(ABC): """Load existing OpenAI vector stores into the in-memory cache.""" self.openai_vector_stores = await self._load_openai_vector_stores() + @abstractmethod + async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None: + """Delete a chunk from a vector store.""" + pass + @abstractmethod async def register_vector_db(self, vector_db: VectorDB) -> None: """Register a vector database (provider-specific implementation).""" @@ -763,17 +768,15 @@ class OpenAIVectorStoreMixin(ABC): if vector_store_id not in self.openai_vector_stores: raise ValueError(f"Vector store {vector_store_id} not found") + dict_chunks = await self._load_openai_vector_store_file_contents(vector_store_id, file_id) + chunks = [Chunk.model_validate(c) for c in dict_chunks] + await self.delete_chunks(vector_store_id, [str(c.chunk_id) for c in chunks if c.chunk_id]) + store_info = self.openai_vector_stores[vector_store_id].copy() file = await self.openai_retrieve_vector_store_file(vector_store_id, file_id) await self._delete_openai_vector_store_file_from_storage(vector_store_id, file_id) - # TODO: We need to actually delete the embeddings from the underlying vector store... - # Also uncomment the corresponding integration test marked as xfail - # - # test_openai_vector_store_delete_file_removes_from_vector_store in - # tests/integration/vector_io/test_openai_vector_stores.py - # Update in-memory cache store_info["file_ids"].remove(file_id) store_info["file_counts"][file.status] -= 1 diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index f892d33c6..4a8749cba 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -231,6 +231,10 @@ class EmbeddingIndex(ABC): async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): raise NotImplementedError() + @abstractmethod + async def delete_chunk(self, chunk_id: str): + raise NotImplementedError() + @abstractmethod async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: raise NotImplementedError() diff --git a/llama_stack/providers/utils/tools/mcp.py b/llama_stack/providers/utils/tools/mcp.py index fbf992c82..76593a4b8 100644 --- a/llama_stack/providers/utils/tools/mcp.py +++ b/llama_stack/providers/utils/tools/mcp.py @@ -4,13 +4,16 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager +from enum import Enum from typing import Any, cast import httpx -from mcp import ClientSession +from mcp import ClientSession, McpError from mcp import types as mcp_types from mcp.client.sse import sse_client +from mcp.client.streamable_http import streamablehttp_client from llama_stack.apis.common.content_types import ImageContentItem, InterleavedContentItem, TextContentItem from llama_stack.apis.tools import ( @@ -21,31 +24,61 @@ from llama_stack.apis.tools import ( ) from llama_stack.distribution.datatypes import AuthenticationRequiredError from llama_stack.log import get_logger +from llama_stack.providers.utils.tools.ttl_dict import TTLDict logger = get_logger(__name__, category="tools") +protocol_cache = TTLDict(ttl_seconds=3600) + + +class MCPProtol(Enum): + UNKNOWN = 0 + STREAMABLE_HTTP = 1 + SSE = 2 + @asynccontextmanager -async def sse_client_wrapper(endpoint: str, headers: dict[str, str]): - try: - async with sse_client(endpoint, headers=headers) as streams: - async with ClientSession(*streams) as session: - await session.initialize() - yield session - except* httpx.HTTPStatusError as eg: - for exc in eg.exceptions: - # mypy does not currently narrow the type of `eg.exceptions` based on the `except*` filter, - # so we explicitly cast each item to httpx.HTTPStatusError. This is safe because - # `except* httpx.HTTPStatusError` guarantees all exceptions in `eg.exceptions` are of that type. - err = cast(httpx.HTTPStatusError, exc) - if err.response.status_code == 401: - raise AuthenticationRequiredError(exc) from exc - raise +async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, Any]: + # we use a ttl'd dict to cache the happy path protocol for each endpoint + # but, we always fall back to trying the other protocol if we cannot initialize the session + connection_strategies = [MCPProtol.STREAMABLE_HTTP, MCPProtol.SSE] + mcp_protocol = protocol_cache.get(endpoint, default=MCPProtol.UNKNOWN) + if mcp_protocol == MCPProtol.SSE: + connection_strategies = [MCPProtol.SSE, MCPProtol.STREAMABLE_HTTP] + + for i, strategy in enumerate(connection_strategies): + try: + client = streamablehttp_client + if strategy == MCPProtol.SSE: + client = sse_client + async with client(endpoint, headers=headers) as client_streams: + async with ClientSession(read_stream=client_streams[0], write_stream=client_streams[1]) as session: + await session.initialize() + protocol_cache[endpoint] = strategy + yield session + return + except* httpx.HTTPStatusError as eg: + for exc in eg.exceptions: + # mypy does not currently narrow the type of `eg.exceptions` based on the `except*` filter, + # so we explicitly cast each item to httpx.HTTPStatusError. This is safe because + # `except* httpx.HTTPStatusError` guarantees all exceptions in `eg.exceptions` are of that type. + err = cast(httpx.HTTPStatusError, exc) + if err.response.status_code == 401: + raise AuthenticationRequiredError(exc) from exc + if i == len(connection_strategies) - 1: + raise + except* McpError: + if i < len(connection_strategies) - 1: + logger.warning( + f"failed to connect via {strategy.name}, falling back to {connection_strategies[i + 1].name}" + ) + else: + raise async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse: tools = [] - async with sse_client_wrapper(endpoint, headers) as session: + async with client_wrapper(endpoint, headers) as session: tools_result = await session.list_tools() for tool in tools_result.tools: parameters = [] @@ -73,7 +106,7 @@ async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefs async def invoke_mcp_tool( endpoint: str, headers: dict[str, str], tool_name: str, kwargs: dict[str, Any] ) -> ToolInvocationResult: - async with sse_client_wrapper(endpoint, headers) as session: + async with client_wrapper(endpoint, headers) as session: result = await session.call_tool(tool_name, kwargs) content: list[InterleavedContentItem] = [] diff --git a/llama_stack/providers/utils/tools/ttl_dict.py b/llama_stack/providers/utils/tools/ttl_dict.py new file mode 100644 index 000000000..2a2605a52 --- /dev/null +++ b/llama_stack/providers/utils/tools/ttl_dict.py @@ -0,0 +1,70 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import time +from threading import RLock +from typing import Any + + +class TTLDict(dict): + """ + A dictionary with a ttl for each item + """ + + def __init__(self, ttl_seconds: float, *args, **kwargs): + super().__init__(*args, **kwargs) + self.ttl_seconds = ttl_seconds + self._expires: dict[Any, Any] = {} # expires holds when an item will expire + self._lock = RLock() + + if args or kwargs: + for k, v in self.items(): + self.__setitem__(k, v) + + def __delitem__(self, key): + with self._lock: + del self._expires[key] + super().__delitem__(key) + + def __setitem__(self, key, value): + with self._lock: + self._expires[key] = time.monotonic() + self.ttl_seconds + super().__setitem__(key, value) + + def _is_expired(self, key): + if key not in self._expires: + return False + return time.monotonic() > self._expires[key] + + def __getitem__(self, key): + with self._lock: + if self._is_expired(key): + del self._expires[key] + super().__delitem__(key) + raise KeyError(f"{key} has expired and was removed") + + return super().__getitem__(key) + + def get(self, key, default=None): + try: + return self[key] + except KeyError: + return default + + def __contains__(self, key): + try: + _ = self[key] + return True + except KeyError: + return False + + def __repr__(self): + with self._lock: + for key in self.keys(): + if self._is_expired(key): + del self._expires[key] + super().__delitem__(key) + return f"TTLDict({self.ttl_seconds}, {super().__repr__()})" diff --git a/llama_stack/schema_utils.py b/llama_stack/schema_utils.py index 694de333e..93382a881 100644 --- a/llama_stack/schema_utils.py +++ b/llama_stack/schema_utils.py @@ -22,6 +22,7 @@ class WebMethod: # A descriptive name of the corresponding span created by tracing descriptive_name: str | None = None experimental: bool | None = False + required_scope: str | None = None T = TypeVar("T", bound=Callable[..., Any]) @@ -36,6 +37,7 @@ def webmethod( raw_bytes_request_body: bool | None = False, descriptive_name: str | None = None, experimental: bool | None = False, + required_scope: str | None = None, ) -> Callable[[T], T]: """ Decorator that supplies additional metadata to an endpoint operation function. @@ -45,6 +47,7 @@ def webmethod( :param request_examples: Sample requests that the operation might take. Pass a list of objects, not JSON. :param response_examples: Sample responses that the operation might produce. Pass a list of objects, not JSON. :param experimental: True if the operation is experimental and subject to change. + :param required_scope: Required scope for this endpoint (e.g., 'monitoring.viewer'). """ def wrap(func: T) -> T: @@ -57,6 +60,7 @@ def webmethod( raw_bytes_request_body=raw_bytes_request_body, descriptive_name=descriptive_name, experimental=experimental, + required_scope=required_scope, ) return func diff --git a/llama_stack/templates/ci-tests/build.yaml b/llama_stack/templates/ci-tests/build.yaml index 625e36e4f..2f18e5d26 100644 --- a/llama_stack/templates/ci-tests/build.yaml +++ b/llama_stack/templates/ci-tests/build.yaml @@ -3,57 +3,98 @@ distribution_spec: description: CI tests for Llama Stack providers: inference: - - remote::cerebras - - remote::ollama - - remote::vllm - - remote::tgi - - remote::hf::serverless - - remote::hf::endpoint - - remote::fireworks - - remote::together - - remote::bedrock - - remote::databricks - - remote::nvidia - - remote::runpod - - remote::openai - - remote::anthropic - - remote::gemini - - remote::groq - - remote::llama-openai-compat - - remote::sambanova - - remote::passthrough - - inline::sentence-transformers + - provider_id: ${env.ENABLE_CEREBRAS:=__disabled__} + provider_type: remote::cerebras + - provider_id: ${env.ENABLE_OLLAMA:=__disabled__} + provider_type: remote::ollama + - provider_id: ${env.ENABLE_VLLM:=__disabled__} + provider_type: remote::vllm + - provider_id: ${env.ENABLE_TGI:=__disabled__} + provider_type: remote::tgi + - provider_id: ${env.ENABLE_HF_SERVERLESS:=__disabled__} + provider_type: remote::hf::serverless + - provider_id: ${env.ENABLE_HF_ENDPOINT:=__disabled__} + provider_type: remote::hf::endpoint + - provider_id: ${env.ENABLE_FIREWORKS:=__disabled__} + provider_type: remote::fireworks + - provider_id: ${env.ENABLE_TOGETHER:=__disabled__} + provider_type: remote::together + - provider_id: ${env.ENABLE_BEDROCK:=__disabled__} + provider_type: remote::bedrock + - provider_id: ${env.ENABLE_DATABRICKS:=__disabled__} + provider_type: remote::databricks + - provider_id: ${env.ENABLE_NVIDIA:=__disabled__} + provider_type: remote::nvidia + - provider_id: ${env.ENABLE_RUNPOD:=__disabled__} + provider_type: remote::runpod + - provider_id: ${env.ENABLE_OPENAI:=__disabled__} + provider_type: remote::openai + - provider_id: ${env.ENABLE_ANTHROPIC:=__disabled__} + provider_type: remote::anthropic + - provider_id: ${env.ENABLE_GEMINI:=__disabled__} + provider_type: remote::gemini + - provider_id: ${env.ENABLE_GROQ:=__disabled__} + provider_type: remote::groq + - provider_id: ${env.ENABLE_LLAMA_OPENAI_COMPAT:=__disabled__} + provider_type: remote::llama-openai-compat + - provider_id: ${env.ENABLE_SAMBANOVA:=__disabled__} + provider_type: remote::sambanova + - provider_id: ${env.ENABLE_PASSTHROUGH:=__disabled__} + provider_type: remote::passthrough + - provider_id: sentence-transformers + provider_type: inline::sentence-transformers vector_io: - - inline::faiss - - inline::sqlite-vec - - inline::milvus - - remote::chromadb - - remote::pgvector + - provider_id: ${env.ENABLE_FAISS:=faiss} + provider_type: inline::faiss + - provider_id: ${env.ENABLE_SQLITE_VEC:=__disabled__} + provider_type: inline::sqlite-vec + - provider_id: ${env.ENABLE_MILVUS:=__disabled__} + provider_type: inline::milvus + - provider_id: ${env.ENABLE_CHROMADB:=__disabled__} + provider_type: remote::chromadb + - provider_id: ${env.ENABLE_PGVECTOR:=__disabled__} + provider_type: remote::pgvector files: - - inline::localfs + - provider_id: localfs + provider_type: inline::localfs safety: - - inline::llama-guard + - provider_id: llama-guard + provider_type: inline::llama-guard agents: - - inline::meta-reference + - provider_id: meta-reference + provider_type: inline::meta-reference telemetry: - - inline::meta-reference + - provider_id: meta-reference + provider_type: inline::meta-reference post_training: - - inline::huggingface + - provider_id: huggingface + provider_type: inline::huggingface eval: - - inline::meta-reference + - provider_id: meta-reference + provider_type: inline::meta-reference datasetio: - - remote::huggingface - - inline::localfs + - provider_id: huggingface + provider_type: remote::huggingface + - provider_id: localfs + provider_type: inline::localfs scoring: - - inline::basic - - inline::llm-as-judge - - inline::braintrust + - provider_id: basic + provider_type: inline::basic + - provider_id: llm-as-judge + provider_type: inline::llm-as-judge + - provider_id: braintrust + provider_type: inline::braintrust tool_runtime: - - remote::brave-search - - remote::tavily-search - - inline::rag-runtime - - remote::model-context-protocol + - provider_id: brave-search + provider_type: remote::brave-search + - provider_id: tavily-search + provider_type: remote::tavily-search + - provider_id: rag-runtime + provider_type: inline::rag-runtime + - provider_id: model-context-protocol + provider_type: remote::model-context-protocol image_type: conda +image_name: ci-tests additional_pip_packages: - aiosqlite - asyncpg diff --git a/llama_stack/templates/ci-tests/run.yaml b/llama_stack/templates/ci-tests/run.yaml index 1396d54a8..6f8a192ee 100644 --- a/llama_stack/templates/ci-tests/run.yaml +++ b/llama_stack/templates/ci-tests/run.yaml @@ -56,7 +56,6 @@ providers: api_key: ${env.TOGETHER_API_KEY} - provider_id: ${env.ENABLE_BEDROCK:=__disabled__} provider_type: remote::bedrock - config: {} - provider_id: ${env.ENABLE_DATABRICKS:=__disabled__} provider_type: remote::databricks config: @@ -107,7 +106,6 @@ providers: api_key: ${env.PASSTHROUGH_API_KEY} - provider_id: ${env.ENABLE_SENTENCE_TRANSFORMERS:=sentence-transformers} provider_type: inline::sentence-transformers - config: {} vector_io: - provider_id: ${env.ENABLE_FAISS:=faiss} provider_type: inline::faiss @@ -208,10 +206,8 @@ providers: scoring: - provider_id: basic provider_type: inline::basic - config: {} - provider_id: llm-as-judge provider_type: inline::llm-as-judge - config: {} - provider_id: braintrust provider_type: inline::braintrust config: @@ -229,10 +225,8 @@ providers: max_results: 3 - provider_id: rag-runtime provider_type: inline::rag-runtime - config: {} - provider_id: model-context-protocol provider_type: remote::model-context-protocol - config: {} metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/registry.db diff --git a/llama_stack/templates/dell/build.yaml b/llama_stack/templates/dell/build.yaml index ff8d58a08..d19934ee5 100644 --- a/llama_stack/templates/dell/build.yaml +++ b/llama_stack/templates/dell/build.yaml @@ -4,32 +4,50 @@ distribution_spec: container providers: inference: - - remote::tgi - - inline::sentence-transformers + - provider_id: tgi + provider_type: remote::tgi + - provider_id: sentence-transformers + provider_type: inline::sentence-transformers vector_io: - - inline::faiss - - remote::chromadb - - remote::pgvector + - provider_id: faiss + provider_type: inline::faiss + - provider_id: chromadb + provider_type: remote::chromadb + - provider_id: pgvector + provider_type: remote::pgvector safety: - - inline::llama-guard + - provider_id: llama-guard + provider_type: inline::llama-guard agents: - - inline::meta-reference + - provider_id: meta-reference + provider_type: inline::meta-reference telemetry: - - inline::meta-reference + - provider_id: meta-reference + provider_type: inline::meta-reference eval: - - inline::meta-reference + - provider_id: meta-reference + provider_type: inline::meta-reference datasetio: - - remote::huggingface - - inline::localfs + - provider_id: huggingface + provider_type: remote::huggingface + - provider_id: localfs + provider_type: inline::localfs scoring: - - inline::basic - - inline::llm-as-judge - - inline::braintrust + - provider_id: basic + provider_type: inline::basic + - provider_id: llm-as-judge + provider_type: inline::llm-as-judge + - provider_id: braintrust + provider_type: inline::braintrust tool_runtime: - - remote::brave-search - - remote::tavily-search - - inline::rag-runtime + - provider_id: brave-search + provider_type: remote::brave-search + - provider_id: tavily-search + provider_type: remote::tavily-search + - provider_id: rag-runtime + provider_type: inline::rag-runtime image_type: conda +image_name: dell additional_pip_packages: - aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/dell/dell.py b/llama_stack/templates/dell/dell.py index 5a6f52a89..b2210e7dc 100644 --- a/llama_stack/templates/dell/dell.py +++ b/llama_stack/templates/dell/dell.py @@ -19,18 +19,32 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin def get_distribution_template() -> DistributionTemplate: providers = { - "inference": ["remote::tgi", "inline::sentence-transformers"], - "vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"], - "safety": ["inline::llama-guard"], - "agents": ["inline::meta-reference"], - "telemetry": ["inline::meta-reference"], - "eval": ["inline::meta-reference"], - "datasetio": ["remote::huggingface", "inline::localfs"], - "scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"], + "inference": [ + Provider(provider_id="tgi", provider_type="remote::tgi"), + Provider(provider_id="sentence-transformers", provider_type="inline::sentence-transformers"), + ], + "vector_io": [ + Provider(provider_id="faiss", provider_type="inline::faiss"), + Provider(provider_id="chromadb", provider_type="remote::chromadb"), + Provider(provider_id="pgvector", provider_type="remote::pgvector"), + ], + "safety": [Provider(provider_id="llama-guard", provider_type="inline::llama-guard")], + "agents": [Provider(provider_id="meta-reference", provider_type="inline::meta-reference")], + "telemetry": [Provider(provider_id="meta-reference", provider_type="inline::meta-reference")], + "eval": [Provider(provider_id="meta-reference", provider_type="inline::meta-reference")], + "datasetio": [ + Provider(provider_id="huggingface", provider_type="remote::huggingface"), + Provider(provider_id="localfs", provider_type="inline::localfs"), + ], + "scoring": [ + Provider(provider_id="basic", provider_type="inline::basic"), + Provider(provider_id="llm-as-judge", provider_type="inline::llm-as-judge"), + Provider(provider_id="braintrust", provider_type="inline::braintrust"), + ], "tool_runtime": [ - "remote::brave-search", - "remote::tavily-search", - "inline::rag-runtime", + Provider(provider_id="brave-search", provider_type="remote::brave-search"), + Provider(provider_id="tavily-search", provider_type="remote::tavily-search"), + Provider(provider_id="rag-runtime", provider_type="inline::rag-runtime"), ], } name = "dell" diff --git a/llama_stack/templates/dell/run-with-safety.yaml b/llama_stack/templates/dell/run-with-safety.yaml index 768fad4fa..ecc6729eb 100644 --- a/llama_stack/templates/dell/run-with-safety.yaml +++ b/llama_stack/templates/dell/run-with-safety.yaml @@ -22,7 +22,6 @@ providers: url: ${env.DEH_SAFETY_URL} - provider_id: sentence-transformers provider_type: inline::sentence-transformers - config: {} vector_io: - provider_id: chromadb provider_type: remote::chromadb @@ -74,10 +73,8 @@ providers: scoring: - provider_id: basic provider_type: inline::basic - config: {} - provider_id: llm-as-judge provider_type: inline::llm-as-judge - config: {} - provider_id: braintrust provider_type: inline::braintrust config: @@ -95,7 +92,6 @@ providers: max_results: 3 - provider_id: rag-runtime provider_type: inline::rag-runtime - config: {} metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/registry.db diff --git a/llama_stack/templates/dell/run.yaml b/llama_stack/templates/dell/run.yaml index de2ada009..fc2553526 100644 --- a/llama_stack/templates/dell/run.yaml +++ b/llama_stack/templates/dell/run.yaml @@ -18,7 +18,6 @@ providers: url: ${env.DEH_URL} - provider_id: sentence-transformers provider_type: inline::sentence-transformers - config: {} vector_io: - provider_id: chromadb provider_type: remote::chromadb @@ -70,10 +69,8 @@ providers: scoring: - provider_id: basic provider_type: inline::basic - config: {} - provider_id: llm-as-judge provider_type: inline::llm-as-judge - config: {} - provider_id: braintrust provider_type: inline::braintrust config: @@ -91,7 +88,6 @@ providers: max_results: 3 - provider_id: rag-runtime provider_type: inline::rag-runtime - config: {} metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/registry.db diff --git a/llama_stack/templates/meta-reference-gpu/build.yaml b/llama_stack/templates/meta-reference-gpu/build.yaml index 2119eeddd..0a0bc0aea 100644 --- a/llama_stack/templates/meta-reference-gpu/build.yaml +++ b/llama_stack/templates/meta-reference-gpu/build.yaml @@ -3,32 +3,50 @@ distribution_spec: description: Use Meta Reference for running LLM inference providers: inference: - - inline::meta-reference + - provider_id: meta-reference + provider_type: inline::meta-reference vector_io: - - inline::faiss - - remote::chromadb - - remote::pgvector + - provider_id: faiss + provider_type: inline::faiss + - provider_id: chromadb + provider_type: remote::chromadb + - provider_id: pgvector + provider_type: remote::pgvector safety: - - inline::llama-guard + - provider_id: llama-guard + provider_type: inline::llama-guard agents: - - inline::meta-reference + - provider_id: meta-reference + provider_type: inline::meta-reference telemetry: - - inline::meta-reference + - provider_id: meta-reference + provider_type: inline::meta-reference eval: - - inline::meta-reference + - provider_id: meta-reference + provider_type: inline::meta-reference datasetio: - - remote::huggingface - - inline::localfs + - provider_id: huggingface + provider_type: remote::huggingface + - provider_id: localfs + provider_type: inline::localfs scoring: - - inline::basic - - inline::llm-as-judge - - inline::braintrust + - provider_id: basic + provider_type: inline::basic + - provider_id: llm-as-judge + provider_type: inline::llm-as-judge + - provider_id: braintrust + provider_type: inline::braintrust tool_runtime: - - remote::brave-search - - remote::tavily-search - - inline::rag-runtime - - remote::model-context-protocol + - provider_id: brave-search + provider_type: remote::brave-search + - provider_id: tavily-search + provider_type: remote::tavily-search + - provider_id: rag-runtime + provider_type: inline::rag-runtime + - provider_id: model-context-protocol + provider_type: remote::model-context-protocol image_type: conda +image_name: meta-reference-gpu additional_pip_packages: - aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/meta-reference-gpu/meta_reference.py b/llama_stack/templates/meta-reference-gpu/meta_reference.py index 4bfb4e9d8..6ca500eff 100644 --- a/llama_stack/templates/meta-reference-gpu/meta_reference.py +++ b/llama_stack/templates/meta-reference-gpu/meta_reference.py @@ -25,19 +25,91 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin def get_distribution_template() -> DistributionTemplate: providers = { - "inference": ["inline::meta-reference"], - "vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"], - "safety": ["inline::llama-guard"], - "agents": ["inline::meta-reference"], - "telemetry": ["inline::meta-reference"], - "eval": ["inline::meta-reference"], - "datasetio": ["remote::huggingface", "inline::localfs"], - "scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"], + "inference": [ + Provider( + provider_id="meta-reference", + provider_type="inline::meta-reference", + ) + ], + "vector_io": [ + Provider( + provider_id="faiss", + provider_type="inline::faiss", + ), + Provider( + provider_id="chromadb", + provider_type="remote::chromadb", + ), + Provider( + provider_id="pgvector", + provider_type="remote::pgvector", + ), + ], + "safety": [ + Provider( + provider_id="llama-guard", + provider_type="inline::llama-guard", + ) + ], + "agents": [ + Provider( + provider_id="meta-reference", + provider_type="inline::meta-reference", + ) + ], + "telemetry": [ + Provider( + provider_id="meta-reference", + provider_type="inline::meta-reference", + ) + ], + "eval": [ + Provider( + provider_id="meta-reference", + provider_type="inline::meta-reference", + ) + ], + "datasetio": [ + Provider( + provider_id="huggingface", + provider_type="remote::huggingface", + ), + Provider( + provider_id="localfs", + provider_type="inline::localfs", + ), + ], + "scoring": [ + Provider( + provider_id="basic", + provider_type="inline::basic", + ), + Provider( + provider_id="llm-as-judge", + provider_type="inline::llm-as-judge", + ), + Provider( + provider_id="braintrust", + provider_type="inline::braintrust", + ), + ], "tool_runtime": [ - "remote::brave-search", - "remote::tavily-search", - "inline::rag-runtime", - "remote::model-context-protocol", + Provider( + provider_id="brave-search", + provider_type="remote::brave-search", + ), + Provider( + provider_id="tavily-search", + provider_type="remote::tavily-search", + ), + Provider( + provider_id="rag-runtime", + provider_type="inline::rag-runtime", + ), + Provider( + provider_id="model-context-protocol", + provider_type="remote::model-context-protocol", + ), ], } name = "meta-reference-gpu" diff --git a/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml b/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml index 49657a680..910f9ec46 100644 --- a/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml +++ b/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml @@ -24,7 +24,6 @@ providers: max_seq_len: ${env.MAX_SEQ_LEN:=4096} - provider_id: sentence-transformers provider_type: inline::sentence-transformers - config: {} - provider_id: meta-reference-safety provider_type: inline::meta-reference config: @@ -88,10 +87,8 @@ providers: scoring: - provider_id: basic provider_type: inline::basic - config: {} - provider_id: llm-as-judge provider_type: inline::llm-as-judge - config: {} - provider_id: braintrust provider_type: inline::braintrust config: @@ -109,10 +106,8 @@ providers: max_results: 3 - provider_id: rag-runtime provider_type: inline::rag-runtime - config: {} - provider_id: model-context-protocol provider_type: remote::model-context-protocol - config: {} metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/registry.db diff --git a/llama_stack/templates/meta-reference-gpu/run.yaml b/llama_stack/templates/meta-reference-gpu/run.yaml index 2923b5faf..5266f3c84 100644 --- a/llama_stack/templates/meta-reference-gpu/run.yaml +++ b/llama_stack/templates/meta-reference-gpu/run.yaml @@ -24,7 +24,6 @@ providers: max_seq_len: ${env.MAX_SEQ_LEN:=4096} - provider_id: sentence-transformers provider_type: inline::sentence-transformers - config: {} vector_io: - provider_id: faiss provider_type: inline::faiss @@ -78,10 +77,8 @@ providers: scoring: - provider_id: basic provider_type: inline::basic - config: {} - provider_id: llm-as-judge provider_type: inline::llm-as-judge - config: {} - provider_id: braintrust provider_type: inline::braintrust config: @@ -99,10 +96,8 @@ providers: max_results: 3 - provider_id: rag-runtime provider_type: inline::rag-runtime - config: {} - provider_id: model-context-protocol provider_type: remote::model-context-protocol - config: {} metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/registry.db diff --git a/llama_stack/templates/nvidia/build.yaml b/llama_stack/templates/nvidia/build.yaml index 51685b2e3..572a70408 100644 --- a/llama_stack/templates/nvidia/build.yaml +++ b/llama_stack/templates/nvidia/build.yaml @@ -3,27 +3,39 @@ distribution_spec: description: Use NVIDIA NIM for running LLM inference, evaluation and safety providers: inference: - - remote::nvidia + - provider_id: nvidia + provider_type: remote::nvidia vector_io: - - inline::faiss + - provider_id: faiss + provider_type: inline::faiss safety: - - remote::nvidia + - provider_id: nvidia + provider_type: remote::nvidia agents: - - inline::meta-reference + - provider_id: meta-reference + provider_type: inline::meta-reference telemetry: - - inline::meta-reference + - provider_id: meta-reference + provider_type: inline::meta-reference eval: - - remote::nvidia + - provider_id: nvidia + provider_type: remote::nvidia post_training: - - remote::nvidia + - provider_id: nvidia + provider_type: remote::nvidia datasetio: - - inline::localfs - - remote::nvidia + - provider_id: localfs + provider_type: inline::localfs + - provider_id: nvidia + provider_type: remote::nvidia scoring: - - inline::basic + - provider_id: basic + provider_type: inline::basic tool_runtime: - - inline::rag-runtime + - provider_id: rag-runtime + provider_type: inline::rag-runtime image_type: conda +image_name: nvidia additional_pip_packages: - aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/nvidia/doc_template.md b/llama_stack/templates/nvidia/doc_template.md index 3cb8245df..5a180d49f 100644 --- a/llama_stack/templates/nvidia/doc_template.md +++ b/llama_stack/templates/nvidia/doc_template.md @@ -1,3 +1,6 @@ +--- +orphan: true +--- # NVIDIA Distribution The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations. diff --git a/llama_stack/templates/nvidia/nvidia.py b/llama_stack/templates/nvidia/nvidia.py index e5c13aa74..25beeae75 100644 --- a/llama_stack/templates/nvidia/nvidia.py +++ b/llama_stack/templates/nvidia/nvidia.py @@ -17,16 +17,65 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin def get_distribution_template() -> DistributionTemplate: providers = { - "inference": ["remote::nvidia"], - "vector_io": ["inline::faiss"], - "safety": ["remote::nvidia"], - "agents": ["inline::meta-reference"], - "telemetry": ["inline::meta-reference"], - "eval": ["remote::nvidia"], - "post_training": ["remote::nvidia"], - "datasetio": ["inline::localfs", "remote::nvidia"], - "scoring": ["inline::basic"], - "tool_runtime": ["inline::rag-runtime"], + "inference": [ + Provider( + provider_id="nvidia", + provider_type="remote::nvidia", + ) + ], + "vector_io": [ + Provider( + provider_id="faiss", + provider_type="inline::faiss", + ) + ], + "safety": [ + Provider( + provider_id="nvidia", + provider_type="remote::nvidia", + ) + ], + "agents": [ + Provider( + provider_id="meta-reference", + provider_type="inline::meta-reference", + ) + ], + "telemetry": [ + Provider( + provider_id="meta-reference", + provider_type="inline::meta-reference", + ) + ], + "eval": [ + Provider( + provider_id="nvidia", + provider_type="remote::nvidia", + ) + ], + "post_training": [Provider(provider_id="nvidia", provider_type="remote::nvidia", config={})], + "datasetio": [ + Provider( + provider_id="localfs", + provider_type="inline::localfs", + ), + Provider( + provider_id="nvidia", + provider_type="remote::nvidia", + ), + ], + "scoring": [ + Provider( + provider_id="basic", + provider_type="inline::basic", + ) + ], + "tool_runtime": [ + Provider( + provider_id="rag-runtime", + provider_type="inline::rag-runtime", + ) + ], } inference_provider = Provider( diff --git a/llama_stack/templates/nvidia/run-with-safety.yaml b/llama_stack/templates/nvidia/run-with-safety.yaml index 7017a5955..015724050 100644 --- a/llama_stack/templates/nvidia/run-with-safety.yaml +++ b/llama_stack/templates/nvidia/run-with-safety.yaml @@ -85,11 +85,9 @@ providers: scoring: - provider_id: basic provider_type: inline::basic - config: {} tool_runtime: - provider_id: rag-runtime provider_type: inline::rag-runtime - config: {} metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/registry.db diff --git a/llama_stack/templates/nvidia/run.yaml b/llama_stack/templates/nvidia/run.yaml index ccddf11a2..f087e89ee 100644 --- a/llama_stack/templates/nvidia/run.yaml +++ b/llama_stack/templates/nvidia/run.yaml @@ -74,11 +74,9 @@ providers: scoring: - provider_id: basic provider_type: inline::basic - config: {} tool_runtime: - provider_id: rag-runtime provider_type: inline::rag-runtime - config: {} metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/registry.db diff --git a/llama_stack/templates/open-benchmark/build.yaml b/llama_stack/templates/open-benchmark/build.yaml index 5f82c5243..6647b471c 100644 --- a/llama_stack/templates/open-benchmark/build.yaml +++ b/llama_stack/templates/open-benchmark/build.yaml @@ -3,36 +3,58 @@ distribution_spec: description: Distribution for running open benchmarks providers: inference: - - remote::openai - - remote::anthropic - - remote::gemini - - remote::groq - - remote::together + - provider_id: openai + provider_type: remote::openai + - provider_id: anthropic + provider_type: remote::anthropic + - provider_id: gemini + provider_type: remote::gemini + - provider_id: groq + provider_type: remote::groq + - provider_id: together + provider_type: remote::together vector_io: - - inline::sqlite-vec - - remote::chromadb - - remote::pgvector + - provider_id: sqlite-vec + provider_type: inline::sqlite-vec + - provider_id: chromadb + provider_type: remote::chromadb + - provider_id: pgvector + provider_type: remote::pgvector safety: - - inline::llama-guard + - provider_id: llama-guard + provider_type: inline::llama-guard agents: - - inline::meta-reference + - provider_id: meta-reference + provider_type: inline::meta-reference telemetry: - - inline::meta-reference + - provider_id: meta-reference + provider_type: inline::meta-reference eval: - - inline::meta-reference + - provider_id: meta-reference + provider_type: inline::meta-reference datasetio: - - remote::huggingface - - inline::localfs + - provider_id: huggingface + provider_type: remote::huggingface + - provider_id: localfs + provider_type: inline::localfs scoring: - - inline::basic - - inline::llm-as-judge - - inline::braintrust + - provider_id: basic + provider_type: inline::basic + - provider_id: llm-as-judge + provider_type: inline::llm-as-judge + - provider_id: braintrust + provider_type: inline::braintrust tool_runtime: - - remote::brave-search - - remote::tavily-search - - inline::rag-runtime - - remote::model-context-protocol + - provider_id: brave-search + provider_type: remote::brave-search + - provider_id: tavily-search + provider_type: remote::tavily-search + - provider_id: rag-runtime + provider_type: inline::rag-runtime + - provider_id: model-context-protocol + provider_type: remote::model-context-protocol image_type: conda +image_name: open-benchmark additional_pip_packages: - aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/open-benchmark/open_benchmark.py b/llama_stack/templates/open-benchmark/open_benchmark.py index ae25c9fc9..3a17e7525 100644 --- a/llama_stack/templates/open-benchmark/open_benchmark.py +++ b/llama_stack/templates/open-benchmark/open_benchmark.py @@ -96,19 +96,33 @@ def get_inference_providers() -> tuple[list[Provider], dict[str, list[ProviderMo def get_distribution_template() -> DistributionTemplate: inference_providers, available_models = get_inference_providers() providers = { - "inference": [p.provider_type for p in inference_providers], - "vector_io": ["inline::sqlite-vec", "remote::chromadb", "remote::pgvector"], - "safety": ["inline::llama-guard"], - "agents": ["inline::meta-reference"], - "telemetry": ["inline::meta-reference"], - "eval": ["inline::meta-reference"], - "datasetio": ["remote::huggingface", "inline::localfs"], - "scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"], + "inference": inference_providers, + "vector_io": [ + Provider(provider_id="sqlite-vec", provider_type="inline::sqlite-vec"), + Provider(provider_id="chromadb", provider_type="remote::chromadb"), + Provider(provider_id="pgvector", provider_type="remote::pgvector"), + ], + "safety": [Provider(provider_id="llama-guard", provider_type="inline::llama-guard")], + "agents": [Provider(provider_id="meta-reference", provider_type="inline::meta-reference")], + "telemetry": [Provider(provider_id="meta-reference", provider_type="inline::meta-reference")], + "eval": [Provider(provider_id="meta-reference", provider_type="inline::meta-reference")], + "datasetio": [ + Provider(provider_id="huggingface", provider_type="remote::huggingface"), + Provider(provider_id="localfs", provider_type="inline::localfs"), + ], + "scoring": [ + Provider(provider_id="basic", provider_type="inline::basic"), + Provider(provider_id="llm-as-judge", provider_type="inline::llm-as-judge"), + Provider(provider_id="braintrust", provider_type="inline::braintrust"), + ], "tool_runtime": [ - "remote::brave-search", - "remote::tavily-search", - "inline::rag-runtime", - "remote::model-context-protocol", + Provider(provider_id="brave-search", provider_type="remote::brave-search"), + Provider(provider_id="tavily-search", provider_type="remote::tavily-search"), + Provider(provider_id="rag-runtime", provider_type="inline::rag-runtime"), + Provider( + provider_id="model-context-protocol", + provider_type="remote::model-context-protocol", + ), ], } name = "open-benchmark" diff --git a/llama_stack/templates/open-benchmark/run.yaml b/llama_stack/templates/open-benchmark/run.yaml index 828b960a2..ba6a5e9d6 100644 --- a/llama_stack/templates/open-benchmark/run.yaml +++ b/llama_stack/templates/open-benchmark/run.yaml @@ -106,10 +106,8 @@ providers: scoring: - provider_id: basic provider_type: inline::basic - config: {} - provider_id: llm-as-judge provider_type: inline::llm-as-judge - config: {} - provider_id: braintrust provider_type: inline::braintrust config: @@ -127,10 +125,8 @@ providers: max_results: 3 - provider_id: rag-runtime provider_type: inline::rag-runtime - config: {} - provider_id: model-context-protocol provider_type: remote::model-context-protocol - config: {} metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/registry.db diff --git a/llama_stack/templates/postgres-demo/build.yaml b/llama_stack/templates/postgres-demo/build.yaml index 645b59613..d5e816a54 100644 --- a/llama_stack/templates/postgres-demo/build.yaml +++ b/llama_stack/templates/postgres-demo/build.yaml @@ -3,22 +3,33 @@ distribution_spec: description: Quick start template for running Llama Stack with several popular providers providers: inference: - - remote::vllm - - inline::sentence-transformers + - provider_id: vllm-inference + provider_type: remote::vllm + - provider_id: sentence-transformers + provider_type: inline::sentence-transformers vector_io: - - remote::chromadb + - provider_id: chromadb + provider_type: remote::chromadb safety: - - inline::llama-guard + - provider_id: llama-guard + provider_type: inline::llama-guard agents: - - inline::meta-reference + - provider_id: meta-reference + provider_type: inline::meta-reference telemetry: - - inline::meta-reference + - provider_id: meta-reference + provider_type: inline::meta-reference tool_runtime: - - remote::brave-search - - remote::tavily-search - - inline::rag-runtime - - remote::model-context-protocol + - provider_id: brave-search + provider_type: remote::brave-search + - provider_id: tavily-search + provider_type: remote::tavily-search + - provider_id: rag-runtime + provider_type: inline::rag-runtime + - provider_id: model-context-protocol + provider_type: remote::model-context-protocol image_type: conda +image_name: postgres-demo additional_pip_packages: - asyncpg - psycopg2-binary diff --git a/llama_stack/templates/postgres-demo/postgres_demo.py b/llama_stack/templates/postgres-demo/postgres_demo.py index c7ab222ec..24e3f6f27 100644 --- a/llama_stack/templates/postgres-demo/postgres_demo.py +++ b/llama_stack/templates/postgres-demo/postgres_demo.py @@ -34,16 +34,24 @@ def get_distribution_template() -> DistributionTemplate: ), ] providers = { - "inference": ([p.provider_type for p in inference_providers] + ["inline::sentence-transformers"]), - "vector_io": ["remote::chromadb"], - "safety": ["inline::llama-guard"], - "agents": ["inline::meta-reference"], - "telemetry": ["inline::meta-reference"], + "inference": inference_providers + + [ + Provider(provider_id="sentence-transformers", provider_type="inline::sentence-transformers"), + ], + "vector_io": [ + Provider(provider_id="chromadb", provider_type="remote::chromadb"), + ], + "safety": [Provider(provider_id="llama-guard", provider_type="inline::llama-guard")], + "agents": [Provider(provider_id="meta-reference", provider_type="inline::meta-reference")], + "telemetry": [Provider(provider_id="meta-reference", provider_type="inline::meta-reference")], "tool_runtime": [ - "remote::brave-search", - "remote::tavily-search", - "inline::rag-runtime", - "remote::model-context-protocol", + Provider(provider_id="brave-search", provider_type="remote::brave-search"), + Provider(provider_id="tavily-search", provider_type="remote::tavily-search"), + Provider(provider_id="rag-runtime", provider_type="inline::rag-runtime"), + Provider( + provider_id="model-context-protocol", + provider_type="remote::model-context-protocol", + ), ], } name = "postgres-demo" diff --git a/llama_stack/templates/postgres-demo/run.yaml b/llama_stack/templates/postgres-demo/run.yaml index feb85e316..747b7dc53 100644 --- a/llama_stack/templates/postgres-demo/run.yaml +++ b/llama_stack/templates/postgres-demo/run.yaml @@ -18,7 +18,6 @@ providers: tls_verify: ${env.VLLM_TLS_VERIFY:=true} - provider_id: sentence-transformers provider_type: inline::sentence-transformers - config: {} vector_io: - provider_id: ${env.ENABLE_CHROMADB:+chromadb} provider_type: remote::chromadb @@ -70,10 +69,8 @@ providers: max_results: 3 - provider_id: rag-runtime provider_type: inline::rag-runtime - config: {} - provider_id: model-context-protocol provider_type: remote::model-context-protocol - config: {} metadata_store: type: postgres host: ${env.POSTGRES_HOST:=localhost} diff --git a/llama_stack/templates/starter/build.yaml b/llama_stack/templates/starter/build.yaml index 8180124f6..9b540ab62 100644 --- a/llama_stack/templates/starter/build.yaml +++ b/llama_stack/templates/starter/build.yaml @@ -3,57 +3,98 @@ distribution_spec: description: Quick start template for running Llama Stack with several popular providers providers: inference: - - remote::cerebras - - remote::ollama - - remote::vllm - - remote::tgi - - remote::hf::serverless - - remote::hf::endpoint - - remote::fireworks - - remote::together - - remote::bedrock - - remote::databricks - - remote::nvidia - - remote::runpod - - remote::openai - - remote::anthropic - - remote::gemini - - remote::groq - - remote::llama-openai-compat - - remote::sambanova - - remote::passthrough - - inline::sentence-transformers + - provider_id: ${env.ENABLE_CEREBRAS:=__disabled__} + provider_type: remote::cerebras + - provider_id: ${env.ENABLE_OLLAMA:=__disabled__} + provider_type: remote::ollama + - provider_id: ${env.ENABLE_VLLM:=__disabled__} + provider_type: remote::vllm + - provider_id: ${env.ENABLE_TGI:=__disabled__} + provider_type: remote::tgi + - provider_id: ${env.ENABLE_HF_SERVERLESS:=__disabled__} + provider_type: remote::hf::serverless + - provider_id: ${env.ENABLE_HF_ENDPOINT:=__disabled__} + provider_type: remote::hf::endpoint + - provider_id: ${env.ENABLE_FIREWORKS:=__disabled__} + provider_type: remote::fireworks + - provider_id: ${env.ENABLE_TOGETHER:=__disabled__} + provider_type: remote::together + - provider_id: ${env.ENABLE_BEDROCK:=__disabled__} + provider_type: remote::bedrock + - provider_id: ${env.ENABLE_DATABRICKS:=__disabled__} + provider_type: remote::databricks + - provider_id: ${env.ENABLE_NVIDIA:=__disabled__} + provider_type: remote::nvidia + - provider_id: ${env.ENABLE_RUNPOD:=__disabled__} + provider_type: remote::runpod + - provider_id: ${env.ENABLE_OPENAI:=__disabled__} + provider_type: remote::openai + - provider_id: ${env.ENABLE_ANTHROPIC:=__disabled__} + provider_type: remote::anthropic + - provider_id: ${env.ENABLE_GEMINI:=__disabled__} + provider_type: remote::gemini + - provider_id: ${env.ENABLE_GROQ:=__disabled__} + provider_type: remote::groq + - provider_id: ${env.ENABLE_LLAMA_OPENAI_COMPAT:=__disabled__} + provider_type: remote::llama-openai-compat + - provider_id: ${env.ENABLE_SAMBANOVA:=__disabled__} + provider_type: remote::sambanova + - provider_id: ${env.ENABLE_PASSTHROUGH:=__disabled__} + provider_type: remote::passthrough + - provider_id: sentence-transformers + provider_type: inline::sentence-transformers vector_io: - - inline::faiss - - inline::sqlite-vec - - inline::milvus - - remote::chromadb - - remote::pgvector + - provider_id: ${env.ENABLE_FAISS:=faiss} + provider_type: inline::faiss + - provider_id: ${env.ENABLE_SQLITE_VEC:=__disabled__} + provider_type: inline::sqlite-vec + - provider_id: ${env.ENABLE_MILVUS:=__disabled__} + provider_type: inline::milvus + - provider_id: ${env.ENABLE_CHROMADB:=__disabled__} + provider_type: remote::chromadb + - provider_id: ${env.ENABLE_PGVECTOR:=__disabled__} + provider_type: remote::pgvector files: - - inline::localfs + - provider_id: localfs + provider_type: inline::localfs safety: - - inline::llama-guard + - provider_id: llama-guard + provider_type: inline::llama-guard agents: - - inline::meta-reference + - provider_id: meta-reference + provider_type: inline::meta-reference telemetry: - - inline::meta-reference + - provider_id: meta-reference + provider_type: inline::meta-reference post_training: - - inline::huggingface + - provider_id: huggingface + provider_type: inline::huggingface eval: - - inline::meta-reference + - provider_id: meta-reference + provider_type: inline::meta-reference datasetio: - - remote::huggingface - - inline::localfs + - provider_id: huggingface + provider_type: remote::huggingface + - provider_id: localfs + provider_type: inline::localfs scoring: - - inline::basic - - inline::llm-as-judge - - inline::braintrust + - provider_id: basic + provider_type: inline::basic + - provider_id: llm-as-judge + provider_type: inline::llm-as-judge + - provider_id: braintrust + provider_type: inline::braintrust tool_runtime: - - remote::brave-search - - remote::tavily-search - - inline::rag-runtime - - remote::model-context-protocol + - provider_id: brave-search + provider_type: remote::brave-search + - provider_id: tavily-search + provider_type: remote::tavily-search + - provider_id: rag-runtime + provider_type: inline::rag-runtime + - provider_id: model-context-protocol + provider_type: remote::model-context-protocol image_type: conda +image_name: starter additional_pip_packages: - aiosqlite - asyncpg diff --git a/llama_stack/templates/starter/run.yaml b/llama_stack/templates/starter/run.yaml index c38933f98..d60800ebb 100644 --- a/llama_stack/templates/starter/run.yaml +++ b/llama_stack/templates/starter/run.yaml @@ -56,7 +56,6 @@ providers: api_key: ${env.TOGETHER_API_KEY} - provider_id: ${env.ENABLE_BEDROCK:=__disabled__} provider_type: remote::bedrock - config: {} - provider_id: ${env.ENABLE_DATABRICKS:=__disabled__} provider_type: remote::databricks config: @@ -107,7 +106,6 @@ providers: api_key: ${env.PASSTHROUGH_API_KEY} - provider_id: ${env.ENABLE_SENTENCE_TRANSFORMERS:=sentence-transformers} provider_type: inline::sentence-transformers - config: {} vector_io: - provider_id: ${env.ENABLE_FAISS:=faiss} provider_type: inline::faiss @@ -208,10 +206,8 @@ providers: scoring: - provider_id: basic provider_type: inline::basic - config: {} - provider_id: llm-as-judge provider_type: inline::llm-as-judge - config: {} - provider_id: braintrust provider_type: inline::braintrust config: @@ -229,10 +225,8 @@ providers: max_results: 3 - provider_id: rag-runtime provider_type: inline::rag-runtime - config: {} - provider_id: model-context-protocol provider_type: remote::model-context-protocol - config: {} metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/registry.db diff --git a/llama_stack/templates/starter/starter.py b/llama_stack/templates/starter/starter.py index cee1094db..489117702 100644 --- a/llama_stack/templates/starter/starter.py +++ b/llama_stack/templates/starter/starter.py @@ -253,21 +253,91 @@ def get_distribution_template() -> DistributionTemplate: ] providers = { - "inference": ([p.provider_type for p in remote_inference_providers] + ["inline::sentence-transformers"]), - "vector_io": ([p.provider_type for p in vector_io_providers]), - "files": ["inline::localfs"], - "safety": ["inline::llama-guard"], - "agents": ["inline::meta-reference"], - "telemetry": ["inline::meta-reference"], - "post_training": ["inline::huggingface"], - "eval": ["inline::meta-reference"], - "datasetio": ["remote::huggingface", "inline::localfs"], - "scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"], + "inference": remote_inference_providers + + [ + Provider( + provider_id="sentence-transformers", + provider_type="inline::sentence-transformers", + ) + ], + "vector_io": vector_io_providers, + "files": [ + Provider( + provider_id="localfs", + provider_type="inline::localfs", + ) + ], + "safety": [ + Provider( + provider_id="llama-guard", + provider_type="inline::llama-guard", + ) + ], + "agents": [ + Provider( + provider_id="meta-reference", + provider_type="inline::meta-reference", + ) + ], + "telemetry": [ + Provider( + provider_id="meta-reference", + provider_type="inline::meta-reference", + ) + ], + "post_training": [ + Provider( + provider_id="huggingface", + provider_type="inline::huggingface", + ) + ], + "eval": [ + Provider( + provider_id="meta-reference", + provider_type="inline::meta-reference", + ) + ], + "datasetio": [ + Provider( + provider_id="huggingface", + provider_type="remote::huggingface", + ), + Provider( + provider_id="localfs", + provider_type="inline::localfs", + ), + ], + "scoring": [ + Provider( + provider_id="basic", + provider_type="inline::basic", + ), + Provider( + provider_id="llm-as-judge", + provider_type="inline::llm-as-judge", + ), + Provider( + provider_id="braintrust", + provider_type="inline::braintrust", + ), + ], "tool_runtime": [ - "remote::brave-search", - "remote::tavily-search", - "inline::rag-runtime", - "remote::model-context-protocol", + Provider( + provider_id="brave-search", + provider_type="remote::brave-search", + ), + Provider( + provider_id="tavily-search", + provider_type="remote::tavily-search", + ), + Provider( + provider_id="rag-runtime", + provider_type="inline::rag-runtime", + ), + Provider( + provider_id="model-context-protocol", + provider_type="remote::model-context-protocol", + ), ], } files_provider = Provider( diff --git a/llama_stack/templates/template.py b/llama_stack/templates/template.py index fb2528873..e9054f95d 100644 --- a/llama_stack/templates/template.py +++ b/llama_stack/templates/template.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from pathlib import Path -from typing import Literal +from typing import Any, Literal import jinja2 import rich @@ -35,6 +35,51 @@ from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig from llama_stack.providers.utils.sqlstore.sqlstore import get_pip_packages as get_sql_pip_packages +def filter_empty_values(obj: Any) -> Any: + """Recursively filter out specific empty values from a dictionary or list. + + This function removes: + - Empty strings ('') only when they are the 'module' field + - Empty dictionaries ({}) only when they are the 'config' field + - None values (always excluded) + """ + if obj is None: + return None + + if isinstance(obj, dict): + filtered = {} + for key, value in obj.items(): + # Special handling for specific fields + if key == "module" and isinstance(value, str) and value == "": + # Skip empty module strings + continue + elif key == "config" and isinstance(value, dict) and not value: + # Skip empty config dictionaries + continue + elif key == "container_image" and not value: + # Skip empty container_image names + continue + else: + # For all other fields, recursively filter but preserve empty values + filtered_value = filter_empty_values(value) + # if filtered_value is not None: + filtered[key] = filtered_value + return filtered + + elif isinstance(obj, list): + filtered = [] + for item in obj: + filtered_item = filter_empty_values(item) + if filtered_item is not None: + filtered.append(filtered_item) + return filtered + + else: + # For all other types (including empty strings and dicts that aren't module/config), + # preserve them as-is + return obj + + def get_model_registry( available_models: dict[str, list[ProviderModelEntry]], ) -> tuple[list[ModelInput], bool]: @@ -138,31 +183,26 @@ class RunConfigSettings(BaseModel): def run_config( self, name: str, - providers: dict[str, list[str]], + providers: dict[str, list[Provider]], container_image: str | None = None, ) -> dict: provider_registry = get_provider_registry() - provider_configs = {} - for api_str, provider_types in providers.items(): + for api_str, provider_objs in providers.items(): if api_providers := self.provider_overrides.get(api_str): # Convert Provider objects to dicts for YAML serialization - provider_configs[api_str] = [ - p.model_dump(exclude_none=True) if isinstance(p, Provider) else p for p in api_providers - ] + provider_configs[api_str] = [p.model_dump(exclude_none=True) for p in api_providers] continue provider_configs[api_str] = [] - for provider_type in provider_types: - provider_id = provider_type.split("::")[-1] - + for provider in provider_objs: api = Api(api_str) - if provider_type not in provider_registry[api]: - raise ValueError(f"Unknown provider type: {provider_type} for API: {api_str}") + if provider.provider_type not in provider_registry[api]: + raise ValueError(f"Unknown provider type: {provider.provider_type} for API: {api_str}") - config_class = provider_registry[api][provider_type].config_class + config_class = provider_registry[api][provider.provider_type].config_class assert config_class is not None, ( - f"No config class for provider type: {provider_type} for API: {api_str}" + f"No config class for provider type: {provider.provider_type} for API: {api_str}" ) config_class = instantiate_class_type(config_class) @@ -171,14 +211,9 @@ class RunConfigSettings(BaseModel): else: config = {} - provider_configs[api_str].append( - Provider( - provider_id=provider_id, - provider_type=provider_type, - config=config, - ).model_dump(exclude_none=True) - ) - + provider.config = config + # Convert Provider object to dict for YAML serialization + provider_configs[api_str].append(provider.model_dump(exclude_none=True)) # Get unique set of APIs from providers apis = sorted(providers.keys()) @@ -222,7 +257,7 @@ class DistributionTemplate(BaseModel): description: str distro_type: Literal["self_hosted", "remote_hosted", "ondevice"] - providers: dict[str, list[str]] + providers: dict[str, list[Provider]] run_configs: dict[str, RunConfigSettings] template_path: Path | None = None @@ -255,13 +290,28 @@ class DistributionTemplate(BaseModel): if self.additional_pip_packages: additional_pip_packages.extend(self.additional_pip_packages) + # Create minimal providers for build config (without runtime configs) + build_providers = {} + for api, providers in self.providers.items(): + build_providers[api] = [] + for provider in providers: + # Create a minimal provider object with only essential build information + build_provider = Provider( + provider_id=provider.provider_id, + provider_type=provider.provider_type, + config={}, # Empty config for build + module=provider.module, + ) + build_providers[api].append(build_provider) + return BuildConfig( distribution_spec=DistributionSpec( description=self.description, container_image=self.container_image, - providers=self.providers, + providers=build_providers, ), - image_type="conda", # default to conda, can be overridden + image_type="conda", + image_name=self.name, additional_pip_packages=sorted(set(additional_pip_packages)), ) @@ -270,7 +320,7 @@ class DistributionTemplate(BaseModel): providers_table += "|-----|-------------|\n" for api, providers in sorted(self.providers.items()): - providers_str = ", ".join(f"`{p}`" for p in providers) + providers_str = ", ".join(f"`{p.provider_type}`" for p in providers) providers_table += f"| {api} | {providers_str} |\n" template = self.template_path.read_text() @@ -334,7 +384,7 @@ class DistributionTemplate(BaseModel): build_config = self.build_config() with open(yaml_output_dir / "build.yaml", "w") as f: yaml.safe_dump( - build_config.model_dump(exclude_none=True), + filter_empty_values(build_config.model_dump(exclude_none=True)), f, sort_keys=False, ) @@ -343,7 +393,7 @@ class DistributionTemplate(BaseModel): run_config = settings.run_config(self.name, self.providers, self.container_image) with open(yaml_output_dir / yaml_pth, "w") as f: yaml.safe_dump( - {k: v for k, v in run_config.items() if v is not None}, + filter_empty_values(run_config), f, sort_keys=False, ) diff --git a/llama_stack/templates/watsonx/build.yaml b/llama_stack/templates/watsonx/build.yaml index 08ee2c5ce..bc992f0c7 100644 --- a/llama_stack/templates/watsonx/build.yaml +++ b/llama_stack/templates/watsonx/build.yaml @@ -3,31 +3,49 @@ distribution_spec: description: Use watsonx for running LLM inference providers: inference: - - remote::watsonx - - inline::sentence-transformers + - provider_id: watsonx + provider_type: remote::watsonx + - provider_id: sentence-transformers + provider_type: inline::sentence-transformers vector_io: - - inline::faiss + - provider_id: faiss + provider_type: inline::faiss safety: - - inline::llama-guard + - provider_id: llama-guard + provider_type: inline::llama-guard agents: - - inline::meta-reference + - provider_id: meta-reference + provider_type: inline::meta-reference telemetry: - - inline::meta-reference + - provider_id: meta-reference + provider_type: inline::meta-reference eval: - - inline::meta-reference + - provider_id: meta-reference + provider_type: inline::meta-reference datasetio: - - remote::huggingface - - inline::localfs + - provider_id: huggingface + provider_type: remote::huggingface + - provider_id: localfs + provider_type: inline::localfs scoring: - - inline::basic - - inline::llm-as-judge - - inline::braintrust + - provider_id: basic + provider_type: inline::basic + - provider_id: llm-as-judge + provider_type: inline::llm-as-judge + - provider_id: braintrust + provider_type: inline::braintrust tool_runtime: - - remote::brave-search - - remote::tavily-search - - inline::rag-runtime - - remote::model-context-protocol + - provider_id: brave-search + provider_type: remote::brave-search + - provider_id: tavily-search + provider_type: remote::tavily-search + - provider_id: rag-runtime + provider_type: inline::rag-runtime + - provider_id: model-context-protocol + provider_type: remote::model-context-protocol image_type: conda +image_name: watsonx additional_pip_packages: -- aiosqlite - sqlalchemy[asyncio] +- aiosqlite +- aiosqlite diff --git a/llama_stack/templates/watsonx/run.yaml b/llama_stack/templates/watsonx/run.yaml index afbbdb917..f5fe31bef 100644 --- a/llama_stack/templates/watsonx/run.yaml +++ b/llama_stack/templates/watsonx/run.yaml @@ -20,7 +20,6 @@ providers: project_id: ${env.WATSONX_PROJECT_ID:=} - provider_id: sentence-transformers provider_type: inline::sentence-transformers - config: {} vector_io: - provider_id: faiss provider_type: inline::faiss @@ -74,10 +73,8 @@ providers: scoring: - provider_id: basic provider_type: inline::basic - config: {} - provider_id: llm-as-judge provider_type: inline::llm-as-judge - config: {} - provider_id: braintrust provider_type: inline::braintrust config: @@ -95,10 +92,8 @@ providers: max_results: 3 - provider_id: rag-runtime provider_type: inline::rag-runtime - config: {} - provider_id: model-context-protocol provider_type: remote::model-context-protocol - config: {} metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/watsonx}/registry.db diff --git a/llama_stack/templates/watsonx/watsonx.py b/llama_stack/templates/watsonx/watsonx.py index ea185f05d..c13bbea36 100644 --- a/llama_stack/templates/watsonx/watsonx.py +++ b/llama_stack/templates/watsonx/watsonx.py @@ -18,19 +18,87 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin def get_distribution_template() -> DistributionTemplate: providers = { - "inference": ["remote::watsonx", "inline::sentence-transformers"], - "vector_io": ["inline::faiss"], - "safety": ["inline::llama-guard"], - "agents": ["inline::meta-reference"], - "telemetry": ["inline::meta-reference"], - "eval": ["inline::meta-reference"], - "datasetio": ["remote::huggingface", "inline::localfs"], - "scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"], + "inference": [ + Provider( + provider_id="watsonx", + provider_type="remote::watsonx", + ), + Provider( + provider_id="sentence-transformers", + provider_type="inline::sentence-transformers", + ), + ], + "vector_io": [ + Provider( + provider_id="faiss", + provider_type="inline::faiss", + ) + ], + "safety": [ + Provider( + provider_id="llama-guard", + provider_type="inline::llama-guard", + ) + ], + "agents": [ + Provider( + provider_id="meta-reference", + provider_type="inline::meta-reference", + ) + ], + "telemetry": [ + Provider( + provider_id="meta-reference", + provider_type="inline::meta-reference", + ) + ], + "eval": [ + Provider( + provider_id="meta-reference", + provider_type="inline::meta-reference", + ) + ], + "datasetio": [ + Provider( + provider_id="huggingface", + provider_type="remote::huggingface", + ), + Provider( + provider_id="localfs", + provider_type="inline::localfs", + ), + ], + "scoring": [ + Provider( + provider_id="basic", + provider_type="inline::basic", + ), + Provider( + provider_id="llm-as-judge", + provider_type="inline::llm-as-judge", + ), + Provider( + provider_id="braintrust", + provider_type="inline::braintrust", + ), + ], "tool_runtime": [ - "remote::brave-search", - "remote::tavily-search", - "inline::rag-runtime", - "remote::model-context-protocol", + Provider( + provider_id="brave-search", + provider_type="remote::brave-search", + ), + Provider( + provider_id="tavily-search", + provider_type="remote::tavily-search", + ), + Provider( + provider_id="rag-runtime", + provider_type="inline::rag-runtime", + ), + Provider( + provider_id="model-context-protocol", + provider_type="remote::model-context-protocol", + ), ], } diff --git a/llama_stack/ui/package-lock.json b/llama_stack/ui/package-lock.json index 158569241..6412741aa 100644 --- a/llama_stack/ui/package-lock.json +++ b/llama_stack/ui/package-lock.json @@ -15,7 +15,7 @@ "@radix-ui/react-tooltip": "^1.2.6", "class-variance-authority": "^0.7.1", "clsx": "^2.1.1", - "llama-stack-client": "^0.2.14", + "llama-stack-client": "^0.2.15", "lucide-react": "^0.510.0", "next": "15.3.3", "next-auth": "^4.24.11", @@ -6468,14 +6468,15 @@ } }, "node_modules/form-data": { - "version": "4.0.2", - "resolved": "https://registry.npmjs.org/form-data/-/form-data-4.0.2.tgz", - "integrity": "sha512-hGfm/slu0ZabnNt4oaRZ6uREyfCj6P4fT/n6A1rGV+Z0VdGXjfOhVUpkn6qVQONHGIFwmveGXyDs75+nr6FM8w==", + "version": "4.0.4", + "resolved": "https://registry.npmjs.org/form-data/-/form-data-4.0.4.tgz", + "integrity": "sha512-KrGhL9Q4zjj0kiUt5OO4Mr/A/jlI2jDYs5eHBpYHPcBEVSiipAvn2Ko2HnPe20rmcuuvMHNdZFp+4IlGTMF0Ow==", "license": "MIT", "dependencies": { "asynckit": "^0.4.0", "combined-stream": "^1.0.8", "es-set-tostringtag": "^2.1.0", + "hasown": "^2.0.2", "mime-types": "^2.1.12" }, "engines": { @@ -9099,9 +9100,9 @@ "license": "MIT" }, "node_modules/llama-stack-client": { - "version": "0.2.14", - "resolved": "https://registry.npmjs.org/llama-stack-client/-/llama-stack-client-0.2.14.tgz", - "integrity": "sha512-bVU3JHp+EPEKR0Vb9vcd9ZyQj/72jSDuptKLwOXET9WrkphIQ8xuW5ueecMTgq8UEls3lwB3HiZM2cDOR9eDsQ==", + "version": "0.2.15", + "resolved": "https://registry.npmjs.org/llama-stack-client/-/llama-stack-client-0.2.15.tgz", + "integrity": "sha512-onfYzgPWAxve4uP7BuiK/ZdEC7w6X1PIXXXpQY57qZC7C4xUAM5kwfT3JWIe/jE22Lwc2vTN1ScfYlAYcoYAsg==", "license": "Apache-2.0", "dependencies": { "@types/node": "^18.11.18", diff --git a/scripts/gen-ci-docs.py b/scripts/gen-ci-docs.py new file mode 100755 index 000000000..630cfe765 --- /dev/null +++ b/scripts/gen-ci-docs.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pathlib import Path + +import yaml + +REPO_ROOT = Path(__file__).parent.parent + + +def parse_workflow_file(file_path): + """Parse a workflow YAML file and extract name and run-name.""" + try: + with open(file_path, encoding="utf-8") as f: + content = yaml.safe_load(f) + + name = content["name"] + run_name = content["run-name"] + + return name, run_name + except Exception as e: + raise Exception(f"Error parsing {file_path}") from e + + +def generate_ci_docs(): + """Generate the CI documentation README.md file.""" + + # Define paths + workflows_dir = REPO_ROOT / ".github/workflows" + readme_path = workflows_dir / "README.md" + + # Header section to preserve + header = """# Llama Stack CI + +Llama Stack uses GitHub Actions for Continous Integration (CI). Below is a table detailing what CI the project includes and the purpose. + +| Name | File | Purpose | +| ---- | ---- | ------- | +""" + + # Get all .yml files in workflows directory + yml_files = [] + for file_path in workflows_dir.glob("*.yml"): + yml_files.append(file_path) + + # Sort files alphabetically for consistent output + yml_files.sort(key=lambda x: x.name) + + # Generate table rows + table_rows = [] + for file_path in yml_files: + name, run_name = parse_workflow_file(file_path) + filename = file_path.name + + # Create markdown link in the format [filename.yml](filename.yml) + file_link = f"[{filename}]({filename})" + + # Create table row + row = f"| {name} | {file_link} | {run_name} |" + table_rows.append(row) + + # Combine header and table rows + content = header + "\n".join(table_rows) + "\n" + + # Write to README.md + with open(readme_path, "w", encoding="utf-8") as f: + f.write(content) + + print(f"Generated {readme_path} with {len(table_rows)} workflow entries") + + +if __name__ == "__main__": + generate_ci_docs() diff --git a/scripts/install.sh b/scripts/install.sh index b5afe43b8..5dc74fae1 100755 --- a/scripts/install.sh +++ b/scripts/install.sh @@ -15,11 +15,40 @@ set -Eeuo pipefail PORT=8321 OLLAMA_PORT=11434 MODEL_ALIAS="llama3.2:3b" -SERVER_IMAGE="docker.io/llamastack/distribution-ollama:0.2.2" -WAIT_TIMEOUT=300 +SERVER_IMAGE="docker.io/llamastack/distribution-starter:latest" +WAIT_TIMEOUT=30 +TEMP_LOG="" + +# Cleanup function to remove temporary files +cleanup() { + if [ -n "$TEMP_LOG" ] && [ -f "$TEMP_LOG" ]; then + rm -f "$TEMP_LOG" + fi +} + +# Set up trap to clean up on exit, error, or interrupt +trap cleanup EXIT ERR INT TERM log(){ printf "\e[1;32m%s\e[0m\n" "$*"; } -die(){ printf "\e[1;31m❌ %s\e[0m\n" "$*" >&2; exit 1; } +die(){ + printf "\e[1;31m❌ %s\e[0m\n" "$*" >&2 + printf "\e[1;31m🐛 Report an issue @ https://github.com/meta-llama/llama-stack/issues if you think it's a bug\e[0m\n" >&2 + exit 1 +} + +# Helper function to execute command with logging +execute_with_log() { + local cmd=("$@") + TEMP_LOG=$(mktemp) + if ! "${cmd[@]}" > "$TEMP_LOG" 2>&1; then + log "❌ Command failed; dumping output:" + log "Command that failed: ${cmd[*]}" + log "Command output:" + cat "$TEMP_LOG" + return 1 + fi + return 0 +} wait_for_service() { local url="$1" @@ -27,7 +56,7 @@ wait_for_service() { local timeout="$3" local name="$4" local start ts - log "⏳ Waiting for ${name}…" + log "⏳ Waiting for ${name}..." start=$(date +%s) while true; do if curl --retry 5 --retry-delay 1 --retry-max-time "$timeout" --retry-all-errors --silent --fail "$url" 2>/dev/null | grep -q "$pattern"; then @@ -38,24 +67,24 @@ wait_for_service() { return 1 fi printf '.' - sleep 1 done + printf '\n' return 0 } usage() { cat << EOF -📚 Llama-Stack Deployment Script +📚 Llama Stack Deployment Script Description: - This script sets up and deploys Llama-Stack with Ollama integration in containers. + This script sets up and deploys Llama Stack with Ollama integration in containers. It handles both Docker and Podman runtimes and includes automatic platform detection. Usage: $(basename "$0") [OPTIONS] Options: - -p, --port PORT Server port for Llama-Stack (default: ${PORT}) + -p, --port PORT Server port for Llama Stack (default: ${PORT}) -o, --ollama-port PORT Ollama service port (default: ${OLLAMA_PORT}) -m, --model MODEL Model alias to use (default: ${MODEL_ALIAS}) -i, --image IMAGE Server image (default: ${SERVER_IMAGE}) @@ -129,15 +158,15 @@ fi # CONTAINERS_MACHINE_PROVIDER=libkrun podman machine init if [ "$ENGINE" = "podman" ] && [ "$(uname -s)" = "Darwin" ]; then if ! podman info &>/dev/null; then - log "⌛️ Initializing Podman VM…" + log "⌛️ Initializing Podman VM..." podman machine init &>/dev/null || true podman machine start &>/dev/null || true - log "⌛️ Waiting for Podman API…" + log "⌛️ Waiting for Podman API..." until podman info &>/dev/null; do sleep 1 done - log "✅ Podman VM is up" + log "✅ Podman VM is up." fi fi @@ -145,8 +174,10 @@ fi for name in ollama-server llama-stack; do ids=$($ENGINE ps -aq --filter "name=^${name}$") if [ -n "$ids" ]; then - log "⚠️ Found existing container(s) for '${name}', removing…" - $ENGINE rm -f "$ids" > /dev/null 2>&1 + log "⚠️ Found existing container(s) for '${name}', removing..." + if ! execute_with_log $ENGINE rm -f "$ids"; then + die "Container cleanup failed" + fi fi done @@ -154,28 +185,32 @@ done # 0. Create a shared network ############################################################################### if ! $ENGINE network inspect llama-net >/dev/null 2>&1; then - log "🌐 Creating network…" - $ENGINE network create llama-net >/dev/null 2>&1 + log "🌐 Creating network..." + if ! execute_with_log $ENGINE network create llama-net; then + die "Network creation failed" + fi fi ############################################################################### # 1. Ollama ############################################################################### -log "🦙 Starting Ollama…" -$ENGINE run -d "${PLATFORM_OPTS[@]}" --name ollama-server \ +log "🦙 Starting Ollama..." +if ! execute_with_log $ENGINE run -d "${PLATFORM_OPTS[@]}" --name ollama-server \ --network llama-net \ -p "${OLLAMA_PORT}:${OLLAMA_PORT}" \ - docker.io/ollama/ollama > /dev/null 2>&1 + docker.io/ollama/ollama > /dev/null 2>&1; then + die "Ollama startup failed" +fi if ! wait_for_service "http://localhost:${OLLAMA_PORT}/" "Ollama" "$WAIT_TIMEOUT" "Ollama daemon"; then - log "❌ Ollama daemon did not become ready in ${WAIT_TIMEOUT}s; dumping container logs:" + log "❌ Ollama daemon did not become ready in ${WAIT_TIMEOUT}s; dumping container logs:" $ENGINE logs --tail 200 ollama-server die "Ollama startup failed" fi -log "📦 Ensuring model is pulled: ${MODEL_ALIAS}…" -if ! $ENGINE exec ollama-server ollama pull "${MODEL_ALIAS}" > /dev/null 2>&1; then - log "❌ Failed to pull model ${MODEL_ALIAS}; dumping container logs:" +log "📦 Ensuring model is pulled: ${MODEL_ALIAS}..." +if ! execute_with_log $ENGINE exec ollama-server ollama pull "${MODEL_ALIAS}"; then + log "❌ Failed to pull model ${MODEL_ALIAS}; dumping container logs:" $ENGINE logs --tail 200 ollama-server die "Model pull failed" fi @@ -187,25 +222,29 @@ cmd=( run -d "${PLATFORM_OPTS[@]}" --name llama-stack \ --network llama-net \ -p "${PORT}:${PORT}" \ "${SERVER_IMAGE}" --port "${PORT}" \ - --env INFERENCE_MODEL="${MODEL_ALIAS}" \ - --env OLLAMA_URL="http://ollama-server:${OLLAMA_PORT}" ) + --env OLLAMA_INFERENCE_MODEL="${MODEL_ALIAS}" \ + --env OLLAMA_URL="http://ollama-server:${OLLAMA_PORT}" \ + --env ENABLE_OLLAMA=ollama) -log "🦙 Starting Llama‑Stack…" -$ENGINE "${cmd[@]}" > /dev/null 2>&1 +log "🦙 Starting Llama Stack..." +if ! execute_with_log $ENGINE "${cmd[@]}"; then + die "Llama Stack startup failed" +fi -if ! wait_for_service "http://127.0.0.1:${PORT}/v1/health" "OK" "$WAIT_TIMEOUT" "Llama-Stack API"; then - log "❌ Llama-Stack did not become ready in ${WAIT_TIMEOUT}s; dumping container logs:" +if ! wait_for_service "http://127.0.0.1:${PORT}/v1/health" "OK" "$WAIT_TIMEOUT" "Llama Stack API"; then + log "❌ Llama Stack did not become ready in ${WAIT_TIMEOUT}s; dumping container logs:" $ENGINE logs --tail 200 llama-stack - die "Llama-Stack startup failed" + die "Llama Stack startup failed" fi ############################################################################### # Done ############################################################################### log "" -log "🎉 Llama‑Stack is ready!" +log "🎉 Llama Stack is ready!" log "👉 API endpoint: http://localhost:${PORT}" log "📖 Documentation: https://llama-stack.readthedocs.io/en/latest/references/index.html" -log "💻 To access the llama‑stack CLI, exec into the container:" +log "💻 To access the llama stack CLI, exec into the container:" log " $ENGINE exec -ti llama-stack bash" +log "🐛 Report an issue @ https://github.com/meta-llama/llama-stack/issues if you think it's a bug" log "" diff --git a/tests/external-provider/llama-stack-provider-ollama/README.md b/tests/external-provider/llama-stack-provider-ollama/README.md deleted file mode 100644 index 8bd2b6a87..000000000 --- a/tests/external-provider/llama-stack-provider-ollama/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# Ollama external provider for Llama Stack - -Template code to create a new external provider for Llama Stack. diff --git a/tests/external-provider/llama-stack-provider-ollama/custom_ollama.yaml b/tests/external-provider/llama-stack-provider-ollama/custom_ollama.yaml deleted file mode 100644 index 2ae1e2cf3..000000000 --- a/tests/external-provider/llama-stack-provider-ollama/custom_ollama.yaml +++ /dev/null @@ -1,7 +0,0 @@ -adapter: - adapter_type: custom_ollama - pip_packages: ["ollama", "aiohttp", "tests/external-provider/llama-stack-provider-ollama"] - config_class: llama_stack_provider_ollama.config.OllamaImplConfig - module: llama_stack_provider_ollama -api_dependencies: [] -optional_api_dependencies: [] diff --git a/tests/external-provider/llama-stack-provider-ollama/pyproject.toml b/tests/external-provider/llama-stack-provider-ollama/pyproject.toml deleted file mode 100644 index ca1fecc42..000000000 --- a/tests/external-provider/llama-stack-provider-ollama/pyproject.toml +++ /dev/null @@ -1,43 +0,0 @@ -[project] -dependencies = [ - "llama-stack", - "pydantic", - "ollama", - "aiohttp", - "aiosqlite", - "autoevals", - "chardet", - "chromadb-client", - "datasets", - "faiss-cpu", - "fastapi", - "fire", - "httpx", - "matplotlib", - "mcp", - "nltk", - "numpy", - "openai", - "opentelemetry-exporter-otlp-proto-http", - "opentelemetry-sdk", - "pandas", - "pillow", - "psycopg2-binary", - "pymongo", - "pypdf", - "redis", - "requests", - "scikit-learn", - "scipy", - "sentencepiece", - "tqdm", - "transformers", - "tree_sitter", - "uvicorn", -] - -name = "llama-stack-provider-ollama" -version = "0.1.0" -description = "External provider for Ollama using the Llama Stack API" -readme = "README.md" -requires-python = ">=3.12" diff --git a/tests/external-provider/llama-stack-provider-ollama/run.yaml b/tests/external-provider/llama-stack-provider-ollama/run.yaml deleted file mode 100644 index 65fd7571c..000000000 --- a/tests/external-provider/llama-stack-provider-ollama/run.yaml +++ /dev/null @@ -1,124 +0,0 @@ -version: 2 -image_name: ollama -apis: -- agents -- datasetio -- eval -- inference -- safety -- scoring -- telemetry -- tool_runtime -- vector_io - -providers: - inference: - - provider_id: ollama - provider_type: remote::ollama - config: - url: ${env.OLLAMA_URL:=http://localhost:11434} - vector_io: - - provider_id: faiss - provider_type: inline::faiss - config: - metadata_store: - type: sqlite - namespace: null - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/faiss_store.db - safety: - - provider_id: llama-guard - provider_type: inline::llama-guard - config: {} - agents: - - provider_id: meta-reference - provider_type: inline::meta-reference - config: - agents_store: - type: sqlite - namespace: null - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/agents_store.db - responses_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/responses_store.db - telemetry: - - provider_id: meta-reference - provider_type: inline::meta-reference - config: - service_name: "${env.OTEL_SERVICE_NAME:=\u200b}" - sinks: ${env.TELEMETRY_SINKS:=console,sqlite} - sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/trace_store.db - eval: - - provider_id: meta-reference - provider_type: inline::meta-reference - config: - metadata_store: - type: sqlite - namespace: null - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/huggingface_datasetio.db - datasetio: - - provider_id: huggingface - provider_type: remote::huggingface - config: - metadata_store: - type: sqlite - namespace: null - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/huggingface_datasetio.db - - provider_id: localfs - provider_type: inline::localfs - config: - metadata_store: - type: sqlite - namespace: null - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/localfs_datasetio.db - scoring: - - provider_id: basic - provider_type: inline::basic - config: {} - tool_runtime: - - provider_id: brave-search - provider_type: remote::brave-search - config: - api_key: ${env.BRAVE_SEARCH_API_KEY:+} - max_results: 3 - - provider_id: tavily-search - provider_type: remote::tavily-search - config: - api_key: ${env.TAVILY_SEARCH_API_KEY:+} - max_results: 3 - - provider_id: rag-runtime - provider_type: inline::rag-runtime - config: {} - - provider_id: wolfram-alpha - provider_type: remote::wolfram-alpha - config: - api_key: ${env.WOLFRAM_ALPHA_API_KEY:+} - -metadata_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/registry.db -models: -- metadata: {} - model_id: ${env.INFERENCE_MODEL} - provider_id: custom_ollama - model_type: llm -- metadata: - embedding_dimension: 384 - model_id: all-MiniLM-L6-v2 - provider_id: custom_ollama - provider_model_id: all-minilm:l6-v2 - model_type: embedding -shields: [] -vector_dbs: [] -datasets: [] -scoring_fns: [] -benchmarks: [] -tool_groups: -- toolgroup_id: builtin::websearch - provider_id: tavily-search -- toolgroup_id: builtin::rag - provider_id: rag-runtime -- toolgroup_id: builtin::wolfram_alpha - provider_id: wolfram-alpha -server: - port: 8321 -external_providers_dir: ~/.llama/providers.d diff --git a/tests/external-provider/llama-stack-provider-ollama/custom-distro.yaml b/tests/external/build.yaml similarity index 55% rename from tests/external-provider/llama-stack-provider-ollama/custom-distro.yaml rename to tests/external/build.yaml index 1f3ab3817..c928febdb 100644 --- a/tests/external-provider/llama-stack-provider-ollama/custom-distro.yaml +++ b/tests/external/build.yaml @@ -2,8 +2,10 @@ version: '2' distribution_spec: description: Custom distro for CI tests providers: - inference: - - remote::custom_ollama -image_type: container + weather: + - provider_id: kaze + provider_type: remote::kaze +image_type: venv image_name: ci-test external_providers_dir: ~/.llama/providers.d +external_apis_dir: ~/.llama/apis.d diff --git a/tests/external/kaze.yaml b/tests/external/kaze.yaml new file mode 100644 index 000000000..c61ac0e31 --- /dev/null +++ b/tests/external/kaze.yaml @@ -0,0 +1,6 @@ +adapter: + adapter_type: kaze + pip_packages: ["tests/external/llama-stack-provider-kaze"] + config_class: llama_stack_provider_kaze.config.KazeProviderConfig + module: llama_stack_provider_kaze +optional_api_dependencies: [] diff --git a/tests/external/llama-stack-api-weather/pyproject.toml b/tests/external/llama-stack-api-weather/pyproject.toml new file mode 100644 index 000000000..566e1e9aa --- /dev/null +++ b/tests/external/llama-stack-api-weather/pyproject.toml @@ -0,0 +1,15 @@ +[project] +name = "llama-stack-api-weather" +version = "0.1.0" +description = "Weather API for Llama Stack" +readme = "README.md" +requires-python = ">=3.10" +dependencies = ["llama-stack", "pydantic"] + +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +where = ["src"] +include = ["llama_stack_api_weather", "llama_stack_api_weather.*"] diff --git a/tests/external/llama-stack-api-weather/src/llama_stack_api_weather/__init__.py b/tests/external/llama-stack-api-weather/src/llama_stack_api_weather/__init__.py new file mode 100644 index 000000000..d0227615d --- /dev/null +++ b/tests/external/llama-stack-api-weather/src/llama_stack_api_weather/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +"""Weather API for Llama Stack.""" + +from .weather import WeatherProvider, available_providers + +__all__ = ["WeatherProvider", "available_providers"] diff --git a/tests/external/llama-stack-api-weather/src/llama_stack_api_weather/weather.py b/tests/external/llama-stack-api-weather/src/llama_stack_api_weather/weather.py new file mode 100644 index 000000000..4b3bfb641 --- /dev/null +++ b/tests/external/llama-stack-api-weather/src/llama_stack_api_weather/weather.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Protocol + +from llama_stack.providers.datatypes import AdapterSpec, Api, ProviderSpec, RemoteProviderSpec +from llama_stack.schema_utils import webmethod + + +def available_providers() -> list[ProviderSpec]: + return [ + RemoteProviderSpec( + api=Api.weather, + provider_type="remote::kaze", + config_class="llama_stack_provider_kaze.KazeProviderConfig", + adapter=AdapterSpec( + adapter_type="kaze", + module="llama_stack_provider_kaze", + pip_packages=["llama_stack_provider_kaze"], + config_class="llama_stack_provider_kaze.KazeProviderConfig", + ), + ), + ] + + +class WeatherProvider(Protocol): + """ + A protocol for the Weather API. + """ + + @webmethod(route="/weather/locations", method="GET") + async def get_available_locations() -> dict[str, list[str]]: + """ + Get the available locations. + """ + ... diff --git a/tests/external/llama-stack-provider-kaze/pyproject.toml b/tests/external/llama-stack-provider-kaze/pyproject.toml new file mode 100644 index 000000000..7bbf1f843 --- /dev/null +++ b/tests/external/llama-stack-provider-kaze/pyproject.toml @@ -0,0 +1,15 @@ +[project] +name = "llama-stack-provider-kaze" +version = "0.1.0" +description = "Kaze weather provider for Llama Stack" +readme = "README.md" +requires-python = ">=3.10" +dependencies = ["llama-stack", "pydantic", "aiohttp"] + +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +where = ["src"] +include = ["llama_stack_provider_kaze", "llama_stack_provider_kaze.*"] diff --git a/tests/external/llama-stack-provider-kaze/src/llama_stack_provider_kaze/__init__.py b/tests/external/llama-stack-provider-kaze/src/llama_stack_provider_kaze/__init__.py new file mode 100644 index 000000000..581ff38c7 --- /dev/null +++ b/tests/external/llama-stack-provider-kaze/src/llama_stack_provider_kaze/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +"""Kaze weather provider for Llama Stack.""" + +from .config import KazeProviderConfig +from .kaze import WeatherKazeAdapter + +__all__ = ["KazeProviderConfig", "WeatherKazeAdapter"] + + +async def get_adapter_impl(config: KazeProviderConfig, _deps): + from .kaze import WeatherKazeAdapter + + impl = WeatherKazeAdapter(config) + await impl.initialize() + return impl diff --git a/tests/external/llama-stack-provider-kaze/src/llama_stack_provider_kaze/config.py b/tests/external/llama-stack-provider-kaze/src/llama_stack_provider_kaze/config.py new file mode 100644 index 000000000..4b82698ed --- /dev/null +++ b/tests/external/llama-stack-provider-kaze/src/llama_stack_provider_kaze/config.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel + + +class KazeProviderConfig(BaseModel): + """Configuration for the Kaze weather provider.""" diff --git a/tests/external/llama-stack-provider-kaze/src/llama_stack_provider_kaze/kaze.py b/tests/external/llama-stack-provider-kaze/src/llama_stack_provider_kaze/kaze.py new file mode 100644 index 000000000..120b5438d --- /dev/null +++ b/tests/external/llama-stack-provider-kaze/src/llama_stack_provider_kaze/kaze.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack_api_weather.weather import WeatherProvider + +from .config import KazeProviderConfig + + +class WeatherKazeAdapter(WeatherProvider): + """Kaze weather provider implementation.""" + + def __init__( + self, + config: KazeProviderConfig, + ) -> None: + self.config = config + + async def initialize(self) -> None: + pass + + async def get_available_locations(self) -> dict[str, list[str]]: + """Get available weather locations.""" + return {"locations": ["Paris", "Tokyo"]} diff --git a/tests/external/ramalama-stack/build.yaml b/tests/external/ramalama-stack/build.yaml new file mode 100644 index 000000000..c781e6537 --- /dev/null +++ b/tests/external/ramalama-stack/build.yaml @@ -0,0 +1,14 @@ +version: 2 +distribution_spec: + description: Use (an external) Ramalama server for running LLM inference + container_image: null + providers: + inference: + - provider_id: ramalama + provider_type: remote::ramalama + module: ramalama_stack==0.3.0a0 +image_type: venv +image_name: ramalama-stack-test +additional_pip_packages: +- aiosqlite +- sqlalchemy[asyncio] diff --git a/tests/external/ramalama-stack/run.yaml b/tests/external/ramalama-stack/run.yaml new file mode 100644 index 000000000..9d1d34df3 --- /dev/null +++ b/tests/external/ramalama-stack/run.yaml @@ -0,0 +1,12 @@ +version: 2 +image_name: ramalama +apis: +- inference +providers: + inference: + - provider_id: ramalama + provider_type: remote::ramalama + module: ramalama_stack==0.3.0a0 + config: {} +server: + port: 8321 diff --git a/tests/external/run-byoa.yaml b/tests/external/run-byoa.yaml new file mode 100644 index 000000000..5774ae9da --- /dev/null +++ b/tests/external/run-byoa.yaml @@ -0,0 +1,13 @@ +version: "2" +image_name: "llama-stack-api-weather" +apis: + - weather +providers: + weather: + - provider_id: kaze + provider_type: remote::kaze + config: {} +external_apis_dir: ~/.llama/apis.d +external_providers_dir: ~/.llama/providers.d +server: + port: 8321 diff --git a/tests/external/weather.yaml b/tests/external/weather.yaml new file mode 100644 index 000000000..a84fcc921 --- /dev/null +++ b/tests/external/weather.yaml @@ -0,0 +1,4 @@ +module: llama_stack_api_weather +name: weather +pip_packages: ["tests/external/llama-stack-api-weather"] +protocol: WeatherProvider diff --git a/tests/integration/inference/test_openai_completion.py b/tests/integration/inference/test_openai_completion.py index e82714ffd..52227d5e3 100644 --- a/tests/integration/inference/test_openai_completion.py +++ b/tests/integration/inference/test_openai_completion.py @@ -179,9 +179,7 @@ def test_openai_completion_prompt_logprobs(llama_stack_client, client_with_model model=text_model_id, prompt=prompt, stream=False, - extra_body={ - "prompt_logprobs": prompt_logprobs, - }, + prompt_logprobs=prompt_logprobs, ) assert len(response.choices) > 0 choice = response.choices[0] @@ -196,9 +194,7 @@ def test_openai_completion_guided_choice(llama_stack_client, client_with_models, model=text_model_id, prompt=prompt, stream=False, - extra_body={ - "guided_choice": ["joy", "sadness"], - }, + guided_choice=["joy", "sadness"], ) assert len(response.choices) > 0 choice = response.choices[0] diff --git a/tests/integration/vector_io/test_openai_vector_stores.py b/tests/integration/vector_io/test_openai_vector_stores.py index 71d2bc55e..a34c5b410 100644 --- a/tests/integration/vector_io/test_openai_vector_stores.py +++ b/tests/integration/vector_io/test_openai_vector_stores.py @@ -20,22 +20,15 @@ logger = logging.getLogger(__name__) def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models): - vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"] - for p in vector_io_providers: - if p.provider_type in ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "inline::chromadb"]: - return - - pytest.skip("OpenAI vector stores are not supported by any provider") - - -def skip_if_provider_doesnt_support_openai_vector_store_files_api(client_with_models): vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"] for p in vector_io_providers: if p.provider_type in [ "inline::faiss", "inline::sqlite-vec", "inline::milvus", + "inline::chromadb", "remote::pgvector", + "remote::chromadb", ]: return @@ -457,7 +450,6 @@ def test_openai_vector_store_search_with_max_num_results( def test_openai_vector_store_attach_file(compat_client_with_empty_stores, client_with_models): """Test OpenAI vector store attach file.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) - skip_if_provider_doesnt_support_openai_vector_store_files_api(client_with_models) if isinstance(compat_client_with_empty_stores, LlamaStackClient): pytest.skip("Vector Store Files attach is not yet supported with LlamaStackClient") @@ -509,7 +501,6 @@ def test_openai_vector_store_attach_file(compat_client_with_empty_stores, client def test_openai_vector_store_attach_files_on_creation(compat_client_with_empty_stores, client_with_models): """Test OpenAI vector store attach files on creation.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) - skip_if_provider_doesnt_support_openai_vector_store_files_api(client_with_models) if isinstance(compat_client_with_empty_stores, LlamaStackClient): pytest.skip("Vector Store Files attach is not yet supported with LlamaStackClient") @@ -566,7 +557,6 @@ def test_openai_vector_store_attach_files_on_creation(compat_client_with_empty_s def test_openai_vector_store_list_files(compat_client_with_empty_stores, client_with_models): """Test OpenAI vector store list files.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) - skip_if_provider_doesnt_support_openai_vector_store_files_api(client_with_models) if isinstance(compat_client_with_empty_stores, LlamaStackClient): pytest.skip("Vector Store Files list is not yet supported with LlamaStackClient") @@ -640,7 +630,6 @@ def test_openai_vector_store_list_files_invalid_vector_store(compat_client_with_ def test_openai_vector_store_retrieve_file_contents(compat_client_with_empty_stores, client_with_models): """Test OpenAI vector store retrieve file contents.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) - skip_if_provider_doesnt_support_openai_vector_store_files_api(client_with_models) if isinstance(compat_client_with_empty_stores, LlamaStackClient): pytest.skip("Vector Store Files retrieve contents is not yet supported with LlamaStackClient") @@ -682,7 +671,6 @@ def test_openai_vector_store_retrieve_file_contents(compat_client_with_empty_sto def test_openai_vector_store_delete_file(compat_client_with_empty_stores, client_with_models): """Test OpenAI vector store delete file.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) - skip_if_provider_doesnt_support_openai_vector_store_files_api(client_with_models) if isinstance(compat_client_with_empty_stores, LlamaStackClient): pytest.skip("Vector Store Files list is not yet supported with LlamaStackClient") @@ -735,12 +723,9 @@ def test_openai_vector_store_delete_file(compat_client_with_empty_stores, client assert updated_vector_store.file_counts.in_progress == 0 -# TODO: Remove this xfail once we have a way to remove embeddings from vector store -@pytest.mark.xfail(reason="Vector Store Files delete doesn't remove embeddings from vector store", strict=True) def test_openai_vector_store_delete_file_removes_from_vector_store(compat_client_with_empty_stores, client_with_models): """Test OpenAI vector store delete file removes from vector store.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) - skip_if_provider_doesnt_support_openai_vector_store_files_api(client_with_models) if isinstance(compat_client_with_empty_stores, LlamaStackClient): pytest.skip("Vector Store Files attach is not yet supported with LlamaStackClient") @@ -782,7 +767,6 @@ def test_openai_vector_store_delete_file_removes_from_vector_store(compat_client def test_openai_vector_store_update_file(compat_client_with_empty_stores, client_with_models): """Test OpenAI vector store update file.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) - skip_if_provider_doesnt_support_openai_vector_store_files_api(client_with_models) if isinstance(compat_client_with_empty_stores, LlamaStackClient): pytest.skip("Vector Store Files update is not yet supported with LlamaStackClient") @@ -831,7 +815,6 @@ def test_create_vector_store_files_duplicate_vector_store_name(compat_client_wit This test confirms that client.vector_stores.create() creates a unique ID """ skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) - skip_if_provider_doesnt_support_openai_vector_store_files_api(client_with_models) if isinstance(compat_client_with_empty_stores, LlamaStackClient): pytest.skip("Vector Store Files create is not yet supported with LlamaStackClient") diff --git a/tests/unit/distribution/routers/test_routing_tables.py b/tests/unit/distribution/routers/test_routing_tables.py index 12b05ebff..c1b57cb4f 100644 --- a/tests/unit/distribution/routers/test_routing_tables.py +++ b/tests/unit/distribution/routers/test_routing_tables.py @@ -15,6 +15,7 @@ from llama_stack.apis.models import Model, ModelType from llama_stack.apis.shields.shields import Shield from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter from llama_stack.apis.vector_dbs import VectorDB +from llama_stack.distribution.datatypes import RegistryEntrySource from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable from llama_stack.distribution.routing_tables.models import ModelsRoutingTable @@ -45,6 +46,30 @@ class InferenceImpl(Impl): async def unregister_model(self, model_id: str): return model_id + async def should_refresh_models(self): + return False + + async def list_models(self): + return [ + Model( + identifier="provider-model-1", + provider_resource_id="provider-model-1", + provider_id="test_provider", + metadata={}, + model_type=ModelType.llm, + ), + Model( + identifier="provider-model-2", + provider_resource_id="provider-model-2", + provider_id="test_provider", + metadata={"embedding_dimension": 512}, + model_type=ModelType.embedding, + ), + ] + + async def shutdown(self): + pass + class SafetyImpl(Impl): def __init__(self): @@ -378,3 +403,170 @@ async def test_models_fallback_lookup_behavior(cached_disk_dist_registry): raise AssertionError("Should have raised ValueError for non-existent model") except ValueError as e: assert "not found" in str(e) + + +async def test_models_source_tracking_default(cached_disk_dist_registry): + """Test that models registered via register_model get default source.""" + table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) + await table.initialize() + + # Register model via register_model (should get default source) + await table.register_model(model_id="user-model", provider_id="test_provider") + + models = await table.list_models() + assert len(models.data) == 1 + model = models.data[0] + assert model.source == RegistryEntrySource.via_register_api + assert model.identifier == "test_provider/user-model" + + # Cleanup + await table.shutdown() + + +async def test_models_source_tracking_provider(cached_disk_dist_registry): + """Test that models registered via update_registered_models get provider source.""" + table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) + await table.initialize() + + # Simulate provider refresh by calling update_registered_models + provider_models = [ + Model( + identifier="provider-model-1", + provider_resource_id="provider-model-1", + provider_id="test_provider", + metadata={}, + model_type=ModelType.llm, + ), + Model( + identifier="provider-model-2", + provider_resource_id="provider-model-2", + provider_id="test_provider", + metadata={"embedding_dimension": 512}, + model_type=ModelType.embedding, + ), + ] + await table.update_registered_models("test_provider", provider_models) + + models = await table.list_models() + assert len(models.data) == 2 + + # All models should have provider source + for model in models.data: + assert model.source == RegistryEntrySource.listed_from_provider + assert model.provider_id == "test_provider" + + # Cleanup + await table.shutdown() + + +async def test_models_source_interaction_preserves_default(cached_disk_dist_registry): + """Test that provider refresh preserves user-registered models with default source.""" + table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) + await table.initialize() + + # First register a user model with same provider_resource_id as provider will later provide + await table.register_model( + model_id="my-custom-alias", provider_model_id="provider-model-1", provider_id="test_provider" + ) + + # Verify user model is registered with default source + models = await table.list_models() + assert len(models.data) == 1 + user_model = models.data[0] + assert user_model.source == RegistryEntrySource.via_register_api + assert user_model.identifier == "my-custom-alias" + assert user_model.provider_resource_id == "provider-model-1" + + # Now simulate provider refresh + provider_models = [ + Model( + identifier="provider-model-1", + provider_resource_id="provider-model-1", + provider_id="test_provider", + metadata={}, + model_type=ModelType.llm, + ), + Model( + identifier="different-model", + provider_resource_id="different-model", + provider_id="test_provider", + metadata={}, + model_type=ModelType.llm, + ), + ] + await table.update_registered_models("test_provider", provider_models) + + # Verify user model with alias is preserved, but provider added new model + models = await table.list_models() + assert len(models.data) == 2 + + # Find the user model and provider model + user_model = next((m for m in models.data if m.identifier == "my-custom-alias"), None) + provider_model = next((m for m in models.data if m.identifier == "different-model"), None) + + assert user_model is not None + assert user_model.source == RegistryEntrySource.via_register_api + assert user_model.provider_resource_id == "provider-model-1" + + assert provider_model is not None + assert provider_model.source == RegistryEntrySource.listed_from_provider + assert provider_model.provider_resource_id == "different-model" + + # Cleanup + await table.shutdown() + + +async def test_models_source_interaction_cleanup_provider_models(cached_disk_dist_registry): + """Test that provider refresh removes old provider models but keeps default ones.""" + table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) + await table.initialize() + + # Register a user model + await table.register_model(model_id="user-model", provider_id="test_provider") + + # Add some provider models + provider_models_v1 = [ + Model( + identifier="provider-model-old", + provider_resource_id="provider-model-old", + provider_id="test_provider", + metadata={}, + model_type=ModelType.llm, + ), + ] + await table.update_registered_models("test_provider", provider_models_v1) + + # Verify we have both user and provider models + models = await table.list_models() + assert len(models.data) == 2 + + # Now update with new provider models (should remove old provider models) + provider_models_v2 = [ + Model( + identifier="provider-model-new", + provider_resource_id="provider-model-new", + provider_id="test_provider", + metadata={}, + model_type=ModelType.llm, + ), + ] + await table.update_registered_models("test_provider", provider_models_v2) + + # Should have user model + new provider model, old provider model gone + models = await table.list_models() + assert len(models.data) == 2 + + identifiers = {m.identifier for m in models.data} + assert "test_provider/user-model" in identifiers # User model preserved + assert "provider-model-new" in identifiers # New provider model (uses provider's identifier) + assert "provider-model-old" not in identifiers # Old provider model removed + + # Verify sources are correct + user_model = next((m for m in models.data if m.identifier == "test_provider/user-model"), None) + provider_model = next((m for m in models.data if m.identifier == "provider-model-new"), None) + + assert user_model.source == RegistryEntrySource.via_register_api + assert provider_model.source == RegistryEntrySource.listed_from_provider + + # Cleanup + await table.shutdown() diff --git a/tests/unit/distribution/test_distribution.py b/tests/unit/distribution/test_distribution.py index ae24602d7..5aac113eb 100644 --- a/tests/unit/distribution/test_distribution.py +++ b/tests/unit/distribution/test_distribution.py @@ -106,6 +106,40 @@ def api_directories(tmp_path): return remote_inference_dir, inline_inference_dir +def make_import_module_side_effect( + builtin_provider_spec=None, + external_module=None, + raise_for_external=False, + missing_get_provider_spec=False, +): + from types import SimpleNamespace + + def import_module_side_effect(name): + if name == "llama_stack.providers.registry.inference": + mock_builtin = SimpleNamespace( + available_providers=lambda: [ + builtin_provider_spec + or ProviderSpec( + api=Api.inference, + provider_type="test_provider", + config_class="test_provider.config.TestProviderConfig", + module="test_provider", + ) + ] + ) + return mock_builtin + elif name == "external_test.provider": + if raise_for_external: + raise ModuleNotFoundError(name) + if missing_get_provider_spec: + return SimpleNamespace() + return external_module + else: + raise ModuleNotFoundError(name) + + return import_module_side_effect + + class TestProviderRegistry: """Test suite for provider registry functionality.""" @@ -221,3 +255,124 @@ pip_packages: with pytest.raises(KeyError) as exc_info: get_provider_registry(base_config) assert "config_class" in str(exc_info.value) + + def test_external_provider_from_module_success(self, mock_providers): + """Test loading an external provider from a module (success path).""" + from types import SimpleNamespace + + from llama_stack.distribution.datatypes import Provider, StackRunConfig + from llama_stack.providers.datatypes import Api, ProviderSpec + + # Simulate a provider module with get_provider_spec + fake_spec = ProviderSpec( + api=Api.inference, + provider_type="external_test", + config_class="external_test.config.ExternalTestConfig", + module="external_test", + ) + fake_module = SimpleNamespace(get_provider_spec=lambda: fake_spec) + + import_module_side_effect = make_import_module_side_effect(external_module=fake_module) + + with patch("importlib.import_module", side_effect=import_module_side_effect) as mock_import: + config = StackRunConfig( + image_name="test_image", + providers={ + "inference": [ + Provider( + provider_id="external_test", + provider_type="external_test", + config={}, + module="external_test", + ) + ] + }, + ) + registry = get_provider_registry(config) + assert Api.inference in registry + assert "external_test" in registry[Api.inference] + provider = registry[Api.inference]["external_test"] + assert provider.module == "external_test" + assert provider.config_class == "external_test.config.ExternalTestConfig" + mock_import.assert_any_call("llama_stack.providers.registry.inference") + mock_import.assert_any_call("external_test.provider") + + def test_external_provider_from_module_not_found(self, mock_providers): + """Test handling ModuleNotFoundError for missing provider module.""" + from llama_stack.distribution.datatypes import Provider, StackRunConfig + + import_module_side_effect = make_import_module_side_effect(raise_for_external=True) + + with patch("importlib.import_module", side_effect=import_module_side_effect): + config = StackRunConfig( + image_name="test_image", + providers={ + "inference": [ + Provider( + provider_id="external_test", + provider_type="external_test", + config={}, + module="external_test", + ) + ] + }, + ) + with pytest.raises(ValueError) as exc_info: + get_provider_registry(config) + assert "get_provider_spec not found" in str(exc_info.value) + + def test_external_provider_from_module_missing_get_provider_spec(self, mock_providers): + """Test handling missing get_provider_spec in provider module (should raise ValueError).""" + from llama_stack.distribution.datatypes import Provider, StackRunConfig + + import_module_side_effect = make_import_module_side_effect(missing_get_provider_spec=True) + + with patch("importlib.import_module", side_effect=import_module_side_effect): + config = StackRunConfig( + image_name="test_image", + providers={ + "inference": [ + Provider( + provider_id="external_test", + provider_type="external_test", + config={}, + module="external_test", + ) + ] + }, + ) + with pytest.raises(AttributeError): + get_provider_registry(config) + + def test_external_provider_from_module_building(self, mock_providers): + """Test loading an external provider from a module during build (building=True, partial spec).""" + from llama_stack.distribution.datatypes import BuildConfig, DistributionSpec, Provider + from llama_stack.providers.datatypes import Api + + # No importlib patch needed, should not import module when type of `config` is BuildConfig or DistributionSpec + build_config = BuildConfig( + version=2, + image_type="container", + image_name="test_image", + distribution_spec=DistributionSpec( + description="test", + providers={ + "inference": [ + Provider( + provider_id="external_test", + provider_type="external_test", + config={}, + module="external_test", + ) + ] + }, + ), + ) + registry = get_provider_registry(build_config) + assert Api.inference in registry + assert "external_test" in registry[Api.inference] + provider = registry[Api.inference]["external_test"] + assert provider.module == "external_test" + assert provider.is_external is True + # config_class is empty string in partial spec + assert provider.config_class == "" diff --git a/tests/unit/providers/nvidia/test_safety.py b/tests/unit/providers/nvidia/test_safety.py index 73fc32a02..bfd91f466 100644 --- a/tests/unit/providers/nvidia/test_safety.py +++ b/tests/unit/providers/nvidia/test_safety.py @@ -5,321 +5,353 @@ # the root directory of this source tree. import os -import unittest from typing import Any from unittest.mock import AsyncMock, MagicMock, patch import pytest from llama_stack.apis.inference import CompletionMessage, UserMessage +from llama_stack.apis.resource import ResourceType from llama_stack.apis.safety import RunShieldResponse, ViolationLevel from llama_stack.apis.shields import Shield +from llama_stack.models.llama.datatypes import StopReason from llama_stack.providers.remote.safety.nvidia.config import NVIDIASafetyConfig from llama_stack.providers.remote.safety.nvidia.nvidia import NVIDIASafetyAdapter -class TestNVIDIASafetyAdapter(unittest.TestCase): - def setUp(self): - os.environ["NVIDIA_GUARDRAILS_URL"] = "http://nemo.test" +class TestNVIDIASafetyAdapter(NVIDIASafetyAdapter): + """Test implementation that provides the required shield_store.""" - # Initialize the adapter - self.config = NVIDIASafetyConfig( - guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"], - ) - self.adapter = NVIDIASafetyAdapter(config=self.config) - self.shield_store = AsyncMock() - self.adapter.shield_store = self.shield_store + def __init__(self, config: NVIDIASafetyConfig, shield_store): + super().__init__(config) + self.shield_store = shield_store - # Mock the HTTP request methods - self.guardrails_post_patcher = patch( - "llama_stack.providers.remote.safety.nvidia.nvidia.NeMoGuardrails._guardrails_post" - ) - self.mock_guardrails_post = self.guardrails_post_patcher.start() - self.mock_guardrails_post.return_value = {"status": "allowed"} - def tearDown(self): - """Clean up after each test.""" - self.guardrails_post_patcher.stop() +@pytest.fixture +def nvidia_adapter(): + """Set up the NVIDIASafetyAdapter for testing.""" + os.environ["NVIDIA_GUARDRAILS_URL"] = "http://nemo.test" - @pytest.fixture(autouse=True) - def inject_fixtures(self, run_async): - self.run_async = run_async + # Initialize the adapter + config = NVIDIASafetyConfig( + guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"], + ) - def _assert_request( - self, - mock_call: MagicMock, - expected_url: str, - expected_headers: dict[str, str] | None = None, - expected_json: dict[str, Any] | None = None, - ) -> None: - """ - Helper method to verify request details in mock API calls. + # Create a mock shield store that implements the ShieldStore protocol + shield_store = AsyncMock() + shield_store.get_shield = AsyncMock() - Args: - mock_call: The MagicMock object that was called - expected_url: The expected URL to which the request was made - expected_headers: Optional dictionary of expected request headers - expected_json: Optional dictionary of expected JSON payload - """ - call_args = mock_call.call_args + adapter = TestNVIDIASafetyAdapter(config=config, shield_store=shield_store) - # Check URL - assert call_args[0][0] == expected_url + return adapter - # Check headers if provided - if expected_headers: - for key, value in expected_headers.items(): - assert call_args[1]["headers"][key] == value - # Check JSON if provided - if expected_json: - for key, value in expected_json.items(): - if isinstance(value, dict): - for nested_key, nested_value in value.items(): - assert call_args[1]["json"][key][nested_key] == nested_value - else: - assert call_args[1]["json"][key] == value +@pytest.fixture +def mock_guardrails_post(): + """Mock the HTTP request methods.""" + with patch("llama_stack.providers.remote.safety.nvidia.nvidia.NeMoGuardrails._guardrails_post") as mock_post: + mock_post.return_value = {"status": "allowed"} + yield mock_post - def test_register_shield_with_valid_id(self): - shield = Shield( - provider_id="nvidia", - type="shield", - identifier="test-shield", - provider_resource_id="test-model", - ) - # Register the shield - self.run_async(self.adapter.register_shield(shield)) +def _assert_request( + mock_call: MagicMock, + expected_url: str, + expected_headers: dict[str, str] | None = None, + expected_json: dict[str, Any] | None = None, +) -> None: + """ + Helper method to verify request details in mock API calls. - def test_register_shield_without_id(self): - shield = Shield( - provider_id="nvidia", - type="shield", - identifier="test-shield", - provider_resource_id="", - ) + Args: + mock_call: The MagicMock object that was called + expected_url: The expected URL to which the request was made + expected_headers: Optional dictionary of expected request headers + expected_json: Optional dictionary of expected JSON payload + """ + call_args = mock_call.call_args - # Register the shield should raise a ValueError - with self.assertRaises(ValueError): - self.run_async(self.adapter.register_shield(shield)) + # Check URL + assert call_args[0][0] == expected_url - def test_run_shield_allowed(self): - # Set up the shield - shield_id = "test-shield" - shield = Shield( - provider_id="nvidia", - type="shield", - identifier=shield_id, - provider_resource_id="test-model", - ) - self.shield_store.get_shield.return_value = shield + # Check headers if provided + if expected_headers: + for key, value in expected_headers.items(): + assert call_args[1]["headers"][key] == value - # Mock Guardrails API response - self.mock_guardrails_post.return_value = {"status": "allowed"} + # Check JSON if provided + if expected_json: + for key, value in expected_json.items(): + if isinstance(value, dict): + for nested_key, nested_value in value.items(): + assert call_args[1]["json"][key][nested_key] == nested_value + else: + assert call_args[1]["json"][key] == value - # Run the shield - messages = [ - UserMessage(role="user", content="Hello, how are you?"), - CompletionMessage( - role="assistant", - content="I'm doing well, thank you for asking!", - stop_reason="end_of_message", - tool_calls=[], - ), - ] - result = self.run_async(self.adapter.run_shield(shield_id, messages)) - # Verify the shield store was called - self.shield_store.get_shield.assert_called_once_with(shield_id) +async def test_register_shield_with_valid_id(nvidia_adapter): + adapter = nvidia_adapter - # Verify the Guardrails API was called correctly - self.mock_guardrails_post.assert_called_once_with( - path="/v1/guardrail/checks", - data={ - "model": shield_id, - "messages": [ - {"role": "user", "content": "Hello, how are you?"}, - {"role": "assistant", "content": "I'm doing well, thank you for asking!"}, - ], - "temperature": 1.0, - "top_p": 1, - "frequency_penalty": 0, - "presence_penalty": 0, - "max_tokens": 160, - "stream": False, - "guardrails": { - "config_id": "self-check", - }, + shield = Shield( + provider_id="nvidia", + type=ResourceType.shield, + identifier="test-shield", + provider_resource_id="test-model", + ) + + # Register the shield + await adapter.register_shield(shield) + + +async def test_register_shield_without_id(nvidia_adapter): + adapter = nvidia_adapter + + shield = Shield( + provider_id="nvidia", + type=ResourceType.shield, + identifier="test-shield", + provider_resource_id="", + ) + + # Register the shield should raise a ValueError + with pytest.raises(ValueError): + await adapter.register_shield(shield) + + +async def test_run_shield_allowed(nvidia_adapter, mock_guardrails_post): + adapter = nvidia_adapter + + # Set up the shield + shield_id = "test-shield" + shield = Shield( + provider_id="nvidia", + type=ResourceType.shield, + identifier=shield_id, + provider_resource_id="test-model", + ) + adapter.shield_store.get_shield.return_value = shield + + # Mock Guardrails API response + mock_guardrails_post.return_value = {"status": "allowed"} + + # Run the shield + messages = [ + UserMessage(role="user", content="Hello, how are you?"), + CompletionMessage( + role="assistant", + content="I'm doing well, thank you for asking!", + stop_reason=StopReason.end_of_message, + tool_calls=[], + ), + ] + result = await adapter.run_shield(shield_id, messages) + + # Verify the shield store was called + adapter.shield_store.get_shield.assert_called_once_with(shield_id) + + # Verify the Guardrails API was called correctly + mock_guardrails_post.assert_called_once_with( + path="/v1/guardrail/checks", + data={ + "model": shield_id, + "messages": [ + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I'm doing well, thank you for asking!"}, + ], + "temperature": 1.0, + "top_p": 1, + "frequency_penalty": 0, + "presence_penalty": 0, + "max_tokens": 160, + "stream": False, + "guardrails": { + "config_id": "self-check", }, - ) + }, + ) - # Verify the result - assert isinstance(result, RunShieldResponse) - assert result.violation is None + # Verify the result + assert isinstance(result, RunShieldResponse) + assert result.violation is None - def test_run_shield_blocked(self): - # Set up the shield - shield_id = "test-shield" - shield = Shield( - provider_id="nvidia", - type="shield", - identifier=shield_id, - provider_resource_id="test-model", - ) - self.shield_store.get_shield.return_value = shield - # Mock Guardrails API response - self.mock_guardrails_post.return_value = {"status": "blocked", "rails_status": {"reason": "harmful_content"}} +async def test_run_shield_blocked(nvidia_adapter, mock_guardrails_post): + adapter = nvidia_adapter - # Run the shield - messages = [ - UserMessage(role="user", content="Hello, how are you?"), - CompletionMessage( - role="assistant", - content="I'm doing well, thank you for asking!", - stop_reason="end_of_message", - tool_calls=[], - ), - ] - result = self.run_async(self.adapter.run_shield(shield_id, messages)) + # Set up the shield + shield_id = "test-shield" + shield = Shield( + provider_id="nvidia", + type=ResourceType.shield, + identifier=shield_id, + provider_resource_id="test-model", + ) + adapter.shield_store.get_shield.return_value = shield - # Verify the shield store was called - self.shield_store.get_shield.assert_called_once_with(shield_id) + # Mock Guardrails API response + mock_guardrails_post.return_value = {"status": "blocked", "rails_status": {"reason": "harmful_content"}} - # Verify the Guardrails API was called correctly - self.mock_guardrails_post.assert_called_once_with( - path="/v1/guardrail/checks", - data={ - "model": shield_id, - "messages": [ - {"role": "user", "content": "Hello, how are you?"}, - {"role": "assistant", "content": "I'm doing well, thank you for asking!"}, - ], - "temperature": 1.0, - "top_p": 1, - "frequency_penalty": 0, - "presence_penalty": 0, - "max_tokens": 160, - "stream": False, - "guardrails": { - "config_id": "self-check", - }, + # Run the shield + messages = [ + UserMessage(role="user", content="Hello, how are you?"), + CompletionMessage( + role="assistant", + content="I'm doing well, thank you for asking!", + stop_reason=StopReason.end_of_message, + tool_calls=[], + ), + ] + result = await adapter.run_shield(shield_id, messages) + + # Verify the shield store was called + adapter.shield_store.get_shield.assert_called_once_with(shield_id) + + # Verify the Guardrails API was called correctly + mock_guardrails_post.assert_called_once_with( + path="/v1/guardrail/checks", + data={ + "model": shield_id, + "messages": [ + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I'm doing well, thank you for asking!"}, + ], + "temperature": 1.0, + "top_p": 1, + "frequency_penalty": 0, + "presence_penalty": 0, + "max_tokens": 160, + "stream": False, + "guardrails": { + "config_id": "self-check", }, - ) + }, + ) - # Verify the result - assert result.violation is not None - assert isinstance(result, RunShieldResponse) - assert result.violation.user_message == "Sorry I cannot do this." - assert result.violation.violation_level == ViolationLevel.ERROR - assert result.violation.metadata == {"reason": "harmful_content"} + # Verify the result + assert result.violation is not None + assert isinstance(result, RunShieldResponse) + assert result.violation.user_message == "Sorry I cannot do this." + assert result.violation.violation_level == ViolationLevel.ERROR + assert result.violation.metadata == {"reason": "harmful_content"} - def test_run_shield_not_found(self): - # Set up shield store to return None - shield_id = "non-existent-shield" - self.shield_store.get_shield.return_value = None - messages = [ - UserMessage(role="user", content="Hello, how are you?"), - ] +async def test_run_shield_not_found(nvidia_adapter, mock_guardrails_post): + adapter = nvidia_adapter - with self.assertRaises(ValueError): - self.run_async(self.adapter.run_shield(shield_id, messages)) + # Set up shield store to return None + shield_id = "non-existent-shield" + adapter.shield_store.get_shield.return_value = None - # Verify the shield store was called - self.shield_store.get_shield.assert_called_once_with(shield_id) + messages = [ + UserMessage(role="user", content="Hello, how are you?"), + ] - # Verify the Guardrails API was not called - self.mock_guardrails_post.assert_not_called() + with pytest.raises(ValueError): + await adapter.run_shield(shield_id, messages) - def test_run_shield_http_error(self): - shield_id = "test-shield" - shield = Shield( - provider_id="nvidia", - type="shield", - identifier=shield_id, - provider_resource_id="test-model", - ) - self.shield_store.get_shield.return_value = shield + # Verify the shield store was called + adapter.shield_store.get_shield.assert_called_once_with(shield_id) - # Mock Guardrails API to raise an exception - error_msg = "API Error: 500 Internal Server Error" - self.mock_guardrails_post.side_effect = Exception(error_msg) + # Verify the Guardrails API was not called + mock_guardrails_post.assert_not_called() - # Running the shield should raise an exception - messages = [ - UserMessage(role="user", content="Hello, how are you?"), - CompletionMessage( - role="assistant", - content="I'm doing well, thank you for asking!", - stop_reason="end_of_message", - tool_calls=[], - ), - ] - with self.assertRaises(Exception) as context: - self.run_async(self.adapter.run_shield(shield_id, messages)) - # Verify the shield store was called - self.shield_store.get_shield.assert_called_once_with(shield_id) +async def test_run_shield_http_error(nvidia_adapter, mock_guardrails_post): + adapter = nvidia_adapter - # Verify the Guardrails API was called correctly - self.mock_guardrails_post.assert_called_once_with( - path="/v1/guardrail/checks", - data={ - "model": shield_id, - "messages": [ - {"role": "user", "content": "Hello, how are you?"}, - {"role": "assistant", "content": "I'm doing well, thank you for asking!"}, - ], - "temperature": 1.0, - "top_p": 1, - "frequency_penalty": 0, - "presence_penalty": 0, - "max_tokens": 160, - "stream": False, - "guardrails": { - "config_id": "self-check", - }, + shield_id = "test-shield" + shield = Shield( + provider_id="nvidia", + type=ResourceType.shield, + identifier=shield_id, + provider_resource_id="test-model", + ) + adapter.shield_store.get_shield.return_value = shield + + # Mock Guardrails API to raise an exception + error_msg = "API Error: 500 Internal Server Error" + mock_guardrails_post.side_effect = Exception(error_msg) + + # Running the shield should raise an exception + messages = [ + UserMessage(role="user", content="Hello, how are you?"), + CompletionMessage( + role="assistant", + content="I'm doing well, thank you for asking!", + stop_reason=StopReason.end_of_message, + tool_calls=[], + ), + ] + with pytest.raises(Exception) as exc_info: + await adapter.run_shield(shield_id, messages) + + # Verify the shield store was called + adapter.shield_store.get_shield.assert_called_once_with(shield_id) + + # Verify the Guardrails API was called correctly + mock_guardrails_post.assert_called_once_with( + path="/v1/guardrail/checks", + data={ + "model": shield_id, + "messages": [ + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I'm doing well, thank you for asking!"}, + ], + "temperature": 1.0, + "top_p": 1, + "frequency_penalty": 0, + "presence_penalty": 0, + "max_tokens": 160, + "stream": False, + "guardrails": { + "config_id": "self-check", }, - ) - # Verify the exception message - assert error_msg in str(context.exception) + }, + ) + # Verify the exception message + assert error_msg in str(exc_info.value) - def test_init_nemo_guardrails(self): - from llama_stack.providers.remote.safety.nvidia.nvidia import NeMoGuardrails - test_config_id = "test-custom-config-id" - config = NVIDIASafetyConfig( - guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"], - config_id=test_config_id, - ) - # Initialize with default parameters - test_model = "test-model" - guardrails = NeMoGuardrails(config, test_model) +def test_init_nemo_guardrails(): + from llama_stack.providers.remote.safety.nvidia.nvidia import NeMoGuardrails - # Verify the attributes are set correctly - assert guardrails.config_id == test_config_id - assert guardrails.model == test_model - assert guardrails.threshold == 0.9 # Default value - assert guardrails.temperature == 1.0 # Default value - assert guardrails.guardrails_service_url == os.environ["NVIDIA_GUARDRAILS_URL"] + os.environ["NVIDIA_GUARDRAILS_URL"] = "http://nemo.test" - # Initialize with custom parameters - guardrails = NeMoGuardrails(config, test_model, threshold=0.8, temperature=0.7) + test_config_id = "test-custom-config-id" + config = NVIDIASafetyConfig( + guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"], + config_id=test_config_id, + ) + # Initialize with default parameters + test_model = "test-model" + guardrails = NeMoGuardrails(config, test_model) - # Verify the attributes are set correctly - assert guardrails.config_id == test_config_id - assert guardrails.model == test_model - assert guardrails.threshold == 0.8 - assert guardrails.temperature == 0.7 - assert guardrails.guardrails_service_url == os.environ["NVIDIA_GUARDRAILS_URL"] + # Verify the attributes are set correctly + assert guardrails.config_id == test_config_id + assert guardrails.model == test_model + assert guardrails.threshold == 0.9 # Default value + assert guardrails.temperature == 1.0 # Default value + assert guardrails.guardrails_service_url == os.environ["NVIDIA_GUARDRAILS_URL"] - def test_init_nemo_guardrails_invalid_temperature(self): - from llama_stack.providers.remote.safety.nvidia.nvidia import NeMoGuardrails + # Initialize with custom parameters + guardrails = NeMoGuardrails(config, test_model, threshold=0.8, temperature=0.7) - config = NVIDIASafetyConfig( - guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"], - config_id="test-custom-config-id", - ) - with self.assertRaises(ValueError): - NeMoGuardrails(config, "test-model", temperature=0) + # Verify the attributes are set correctly + assert guardrails.config_id == test_config_id + assert guardrails.model == test_model + assert guardrails.threshold == 0.8 + assert guardrails.temperature == 0.7 + assert guardrails.guardrails_service_url == os.environ["NVIDIA_GUARDRAILS_URL"] + + +def test_init_nemo_guardrails_invalid_temperature(): + from llama_stack.providers.remote.safety.nvidia.nvidia import NeMoGuardrails + + os.environ["NVIDIA_GUARDRAILS_URL"] = "http://nemo.test" + + config = NVIDIASafetyConfig( + guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"], + config_id="test-custom-config-id", + ) + with pytest.raises(ValueError): + NeMoGuardrails(config, "test-model", temperature=0) diff --git a/tests/unit/providers/vector_io/conftest.py b/tests/unit/providers/vector_io/conftest.py index 45e37d6ff..bcba06140 100644 --- a/tests/unit/providers/vector_io/conftest.py +++ b/tests/unit/providers/vector_io/conftest.py @@ -8,6 +8,7 @@ import random import numpy as np import pytest +from chromadb import PersistentClient from pymilvus import MilvusClient, connections from llama_stack.apis.vector_dbs import VectorDB @@ -18,7 +19,7 @@ from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, Faiss from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig, SqliteKVStoreConfig from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter -from llama_stack.providers.remote.vector_io.chroma.chroma import ChromaIndex, ChromaVectorIOAdapter +from llama_stack.providers.remote.vector_io.chroma.chroma import ChromaIndex, ChromaVectorIOAdapter, maybe_await from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter EMBEDDING_DIMENSION = 384 @@ -26,6 +27,11 @@ COLLECTION_PREFIX = "test_collection" MILVUS_ALIAS = "test_milvus" +@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma"]) +def vector_provider(request): + return request.param + + @pytest.fixture def vector_db_id() -> str: return f"test-vector-db-{random.randint(1, 100)}" @@ -94,11 +100,6 @@ def sample_embeddings_with_metadata(sample_chunks_with_metadata): return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks_with_metadata]) -@pytest.fixture(params=["milvus", "sqlite_vec", "faiss"]) -def vector_provider(request): - return request.param - - @pytest.fixture(scope="session") def mock_inference_api(embedding_dimension): class MockInferenceAPI: @@ -246,10 +247,10 @@ def chroma_vec_db_path(tmp_path_factory): @pytest.fixture async def chroma_vec_index(chroma_vec_db_path, embedding_dimension): - index = ChromaIndex( - embedding_dimension=embedding_dimension, - persist_directory=chroma_vec_db_path, - ) + client = PersistentClient(path=chroma_vec_db_path) + name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}" + collection = await maybe_await(client.get_or_create_collection(name)) + index = ChromaIndex(client=client, collection=collection) await index.initialize() yield index await index.delete() @@ -257,7 +258,10 @@ async def chroma_vec_index(chroma_vec_db_path, embedding_dimension): @pytest.fixture async def chroma_vec_adapter(chroma_vec_db_path, mock_inference_api, embedding_dimension): - config = ChromaVectorIOConfig(persist_directory=chroma_vec_db_path) + config = ChromaVectorIOConfig( + db_path=chroma_vec_db_path, + kvstore=SqliteKVStoreConfig(), + ) adapter = ChromaVectorIOAdapter( config=config, inference_api=mock_inference_api, diff --git a/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py b/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py index bf7663d2e..98889f38e 100644 --- a/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py +++ b/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py @@ -86,10 +86,14 @@ async def test_register_and_unregister_vector_db(vector_io_adapter): assert dummy.identifier not in vector_io_adapter.cache -async def test_query_unregistered_raises(vector_io_adapter): +async def test_query_unregistered_raises(vector_io_adapter, vector_provider): fake_emb = np.zeros(8, dtype=np.float32) - with pytest.raises(ValueError): - await vector_io_adapter.query_chunks("no_such_db", fake_emb) + if vector_provider == "chroma": + with pytest.raises(AttributeError): + await vector_io_adapter.query_chunks("no_such_db", fake_emb) + else: + with pytest.raises(ValueError): + await vector_io_adapter.query_chunks("no_such_db", fake_emb) async def test_insert_chunks_calls_underlying_index(vector_io_adapter): diff --git a/tests/unit/server/test_auth.py b/tests/unit/server/test_auth.py index 7012a7f17..adf0140e2 100644 --- a/tests/unit/server/test_auth.py +++ b/tests/unit/server/test_auth.py @@ -19,7 +19,8 @@ from llama_stack.distribution.datatypes import ( OAuth2JWKSConfig, OAuth2TokenAuthConfig, ) -from llama_stack.distribution.server.auth import AuthenticationMiddleware +from llama_stack.distribution.request_headers import User +from llama_stack.distribution.server.auth import AuthenticationMiddleware, _has_required_scope from llama_stack.distribution.server.auth_providers import ( get_attributes_from_claims, ) @@ -73,7 +74,7 @@ def http_app(mock_auth_endpoint): ), access_policy=[], ) - app.add_middleware(AuthenticationMiddleware, auth_config=auth_config) + app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={}) @app.get("/test") def test_endpoint(): @@ -111,7 +112,50 @@ def mock_http_middleware(mock_auth_endpoint): ), access_policy=[], ) - return AuthenticationMiddleware(mock_app, auth_config), mock_app + return AuthenticationMiddleware(mock_app, auth_config, {}), mock_app + + +@pytest.fixture +def mock_impls(): + """Mock implementations for scope testing""" + return {} + + +@pytest.fixture +def scope_middleware_with_mocks(mock_auth_endpoint): + """Create AuthenticationMiddleware with mocked route implementations""" + mock_app = AsyncMock() + auth_config = AuthenticationConfig( + provider_config=CustomAuthConfig( + type=AuthProviderType.CUSTOM, + endpoint=mock_auth_endpoint, + ), + access_policy=[], + ) + middleware = AuthenticationMiddleware(mock_app, auth_config, {}) + + # Mock the route_impls to simulate finding routes with required scopes + from llama_stack.schema_utils import WebMethod + + scoped_webmethod = WebMethod(route="/test/scoped", method="POST", required_scope="test.read") + + public_webmethod = WebMethod(route="/test/public", method="GET") + + # Mock the route finding logic + def mock_find_matching_route(method, path, route_impls): + if method == "POST" and path == "/test/scoped": + return None, {}, "/test/scoped", scoped_webmethod + elif method == "GET" and path == "/test/public": + return None, {}, "/test/public", public_webmethod + else: + raise ValueError("No matching route") + + import llama_stack.distribution.server.auth + + llama_stack.distribution.server.auth.find_matching_route = mock_find_matching_route + llama_stack.distribution.server.auth.initialize_route_impls = lambda impls: {} + + return middleware, mock_app async def mock_post_success(*args, **kwargs): @@ -138,6 +182,36 @@ async def mock_post_exception(*args, **kwargs): raise Exception("Connection error") +async def mock_post_success_with_scope(*args, **kwargs): + """Mock auth response for user with test.read scope""" + return MockResponse( + 200, + { + "message": "Authentication successful", + "principal": "test-user", + "attributes": { + "scopes": ["test.read", "other.scope"], + "roles": ["user"], + }, + }, + ) + + +async def mock_post_success_no_scope(*args, **kwargs): + """Mock auth response for user without required scope""" + return MockResponse( + 200, + { + "message": "Authentication successful", + "principal": "test-user", + "attributes": { + "scopes": ["other.scope"], + "roles": ["user"], + }, + }, + ) + + # HTTP Endpoint Tests def test_missing_auth_header(http_client): response = http_client.get("/test") @@ -252,7 +326,7 @@ def oauth2_app(): ), access_policy=[], ) - app.add_middleware(AuthenticationMiddleware, auth_config=auth_config) + app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={}) @app.get("/test") def test_endpoint(): @@ -351,7 +425,7 @@ def oauth2_app_with_jwks_token(): ), access_policy=[], ) - app.add_middleware(AuthenticationMiddleware, auth_config=auth_config) + app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={}) @app.get("/test") def test_endpoint(): @@ -442,7 +516,7 @@ def introspection_app(mock_introspection_endpoint): ), access_policy=[], ) - app.add_middleware(AuthenticationMiddleware, auth_config=auth_config) + app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={}) @app.get("/test") def test_endpoint(): @@ -472,7 +546,7 @@ def introspection_app_with_custom_mapping(mock_introspection_endpoint): ), access_policy=[], ) - app.add_middleware(AuthenticationMiddleware, auth_config=auth_config) + app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={}) @app.get("/test") def test_endpoint(): @@ -581,3 +655,122 @@ def test_valid_introspection_with_custom_mapping_authentication( ) assert response.status_code == 200 assert response.json() == {"message": "Authentication successful"} + + +# Scope-based authorization tests +@patch("httpx.AsyncClient.post", new=mock_post_success_with_scope) +async def test_scope_authorization_success(scope_middleware_with_mocks, valid_api_key): + """Test that user with required scope can access protected endpoint""" + middleware, mock_app = scope_middleware_with_mocks + mock_receive = AsyncMock() + mock_send = AsyncMock() + + scope = { + "type": "http", + "path": "/test/scoped", + "method": "POST", + "headers": [(b"authorization", f"Bearer {valid_api_key}".encode())], + } + + await middleware(scope, mock_receive, mock_send) + + # Should call the downstream app (no 403 error sent) + mock_app.assert_called_once_with(scope, mock_receive, mock_send) + mock_send.assert_not_called() + + +@patch("httpx.AsyncClient.post", new=mock_post_success_no_scope) +async def test_scope_authorization_denied(scope_middleware_with_mocks, valid_api_key): + """Test that user without required scope gets 403 access denied""" + middleware, mock_app = scope_middleware_with_mocks + mock_receive = AsyncMock() + mock_send = AsyncMock() + + scope = { + "type": "http", + "path": "/test/scoped", + "method": "POST", + "headers": [(b"authorization", f"Bearer {valid_api_key}".encode())], + } + + await middleware(scope, mock_receive, mock_send) + + # Should send 403 error, not call downstream app + mock_app.assert_not_called() + assert mock_send.call_count == 2 # start + body + + # Check the response + start_call = mock_send.call_args_list[0][0][0] + assert start_call["status"] == 403 + + body_call = mock_send.call_args_list[1][0][0] + body_text = body_call["body"].decode() + assert "Access denied" in body_text + assert "test.read" in body_text + + +@patch("httpx.AsyncClient.post", new=mock_post_success_no_scope) +async def test_public_endpoint_no_scope_required(scope_middleware_with_mocks, valid_api_key): + """Test that public endpoints work without specific scopes""" + middleware, mock_app = scope_middleware_with_mocks + mock_receive = AsyncMock() + mock_send = AsyncMock() + + scope = { + "type": "http", + "path": "/test/public", + "method": "GET", + "headers": [(b"authorization", f"Bearer {valid_api_key}".encode())], + } + + await middleware(scope, mock_receive, mock_send) + + # Should call the downstream app (no error) + mock_app.assert_called_once_with(scope, mock_receive, mock_send) + mock_send.assert_not_called() + + +async def test_scope_authorization_no_auth_disabled(scope_middleware_with_mocks): + """Test that when auth is disabled (no user), scope checks are bypassed""" + middleware, mock_app = scope_middleware_with_mocks + mock_receive = AsyncMock() + mock_send = AsyncMock() + + scope = { + "type": "http", + "path": "/test/scoped", + "method": "POST", + "headers": [], # No authorization header + } + + await middleware(scope, mock_receive, mock_send) + + # Should send 401 auth error, not call downstream app + mock_app.assert_not_called() + assert mock_send.call_count == 2 # start + body + + # Check the response + start_call = mock_send.call_args_list[0][0][0] + assert start_call["status"] == 401 + + body_call = mock_send.call_args_list[1][0][0] + body_text = body_call["body"].decode() + assert "Authentication required" in body_text + + +def test_has_required_scope_function(): + """Test the _has_required_scope function directly""" + # Test user with required scope + user_with_scope = User(principal="test-user", attributes={"scopes": ["test.read", "other.scope"]}) + assert _has_required_scope("test.read", user_with_scope) + + # Test user without required scope + user_without_scope = User(principal="test-user", attributes={"scopes": ["other.scope"]}) + assert not _has_required_scope("test.read", user_without_scope) + + # Test user with no scopes attribute + user_no_scopes = User(principal="test-user", attributes={}) + assert not _has_required_scope("test.read", user_no_scopes) + + # Test no user (auth disabled) + assert _has_required_scope("test.read", None)