Merge branch 'main' into update-api-docs

This commit is contained in:
Sai Prashanth S 2025-07-25 09:32:28 -07:00 committed by GitHub
commit cd16c72cdf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
136 changed files with 4279 additions and 1465 deletions

27
.github/actions/setup-vllm/action.yml vendored Normal file
View file

@ -0,0 +1,27 @@
name: Setup VLLM
description: Start VLLM
runs:
using: "composite"
steps:
- name: Start VLLM
shell: bash
run: |
# Start vllm container
docker run -d \
--name vllm \
-p 8000:8000 \
--privileged=true \
quay.io/higginsd/vllm-cpu:65393ee064 \
--host 0.0.0.0 \
--port 8000 \
--enable-auto-tool-choice \
--tool-call-parser llama3_json \
--model /root/.cache/Llama-3.2-1B-Instruct \
--served-model-name meta-llama/Llama-3.2-1B-Instruct
# Wait for vllm to be ready
echo "Waiting for vllm to be ready..."
timeout 900 bash -c 'until curl -f http://localhost:8000/health; do
echo "Waiting for vllm..."
sleep 5
done'

View file

@ -14,8 +14,6 @@ updates:
schedule: schedule:
interval: "weekly" interval: "weekly"
day: "saturday" day: "saturday"
# ignore all non-security updates: https://docs.github.com/en/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file#open-pull-requests-limit
open-pull-requests-limit: 0
labels: labels:
- type/dependencies - type/dependencies
- python - python

22
.github/workflows/README.md vendored Normal file
View file

@ -0,0 +1,22 @@
# Llama Stack CI
Llama Stack uses GitHub Actions for Continous Integration (CI). Below is a table detailing what CI the project includes and the purpose.
| Name | File | Purpose |
| ---- | ---- | ------- |
| Update Changelog | [changelog.yml](changelog.yml) | Creates PR for updating the CHANGELOG.md |
| Coverage Badge | [coverage-badge.yml](coverage-badge.yml) | Creates PR for updating the code coverage badge |
| Installer CI | [install-script-ci.yml](install-script-ci.yml) | Test the installation script |
| Integration Auth Tests | [integration-auth-tests.yml](integration-auth-tests.yml) | Run the integration test suite with Kubernetes authentication |
| SqlStore Integration Tests | [integration-sql-store-tests.yml](integration-sql-store-tests.yml) | Run the integration test suite with SqlStore |
| Integration Tests | [integration-tests.yml](integration-tests.yml) | Run the integration test suite with Ollama |
| Vector IO Integration Tests | [integration-vector-io-tests.yml](integration-vector-io-tests.yml) | Run the integration test suite with various VectorIO providers |
| Pre-commit | [pre-commit.yml](pre-commit.yml) | Run pre-commit checks |
| Test Llama Stack Build | [providers-build.yml](providers-build.yml) | Test llama stack build |
| Python Package Build Test | [python-build-test.yml](python-build-test.yml) | Test building the llama-stack PyPI project |
| Check semantic PR titles | [semantic-pr.yml](semantic-pr.yml) | Ensure that PR titles follow the conventional commit spec |
| Close stale issues and PRs | [stale_bot.yml](stale_bot.yml) | Run the Stale Bot action |
| Test External Providers Installed via Module | [test-external-provider-module.yml](test-external-provider-module.yml) | Test External Provider installation via Python module |
| Test External API and Providers | [test-external.yml](test-external.yml) | Test the External API and Provider mechanisms |
| Unit Tests | [unit-tests.yml](unit-tests.yml) | Run the unit test suite |
| Update ReadTheDocs | [update-readthedocs.yml](update-readthedocs.yml) | Update the Llama Stack ReadTheDocs site |

View file

@ -1,5 +1,7 @@
name: Update Changelog name: Update Changelog
run-name: Creates PR for updating the CHANGELOG.md
on: on:
release: release:
types: [published, unpublished, created, edited, deleted, released] types: [published, unpublished, created, edited, deleted, released]

View file

@ -1,5 +1,7 @@
name: Coverage Badge name: Coverage Badge
run-name: Creates PR for updating the code coverage badge
on: on:
push: push:
branches: [ main ] branches: [ main ]

View file

@ -1,5 +1,7 @@
name: Installer CI name: Installer CI
run-name: Test the installation script
on: on:
pull_request: pull_request:
paths: paths:
@ -17,10 +19,20 @@ jobs:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # 4.2.2 - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # 4.2.2
- name: Run ShellCheck on install.sh - name: Run ShellCheck on install.sh
run: shellcheck scripts/install.sh run: shellcheck scripts/install.sh
smoke-test: smoke-test-on-dev:
needs: lint
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # 4.2.2 - name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Install dependencies
uses: ./.github/actions/setup-runner
- name: Build a single provider
run: |
USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --template starter --image-type container --image-name test
- name: Run installer end-to-end - name: Run installer end-to-end
run: ./scripts/install.sh run: |
IMAGE_ID=$(docker images --format "{{.Repository}}:{{.Tag}}" | head -n 1)
./scripts/install.sh --image $IMAGE_ID

View file

@ -1,5 +1,7 @@
name: Integration Auth Tests name: Integration Auth Tests
run-name: Run the integration test suite with Kubernetes authentication
on: on:
push: push:
branches: [ main ] branches: [ main ]

View file

@ -1,5 +1,7 @@
name: SqlStore Integration Tests name: SqlStore Integration Tests
run-name: Run the integration test suite with SqlStore
on: on:
push: push:
branches: [ main ] branches: [ main ]

View file

@ -1,5 +1,7 @@
name: Integration Tests name: Integration Tests
run-name: Run the integration test suite with Ollama
on: on:
push: push:
branches: [ main ] branches: [ main ]
@ -14,13 +16,19 @@ on:
- '.github/workflows/integration-tests.yml' # This workflow - '.github/workflows/integration-tests.yml' # This workflow
- '.github/actions/setup-ollama/action.yml' - '.github/actions/setup-ollama/action.yml'
schedule: schedule:
- cron: '0 0 * * *' # Daily at 12 AM UTC # If changing the cron schedule, update the provider in the test-matrix job
- cron: '0 0 * * *' # (test latest client) Daily at 12 AM UTC
- cron: '1 0 * * 0' # (test vllm) Weekly on Sunday at 1 AM UTC
workflow_dispatch: workflow_dispatch:
inputs: inputs:
test-all-client-versions: test-all-client-versions:
description: 'Test against both the latest and published versions' description: 'Test against both the latest and published versions'
type: boolean type: boolean
default: false default: false
test-provider:
description: 'Test against a specific provider'
type: string
default: 'ollama'
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.ref }} group: ${{ github.workflow }}-${{ github.ref }}
@ -53,8 +61,17 @@ jobs:
matrix: matrix:
test-type: ${{ fromJson(needs.discover-tests.outputs.test-type) }} test-type: ${{ fromJson(needs.discover-tests.outputs.test-type) }}
client-type: [library, server] client-type: [library, server]
# Use vllm on weekly schedule, otherwise use test-provider input (defaults to ollama)
provider: ${{ (github.event.schedule == '1 0 * * 0') && fromJSON('["vllm"]') || fromJSON(format('["{0}"]', github.event.inputs.test-provider || 'ollama')) }}
python-version: ["3.12", "3.13"] python-version: ["3.12", "3.13"]
client-version: ${{ (github.event_name == 'schedule' || github.event.inputs.test-all-client-versions == 'true') && fromJSON('["published", "latest"]') || fromJSON('["latest"]') }} client-version: ${{ (github.event.schedule == '0 0 * * 0' || github.event.inputs.test-all-client-versions == 'true') && fromJSON('["published", "latest"]') || fromJSON('["latest"]') }}
exclude: # TODO: look into why these tests are failing and fix them
- provider: vllm
test-type: safety
- provider: vllm
test-type: post_training
- provider: vllm
test-type: tool_runtime
steps: steps:
- name: Checkout repository - name: Checkout repository
@ -67,8 +84,13 @@ jobs:
client-version: ${{ matrix.client-version }} client-version: ${{ matrix.client-version }}
- name: Setup ollama - name: Setup ollama
if: ${{ matrix.provider == 'ollama' }}
uses: ./.github/actions/setup-ollama uses: ./.github/actions/setup-ollama
- name: Setup vllm
if: ${{ matrix.provider == 'vllm' }}
uses: ./.github/actions/setup-vllm
- name: Build Llama Stack - name: Build Llama Stack
run: | run: |
uv run llama stack build --template ci-tests --image-type venv uv run llama stack build --template ci-tests --image-type venv
@ -81,10 +103,6 @@ jobs:
- name: Run Integration Tests - name: Run Integration Tests
env: env:
OLLAMA_INFERENCE_MODEL: "llama3.2:3b-instruct-fp16" # for server tests
ENABLE_OLLAMA: "ollama" # for server tests
OLLAMA_URL: "http://0.0.0.0:11434"
SAFETY_MODEL: "llama-guard3:1b"
LLAMA_STACK_CLIENT_TIMEOUT: "300" # Increased timeout for eval operations LLAMA_STACK_CLIENT_TIMEOUT: "300" # Increased timeout for eval operations
# Use 'shell' to get pipefail behavior # Use 'shell' to get pipefail behavior
# https://docs.github.com/en/actions/reference/workflow-syntax-for-github-actions#exit-codes-and-error-action-preference # https://docs.github.com/en/actions/reference/workflow-syntax-for-github-actions#exit-codes-and-error-action-preference
@ -96,12 +114,31 @@ jobs:
else else
stack_config="server:ci-tests" stack_config="server:ci-tests"
fi fi
EXCLUDE_TESTS="builtin_tool or safety_with_image or code_interpreter or test_rag"
if [ "${{ matrix.provider }}" == "ollama" ]; then
export ENABLE_OLLAMA="ollama"
export OLLAMA_URL="http://0.0.0.0:11434"
export OLLAMA_INFERENCE_MODEL="llama3.2:3b-instruct-fp16"
export TEXT_MODEL=ollama/$OLLAMA_INFERENCE_MODEL
export SAFETY_MODEL="llama-guard3:1b"
EXTRA_PARAMS="--safety-shield=$SAFETY_MODEL"
else
export ENABLE_VLLM="vllm"
export VLLM_URL="http://localhost:8000/v1"
export VLLM_INFERENCE_MODEL="meta-llama/Llama-3.2-1B-Instruct"
export TEXT_MODEL=vllm/$VLLM_INFERENCE_MODEL
# TODO: remove the not(test_inference_store_tool_calls) once we can get the tool called consistently
EXTRA_PARAMS=
EXCLUDE_TESTS="${EXCLUDE_TESTS} or test_inference_store_tool_calls"
fi
uv run pytest -s -v tests/integration/${{ matrix.test-type }} --stack-config=${stack_config} \ uv run pytest -s -v tests/integration/${{ matrix.test-type }} --stack-config=${stack_config} \
-k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \ -k "not( ${EXCLUDE_TESTS} )" \
--text-model="ollama/llama3.2:3b-instruct-fp16" \ --text-model=$TEXT_MODEL \
--embedding-model=sentence-transformers/all-MiniLM-L6-v2 \ --embedding-model=sentence-transformers/all-MiniLM-L6-v2 \
--safety-shield=$SAFETY_MODEL \ --color=yes ${EXTRA_PARAMS} \
--color=yes \
--capture=tee-sys | tee pytest-${{ matrix.test-type }}.log --capture=tee-sys | tee pytest-${{ matrix.test-type }}.log
- name: Check Storage and Memory Available After Tests - name: Check Storage and Memory Available After Tests
@ -110,16 +147,17 @@ jobs:
free -h free -h
df -h df -h
- name: Write ollama logs to file - name: Write inference logs to file
if: ${{ always() }} if: ${{ always() }}
run: | run: |
sudo docker logs ollama > ollama.log sudo docker logs ollama > ollama.log || true
sudo docker logs vllm > vllm.log || true
- name: Upload all logs to artifacts - name: Upload all logs to artifacts
if: ${{ always() }} if: ${{ always() }}
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with: with:
name: logs-${{ github.run_id }}-${{ github.run_attempt }}-${{ matrix.client-type }}-${{ matrix.test-type }}-${{ matrix.python-version }}-${{ matrix.client-version }} name: logs-${{ github.run_id }}-${{ github.run_attempt }}-${{ matrix.provider }}-${{ matrix.client-type }}-${{ matrix.test-type }}-${{ matrix.python-version }}-${{ matrix.client-version }}
path: | path: |
*.log *.log
retention-days: 1 retention-days: 1

View file

@ -1,5 +1,7 @@
name: Vector IO Integration Tests name: Vector IO Integration Tests
run-name: Run the integration test suite with various VectorIO providers
on: on:
push: push:
branches: [ main ] branches: [ main ]

View file

@ -1,5 +1,7 @@
name: Pre-commit name: Pre-commit
run-name: Run pre-commit checks
on: on:
pull_request: pull_request:
push: push:

View file

@ -1,5 +1,7 @@
name: Test Llama Stack Build name: Test Llama Stack Build
run-name: Test llama stack build
on: on:
push: push:
branches: branches:

View file

@ -1,5 +1,7 @@
name: Python Package Build Test name: Python Package Build Test
run-name: Test building the llama-stack PyPI project
on: on:
push: push:
branches: branches:
@ -20,7 +22,7 @@ jobs:
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Install uv - name: Install uv
uses: astral-sh/setup-uv@7edac99f961f18b581bbd960d59d049f04c0002f # v6.4.1 uses: astral-sh/setup-uv@e92bafb6253dcd438e0484186d7669ea7a8ca1cc # v6.4.3
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
activate-environment: true activate-environment: true

View file

@ -1,5 +1,7 @@
name: Check semantic PR titles name: Check semantic PR titles
run-name: Ensure that PR titles follow the conventional commit spec
on: on:
pull_request_target: pull_request_target:
types: types:

View file

@ -1,5 +1,7 @@
name: Close stale issues and PRs name: Close stale issues and PRs
run-name: Run the Stale Bot action
on: on:
schedule: schedule:
- cron: '0 0 * * *' # every day at midnight - cron: '0 0 * * *' # every day at midnight

View file

@ -1,4 +1,6 @@
name: Test External Providers name: Test External Providers Installed via Module
run-name: Test External Provider installation via Python module
on: on:
push: push:
@ -11,10 +13,10 @@ on:
- 'uv.lock' - 'uv.lock'
- 'pyproject.toml' - 'pyproject.toml'
- 'requirements.txt' - 'requirements.txt'
- '.github/workflows/test-external-providers.yml' # This workflow - '.github/workflows/test-external-providers-module.yml' # This workflow
jobs: jobs:
test-external-providers: test-external-providers-from-module:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
matrix: matrix:
@ -28,39 +30,38 @@ jobs:
- name: Install dependencies - name: Install dependencies
uses: ./.github/actions/setup-runner uses: ./.github/actions/setup-runner
- name: Install Ramalama
shell: bash
run: |
uv pip install ramalama
- name: Run Ramalama
shell: bash
run: |
nohup ramalama serve llama3.2:3b-instruct-fp16 > ramalama_server.log 2>&1 &
- name: Apply image type to config file - name: Apply image type to config file
run: | run: |
yq -i '.image_type = "${{ matrix.image-type }}"' tests/external-provider/llama-stack-provider-ollama/custom-distro.yaml yq -i '.image_type = "${{ matrix.image-type }}"' tests/external/ramalama-stack/run.yaml
cat tests/external-provider/llama-stack-provider-ollama/custom-distro.yaml cat tests/external/ramalama-stack/run.yaml
- name: Setup directory for Ollama custom provider
run: |
mkdir -p tests/external-provider/llama-stack-provider-ollama/src/
cp -a llama_stack/providers/remote/inference/ollama/ tests/external-provider/llama-stack-provider-ollama/src/llama_stack_provider_ollama
- name: Create provider configuration
run: |
mkdir -p /home/runner/.llama/providers.d/remote/inference
cp tests/external-provider/llama-stack-provider-ollama/custom_ollama.yaml /home/runner/.llama/providers.d/remote/inference/custom_ollama.yaml
- name: Build distro from config file - name: Build distro from config file
run: | run: |
USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. llama stack build --config tests/external-provider/llama-stack-provider-ollama/custom-distro.yaml USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. llama stack build --config tests/external/ramalama-stack/build.yaml
- name: Start Llama Stack server in background - name: Start Llama Stack server in background
if: ${{ matrix.image-type }} == 'venv' if: ${{ matrix.image-type }} == 'venv'
env: env:
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct" INFERENCE_MODEL: "llama3.2:3b-instruct-fp16"
run: | run: |
# Use the virtual environment created by the build step (name comes from build config) # Use the virtual environment created by the build step (name comes from build config)
source ci-test/bin/activate source ramalama-stack-test/bin/activate
uv pip list uv pip list
nohup llama stack run tests/external-provider/llama-stack-provider-ollama/run.yaml --image-type ${{ matrix.image-type }} > server.log 2>&1 & nohup llama stack run tests/external/ramalama-stack/run.yaml --image-type ${{ matrix.image-type }} > server.log 2>&1 &
- name: Wait for Llama Stack server to be ready - name: Wait for Llama Stack server to be ready
run: | run: |
for i in {1..30}; do for i in {1..30}; do
if ! grep -q "Successfully loaded external provider remote::custom_ollama" server.log; then if ! grep -q "successfully connected to Ramalama" server.log; then
echo "Waiting for Llama Stack server to load the provider..." echo "Waiting for Llama Stack server to load the provider..."
sleep 1 sleep 1
else else

77
.github/workflows/test-external.yml vendored Normal file
View file

@ -0,0 +1,77 @@
name: Test External API and Providers
run-name: Test the External API and Provider mechanisms
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
paths:
- 'llama_stack/**'
- 'tests/integration/**'
- 'uv.lock'
- 'pyproject.toml'
- 'requirements.txt'
- '.github/workflows/test-external.yml' # This workflow
jobs:
test-external:
runs-on: ubuntu-latest
strategy:
matrix:
image-type: [venv]
# We don't do container yet, it's tricky to install a package from the host into the
# container and point 'uv pip install' to the correct path...
steps:
- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Install dependencies
uses: ./.github/actions/setup-runner
- name: Create API configuration
run: |
mkdir -p /home/runner/.llama/apis.d
cp tests/external/weather.yaml /home/runner/.llama/apis.d/weather.yaml
- name: Create provider configuration
run: |
mkdir -p /home/runner/.llama/providers.d/remote/weather
cp tests/external/kaze.yaml /home/runner/.llama/providers.d/remote/weather/kaze.yaml
- name: Print distro dependencies
run: |
USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. llama stack build --config tests/external/build.yaml --print-deps-only
- name: Build distro from config file
run: |
USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. llama stack build --config tests/external/build.yaml
- name: Start Llama Stack server in background
if: ${{ matrix.image-type }} == 'venv'
env:
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
run: |
# Use the virtual environment created by the build step (name comes from build config)
source ci-test/bin/activate
uv pip list
nohup llama stack run tests/external/run-byoa.yaml --image-type ${{ matrix.image-type }} > server.log 2>&1 &
- name: Wait for Llama Stack server to be ready
run: |
echo "Waiting for Llama Stack server..."
for i in {1..30}; do
if curl -sSf http://localhost:8321/v1/health | grep -q "OK"; then
echo "Llama Stack server is up!"
exit 0
fi
sleep 1
done
echo "Llama Stack server failed to start"
cat server.log
exit 1
- name: Test external API
run: |
curl -sSf http://localhost:8321/v1/weather/locations

View file

@ -1,5 +1,7 @@
name: Unit Tests name: Unit Tests
run-name: Run the unit test suite
on: on:
push: push:
branches: [ main ] branches: [ main ]

View file

@ -1,5 +1,7 @@
name: Update ReadTheDocs name: Update ReadTheDocs
run-name: Update the Llama Stack ReadTheDocs site
on: on:
workflow_dispatch: workflow_dispatch:
inputs: inputs:

View file

@ -145,6 +145,15 @@ repos:
echo; echo;
exit 1; exit 1;
} || true } || true
- id: generate-ci-docs
name: Generate CI documentation
additional_dependencies:
- uv==0.7.8
entry: uv run ./scripts/gen-ci-docs.py
language: python
pass_filenames: false
require_serial: true
files: ^.github/workflows/.*$
ci: ci:
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks

View file

@ -0,0 +1,392 @@
# External APIs
Llama Stack supports external APIs that live outside of the main codebase. This allows you to:
- Create and maintain your own APIs independently
- Share APIs with others without contributing to the main codebase
- Keep API-specific code separate from the core Llama Stack code
## Configuration
To enable external APIs, you need to configure the `external_apis_dir` in your Llama Stack configuration. This directory should contain your external API specifications:
```yaml
external_apis_dir: ~/.llama/apis.d/
```
## Directory Structure
The external APIs directory should follow this structure:
```
apis.d/
custom_api1.yaml
custom_api2.yaml
```
Each YAML file in these directories defines an API specification.
## API Specification
Here's an example of an external API specification for a weather API:
```yaml
module: weather
api_dependencies:
- inference
protocol: WeatherAPI
name: weather
pip_packages:
- llama-stack-api-weather
```
### API Specification Fields
- `module`: Python module containing the API implementation
- `protocol`: Name of the protocol class for the API
- `name`: Name of the API
- `pip_packages`: List of pip packages to install the API, typically a single package
## Required Implementation
External APIs must expose a `available_providers()` function in their module that returns a list of provider names:
```python
# llama_stack_api_weather/api.py
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec
def available_providers() -> list[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.weather,
provider_type="inline::darksky",
pip_packages=[],
module="llama_stack_provider_darksky",
config_class="llama_stack_provider_darksky.DarkSkyWeatherImplConfig",
),
]
```
A Protocol class like so:
```python
# llama_stack_api_weather/api.py
from typing import Protocol
from llama_stack.schema_utils import webmethod
class WeatherAPI(Protocol):
"""
A protocol for the Weather API.
"""
@webmethod(route="/locations", method="GET")
async def get_available_locations() -> dict[str, list[str]]:
"""
Get the available locations.
"""
...
```
## Example: Custom API
Here's a complete example of creating and using a custom API:
1. First, create the API package:
```bash
mkdir -p llama-stack-api-weather
cd llama-stack-api-weather
mkdir src/llama_stack_api_weather
git init
uv init
```
2. Edit `pyproject.toml`:
```toml
[project]
name = "llama-stack-api-weather"
version = "0.1.0"
description = "Weather API for Llama Stack"
readme = "README.md"
requires-python = ">=3.10"
dependencies = ["llama-stack", "pydantic"]
[build-system]
requires = ["setuptools"]
build-backend = "setuptools.build_meta"
[tool.setuptools.packages.find]
where = ["src"]
include = ["llama_stack_api_weather", "llama_stack_api_weather.*"]
```
3. Create the initial files:
```bash
touch src/llama_stack_api_weather/__init__.py
touch src/llama_stack_api_weather/api.py
```
```python
# llama-stack-api-weather/src/llama_stack_api_weather/__init__.py
"""Weather API for Llama Stack."""
from .api import WeatherAPI, available_providers
__all__ = ["WeatherAPI", "available_providers"]
```
4. Create the API implementation:
```python
# llama-stack-api-weather/src/llama_stack_api_weather/weather.py
from typing import Protocol
from llama_stack.providers.datatypes import (
AdapterSpec,
Api,
ProviderSpec,
RemoteProviderSpec,
)
from llama_stack.schema_utils import webmethod
def available_providers() -> list[ProviderSpec]:
return [
RemoteProviderSpec(
api=Api.weather,
provider_type="remote::kaze",
config_class="llama_stack_provider_kaze.KazeProviderConfig",
adapter=AdapterSpec(
adapter_type="kaze",
module="llama_stack_provider_kaze",
pip_packages=["llama_stack_provider_kaze"],
config_class="llama_stack_provider_kaze.KazeProviderConfig",
),
),
]
class WeatherProvider(Protocol):
"""
A protocol for the Weather API.
"""
@webmethod(route="/weather/locations", method="GET")
async def get_available_locations() -> dict[str, list[str]]:
"""
Get the available locations.
"""
...
```
5. Create the API specification:
```yaml
# ~/.llama/apis.d/weather.yaml
module: llama_stack_api_weather
name: weather
pip_packages: ["llama-stack-api-weather"]
protocol: WeatherProvider
```
6. Install the API package:
```bash
uv pip install -e .
```
7. Configure Llama Stack to use external APIs:
```yaml
version: "2"
image_name: "llama-stack-api-weather"
apis:
- weather
providers: {}
external_apis_dir: ~/.llama/apis.d
```
The API will now be available at `/v1/weather/locations`.
## Example: custom provider for the weather API
1. Create the provider package:
```bash
mkdir -p llama-stack-provider-kaze
cd llama-stack-provider-kaze
uv init
```
2. Edit `pyproject.toml`:
```toml
[project]
name = "llama-stack-provider-kaze"
version = "0.1.0"
description = "Kaze weather provider for Llama Stack"
readme = "README.md"
requires-python = ">=3.10"
dependencies = ["llama-stack", "pydantic", "aiohttp"]
[build-system]
requires = ["setuptools"]
build-backend = "setuptools.build_meta"
[tool.setuptools.packages.find]
where = ["src"]
include = ["llama_stack_provider_kaze", "llama_stack_provider_kaze.*"]
```
3. Create the initial files:
```bash
touch src/llama_stack_provider_kaze/__init__.py
touch src/llama_stack_provider_kaze/kaze.py
```
4. Create the provider implementation:
Initialization function:
```python
# llama-stack-provider-kaze/src/llama_stack_provider_kaze/__init__.py
"""Kaze weather provider for Llama Stack."""
from .config import KazeProviderConfig
from .kaze import WeatherKazeAdapter
__all__ = ["KazeProviderConfig", "WeatherKazeAdapter"]
async def get_adapter_impl(config: KazeProviderConfig, _deps):
from .kaze import WeatherKazeAdapter
impl = WeatherKazeAdapter(config)
await impl.initialize()
return impl
```
Configuration:
```python
# llama-stack-provider-kaze/src/llama_stack_provider_kaze/config.py
from pydantic import BaseModel, Field
class KazeProviderConfig(BaseModel):
"""Configuration for the Kaze weather provider."""
base_url: str = Field(
"https://api.kaze.io/v1",
description="Base URL for the Kaze weather API",
)
```
Main implementation:
```python
# llama-stack-provider-kaze/src/llama_stack_provider_kaze/kaze.py
from llama_stack_api_weather.api import WeatherProvider
from .config import KazeProviderConfig
class WeatherKazeAdapter(WeatherProvider):
"""Kaze weather provider implementation."""
def __init__(
self,
config: KazeProviderConfig,
) -> None:
self.config = config
async def initialize(self) -> None:
pass
async def get_available_locations(self) -> dict[str, list[str]]:
"""Get available weather locations."""
return {"locations": ["Paris", "Tokyo"]}
```
5. Create the provider specification:
```yaml
# ~/.llama/providers.d/remote/weather/kaze.yaml
adapter:
adapter_type: kaze
pip_packages: ["llama_stack_provider_kaze"]
config_class: llama_stack_provider_kaze.config.KazeProviderConfig
module: llama_stack_provider_kaze
optional_api_dependencies: []
```
6. Install the provider package:
```bash
uv pip install -e .
```
7. Configure Llama Stack to use the provider:
```yaml
# ~/.llama/run-byoa.yaml
version: "2"
image_name: "llama-stack-api-weather"
apis:
- weather
providers:
weather:
- provider_id: kaze
provider_type: remote::kaze
config: {}
external_apis_dir: ~/.llama/apis.d
external_providers_dir: ~/.llama/providers.d
server:
port: 8321
```
8. Run the server:
```bash
python -m llama_stack.distribution.server.server --yaml-config ~/.llama/run-byoa.yaml
```
9. Test the API:
```bash
curl -sSf http://127.0.0.1:8321/v1/weather/locations
{"locations":["Paris","Tokyo"]}%
```
## Best Practices
1. **Package Naming**: Use a clear and descriptive name for your API package.
2. **Version Management**: Keep your API package versioned and compatible with the Llama Stack version you're using.
3. **Dependencies**: Only include the minimum required dependencies in your API package.
4. **Documentation**: Include clear documentation in your API package about:
- Installation requirements
- Configuration options
- API endpoints and usage
- Any limitations or known issues
5. **Testing**: Include tests in your API package to ensure it works correctly with Llama Stack.
## Troubleshooting
If your external API isn't being loaded:
1. Check that the `external_apis_dir` path is correct and accessible.
2. Verify that the YAML files are properly formatted.
3. Ensure all required Python packages are installed.
4. Check the Llama Stack server logs for any error messages - turn on debug logging to get more information using `LLAMA_STACK_LOGGING=all=debug`.
5. Verify that the API package is installed in your Python environment.

View file

@ -10,9 +10,11 @@ A Llama Stack API is described as a collection of REST endpoints. We currently s
- **Eval**: generate outputs (via Inference or Agents) and perform scoring - **Eval**: generate outputs (via Inference or Agents) and perform scoring
- **VectorIO**: perform operations on vector stores, such as adding documents, searching, and deleting documents - **VectorIO**: perform operations on vector stores, such as adding documents, searching, and deleting documents
- **Telemetry**: collect telemetry data from the system - **Telemetry**: collect telemetry data from the system
- **Post Training**: fine-tune a model
- **Tool Runtime**: interact with various tools and protocols
- **Responses**: generate responses from an LLM using this OpenAI compatible API.
We are working on adding a few more APIs to complete the application lifecycle. These will include: We are working on adding a few more APIs to complete the application lifecycle. These will include:
- **Batch Inference**: run inference on a dataset of inputs - **Batch Inference**: run inference on a dataset of inputs
- **Batch Agents**: run agents on a dataset of inputs - **Batch Agents**: run agents on a dataset of inputs
- **Post Training**: fine-tune a model
- **Synthetic Data Generation**: generate synthetic data for model development - **Synthetic Data Generation**: generate synthetic data for model development

View file

@ -504,6 +504,47 @@ created by users sharing a team with them:
description: any user has read access to any resource created by a user with the same team description: any user has read access to any resource created by a user with the same team
``` ```
#### API Endpoint Authorization with Scopes
In addition to resource-based access control, Llama Stack supports endpoint-level authorization using OAuth 2.0 style scopes. When authentication is enabled, specific API endpoints require users to have particular scopes in their authentication token.
**Scope-Gated APIs:**
The following APIs are currently gated by scopes:
- **Telemetry API** (scope: `telemetry.read`):
- `POST /telemetry/traces` - Query traces
- `GET /telemetry/traces/{trace_id}` - Get trace by ID
- `GET /telemetry/traces/{trace_id}/spans/{span_id}` - Get span by ID
- `POST /telemetry/spans/{span_id}/tree` - Get span tree
- `POST /telemetry/spans` - Query spans
- `POST /telemetry/metrics/{metric_name}` - Query metrics
**Authentication Configuration:**
For **JWT/OAuth2 providers**, scopes should be included in the JWT's claims:
```json
{
"sub": "user123",
"scope": "telemetry.read",
"aud": "llama-stack"
}
```
For **custom authentication providers**, the endpoint must return user attributes including the `scopes` array:
```json
{
"principal": "user123",
"attributes": {
"scopes": ["telemetry.read"]
}
}
```
**Behavior:**
- Users without the required scope receive a 403 Forbidden response
- When authentication is disabled, scope checks are bypassed
- Endpoints without `required_scope` work normally for all authenticated users
### Quota Configuration ### Quota Configuration
The `quota` section allows you to enable server-side request throttling for both The `quota` section allows you to enable server-side request throttling for both

View file

@ -1,3 +1,6 @@
---
orphan: true
---
<!-- This file was auto-generated by distro_codegen.py, please edit source --> <!-- This file was auto-generated by distro_codegen.py, please edit source -->
# NVIDIA Distribution # NVIDIA Distribution

View file

@ -7,7 +7,17 @@ Llama Stack supports external providers that live outside of the main codebase.
## Configuration ## Configuration
To enable external providers, you need to configure the `external_providers_dir` in your Llama Stack configuration. This directory should contain your external provider specifications: To enable external providers, you need to add `module` into your build yaml, allowing Llama Stack to install the required package corresponding to the external provider.
an example entry in your build.yaml should look like:
```
- provider_id: ramalama
provider_type: remote::ramalama
module: ramalama_stack
```
Additionally you can configure the `external_providers_dir` in your Llama Stack configuration. This method is in the process of being deprecated in favor of the `module` method. If using this method, the external provider directory should contain your external provider specifications:
```yaml ```yaml
external_providers_dir: ~/.llama/providers.d/ external_providers_dir: ~/.llama/providers.d/
@ -112,6 +122,31 @@ container_image: custom-vector-store:latest # optional
## Required Implementation ## Required Implementation
## All Providers
All providers must contain a `get_provider_spec` function in their `provider` module. This is a standardized structure that Llama Stack expects and is necessary for getting things such as the config class. The `get_provider_spec` method returns a structure identical to the `adapter`. An example function may look like:
```python
from llama_stack.providers.datatypes import (
ProviderSpec,
Api,
AdapterSpec,
remote_provider_spec,
)
def get_provider_spec() -> ProviderSpec:
return remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="ramalama",
pip_packages=["ramalama>=0.8.5", "pymilvus"],
config_class="ramalama_stack.config.RamalamaImplConfig",
module="ramalama_stack",
),
)
```
### Remote Providers ### Remote Providers
Remote providers must expose a `get_adapter_impl()` function in their module that takes two arguments: Remote providers must expose a `get_adapter_impl()` function in their module that takes two arguments:
@ -155,7 +190,7 @@ Version: 0.1.0
Location: /path/to/venv/lib/python3.10/site-packages Location: /path/to/venv/lib/python3.10/site-packages
``` ```
## Example: Custom Ollama Provider ## Example using `external_providers_dir`: Custom Ollama Provider
Here's a complete example of creating and using a custom Ollama provider: Here's a complete example of creating and using a custom Ollama provider:
@ -206,6 +241,35 @@ external_providers_dir: ~/.llama/providers.d/
The provider will now be available in Llama Stack with the type `remote::custom_ollama`. The provider will now be available in Llama Stack with the type `remote::custom_ollama`.
## Example using `module`: ramalama-stack
[ramalama-stack](https://github.com/containers/ramalama-stack) is a recognized external provider that supports installation via module.
To install Llama Stack with this external provider a user can provider the following build.yaml:
```yaml
version: 2
distribution_spec:
description: Use (an external) Ramalama server for running LLM inference
container_image: null
providers:
inference:
- provider_id: ramalama
provider_type: remote::ramalama
module: ramalama_stack==0.3.0a0
image_type: venv
image_name: null
external_providers_dir: null
additional_pip_packages:
- aiosqlite
- sqlalchemy[asyncio]
```
No other steps are required other than `llama stack build` and `llama stack run`. The build process will use `module` to install all of the provider dependencies, retrieve the spec, etc.
The provider will now be available in Llama Stack with the type `remote::ramalama`.
## Best Practices ## Best Practices
1. **Package Naming**: Use the prefix `llama-stack-provider-` for your provider packages to make them easily identifiable. 1. **Package Naming**: Use the prefix `llama-stack-provider-` for your provider packages to make them easily identifiable.
@ -229,9 +293,10 @@ information. Execute the test for the Provider type you are developing.
If your external provider isn't being loaded: If your external provider isn't being loaded:
1. Check that `module` points to a published pip package with a top level `provider` module including `get_provider_spec`.
1. Check that the `external_providers_dir` path is correct and accessible. 1. Check that the `external_providers_dir` path is correct and accessible.
2. Verify that the YAML files are properly formatted. 2. Verify that the YAML files are properly formatted.
3. Ensure all required Python packages are installed. 3. Ensure all required Python packages are installed.
4. Check the Llama Stack server logs for any error messages - turn on debug logging to get more 4. Check the Llama Stack server logs for any error messages - turn on debug logging to get more
information using `LLAMA_STACK_LOGGING=all=debug`. information using `LLAMA_STACK_LOGGING=all=debug`.
5. Verify that the provider package is installed in your Python environment. 5. Verify that the provider package is installed in your Python environment if using `external_providers_dir`.

View file

@ -8,6 +8,7 @@ Fireworks AI inference provider for Llama models and other AI models on the Fire
| Field | Type | Required | Default | Description | | Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `url` | `<class 'str'>` | No | https://api.fireworks.ai/inference/v1 | The URL for the Fireworks server | | `url` | `<class 'str'>` | No | https://api.fireworks.ai/inference/v1 | The URL for the Fireworks server |
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The Fireworks.ai API Key | | `api_key` | `pydantic.types.SecretStr \| None` | No | | The Fireworks.ai API Key |

View file

@ -9,8 +9,7 @@ Ollama inference provider for running local models through the Ollama runtime.
| Field | Type | Required | Default | Description | | Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `url` | `<class 'str'>` | No | http://localhost:11434 | | | `url` | `<class 'str'>` | No | http://localhost:11434 | |
| `refresh_models` | `<class 'bool'>` | No | False | refresh and re-register models periodically | | `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically |
| `refresh_models_interval` | `<class 'int'>` | No | 300 | interval in seconds to refresh models |
## Sample Configuration ## Sample Configuration

View file

@ -8,6 +8,7 @@ Together AI inference provider for open-source models and collaborative AI devel
| Field | Type | Required | Default | Description | | Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `url` | `<class 'str'>` | No | https://api.together.xyz/v1 | The URL for the Together AI server | | `url` | `<class 'str'>` | No | https://api.together.xyz/v1 | The URL for the Together AI server |
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The Together AI API Key | | `api_key` | `pydantic.types.SecretStr \| None` | No | | The Together AI API Key |

View file

@ -13,7 +13,6 @@ Remote vLLM inference provider for connecting to vLLM servers.
| `api_token` | `str \| None` | No | fake | The API token | | `api_token` | `str \| None` | No | fake | The API token |
| `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. | | `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. |
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically | | `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically |
| `refresh_models_interval` | `<class 'int'>` | No | 300 | Interval in seconds to refresh models |
## Sample Configuration ## Sample Configuration

View file

@ -4,15 +4,83 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum from enum import Enum, EnumMeta
from pydantic import BaseModel from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
class DynamicApiMeta(EnumMeta):
def __new__(cls, name, bases, namespace):
# Store the original enum values
original_values = {k: v for k, v in namespace.items() if not k.startswith("_")}
# Create the enum class
cls = super().__new__(cls, name, bases, namespace)
# Store the original values for reference
cls._original_values = original_values
# Initialize _dynamic_values
cls._dynamic_values = {}
return cls
def __call__(cls, value):
try:
return super().__call__(value)
except ValueError as e:
# If this value was already dynamically added, return it
if value in cls._dynamic_values:
return cls._dynamic_values[value]
# If the value doesn't exist, create a new enum member
# Create a new member name from the value
member_name = value.lower().replace("-", "_")
# If this member name already exists in the enum, return the existing member
if member_name in cls._member_map_:
return cls._member_map_[member_name]
# Instead of creating a new member, raise ValueError to force users to use Api.add() to
# register new APIs explicitly
raise ValueError(f"API '{value}' does not exist. Use Api.add() to register new APIs.") from e
def __iter__(cls):
# Allow iteration over both static and dynamic members
yield from super().__iter__()
if hasattr(cls, "_dynamic_values"):
yield from cls._dynamic_values.values()
def add(cls, value):
"""
Add a new API to the enum.
Used to register external APIs.
"""
member_name = value.lower().replace("-", "_")
# If this member name already exists in the enum, return it
if member_name in cls._member_map_:
return cls._member_map_[member_name]
# Create a new enum member
member = object.__new__(cls)
member._name_ = member_name
member._value_ = value
# Add it to the enum class
cls._member_map_[member_name] = member
cls._member_names_.append(member_name)
cls._member_type_ = str
# Store it in our dynamic values
cls._dynamic_values[value] = member
return member
@json_schema_type @json_schema_type
class Api(Enum): class Api(Enum, metaclass=DynamicApiMeta):
"""Enumeration of all available APIs in the Llama Stack system. """Enumeration of all available APIs in the Llama Stack system.
:cvar providers: Provider management and configuration :cvar providers: Provider management and configuration
:cvar inference: Text generation, chat completions, and embeddings :cvar inference: Text generation, chat completions, and embeddings
@ -35,7 +103,6 @@ class Api(Enum):
:cvar files: File storage and management :cvar files: File storage and management
:cvar inspect: Built-in system inspection and introspection :cvar inspect: Built-in system inspection and introspection
""" """
providers = "providers" providers = "providers"
inference = "inference" inference = "inference"
safety = "safety" safety = "safety"
@ -77,3 +144,12 @@ class Error(BaseModel):
title: str title: str
detail: str detail: str
instance: str | None = None instance: str | None = None
class ExternalApiSpec(BaseModel):
"""Specification for an external API implementation."""
module: str = Field(..., description="Python module containing the API implementation")
name: str = Field(..., description="Name of the API")
pip_packages: list[str] = Field(default=[], description="List of pip packages to install the API")
protocol: str = Field(..., description="Name of the protocol class for the API")

View file

@ -911,12 +911,6 @@ class OpenAIEmbeddingsResponse(BaseModel):
class ModelStore(Protocol): class ModelStore(Protocol):
async def get_model(self, identifier: str) -> Model: ... async def get_model(self, identifier: str) -> Model: ...
async def update_registered_llm_models(
self,
provider_id: str,
models: list[Model],
) -> None: ...
class TextTruncation(Enum): class TextTruncation(Enum):
"""Config for how to truncate text for embedding when text is longer than the model's max sequence length. Start and End semantics depend on whether the language is left-to-right or right-to-left. """Config for how to truncate text for embedding when text is longer than the model's max sequence length. Start and End semantics depend on whether the language is left-to-right or right-to-left.

View file

@ -22,6 +22,8 @@ from llama_stack.schema_utils import json_schema_type, register_schema, webmetho
# Add this constant near the top of the file, after the imports # Add this constant near the top of the file, after the imports
DEFAULT_TTL_DAYS = 7 DEFAULT_TTL_DAYS = 7
REQUIRED_SCOPE = "telemetry.read"
@json_schema_type @json_schema_type
class SpanStatus(Enum): class SpanStatus(Enum):
@ -422,7 +424,7 @@ class Telemetry(Protocol):
""" """
... ...
@webmethod(route="/telemetry/traces", method="POST") @webmethod(route="/telemetry/traces", method="POST", required_scope=REQUIRED_SCOPE)
async def query_traces( async def query_traces(
self, self,
attribute_filters: list[QueryCondition] | None = None, attribute_filters: list[QueryCondition] | None = None,
@ -440,7 +442,7 @@ class Telemetry(Protocol):
""" """
... ...
@webmethod(route="/telemetry/traces/{trace_id:path}", method="GET") @webmethod(route="/telemetry/traces/{trace_id:path}", method="GET", required_scope=REQUIRED_SCOPE)
async def get_trace(self, trace_id: str) -> Trace: async def get_trace(self, trace_id: str) -> Trace:
"""Get a trace by its ID. """Get a trace by its ID.
@ -449,7 +451,9 @@ class Telemetry(Protocol):
""" """
... ...
@webmethod(route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}", method="GET") @webmethod(
route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}", method="GET", required_scope=REQUIRED_SCOPE
)
async def get_span(self, trace_id: str, span_id: str) -> Span: async def get_span(self, trace_id: str, span_id: str) -> Span:
"""Get a span by its ID. """Get a span by its ID.
@ -459,7 +463,7 @@ class Telemetry(Protocol):
""" """
... ...
@webmethod(route="/telemetry/spans/{span_id:path}/tree", method="POST") @webmethod(route="/telemetry/spans/{span_id:path}/tree", method="POST", required_scope=REQUIRED_SCOPE)
async def get_span_tree( async def get_span_tree(
self, self,
span_id: str, span_id: str,
@ -475,7 +479,7 @@ class Telemetry(Protocol):
""" """
... ...
@webmethod(route="/telemetry/spans", method="POST") @webmethod(route="/telemetry/spans", method="POST", required_scope=REQUIRED_SCOPE)
async def query_spans( async def query_spans(
self, self,
attribute_filters: list[QueryCondition], attribute_filters: list[QueryCondition],
@ -508,7 +512,7 @@ class Telemetry(Protocol):
""" """
... ...
@webmethod(route="/telemetry/metrics/{metric_name}", method="POST") @webmethod(route="/telemetry/metrics/{metric_name}", method="POST", required_scope=REQUIRED_SCOPE)
async def query_metrics( async def query_metrics(
self, self,
metric_name: str, metric_name: str,

View file

@ -36,6 +36,7 @@ from llama_stack.distribution.datatypes import (
StackRunConfig, StackRunConfig,
) )
from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.external import load_external_apis
from llama_stack.distribution.resolver import InvalidProviderError from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.stack import replace_env_vars from llama_stack.distribution.stack import replace_env_vars
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR, EXTERNAL_PROVIDERS_DIR from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR, EXTERNAL_PROVIDERS_DIR
@ -93,7 +94,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
) )
sys.exit(1) sys.exit(1)
elif args.providers: elif args.providers:
providers_list: dict[str, str | list[str]] = dict() provider_list: dict[str, list[Provider]] = dict()
for api_provider in args.providers.split(","): for api_provider in args.providers.split(","):
if "=" not in api_provider: if "=" not in api_provider:
cprint( cprint(
@ -102,7 +103,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
file=sys.stderr, file=sys.stderr,
) )
sys.exit(1) sys.exit(1)
api, provider = api_provider.split("=") api, provider_type = api_provider.split("=")
providers_for_api = get_provider_registry().get(Api(api), None) providers_for_api = get_provider_registry().get(Api(api), None)
if providers_for_api is None: if providers_for_api is None:
cprint( cprint(
@ -111,16 +112,14 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
file=sys.stderr, file=sys.stderr,
) )
sys.exit(1) sys.exit(1)
if provider in providers_for_api: if provider_type in providers_for_api:
if api not in providers_list: provider = Provider(
providers_list[api] = [] provider_type=provider_type,
# Use type guarding to ensure we have a list provider_id=provider_type.split("::")[1],
provider_value = providers_list[api] config={},
if isinstance(provider_value, list): module=None,
provider_value.append(provider) )
else: provider_list.setdefault(api, []).append(provider)
# Convert string to list and append
providers_list[api] = [provider_value, provider]
else: else:
cprint( cprint(
f"{provider} is not a valid provider for the {api} API.", f"{provider} is not a valid provider for the {api} API.",
@ -129,7 +128,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
) )
sys.exit(1) sys.exit(1)
distribution_spec = DistributionSpec( distribution_spec = DistributionSpec(
providers=providers_list, providers=provider_list,
description=",".join(args.providers), description=",".join(args.providers),
) )
if not args.image_type: if not args.image_type:
@ -190,7 +189,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
cprint("Tip: use <TAB> to see options for the providers.\n", color="green", file=sys.stderr) cprint("Tip: use <TAB> to see options for the providers.\n", color="green", file=sys.stderr)
providers: dict[str, str | list[str]] = dict() providers: dict[str, list[Provider]] = dict()
for api, providers_for_api in get_provider_registry().items(): for api, providers_for_api in get_provider_registry().items():
available_providers = [x for x in providers_for_api.keys() if x not in ("remote", "remote::sample")] available_providers = [x for x in providers_for_api.keys() if x not in ("remote", "remote::sample")]
if not available_providers: if not available_providers:
@ -236,11 +235,13 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
if args.print_deps_only: if args.print_deps_only:
print(f"# Dependencies for {args.template or args.config or image_name}") print(f"# Dependencies for {args.template or args.config or image_name}")
normal_deps, special_deps = get_provider_dependencies(build_config) normal_deps, special_deps, external_provider_dependencies = get_provider_dependencies(build_config)
normal_deps += SERVER_DEPENDENCIES normal_deps += SERVER_DEPENDENCIES
print(f"uv pip install {' '.join(normal_deps)}") print(f"uv pip install {' '.join(normal_deps)}")
for special_dep in special_deps: for special_dep in special_deps:
print(f"uv pip install {special_dep}") print(f"uv pip install {special_dep}")
for external_dep in external_provider_dependencies:
print(f"uv pip install {external_dep}")
return return
try: try:
@ -303,27 +304,25 @@ def _generate_run_config(
provider_registry = get_provider_registry(build_config) provider_registry = get_provider_registry(build_config)
for api in apis: for api in apis:
run_config.providers[api] = [] run_config.providers[api] = []
provider_types = build_config.distribution_spec.providers[api] providers = build_config.distribution_spec.providers[api]
if isinstance(provider_types, str):
provider_types = [provider_types]
for i, provider_type in enumerate(provider_types): for provider in providers:
pid = provider_type.split("::")[-1] pid = provider.provider_id
p = provider_registry[Api(api)][provider_type] p = provider_registry[Api(api)][provider.provider_type]
if p.deprecation_error: if p.deprecation_error:
raise InvalidProviderError(p.deprecation_error) raise InvalidProviderError(p.deprecation_error)
try: try:
config_type = instantiate_class_type(provider_registry[Api(api)][provider_type].config_class) config_type = instantiate_class_type(provider_registry[Api(api)][provider.provider_type].config_class)
except ModuleNotFoundError: except (ModuleNotFoundError, ValueError) as exc:
# HACK ALERT: # HACK ALERT:
# This code executes after building is done, the import cannot work since the # This code executes after building is done, the import cannot work since the
# package is either available in the venv or container - not available on the host. # package is either available in the venv or container - not available on the host.
# TODO: use a "is_external" flag in ProviderSpec to check if the provider is # TODO: use a "is_external" flag in ProviderSpec to check if the provider is
# external # external
cprint( cprint(
f"Failed to import provider {provider_type} for API {api} - assuming it's external, skipping", f"Failed to import provider {provider.provider_type} for API {api} - assuming it's external, skipping: {exc}",
color="yellow", color="yellow",
file=sys.stderr, file=sys.stderr,
) )
@ -336,9 +335,10 @@ def _generate_run_config(
config = {} config = {}
p_spec = Provider( p_spec = Provider(
provider_id=f"{pid}-{i}" if len(provider_types) > 1 else pid, provider_id=pid,
provider_type=provider_type, provider_type=provider.provider_type,
config=config, config=config,
module=provider.module,
) )
run_config.providers[api].append(p_spec) run_config.providers[api].append(p_spec)
@ -401,9 +401,32 @@ def _run_stack_build_command_from_build_config(
run_config_file = _generate_run_config(build_config, build_dir, image_name) run_config_file = _generate_run_config(build_config, build_dir, image_name)
with open(build_file_path, "w") as f: with open(build_file_path, "w") as f:
to_write = json.loads(build_config.model_dump_json()) to_write = json.loads(build_config.model_dump_json(exclude_none=True))
f.write(yaml.dump(to_write, sort_keys=False)) f.write(yaml.dump(to_write, sort_keys=False))
# We first install the external APIs so that the build process can use them and discover the
# providers dependencies
if build_config.external_apis_dir:
cprint("Installing external APIs", color="yellow", file=sys.stderr)
external_apis = load_external_apis(build_config)
if external_apis:
# install the external APIs
packages = []
for _, api_spec in external_apis.items():
if api_spec.pip_packages:
packages.extend(api_spec.pip_packages)
cprint(
f"Installing {api_spec.name} with pip packages {api_spec.pip_packages}",
color="yellow",
file=sys.stderr,
)
return_code = run_command(["uv", "pip", "install", *packages])
if return_code != 0:
packages_str = ", ".join(packages)
raise RuntimeError(
f"Failed to install external APIs packages: {packages_str} (return code: {return_code})"
)
return_code = build_image( return_code = build_image(
build_config, build_config,
build_file_path, build_file_path,

View file

@ -14,6 +14,7 @@ from termcolor import cprint
from llama_stack.distribution.datatypes import BuildConfig from llama_stack.distribution.datatypes import BuildConfig
from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.external import load_external_apis
from llama_stack.distribution.utils.exec import run_command from llama_stack.distribution.utils.exec import run_command
from llama_stack.distribution.utils.image_types import LlamaStackImageType from llama_stack.distribution.utils.image_types import LlamaStackImageType
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
@ -41,7 +42,7 @@ class ApiInput(BaseModel):
def get_provider_dependencies( def get_provider_dependencies(
config: BuildConfig | DistributionTemplate, config: BuildConfig | DistributionTemplate,
) -> tuple[list[str], list[str]]: ) -> tuple[list[str], list[str], list[str]]:
"""Get normal and special dependencies from provider configuration.""" """Get normal and special dependencies from provider configuration."""
if isinstance(config, DistributionTemplate): if isinstance(config, DistributionTemplate):
config = config.build_config() config = config.build_config()
@ -50,6 +51,7 @@ def get_provider_dependencies(
additional_pip_packages = config.additional_pip_packages additional_pip_packages = config.additional_pip_packages
deps = [] deps = []
external_provider_deps = []
registry = get_provider_registry(config) registry = get_provider_registry(config)
for api_str, provider_or_providers in providers.items(): for api_str, provider_or_providers in providers.items():
providers_for_api = registry[Api(api_str)] providers_for_api = registry[Api(api_str)]
@ -64,8 +66,16 @@ def get_provider_dependencies(
raise ValueError(f"Provider `{provider}` is not available for API `{api_str}`") raise ValueError(f"Provider `{provider}` is not available for API `{api_str}`")
provider_spec = providers_for_api[provider_type] provider_spec = providers_for_api[provider_type]
deps.extend(provider_spec.pip_packages) if hasattr(provider_spec, "is_external") and provider_spec.is_external:
if provider_spec.container_image: # this ensures we install the top level module for our external providers
if provider_spec.module:
if isinstance(provider_spec.module, str):
external_provider_deps.append(provider_spec.module)
else:
external_provider_deps.extend(provider_spec.module)
if hasattr(provider_spec, "pip_packages"):
deps.extend(provider_spec.pip_packages)
if hasattr(provider_spec, "container_image") and provider_spec.container_image:
raise ValueError("A stack's dependencies cannot have a container image") raise ValueError("A stack's dependencies cannot have a container image")
normal_deps = [] normal_deps = []
@ -78,7 +88,7 @@ def get_provider_dependencies(
normal_deps.extend(additional_pip_packages or []) normal_deps.extend(additional_pip_packages or [])
return list(set(normal_deps)), list(set(special_deps)) return list(set(normal_deps)), list(set(special_deps)), list(set(external_provider_deps))
def print_pip_install_help(config: BuildConfig): def print_pip_install_help(config: BuildConfig):
@ -103,41 +113,59 @@ def build_image(
): ):
container_base = build_config.distribution_spec.container_image or "python:3.12-slim" container_base = build_config.distribution_spec.container_image or "python:3.12-slim"
normal_deps, special_deps = get_provider_dependencies(build_config) normal_deps, special_deps, external_provider_deps = get_provider_dependencies(build_config)
normal_deps += SERVER_DEPENDENCIES normal_deps += SERVER_DEPENDENCIES
if build_config.external_apis_dir:
external_apis = load_external_apis(build_config)
if external_apis:
for _, api_spec in external_apis.items():
normal_deps.extend(api_spec.pip_packages)
if build_config.image_type == LlamaStackImageType.CONTAINER.value: if build_config.image_type == LlamaStackImageType.CONTAINER.value:
script = str(importlib.resources.files("llama_stack") / "distribution/build_container.sh") script = str(importlib.resources.files("llama_stack") / "distribution/build_container.sh")
args = [ args = [
script, script,
"--template-or-config",
template_or_config, template_or_config,
"--image-name",
image_name, image_name,
"--container-base",
container_base, container_base,
"--normal-deps",
" ".join(normal_deps), " ".join(normal_deps),
] ]
# When building from a config file (not a template), include the run config path in the # When building from a config file (not a template), include the run config path in the
# build arguments # build arguments
if run_config is not None: if run_config is not None:
args.append(run_config) args.extend(["--run-config", run_config])
elif build_config.image_type == LlamaStackImageType.CONDA.value: elif build_config.image_type == LlamaStackImageType.CONDA.value:
script = str(importlib.resources.files("llama_stack") / "distribution/build_conda_env.sh") script = str(importlib.resources.files("llama_stack") / "distribution/build_conda_env.sh")
args = [ args = [
script, script,
"--env-name",
str(image_name), str(image_name),
"--build-file-path",
str(build_file_path), str(build_file_path),
"--normal-deps",
" ".join(normal_deps), " ".join(normal_deps),
] ]
elif build_config.image_type == LlamaStackImageType.VENV.value: elif build_config.image_type == LlamaStackImageType.VENV.value:
script = str(importlib.resources.files("llama_stack") / "distribution/build_venv.sh") script = str(importlib.resources.files("llama_stack") / "distribution/build_venv.sh")
args = [ args = [
script, script,
"--env-name",
str(image_name), str(image_name),
"--normal-deps",
" ".join(normal_deps), " ".join(normal_deps),
] ]
# Always pass both arguments, even if empty, to maintain consistent positional arguments
if special_deps: if special_deps:
args.append("#".join(special_deps)) args.extend(["--optional-deps", "#".join(special_deps)])
if external_provider_deps:
args.extend(
["--external-provider-deps", "#".join(external_provider_deps)]
) # the script will install external provider module, get its deps, and install those too.
return_code = run_command(args) return_code = run_command(args)

View file

@ -9,10 +9,91 @@
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-} LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
LLAMA_STACK_CLIENT_DIR=${LLAMA_STACK_CLIENT_DIR:-} LLAMA_STACK_CLIENT_DIR=${LLAMA_STACK_CLIENT_DIR:-}
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-} TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
PYPI_VERSION=${PYPI_VERSION:-}
# This timeout (in seconds) is necessary when installing PyTorch via uv since it's likely to time out # This timeout (in seconds) is necessary when installing PyTorch via uv since it's likely to time out
# Reference: https://github.com/astral-sh/uv/pull/1694 # Reference: https://github.com/astral-sh/uv/pull/1694
UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500} UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
set -euo pipefail
# Define color codes
RED='\033[0;31m'
GREEN='\033[0;32m'
NC='\033[0m' # No Color
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
source "$SCRIPT_DIR/common.sh"
# Usage function
usage() {
echo "Usage: $0 --env-name <conda_env_name> --build-file-path <build_file_path> --normal-deps <pip_dependencies> [--external-provider-deps <external_provider_deps>] [--optional-deps <special_pip_deps>]"
echo "Example: $0 --env-name my-conda-env --build-file-path ./my-stack-build.yaml --normal-deps 'numpy pandas scipy' --external-provider-deps 'foo' --optional-deps 'bar'"
exit 1
}
# Parse arguments
env_name=""
build_file_path=""
normal_deps=""
external_provider_deps=""
optional_deps=""
while [[ $# -gt 0 ]]; do
key="$1"
case "$key" in
--env-name)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --env-name requires a string value" >&2
usage
fi
env_name="$2"
shift 2
;;
--build-file-path)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --build-file-path requires a string value" >&2
usage
fi
build_file_path="$2"
shift 2
;;
--normal-deps)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --normal-deps requires a string value" >&2
usage
fi
normal_deps="$2"
shift 2
;;
--external-provider-deps)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --external-provider-deps requires a string value" >&2
usage
fi
external_provider_deps="$2"
shift 2
;;
--optional-deps)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --optional-deps requires a string value" >&2
usage
fi
optional_deps="$2"
shift 2
;;
*)
echo "Unknown option: $1" >&2
usage
;;
esac
done
# Check required arguments
if [[ -z "$env_name" || -z "$build_file_path" || -z "$normal_deps" ]]; then
echo "Error: --env-name, --build-file-path, and --normal-deps are required." >&2
usage
fi
if [ -n "$LLAMA_STACK_DIR" ]; then if [ -n "$LLAMA_STACK_DIR" ]; then
echo "Using llama-stack-dir=$LLAMA_STACK_DIR" echo "Using llama-stack-dir=$LLAMA_STACK_DIR"
fi fi
@ -20,50 +101,18 @@ if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
echo "Using llama-stack-client-dir=$LLAMA_STACK_CLIENT_DIR" echo "Using llama-stack-client-dir=$LLAMA_STACK_CLIENT_DIR"
fi fi
if [ "$#" -lt 3 ]; then
echo "Usage: $0 <distribution_type> <conda_env_name> <build_file_path> <pip_dependencies> [<special_pip_deps>]" >&2
echo "Example: $0 <distribution_type> my-conda-env ./my-stack-build.yaml 'numpy pandas scipy'" >&2
exit 1
fi
special_pip_deps="$4"
set -euo pipefail
env_name="$1"
build_file_path="$2"
pip_dependencies="$3"
# Define color codes
RED='\033[0;31m'
GREEN='\033[0;32m'
NC='\033[0m' # No Color
# this is set if we actually create a new conda in which case we need to clean up
ENVNAME=""
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
source "$SCRIPT_DIR/common.sh"
ensure_conda_env_python310() { ensure_conda_env_python310() {
local env_name="$1" # Use only global variables set by flag parser
local pip_dependencies="$2"
local special_pip_deps="$3"
local python_version="3.12" local python_version="3.12"
# Check if conda command is available
if ! is_command_available conda; then if ! is_command_available conda; then
printf "${RED}Error: conda command not found. Is Conda installed and in your PATH?${NC}" >&2 printf "${RED}Error: conda command not found. Is Conda installed and in your PATH?${NC}" >&2
exit 1 exit 1
fi fi
# Check if the environment exists
if conda env list | grep -q "^${env_name} "; then if conda env list | grep -q "^${env_name} "; then
printf "Conda environment '${env_name}' exists. Checking Python version...\n" printf "Conda environment '${env_name}' exists. Checking Python version...\n"
# Check Python version in the environment
current_version=$(conda run -n "${env_name}" python --version 2>&1 | cut -d' ' -f2 | cut -d'.' -f1,2) current_version=$(conda run -n "${env_name}" python --version 2>&1 | cut -d' ' -f2 | cut -d'.' -f1,2)
if [ "$current_version" = "$python_version" ]; then if [ "$current_version" = "$python_version" ]; then
printf "Environment '${env_name}' already has Python ${python_version}. No action needed.\n" printf "Environment '${env_name}' already has Python ${python_version}. No action needed.\n"
else else
@ -73,37 +122,37 @@ ensure_conda_env_python310() {
else else
printf "Conda environment '${env_name}' does not exist. Creating with Python ${python_version}...\n" printf "Conda environment '${env_name}' does not exist. Creating with Python ${python_version}...\n"
conda create -n "${env_name}" python="${python_version}" -y conda create -n "${env_name}" python="${python_version}" -y
ENVNAME="${env_name}"
# setup_cleanup_handlers
fi fi
eval "$(conda shell.bash hook)" eval "$(conda shell.bash hook)"
conda deactivate && conda activate "${env_name}" conda deactivate && conda activate "${env_name}"
"$CONDA_PREFIX"/bin/pip install uv "$CONDA_PREFIX"/bin/pip install uv
if [ -n "$TEST_PYPI_VERSION" ]; then if [ -n "$TEST_PYPI_VERSION" ]; then
# these packages are damaged in test-pypi, so install them first
uv pip install fastapi libcst uv pip install fastapi libcst
uv pip install --extra-index-url https://test.pypi.org/simple/ \ uv pip install --extra-index-url https://test.pypi.org/simple/ \
llama-stack=="$TEST_PYPI_VERSION" \ llama-stack=="$TEST_PYPI_VERSION" \
"$pip_dependencies" "$normal_deps"
if [ -n "$special_pip_deps" ]; then if [ -n "$optional_deps" ]; then
IFS='#' read -ra parts <<<"$special_pip_deps" IFS='#' read -ra parts <<<"$optional_deps"
for part in "${parts[@]}"; do
echo "$part"
uv pip install $part
done
fi
if [ -n "$external_provider_deps" ]; then
IFS='#' read -ra parts <<<"$external_provider_deps"
for part in "${parts[@]}"; do for part in "${parts[@]}"; do
echo "$part" echo "$part"
uv pip install "$part" uv pip install "$part"
done done
fi fi
else else
# Re-installing llama-stack in the new conda environment
if [ -n "$LLAMA_STACK_DIR" ]; then if [ -n "$LLAMA_STACK_DIR" ]; then
if [ ! -d "$LLAMA_STACK_DIR" ]; then if [ ! -d "$LLAMA_STACK_DIR" ]; then
printf "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: $LLAMA_STACK_DIR${NC}\n" >&2 printf "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: $LLAMA_STACK_DIR${NC}\n" >&2
exit 1 exit 1
fi fi
printf "Installing from LLAMA_STACK_DIR: $LLAMA_STACK_DIR\n" printf "Installing from LLAMA_STACK_DIR: $LLAMA_STACK_DIR\n"
uv pip install --no-cache-dir -e "$LLAMA_STACK_DIR" uv pip install --no-cache-dir -e "$LLAMA_STACK_DIR"
else else
@ -115,31 +164,44 @@ ensure_conda_env_python310() {
fi fi
uv pip install --no-cache-dir "$SPEC_VERSION" uv pip install --no-cache-dir "$SPEC_VERSION"
fi fi
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
if [ ! -d "$LLAMA_STACK_CLIENT_DIR" ]; then if [ ! -d "$LLAMA_STACK_CLIENT_DIR" ]; then
printf "${RED}Warning: LLAMA_STACK_CLIENT_DIR is set but directory does not exist: $LLAMA_STACK_CLIENT_DIR${NC}\n" >&2 printf "${RED}Warning: LLAMA_STACK_CLIENT_DIR is set but directory does not exist: $LLAMA_STACK_CLIENT_DIR${NC}\n" >&2
exit 1 exit 1
fi fi
printf "Installing from LLAMA_STACK_CLIENT_DIR: $LLAMA_STACK_CLIENT_DIR\n" printf "Installing from LLAMA_STACK_CLIENT_DIR: $LLAMA_STACK_CLIENT_DIR\n"
uv pip install --no-cache-dir -e "$LLAMA_STACK_CLIENT_DIR" uv pip install --no-cache-dir -e "$LLAMA_STACK_CLIENT_DIR"
fi fi
# Install pip dependencies
printf "Installing pip dependencies\n" printf "Installing pip dependencies\n"
uv pip install $pip_dependencies uv pip install $normal_deps
if [ -n "$special_pip_deps" ]; then if [ -n "$optional_deps" ]; then
IFS='#' read -ra parts <<<"$special_pip_deps" IFS='#' read -ra parts <<<"$optional_deps"
for part in "${parts[@]}"; do for part in "${parts[@]}"; do
echo "$part" echo "$part"
uv pip install $part uv pip install $part
done done
fi fi
if [ -n "$external_provider_deps" ]; then
IFS='#' read -ra parts <<<"$external_provider_deps"
for part in "${parts[@]}"; do
echo "Getting provider spec for module: $part and installing dependencies"
package_name=$(echo "$part" | sed 's/[<>=!].*//')
python3 -c "
import importlib
import sys
try:
module = importlib.import_module(f'$package_name.provider')
spec = module.get_provider_spec()
if hasattr(spec, 'pip_packages') and spec.pip_packages:
print('\\n'.join(spec.pip_packages))
except Exception as e:
print(f'Error getting provider spec for $package_name: {e}', file=sys.stderr)
" | uv pip install -r -
done
fi
fi fi
mv "$build_file_path" "$CONDA_PREFIX"/llamastack-build.yaml mv "$build_file_path" "$CONDA_PREFIX"/llamastack-build.yaml
echo "Build spec configuration saved at $CONDA_PREFIX/llamastack-build.yaml" echo "Build spec configuration saved at $CONDA_PREFIX/llamastack-build.yaml"
} }
ensure_conda_env_python310 "$env_name" "$pip_dependencies" "$special_pip_deps" ensure_conda_env_python310 "$env_name" "$build_file_path" "$normal_deps" "$optional_deps" "$external_provider_deps"

View file

@ -27,52 +27,103 @@ RUN_CONFIG_PATH=/app/run.yaml
BUILD_CONTEXT_DIR=$(pwd) BUILD_CONTEXT_DIR=$(pwd)
if [ "$#" -lt 4 ]; then
# This only works for templates
echo "Usage: $0 <template_or_config> <image_name> <container_base> <pip_dependencies> [<run_config>] [<special_pip_deps>]" >&2
exit 1
fi
set -euo pipefail set -euo pipefail
template_or_config="$1"
shift
image_name="$1"
shift
container_base="$1"
shift
pip_dependencies="$1"
shift
# Handle optional arguments
run_config=""
special_pip_deps=""
# Check if there are more arguments
# The logics is becoming cumbersom, we should refactor it if we can do better
if [ $# -gt 0 ]; then
# Check if the argument ends with .yaml
if [[ "$1" == *.yaml ]]; then
run_config="$1"
shift
# If there's another argument after .yaml, it must be special_pip_deps
if [ $# -gt 0 ]; then
special_pip_deps="$1"
fi
else
# If it's not .yaml, it must be special_pip_deps
special_pip_deps="$1"
fi
fi
# Define color codes # Define color codes
RED='\033[0;31m' RED='\033[0;31m'
NC='\033[0m' # No Color NC='\033[0m' # No Color
# Usage function
usage() {
echo "Usage: $0 --image-name <image_name> --container-base <container_base> --normal-deps <pip_dependencies> [--run-config <run_config>] [--external-provider-deps <external_provider_deps>] [--optional-deps <special_pip_deps>]"
echo "Example: $0 --image-name llama-stack-img --container-base python:3.12-slim --normal-deps 'numpy pandas' --run-config ./run.yaml --external-provider-deps 'foo' --optional-deps 'bar'"
exit 1
}
# Parse arguments
image_name=""
container_base=""
normal_deps=""
external_provider_deps=""
optional_deps=""
run_config=""
template_or_config=""
while [[ $# -gt 0 ]]; do
key="$1"
case "$key" in
--image-name)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --image-name requires a string value" >&2
usage
fi
image_name="$2"
shift 2
;;
--container-base)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --container-base requires a string value" >&2
usage
fi
container_base="$2"
shift 2
;;
--normal-deps)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --normal-deps requires a string value" >&2
usage
fi
normal_deps="$2"
shift 2
;;
--external-provider-deps)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --external-provider-deps requires a string value" >&2
usage
fi
external_provider_deps="$2"
shift 2
;;
--optional-deps)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --optional-deps requires a string value" >&2
usage
fi
optional_deps="$2"
shift 2
;;
--run-config)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --run-config requires a string value" >&2
usage
fi
run_config="$2"
shift 2
;;
--template-or-config)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --template-or-config requires a string value" >&2
usage
fi
template_or_config="$2"
shift 2
;;
*)
echo "Unknown option: $1" >&2
usage
;;
esac
done
# Check required arguments
if [[ -z "$image_name" || -z "$container_base" || -z "$normal_deps" ]]; then
echo "Error: --image-name, --container-base, and --normal-deps are required." >&2
usage
fi
CONTAINER_BINARY=${CONTAINER_BINARY:-docker} CONTAINER_BINARY=${CONTAINER_BINARY:-docker}
CONTAINER_OPTS=${CONTAINER_OPTS:---progress=plain} CONTAINER_OPTS=${CONTAINER_OPTS:---progress=plain}
TEMP_DIR=$(mktemp -d) TEMP_DIR=$(mktemp -d)
SCRIPT_DIR=$(dirname "$(readlink -f "$0")") SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
source "$SCRIPT_DIR/common.sh" source "$SCRIPT_DIR/common.sh"
@ -81,18 +132,15 @@ add_to_container() {
if [ -t 0 ]; then if [ -t 0 ]; then
printf '%s\n' "$1" >>"$output_file" printf '%s\n' "$1" >>"$output_file"
else else
# If stdin is not a terminal, read from it (heredoc)
cat >>"$output_file" cat >>"$output_file"
fi fi
} }
# Check if container command is available
if ! is_command_available "$CONTAINER_BINARY"; then if ! is_command_available "$CONTAINER_BINARY"; then
printf "${RED}Error: ${CONTAINER_BINARY} command not found. Is ${CONTAINER_BINARY} installed and in your PATH?${NC}" >&2 printf "${RED}Error: ${CONTAINER_BINARY} command not found. Is ${CONTAINER_BINARY} installed and in your PATH?${NC}" >&2
exit 1 exit 1
fi fi
# Update and install UBI9 components if UBI9 base image is used
if [[ $container_base == *"registry.access.redhat.com/ubi9"* ]]; then if [[ $container_base == *"registry.access.redhat.com/ubi9"* ]]; then
add_to_container << EOF add_to_container << EOF
FROM $container_base FROM $container_base
@ -135,16 +183,16 @@ EOF
# Add pip dependencies first since llama-stack is what will change most often # Add pip dependencies first since llama-stack is what will change most often
# so we can reuse layers. # so we can reuse layers.
if [ -n "$pip_dependencies" ]; then if [ -n "$normal_deps" ]; then
read -ra pip_args <<< "$pip_dependencies" read -ra pip_args <<< "$normal_deps"
quoted_deps=$(printf " %q" "${pip_args[@]}") quoted_deps=$(printf " %q" "${pip_args[@]}")
add_to_container << EOF add_to_container << EOF
RUN $MOUNT_CACHE uv pip install $quoted_deps RUN $MOUNT_CACHE uv pip install $quoted_deps
EOF EOF
fi fi
if [ -n "$special_pip_deps" ]; then if [ -n "$optional_deps" ]; then
IFS='#' read -ra parts <<<"$special_pip_deps" IFS='#' read -ra parts <<<"$optional_deps"
for part in "${parts[@]}"; do for part in "${parts[@]}"; do
read -ra pip_args <<< "$part" read -ra pip_args <<< "$part"
quoted_deps=$(printf " %q" "${pip_args[@]}") quoted_deps=$(printf " %q" "${pip_args[@]}")
@ -154,7 +202,33 @@ EOF
done done
fi fi
# Function to get Python command if [ -n "$external_provider_deps" ]; then
IFS='#' read -ra parts <<<"$external_provider_deps"
for part in "${parts[@]}"; do
read -ra pip_args <<< "$part"
quoted_deps=$(printf " %q" "${pip_args[@]}")
add_to_container <<EOF
RUN $MOUNT_CACHE uv pip install $quoted_deps
EOF
add_to_container <<EOF
RUN python3 - <<PYTHON | $MOUNT_CACHE uv pip install -r -
import importlib
import sys
try:
package_name = '$part'.split('==')[0].split('>=')[0].split('<=')[0].split('!=')[0].split('<')[0].split('>')[0]
module = importlib.import_module(f'{package_name}.provider')
spec = module.get_provider_spec()
if hasattr(spec, 'pip_packages') and spec.pip_packages:
if isinstance(spec.pip_packages, (list, tuple)):
print('\n'.join(spec.pip_packages))
except Exception as e:
print(f'Error getting provider spec for {package_name}: {e}', file=sys.stderr)
PYTHON
EOF
done
fi
get_python_cmd() { get_python_cmd() {
if is_command_available python; then if is_command_available python; then
echo "python" echo "python"

View file

@ -18,6 +18,76 @@ UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
UV_SYSTEM_PYTHON=${UV_SYSTEM_PYTHON:-} UV_SYSTEM_PYTHON=${UV_SYSTEM_PYTHON:-}
VIRTUAL_ENV=${VIRTUAL_ENV:-} VIRTUAL_ENV=${VIRTUAL_ENV:-}
set -euo pipefail
# Define color codes
RED='\033[0;31m'
NC='\033[0m' # No Color
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
source "$SCRIPT_DIR/common.sh"
# Usage function
usage() {
echo "Usage: $0 --env-name <env_name> --normal-deps <pip_dependencies> [--external-provider-deps <external_provider_deps>] [--optional-deps <special_pip_deps>]"
echo "Example: $0 --env-name mybuild --normal-deps 'numpy pandas scipy' --external-provider-deps 'foo' --optional-deps 'bar'"
exit 1
}
# Parse arguments
env_name=""
normal_deps=""
external_provider_deps=""
optional_deps=""
while [[ $# -gt 0 ]]; do
key="$1"
case "$key" in
--env-name)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --env-name requires a string value" >&2
usage
fi
env_name="$2"
shift 2
;;
--normal-deps)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --normal-deps requires a string value" >&2
usage
fi
normal_deps="$2"
shift 2
;;
--external-provider-deps)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --external-provider-deps requires a string value" >&2
usage
fi
external_provider_deps="$2"
shift 2
;;
--optional-deps)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --optional-deps requires a string value" >&2
usage
fi
optional_deps="$2"
shift 2
;;
*)
echo "Unknown option: $1" >&2
usage
;;
esac
done
# Check required arguments
if [[ -z "$env_name" || -z "$normal_deps" ]]; then
echo "Error: --env-name and --normal-deps are required." >&2
usage
fi
if [ -n "$LLAMA_STACK_DIR" ]; then if [ -n "$LLAMA_STACK_DIR" ]; then
echo "Using llama-stack-dir=$LLAMA_STACK_DIR" echo "Using llama-stack-dir=$LLAMA_STACK_DIR"
fi fi
@ -25,29 +95,6 @@ if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
echo "Using llama-stack-client-dir=$LLAMA_STACK_CLIENT_DIR" echo "Using llama-stack-client-dir=$LLAMA_STACK_CLIENT_DIR"
fi fi
if [ "$#" -lt 2 ]; then
echo "Usage: $0 <env_name> <pip_dependencies> [<special_pip_deps>]" >&2
echo "Example: $0 mybuild ./my-stack-build.yaml 'numpy pandas scipy'" >&2
exit 1
fi
special_pip_deps="$3"
set -euo pipefail
env_name="$1"
pip_dependencies="$2"
# Define color codes
RED='\033[0;31m'
NC='\033[0m' # No Color
# this is set if we actually create a new conda in which case we need to clean up
ENVNAME=""
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
source "$SCRIPT_DIR/common.sh"
# pre-run checks to make sure we can proceed with the installation # pre-run checks to make sure we can proceed with the installation
pre_run_checks() { pre_run_checks() {
local env_name="$1" local env_name="$1"
@ -71,49 +118,44 @@ pre_run_checks() {
} }
run() { run() {
local env_name="$1" # Use only global variables set by flag parser
local pip_dependencies="$2"
local special_pip_deps="$3"
if [ -n "$UV_SYSTEM_PYTHON" ] || [ "$env_name" == "__system__" ]; then if [ -n "$UV_SYSTEM_PYTHON" ] || [ "$env_name" == "__system__" ]; then
echo "Installing dependencies in system Python environment" echo "Installing dependencies in system Python environment"
# if env == __system__, ensure we set UV_SYSTEM_PYTHON
export UV_SYSTEM_PYTHON=1 export UV_SYSTEM_PYTHON=1
elif [ "$VIRTUAL_ENV" == "$env_name" ]; then elif [ "$VIRTUAL_ENV" == "$env_name" ]; then
echo "Virtual environment $env_name is already active" echo "Virtual environment $env_name is already active"
else else
echo "Using virtual environment $env_name" echo "Using virtual environment $env_name"
uv venv "$env_name" uv venv "$env_name"
# shellcheck source=/dev/null
source "$env_name/bin/activate" source "$env_name/bin/activate"
fi fi
if [ -n "$TEST_PYPI_VERSION" ]; then if [ -n "$TEST_PYPI_VERSION" ]; then
# these packages are damaged in test-pypi, so install them first
uv pip install fastapi libcst uv pip install fastapi libcst
# shellcheck disable=SC2086
# we are building a command line so word splitting is expected
uv pip install --extra-index-url https://test.pypi.org/simple/ \ uv pip install --extra-index-url https://test.pypi.org/simple/ \
--index-strategy unsafe-best-match \ --index-strategy unsafe-best-match \
llama-stack=="$TEST_PYPI_VERSION" \ llama-stack=="$TEST_PYPI_VERSION" \
$pip_dependencies $normal_deps
if [ -n "$special_pip_deps" ]; then if [ -n "$optional_deps" ]; then
IFS='#' read -ra parts <<<"$special_pip_deps" IFS='#' read -ra parts <<<"$optional_deps"
for part in "${parts[@]}"; do for part in "${parts[@]}"; do
echo "$part" echo "$part"
# shellcheck disable=SC2086
# we are building a command line so word splitting is expected
uv pip install $part uv pip install $part
done done
fi fi
if [ -n "$external_provider_deps" ]; then
IFS='#' read -ra parts <<<"$external_provider_deps"
for part in "${parts[@]}"; do
echo "$part"
uv pip install "$part"
done
fi
else else
# Re-installing llama-stack in the new virtual environment
if [ -n "$LLAMA_STACK_DIR" ]; then if [ -n "$LLAMA_STACK_DIR" ]; then
if [ ! -d "$LLAMA_STACK_DIR" ]; then if [ ! -d "$LLAMA_STACK_DIR" ]; then
printf "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: %s${NC}\n" "$LLAMA_STACK_DIR" >&2 printf "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: %s${NC}\n" "$LLAMA_STACK_DIR" >&2
exit 1 exit 1
fi fi
printf "Installing from LLAMA_STACK_DIR: %s\n" "$LLAMA_STACK_DIR" printf "Installing from LLAMA_STACK_DIR: %s\n" "$LLAMA_STACK_DIR"
uv pip install --no-cache-dir -e "$LLAMA_STACK_DIR" uv pip install --no-cache-dir -e "$LLAMA_STACK_DIR"
else else
@ -125,27 +167,41 @@ run() {
printf "${RED}Warning: LLAMA_STACK_CLIENT_DIR is set but directory does not exist: %s${NC}\n" "$LLAMA_STACK_CLIENT_DIR" >&2 printf "${RED}Warning: LLAMA_STACK_CLIENT_DIR is set but directory does not exist: %s${NC}\n" "$LLAMA_STACK_CLIENT_DIR" >&2
exit 1 exit 1
fi fi
printf "Installing from LLAMA_STACK_CLIENT_DIR: %s\n" "$LLAMA_STACK_CLIENT_DIR" printf "Installing from LLAMA_STACK_CLIENT_DIR: %s\n" "$LLAMA_STACK_CLIENT_DIR"
uv pip install --no-cache-dir -e "$LLAMA_STACK_CLIENT_DIR" uv pip install --no-cache-dir -e "$LLAMA_STACK_CLIENT_DIR"
fi fi
# Install pip dependencies
printf "Installing pip dependencies\n" printf "Installing pip dependencies\n"
# shellcheck disable=SC2086 uv pip install $normal_deps
# we are building a command line so word splitting is expected if [ -n "$optional_deps" ]; then
uv pip install $pip_dependencies IFS='#' read -ra parts <<<"$optional_deps"
if [ -n "$special_pip_deps" ]; then
IFS='#' read -ra parts <<<"$special_pip_deps"
for part in "${parts[@]}"; do for part in "${parts[@]}"; do
echo "$part" echo "Installing special provider module: $part"
# shellcheck disable=SC2086
# we are building a command line so word splitting is expected
uv pip install $part uv pip install $part
done done
fi fi
if [ -n "$external_provider_deps" ]; then
IFS='#' read -ra parts <<<"$external_provider_deps"
for part in "${parts[@]}"; do
echo "Installing external provider module: $part"
uv pip install "$part"
echo "Getting provider spec for module: $part and installing dependencies"
package_name=$(echo "$part" | sed 's/[<>=!].*//')
python3 -c "
import importlib
import sys
try:
module = importlib.import_module(f'$package_name.provider')
spec = module.get_provider_spec()
if hasattr(spec, 'pip_packages') and spec.pip_packages:
print('\\n'.join(spec.pip_packages))
except Exception as e:
print(f'Error getting provider spec for $package_name: {e}', file=sys.stderr)
" | uv pip install -r -
done
fi
fi fi
} }
pre_run_checks "$env_name" pre_run_checks "$env_name"
run "$env_name" "$pip_dependencies" "$special_pip_deps" run

View file

@ -91,21 +91,21 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
logger.info(f"Configuring API `{api_str}`...") logger.info(f"Configuring API `{api_str}`...")
updated_providers = [] updated_providers = []
for i, provider_type in enumerate(plist): for i, provider in enumerate(plist):
if i >= 1: if i >= 1:
others = ", ".join(plist[i:]) others = ", ".join(p.provider_type for p in plist[i:])
logger.info( logger.info(
f"Not configuring other providers ({others}) interactively. Please edit the resulting YAML directly.\n" f"Not configuring other providers ({others}) interactively. Please edit the resulting YAML directly.\n"
) )
break break
logger.info(f"> Configuring provider `({provider_type})`") logger.info(f"> Configuring provider `({provider.provider_type})`")
updated_providers.append( updated_providers.append(
configure_single_provider( configure_single_provider(
provider_registry[api], provider_registry[api],
Provider( Provider(
provider_id=(f"{provider_type}-{i:02d}" if len(plist) > 1 else provider_type), provider_id=(f"{provider.provider_id}-{i:02d}" if len(plist) > 1 else provider.provider_id),
provider_type=provider_type, provider_type=provider.provider_type,
config={}, config={},
), ),
) )

View file

@ -36,6 +36,11 @@ LLAMA_STACK_RUN_CONFIG_VERSION = 2
RoutingKey = str | list[str] RoutingKey = str | list[str]
class RegistryEntrySource(StrEnum):
via_register_api = "via_register_api"
listed_from_provider = "listed_from_provider"
class User(BaseModel): class User(BaseModel):
principal: str principal: str
# further attributes that may be used for access control decisions # further attributes that may be used for access control decisions
@ -50,6 +55,7 @@ class ResourceWithOwner(Resource):
resource. This can be used to constrain access to the resource.""" resource. This can be used to constrain access to the resource."""
owner: User | None = None owner: User | None = None
source: RegistryEntrySource = RegistryEntrySource.via_register_api
# Use the extended Resource for all routable objects # Use the extended Resource for all routable objects
@ -130,29 +136,40 @@ class RoutingTableProviderSpec(ProviderSpec):
pip_packages: list[str] = Field(default_factory=list) pip_packages: list[str] = Field(default_factory=list)
class Provider(BaseModel):
# provider_id of None means that the provider is not enabled - this happens
# when the provider is enabled via a conditional environment variable
provider_id: str | None
provider_type: str
config: dict[str, Any] = {}
module: str | None = Field(
default=None,
description="""
Fully-qualified name of the external provider module to import. The module is expected to have:
- `get_adapter_impl(config, deps)`: returns the adapter implementation
Example: `module: ramalama_stack`
""",
)
class DistributionSpec(BaseModel): class DistributionSpec(BaseModel):
description: str | None = Field( description: str | None = Field(
default="", default="",
description="Description of the distribution", description="Description of the distribution",
) )
container_image: str | None = None container_image: str | None = None
providers: dict[str, str | list[str]] = Field( providers: dict[str, list[Provider]] = Field(
default_factory=dict, default_factory=dict,
description=""" description="""
Provider Types for each of the APIs provided by this distribution. If you Provider Types for each of the APIs provided by this distribution. If you
select multiple providers, you should provide an appropriate 'routing_map' select multiple providers, you should provide an appropriate 'routing_map'
in the runtime configuration to help route to the correct provider.""", in the runtime configuration to help route to the correct provider.
""",
) )
class Provider(BaseModel):
# provider_id of None means that the provider is not enabled - this happens
# when the provider is enabled via a conditional environment variable
provider_id: str | None
provider_type: str
config: dict[str, Any]
class LoggingConfig(BaseModel): class LoggingConfig(BaseModel):
category_levels: dict[str, str] = Field( category_levels: dict[str, str] = Field(
default_factory=dict, default_factory=dict,
@ -381,6 +398,11 @@ a default SQLite store will be used.""",
description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.", description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.",
) )
external_apis_dir: Path | None = Field(
default=None,
description="Path to directory containing external API implementations. The APIs code and dependencies must be installed on the system.",
)
@field_validator("external_providers_dir") @field_validator("external_providers_dir")
@classmethod @classmethod
def validate_external_providers_dir(cls, v): def validate_external_providers_dir(cls, v):
@ -412,6 +434,10 @@ class BuildConfig(BaseModel):
default_factory=list, default_factory=list,
description="Additional pip packages to install in the distribution. These packages will be installed in the distribution environment.", description="Additional pip packages to install in the distribution. These packages will be installed in the distribution environment.",
) )
external_apis_dir: Path | None = Field(
default=None,
description="Path to directory containing external API implementations. The APIs code and dependencies must be installed on the system.",
)
@field_validator("external_providers_dir") @field_validator("external_providers_dir")
@classmethod @classmethod

View file

@ -12,6 +12,8 @@ from typing import Any
import yaml import yaml
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.distribution.datatypes import BuildConfig, DistributionSpec
from llama_stack.distribution.external import load_external_apis
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ( from llama_stack.providers.datatypes import (
AdapterSpec, AdapterSpec,
@ -96,12 +98,10 @@ def _load_inline_provider_spec(spec_data: dict[str, Any], api: Api, provider_nam
return spec return spec
def get_provider_registry( def get_provider_registry(config=None) -> dict[Api, dict[str, ProviderSpec]]:
config=None,
) -> dict[Api, dict[str, ProviderSpec]]:
"""Get the provider registry, optionally including external providers. """Get the provider registry, optionally including external providers.
This function loads both built-in providers and external providers from YAML files. This function loads both built-in providers and external providers from YAML files or from their provided modules.
External providers are loaded from a directory structure like: External providers are loaded from a directory structure like:
providers.d/ providers.d/
@ -122,8 +122,13 @@ def get_provider_registry(
safety/ safety/
llama-guard.yaml llama-guard.yaml
This method is overloaded in that it can be called from a variety of places: during build, during run, during stack construction.
So when building external providers from a module, there are scenarios where the pip package required to import the module might not be available yet.
There is special handling for all of the potential cases this method can be called from.
Args: Args:
config: Optional object containing the external providers directory path config: Optional object containing the external providers directory path
building: Optional bool delineating whether or not this is being called from a build process
Returns: Returns:
A dictionary mapping APIs to their available providers A dictionary mapping APIs to their available providers
@ -133,58 +138,140 @@ def get_provider_registry(
ValueError: If any provider spec is invalid ValueError: If any provider spec is invalid
""" """
ret: dict[Api, dict[str, ProviderSpec]] = {} registry: dict[Api, dict[str, ProviderSpec]] = {}
for api in providable_apis(): for api in providable_apis():
name = api.name.lower() name = api.name.lower()
logger.debug(f"Importing module {name}") logger.debug(f"Importing module {name}")
try: try:
module = importlib.import_module(f"llama_stack.providers.registry.{name}") module = importlib.import_module(f"llama_stack.providers.registry.{name}")
ret[api] = {a.provider_type: a for a in module.available_providers()} registry[api] = {a.provider_type: a for a in module.available_providers()}
except ImportError as e: except ImportError as e:
logger.warning(f"Failed to import module {name}: {e}") logger.warning(f"Failed to import module {name}: {e}")
# Check if config has the external_providers_dir attribute # Refresh providable APIs with external APIs if any
if config and hasattr(config, "external_providers_dir") and config.external_providers_dir: external_apis = load_external_apis(config)
external_providers_dir = os.path.abspath(os.path.expanduser(config.external_providers_dir)) for api, api_spec in external_apis.items():
if not os.path.exists(external_providers_dir): name = api_spec.name.lower()
raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}") logger.info(f"Importing external API {name} module {api_spec.module}")
logger.info(f"Loading external providers from {external_providers_dir}") try:
module = importlib.import_module(api_spec.module)
registry[api] = {a.provider_type: a for a in module.available_providers()}
except (ImportError, AttributeError) as e:
# Populate the registry with an empty dict to avoid breaking the provider registry
# This assume that the in-tree provider(s) are not available for this API which means
# that users will need to use external providers for this API.
registry[api] = {}
logger.error(
f"Failed to import external API {name}: {e}. Could not populate the in-tree provider(s) registry for {api.name}. \n"
"Install the API package to load any in-tree providers for this API."
)
for api in providable_apis(): # Check if config has external providers
api_name = api.name.lower() if config:
if hasattr(config, "external_providers_dir") and config.external_providers_dir:
registry = get_external_providers_from_dir(registry, config)
# else lets check for modules in each provider
registry = get_external_providers_from_module(
registry=registry,
config=config,
building=(isinstance(config, BuildConfig) or isinstance(config, DistributionSpec)),
)
# Process both remote and inline providers return registry
for provider_type in ["remote", "inline"]:
api_dir = os.path.join(external_providers_dir, provider_type, api_name)
if not os.path.exists(api_dir):
logger.debug(f"No {provider_type} provider directory found for {api_name}")
continue
# Look for provider spec files in the API directory
for spec_path in glob.glob(os.path.join(api_dir, "*.yaml")):
provider_name = os.path.splitext(os.path.basename(spec_path))[0]
logger.info(f"Loading {provider_type} provider spec from {spec_path}")
try: def get_external_providers_from_dir(
with open(spec_path) as f: registry: dict[Api, dict[str, ProviderSpec]], config
spec_data = yaml.safe_load(f) ) -> dict[Api, dict[str, ProviderSpec]]:
logger.warning(
"Specifying external providers via `external_providers_dir` is being deprecated. Please specify `module:` in the provider instead."
)
external_providers_dir = os.path.abspath(os.path.expanduser(config.external_providers_dir))
if not os.path.exists(external_providers_dir):
raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}")
logger.info(f"Loading external providers from {external_providers_dir}")
if provider_type == "remote": for api in providable_apis():
spec = _load_remote_provider_spec(spec_data, api) api_name = api.name.lower()
provider_type_key = f"remote::{provider_name}"
else:
spec = _load_inline_provider_spec(spec_data, api, provider_name)
provider_type_key = f"inline::{provider_name}"
logger.info(f"Loaded {provider_type} provider spec for {provider_type_key} from {spec_path}") # Process both remote and inline providers
if provider_type_key in ret[api]: for provider_type in ["remote", "inline"]:
logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}") api_dir = os.path.join(external_providers_dir, provider_type, api_name)
ret[api][provider_type_key] = spec if not os.path.exists(api_dir):
logger.info(f"Successfully loaded external provider {provider_type_key}") logger.debug(f"No {provider_type} provider directory found for {api_name}")
except yaml.YAMLError as yaml_err: continue
logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}")
raise yaml_err # Look for provider spec files in the API directory
except Exception as e: for spec_path in glob.glob(os.path.join(api_dir, "*.yaml")):
logger.error(f"Failed to load provider spec from {spec_path}: {e}") provider_name = os.path.splitext(os.path.basename(spec_path))[0]
raise e logger.info(f"Loading {provider_type} provider spec from {spec_path}")
return ret
try:
with open(spec_path) as f:
spec_data = yaml.safe_load(f)
if provider_type == "remote":
spec = _load_remote_provider_spec(spec_data, api)
provider_type_key = f"remote::{provider_name}"
else:
spec = _load_inline_provider_spec(spec_data, api, provider_name)
provider_type_key = f"inline::{provider_name}"
logger.info(f"Loaded {provider_type} provider spec for {provider_type_key} from {spec_path}")
if provider_type_key in registry[api]:
logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}")
registry[api][provider_type_key] = spec
logger.info(f"Successfully loaded external provider {provider_type_key}")
except yaml.YAMLError as yaml_err:
logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}")
raise yaml_err
except Exception as e:
logger.error(f"Failed to load provider spec from {spec_path}: {e}")
raise e
return registry
def get_external_providers_from_module(
registry: dict[Api, dict[str, ProviderSpec]], config, building: bool
) -> dict[Api, dict[str, ProviderSpec]]:
provider_list = None
if isinstance(config, BuildConfig):
provider_list = config.distribution_spec.providers.items()
else:
provider_list = config.providers.items()
if provider_list is None:
logger.warning("Could not get list of providers from config")
return registry
for provider_api, providers in provider_list:
for provider in providers:
if not hasattr(provider, "module") or provider.module is None:
continue
# get provider using module
try:
if not building:
package_name = provider.module.split("==")[0]
module = importlib.import_module(f"{package_name}.provider")
# if config class is wrong you will get an error saying module could not be imported
spec = module.get_provider_spec()
else:
# pass in a partially filled out provider spec to satisfy the registry -- knowing we will be overwriting it later upon build and run
spec = ProviderSpec(
api=Api(provider_api),
provider_type=provider.provider_type,
is_external=True,
module=provider.module,
config_class="",
)
provider_type = provider.provider_type
# in the case we are building we CANNOT import this module of course because it has not been installed.
# return a partially filled out spec that the build script will populate.
registry[Api(provider_api)][provider_type] = spec
except ModuleNotFoundError as exc:
raise ValueError(
"get_provider_spec not found. If specifying an external provider via `module` in the Provider spec, the Provider must have the `provider.get_provider_spec` module available"
) from exc
except Exception as e:
logger.error(f"Failed to load provider spec from module {provider.module}: {e}")
raise e
return registry

View file

@ -0,0 +1,54 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import yaml
from llama_stack.apis.datatypes import Api, ExternalApiSpec
from llama_stack.distribution.datatypes import BuildConfig, StackRunConfig
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="core")
def load_external_apis(config: StackRunConfig | BuildConfig | None) -> dict[Api, ExternalApiSpec]:
"""Load external API specifications from the configured directory.
Args:
config: StackRunConfig or BuildConfig containing the external APIs directory path
Returns:
A dictionary mapping API names to their specifications
"""
if not config or not config.external_apis_dir:
return {}
external_apis_dir = config.external_apis_dir.expanduser().resolve()
if not external_apis_dir.is_dir():
logger.error(f"External APIs directory is not a directory: {external_apis_dir}")
return {}
logger.info(f"Loading external APIs from {external_apis_dir}")
external_apis: dict[Api, ExternalApiSpec] = {}
# Look for YAML files in the external APIs directory
for yaml_path in external_apis_dir.glob("*.yaml"):
try:
with open(yaml_path) as f:
spec_data = yaml.safe_load(f)
spec = ExternalApiSpec(**spec_data)
api = Api.add(spec.name)
logger.info(f"Loaded external API spec for {spec.name} from {yaml_path}")
external_apis[api] = spec
except yaml.YAMLError as yaml_err:
logger.error(f"Failed to parse YAML file {yaml_path}: {yaml_err}")
raise
except Exception:
logger.exception(f"Failed to load external API spec from {yaml_path}")
raise
return external_apis

View file

@ -16,6 +16,7 @@ from llama_stack.apis.inspect import (
VersionInfo, VersionInfo,
) )
from llama_stack.distribution.datatypes import StackRunConfig from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.external import load_external_apis
from llama_stack.distribution.server.routes import get_all_api_routes from llama_stack.distribution.server.routes import get_all_api_routes
from llama_stack.providers.datatypes import HealthStatus from llama_stack.providers.datatypes import HealthStatus
@ -42,7 +43,8 @@ class DistributionInspectImpl(Inspect):
run_config: StackRunConfig = self.config.run_config run_config: StackRunConfig = self.config.run_config
ret = [] ret = []
all_endpoints = get_all_api_routes() external_apis = load_external_apis(run_config)
all_endpoints = get_all_api_routes(external_apis)
for api, endpoints in all_endpoints.items(): for api, endpoints in all_endpoints.items():
# Always include provider and inspect APIs, filter others based on run config # Always include provider and inspect APIs, filter others based on run config
if api.value in ["providers", "inspect"]: if api.value in ["providers", "inspect"]:
@ -53,7 +55,8 @@ class DistributionInspectImpl(Inspect):
method=next(iter([m for m in e.methods if m != "HEAD"])), method=next(iter([m for m in e.methods if m != "HEAD"])),
provider_types=[], # These APIs don't have "real" providers - they're internal to the stack provider_types=[], # These APIs don't have "real" providers - they're internal to the stack
) )
for e in endpoints for e, _ in endpoints
if e.methods is not None
] ]
) )
else: else:
@ -66,7 +69,8 @@ class DistributionInspectImpl(Inspect):
method=next(iter([m for m in e.methods if m != "HEAD"])), method=next(iter([m for m in e.methods if m != "HEAD"])),
provider_types=[p.provider_type for p in providers], provider_types=[p.provider_type for p in providers],
) )
for e in endpoints for e, _ in endpoints
if e.methods is not None
] ]
) )

View file

@ -161,7 +161,13 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
if not self.skip_logger_removal: if not self.skip_logger_removal:
self._remove_root_logger_handlers() self._remove_root_logger_handlers()
return self.loop.run_until_complete(self.async_client.initialize()) # use a new event loop to avoid interfering with the main event loop
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(self.async_client.initialize())
finally:
asyncio.set_event_loop(None)
def _remove_root_logger_handlers(self): def _remove_root_logger_handlers(self):
""" """
@ -243,15 +249,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
file=sys.stderr, file=sys.stderr,
) )
if self.config_path_or_template_name.endswith(".yaml"): if self.config_path_or_template_name.endswith(".yaml"):
# Convert Provider objects to their types
provider_types: dict[str, str | list[str]] = {}
for api, providers in self.config.providers.items():
types = [p.provider_type for p in providers]
# Convert single-item lists to strings
provider_types[api] = types[0] if len(types) == 1 else types
build_config = BuildConfig( build_config = BuildConfig(
distribution_spec=DistributionSpec( distribution_spec=DistributionSpec(
providers=provider_types, providers=self.config.providers,
), ),
external_providers_dir=self.config.external_providers_dir, external_providers_dir=self.config.external_providers_dir,
) )
@ -353,13 +353,15 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
body = options.params or {} body = options.params or {}
body |= options.json_data or {} body |= options.json_data or {}
matched_func, path_params, route = find_matching_route(options.method, path, self.route_impls) matched_func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls)
body |= path_params body |= path_params
body, field_names = self._handle_file_uploads(options, body) body, field_names = self._handle_file_uploads(options, body)
body = self._convert_body(path, options.method, body, exclude_params=set(field_names)) body = self._convert_body(path, options.method, body, exclude_params=set(field_names))
await start_trace(route, {"__location__": "library_client"})
trace_path = webmethod.descriptive_name or route_path
await start_trace(trace_path, {"__location__": "library_client"})
try: try:
result = await matched_func(**body) result = await matched_func(**body)
finally: finally:
@ -409,12 +411,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
path = options.url path = options.url
body = options.params or {} body = options.params or {}
body |= options.json_data or {} body |= options.json_data or {}
func, path_params, route = find_matching_route(options.method, path, self.route_impls) func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls)
body |= path_params body |= path_params
body = self._convert_body(path, options.method, body) body = self._convert_body(path, options.method, body)
await start_trace(route, {"__location__": "library_client"}) trace_path = webmethod.descriptive_name or route_path
await start_trace(trace_path, {"__location__": "library_client"})
async def gen(): async def gen():
try: try:
@ -445,8 +448,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
# we use asynchronous impl always internally and channel all requests to AsyncLlamaStackClient # we use asynchronous impl always internally and channel all requests to AsyncLlamaStackClient
# however, the top-level caller may be a SyncAPIClient -- so its stream_cls might be a Stream (SyncStream) # however, the top-level caller may be a SyncAPIClient -- so its stream_cls might be a Stream (SyncStream)
# so we need to convert it to AsyncStream # so we need to convert it to AsyncStream
# mypy can't track runtime variables inside the [...] of a generic, so ignore that check
args = get_args(stream_cls) args = get_args(stream_cls)
stream_cls = AsyncStream[args[0]] stream_cls = AsyncStream[args[0]] # type: ignore[valid-type]
response = AsyncAPIResponse( response = AsyncAPIResponse(
raw=mock_response, raw=mock_response,
client=self, client=self,
@ -468,7 +472,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
exclude_params = exclude_params or set() exclude_params = exclude_params or set()
func, _, _ = find_matching_route(method, path, self.route_impls) func, _, _, _ = find_matching_route(method, path, self.route_impls)
sig = inspect.signature(func) sig = inspect.signature(func)
# Strip NOT_GIVENs to use the defaults in signature # Strip NOT_GIVENs to use the defaults in signature

View file

@ -101,3 +101,15 @@ def get_authenticated_user() -> User | None:
if not provider_data: if not provider_data:
return None return None
return provider_data.get("__authenticated_user") return provider_data.get("__authenticated_user")
def user_from_scope(scope: dict) -> User | None:
"""Create a User object from ASGI scope data (set by authentication middleware)"""
user_attributes = scope.get("user_attributes", {})
principal = scope.get("principal", "")
# auth not enabled
if not principal and not user_attributes:
return None
return User(principal=principal, attributes=user_attributes)

View file

@ -11,6 +11,7 @@ from llama_stack.apis.agents import Agents
from llama_stack.apis.benchmarks import Benchmarks from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets from llama_stack.apis.datasets import Datasets
from llama_stack.apis.datatypes import ExternalApiSpec
from llama_stack.apis.eval import Eval from llama_stack.apis.eval import Eval
from llama_stack.apis.files import Files from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference, InferenceProvider from llama_stack.apis.inference import Inference, InferenceProvider
@ -35,6 +36,7 @@ from llama_stack.distribution.datatypes import (
StackRunConfig, StackRunConfig,
) )
from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.external import load_external_apis
from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.store import DistributionRegistry
from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger from llama_stack.log import get_logger
@ -59,8 +61,16 @@ class InvalidProviderError(Exception):
pass pass
def api_protocol_map() -> dict[Api, Any]: def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) -> dict[Api, Any]:
return { """Get a mapping of API types to their protocol classes.
Args:
external_apis: Optional dictionary of external API specifications
Returns:
Dictionary mapping API types to their protocol classes
"""
protocols = {
Api.providers: ProvidersAPI, Api.providers: ProvidersAPI,
Api.agents: Agents, Api.agents: Agents,
Api.inference: Inference, Api.inference: Inference,
@ -83,10 +93,23 @@ def api_protocol_map() -> dict[Api, Any]:
Api.files: Files, Api.files: Files,
} }
if external_apis:
for api, api_spec in external_apis.items():
try:
module = importlib.import_module(api_spec.module)
api_class = getattr(module, api_spec.protocol)
def api_protocol_map_for_compliance_check() -> dict[Api, Any]: protocols[api] = api_class
except (ImportError, AttributeError):
logger.exception(f"Failed to load external API {api_spec.name}")
return protocols
def api_protocol_map_for_compliance_check(config: Any) -> dict[Api, Any]:
external_apis = load_external_apis(config)
return { return {
**api_protocol_map(), **api_protocol_map(external_apis),
Api.inference: InferenceProvider, Api.inference: InferenceProvider,
} }
@ -250,7 +273,7 @@ async def instantiate_providers(
dist_registry: DistributionRegistry, dist_registry: DistributionRegistry,
run_config: StackRunConfig, run_config: StackRunConfig,
policy: list[AccessRule], policy: list[AccessRule],
) -> dict: ) -> dict[Api, Any]:
"""Instantiates providers asynchronously while managing dependencies.""" """Instantiates providers asynchronously while managing dependencies."""
impls: dict[Api, Any] = {} impls: dict[Api, Any] = {}
inner_impls_by_provider_id: dict[str, dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis} inner_impls_by_provider_id: dict[str, dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
@ -322,7 +345,7 @@ async def instantiate_provider(
policy: list[AccessRule], policy: list[AccessRule],
): ):
provider_spec = provider.spec provider_spec = provider.spec
if not hasattr(provider_spec, "module"): if not hasattr(provider_spec, "module") or provider_spec.module is None:
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute") raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
logger.debug(f"Instantiating provider {provider.provider_id} from {provider_spec.module}") logger.debug(f"Instantiating provider {provider.provider_id} from {provider_spec.module}")
@ -360,7 +383,7 @@ async def instantiate_provider(
impl.__provider_spec__ = provider_spec impl.__provider_spec__ = provider_spec
impl.__provider_config__ = config impl.__provider_config__ = config
protocols = api_protocol_map_for_compliance_check() protocols = api_protocol_map_for_compliance_check(run_config)
additional_protocols = additional_protocols_map() additional_protocols = additional_protocols_map()
# TODO: check compliance for special tool groups # TODO: check compliance for special tool groups
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol # the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol

View file

@ -117,6 +117,9 @@ class CommonRoutingTableImpl(RoutingTable):
for p in self.impls_by_provider_id.values(): for p in self.impls_by_provider_id.values():
await p.shutdown() await p.shutdown()
async def refresh(self) -> None:
pass
async def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any: async def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
from .benchmarks import BenchmarksRoutingTable from .benchmarks import BenchmarksRoutingTable
from .datasets import DatasetsRoutingTable from .datasets import DatasetsRoutingTable
@ -206,7 +209,6 @@ class CommonRoutingTableImpl(RoutingTable):
if obj.type == ResourceType.model.value: if obj.type == ResourceType.model.value:
await self.dist_registry.register(registered_obj) await self.dist_registry.register(registered_obj)
return registered_obj return registered_obj
else: else:
await self.dist_registry.register(obj) await self.dist_registry.register(obj)
return obj return obj

View file

@ -10,6 +10,7 @@ from typing import Any
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
from llama_stack.distribution.datatypes import ( from llama_stack.distribution.datatypes import (
ModelWithOwner, ModelWithOwner,
RegistryEntrySource,
) )
from llama_stack.log import get_logger from llama_stack.log import get_logger
@ -19,6 +20,26 @@ logger = get_logger(name=__name__, category="core")
class ModelsRoutingTable(CommonRoutingTableImpl, Models): class ModelsRoutingTable(CommonRoutingTableImpl, Models):
listed_providers: set[str] = set()
async def refresh(self) -> None:
for provider_id, provider in self.impls_by_provider_id.items():
refresh = await provider.should_refresh_models()
if not (refresh or provider_id in self.listed_providers):
continue
try:
models = await provider.list_models()
except Exception as e:
logger.exception(f"Model refresh failed for provider {provider_id}: {e}")
continue
self.listed_providers.add(provider_id)
if models is None:
continue
await self.update_registered_models(provider_id, models)
async def list_models(self) -> ListModelsResponse: async def list_models(self) -> ListModelsResponse:
return ListModelsResponse(data=await self.get_all_with_type("model")) return ListModelsResponse(data=await self.get_all_with_type("model"))
@ -81,6 +102,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
provider_id=provider_id, provider_id=provider_id,
metadata=metadata, metadata=metadata,
model_type=model_type, model_type=model_type,
source=RegistryEntrySource.via_register_api,
) )
registered_model = await self.register_object(model) registered_model = await self.register_object(model)
return registered_model return registered_model
@ -91,7 +113,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
raise ValueError(f"Model {model_id} not found") raise ValueError(f"Model {model_id} not found")
await self.unregister_object(existing_model) await self.unregister_object(existing_model)
async def update_registered_llm_models( async def update_registered_models(
self, self,
provider_id: str, provider_id: str,
models: list[Model], models: list[Model],
@ -102,18 +124,19 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
# from run.yaml) that we need to keep track of # from run.yaml) that we need to keep track of
model_ids = {} model_ids = {}
for model in existing_models: for model in existing_models:
# we leave embeddings models alone because often we don't get metadata if model.provider_id != provider_id:
# (embedding dimension, etc.) from the provider continue
if model.provider_id == provider_id and model.model_type == ModelType.llm: if model.source == RegistryEntrySource.via_register_api:
model_ids[model.provider_resource_id] = model.identifier model_ids[model.provider_resource_id] = model.identifier
logger.debug(f"unregistering model {model.identifier}") continue
await self.unregister_object(model)
logger.debug(f"unregistering model {model.identifier}")
await self.unregister_object(model)
for model in models: for model in models:
if model.model_type != ModelType.llm:
continue
if model.provider_resource_id in model_ids: if model.provider_resource_id in model_ids:
model.identifier = model_ids[model.provider_resource_id] # avoid overwriting a non-provider-registered model entry
continue
logger.debug(f"registering model {model.identifier} ({model.provider_resource_id})") logger.debug(f"registering model {model.identifier} ({model.provider_resource_id})")
await self.register_object( await self.register_object(
@ -123,5 +146,6 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
provider_id=provider_id, provider_id=provider_id,
metadata=model.metadata, metadata=model.metadata,
model_type=model.model_type, model_type=model.model_type,
source=RegistryEntrySource.listed_from_provider,
) )
) )

View file

@ -7,9 +7,12 @@
import json import json
import httpx import httpx
from aiohttp import hdrs
from llama_stack.distribution.datatypes import AuthenticationConfig from llama_stack.distribution.datatypes import AuthenticationConfig, User
from llama_stack.distribution.request_headers import user_from_scope
from llama_stack.distribution.server.auth_providers import create_auth_provider from llama_stack.distribution.server.auth_providers import create_auth_provider
from llama_stack.distribution.server.routes import find_matching_route, initialize_route_impls
from llama_stack.log import get_logger from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth") logger = get_logger(name=__name__, category="auth")
@ -78,12 +81,14 @@ class AuthenticationMiddleware:
access resources that don't have access_attributes defined. access resources that don't have access_attributes defined.
""" """
def __init__(self, app, auth_config: AuthenticationConfig): def __init__(self, app, auth_config: AuthenticationConfig, impls):
self.app = app self.app = app
self.impls = impls
self.auth_provider = create_auth_provider(auth_config) self.auth_provider = create_auth_provider(auth_config)
async def __call__(self, scope, receive, send): async def __call__(self, scope, receive, send):
if scope["type"] == "http": if scope["type"] == "http":
# First, handle authentication
headers = dict(scope.get("headers", [])) headers = dict(scope.get("headers", []))
auth_header = headers.get(b"authorization", b"").decode() auth_header = headers.get(b"authorization", b"").decode()
@ -121,15 +126,50 @@ class AuthenticationMiddleware:
f"Authentication successful: {validation_result.principal} with {len(validation_result.attributes)} attributes" f"Authentication successful: {validation_result.principal} with {len(validation_result.attributes)} attributes"
) )
# Scope-based API access control
path = scope.get("path", "")
method = scope.get("method", hdrs.METH_GET)
if not hasattr(self, "route_impls"):
self.route_impls = initialize_route_impls(self.impls)
try:
_, _, _, webmethod = find_matching_route(method, path, self.route_impls)
except ValueError:
# If no matching endpoint is found, pass through to FastAPI
return await self.app(scope, receive, send)
if webmethod.required_scope:
user = user_from_scope(scope)
if not _has_required_scope(webmethod.required_scope, user):
return await self._send_auth_error(
send,
f"Access denied: user does not have required scope: {webmethod.required_scope}",
status=403,
)
return await self.app(scope, receive, send) return await self.app(scope, receive, send)
async def _send_auth_error(self, send, message): async def _send_auth_error(self, send, message, status=401):
await send( await send(
{ {
"type": "http.response.start", "type": "http.response.start",
"status": 401, "status": status,
"headers": [[b"content-type", b"application/json"]], "headers": [[b"content-type", b"application/json"]],
} }
) )
error_msg = json.dumps({"error": {"message": message}}).encode() error_key = "message" if status == 401 else "detail"
error_msg = json.dumps({"error": {error_key: message}}).encode()
await send({"type": "http.response.body", "body": error_msg}) await send({"type": "http.response.body", "body": error_msg})
def _has_required_scope(required_scope: str, user: User | None) -> bool:
# if no user, assume auth is not enabled
if not user:
return True
if not user.attributes:
return False
user_scopes = user.attributes.get("scopes", [])
return required_scope in user_scopes

View file

@ -12,17 +12,18 @@ from typing import Any
from aiohttp import hdrs from aiohttp import hdrs
from starlette.routing import Route from starlette.routing import Route
from llama_stack.apis.datatypes import Api, ExternalApiSpec
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
from llama_stack.apis.version import LLAMA_STACK_API_VERSION from llama_stack.apis.version import LLAMA_STACK_API_VERSION
from llama_stack.distribution.resolver import api_protocol_map from llama_stack.distribution.resolver import api_protocol_map
from llama_stack.providers.datatypes import Api from llama_stack.schema_utils import WebMethod
EndpointFunc = Callable[..., Any] EndpointFunc = Callable[..., Any]
PathParams = dict[str, str] PathParams = dict[str, str]
RouteInfo = tuple[EndpointFunc, str] RouteInfo = tuple[EndpointFunc, str, WebMethod]
PathImpl = dict[str, RouteInfo] PathImpl = dict[str, RouteInfo]
RouteImpls = dict[str, PathImpl] RouteImpls = dict[str, PathImpl]
RouteMatch = tuple[EndpointFunc, PathParams, str] RouteMatch = tuple[EndpointFunc, PathParams, str, WebMethod]
def toolgroup_protocol_map(): def toolgroup_protocol_map():
@ -31,10 +32,12 @@ def toolgroup_protocol_map():
} }
def get_all_api_routes() -> dict[Api, list[Route]]: def get_all_api_routes(
external_apis: dict[Api, ExternalApiSpec] | None = None,
) -> dict[Api, list[tuple[Route, WebMethod]]]:
apis = {} apis = {}
protocols = api_protocol_map() protocols = api_protocol_map(external_apis)
toolgroup_protocols = toolgroup_protocol_map() toolgroup_protocols = toolgroup_protocol_map()
for api, protocol in protocols.items(): for api, protocol in protocols.items():
routes = [] routes = []
@ -65,7 +68,7 @@ def get_all_api_routes() -> dict[Api, list[Route]]:
else: else:
http_method = hdrs.METH_POST http_method = hdrs.METH_POST
routes.append( routes.append(
Route(path=path, methods=[http_method], name=name, endpoint=None) (Route(path=path, methods=[http_method], name=name, endpoint=None), webmethod)
) # setting endpoint to None since don't use a Router object ) # setting endpoint to None since don't use a Router object
apis[api] = routes apis[api] = routes
@ -73,8 +76,8 @@ def get_all_api_routes() -> dict[Api, list[Route]]:
return apis return apis
def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls: def initialize_route_impls(impls, external_apis: dict[Api, ExternalApiSpec] | None = None) -> RouteImpls:
routes = get_all_api_routes() api_to_routes = get_all_api_routes(external_apis)
route_impls: RouteImpls = {} route_impls: RouteImpls = {}
def _convert_path_to_regex(path: str) -> str: def _convert_path_to_regex(path: str) -> str:
@ -88,10 +91,10 @@ def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls:
return f"^{pattern}$" return f"^{pattern}$"
for api, api_routes in routes.items(): for api, api_routes in api_to_routes.items():
if api not in impls: if api not in impls:
continue continue
for route in api_routes: for route, webmethod in api_routes:
impl = impls[api] impl = impls[api]
func = getattr(impl, route.name) func = getattr(impl, route.name)
# Get the first (and typically only) method from the set, filtering out HEAD # Get the first (and typically only) method from the set, filtering out HEAD
@ -104,6 +107,7 @@ def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls:
route_impls[method][_convert_path_to_regex(route.path)] = ( route_impls[method][_convert_path_to_regex(route.path)] = (
func, func,
route.path, route.path,
webmethod,
) )
return route_impls return route_impls
@ -118,7 +122,7 @@ def find_matching_route(method: str, path: str, route_impls: RouteImpls) -> Rout
route_impls: A dictionary of endpoint implementations route_impls: A dictionary of endpoint implementations
Returns: Returns:
A tuple of (endpoint_function, path_params, descriptive_name) A tuple of (endpoint_function, path_params, route_path, webmethod_metadata)
Raises: Raises:
ValueError: If no matching endpoint is found ValueError: If no matching endpoint is found
@ -127,11 +131,11 @@ def find_matching_route(method: str, path: str, route_impls: RouteImpls) -> Rout
if not impls: if not impls:
raise ValueError(f"No endpoint found for {path}") raise ValueError(f"No endpoint found for {path}")
for regex, (func, descriptive_name) in impls.items(): for regex, (func, route_path, webmethod) in impls.items():
match = re.match(regex, path) match = re.match(regex, path)
if match: if match:
# Extract named groups from the regex match # Extract named groups from the regex match
path_params = match.groupdict() path_params = match.groupdict()
return func, path_params, descriptive_name return func, path_params, route_path, webmethod
raise ValueError(f"No endpoint found for {path}") raise ValueError(f"No endpoint found for {path}")

View file

@ -40,7 +40,12 @@ from llama_stack.distribution.datatypes import (
StackRunConfig, StackRunConfig,
) )
from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.request_headers import PROVIDER_DATA_VAR, User, request_provider_data_context from llama_stack.distribution.external import ExternalApiSpec, load_external_apis
from llama_stack.distribution.request_headers import (
PROVIDER_DATA_VAR,
request_provider_data_context,
user_from_scope,
)
from llama_stack.distribution.resolver import InvalidProviderError from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.server.routes import ( from llama_stack.distribution.server.routes import (
find_matching_route, find_matching_route,
@ -51,6 +56,7 @@ from llama_stack.distribution.stack import (
cast_image_name_to_string, cast_image_name_to_string,
construct_stack, construct_stack,
replace_env_vars, replace_env_vars,
shutdown_stack,
validate_env_pair, validate_env_pair,
) )
from llama_stack.distribution.utils.config import redact_sensitive_fields from llama_stack.distribution.utils.config import redact_sensitive_fields
@ -146,18 +152,7 @@ async def shutdown(app):
Handled by the lifespan context manager. The shutdown process involves Handled by the lifespan context manager. The shutdown process involves
shutting down all implementations registered in the application. shutting down all implementations registered in the application.
""" """
for impl in app.__llama_stack_impls__.values(): await shutdown_stack(app.__llama_stack_impls__)
impl_name = impl.__class__.__name__
logger.info("Shutting down %s", impl_name)
try:
if hasattr(impl, "shutdown"):
await asyncio.wait_for(impl.shutdown(), timeout=5)
else:
logger.warning("No shutdown method for %s", impl_name)
except TimeoutError:
logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True)
except (Exception, asyncio.CancelledError) as e:
logger.exception("Failed to shutdown %s: %s", impl_name, {e})
@asynccontextmanager @asynccontextmanager
@ -222,9 +217,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
@functools.wraps(func) @functools.wraps(func)
async def route_handler(request: Request, **kwargs): async def route_handler(request: Request, **kwargs):
# Get auth attributes from the request scope # Get auth attributes from the request scope
user_attributes = request.scope.get("user_attributes", {}) user = user_from_scope(request.scope)
principal = request.scope.get("principal", "")
user = User(principal=principal, attributes=user_attributes)
await log_request_pre_validation(request) await log_request_pre_validation(request)
@ -282,9 +275,10 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
class TracingMiddleware: class TracingMiddleware:
def __init__(self, app, impls): def __init__(self, app, impls, external_apis: dict[str, ExternalApiSpec]):
self.app = app self.app = app
self.impls = impls self.impls = impls
self.external_apis = external_apis
# FastAPI built-in paths that should bypass custom routing # FastAPI built-in paths that should bypass custom routing
self.fastapi_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static") self.fastapi_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static")
@ -301,10 +295,12 @@ class TracingMiddleware:
return await self.app(scope, receive, send) return await self.app(scope, receive, send)
if not hasattr(self, "route_impls"): if not hasattr(self, "route_impls"):
self.route_impls = initialize_route_impls(self.impls) self.route_impls = initialize_route_impls(self.impls, self.external_apis)
try: try:
_, _, trace_path = find_matching_route(scope.get("method", hdrs.METH_GET), path, self.route_impls) _, _, route_path, webmethod = find_matching_route(
scope.get("method", hdrs.METH_GET), path, self.route_impls
)
except ValueError: except ValueError:
# If no matching endpoint is found, pass through to FastAPI # If no matching endpoint is found, pass through to FastAPI
logger.debug(f"No matching route found for path: {path}, falling back to FastAPI") logger.debug(f"No matching route found for path: {path}, falling back to FastAPI")
@ -321,6 +317,7 @@ class TracingMiddleware:
if tracestate: if tracestate:
trace_attributes["tracestate"] = tracestate trace_attributes["tracestate"] = tracestate
trace_path = webmethod.descriptive_name or route_path
trace_context = await start_trace(trace_path, trace_attributes) trace_context = await start_trace(trace_path, trace_attributes)
async def send_with_trace_id(message): async def send_with_trace_id(message):
@ -432,10 +429,21 @@ def main(args: argparse.Namespace | None = None):
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"): if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
app.add_middleware(ClientVersionMiddleware) app.add_middleware(ClientVersionMiddleware)
# Add authentication middleware if configured try:
# Create and set the event loop that will be used for both construction and server runtime
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Construct the stack in the persistent event loop
impls = loop.run_until_complete(construct_stack(config))
except InvalidProviderError as e:
logger.error(f"Error: {str(e)}")
sys.exit(1)
if config.server.auth: if config.server.auth:
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}") logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}")
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth) app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth, impls=impls)
else: else:
if config.server.quota: if config.server.quota:
quota = config.server.quota quota = config.server.quota
@ -466,24 +474,14 @@ def main(args: argparse.Namespace | None = None):
window_seconds=window_seconds, window_seconds=window_seconds,
) )
try:
# Create and set the event loop that will be used for both construction and server runtime
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Construct the stack in the persistent event loop
impls = loop.run_until_complete(construct_stack(config))
except InvalidProviderError as e:
logger.error(f"Error: {str(e)}")
sys.exit(1)
if Api.telemetry in impls: if Api.telemetry in impls:
setup_logger(impls[Api.telemetry]) setup_logger(impls[Api.telemetry])
else: else:
setup_logger(TelemetryAdapter(TelemetryConfig(), {})) setup_logger(TelemetryAdapter(TelemetryConfig(), {}))
all_routes = get_all_api_routes() # Load external APIs if configured
external_apis = load_external_apis(config)
all_routes = get_all_api_routes(external_apis)
if config.apis: if config.apis:
apis_to_serve = set(config.apis) apis_to_serve = set(config.apis)
@ -502,9 +500,12 @@ def main(args: argparse.Namespace | None = None):
api = Api(api_str) api = Api(api_str)
routes = all_routes[api] routes = all_routes[api]
impl = impls[api] try:
impl = impls[api]
except KeyError as e:
raise ValueError(f"Could not find provider implementation for {api} API") from e
for route in routes: for route, _ in routes:
if not hasattr(impl, route.name): if not hasattr(impl, route.name):
# ideally this should be a typing violation already # ideally this should be a typing violation already
raise ValueError(f"Could not find method {route.name} on {impl}!") raise ValueError(f"Could not find method {route.name} on {impl}!")
@ -533,7 +534,7 @@ def main(args: argparse.Namespace | None = None):
app.exception_handler(Exception)(global_exception_handler) app.exception_handler(Exception)(global_exception_handler)
app.__llama_stack_impls__ = impls app.__llama_stack_impls__ = impls
app.add_middleware(TracingMiddleware, impls=impls) app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis)
import uvicorn import uvicorn

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import asyncio
import importlib.resources import importlib.resources
import os import os
import re import re
@ -38,6 +39,7 @@ from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.inspect import DistributionInspectConfig, DistributionInspectImpl from llama_stack.distribution.inspect import DistributionInspectConfig, DistributionInspectImpl
from llama_stack.distribution.providers import ProviderImpl, ProviderImplConfig from llama_stack.distribution.providers import ProviderImpl, ProviderImplConfig
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
from llama_stack.distribution.routing_tables.common import CommonRoutingTableImpl
from llama_stack.distribution.store.registry import create_dist_registry from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger from llama_stack.log import get_logger
@ -90,6 +92,10 @@ RESOURCES = [
] ]
REGISTRY_REFRESH_INTERVAL_SECONDS = 300
REGISTRY_REFRESH_TASK = None
async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]): async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
for rsrc, api, register_method, list_method in RESOURCES: for rsrc, api, register_method, list_method in RESOURCES:
objects = getattr(run_config, rsrc) objects = getattr(run_config, rsrc)
@ -324,9 +330,53 @@ async def construct_stack(
add_internal_implementations(impls, run_config) add_internal_implementations(impls, run_config)
await register_resources(run_config, impls) await register_resources(run_config, impls)
global REGISTRY_REFRESH_TASK
REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry(impls))
def cb(task):
import traceback
if task.cancelled():
logger.error("Model refresh task cancelled")
elif task.exception():
logger.error(f"Model refresh task failed: {task.exception()}")
traceback.print_exception(task.exception())
else:
logger.debug("Model refresh task completed")
REGISTRY_REFRESH_TASK.add_done_callback(cb)
return impls return impls
async def shutdown_stack(impls: dict[Api, Any]):
for impl in impls.values():
impl_name = impl.__class__.__name__
logger.info(f"Shutting down {impl_name}")
try:
if hasattr(impl, "shutdown"):
await asyncio.wait_for(impl.shutdown(), timeout=5)
else:
logger.warning(f"No shutdown method for {impl_name}")
except TimeoutError:
logger.exception(f"Shutdown timeout for {impl_name}")
except (Exception, asyncio.CancelledError) as e:
logger.exception(f"Failed to shutdown {impl_name}: {e}")
global REGISTRY_REFRESH_TASK
if REGISTRY_REFRESH_TASK:
REGISTRY_REFRESH_TASK.cancel()
async def refresh_registry(impls: dict[Api, Any]):
routing_tables = [v for v in impls.values() if isinstance(v, CommonRoutingTableImpl)]
while True:
for routing_table in routing_tables:
await routing_table.refresh()
await asyncio.sleep(REGISTRY_REFRESH_INTERVAL_SECONDS)
def get_stack_run_config_from_template(template: str) -> StackRunConfig: def get_stack_run_config_from_template(template: str) -> StackRunConfig:
template_path = importlib.resources.files("llama_stack") / f"templates/{template}/run.yaml" template_path = importlib.resources.files("llama_stack") / f"templates/{template}/run.yaml"

View file

@ -6,6 +6,7 @@
import logging import logging
import os import os
import re
import sys import sys
from logging.config import dictConfig from logging.config import dictConfig
@ -30,6 +31,7 @@ CATEGORIES = [
"eval", "eval",
"tools", "tools",
"client", "client",
"telemetry",
] ]
# Initialize category levels with default level # Initialize category levels with default level
@ -113,6 +115,11 @@ def parse_environment_config(env_config: str) -> dict[str, int]:
return category_levels return category_levels
def strip_rich_markup(text):
"""Remove Rich markup tags like [dim], [bold magenta], etc."""
return re.sub(r"\[/?[a-zA-Z0-9 _#=,]+\]", "", text)
class CustomRichHandler(RichHandler): class CustomRichHandler(RichHandler):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
kwargs["console"] = Console(width=150) kwargs["console"] = Console(width=150)
@ -131,6 +138,19 @@ class CustomRichHandler(RichHandler):
self.markup = original_markup self.markup = original_markup
class CustomFileHandler(logging.FileHandler):
def __init__(self, filename, mode="a", encoding=None, delay=False):
super().__init__(filename, mode, encoding, delay)
# Default formatter to match console output
self.default_formatter = logging.Formatter("%(asctime)s %(name)s:%(lineno)d %(category)s: %(message)s")
self.setFormatter(self.default_formatter)
def emit(self, record):
if hasattr(record, "msg"):
record.msg = strip_rich_markup(str(record.msg))
super().emit(record)
def setup_logging(category_levels: dict[str, int], log_file: str | None) -> None: def setup_logging(category_levels: dict[str, int], log_file: str | None) -> None:
""" """
Configure logging based on the provided category log levels and an optional log file. Configure logging based on the provided category log levels and an optional log file.
@ -167,8 +187,7 @@ def setup_logging(category_levels: dict[str, int], log_file: str | None) -> None
# Add a file handler if log_file is set # Add a file handler if log_file is set
if log_file: if log_file:
handlers["file"] = { handlers["file"] = {
"class": "logging.FileHandler", "()": CustomFileHandler,
"formatter": "rich",
"filename": log_file, "filename": log_file,
"mode": "a", "mode": "a",
"encoding": "utf-8", "encoding": "utf-8",

View file

@ -47,6 +47,17 @@ class ModelsProtocolPrivate(Protocol):
async def unregister_model(self, model_id: str) -> None: ... async def unregister_model(self, model_id: str) -> None: ...
# the Stack router will query each provider for their list of models
# if a `refresh_interval_seconds` is provided, this method will be called
# periodically to refresh the list of models
#
# NOTE: each model returned will be registered with the model registry. this means
# a callback to the `register_model()` method will be made. this is duplicative and
# may be removed in the future.
async def list_models(self) -> list[Model] | None: ...
async def should_refresh_models(self) -> bool: ...
class ShieldsProtocolPrivate(Protocol): class ShieldsProtocolPrivate(Protocol):
async def register_shield(self, shield: Shield) -> None: ... async def register_shield(self, shield: Shield) -> None: ...
@ -104,6 +115,19 @@ class ProviderSpec(BaseModel):
description="If this provider is deprecated and does NOT work, specify the error message here", description="If this provider is deprecated and does NOT work, specify the error message here",
) )
module: str | None = Field(
default=None,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_adapter_impl(config, deps)`: returns the adapter implementation
Example: `module: ramalama_stack`
""",
)
is_external: bool = Field(default=False, description="Notes whether this provider is an external provider.")
# used internally by the resolver; this is a hack for now # used internally by the resolver; this is a hack for now
deps__: list[str] = Field(default_factory=list) deps__: list[str] = Field(default_factory=list)
@ -124,7 +148,7 @@ class AdapterSpec(BaseModel):
description="Unique identifier for this adapter", description="Unique identifier for this adapter",
) )
module: str = Field( module: str = Field(
..., default_factory=str,
description=""" description="""
Fully-qualified name of the module to import. The module is expected to have: Fully-qualified name of the module to import. The module is expected to have:
@ -162,14 +186,7 @@ The container image to use for this implementation. If one is provided, pip_pack
If a provider depends on other providers, the dependencies MUST NOT specify a container image. If a provider depends on other providers, the dependencies MUST NOT specify a container image.
""", """,
) )
module: str = Field( # module field is inherited from ProviderSpec
...,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_provider_impl(config, deps)`: returns the local implementation
""",
)
provider_data_validator: str | None = Field( provider_data_validator: str | None = Field(
default=None, default=None,
) )
@ -212,9 +229,7 @@ API responses, specify the adapter here.
def container_image(self) -> str | None: def container_image(self) -> str | None:
return None return None
@property # module field is inherited from ProviderSpec
def module(self) -> str:
return self.adapter.module
@property @property
def pip_packages(self) -> list[str]: def pip_packages(self) -> list[str]:
@ -226,14 +241,19 @@ API responses, specify the adapter here.
def remote_provider_spec( def remote_provider_spec(
api: Api, adapter: AdapterSpec, api_dependencies: list[Api] | None = None api: Api,
adapter: AdapterSpec,
api_dependencies: list[Api] | None = None,
optional_api_dependencies: list[Api] | None = None,
) -> RemoteProviderSpec: ) -> RemoteProviderSpec:
return RemoteProviderSpec( return RemoteProviderSpec(
api=api, api=api,
provider_type=f"remote::{adapter.adapter_type}", provider_type=f"remote::{adapter.adapter_type}",
config_class=adapter.config_class, config_class=adapter.config_class,
module=adapter.module,
adapter=adapter, adapter=adapter,
api_dependencies=api_dependencies or [], api_dependencies=api_dependencies or [],
optional_api_dependencies=optional_api_dependencies or [],
) )

View file

@ -102,6 +102,12 @@ class MetaReferenceInferenceImpl(
if self.config.create_distributed_process_group: if self.config.create_distributed_process_group:
self.generator.stop() self.generator.stop()
async def should_refresh_models(self) -> bool:
return False
async def list_models(self) -> list[Model] | None:
return None
async def unregister_model(self, model_id: str) -> None: async def unregister_model(self, model_id: str) -> None:
pass pass

View file

@ -20,6 +20,7 @@ from llama_stack.apis.inference import (
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
) )
from llama_stack.apis.models import ModelType
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import ( from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin, SentenceTransformerEmbeddingMixin,
@ -41,6 +42,8 @@ class SentenceTransformersInferenceImpl(
InferenceProvider, InferenceProvider,
ModelsProtocolPrivate, ModelsProtocolPrivate,
): ):
__provider_id__: str
def __init__(self, config: SentenceTransformersInferenceConfig) -> None: def __init__(self, config: SentenceTransformersInferenceConfig) -> None:
self.config = config self.config = config
@ -50,6 +53,22 @@ class SentenceTransformersInferenceImpl(
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def should_refresh_models(self) -> bool:
return False
async def list_models(self) -> list[Model] | None:
return [
Model(
identifier="all-MiniLM-L6-v2",
provider_resource_id="all-MiniLM-L6-v2",
provider_id=self.__provider_id__,
metadata={
"embedding_dimension": 384,
},
model_type=ModelType.embedding,
),
]
async def register_model(self, model: Model) -> Model: async def register_model(self, model: Model) -> Model:
return model return model

View file

@ -11,19 +11,9 @@ from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.sdk.trace.export import SpanProcessor from opentelemetry.sdk.trace.export import SpanProcessor
from opentelemetry.trace.status import StatusCode from opentelemetry.trace.status import StatusCode
# Colors for console output from llama_stack.log import get_logger
COLORS = {
"reset": "\033[0m", logger = get_logger(name="console_span_processor", category="telemetry")
"bold": "\033[1m",
"dim": "\033[2m",
"red": "\033[31m",
"green": "\033[32m",
"yellow": "\033[33m",
"blue": "\033[34m",
"magenta": "\033[35m",
"cyan": "\033[36m",
"white": "\033[37m",
}
class ConsoleSpanProcessor(SpanProcessor): class ConsoleSpanProcessor(SpanProcessor):
@ -35,34 +25,21 @@ class ConsoleSpanProcessor(SpanProcessor):
return return
timestamp = datetime.fromtimestamp(span.start_time / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3] timestamp = datetime.fromtimestamp(span.start_time / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
logger.info(f"[dim]{timestamp}[/dim] [bold magenta][START][/bold magenta] [dim]{span.name}[/dim]")
print(
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
f"{COLORS['magenta']}[START]{COLORS['reset']} "
f"{COLORS['dim']}{span.name}{COLORS['reset']}"
)
def on_end(self, span: ReadableSpan) -> None: def on_end(self, span: ReadableSpan) -> None:
if span.attributes and span.attributes.get("__autotraced__"): if span.attributes and span.attributes.get("__autotraced__"):
return return
timestamp = datetime.fromtimestamp(span.end_time / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3] timestamp = datetime.fromtimestamp(span.end_time / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
span_context = f"[dim]{timestamp}[/dim] [bold magenta][END][/bold magenta] [dim]{span.name}[/dim]"
span_context = (
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
f"{COLORS['magenta']}[END]{COLORS['reset']} "
f"{COLORS['dim']}{span.name}{COLORS['reset']}"
)
if span.status.status_code == StatusCode.ERROR: if span.status.status_code == StatusCode.ERROR:
span_context += f"{COLORS['reset']} {COLORS['red']}[ERROR]{COLORS['reset']}" span_context += " [bold red][ERROR][/bold red]"
elif span.status.status_code != StatusCode.UNSET: elif span.status.status_code != StatusCode.UNSET:
span_context += f"{COLORS['reset']} [{span.status.status_code}]" span_context += f" [{span.status.status_code}]"
duration_ms = (span.end_time - span.start_time) / 1e6 duration_ms = (span.end_time - span.start_time) / 1e6
span_context += f"{COLORS['reset']} ({duration_ms:.2f}ms)" span_context += f" ({duration_ms:.2f}ms)"
logger.info(span_context)
print(span_context)
if self.print_attributes and span.attributes: if self.print_attributes and span.attributes:
for key, value in span.attributes.items(): for key, value in span.attributes.items():
@ -71,31 +48,26 @@ class ConsoleSpanProcessor(SpanProcessor):
str_value = str(value) str_value = str(value)
if len(str_value) > 1000: if len(str_value) > 1000:
str_value = str_value[:997] + "..." str_value = str_value[:997] + "..."
print(f" {COLORS['dim']}{key}: {str_value}{COLORS['reset']}") logger.info(f" [dim]{key}[/dim]: {str_value}")
for event in span.events: for event in span.events:
event_time = datetime.fromtimestamp(event.timestamp / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3] event_time = datetime.fromtimestamp(event.timestamp / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
severity = event.attributes.get("severity", "info") severity = event.attributes.get("severity", "info")
message = event.attributes.get("message", event.name) message = event.attributes.get("message", event.name)
if isinstance(message, dict | list): if isinstance(message, dict) or isinstance(message, list):
message = json.dumps(message, indent=2) message = json.dumps(message, indent=2)
severity_color = {
severity_colors = { "error": "red",
"error": f"{COLORS['bold']}{COLORS['red']}", "warn": "yellow",
"warn": f"{COLORS['bold']}{COLORS['yellow']}", "info": "white",
"info": COLORS["white"], "debug": "dim",
"debug": COLORS["dim"], }.get(severity, "white")
} logger.info(f" {event_time} [bold {severity_color}][{severity.upper()}][/bold {severity_color}] {message}")
msg_color = severity_colors.get(severity, COLORS["white"])
print(f" {event_time} {msg_color}[{severity.upper()}] {message}{COLORS['reset']}")
if event.attributes: if event.attributes:
for key, value in event.attributes.items(): for key, value in event.attributes.items():
if key.startswith("__") or key in ["message", "severity"]: if key.startswith("__") or key in ["message", "severity"]:
continue continue
print(f" {COLORS['dim']}{key}: {value}{COLORS['reset']}") logger.info(f"/r[dim]{key}[/dim]: {value}")
def shutdown(self) -> None: def shutdown(self) -> None:
"""Shutdown the processor.""" """Shutdown the processor."""

View file

@ -55,6 +55,11 @@ class FaissIndex(EmbeddingIndex):
self.kvstore = kvstore self.kvstore = kvstore
self.bank_id = bank_id self.bank_id = bank_id
# A list of chunk id's in the same order as they are in the index,
# must be updated when chunks are added or removed
self.chunk_id_lock = asyncio.Lock()
self.chunk_ids: list[Any] = []
@classmethod @classmethod
async def create(cls, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None): async def create(cls, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None):
instance = cls(dimension, kvstore, bank_id) instance = cls(dimension, kvstore, bank_id)
@ -75,6 +80,7 @@ class FaissIndex(EmbeddingIndex):
buffer = io.BytesIO(base64.b64decode(data["faiss_index"])) buffer = io.BytesIO(base64.b64decode(data["faiss_index"]))
try: try:
self.index = faiss.deserialize_index(np.load(buffer, allow_pickle=False)) self.index = faiss.deserialize_index(np.load(buffer, allow_pickle=False))
self.chunk_ids = [chunk.chunk_id for chunk in self.chunk_by_index.values()]
except Exception as e: except Exception as e:
logger.debug(e, exc_info=True) logger.debug(e, exc_info=True)
raise ValueError( raise ValueError(
@ -114,11 +120,33 @@ class FaissIndex(EmbeddingIndex):
for i, chunk in enumerate(chunks): for i, chunk in enumerate(chunks):
self.chunk_by_index[indexlen + i] = chunk self.chunk_by_index[indexlen + i] = chunk
self.index.add(np.array(embeddings).astype(np.float32)) async with self.chunk_id_lock:
self.index.add(np.array(embeddings).astype(np.float32))
self.chunk_ids.extend([chunk.chunk_id for chunk in chunks])
# Save updated index # Save updated index
await self._save_index() await self._save_index()
async def delete_chunk(self, chunk_id: str) -> None:
if chunk_id not in self.chunk_ids:
return
async with self.chunk_id_lock:
index = self.chunk_ids.index(chunk_id)
self.index.remove_ids(np.array([index]))
new_chunk_by_index = {}
for idx, chunk in self.chunk_by_index.items():
# Shift all chunks after the removed chunk to the left
if idx > index:
new_chunk_by_index[idx - 1] = chunk
else:
new_chunk_by_index[idx] = chunk
self.chunk_by_index = new_chunk_by_index
self.chunk_ids.pop(index)
await self._save_index()
async def query_vector( async def query_vector(
self, self,
embedding: NDArray, embedding: NDArray,
@ -260,3 +288,9 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
raise ValueError(f"Vector DB {vector_db_id} not found") raise ValueError(f"Vector DB {vector_db_id} not found")
return await index.query_chunks(query, params) return await index.query_chunks(query, params)
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
"""Delete a chunk from a faiss index"""
faiss_index = self.cache[store_id].index
for chunk_id in chunk_ids:
await faiss_index.delete_chunk(chunk_id)

View file

@ -425,6 +425,35 @@ class SQLiteVecIndex(EmbeddingIndex):
return QueryChunksResponse(chunks=chunks, scores=scores) return QueryChunksResponse(chunks=chunks, scores=scores)
async def delete_chunk(self, chunk_id: str) -> None:
"""Remove a chunk from the SQLite vector store."""
def _delete_chunk():
connection = _create_sqlite_connection(self.db_path)
cur = connection.cursor()
try:
cur.execute("BEGIN TRANSACTION")
# Delete from metadata table
cur.execute(f"DELETE FROM {self.metadata_table} WHERE id = ?", (chunk_id,))
# Delete from vector table
cur.execute(f"DELETE FROM {self.vector_table} WHERE id = ?", (chunk_id,))
# Delete from FTS table
cur.execute(f"DELETE FROM {self.fts_table} WHERE id = ?", (chunk_id,))
connection.commit()
except Exception as e:
connection.rollback()
logger.error(f"Error deleting chunk {chunk_id}: {e}")
raise
finally:
cur.close()
connection.close()
await asyncio.to_thread(_delete_chunk)
class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate): class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
""" """
@ -520,3 +549,13 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
if not index: if not index:
raise ValueError(f"Vector DB {vector_db_id} not found") raise ValueError(f"Vector DB {vector_db_id} not found")
return await index.query_chunks(query, params) return await index.query_chunks(query, params)
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
"""Delete a chunk from a sqlite_vec index."""
index = await self._get_and_cache_vector_db_index(store_id)
if not index:
raise ValueError(f"Vector DB {store_id} not found")
for chunk_id in chunk_ids:
# Use the index's delete_chunk method
await index.index.delete_chunk(chunk_id)

View file

@ -410,6 +410,7 @@ See [PGVector's documentation](https://github.com/pgvector/pgvector) for more de
""", """,
), ),
api_dependencies=[Api.inference], api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
), ),
remote_provider_spec( remote_provider_spec(
Api.vector_io, Api.vector_io,

View file

@ -6,13 +6,14 @@
from typing import Any from typing import Any
from pydantic import BaseModel, Field, SecretStr from pydantic import Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
@json_schema_type @json_schema_type
class FireworksImplConfig(BaseModel): class FireworksImplConfig(RemoteInferenceProviderConfig):
url: str = Field( url: str = Field(
default="https://api.fireworks.ai/inference/v1", default="https://api.fireworks.ai/inference/v1",
description="The URL for the Fireworks server", description="The URL for the Fireworks server",

View file

@ -70,7 +70,7 @@ logger = get_logger(name=__name__, category="inference")
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData): class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
def __init__(self, config: FireworksImplConfig) -> None: def __init__(self, config: FireworksImplConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES) ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models)
self.config = config self.config = config
async def initialize(self) -> None: async def initialize(self) -> None:

View file

@ -13,8 +13,10 @@ DEFAULT_OLLAMA_URL = "http://localhost:11434"
class OllamaImplConfig(BaseModel): class OllamaImplConfig(BaseModel):
url: str = DEFAULT_OLLAMA_URL url: str = DEFAULT_OLLAMA_URL
refresh_models: bool = Field(default=False, description="refresh and re-register models periodically") refresh_models: bool = Field(
refresh_models_interval: int = Field(default=300, description="interval in seconds to refresh models") default=False,
description="Whether to refresh models periodically",
)
@classmethod @classmethod
def sample_run_config(cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", **kwargs) -> dict[str, Any]: def sample_run_config(cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", **kwargs) -> dict[str, Any]:

View file

@ -98,14 +98,16 @@ class OllamaInferenceAdapter(
def __init__(self, config: OllamaImplConfig) -> None: def __init__(self, config: OllamaImplConfig) -> None:
self.register_helper = ModelRegistryHelper(MODEL_ENTRIES) self.register_helper = ModelRegistryHelper(MODEL_ENTRIES)
self.config = config self.config = config
self._client = None self._clients: dict[asyncio.AbstractEventLoop, AsyncClient] = {}
self._openai_client = None self._openai_client = None
@property @property
def client(self) -> AsyncClient: def client(self) -> AsyncClient:
if self._client is None: # ollama client attaches itself to the current event loop (sadly?)
self._client = AsyncClient(host=self.config.url) loop = asyncio.get_running_loop()
return self._client if loop not in self._clients:
self._clients[loop] = AsyncClient(host=self.config.url)
return self._clients[loop]
@property @property
def openai_client(self) -> AsyncOpenAI: def openai_client(self) -> AsyncOpenAI:
@ -121,59 +123,61 @@ class OllamaInferenceAdapter(
"Ollama Server is not running, make sure to start it using `ollama serve` in a separate terminal" "Ollama Server is not running, make sure to start it using `ollama serve` in a separate terminal"
) )
if self.config.refresh_models: async def should_refresh_models(self) -> bool:
logger.debug("ollama starting background model refresh task") return self.config.refresh_models
self._refresh_task = asyncio.create_task(self._refresh_models())
def cb(task):
if task.cancelled():
import traceback
logger.error(f"ollama background refresh task canceled:\n{''.join(traceback.format_stack())}")
elif task.exception():
logger.error(f"ollama background refresh task died: {task.exception()}")
else:
logger.error("ollama background refresh task completed unexpectedly")
self._refresh_task.add_done_callback(cb)
async def _refresh_models(self) -> None:
# Wait for model store to be available (with timeout)
waited_time = 0
while not self.model_store and waited_time < 60:
await asyncio.sleep(1)
waited_time += 1
if not self.model_store:
raise ValueError("Model store not set after waiting 60 seconds")
async def list_models(self) -> list[Model] | None:
provider_id = self.__provider_id__ provider_id = self.__provider_id__
while True: response = await self.client.list()
try:
response = await self.client.list() # always add the two embedding models which can be pulled on demand
except Exception as e: models = [
logger.warning(f"Failed to list models: {str(e)}") Model(
await asyncio.sleep(self.config.refresh_models_interval) identifier="all-minilm:l6-v2",
provider_resource_id="all-minilm:l6-v2",
provider_id=provider_id,
metadata={
"embedding_dimension": 384,
"context_length": 512,
},
model_type=ModelType.embedding,
),
# add all-minilm alias
Model(
identifier="all-minilm",
provider_resource_id="all-minilm:l6-v2",
provider_id=provider_id,
metadata={
"embedding_dimension": 384,
"context_length": 512,
},
model_type=ModelType.embedding,
),
Model(
identifier="nomic-embed-text",
provider_resource_id="nomic-embed-text",
provider_id=provider_id,
metadata={
"embedding_dimension": 768,
"context_length": 8192,
},
model_type=ModelType.embedding,
),
]
for m in response.models:
# kill embedding models since we don't know dimensions for them
if m.details.family in ["bert"]:
continue continue
models.append(
models = [] Model(
for m in response.models: identifier=m.model,
model_type = ModelType.embedding if m.details.family in ["bert"] else ModelType.llm provider_resource_id=m.model,
if model_type == ModelType.embedding: provider_id=provider_id,
continue metadata={},
models.append( model_type=ModelType.llm,
Model(
identifier=m.model,
provider_resource_id=m.model,
provider_id=provider_id,
metadata={},
model_type=model_type,
)
) )
await self.model_store.update_registered_llm_models(provider_id, models) )
logger.debug(f"ollama refreshed model list ({len(models)} models)") return models
await asyncio.sleep(self.config.refresh_models_interval)
async def health(self) -> HealthResponse: async def health(self) -> HealthResponse:
""" """
@ -190,12 +194,7 @@ class OllamaInferenceAdapter(
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}") return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
async def shutdown(self) -> None: async def shutdown(self) -> None:
if hasattr(self, "_refresh_task") and not self._refresh_task.done(): self._clients.clear()
logger.debug("ollama cancelling background refresh task")
self._refresh_task.cancel()
self._client = None
self._openai_client = None
async def unregister_model(self, model_id: str) -> None: async def unregister_model(self, model_id: str) -> None:
pass pass

View file

@ -6,13 +6,14 @@
from typing import Any from typing import Any
from pydantic import BaseModel, Field, SecretStr from pydantic import Field, SecretStr
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
@json_schema_type @json_schema_type
class TogetherImplConfig(BaseModel): class TogetherImplConfig(RemoteInferenceProviderConfig):
url: str = Field( url: str = Field(
default="https://api.together.xyz/v1", default="https://api.together.xyz/v1",
description="The URL for the Together AI server", description="The URL for the Together AI server",

View file

@ -66,7 +66,7 @@ logger = get_logger(name=__name__, category="inference")
class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData): class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
def __init__(self, config: TogetherImplConfig) -> None: def __init__(self, config: TogetherImplConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES) ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models)
self.config = config self.config = config
async def initialize(self) -> None: async def initialize(self) -> None:

View file

@ -33,10 +33,6 @@ class VLLMInferenceAdapterConfig(BaseModel):
default=False, default=False,
description="Whether to refresh models periodically", description="Whether to refresh models periodically",
) )
refresh_models_interval: int = Field(
default=300,
description="Interval in seconds to refresh models",
)
@field_validator("tls_verify") @field_validator("tls_verify")
@classmethod @classmethod

View file

@ -3,7 +3,6 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import asyncio
import json import json
from collections.abc import AsyncGenerator, AsyncIterator from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any from typing import Any
@ -293,7 +292,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
# automatically set by the resolver when instantiating the provider # automatically set by the resolver when instantiating the provider
__provider_id__: str __provider_id__: str
model_store: ModelStore | None = None model_store: ModelStore | None = None
_refresh_task: asyncio.Task | None = None
def __init__(self, config: VLLMInferenceAdapterConfig) -> None: def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries()) self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
@ -301,65 +299,30 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
self.client = None self.client = None
async def initialize(self) -> None: async def initialize(self) -> None:
if not self.config.url: pass
# intentionally don't raise an error here, we want to allow the provider to be "dormant"
# or available in distributions like "starter" without causing a ruckus
return
if self.config.refresh_models: async def should_refresh_models(self) -> bool:
self._refresh_task = asyncio.create_task(self._refresh_models()) return self.config.refresh_models
def cb(task):
import traceback
if task.cancelled():
log.error(f"vLLM background refresh task canceled:\n{''.join(traceback.format_stack())}")
elif task.exception():
# print the stack trace for the exception
exc = task.exception()
log.error(f"vLLM background refresh task died: {exc}")
traceback.print_exception(exc)
else:
log.error("vLLM background refresh task completed unexpectedly")
self._refresh_task.add_done_callback(cb)
async def _refresh_models(self) -> None:
provider_id = self.__provider_id__
waited_time = 0
while not self.model_store and waited_time < 60:
await asyncio.sleep(1)
waited_time += 1
if not self.model_store:
raise ValueError("Model store not set after waiting 60 seconds")
async def list_models(self) -> list[Model] | None:
self._lazy_initialize_client() self._lazy_initialize_client()
assert self.client is not None # mypy assert self.client is not None # mypy
while True: models = []
try: async for m in self.client.models.list():
models = [] model_type = ModelType.llm # unclear how to determine embedding vs. llm models
async for m in self.client.models.list(): models.append(
model_type = ModelType.llm # unclear how to determine embedding vs. llm models Model(
models.append( identifier=m.id,
Model( provider_resource_id=m.id,
identifier=m.id, provider_id=self.__provider_id__,
provider_resource_id=m.id, metadata={},
provider_id=provider_id, model_type=model_type,
metadata={}, )
model_type=model_type, )
) return models
)
await self.model_store.update_registered_llm_models(provider_id, models)
log.debug(f"vLLM refreshed model list ({len(models)} models)")
except Exception as e:
log.error(f"vLLM background refresh task failed: {e}")
await asyncio.sleep(self.config.refresh_models_interval)
async def shutdown(self) -> None: async def shutdown(self) -> None:
if self._refresh_task: pass
self._refresh_task.cancel()
self._refresh_task = None
async def unregister_model(self, model_id: str) -> None: async def unregister_model(self, model_id: str) -> None:
pass pass

View file

@ -57,12 +57,15 @@ class ChromaIndex(EmbeddingIndex):
self.collection = collection self.collection = collection
self.kvstore = kvstore self.kvstore = kvstore
async def initialize(self):
pass
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
assert len(chunks) == len(embeddings), ( assert len(chunks) == len(embeddings), (
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
) )
ids = [f"{c.metadata['document_id']}:chunk-{i}" for i, c in enumerate(chunks)] ids = [f"{c.metadata.get('document_id', '')}:{c.chunk_id}" for c in chunks]
await maybe_await( await maybe_await(
self.collection.add( self.collection.add(
documents=[chunk.model_dump_json() for chunk in chunks], documents=[chunk.model_dump_json() for chunk in chunks],
@ -112,6 +115,9 @@ class ChromaIndex(EmbeddingIndex):
) -> QueryChunksResponse: ) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in Chroma") raise NotImplementedError("Keyword search is not supported in Chroma")
async def delete_chunk(self, chunk_id: str) -> None:
raise NotImplementedError("delete_chunk is not supported in Chroma")
async def query_hybrid( async def query_hybrid(
self, self,
embedding: NDArray, embedding: NDArray,
@ -137,9 +143,12 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
self.client = None self.client = None
self.cache = {} self.cache = {}
self.kvstore: KVStore | None = None self.kvstore: KVStore | None = None
self.vector_db_store = None
async def initialize(self) -> None: async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.kvstore) self.kvstore = await kvstore_impl(self.config.kvstore)
self.vector_db_store = self.kvstore
if isinstance(self.config, RemoteChromaVectorIOConfig): if isinstance(self.config, RemoteChromaVectorIOConfig):
log.info(f"Connecting to Chroma server at: {self.config.url}") log.info(f"Connecting to Chroma server at: {self.config.url}")
url = self.config.url.rstrip("/") url = self.config.url.rstrip("/")
@ -172,6 +181,10 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
) )
async def unregister_vector_db(self, vector_db_id: str) -> None: async def unregister_vector_db(self, vector_db_id: str) -> None:
if vector_db_id not in self.cache:
log.warning(f"Vector DB {vector_db_id} not found")
return
await self.cache[vector_db_id].index.delete() await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id] del self.cache[vector_db_id]
@ -182,6 +195,8 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
ttl_seconds: int | None = None, ttl_seconds: int | None = None,
) -> None: ) -> None:
index = await self._get_and_cache_vector_db_index(vector_db_id) index = await self._get_and_cache_vector_db_index(vector_db_id)
if index is None:
raise ValueError(f"Vector DB {vector_db_id} not found in Chroma")
await index.insert_chunks(chunks) await index.insert_chunks(chunks)
@ -193,6 +208,9 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
) -> QueryChunksResponse: ) -> QueryChunksResponse:
index = await self._get_and_cache_vector_db_index(vector_db_id) index = await self._get_and_cache_vector_db_index(vector_db_id)
if index is None:
raise ValueError(f"Vector DB {vector_db_id} not found in Chroma")
return await index.query_chunks(query, params) return await index.query_chunks(query, params)
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex: async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex:
@ -208,3 +226,6 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
index = VectorDBWithIndex(vector_db, ChromaIndex(self.client, collection), self.inference_api) index = VectorDBWithIndex(vector_db, ChromaIndex(self.client, collection), self.inference_api)
self.cache[vector_db_id] = index self.cache[vector_db_id] = index
return index return index
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")

View file

@ -247,6 +247,16 @@ class MilvusIndex(EmbeddingIndex):
) -> QueryChunksResponse: ) -> QueryChunksResponse:
raise NotImplementedError("Hybrid search is not supported in Milvus") raise NotImplementedError("Hybrid search is not supported in Milvus")
async def delete_chunk(self, chunk_id: str) -> None:
"""Remove a chunk from the Milvus collection."""
try:
await asyncio.to_thread(
self.client.delete, collection_name=self.collection_name, filter=f'chunk_id == "{chunk_id}"'
)
except Exception as e:
logger.error(f"Error deleting chunk {chunk_id} from Milvus collection {self.collection_name}: {e}")
raise
class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate): class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
def __init__( def __init__(
@ -369,3 +379,13 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
) )
return await index.query_chunks(query, params) return await index.query_chunks(query, params)
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
"""Delete a chunk from a milvus vector store."""
index = await self._get_and_cache_vector_db_index(store_id)
if not index:
raise ValueError(f"Vector DB {store_id} not found")
for chunk_id in chunk_ids:
# Use the index's delete_chunk method
await index.index.delete_chunk(chunk_id)

View file

@ -12,6 +12,6 @@ from .config import PGVectorVectorIOConfig
async def get_adapter_impl(config: PGVectorVectorIOConfig, deps: dict[Api, ProviderSpec]): async def get_adapter_impl(config: PGVectorVectorIOConfig, deps: dict[Api, ProviderSpec]):
from .pgvector import PGVectorVectorIOAdapter from .pgvector import PGVectorVectorIOAdapter
impl = PGVectorVectorIOAdapter(config, deps[Api.inference]) impl = PGVectorVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files, None))
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -99,7 +99,7 @@ class PGVectorIndex(EmbeddingIndex):
for i, chunk in enumerate(chunks): for i, chunk in enumerate(chunks):
values.append( values.append(
( (
f"{chunk.metadata['document_id']}:chunk-{i}", f"{chunk.chunk_id}",
Json(chunk.model_dump()), Json(chunk.model_dump()),
embeddings[i].tolist(), embeddings[i].tolist(),
) )
@ -159,6 +159,11 @@ class PGVectorIndex(EmbeddingIndex):
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}") cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
async def delete_chunk(self, chunk_id: str) -> None:
"""Remove a chunk from the PostgreSQL table."""
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE id = %s", (chunk_id,))
class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate): class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
def __init__( def __init__(
@ -265,3 +270,13 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
index = PGVectorIndex(vector_db, vector_db.embedding_dimension, self.conn) index = PGVectorIndex(vector_db, vector_db.embedding_dimension, self.conn)
self.cache[vector_db_id] = VectorDBWithIndex(vector_db, index, self.inference_api) self.cache[vector_db_id] = VectorDBWithIndex(vector_db, index, self.inference_api)
return self.cache[vector_db_id] return self.cache[vector_db_id]
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
"""Delete a chunk from a PostgreSQL vector store."""
index = await self._get_and_cache_vector_db_index(store_id)
if not index:
raise ValueError(f"Vector DB {store_id} not found")
for chunk_id in chunk_ids:
# Use the index's delete_chunk method
await index.index.delete_chunk(chunk_id)

View file

@ -82,6 +82,9 @@ class QdrantIndex(EmbeddingIndex):
await self.client.upsert(collection_name=self.collection_name, points=points) await self.client.upsert(collection_name=self.collection_name, points=points)
async def delete_chunk(self, chunk_id: str) -> None:
raise NotImplementedError("delete_chunk is not supported in qdrant")
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
results = ( results = (
await self.client.query_points( await self.client.query_points(
@ -307,3 +310,6 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
file_id: str, file_id: str,
) -> VectorStoreFileObject: ) -> VectorStoreFileObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant") raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")

View file

@ -66,6 +66,9 @@ class WeaviateIndex(EmbeddingIndex):
# TODO: make this async friendly # TODO: make this async friendly
collection.data.insert_many(data_objects) collection.data.insert_many(data_objects)
async def delete_chunk(self, chunk_id: str) -> None:
raise NotImplementedError("delete_chunk is not supported in Chroma")
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
collection = self.client.collections.get(self.collection_name) collection = self.client.collections.get(self.collection_name)
@ -264,3 +267,6 @@ class WeaviateVectorIOAdapter(
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None: async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate") raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate")
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate")

View file

@ -20,6 +20,13 @@ from llama_stack.providers.utils.inference import (
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core")
class RemoteInferenceProviderConfig(BaseModel):
allowed_models: list[str] | None = Field(
default=None,
description="List of models that should be registered with the model registry. If None, all models are allowed.",
)
# TODO: this class is more confusing than useful right now. We need to make it # TODO: this class is more confusing than useful right now. We need to make it
# more closer to the Model class. # more closer to the Model class.
class ProviderModelEntry(BaseModel): class ProviderModelEntry(BaseModel):
@ -65,7 +72,10 @@ def build_model_entry(provider_model_id: str, model_descriptor: str) -> Provider
class ModelRegistryHelper(ModelsProtocolPrivate): class ModelRegistryHelper(ModelsProtocolPrivate):
def __init__(self, model_entries: list[ProviderModelEntry]): __provider_id__: str
def __init__(self, model_entries: list[ProviderModelEntry], allowed_models: list[str] | None = None):
self.allowed_models = allowed_models
self.alias_to_provider_id_map = {} self.alias_to_provider_id_map = {}
self.provider_id_to_llama_model_map = {} self.provider_id_to_llama_model_map = {}
for entry in model_entries: for entry in model_entries:
@ -79,6 +89,27 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
self.alias_to_provider_id_map[entry.llama_model] = entry.provider_model_id self.alias_to_provider_id_map[entry.llama_model] = entry.provider_model_id
self.provider_id_to_llama_model_map[entry.provider_model_id] = entry.llama_model self.provider_id_to_llama_model_map[entry.provider_model_id] = entry.llama_model
async def list_models(self) -> list[Model] | None:
models = []
for entry in self.model_entries:
ids = [entry.provider_model_id] + entry.aliases
for id in ids:
if self.allowed_models and id not in self.allowed_models:
continue
models.append(
Model(
model_id=id,
provider_resource_id=entry.provider_model_id,
model_type=ModelType.llm,
metadata=entry.metadata,
provider_id=self.__provider_id__,
)
)
return models
async def should_refresh_models(self) -> bool:
return False
def get_provider_model_id(self, identifier: str) -> str | None: def get_provider_model_id(self, identifier: str) -> str | None:
return self.alias_to_provider_id_map.get(identifier, None) return self.alias_to_provider_id_map.get(identifier, None)

View file

@ -152,6 +152,11 @@ class OpenAIVectorStoreMixin(ABC):
"""Load existing OpenAI vector stores into the in-memory cache.""" """Load existing OpenAI vector stores into the in-memory cache."""
self.openai_vector_stores = await self._load_openai_vector_stores() self.openai_vector_stores = await self._load_openai_vector_stores()
@abstractmethod
async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None:
"""Delete a chunk from a vector store."""
pass
@abstractmethod @abstractmethod
async def register_vector_db(self, vector_db: VectorDB) -> None: async def register_vector_db(self, vector_db: VectorDB) -> None:
"""Register a vector database (provider-specific implementation).""" """Register a vector database (provider-specific implementation)."""
@ -763,17 +768,15 @@ class OpenAIVectorStoreMixin(ABC):
if vector_store_id not in self.openai_vector_stores: if vector_store_id not in self.openai_vector_stores:
raise ValueError(f"Vector store {vector_store_id} not found") raise ValueError(f"Vector store {vector_store_id} not found")
dict_chunks = await self._load_openai_vector_store_file_contents(vector_store_id, file_id)
chunks = [Chunk.model_validate(c) for c in dict_chunks]
await self.delete_chunks(vector_store_id, [str(c.chunk_id) for c in chunks if c.chunk_id])
store_info = self.openai_vector_stores[vector_store_id].copy() store_info = self.openai_vector_stores[vector_store_id].copy()
file = await self.openai_retrieve_vector_store_file(vector_store_id, file_id) file = await self.openai_retrieve_vector_store_file(vector_store_id, file_id)
await self._delete_openai_vector_store_file_from_storage(vector_store_id, file_id) await self._delete_openai_vector_store_file_from_storage(vector_store_id, file_id)
# TODO: We need to actually delete the embeddings from the underlying vector store...
# Also uncomment the corresponding integration test marked as xfail
#
# test_openai_vector_store_delete_file_removes_from_vector_store in
# tests/integration/vector_io/test_openai_vector_stores.py
# Update in-memory cache # Update in-memory cache
store_info["file_ids"].remove(file_id) store_info["file_ids"].remove(file_id)
store_info["file_counts"][file.status] -= 1 store_info["file_counts"][file.status] -= 1

View file

@ -231,6 +231,10 @@ class EmbeddingIndex(ABC):
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
raise NotImplementedError() raise NotImplementedError()
@abstractmethod
async def delete_chunk(self, chunk_id: str):
raise NotImplementedError()
@abstractmethod @abstractmethod
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
raise NotImplementedError() raise NotImplementedError()

View file

@ -4,13 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from enum import Enum
from typing import Any, cast from typing import Any, cast
import httpx import httpx
from mcp import ClientSession from mcp import ClientSession, McpError
from mcp import types as mcp_types from mcp import types as mcp_types
from mcp.client.sse import sse_client from mcp.client.sse import sse_client
from mcp.client.streamable_http import streamablehttp_client
from llama_stack.apis.common.content_types import ImageContentItem, InterleavedContentItem, TextContentItem from llama_stack.apis.common.content_types import ImageContentItem, InterleavedContentItem, TextContentItem
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
@ -21,31 +24,61 @@ from llama_stack.apis.tools import (
) )
from llama_stack.distribution.datatypes import AuthenticationRequiredError from llama_stack.distribution.datatypes import AuthenticationRequiredError
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.utils.tools.ttl_dict import TTLDict
logger = get_logger(__name__, category="tools") logger = get_logger(__name__, category="tools")
protocol_cache = TTLDict(ttl_seconds=3600)
class MCPProtol(Enum):
UNKNOWN = 0
STREAMABLE_HTTP = 1
SSE = 2
@asynccontextmanager @asynccontextmanager
async def sse_client_wrapper(endpoint: str, headers: dict[str, str]): async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, Any]:
try: # we use a ttl'd dict to cache the happy path protocol for each endpoint
async with sse_client(endpoint, headers=headers) as streams: # but, we always fall back to trying the other protocol if we cannot initialize the session
async with ClientSession(*streams) as session: connection_strategies = [MCPProtol.STREAMABLE_HTTP, MCPProtol.SSE]
await session.initialize() mcp_protocol = protocol_cache.get(endpoint, default=MCPProtol.UNKNOWN)
yield session if mcp_protocol == MCPProtol.SSE:
except* httpx.HTTPStatusError as eg: connection_strategies = [MCPProtol.SSE, MCPProtol.STREAMABLE_HTTP]
for exc in eg.exceptions:
# mypy does not currently narrow the type of `eg.exceptions` based on the `except*` filter, for i, strategy in enumerate(connection_strategies):
# so we explicitly cast each item to httpx.HTTPStatusError. This is safe because try:
# `except* httpx.HTTPStatusError` guarantees all exceptions in `eg.exceptions` are of that type. client = streamablehttp_client
err = cast(httpx.HTTPStatusError, exc) if strategy == MCPProtol.SSE:
if err.response.status_code == 401: client = sse_client
raise AuthenticationRequiredError(exc) from exc async with client(endpoint, headers=headers) as client_streams:
raise async with ClientSession(read_stream=client_streams[0], write_stream=client_streams[1]) as session:
await session.initialize()
protocol_cache[endpoint] = strategy
yield session
return
except* httpx.HTTPStatusError as eg:
for exc in eg.exceptions:
# mypy does not currently narrow the type of `eg.exceptions` based on the `except*` filter,
# so we explicitly cast each item to httpx.HTTPStatusError. This is safe because
# `except* httpx.HTTPStatusError` guarantees all exceptions in `eg.exceptions` are of that type.
err = cast(httpx.HTTPStatusError, exc)
if err.response.status_code == 401:
raise AuthenticationRequiredError(exc) from exc
if i == len(connection_strategies) - 1:
raise
except* McpError:
if i < len(connection_strategies) - 1:
logger.warning(
f"failed to connect via {strategy.name}, falling back to {connection_strategies[i + 1].name}"
)
else:
raise
async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse: async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse:
tools = [] tools = []
async with sse_client_wrapper(endpoint, headers) as session: async with client_wrapper(endpoint, headers) as session:
tools_result = await session.list_tools() tools_result = await session.list_tools()
for tool in tools_result.tools: for tool in tools_result.tools:
parameters = [] parameters = []
@ -73,7 +106,7 @@ async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefs
async def invoke_mcp_tool( async def invoke_mcp_tool(
endpoint: str, headers: dict[str, str], tool_name: str, kwargs: dict[str, Any] endpoint: str, headers: dict[str, str], tool_name: str, kwargs: dict[str, Any]
) -> ToolInvocationResult: ) -> ToolInvocationResult:
async with sse_client_wrapper(endpoint, headers) as session: async with client_wrapper(endpoint, headers) as session:
result = await session.call_tool(tool_name, kwargs) result = await session.call_tool(tool_name, kwargs)
content: list[InterleavedContentItem] = [] content: list[InterleavedContentItem] = []

View file

@ -0,0 +1,70 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import time
from threading import RLock
from typing import Any
class TTLDict(dict):
"""
A dictionary with a ttl for each item
"""
def __init__(self, ttl_seconds: float, *args, **kwargs):
super().__init__(*args, **kwargs)
self.ttl_seconds = ttl_seconds
self._expires: dict[Any, Any] = {} # expires holds when an item will expire
self._lock = RLock()
if args or kwargs:
for k, v in self.items():
self.__setitem__(k, v)
def __delitem__(self, key):
with self._lock:
del self._expires[key]
super().__delitem__(key)
def __setitem__(self, key, value):
with self._lock:
self._expires[key] = time.monotonic() + self.ttl_seconds
super().__setitem__(key, value)
def _is_expired(self, key):
if key not in self._expires:
return False
return time.monotonic() > self._expires[key]
def __getitem__(self, key):
with self._lock:
if self._is_expired(key):
del self._expires[key]
super().__delitem__(key)
raise KeyError(f"{key} has expired and was removed")
return super().__getitem__(key)
def get(self, key, default=None):
try:
return self[key]
except KeyError:
return default
def __contains__(self, key):
try:
_ = self[key]
return True
except KeyError:
return False
def __repr__(self):
with self._lock:
for key in self.keys():
if self._is_expired(key):
del self._expires[key]
super().__delitem__(key)
return f"TTLDict({self.ttl_seconds}, {super().__repr__()})"

View file

@ -22,6 +22,7 @@ class WebMethod:
# A descriptive name of the corresponding span created by tracing # A descriptive name of the corresponding span created by tracing
descriptive_name: str | None = None descriptive_name: str | None = None
experimental: bool | None = False experimental: bool | None = False
required_scope: str | None = None
T = TypeVar("T", bound=Callable[..., Any]) T = TypeVar("T", bound=Callable[..., Any])
@ -36,6 +37,7 @@ def webmethod(
raw_bytes_request_body: bool | None = False, raw_bytes_request_body: bool | None = False,
descriptive_name: str | None = None, descriptive_name: str | None = None,
experimental: bool | None = False, experimental: bool | None = False,
required_scope: str | None = None,
) -> Callable[[T], T]: ) -> Callable[[T], T]:
""" """
Decorator that supplies additional metadata to an endpoint operation function. Decorator that supplies additional metadata to an endpoint operation function.
@ -45,6 +47,7 @@ def webmethod(
:param request_examples: Sample requests that the operation might take. Pass a list of objects, not JSON. :param request_examples: Sample requests that the operation might take. Pass a list of objects, not JSON.
:param response_examples: Sample responses that the operation might produce. Pass a list of objects, not JSON. :param response_examples: Sample responses that the operation might produce. Pass a list of objects, not JSON.
:param experimental: True if the operation is experimental and subject to change. :param experimental: True if the operation is experimental and subject to change.
:param required_scope: Required scope for this endpoint (e.g., 'monitoring.viewer').
""" """
def wrap(func: T) -> T: def wrap(func: T) -> T:
@ -57,6 +60,7 @@ def webmethod(
raw_bytes_request_body=raw_bytes_request_body, raw_bytes_request_body=raw_bytes_request_body,
descriptive_name=descriptive_name, descriptive_name=descriptive_name,
experimental=experimental, experimental=experimental,
required_scope=required_scope,
) )
return func return func

View file

@ -3,57 +3,98 @@ distribution_spec:
description: CI tests for Llama Stack description: CI tests for Llama Stack
providers: providers:
inference: inference:
- remote::cerebras - provider_id: ${env.ENABLE_CEREBRAS:=__disabled__}
- remote::ollama provider_type: remote::cerebras
- remote::vllm - provider_id: ${env.ENABLE_OLLAMA:=__disabled__}
- remote::tgi provider_type: remote::ollama
- remote::hf::serverless - provider_id: ${env.ENABLE_VLLM:=__disabled__}
- remote::hf::endpoint provider_type: remote::vllm
- remote::fireworks - provider_id: ${env.ENABLE_TGI:=__disabled__}
- remote::together provider_type: remote::tgi
- remote::bedrock - provider_id: ${env.ENABLE_HF_SERVERLESS:=__disabled__}
- remote::databricks provider_type: remote::hf::serverless
- remote::nvidia - provider_id: ${env.ENABLE_HF_ENDPOINT:=__disabled__}
- remote::runpod provider_type: remote::hf::endpoint
- remote::openai - provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
- remote::anthropic provider_type: remote::fireworks
- remote::gemini - provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
- remote::groq provider_type: remote::together
- remote::llama-openai-compat - provider_id: ${env.ENABLE_BEDROCK:=__disabled__}
- remote::sambanova provider_type: remote::bedrock
- remote::passthrough - provider_id: ${env.ENABLE_DATABRICKS:=__disabled__}
- inline::sentence-transformers provider_type: remote::databricks
- provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_type: remote::nvidia
- provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_type: remote::runpod
- provider_id: ${env.ENABLE_OPENAI:=__disabled__}
provider_type: remote::openai
- provider_id: ${env.ENABLE_ANTHROPIC:=__disabled__}
provider_type: remote::anthropic
- provider_id: ${env.ENABLE_GEMINI:=__disabled__}
provider_type: remote::gemini
- provider_id: ${env.ENABLE_GROQ:=__disabled__}
provider_type: remote::groq
- provider_id: ${env.ENABLE_LLAMA_OPENAI_COMPAT:=__disabled__}
provider_type: remote::llama-openai-compat
- provider_id: ${env.ENABLE_SAMBANOVA:=__disabled__}
provider_type: remote::sambanova
- provider_id: ${env.ENABLE_PASSTHROUGH:=__disabled__}
provider_type: remote::passthrough
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
vector_io: vector_io:
- inline::faiss - provider_id: ${env.ENABLE_FAISS:=faiss}
- inline::sqlite-vec provider_type: inline::faiss
- inline::milvus - provider_id: ${env.ENABLE_SQLITE_VEC:=__disabled__}
- remote::chromadb provider_type: inline::sqlite-vec
- remote::pgvector - provider_id: ${env.ENABLE_MILVUS:=__disabled__}
provider_type: inline::milvus
- provider_id: ${env.ENABLE_CHROMADB:=__disabled__}
provider_type: remote::chromadb
- provider_id: ${env.ENABLE_PGVECTOR:=__disabled__}
provider_type: remote::pgvector
files: files:
- inline::localfs - provider_id: localfs
provider_type: inline::localfs
safety: safety:
- inline::llama-guard - provider_id: llama-guard
provider_type: inline::llama-guard
agents: agents:
- inline::meta-reference - provider_id: meta-reference
provider_type: inline::meta-reference
telemetry: telemetry:
- inline::meta-reference - provider_id: meta-reference
provider_type: inline::meta-reference
post_training: post_training:
- inline::huggingface - provider_id: huggingface
provider_type: inline::huggingface
eval: eval:
- inline::meta-reference - provider_id: meta-reference
provider_type: inline::meta-reference
datasetio: datasetio:
- remote::huggingface - provider_id: huggingface
- inline::localfs provider_type: remote::huggingface
- provider_id: localfs
provider_type: inline::localfs
scoring: scoring:
- inline::basic - provider_id: basic
- inline::llm-as-judge provider_type: inline::basic
- inline::braintrust - provider_id: llm-as-judge
provider_type: inline::llm-as-judge
- provider_id: braintrust
provider_type: inline::braintrust
tool_runtime: tool_runtime:
- remote::brave-search - provider_id: brave-search
- remote::tavily-search provider_type: remote::brave-search
- inline::rag-runtime - provider_id: tavily-search
- remote::model-context-protocol provider_type: remote::tavily-search
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
image_type: conda image_type: conda
image_name: ci-tests
additional_pip_packages: additional_pip_packages:
- aiosqlite - aiosqlite
- asyncpg - asyncpg

View file

@ -56,7 +56,6 @@ providers:
api_key: ${env.TOGETHER_API_KEY} api_key: ${env.TOGETHER_API_KEY}
- provider_id: ${env.ENABLE_BEDROCK:=__disabled__} - provider_id: ${env.ENABLE_BEDROCK:=__disabled__}
provider_type: remote::bedrock provider_type: remote::bedrock
config: {}
- provider_id: ${env.ENABLE_DATABRICKS:=__disabled__} - provider_id: ${env.ENABLE_DATABRICKS:=__disabled__}
provider_type: remote::databricks provider_type: remote::databricks
config: config:
@ -107,7 +106,6 @@ providers:
api_key: ${env.PASSTHROUGH_API_KEY} api_key: ${env.PASSTHROUGH_API_KEY}
- provider_id: ${env.ENABLE_SENTENCE_TRANSFORMERS:=sentence-transformers} - provider_id: ${env.ENABLE_SENTENCE_TRANSFORMERS:=sentence-transformers}
provider_type: inline::sentence-transformers provider_type: inline::sentence-transformers
config: {}
vector_io: vector_io:
- provider_id: ${env.ENABLE_FAISS:=faiss} - provider_id: ${env.ENABLE_FAISS:=faiss}
provider_type: inline::faiss provider_type: inline::faiss
@ -208,10 +206,8 @@ providers:
scoring: scoring:
- provider_id: basic - provider_id: basic
provider_type: inline::basic provider_type: inline::basic
config: {}
- provider_id: llm-as-judge - provider_id: llm-as-judge
provider_type: inline::llm-as-judge provider_type: inline::llm-as-judge
config: {}
- provider_id: braintrust - provider_id: braintrust
provider_type: inline::braintrust provider_type: inline::braintrust
config: config:
@ -229,10 +225,8 @@ providers:
max_results: 3 max_results: 3
- provider_id: rag-runtime - provider_id: rag-runtime
provider_type: inline::rag-runtime provider_type: inline::rag-runtime
config: {}
- provider_id: model-context-protocol - provider_id: model-context-protocol
provider_type: remote::model-context-protocol provider_type: remote::model-context-protocol
config: {}
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/registry.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/registry.db

View file

@ -4,32 +4,50 @@ distribution_spec:
container container
providers: providers:
inference: inference:
- remote::tgi - provider_id: tgi
- inline::sentence-transformers provider_type: remote::tgi
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
vector_io: vector_io:
- inline::faiss - provider_id: faiss
- remote::chromadb provider_type: inline::faiss
- remote::pgvector - provider_id: chromadb
provider_type: remote::chromadb
- provider_id: pgvector
provider_type: remote::pgvector
safety: safety:
- inline::llama-guard - provider_id: llama-guard
provider_type: inline::llama-guard
agents: agents:
- inline::meta-reference - provider_id: meta-reference
provider_type: inline::meta-reference
telemetry: telemetry:
- inline::meta-reference - provider_id: meta-reference
provider_type: inline::meta-reference
eval: eval:
- inline::meta-reference - provider_id: meta-reference
provider_type: inline::meta-reference
datasetio: datasetio:
- remote::huggingface - provider_id: huggingface
- inline::localfs provider_type: remote::huggingface
- provider_id: localfs
provider_type: inline::localfs
scoring: scoring:
- inline::basic - provider_id: basic
- inline::llm-as-judge provider_type: inline::basic
- inline::braintrust - provider_id: llm-as-judge
provider_type: inline::llm-as-judge
- provider_id: braintrust
provider_type: inline::braintrust
tool_runtime: tool_runtime:
- remote::brave-search - provider_id: brave-search
- remote::tavily-search provider_type: remote::brave-search
- inline::rag-runtime - provider_id: tavily-search
provider_type: remote::tavily-search
- provider_id: rag-runtime
provider_type: inline::rag-runtime
image_type: conda image_type: conda
image_name: dell
additional_pip_packages: additional_pip_packages:
- aiosqlite - aiosqlite
- sqlalchemy[asyncio] - sqlalchemy[asyncio]

View file

@ -19,18 +19,32 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin
def get_distribution_template() -> DistributionTemplate: def get_distribution_template() -> DistributionTemplate:
providers = { providers = {
"inference": ["remote::tgi", "inline::sentence-transformers"], "inference": [
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"], Provider(provider_id="tgi", provider_type="remote::tgi"),
"safety": ["inline::llama-guard"], Provider(provider_id="sentence-transformers", provider_type="inline::sentence-transformers"),
"agents": ["inline::meta-reference"], ],
"telemetry": ["inline::meta-reference"], "vector_io": [
"eval": ["inline::meta-reference"], Provider(provider_id="faiss", provider_type="inline::faiss"),
"datasetio": ["remote::huggingface", "inline::localfs"], Provider(provider_id="chromadb", provider_type="remote::chromadb"),
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"], Provider(provider_id="pgvector", provider_type="remote::pgvector"),
],
"safety": [Provider(provider_id="llama-guard", provider_type="inline::llama-guard")],
"agents": [Provider(provider_id="meta-reference", provider_type="inline::meta-reference")],
"telemetry": [Provider(provider_id="meta-reference", provider_type="inline::meta-reference")],
"eval": [Provider(provider_id="meta-reference", provider_type="inline::meta-reference")],
"datasetio": [
Provider(provider_id="huggingface", provider_type="remote::huggingface"),
Provider(provider_id="localfs", provider_type="inline::localfs"),
],
"scoring": [
Provider(provider_id="basic", provider_type="inline::basic"),
Provider(provider_id="llm-as-judge", provider_type="inline::llm-as-judge"),
Provider(provider_id="braintrust", provider_type="inline::braintrust"),
],
"tool_runtime": [ "tool_runtime": [
"remote::brave-search", Provider(provider_id="brave-search", provider_type="remote::brave-search"),
"remote::tavily-search", Provider(provider_id="tavily-search", provider_type="remote::tavily-search"),
"inline::rag-runtime", Provider(provider_id="rag-runtime", provider_type="inline::rag-runtime"),
], ],
} }
name = "dell" name = "dell"

View file

@ -22,7 +22,6 @@ providers:
url: ${env.DEH_SAFETY_URL} url: ${env.DEH_SAFETY_URL}
- provider_id: sentence-transformers - provider_id: sentence-transformers
provider_type: inline::sentence-transformers provider_type: inline::sentence-transformers
config: {}
vector_io: vector_io:
- provider_id: chromadb - provider_id: chromadb
provider_type: remote::chromadb provider_type: remote::chromadb
@ -74,10 +73,8 @@ providers:
scoring: scoring:
- provider_id: basic - provider_id: basic
provider_type: inline::basic provider_type: inline::basic
config: {}
- provider_id: llm-as-judge - provider_id: llm-as-judge
provider_type: inline::llm-as-judge provider_type: inline::llm-as-judge
config: {}
- provider_id: braintrust - provider_id: braintrust
provider_type: inline::braintrust provider_type: inline::braintrust
config: config:
@ -95,7 +92,6 @@ providers:
max_results: 3 max_results: 3
- provider_id: rag-runtime - provider_id: rag-runtime
provider_type: inline::rag-runtime provider_type: inline::rag-runtime
config: {}
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/registry.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/registry.db

View file

@ -18,7 +18,6 @@ providers:
url: ${env.DEH_URL} url: ${env.DEH_URL}
- provider_id: sentence-transformers - provider_id: sentence-transformers
provider_type: inline::sentence-transformers provider_type: inline::sentence-transformers
config: {}
vector_io: vector_io:
- provider_id: chromadb - provider_id: chromadb
provider_type: remote::chromadb provider_type: remote::chromadb
@ -70,10 +69,8 @@ providers:
scoring: scoring:
- provider_id: basic - provider_id: basic
provider_type: inline::basic provider_type: inline::basic
config: {}
- provider_id: llm-as-judge - provider_id: llm-as-judge
provider_type: inline::llm-as-judge provider_type: inline::llm-as-judge
config: {}
- provider_id: braintrust - provider_id: braintrust
provider_type: inline::braintrust provider_type: inline::braintrust
config: config:
@ -91,7 +88,6 @@ providers:
max_results: 3 max_results: 3
- provider_id: rag-runtime - provider_id: rag-runtime
provider_type: inline::rag-runtime provider_type: inline::rag-runtime
config: {}
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/registry.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/dell}/registry.db

View file

@ -3,32 +3,50 @@ distribution_spec:
description: Use Meta Reference for running LLM inference description: Use Meta Reference for running LLM inference
providers: providers:
inference: inference:
- inline::meta-reference - provider_id: meta-reference
provider_type: inline::meta-reference
vector_io: vector_io:
- inline::faiss - provider_id: faiss
- remote::chromadb provider_type: inline::faiss
- remote::pgvector - provider_id: chromadb
provider_type: remote::chromadb
- provider_id: pgvector
provider_type: remote::pgvector
safety: safety:
- inline::llama-guard - provider_id: llama-guard
provider_type: inline::llama-guard
agents: agents:
- inline::meta-reference - provider_id: meta-reference
provider_type: inline::meta-reference
telemetry: telemetry:
- inline::meta-reference - provider_id: meta-reference
provider_type: inline::meta-reference
eval: eval:
- inline::meta-reference - provider_id: meta-reference
provider_type: inline::meta-reference
datasetio: datasetio:
- remote::huggingface - provider_id: huggingface
- inline::localfs provider_type: remote::huggingface
- provider_id: localfs
provider_type: inline::localfs
scoring: scoring:
- inline::basic - provider_id: basic
- inline::llm-as-judge provider_type: inline::basic
- inline::braintrust - provider_id: llm-as-judge
provider_type: inline::llm-as-judge
- provider_id: braintrust
provider_type: inline::braintrust
tool_runtime: tool_runtime:
- remote::brave-search - provider_id: brave-search
- remote::tavily-search provider_type: remote::brave-search
- inline::rag-runtime - provider_id: tavily-search
- remote::model-context-protocol provider_type: remote::tavily-search
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
image_type: conda image_type: conda
image_name: meta-reference-gpu
additional_pip_packages: additional_pip_packages:
- aiosqlite - aiosqlite
- sqlalchemy[asyncio] - sqlalchemy[asyncio]

View file

@ -25,19 +25,91 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin
def get_distribution_template() -> DistributionTemplate: def get_distribution_template() -> DistributionTemplate:
providers = { providers = {
"inference": ["inline::meta-reference"], "inference": [
"vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"], Provider(
"safety": ["inline::llama-guard"], provider_id="meta-reference",
"agents": ["inline::meta-reference"], provider_type="inline::meta-reference",
"telemetry": ["inline::meta-reference"], )
"eval": ["inline::meta-reference"], ],
"datasetio": ["remote::huggingface", "inline::localfs"], "vector_io": [
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"], Provider(
provider_id="faiss",
provider_type="inline::faiss",
),
Provider(
provider_id="chromadb",
provider_type="remote::chromadb",
),
Provider(
provider_id="pgvector",
provider_type="remote::pgvector",
),
],
"safety": [
Provider(
provider_id="llama-guard",
provider_type="inline::llama-guard",
)
],
"agents": [
Provider(
provider_id="meta-reference",
provider_type="inline::meta-reference",
)
],
"telemetry": [
Provider(
provider_id="meta-reference",
provider_type="inline::meta-reference",
)
],
"eval": [
Provider(
provider_id="meta-reference",
provider_type="inline::meta-reference",
)
],
"datasetio": [
Provider(
provider_id="huggingface",
provider_type="remote::huggingface",
),
Provider(
provider_id="localfs",
provider_type="inline::localfs",
),
],
"scoring": [
Provider(
provider_id="basic",
provider_type="inline::basic",
),
Provider(
provider_id="llm-as-judge",
provider_type="inline::llm-as-judge",
),
Provider(
provider_id="braintrust",
provider_type="inline::braintrust",
),
],
"tool_runtime": [ "tool_runtime": [
"remote::brave-search", Provider(
"remote::tavily-search", provider_id="brave-search",
"inline::rag-runtime", provider_type="remote::brave-search",
"remote::model-context-protocol", ),
Provider(
provider_id="tavily-search",
provider_type="remote::tavily-search",
),
Provider(
provider_id="rag-runtime",
provider_type="inline::rag-runtime",
),
Provider(
provider_id="model-context-protocol",
provider_type="remote::model-context-protocol",
),
], ],
} }
name = "meta-reference-gpu" name = "meta-reference-gpu"

View file

@ -24,7 +24,6 @@ providers:
max_seq_len: ${env.MAX_SEQ_LEN:=4096} max_seq_len: ${env.MAX_SEQ_LEN:=4096}
- provider_id: sentence-transformers - provider_id: sentence-transformers
provider_type: inline::sentence-transformers provider_type: inline::sentence-transformers
config: {}
- provider_id: meta-reference-safety - provider_id: meta-reference-safety
provider_type: inline::meta-reference provider_type: inline::meta-reference
config: config:
@ -88,10 +87,8 @@ providers:
scoring: scoring:
- provider_id: basic - provider_id: basic
provider_type: inline::basic provider_type: inline::basic
config: {}
- provider_id: llm-as-judge - provider_id: llm-as-judge
provider_type: inline::llm-as-judge provider_type: inline::llm-as-judge
config: {}
- provider_id: braintrust - provider_id: braintrust
provider_type: inline::braintrust provider_type: inline::braintrust
config: config:
@ -109,10 +106,8 @@ providers:
max_results: 3 max_results: 3
- provider_id: rag-runtime - provider_id: rag-runtime
provider_type: inline::rag-runtime provider_type: inline::rag-runtime
config: {}
- provider_id: model-context-protocol - provider_id: model-context-protocol
provider_type: remote::model-context-protocol provider_type: remote::model-context-protocol
config: {}
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/registry.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/registry.db

View file

@ -24,7 +24,6 @@ providers:
max_seq_len: ${env.MAX_SEQ_LEN:=4096} max_seq_len: ${env.MAX_SEQ_LEN:=4096}
- provider_id: sentence-transformers - provider_id: sentence-transformers
provider_type: inline::sentence-transformers provider_type: inline::sentence-transformers
config: {}
vector_io: vector_io:
- provider_id: faiss - provider_id: faiss
provider_type: inline::faiss provider_type: inline::faiss
@ -78,10 +77,8 @@ providers:
scoring: scoring:
- provider_id: basic - provider_id: basic
provider_type: inline::basic provider_type: inline::basic
config: {}
- provider_id: llm-as-judge - provider_id: llm-as-judge
provider_type: inline::llm-as-judge provider_type: inline::llm-as-judge
config: {}
- provider_id: braintrust - provider_id: braintrust
provider_type: inline::braintrust provider_type: inline::braintrust
config: config:
@ -99,10 +96,8 @@ providers:
max_results: 3 max_results: 3
- provider_id: rag-runtime - provider_id: rag-runtime
provider_type: inline::rag-runtime provider_type: inline::rag-runtime
config: {}
- provider_id: model-context-protocol - provider_id: model-context-protocol
provider_type: remote::model-context-protocol provider_type: remote::model-context-protocol
config: {}
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/registry.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/meta-reference-gpu}/registry.db

View file

@ -3,27 +3,39 @@ distribution_spec:
description: Use NVIDIA NIM for running LLM inference, evaluation and safety description: Use NVIDIA NIM for running LLM inference, evaluation and safety
providers: providers:
inference: inference:
- remote::nvidia - provider_id: nvidia
provider_type: remote::nvidia
vector_io: vector_io:
- inline::faiss - provider_id: faiss
provider_type: inline::faiss
safety: safety:
- remote::nvidia - provider_id: nvidia
provider_type: remote::nvidia
agents: agents:
- inline::meta-reference - provider_id: meta-reference
provider_type: inline::meta-reference
telemetry: telemetry:
- inline::meta-reference - provider_id: meta-reference
provider_type: inline::meta-reference
eval: eval:
- remote::nvidia - provider_id: nvidia
provider_type: remote::nvidia
post_training: post_training:
- remote::nvidia - provider_id: nvidia
provider_type: remote::nvidia
datasetio: datasetio:
- inline::localfs - provider_id: localfs
- remote::nvidia provider_type: inline::localfs
- provider_id: nvidia
provider_type: remote::nvidia
scoring: scoring:
- inline::basic - provider_id: basic
provider_type: inline::basic
tool_runtime: tool_runtime:
- inline::rag-runtime - provider_id: rag-runtime
provider_type: inline::rag-runtime
image_type: conda image_type: conda
image_name: nvidia
additional_pip_packages: additional_pip_packages:
- aiosqlite - aiosqlite
- sqlalchemy[asyncio] - sqlalchemy[asyncio]

View file

@ -1,3 +1,6 @@
---
orphan: true
---
# NVIDIA Distribution # NVIDIA Distribution
The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations. The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations.

View file

@ -17,16 +17,65 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin
def get_distribution_template() -> DistributionTemplate: def get_distribution_template() -> DistributionTemplate:
providers = { providers = {
"inference": ["remote::nvidia"], "inference": [
"vector_io": ["inline::faiss"], Provider(
"safety": ["remote::nvidia"], provider_id="nvidia",
"agents": ["inline::meta-reference"], provider_type="remote::nvidia",
"telemetry": ["inline::meta-reference"], )
"eval": ["remote::nvidia"], ],
"post_training": ["remote::nvidia"], "vector_io": [
"datasetio": ["inline::localfs", "remote::nvidia"], Provider(
"scoring": ["inline::basic"], provider_id="faiss",
"tool_runtime": ["inline::rag-runtime"], provider_type="inline::faiss",
)
],
"safety": [
Provider(
provider_id="nvidia",
provider_type="remote::nvidia",
)
],
"agents": [
Provider(
provider_id="meta-reference",
provider_type="inline::meta-reference",
)
],
"telemetry": [
Provider(
provider_id="meta-reference",
provider_type="inline::meta-reference",
)
],
"eval": [
Provider(
provider_id="nvidia",
provider_type="remote::nvidia",
)
],
"post_training": [Provider(provider_id="nvidia", provider_type="remote::nvidia", config={})],
"datasetio": [
Provider(
provider_id="localfs",
provider_type="inline::localfs",
),
Provider(
provider_id="nvidia",
provider_type="remote::nvidia",
),
],
"scoring": [
Provider(
provider_id="basic",
provider_type="inline::basic",
)
],
"tool_runtime": [
Provider(
provider_id="rag-runtime",
provider_type="inline::rag-runtime",
)
],
} }
inference_provider = Provider( inference_provider = Provider(

View file

@ -85,11 +85,9 @@ providers:
scoring: scoring:
- provider_id: basic - provider_id: basic
provider_type: inline::basic provider_type: inline::basic
config: {}
tool_runtime: tool_runtime:
- provider_id: rag-runtime - provider_id: rag-runtime
provider_type: inline::rag-runtime provider_type: inline::rag-runtime
config: {}
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/registry.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/registry.db

View file

@ -74,11 +74,9 @@ providers:
scoring: scoring:
- provider_id: basic - provider_id: basic
provider_type: inline::basic provider_type: inline::basic
config: {}
tool_runtime: tool_runtime:
- provider_id: rag-runtime - provider_id: rag-runtime
provider_type: inline::rag-runtime provider_type: inline::rag-runtime
config: {}
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/registry.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/registry.db

View file

@ -3,36 +3,58 @@ distribution_spec:
description: Distribution for running open benchmarks description: Distribution for running open benchmarks
providers: providers:
inference: inference:
- remote::openai - provider_id: openai
- remote::anthropic provider_type: remote::openai
- remote::gemini - provider_id: anthropic
- remote::groq provider_type: remote::anthropic
- remote::together - provider_id: gemini
provider_type: remote::gemini
- provider_id: groq
provider_type: remote::groq
- provider_id: together
provider_type: remote::together
vector_io: vector_io:
- inline::sqlite-vec - provider_id: sqlite-vec
- remote::chromadb provider_type: inline::sqlite-vec
- remote::pgvector - provider_id: chromadb
provider_type: remote::chromadb
- provider_id: pgvector
provider_type: remote::pgvector
safety: safety:
- inline::llama-guard - provider_id: llama-guard
provider_type: inline::llama-guard
agents: agents:
- inline::meta-reference - provider_id: meta-reference
provider_type: inline::meta-reference
telemetry: telemetry:
- inline::meta-reference - provider_id: meta-reference
provider_type: inline::meta-reference
eval: eval:
- inline::meta-reference - provider_id: meta-reference
provider_type: inline::meta-reference
datasetio: datasetio:
- remote::huggingface - provider_id: huggingface
- inline::localfs provider_type: remote::huggingface
- provider_id: localfs
provider_type: inline::localfs
scoring: scoring:
- inline::basic - provider_id: basic
- inline::llm-as-judge provider_type: inline::basic
- inline::braintrust - provider_id: llm-as-judge
provider_type: inline::llm-as-judge
- provider_id: braintrust
provider_type: inline::braintrust
tool_runtime: tool_runtime:
- remote::brave-search - provider_id: brave-search
- remote::tavily-search provider_type: remote::brave-search
- inline::rag-runtime - provider_id: tavily-search
- remote::model-context-protocol provider_type: remote::tavily-search
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
image_type: conda image_type: conda
image_name: open-benchmark
additional_pip_packages: additional_pip_packages:
- aiosqlite - aiosqlite
- sqlalchemy[asyncio] - sqlalchemy[asyncio]

View file

@ -96,19 +96,33 @@ def get_inference_providers() -> tuple[list[Provider], dict[str, list[ProviderMo
def get_distribution_template() -> DistributionTemplate: def get_distribution_template() -> DistributionTemplate:
inference_providers, available_models = get_inference_providers() inference_providers, available_models = get_inference_providers()
providers = { providers = {
"inference": [p.provider_type for p in inference_providers], "inference": inference_providers,
"vector_io": ["inline::sqlite-vec", "remote::chromadb", "remote::pgvector"], "vector_io": [
"safety": ["inline::llama-guard"], Provider(provider_id="sqlite-vec", provider_type="inline::sqlite-vec"),
"agents": ["inline::meta-reference"], Provider(provider_id="chromadb", provider_type="remote::chromadb"),
"telemetry": ["inline::meta-reference"], Provider(provider_id="pgvector", provider_type="remote::pgvector"),
"eval": ["inline::meta-reference"], ],
"datasetio": ["remote::huggingface", "inline::localfs"], "safety": [Provider(provider_id="llama-guard", provider_type="inline::llama-guard")],
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"], "agents": [Provider(provider_id="meta-reference", provider_type="inline::meta-reference")],
"telemetry": [Provider(provider_id="meta-reference", provider_type="inline::meta-reference")],
"eval": [Provider(provider_id="meta-reference", provider_type="inline::meta-reference")],
"datasetio": [
Provider(provider_id="huggingface", provider_type="remote::huggingface"),
Provider(provider_id="localfs", provider_type="inline::localfs"),
],
"scoring": [
Provider(provider_id="basic", provider_type="inline::basic"),
Provider(provider_id="llm-as-judge", provider_type="inline::llm-as-judge"),
Provider(provider_id="braintrust", provider_type="inline::braintrust"),
],
"tool_runtime": [ "tool_runtime": [
"remote::brave-search", Provider(provider_id="brave-search", provider_type="remote::brave-search"),
"remote::tavily-search", Provider(provider_id="tavily-search", provider_type="remote::tavily-search"),
"inline::rag-runtime", Provider(provider_id="rag-runtime", provider_type="inline::rag-runtime"),
"remote::model-context-protocol", Provider(
provider_id="model-context-protocol",
provider_type="remote::model-context-protocol",
),
], ],
} }
name = "open-benchmark" name = "open-benchmark"

View file

@ -106,10 +106,8 @@ providers:
scoring: scoring:
- provider_id: basic - provider_id: basic
provider_type: inline::basic provider_type: inline::basic
config: {}
- provider_id: llm-as-judge - provider_id: llm-as-judge
provider_type: inline::llm-as-judge provider_type: inline::llm-as-judge
config: {}
- provider_id: braintrust - provider_id: braintrust
provider_type: inline::braintrust provider_type: inline::braintrust
config: config:
@ -127,10 +125,8 @@ providers:
max_results: 3 max_results: 3
- provider_id: rag-runtime - provider_id: rag-runtime
provider_type: inline::rag-runtime provider_type: inline::rag-runtime
config: {}
- provider_id: model-context-protocol - provider_id: model-context-protocol
provider_type: remote::model-context-protocol provider_type: remote::model-context-protocol
config: {}
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/registry.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/registry.db

View file

@ -3,22 +3,33 @@ distribution_spec:
description: Quick start template for running Llama Stack with several popular providers description: Quick start template for running Llama Stack with several popular providers
providers: providers:
inference: inference:
- remote::vllm - provider_id: vllm-inference
- inline::sentence-transformers provider_type: remote::vllm
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
vector_io: vector_io:
- remote::chromadb - provider_id: chromadb
provider_type: remote::chromadb
safety: safety:
- inline::llama-guard - provider_id: llama-guard
provider_type: inline::llama-guard
agents: agents:
- inline::meta-reference - provider_id: meta-reference
provider_type: inline::meta-reference
telemetry: telemetry:
- inline::meta-reference - provider_id: meta-reference
provider_type: inline::meta-reference
tool_runtime: tool_runtime:
- remote::brave-search - provider_id: brave-search
- remote::tavily-search provider_type: remote::brave-search
- inline::rag-runtime - provider_id: tavily-search
- remote::model-context-protocol provider_type: remote::tavily-search
- provider_id: rag-runtime
provider_type: inline::rag-runtime
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
image_type: conda image_type: conda
image_name: postgres-demo
additional_pip_packages: additional_pip_packages:
- asyncpg - asyncpg
- psycopg2-binary - psycopg2-binary

View file

@ -34,16 +34,24 @@ def get_distribution_template() -> DistributionTemplate:
), ),
] ]
providers = { providers = {
"inference": ([p.provider_type for p in inference_providers] + ["inline::sentence-transformers"]), "inference": inference_providers
"vector_io": ["remote::chromadb"], + [
"safety": ["inline::llama-guard"], Provider(provider_id="sentence-transformers", provider_type="inline::sentence-transformers"),
"agents": ["inline::meta-reference"], ],
"telemetry": ["inline::meta-reference"], "vector_io": [
Provider(provider_id="chromadb", provider_type="remote::chromadb"),
],
"safety": [Provider(provider_id="llama-guard", provider_type="inline::llama-guard")],
"agents": [Provider(provider_id="meta-reference", provider_type="inline::meta-reference")],
"telemetry": [Provider(provider_id="meta-reference", provider_type="inline::meta-reference")],
"tool_runtime": [ "tool_runtime": [
"remote::brave-search", Provider(provider_id="brave-search", provider_type="remote::brave-search"),
"remote::tavily-search", Provider(provider_id="tavily-search", provider_type="remote::tavily-search"),
"inline::rag-runtime", Provider(provider_id="rag-runtime", provider_type="inline::rag-runtime"),
"remote::model-context-protocol", Provider(
provider_id="model-context-protocol",
provider_type="remote::model-context-protocol",
),
], ],
} }
name = "postgres-demo" name = "postgres-demo"

Some files were not shown because too many files have changed in this diff Show more