mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-25 21:57:45 +00:00
Merge branch 'main' into fix-chroma
This commit is contained in:
commit
062c6a419a
76 changed files with 2468 additions and 913 deletions
27
.github/actions/setup-vllm/action.yml
vendored
Normal file
27
.github/actions/setup-vllm/action.yml
vendored
Normal file
|
@ -0,0 +1,27 @@
|
|||
name: Setup VLLM
|
||||
description: Start VLLM
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Start VLLM
|
||||
shell: bash
|
||||
run: |
|
||||
# Start vllm container
|
||||
docker run -d \
|
||||
--name vllm \
|
||||
-p 8000:8000 \
|
||||
--privileged=true \
|
||||
quay.io/higginsd/vllm-cpu:65393ee064 \
|
||||
--host 0.0.0.0 \
|
||||
--port 8000 \
|
||||
--enable-auto-tool-choice \
|
||||
--tool-call-parser llama3_json \
|
||||
--model /root/.cache/Llama-3.2-1B-Instruct \
|
||||
--served-model-name meta-llama/Llama-3.2-1B-Instruct
|
||||
|
||||
# Wait for vllm to be ready
|
||||
echo "Waiting for vllm to be ready..."
|
||||
timeout 900 bash -c 'until curl -f http://localhost:8000/health; do
|
||||
echo "Waiting for vllm..."
|
||||
sleep 5
|
||||
done'
|
18
.github/workflows/install-script-ci.yml
vendored
18
.github/workflows/install-script-ci.yml
vendored
|
@ -17,10 +17,20 @@ jobs:
|
|||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # 4.2.2
|
||||
- name: Run ShellCheck on install.sh
|
||||
run: shellcheck scripts/install.sh
|
||||
smoke-test:
|
||||
needs: lint
|
||||
smoke-test-on-dev:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # 4.2.2
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
|
||||
- name: Install dependencies
|
||||
uses: ./.github/actions/setup-runner
|
||||
|
||||
- name: Build a single provider
|
||||
run: |
|
||||
USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --template starter --image-type container --image-name test
|
||||
|
||||
- name: Run installer end-to-end
|
||||
run: ./scripts/install.sh
|
||||
run: |
|
||||
IMAGE_ID=$(docker images --format "{{.Repository}}:{{.Tag}}" | head -n 1)
|
||||
./scripts/install.sh --image $IMAGE_ID
|
||||
|
|
62
.github/workflows/integration-tests.yml
vendored
62
.github/workflows/integration-tests.yml
vendored
|
@ -14,13 +14,19 @@ on:
|
|||
- '.github/workflows/integration-tests.yml' # This workflow
|
||||
- '.github/actions/setup-ollama/action.yml'
|
||||
schedule:
|
||||
- cron: '0 0 * * *' # Daily at 12 AM UTC
|
||||
# If changing the cron schedule, update the provider in the test-matrix job
|
||||
- cron: '0 0 * * *' # (test latest client) Daily at 12 AM UTC
|
||||
- cron: '1 0 * * 0' # (test vllm) Weekly on Sunday at 1 AM UTC
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
test-all-client-versions:
|
||||
description: 'Test against both the latest and published versions'
|
||||
type: boolean
|
||||
default: false
|
||||
test-provider:
|
||||
description: 'Test against a specific provider'
|
||||
type: string
|
||||
default: 'ollama'
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
|
@ -53,8 +59,17 @@ jobs:
|
|||
matrix:
|
||||
test-type: ${{ fromJson(needs.discover-tests.outputs.test-type) }}
|
||||
client-type: [library, server]
|
||||
# Use vllm on weekly schedule, otherwise use test-provider input (defaults to ollama)
|
||||
provider: ${{ (github.event.schedule == '1 0 * * 0') && fromJSON('["vllm"]') || fromJSON(format('["{0}"]', github.event.inputs.test-provider || 'ollama')) }}
|
||||
python-version: ["3.12", "3.13"]
|
||||
client-version: ${{ (github.event_name == 'schedule' || github.event.inputs.test-all-client-versions == 'true') && fromJSON('["published", "latest"]') || fromJSON('["latest"]') }}
|
||||
client-version: ${{ (github.event.schedule == '0 0 * * 0' || github.event.inputs.test-all-client-versions == 'true') && fromJSON('["published", "latest"]') || fromJSON('["latest"]') }}
|
||||
exclude: # TODO: look into why these tests are failing and fix them
|
||||
- provider: vllm
|
||||
test-type: safety
|
||||
- provider: vllm
|
||||
test-type: post_training
|
||||
- provider: vllm
|
||||
test-type: tool_runtime
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
|
@ -67,8 +82,13 @@ jobs:
|
|||
client-version: ${{ matrix.client-version }}
|
||||
|
||||
- name: Setup ollama
|
||||
if: ${{ matrix.provider == 'ollama' }}
|
||||
uses: ./.github/actions/setup-ollama
|
||||
|
||||
- name: Setup vllm
|
||||
if: ${{ matrix.provider == 'vllm' }}
|
||||
uses: ./.github/actions/setup-vllm
|
||||
|
||||
- name: Build Llama Stack
|
||||
run: |
|
||||
uv run llama stack build --template ci-tests --image-type venv
|
||||
|
@ -81,10 +101,6 @@ jobs:
|
|||
|
||||
- name: Run Integration Tests
|
||||
env:
|
||||
OLLAMA_INFERENCE_MODEL: "llama3.2:3b-instruct-fp16" # for server tests
|
||||
ENABLE_OLLAMA: "ollama" # for server tests
|
||||
OLLAMA_URL: "http://0.0.0.0:11434"
|
||||
SAFETY_MODEL: "llama-guard3:1b"
|
||||
LLAMA_STACK_CLIENT_TIMEOUT: "300" # Increased timeout for eval operations
|
||||
# Use 'shell' to get pipefail behavior
|
||||
# https://docs.github.com/en/actions/reference/workflow-syntax-for-github-actions#exit-codes-and-error-action-preference
|
||||
|
@ -96,12 +112,31 @@ jobs:
|
|||
else
|
||||
stack_config="server:ci-tests"
|
||||
fi
|
||||
|
||||
EXCLUDE_TESTS="builtin_tool or safety_with_image or code_interpreter or test_rag"
|
||||
if [ "${{ matrix.provider }}" == "ollama" ]; then
|
||||
export ENABLE_OLLAMA="ollama"
|
||||
export OLLAMA_URL="http://0.0.0.0:11434"
|
||||
export OLLAMA_INFERENCE_MODEL="llama3.2:3b-instruct-fp16"
|
||||
export TEXT_MODEL=ollama/$OLLAMA_INFERENCE_MODEL
|
||||
export SAFETY_MODEL="llama-guard3:1b"
|
||||
EXTRA_PARAMS="--safety-shield=$SAFETY_MODEL"
|
||||
else
|
||||
export ENABLE_VLLM="vllm"
|
||||
export VLLM_URL="http://localhost:8000/v1"
|
||||
export VLLM_INFERENCE_MODEL="meta-llama/Llama-3.2-1B-Instruct"
|
||||
export TEXT_MODEL=vllm/$VLLM_INFERENCE_MODEL
|
||||
# TODO: remove the not(test_inference_store_tool_calls) once we can get the tool called consistently
|
||||
EXTRA_PARAMS=
|
||||
EXCLUDE_TESTS="${EXCLUDE_TESTS} or test_inference_store_tool_calls"
|
||||
fi
|
||||
|
||||
|
||||
uv run pytest -s -v tests/integration/${{ matrix.test-type }} --stack-config=${stack_config} \
|
||||
-k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \
|
||||
--text-model="ollama/llama3.2:3b-instruct-fp16" \
|
||||
-k "not( ${EXCLUDE_TESTS} )" \
|
||||
--text-model=$TEXT_MODEL \
|
||||
--embedding-model=sentence-transformers/all-MiniLM-L6-v2 \
|
||||
--safety-shield=$SAFETY_MODEL \
|
||||
--color=yes \
|
||||
--color=yes ${EXTRA_PARAMS} \
|
||||
--capture=tee-sys | tee pytest-${{ matrix.test-type }}.log
|
||||
|
||||
- name: Check Storage and Memory Available After Tests
|
||||
|
@ -110,16 +145,17 @@ jobs:
|
|||
free -h
|
||||
df -h
|
||||
|
||||
- name: Write ollama logs to file
|
||||
- name: Write inference logs to file
|
||||
if: ${{ always() }}
|
||||
run: |
|
||||
sudo docker logs ollama > ollama.log
|
||||
sudo docker logs ollama > ollama.log || true
|
||||
sudo docker logs vllm > vllm.log || true
|
||||
|
||||
- name: Upload all logs to artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
|
||||
with:
|
||||
name: logs-${{ github.run_id }}-${{ github.run_attempt }}-${{ matrix.client-type }}-${{ matrix.test-type }}-${{ matrix.python-version }}-${{ matrix.client-version }}
|
||||
name: logs-${{ github.run_id }}-${{ github.run_attempt }}-${{ matrix.provider }}-${{ matrix.client-type }}-${{ matrix.test-type }}-${{ matrix.python-version }}-${{ matrix.client-version }}
|
||||
path: |
|
||||
*.log
|
||||
retention-days: 1
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
name: Test External Providers
|
||||
name: Test External API and Providers
|
||||
|
||||
on:
|
||||
push:
|
||||
|
@ -11,10 +11,10 @@ on:
|
|||
- 'uv.lock'
|
||||
- 'pyproject.toml'
|
||||
- 'requirements.txt'
|
||||
- '.github/workflows/test-external-providers.yml' # This workflow
|
||||
- '.github/workflows/test-external.yml' # This workflow
|
||||
|
||||
jobs:
|
||||
test-external-providers:
|
||||
test-external:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
|
@ -28,24 +28,23 @@ jobs:
|
|||
- name: Install dependencies
|
||||
uses: ./.github/actions/setup-runner
|
||||
|
||||
- name: Apply image type to config file
|
||||
- name: Create API configuration
|
||||
run: |
|
||||
yq -i '.image_type = "${{ matrix.image-type }}"' tests/external-provider/llama-stack-provider-ollama/custom-distro.yaml
|
||||
cat tests/external-provider/llama-stack-provider-ollama/custom-distro.yaml
|
||||
|
||||
- name: Setup directory for Ollama custom provider
|
||||
run: |
|
||||
mkdir -p tests/external-provider/llama-stack-provider-ollama/src/
|
||||
cp -a llama_stack/providers/remote/inference/ollama/ tests/external-provider/llama-stack-provider-ollama/src/llama_stack_provider_ollama
|
||||
mkdir -p /home/runner/.llama/apis.d
|
||||
cp tests/external/weather.yaml /home/runner/.llama/apis.d/weather.yaml
|
||||
|
||||
- name: Create provider configuration
|
||||
run: |
|
||||
mkdir -p /home/runner/.llama/providers.d/remote/inference
|
||||
cp tests/external-provider/llama-stack-provider-ollama/custom_ollama.yaml /home/runner/.llama/providers.d/remote/inference/custom_ollama.yaml
|
||||
mkdir -p /home/runner/.llama/providers.d/remote/weather
|
||||
cp tests/external/kaze.yaml /home/runner/.llama/providers.d/remote/weather/kaze.yaml
|
||||
|
||||
- name: Print distro dependencies
|
||||
run: |
|
||||
USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. llama stack build --config tests/external/build.yaml --print-deps-only
|
||||
|
||||
- name: Build distro from config file
|
||||
run: |
|
||||
USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. llama stack build --config tests/external-provider/llama-stack-provider-ollama/custom-distro.yaml
|
||||
USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. llama stack build --config tests/external/build.yaml
|
||||
|
||||
- name: Start Llama Stack server in background
|
||||
if: ${{ matrix.image-type }} == 'venv'
|
||||
|
@ -55,19 +54,22 @@ jobs:
|
|||
# Use the virtual environment created by the build step (name comes from build config)
|
||||
source ci-test/bin/activate
|
||||
uv pip list
|
||||
nohup llama stack run tests/external-provider/llama-stack-provider-ollama/run.yaml --image-type ${{ matrix.image-type }} > server.log 2>&1 &
|
||||
nohup llama stack run tests/external/run-byoa.yaml --image-type ${{ matrix.image-type }} > server.log 2>&1 &
|
||||
|
||||
- name: Wait for Llama Stack server to be ready
|
||||
run: |
|
||||
echo "Waiting for Llama Stack server..."
|
||||
for i in {1..30}; do
|
||||
if ! grep -q "Successfully loaded external provider remote::custom_ollama" server.log; then
|
||||
echo "Waiting for Llama Stack server to load the provider..."
|
||||
sleep 1
|
||||
else
|
||||
echo "Provider loaded"
|
||||
if curl -sSf http://localhost:8321/v1/health | grep -q "OK"; then
|
||||
echo "Llama Stack server is up!"
|
||||
exit 0
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
echo "Provider failed to load"
|
||||
echo "Llama Stack server failed to start"
|
||||
cat server.log
|
||||
exit 1
|
||||
|
||||
- name: Test external API
|
||||
run: |
|
||||
curl -sSf http://localhost:8321/v1/weather/locations
|
49
CHANGELOG.md
49
CHANGELOG.md
|
@ -1,5 +1,34 @@
|
|||
# Changelog
|
||||
|
||||
# v0.2.15
|
||||
Published on: 2025-07-16T03:30:01Z
|
||||
|
||||
|
||||
|
||||
---
|
||||
|
||||
# v0.2.14
|
||||
Published on: 2025-07-04T16:06:48Z
|
||||
|
||||
## Highlights
|
||||
|
||||
* Support for Llama Guard 4
|
||||
* Added Milvus support to vector-stores API
|
||||
* Documentation and zero-to-hero updates for latest APIs
|
||||
|
||||
|
||||
---
|
||||
|
||||
# v0.2.13
|
||||
Published on: 2025-06-28T04:28:11Z
|
||||
|
||||
## Highlights
|
||||
* search_mode support in OpenAI vector store API
|
||||
* Security fixes
|
||||
|
||||
|
||||
---
|
||||
|
||||
# v0.2.12
|
||||
Published on: 2025-06-20T22:52:12Z
|
||||
|
||||
|
@ -485,23 +514,3 @@ A small but important bug-fix release to update the URL datatype for the client-
|
|||
|
||||
---
|
||||
|
||||
# v0.0.62
|
||||
Published on: 2024-12-18T02:39:43Z
|
||||
|
||||
|
||||
|
||||
---
|
||||
|
||||
# v0.0.61
|
||||
Published on: 2024-12-10T20:50:33Z
|
||||
|
||||
|
||||
|
||||
---
|
||||
|
||||
# v0.0.55
|
||||
Published on: 2024-11-23T17:14:07Z
|
||||
|
||||
|
||||
|
||||
---
|
||||
|
|
392
docs/source/apis/external.md
Normal file
392
docs/source/apis/external.md
Normal file
|
@ -0,0 +1,392 @@
|
|||
# External APIs
|
||||
|
||||
Llama Stack supports external APIs that live outside of the main codebase. This allows you to:
|
||||
- Create and maintain your own APIs independently
|
||||
- Share APIs with others without contributing to the main codebase
|
||||
- Keep API-specific code separate from the core Llama Stack code
|
||||
|
||||
## Configuration
|
||||
|
||||
To enable external APIs, you need to configure the `external_apis_dir` in your Llama Stack configuration. This directory should contain your external API specifications:
|
||||
|
||||
```yaml
|
||||
external_apis_dir: ~/.llama/apis.d/
|
||||
```
|
||||
|
||||
## Directory Structure
|
||||
|
||||
The external APIs directory should follow this structure:
|
||||
|
||||
```
|
||||
apis.d/
|
||||
custom_api1.yaml
|
||||
custom_api2.yaml
|
||||
```
|
||||
|
||||
Each YAML file in these directories defines an API specification.
|
||||
|
||||
## API Specification
|
||||
|
||||
Here's an example of an external API specification for a weather API:
|
||||
|
||||
```yaml
|
||||
module: weather
|
||||
api_dependencies:
|
||||
- inference
|
||||
protocol: WeatherAPI
|
||||
name: weather
|
||||
pip_packages:
|
||||
- llama-stack-api-weather
|
||||
```
|
||||
|
||||
### API Specification Fields
|
||||
|
||||
- `module`: Python module containing the API implementation
|
||||
- `protocol`: Name of the protocol class for the API
|
||||
- `name`: Name of the API
|
||||
- `pip_packages`: List of pip packages to install the API, typically a single package
|
||||
|
||||
## Required Implementation
|
||||
|
||||
External APIs must expose a `available_providers()` function in their module that returns a list of provider names:
|
||||
|
||||
```python
|
||||
# llama_stack_api_weather/api.py
|
||||
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec
|
||||
|
||||
|
||||
def available_providers() -> list[ProviderSpec]:
|
||||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.weather,
|
||||
provider_type="inline::darksky",
|
||||
pip_packages=[],
|
||||
module="llama_stack_provider_darksky",
|
||||
config_class="llama_stack_provider_darksky.DarkSkyWeatherImplConfig",
|
||||
),
|
||||
]
|
||||
```
|
||||
|
||||
A Protocol class like so:
|
||||
|
||||
```python
|
||||
# llama_stack_api_weather/api.py
|
||||
from typing import Protocol
|
||||
|
||||
from llama_stack.schema_utils import webmethod
|
||||
|
||||
|
||||
class WeatherAPI(Protocol):
|
||||
"""
|
||||
A protocol for the Weather API.
|
||||
"""
|
||||
|
||||
@webmethod(route="/locations", method="GET")
|
||||
async def get_available_locations() -> dict[str, list[str]]:
|
||||
"""
|
||||
Get the available locations.
|
||||
"""
|
||||
...
|
||||
```
|
||||
|
||||
## Example: Custom API
|
||||
|
||||
Here's a complete example of creating and using a custom API:
|
||||
|
||||
1. First, create the API package:
|
||||
|
||||
```bash
|
||||
mkdir -p llama-stack-api-weather
|
||||
cd llama-stack-api-weather
|
||||
mkdir src/llama_stack_api_weather
|
||||
git init
|
||||
uv init
|
||||
```
|
||||
|
||||
2. Edit `pyproject.toml`:
|
||||
|
||||
```toml
|
||||
[project]
|
||||
name = "llama-stack-api-weather"
|
||||
version = "0.1.0"
|
||||
description = "Weather API for Llama Stack"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = ["llama-stack", "pydantic"]
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
include = ["llama_stack_api_weather", "llama_stack_api_weather.*"]
|
||||
```
|
||||
|
||||
3. Create the initial files:
|
||||
|
||||
```bash
|
||||
touch src/llama_stack_api_weather/__init__.py
|
||||
touch src/llama_stack_api_weather/api.py
|
||||
```
|
||||
|
||||
```python
|
||||
# llama-stack-api-weather/src/llama_stack_api_weather/__init__.py
|
||||
"""Weather API for Llama Stack."""
|
||||
|
||||
from .api import WeatherAPI, available_providers
|
||||
|
||||
__all__ = ["WeatherAPI", "available_providers"]
|
||||
```
|
||||
|
||||
4. Create the API implementation:
|
||||
|
||||
```python
|
||||
# llama-stack-api-weather/src/llama_stack_api_weather/weather.py
|
||||
from typing import Protocol
|
||||
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
Api,
|
||||
ProviderSpec,
|
||||
RemoteProviderSpec,
|
||||
)
|
||||
from llama_stack.schema_utils import webmethod
|
||||
|
||||
|
||||
def available_providers() -> list[ProviderSpec]:
|
||||
return [
|
||||
RemoteProviderSpec(
|
||||
api=Api.weather,
|
||||
provider_type="remote::kaze",
|
||||
config_class="llama_stack_provider_kaze.KazeProviderConfig",
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="kaze",
|
||||
module="llama_stack_provider_kaze",
|
||||
pip_packages=["llama_stack_provider_kaze"],
|
||||
config_class="llama_stack_provider_kaze.KazeProviderConfig",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class WeatherProvider(Protocol):
|
||||
"""
|
||||
A protocol for the Weather API.
|
||||
"""
|
||||
|
||||
@webmethod(route="/weather/locations", method="GET")
|
||||
async def get_available_locations() -> dict[str, list[str]]:
|
||||
"""
|
||||
Get the available locations.
|
||||
"""
|
||||
...
|
||||
```
|
||||
|
||||
5. Create the API specification:
|
||||
|
||||
```yaml
|
||||
# ~/.llama/apis.d/weather.yaml
|
||||
module: llama_stack_api_weather
|
||||
name: weather
|
||||
pip_packages: ["llama-stack-api-weather"]
|
||||
protocol: WeatherProvider
|
||||
|
||||
```
|
||||
|
||||
6. Install the API package:
|
||||
|
||||
```bash
|
||||
uv pip install -e .
|
||||
```
|
||||
|
||||
7. Configure Llama Stack to use external APIs:
|
||||
|
||||
```yaml
|
||||
version: "2"
|
||||
image_name: "llama-stack-api-weather"
|
||||
apis:
|
||||
- weather
|
||||
providers: {}
|
||||
external_apis_dir: ~/.llama/apis.d
|
||||
```
|
||||
|
||||
The API will now be available at `/v1/weather/locations`.
|
||||
|
||||
## Example: custom provider for the weather API
|
||||
|
||||
1. Create the provider package:
|
||||
|
||||
```bash
|
||||
mkdir -p llama-stack-provider-kaze
|
||||
cd llama-stack-provider-kaze
|
||||
uv init
|
||||
```
|
||||
|
||||
2. Edit `pyproject.toml`:
|
||||
|
||||
```toml
|
||||
[project]
|
||||
name = "llama-stack-provider-kaze"
|
||||
version = "0.1.0"
|
||||
description = "Kaze weather provider for Llama Stack"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = ["llama-stack", "pydantic", "aiohttp"]
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
include = ["llama_stack_provider_kaze", "llama_stack_provider_kaze.*"]
|
||||
```
|
||||
|
||||
3. Create the initial files:
|
||||
|
||||
```bash
|
||||
touch src/llama_stack_provider_kaze/__init__.py
|
||||
touch src/llama_stack_provider_kaze/kaze.py
|
||||
```
|
||||
|
||||
4. Create the provider implementation:
|
||||
|
||||
|
||||
Initialization function:
|
||||
|
||||
```python
|
||||
# llama-stack-provider-kaze/src/llama_stack_provider_kaze/__init__.py
|
||||
"""Kaze weather provider for Llama Stack."""
|
||||
|
||||
from .config import KazeProviderConfig
|
||||
from .kaze import WeatherKazeAdapter
|
||||
|
||||
__all__ = ["KazeProviderConfig", "WeatherKazeAdapter"]
|
||||
|
||||
|
||||
async def get_adapter_impl(config: KazeProviderConfig, _deps):
|
||||
from .kaze import WeatherKazeAdapter
|
||||
|
||||
impl = WeatherKazeAdapter(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
```
|
||||
|
||||
Configuration:
|
||||
|
||||
```python
|
||||
# llama-stack-provider-kaze/src/llama_stack_provider_kaze/config.py
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class KazeProviderConfig(BaseModel):
|
||||
"""Configuration for the Kaze weather provider."""
|
||||
|
||||
base_url: str = Field(
|
||||
"https://api.kaze.io/v1",
|
||||
description="Base URL for the Kaze weather API",
|
||||
)
|
||||
```
|
||||
|
||||
Main implementation:
|
||||
|
||||
```python
|
||||
# llama-stack-provider-kaze/src/llama_stack_provider_kaze/kaze.py
|
||||
from llama_stack_api_weather.api import WeatherProvider
|
||||
|
||||
from .config import KazeProviderConfig
|
||||
|
||||
|
||||
class WeatherKazeAdapter(WeatherProvider):
|
||||
"""Kaze weather provider implementation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: KazeProviderConfig,
|
||||
) -> None:
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def get_available_locations(self) -> dict[str, list[str]]:
|
||||
"""Get available weather locations."""
|
||||
return {"locations": ["Paris", "Tokyo"]}
|
||||
```
|
||||
|
||||
5. Create the provider specification:
|
||||
|
||||
```yaml
|
||||
# ~/.llama/providers.d/remote/weather/kaze.yaml
|
||||
adapter:
|
||||
adapter_type: kaze
|
||||
pip_packages: ["llama_stack_provider_kaze"]
|
||||
config_class: llama_stack_provider_kaze.config.KazeProviderConfig
|
||||
module: llama_stack_provider_kaze
|
||||
optional_api_dependencies: []
|
||||
```
|
||||
|
||||
6. Install the provider package:
|
||||
|
||||
```bash
|
||||
uv pip install -e .
|
||||
```
|
||||
|
||||
7. Configure Llama Stack to use the provider:
|
||||
|
||||
```yaml
|
||||
# ~/.llama/run-byoa.yaml
|
||||
version: "2"
|
||||
image_name: "llama-stack-api-weather"
|
||||
apis:
|
||||
- weather
|
||||
providers:
|
||||
weather:
|
||||
- provider_id: kaze
|
||||
provider_type: remote::kaze
|
||||
config: {}
|
||||
external_apis_dir: ~/.llama/apis.d
|
||||
external_providers_dir: ~/.llama/providers.d
|
||||
server:
|
||||
port: 8321
|
||||
```
|
||||
|
||||
8. Run the server:
|
||||
|
||||
```bash
|
||||
python -m llama_stack.distribution.server.server --yaml-config ~/.llama/run-byoa.yaml
|
||||
```
|
||||
|
||||
9. Test the API:
|
||||
|
||||
```bash
|
||||
curl -sSf http://127.0.0.1:8321/v1/weather/locations
|
||||
{"locations":["Paris","Tokyo"]}%
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Package Naming**: Use a clear and descriptive name for your API package.
|
||||
|
||||
2. **Version Management**: Keep your API package versioned and compatible with the Llama Stack version you're using.
|
||||
|
||||
3. **Dependencies**: Only include the minimum required dependencies in your API package.
|
||||
|
||||
4. **Documentation**: Include clear documentation in your API package about:
|
||||
- Installation requirements
|
||||
- Configuration options
|
||||
- API endpoints and usage
|
||||
- Any limitations or known issues
|
||||
|
||||
5. **Testing**: Include tests in your API package to ensure it works correctly with Llama Stack.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
If your external API isn't being loaded:
|
||||
|
||||
1. Check that the `external_apis_dir` path is correct and accessible.
|
||||
2. Verify that the YAML files are properly formatted.
|
||||
3. Ensure all required Python packages are installed.
|
||||
4. Check the Llama Stack server logs for any error messages - turn on debug logging to get more information using `LLAMA_STACK_LOGGING=all=debug`.
|
||||
5. Verify that the API package is installed in your Python environment.
|
|
@ -11,6 +11,7 @@ Here are some key topics that will help you build effective agents:
|
|||
- **[RAG (Retrieval-Augmented Generation)](rag)**: Learn how to enhance your agents with external knowledge through retrieval mechanisms.
|
||||
- **[Agent](agent)**: Understand the components and design patterns of the Llama Stack agent framework.
|
||||
- **[Agent Execution Loop](agent_execution_loop)**: Understand how agents process information, make decisions, and execute actions in a continuous loop.
|
||||
- **[Agents vs Responses API](responses_vs_agents)**: Learn the differences between the Agents API and Responses API, and when to use each one.
|
||||
- **[Tools](tools)**: Extend your agents' capabilities by integrating with external tools and APIs.
|
||||
- **[Evals](evals)**: Evaluate your agents' effectiveness and identify areas for improvement.
|
||||
- **[Telemetry](telemetry)**: Monitor and analyze your agents' performance and behavior.
|
||||
|
@ -23,6 +24,7 @@ Here are some key topics that will help you build effective agents:
|
|||
rag
|
||||
agent
|
||||
agent_execution_loop
|
||||
responses_vs_agents
|
||||
tools
|
||||
evals
|
||||
telemetry
|
||||
|
|
177
docs/source/building_applications/responses_vs_agents.md
Normal file
177
docs/source/building_applications/responses_vs_agents.md
Normal file
|
@ -0,0 +1,177 @@
|
|||
# Agents vs OpenAI Responses API
|
||||
|
||||
Llama Stack (LLS) provides two different APIs for building AI applications with tool calling capabilities: the **Agents API** and the **OpenAI Responses API**. While both enable AI systems to use tools, and maintain full conversation history, they serve different use cases and have distinct characteristics.
|
||||
|
||||
> **Note:** For simple and basic inferencing, you may want to use the [Chat Completions API](https://llama-stack.readthedocs.io/en/latest/providers/index.html#chat-completions) directly, before progressing to Agents or Responses API.
|
||||
|
||||
## Overview
|
||||
|
||||
### LLS Agents API
|
||||
The Agents API is a full-featured, stateful system designed for complex, multi-turn conversations. It maintains conversation state through persistent sessions identified by a unique session ID. The API supports comprehensive agent lifecycle management, detailed execution tracking, and rich metadata about each interaction through a structured session/turn/step hierarchy. The API can orchestrate multiple tool calls within a single turn.
|
||||
|
||||
### OpenAI Responses API
|
||||
The OpenAI Responses API is a full-featured, stateful system designed for complex, multi-turn conversations, with direct compatibility with OpenAI's conversational patterns enhanced by LLama Stack's tool calling capabilities. It maintains conversation state by chaining responses through a `previous_response_id`, allowing interactions to branch or continue from any prior point. Each response can perform multiple tool calls within a single turn.
|
||||
|
||||
### Key Differences
|
||||
The LLS Agents API uses the Chat Completions API on the backend for inference as it's the industry standard for building AI applications and most LLM providers are compatible with this API. For a detailed comparison between Responses and Chat Completions, see [OpenAI's documentation](https://platform.openai.com/docs/guides/responses-vs-chat-completions).
|
||||
|
||||
Additionally, Agents let you specify input/output shields whereas Responses do not (though support is planned). Agents use a linear conversation model referenced by a single session ID. Responses, on the other hand, support branching, where each response can serve as a fork point, and conversations are tracked by the latest response ID. Responses also lets you dynamically choose the model, vector store, files, MCP servers, and more on each inference call, enabling more complex workflows. Agents require a static configuration for these components at the start of the session.
|
||||
|
||||
Today the Agents and Responses APIs can be used independently depending on the use case. But, it is also productive to treat the APIs as complementary. It is not currently supported, but it is planned for the LLS Agents API to alternatively use the Responses API as its backend instead of the default Chat Completions API, i.e., enabling a combination of the safety features of Agents with the dynamic configuration and branching capabilities of Responses.
|
||||
|
||||
| Feature | LLS Agents API | OpenAI Responses API |
|
||||
|---------|------------|---------------------|
|
||||
| **Conversation Management** | Linear persistent sessions | Can branch from any previous response ID |
|
||||
| **Input/Output Safety Shields** | Supported | Not yet supported |
|
||||
| **Per-call Flexibility** | Static per-session configuration | Dynamic per-call configuration |
|
||||
|
||||
## Use Case Example: Research with Multiple Search Methods
|
||||
|
||||
Let's compare how both APIs handle a research task where we need to:
|
||||
1. Search for current information and examples
|
||||
2. Access different information sources dynamically
|
||||
3. Continue the conversation based on search results
|
||||
|
||||
### Agents API: Session-based configuration with safety shields
|
||||
|
||||
```python
|
||||
# Create agent with static session configuration
|
||||
agent = Agent(
|
||||
client,
|
||||
model="Llama3.2-3B-Instruct",
|
||||
instructions="You are a helpful coding assistant",
|
||||
tools=[
|
||||
{
|
||||
"name": "builtin::rag/knowledge_search",
|
||||
"args": {"vector_db_ids": ["code_docs"]},
|
||||
},
|
||||
"builtin::code_interpreter",
|
||||
],
|
||||
input_shields=["llama_guard"],
|
||||
output_shields=["llama_guard"],
|
||||
)
|
||||
|
||||
session_id = agent.create_session("code_session")
|
||||
|
||||
# First turn: Search and execute
|
||||
response1 = agent.create_turn(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Find examples of sorting algorithms and run a bubble sort on [3,1,4,1,5]",
|
||||
},
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Continue conversation in same session
|
||||
response2 = agent.create_turn(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Now optimize that code and test it with a larger dataset",
|
||||
},
|
||||
],
|
||||
session_id=session_id, # Same session, maintains full context
|
||||
)
|
||||
|
||||
# Agents API benefits:
|
||||
# ✅ Safety shields protect against malicious code execution
|
||||
# ✅ Session maintains context between code executions
|
||||
# ✅ Consistent tool configuration throughout conversation
|
||||
print(f"First result: {response1.output_message.content}")
|
||||
print(f"Optimization: {response2.output_message.content}")
|
||||
```
|
||||
|
||||
### Responses API: Dynamic per-call configuration with branching
|
||||
|
||||
```python
|
||||
# First response: Use web search for latest algorithms
|
||||
response1 = client.responses.create(
|
||||
model="Llama3.2-3B-Instruct",
|
||||
input="Search for the latest efficient sorting algorithms and their performance comparisons",
|
||||
tools=[
|
||||
{
|
||||
"type": "web_search",
|
||||
},
|
||||
], # Web search for current information
|
||||
)
|
||||
|
||||
# Continue conversation: Switch to file search for local docs
|
||||
response2 = client.responses.create(
|
||||
model="Llama3.2-1B-Instruct", # Switch to faster model
|
||||
input="Now search my uploaded files for existing sorting implementations",
|
||||
tools=[
|
||||
{ # Using Responses API built-in tools
|
||||
"type": "file_search",
|
||||
"vector_store_ids": ["vs_abc123"], # Vector store containing uploaded files
|
||||
},
|
||||
],
|
||||
previous_response_id=response1.id,
|
||||
)
|
||||
|
||||
# Branch from first response: Try different search approach
|
||||
response3 = client.responses.create(
|
||||
model="Llama3.2-3B-Instruct",
|
||||
input="Instead, search the web for Python-specific sorting best practices",
|
||||
tools=[{"type": "web_search"}], # Different web search query
|
||||
previous_response_id=response1.id, # Branch from response1
|
||||
)
|
||||
|
||||
# Responses API benefits:
|
||||
# ✅ Dynamic tool switching (web search ↔ file search per call)
|
||||
# ✅ OpenAI-compatible tool patterns (web_search, file_search)
|
||||
# ✅ Branch conversations to explore different information sources
|
||||
# ✅ Model flexibility per search type
|
||||
print(f"Web search results: {response1.output_message.content}")
|
||||
print(f"File search results: {response2.output_message.content}")
|
||||
print(f"Alternative web search: {response3.output_message.content}")
|
||||
```
|
||||
|
||||
Both APIs demonstrate distinct strengths that make them valuable on their own for different scenarios. The Agents API excels in providing structured, safety-conscious workflows with persistent session management, while the Responses API offers flexibility through dynamic configuration and OpenAI compatible tool patterns.
|
||||
|
||||
## Use Case Examples
|
||||
|
||||
### 1. **Research and Analysis with Safety Controls**
|
||||
**Best Choice: Agents API**
|
||||
|
||||
**Scenario:** You're building a research assistant for a financial institution that needs to analyze market data, execute code to process financial models, and search through internal compliance documents. The system must ensure all interactions are logged for regulatory compliance and protected by safety shields to prevent malicious code execution or data leaks.
|
||||
|
||||
**Why Agents API?** The Agents API provides persistent session management for iterative research workflows, built-in safety shields to protect against malicious code in financial models, and structured execution logs (session/turn/step) required for regulatory compliance. The static tool configuration ensures consistent access to your knowledge base and code interpreter throughout the entire research session.
|
||||
|
||||
### 2. **Dynamic Information Gathering with Branching Exploration**
|
||||
**Best Choice: Responses API**
|
||||
|
||||
**Scenario:** You're building a competitive intelligence tool that helps businesses research market trends. Users need to dynamically switch between web search for current market data and file search through uploaded industry reports. They also want to branch conversations to explore different market segments simultaneously and experiment with different models for various analysis types.
|
||||
|
||||
**Why Responses API?** The Responses API's branching capability lets users explore multiple market segments from any research point. Dynamic per-call configuration allows switching between web search and file search as needed, while experimenting with different models (faster models for quick searches, more powerful models for deep analysis). The OpenAI-compatible tool patterns make integration straightforward.
|
||||
|
||||
### 3. **OpenAI Migration with Advanced Tool Capabilities**
|
||||
**Best Choice: Responses API**
|
||||
|
||||
**Scenario:** You have an existing application built with OpenAI's Assistants API that uses file search and web search capabilities. You want to migrate to Llama Stack for better performance and cost control while maintaining the same tool calling patterns and adding new capabilities like dynamic vector store selection.
|
||||
|
||||
**Why Responses API?** The Responses API provides full OpenAI tool compatibility (`web_search`, `file_search`) with identical syntax, making migration seamless. The dynamic per-call configuration enables advanced features like switching vector stores per query or changing models based on query complexity - capabilities that extend beyond basic OpenAI functionality while maintaining compatibility.
|
||||
|
||||
### 4. **Educational Programming Tutor**
|
||||
**Best Choice: Agents API**
|
||||
|
||||
**Scenario:** You're building a programming tutor that maintains student context across multiple sessions, safely executes code exercises, and tracks learning progress with audit trails for educators.
|
||||
|
||||
**Why Agents API?** Persistent sessions remember student progress across multiple interactions, safety shields prevent malicious code execution while allowing legitimate programming exercises, and structured execution logs help educators track learning patterns.
|
||||
|
||||
### 5. **Advanced Software Debugging Assistant**
|
||||
**Best Choice: Agents API with Responses Backend**
|
||||
|
||||
**Scenario:** You're building a debugging assistant that helps developers troubleshoot complex issues. It needs to maintain context throughout a debugging session, safely execute diagnostic code, switch between different analysis tools dynamically, and branch conversations to explore multiple potential causes simultaneously.
|
||||
|
||||
**Why Agents + Responses?** The Agent provides safety shields for code execution and session management for the overall debugging workflow. The underlying Responses API enables dynamic model selection and flexible tool configuration per query, while branching lets you explore different theories (memory leak vs. concurrency issue) from the same debugging point and compare results.
|
||||
|
||||
> **Note:** The ability to use Responses API as the backend for Agents is not yet implemented but is planned for a future release. Currently, Agents use Chat Completions API as their backend by default.
|
||||
|
||||
## For More Information
|
||||
|
||||
- **LLS Agents API**: For detailed information on creating and managing agents, see the [Agents documentation](https://llama-stack.readthedocs.io/en/latest/building_applications/agent.html)
|
||||
- **OpenAI Responses API**: For information on using the OpenAI-compatible responses API, see the [OpenAI API documentation](https://platform.openai.com/docs/api-reference/responses)
|
||||
- **Chat Completions API**: For the default backend API used by Agents, see the [Chat Completions providers documentation](https://llama-stack.readthedocs.io/en/latest/providers/index.html#chat-completions)
|
||||
- **Agent Execution Loop**: For understanding how agents process turns and steps in their execution, see the [Agent Execution Loop documentation](https://llama-stack.readthedocs.io/en/latest/building_applications/agent_execution_loop.html)
|
|
@ -10,9 +10,11 @@ A Llama Stack API is described as a collection of REST endpoints. We currently s
|
|||
- **Eval**: generate outputs (via Inference or Agents) and perform scoring
|
||||
- **VectorIO**: perform operations on vector stores, such as adding documents, searching, and deleting documents
|
||||
- **Telemetry**: collect telemetry data from the system
|
||||
- **Post Training**: fine-tune a model
|
||||
- **Tool Runtime**: interact with various tools and protocols
|
||||
- **Responses**: generate responses from an LLM using this OpenAI compatible API.
|
||||
|
||||
We are working on adding a few more APIs to complete the application lifecycle. These will include:
|
||||
- **Batch Inference**: run inference on a dataset of inputs
|
||||
- **Batch Agents**: run agents on a dataset of inputs
|
||||
- **Post Training**: fine-tune a model
|
||||
- **Synthetic Data Generation**: generate synthetic data for model development
|
||||
|
|
|
@ -504,6 +504,47 @@ created by users sharing a team with them:
|
|||
description: any user has read access to any resource created by a user with the same team
|
||||
```
|
||||
|
||||
#### API Endpoint Authorization with Scopes
|
||||
|
||||
In addition to resource-based access control, Llama Stack supports endpoint-level authorization using OAuth 2.0 style scopes. When authentication is enabled, specific API endpoints require users to have particular scopes in their authentication token.
|
||||
|
||||
**Scope-Gated APIs:**
|
||||
The following APIs are currently gated by scopes:
|
||||
|
||||
- **Telemetry API** (scope: `telemetry.read`):
|
||||
- `POST /telemetry/traces` - Query traces
|
||||
- `GET /telemetry/traces/{trace_id}` - Get trace by ID
|
||||
- `GET /telemetry/traces/{trace_id}/spans/{span_id}` - Get span by ID
|
||||
- `POST /telemetry/spans/{span_id}/tree` - Get span tree
|
||||
- `POST /telemetry/spans` - Query spans
|
||||
- `POST /telemetry/metrics/{metric_name}` - Query metrics
|
||||
|
||||
**Authentication Configuration:**
|
||||
|
||||
For **JWT/OAuth2 providers**, scopes should be included in the JWT's claims:
|
||||
```json
|
||||
{
|
||||
"sub": "user123",
|
||||
"scope": "telemetry.read",
|
||||
"aud": "llama-stack"
|
||||
}
|
||||
```
|
||||
|
||||
For **custom authentication providers**, the endpoint must return user attributes including the `scopes` array:
|
||||
```json
|
||||
{
|
||||
"principal": "user123",
|
||||
"attributes": {
|
||||
"scopes": ["telemetry.read"]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Behavior:**
|
||||
- Users without the required scope receive a 403 Forbidden response
|
||||
- When authentication is disabled, scope checks are bypassed
|
||||
- Endpoints without `required_scope` work normally for all authenticated users
|
||||
|
||||
### Quota Configuration
|
||||
|
||||
The `quota` section allows you to enable server-side request throttling for both
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
---
|
||||
orphan: true
|
||||
---
|
||||
<!-- This file was auto-generated by distro_codegen.py, please edit source -->
|
||||
# NVIDIA Distribution
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ Fireworks AI inference provider for Llama models and other AI models on the Fire
|
|||
|
||||
| Field | Type | Required | Default | Description |
|
||||
|-------|------|----------|---------|-------------|
|
||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||
| `url` | `<class 'str'>` | No | https://api.fireworks.ai/inference/v1 | The URL for the Fireworks server |
|
||||
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The Fireworks.ai API Key |
|
||||
|
||||
|
|
|
@ -9,8 +9,7 @@ Ollama inference provider for running local models through the Ollama runtime.
|
|||
| Field | Type | Required | Default | Description |
|
||||
|-------|------|----------|---------|-------------|
|
||||
| `url` | `<class 'str'>` | No | http://localhost:11434 | |
|
||||
| `refresh_models` | `<class 'bool'>` | No | False | refresh and re-register models periodically |
|
||||
| `refresh_models_interval` | `<class 'int'>` | No | 300 | interval in seconds to refresh models |
|
||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically |
|
||||
|
||||
## Sample Configuration
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ Together AI inference provider for open-source models and collaborative AI devel
|
|||
|
||||
| Field | Type | Required | Default | Description |
|
||||
|-------|------|----------|---------|-------------|
|
||||
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||
| `url` | `<class 'str'>` | No | https://api.together.xyz/v1 | The URL for the Together AI server |
|
||||
| `api_key` | `pydantic.types.SecretStr \| None` | No | | The Together AI API Key |
|
||||
|
||||
|
|
|
@ -13,7 +13,6 @@ Remote vLLM inference provider for connecting to vLLM servers.
|
|||
| `api_token` | `str \| None` | No | fake | The API token |
|
||||
| `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. |
|
||||
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically |
|
||||
| `refresh_models_interval` | `<class 'int'>` | No | 300 | Interval in seconds to refresh models |
|
||||
|
||||
## Sample Configuration
|
||||
|
||||
|
|
|
@ -4,15 +4,83 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from enum import Enum, EnumMeta
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class DynamicApiMeta(EnumMeta):
|
||||
def __new__(cls, name, bases, namespace):
|
||||
# Store the original enum values
|
||||
original_values = {k: v for k, v in namespace.items() if not k.startswith("_")}
|
||||
|
||||
# Create the enum class
|
||||
cls = super().__new__(cls, name, bases, namespace)
|
||||
|
||||
# Store the original values for reference
|
||||
cls._original_values = original_values
|
||||
# Initialize _dynamic_values
|
||||
cls._dynamic_values = {}
|
||||
|
||||
return cls
|
||||
|
||||
def __call__(cls, value):
|
||||
try:
|
||||
return super().__call__(value)
|
||||
except ValueError as e:
|
||||
# If this value was already dynamically added, return it
|
||||
if value in cls._dynamic_values:
|
||||
return cls._dynamic_values[value]
|
||||
|
||||
# If the value doesn't exist, create a new enum member
|
||||
# Create a new member name from the value
|
||||
member_name = value.lower().replace("-", "_")
|
||||
|
||||
# If this member name already exists in the enum, return the existing member
|
||||
if member_name in cls._member_map_:
|
||||
return cls._member_map_[member_name]
|
||||
|
||||
# Instead of creating a new member, raise ValueError to force users to use Api.add() to
|
||||
# register new APIs explicitly
|
||||
raise ValueError(f"API '{value}' does not exist. Use Api.add() to register new APIs.") from e
|
||||
|
||||
def __iter__(cls):
|
||||
# Allow iteration over both static and dynamic members
|
||||
yield from super().__iter__()
|
||||
if hasattr(cls, "_dynamic_values"):
|
||||
yield from cls._dynamic_values.values()
|
||||
|
||||
def add(cls, value):
|
||||
"""
|
||||
Add a new API to the enum.
|
||||
Used to register external APIs.
|
||||
"""
|
||||
member_name = value.lower().replace("-", "_")
|
||||
|
||||
# If this member name already exists in the enum, return it
|
||||
if member_name in cls._member_map_:
|
||||
return cls._member_map_[member_name]
|
||||
|
||||
# Create a new enum member
|
||||
member = object.__new__(cls)
|
||||
member._name_ = member_name
|
||||
member._value_ = value
|
||||
|
||||
# Add it to the enum class
|
||||
cls._member_map_[member_name] = member
|
||||
cls._member_names_.append(member_name)
|
||||
cls._member_type_ = str
|
||||
|
||||
# Store it in our dynamic values
|
||||
cls._dynamic_values[value] = member
|
||||
|
||||
return member
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Api(Enum):
|
||||
class Api(Enum, metaclass=DynamicApiMeta):
|
||||
providers = "providers"
|
||||
inference = "inference"
|
||||
safety = "safety"
|
||||
|
@ -54,3 +122,12 @@ class Error(BaseModel):
|
|||
title: str
|
||||
detail: str
|
||||
instance: str | None = None
|
||||
|
||||
|
||||
class ExternalApiSpec(BaseModel):
|
||||
"""Specification for an external API implementation."""
|
||||
|
||||
module: str = Field(..., description="Python module containing the API implementation")
|
||||
name: str = Field(..., description="Name of the API")
|
||||
pip_packages: list[str] = Field(default=[], description="List of pip packages to install the API")
|
||||
protocol: str = Field(..., description="Name of the protocol class for the API")
|
||||
|
|
|
@ -819,12 +819,6 @@ class OpenAIEmbeddingsResponse(BaseModel):
|
|||
class ModelStore(Protocol):
|
||||
async def get_model(self, identifier: str) -> Model: ...
|
||||
|
||||
async def update_registered_llm_models(
|
||||
self,
|
||||
provider_id: str,
|
||||
models: list[Model],
|
||||
) -> None: ...
|
||||
|
||||
|
||||
class TextTruncation(Enum):
|
||||
"""Config for how to truncate text for embedding when text is longer than the model's max sequence length. Start and End semantics depend on whether the language is left-to-right or right-to-left.
|
||||
|
|
|
@ -22,6 +22,8 @@ from llama_stack.schema_utils import json_schema_type, register_schema, webmetho
|
|||
# Add this constant near the top of the file, after the imports
|
||||
DEFAULT_TTL_DAYS = 7
|
||||
|
||||
REQUIRED_SCOPE = "telemetry.read"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SpanStatus(Enum):
|
||||
|
@ -259,7 +261,7 @@ class Telemetry(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/traces", method="POST")
|
||||
@webmethod(route="/telemetry/traces", method="POST", required_scope=REQUIRED_SCOPE)
|
||||
async def query_traces(
|
||||
self,
|
||||
attribute_filters: list[QueryCondition] | None = None,
|
||||
|
@ -277,7 +279,7 @@ class Telemetry(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/traces/{trace_id:path}", method="GET")
|
||||
@webmethod(route="/telemetry/traces/{trace_id:path}", method="GET", required_scope=REQUIRED_SCOPE)
|
||||
async def get_trace(self, trace_id: str) -> Trace:
|
||||
"""Get a trace by its ID.
|
||||
|
||||
|
@ -286,7 +288,9 @@ class Telemetry(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}", method="GET")
|
||||
@webmethod(
|
||||
route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}", method="GET", required_scope=REQUIRED_SCOPE
|
||||
)
|
||||
async def get_span(self, trace_id: str, span_id: str) -> Span:
|
||||
"""Get a span by its ID.
|
||||
|
||||
|
@ -296,7 +300,7 @@ class Telemetry(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/spans/{span_id:path}/tree", method="POST")
|
||||
@webmethod(route="/telemetry/spans/{span_id:path}/tree", method="POST", required_scope=REQUIRED_SCOPE)
|
||||
async def get_span_tree(
|
||||
self,
|
||||
span_id: str,
|
||||
|
@ -312,7 +316,7 @@ class Telemetry(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/spans", method="POST")
|
||||
@webmethod(route="/telemetry/spans", method="POST", required_scope=REQUIRED_SCOPE)
|
||||
async def query_spans(
|
||||
self,
|
||||
attribute_filters: list[QueryCondition],
|
||||
|
@ -345,7 +349,7 @@ class Telemetry(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/telemetry/metrics/{metric_name}", method="POST")
|
||||
@webmethod(route="/telemetry/metrics/{metric_name}", method="POST", required_scope=REQUIRED_SCOPE)
|
||||
async def query_metrics(
|
||||
self,
|
||||
metric_name: str,
|
||||
|
|
|
@ -36,6 +36,7 @@ from llama_stack.distribution.datatypes import (
|
|||
StackRunConfig,
|
||||
)
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.external import load_external_apis
|
||||
from llama_stack.distribution.resolver import InvalidProviderError
|
||||
from llama_stack.distribution.stack import replace_env_vars
|
||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR, EXTERNAL_PROVIDERS_DIR
|
||||
|
@ -404,6 +405,29 @@ def _run_stack_build_command_from_build_config(
|
|||
to_write = json.loads(build_config.model_dump_json())
|
||||
f.write(yaml.dump(to_write, sort_keys=False))
|
||||
|
||||
# We first install the external APIs so that the build process can use them and discover the
|
||||
# providers dependencies
|
||||
if build_config.external_apis_dir:
|
||||
cprint("Installing external APIs", color="yellow", file=sys.stderr)
|
||||
external_apis = load_external_apis(build_config)
|
||||
if external_apis:
|
||||
# install the external APIs
|
||||
packages = []
|
||||
for _, api_spec in external_apis.items():
|
||||
if api_spec.pip_packages:
|
||||
packages.extend(api_spec.pip_packages)
|
||||
cprint(
|
||||
f"Installing {api_spec.name} with pip packages {api_spec.pip_packages}",
|
||||
color="yellow",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return_code = run_command(["uv", "pip", "install", *packages])
|
||||
if return_code != 0:
|
||||
packages_str = ", ".join(packages)
|
||||
raise RuntimeError(
|
||||
f"Failed to install external APIs packages: {packages_str} (return code: {return_code})"
|
||||
)
|
||||
|
||||
return_code = build_image(
|
||||
build_config,
|
||||
build_file_path,
|
||||
|
|
|
@ -14,6 +14,7 @@ from termcolor import cprint
|
|||
|
||||
from llama_stack.distribution.datatypes import BuildConfig
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.external import load_external_apis
|
||||
from llama_stack.distribution.utils.exec import run_command
|
||||
from llama_stack.distribution.utils.image_types import LlamaStackImageType
|
||||
from llama_stack.providers.datatypes import Api
|
||||
|
@ -105,6 +106,11 @@ def build_image(
|
|||
|
||||
normal_deps, special_deps = get_provider_dependencies(build_config)
|
||||
normal_deps += SERVER_DEPENDENCIES
|
||||
if build_config.external_apis_dir:
|
||||
external_apis = load_external_apis(build_config)
|
||||
if external_apis:
|
||||
for _, api_spec in external_apis.items():
|
||||
normal_deps.extend(api_spec.pip_packages)
|
||||
|
||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
||||
script = str(importlib.resources.files("llama_stack") / "distribution/build_container.sh")
|
||||
|
|
|
@ -36,6 +36,11 @@ LLAMA_STACK_RUN_CONFIG_VERSION = 2
|
|||
RoutingKey = str | list[str]
|
||||
|
||||
|
||||
class RegistryEntrySource(StrEnum):
|
||||
via_register_api = "via_register_api"
|
||||
listed_from_provider = "listed_from_provider"
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
principal: str
|
||||
# further attributes that may be used for access control decisions
|
||||
|
@ -50,6 +55,7 @@ class ResourceWithOwner(Resource):
|
|||
resource. This can be used to constrain access to the resource."""
|
||||
|
||||
owner: User | None = None
|
||||
source: RegistryEntrySource = RegistryEntrySource.via_register_api
|
||||
|
||||
|
||||
# Use the extended Resource for all routable objects
|
||||
|
@ -381,6 +387,11 @@ a default SQLite store will be used.""",
|
|||
description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.",
|
||||
)
|
||||
|
||||
external_apis_dir: Path | None = Field(
|
||||
default=None,
|
||||
description="Path to directory containing external API implementations. The APIs code and dependencies must be installed on the system.",
|
||||
)
|
||||
|
||||
@field_validator("external_providers_dir")
|
||||
@classmethod
|
||||
def validate_external_providers_dir(cls, v):
|
||||
|
@ -412,6 +423,10 @@ class BuildConfig(BaseModel):
|
|||
default_factory=list,
|
||||
description="Additional pip packages to install in the distribution. These packages will be installed in the distribution environment.",
|
||||
)
|
||||
external_apis_dir: Path | None = Field(
|
||||
default=None,
|
||||
description="Path to directory containing external API implementations. The APIs code and dependencies must be installed on the system.",
|
||||
)
|
||||
|
||||
@field_validator("external_providers_dir")
|
||||
@classmethod
|
||||
|
|
|
@ -12,6 +12,7 @@ from typing import Any
|
|||
import yaml
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.distribution.external import load_external_apis
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
|
@ -133,16 +134,34 @@ def get_provider_registry(
|
|||
ValueError: If any provider spec is invalid
|
||||
"""
|
||||
|
||||
ret: dict[Api, dict[str, ProviderSpec]] = {}
|
||||
registry: dict[Api, dict[str, ProviderSpec]] = {}
|
||||
for api in providable_apis():
|
||||
name = api.name.lower()
|
||||
logger.debug(f"Importing module {name}")
|
||||
try:
|
||||
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
|
||||
ret[api] = {a.provider_type: a for a in module.available_providers()}
|
||||
registry[api] = {a.provider_type: a for a in module.available_providers()}
|
||||
except ImportError as e:
|
||||
logger.warning(f"Failed to import module {name}: {e}")
|
||||
|
||||
# Refresh providable APIs with external APIs if any
|
||||
external_apis = load_external_apis(config)
|
||||
for api, api_spec in external_apis.items():
|
||||
name = api_spec.name.lower()
|
||||
logger.info(f"Importing external API {name} module {api_spec.module}")
|
||||
try:
|
||||
module = importlib.import_module(api_spec.module)
|
||||
registry[api] = {a.provider_type: a for a in module.available_providers()}
|
||||
except (ImportError, AttributeError) as e:
|
||||
# Populate the registry with an empty dict to avoid breaking the provider registry
|
||||
# This assume that the in-tree provider(s) are not available for this API which means
|
||||
# that users will need to use external providers for this API.
|
||||
registry[api] = {}
|
||||
logger.error(
|
||||
f"Failed to import external API {name}: {e}. Could not populate the in-tree provider(s) registry for {api.name}. \n"
|
||||
"Install the API package to load any in-tree providers for this API."
|
||||
)
|
||||
|
||||
# Check if config has the external_providers_dir attribute
|
||||
if config and hasattr(config, "external_providers_dir") and config.external_providers_dir:
|
||||
external_providers_dir = os.path.abspath(os.path.expanduser(config.external_providers_dir))
|
||||
|
@ -175,11 +194,9 @@ def get_provider_registry(
|
|||
else:
|
||||
spec = _load_inline_provider_spec(spec_data, api, provider_name)
|
||||
provider_type_key = f"inline::{provider_name}"
|
||||
|
||||
logger.info(f"Loaded {provider_type} provider spec for {provider_type_key} from {spec_path}")
|
||||
if provider_type_key in ret[api]:
|
||||
if provider_type_key in registry[api]:
|
||||
logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}")
|
||||
ret[api][provider_type_key] = spec
|
||||
registry[api][provider_type_key] = spec
|
||||
logger.info(f"Successfully loaded external provider {provider_type_key}")
|
||||
except yaml.YAMLError as yaml_err:
|
||||
logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}")
|
||||
|
@ -187,4 +204,4 @@ def get_provider_registry(
|
|||
except Exception as e:
|
||||
logger.error(f"Failed to load provider spec from {spec_path}: {e}")
|
||||
raise e
|
||||
return ret
|
||||
return registry
|
||||
|
|
54
llama_stack/distribution/external.py
Normal file
54
llama_stack/distribution/external.py
Normal file
|
@ -0,0 +1,54 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import yaml
|
||||
|
||||
from llama_stack.apis.datatypes import Api, ExternalApiSpec
|
||||
from llama_stack.distribution.datatypes import BuildConfig, StackRunConfig
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
def load_external_apis(config: StackRunConfig | BuildConfig | None) -> dict[Api, ExternalApiSpec]:
|
||||
"""Load external API specifications from the configured directory.
|
||||
|
||||
Args:
|
||||
config: StackRunConfig or BuildConfig containing the external APIs directory path
|
||||
|
||||
Returns:
|
||||
A dictionary mapping API names to their specifications
|
||||
"""
|
||||
if not config or not config.external_apis_dir:
|
||||
return {}
|
||||
|
||||
external_apis_dir = config.external_apis_dir.expanduser().resolve()
|
||||
if not external_apis_dir.is_dir():
|
||||
logger.error(f"External APIs directory is not a directory: {external_apis_dir}")
|
||||
return {}
|
||||
|
||||
logger.info(f"Loading external APIs from {external_apis_dir}")
|
||||
external_apis: dict[Api, ExternalApiSpec] = {}
|
||||
|
||||
# Look for YAML files in the external APIs directory
|
||||
for yaml_path in external_apis_dir.glob("*.yaml"):
|
||||
try:
|
||||
with open(yaml_path) as f:
|
||||
spec_data = yaml.safe_load(f)
|
||||
|
||||
spec = ExternalApiSpec(**spec_data)
|
||||
api = Api.add(spec.name)
|
||||
logger.info(f"Loaded external API spec for {spec.name} from {yaml_path}")
|
||||
external_apis[api] = spec
|
||||
except yaml.YAMLError as yaml_err:
|
||||
logger.error(f"Failed to parse YAML file {yaml_path}: {yaml_err}")
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception(f"Failed to load external API spec from {yaml_path}")
|
||||
raise
|
||||
|
||||
return external_apis
|
|
@ -16,6 +16,7 @@ from llama_stack.apis.inspect import (
|
|||
VersionInfo,
|
||||
)
|
||||
from llama_stack.distribution.datatypes import StackRunConfig
|
||||
from llama_stack.distribution.external import load_external_apis
|
||||
from llama_stack.distribution.server.routes import get_all_api_routes
|
||||
from llama_stack.providers.datatypes import HealthStatus
|
||||
|
||||
|
@ -42,7 +43,8 @@ class DistributionInspectImpl(Inspect):
|
|||
run_config: StackRunConfig = self.config.run_config
|
||||
|
||||
ret = []
|
||||
all_endpoints = get_all_api_routes()
|
||||
external_apis = load_external_apis(run_config)
|
||||
all_endpoints = get_all_api_routes(external_apis)
|
||||
for api, endpoints in all_endpoints.items():
|
||||
# Always include provider and inspect APIs, filter others based on run config
|
||||
if api.value in ["providers", "inspect"]:
|
||||
|
@ -53,7 +55,8 @@ class DistributionInspectImpl(Inspect):
|
|||
method=next(iter([m for m in e.methods if m != "HEAD"])),
|
||||
provider_types=[], # These APIs don't have "real" providers - they're internal to the stack
|
||||
)
|
||||
for e in endpoints
|
||||
for e, _ in endpoints
|
||||
if e.methods is not None
|
||||
]
|
||||
)
|
||||
else:
|
||||
|
@ -66,7 +69,8 @@ class DistributionInspectImpl(Inspect):
|
|||
method=next(iter([m for m in e.methods if m != "HEAD"])),
|
||||
provider_types=[p.provider_type for p in providers],
|
||||
)
|
||||
for e in endpoints
|
||||
for e, _ in endpoints
|
||||
if e.methods is not None
|
||||
]
|
||||
)
|
||||
|
||||
|
|
|
@ -161,7 +161,13 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
|||
if not self.skip_logger_removal:
|
||||
self._remove_root_logger_handlers()
|
||||
|
||||
return self.loop.run_until_complete(self.async_client.initialize())
|
||||
# use a new event loop to avoid interfering with the main event loop
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
return loop.run_until_complete(self.async_client.initialize())
|
||||
finally:
|
||||
asyncio.set_event_loop(None)
|
||||
|
||||
def _remove_root_logger_handlers(self):
|
||||
"""
|
||||
|
@ -353,13 +359,15 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
body = options.params or {}
|
||||
body |= options.json_data or {}
|
||||
|
||||
matched_func, path_params, route = find_matching_route(options.method, path, self.route_impls)
|
||||
matched_func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls)
|
||||
body |= path_params
|
||||
|
||||
body, field_names = self._handle_file_uploads(options, body)
|
||||
|
||||
body = self._convert_body(path, options.method, body, exclude_params=set(field_names))
|
||||
await start_trace(route, {"__location__": "library_client"})
|
||||
|
||||
trace_path = webmethod.descriptive_name or route_path
|
||||
await start_trace(trace_path, {"__location__": "library_client"})
|
||||
try:
|
||||
result = await matched_func(**body)
|
||||
finally:
|
||||
|
@ -409,12 +417,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
path = options.url
|
||||
body = options.params or {}
|
||||
body |= options.json_data or {}
|
||||
func, path_params, route = find_matching_route(options.method, path, self.route_impls)
|
||||
func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls)
|
||||
body |= path_params
|
||||
|
||||
body = self._convert_body(path, options.method, body)
|
||||
|
||||
await start_trace(route, {"__location__": "library_client"})
|
||||
trace_path = webmethod.descriptive_name or route_path
|
||||
await start_trace(trace_path, {"__location__": "library_client"})
|
||||
|
||||
async def gen():
|
||||
try:
|
||||
|
@ -445,8 +454,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
# we use asynchronous impl always internally and channel all requests to AsyncLlamaStackClient
|
||||
# however, the top-level caller may be a SyncAPIClient -- so its stream_cls might be a Stream (SyncStream)
|
||||
# so we need to convert it to AsyncStream
|
||||
# mypy can't track runtime variables inside the [...] of a generic, so ignore that check
|
||||
args = get_args(stream_cls)
|
||||
stream_cls = AsyncStream[args[0]]
|
||||
stream_cls = AsyncStream[args[0]] # type: ignore[valid-type]
|
||||
response = AsyncAPIResponse(
|
||||
raw=mock_response,
|
||||
client=self,
|
||||
|
@ -468,7 +478,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
|||
|
||||
exclude_params = exclude_params or set()
|
||||
|
||||
func, _, _ = find_matching_route(method, path, self.route_impls)
|
||||
func, _, _, _ = find_matching_route(method, path, self.route_impls)
|
||||
sig = inspect.signature(func)
|
||||
|
||||
# Strip NOT_GIVENs to use the defaults in signature
|
||||
|
|
|
@ -101,3 +101,15 @@ def get_authenticated_user() -> User | None:
|
|||
if not provider_data:
|
||||
return None
|
||||
return provider_data.get("__authenticated_user")
|
||||
|
||||
|
||||
def user_from_scope(scope: dict) -> User | None:
|
||||
"""Create a User object from ASGI scope data (set by authentication middleware)"""
|
||||
user_attributes = scope.get("user_attributes", {})
|
||||
principal = scope.get("principal", "")
|
||||
|
||||
# auth not enabled
|
||||
if not principal and not user_attributes:
|
||||
return None
|
||||
|
||||
return User(principal=principal, attributes=user_attributes)
|
||||
|
|
|
@ -11,6 +11,7 @@ from llama_stack.apis.agents import Agents
|
|||
from llama_stack.apis.benchmarks import Benchmarks
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.datatypes import ExternalApiSpec
|
||||
from llama_stack.apis.eval import Eval
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference, InferenceProvider
|
||||
|
@ -35,6 +36,7 @@ from llama_stack.distribution.datatypes import (
|
|||
StackRunConfig,
|
||||
)
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.external import load_external_apis
|
||||
from llama_stack.distribution.store import DistributionRegistry
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.log import get_logger
|
||||
|
@ -59,8 +61,16 @@ class InvalidProviderError(Exception):
|
|||
pass
|
||||
|
||||
|
||||
def api_protocol_map() -> dict[Api, Any]:
|
||||
return {
|
||||
def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) -> dict[Api, Any]:
|
||||
"""Get a mapping of API types to their protocol classes.
|
||||
|
||||
Args:
|
||||
external_apis: Optional dictionary of external API specifications
|
||||
|
||||
Returns:
|
||||
Dictionary mapping API types to their protocol classes
|
||||
"""
|
||||
protocols = {
|
||||
Api.providers: ProvidersAPI,
|
||||
Api.agents: Agents,
|
||||
Api.inference: Inference,
|
||||
|
@ -83,10 +93,23 @@ def api_protocol_map() -> dict[Api, Any]:
|
|||
Api.files: Files,
|
||||
}
|
||||
|
||||
if external_apis:
|
||||
for api, api_spec in external_apis.items():
|
||||
try:
|
||||
module = importlib.import_module(api_spec.module)
|
||||
api_class = getattr(module, api_spec.protocol)
|
||||
|
||||
def api_protocol_map_for_compliance_check() -> dict[Api, Any]:
|
||||
protocols[api] = api_class
|
||||
except (ImportError, AttributeError):
|
||||
logger.exception(f"Failed to load external API {api_spec.name}")
|
||||
|
||||
return protocols
|
||||
|
||||
|
||||
def api_protocol_map_for_compliance_check(config: Any) -> dict[Api, Any]:
|
||||
external_apis = load_external_apis(config)
|
||||
return {
|
||||
**api_protocol_map(),
|
||||
**api_protocol_map(external_apis),
|
||||
Api.inference: InferenceProvider,
|
||||
}
|
||||
|
||||
|
@ -250,7 +273,7 @@ async def instantiate_providers(
|
|||
dist_registry: DistributionRegistry,
|
||||
run_config: StackRunConfig,
|
||||
policy: list[AccessRule],
|
||||
) -> dict:
|
||||
) -> dict[Api, Any]:
|
||||
"""Instantiates providers asynchronously while managing dependencies."""
|
||||
impls: dict[Api, Any] = {}
|
||||
inner_impls_by_provider_id: dict[str, dict[str, Any]] = {f"inner-{x.value}": {} for x in router_apis}
|
||||
|
@ -360,7 +383,7 @@ async def instantiate_provider(
|
|||
impl.__provider_spec__ = provider_spec
|
||||
impl.__provider_config__ = config
|
||||
|
||||
protocols = api_protocol_map_for_compliance_check()
|
||||
protocols = api_protocol_map_for_compliance_check(run_config)
|
||||
additional_protocols = additional_protocols_map()
|
||||
# TODO: check compliance for special tool groups
|
||||
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
|
||||
|
|
|
@ -117,6 +117,9 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
for p in self.impls_by_provider_id.values():
|
||||
await p.shutdown()
|
||||
|
||||
async def refresh(self) -> None:
|
||||
pass
|
||||
|
||||
async def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
|
||||
from .benchmarks import BenchmarksRoutingTable
|
||||
from .datasets import DatasetsRoutingTable
|
||||
|
@ -206,7 +209,6 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
if obj.type == ResourceType.model.value:
|
||||
await self.dist_registry.register(registered_obj)
|
||||
return registered_obj
|
||||
|
||||
else:
|
||||
await self.dist_registry.register(obj)
|
||||
return obj
|
||||
|
|
|
@ -10,6 +10,7 @@ from typing import Any
|
|||
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
|
||||
from llama_stack.distribution.datatypes import (
|
||||
ModelWithOwner,
|
||||
RegistryEntrySource,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
|
@ -19,6 +20,26 @@ logger = get_logger(name=__name__, category="core")
|
|||
|
||||
|
||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||
listed_providers: set[str] = set()
|
||||
|
||||
async def refresh(self) -> None:
|
||||
for provider_id, provider in self.impls_by_provider_id.items():
|
||||
refresh = await provider.should_refresh_models()
|
||||
if not (refresh or provider_id in self.listed_providers):
|
||||
continue
|
||||
|
||||
try:
|
||||
models = await provider.list_models()
|
||||
except Exception as e:
|
||||
logger.exception(f"Model refresh failed for provider {provider_id}: {e}")
|
||||
continue
|
||||
|
||||
self.listed_providers.add(provider_id)
|
||||
if models is None:
|
||||
continue
|
||||
|
||||
await self.update_registered_models(provider_id, models)
|
||||
|
||||
async def list_models(self) -> ListModelsResponse:
|
||||
return ListModelsResponse(data=await self.get_all_with_type("model"))
|
||||
|
||||
|
@ -81,6 +102,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
provider_id=provider_id,
|
||||
metadata=metadata,
|
||||
model_type=model_type,
|
||||
source=RegistryEntrySource.via_register_api,
|
||||
)
|
||||
registered_model = await self.register_object(model)
|
||||
return registered_model
|
||||
|
@ -91,7 +113,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
raise ValueError(f"Model {model_id} not found")
|
||||
await self.unregister_object(existing_model)
|
||||
|
||||
async def update_registered_llm_models(
|
||||
async def update_registered_models(
|
||||
self,
|
||||
provider_id: str,
|
||||
models: list[Model],
|
||||
|
@ -102,18 +124,19 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
# from run.yaml) that we need to keep track of
|
||||
model_ids = {}
|
||||
for model in existing_models:
|
||||
# we leave embeddings models alone because often we don't get metadata
|
||||
# (embedding dimension, etc.) from the provider
|
||||
if model.provider_id == provider_id and model.model_type == ModelType.llm:
|
||||
if model.provider_id != provider_id:
|
||||
continue
|
||||
if model.source == RegistryEntrySource.via_register_api:
|
||||
model_ids[model.provider_resource_id] = model.identifier
|
||||
logger.debug(f"unregistering model {model.identifier}")
|
||||
await self.unregister_object(model)
|
||||
continue
|
||||
|
||||
logger.debug(f"unregistering model {model.identifier}")
|
||||
await self.unregister_object(model)
|
||||
|
||||
for model in models:
|
||||
if model.model_type != ModelType.llm:
|
||||
continue
|
||||
if model.provider_resource_id in model_ids:
|
||||
model.identifier = model_ids[model.provider_resource_id]
|
||||
# avoid overwriting a non-provider-registered model entry
|
||||
continue
|
||||
|
||||
logger.debug(f"registering model {model.identifier} ({model.provider_resource_id})")
|
||||
await self.register_object(
|
||||
|
@ -123,5 +146,6 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
provider_id=provider_id,
|
||||
metadata=model.metadata,
|
||||
model_type=model.model_type,
|
||||
source=RegistryEntrySource.listed_from_provider,
|
||||
)
|
||||
)
|
||||
|
|
|
@ -7,9 +7,12 @@
|
|||
import json
|
||||
|
||||
import httpx
|
||||
from aiohttp import hdrs
|
||||
|
||||
from llama_stack.distribution.datatypes import AuthenticationConfig
|
||||
from llama_stack.distribution.datatypes import AuthenticationConfig, User
|
||||
from llama_stack.distribution.request_headers import user_from_scope
|
||||
from llama_stack.distribution.server.auth_providers import create_auth_provider
|
||||
from llama_stack.distribution.server.routes import find_matching_route, initialize_route_impls
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="auth")
|
||||
|
@ -78,12 +81,14 @@ class AuthenticationMiddleware:
|
|||
access resources that don't have access_attributes defined.
|
||||
"""
|
||||
|
||||
def __init__(self, app, auth_config: AuthenticationConfig):
|
||||
def __init__(self, app, auth_config: AuthenticationConfig, impls):
|
||||
self.app = app
|
||||
self.impls = impls
|
||||
self.auth_provider = create_auth_provider(auth_config)
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if scope["type"] == "http":
|
||||
# First, handle authentication
|
||||
headers = dict(scope.get("headers", []))
|
||||
auth_header = headers.get(b"authorization", b"").decode()
|
||||
|
||||
|
@ -121,15 +126,50 @@ class AuthenticationMiddleware:
|
|||
f"Authentication successful: {validation_result.principal} with {len(validation_result.attributes)} attributes"
|
||||
)
|
||||
|
||||
# Scope-based API access control
|
||||
path = scope.get("path", "")
|
||||
method = scope.get("method", hdrs.METH_GET)
|
||||
|
||||
if not hasattr(self, "route_impls"):
|
||||
self.route_impls = initialize_route_impls(self.impls)
|
||||
|
||||
try:
|
||||
_, _, _, webmethod = find_matching_route(method, path, self.route_impls)
|
||||
except ValueError:
|
||||
# If no matching endpoint is found, pass through to FastAPI
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
if webmethod.required_scope:
|
||||
user = user_from_scope(scope)
|
||||
if not _has_required_scope(webmethod.required_scope, user):
|
||||
return await self._send_auth_error(
|
||||
send,
|
||||
f"Access denied: user does not have required scope: {webmethod.required_scope}",
|
||||
status=403,
|
||||
)
|
||||
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
async def _send_auth_error(self, send, message):
|
||||
async def _send_auth_error(self, send, message, status=401):
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": 401,
|
||||
"status": status,
|
||||
"headers": [[b"content-type", b"application/json"]],
|
||||
}
|
||||
)
|
||||
error_msg = json.dumps({"error": {"message": message}}).encode()
|
||||
error_key = "message" if status == 401 else "detail"
|
||||
error_msg = json.dumps({"error": {error_key: message}}).encode()
|
||||
await send({"type": "http.response.body", "body": error_msg})
|
||||
|
||||
|
||||
def _has_required_scope(required_scope: str, user: User | None) -> bool:
|
||||
# if no user, assume auth is not enabled
|
||||
if not user:
|
||||
return True
|
||||
|
||||
if not user.attributes:
|
||||
return False
|
||||
|
||||
user_scopes = user.attributes.get("scopes", [])
|
||||
return required_scope in user_scopes
|
||||
|
|
|
@ -12,17 +12,18 @@ from typing import Any
|
|||
from aiohttp import hdrs
|
||||
from starlette.routing import Route
|
||||
|
||||
from llama_stack.apis.datatypes import Api, ExternalApiSpec
|
||||
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
|
||||
from llama_stack.distribution.resolver import api_protocol_map
|
||||
from llama_stack.providers.datatypes import Api
|
||||
from llama_stack.schema_utils import WebMethod
|
||||
|
||||
EndpointFunc = Callable[..., Any]
|
||||
PathParams = dict[str, str]
|
||||
RouteInfo = tuple[EndpointFunc, str]
|
||||
RouteInfo = tuple[EndpointFunc, str, WebMethod]
|
||||
PathImpl = dict[str, RouteInfo]
|
||||
RouteImpls = dict[str, PathImpl]
|
||||
RouteMatch = tuple[EndpointFunc, PathParams, str]
|
||||
RouteMatch = tuple[EndpointFunc, PathParams, str, WebMethod]
|
||||
|
||||
|
||||
def toolgroup_protocol_map():
|
||||
|
@ -31,10 +32,12 @@ def toolgroup_protocol_map():
|
|||
}
|
||||
|
||||
|
||||
def get_all_api_routes() -> dict[Api, list[Route]]:
|
||||
def get_all_api_routes(
|
||||
external_apis: dict[Api, ExternalApiSpec] | None = None,
|
||||
) -> dict[Api, list[tuple[Route, WebMethod]]]:
|
||||
apis = {}
|
||||
|
||||
protocols = api_protocol_map()
|
||||
protocols = api_protocol_map(external_apis)
|
||||
toolgroup_protocols = toolgroup_protocol_map()
|
||||
for api, protocol in protocols.items():
|
||||
routes = []
|
||||
|
@ -65,7 +68,7 @@ def get_all_api_routes() -> dict[Api, list[Route]]:
|
|||
else:
|
||||
http_method = hdrs.METH_POST
|
||||
routes.append(
|
||||
Route(path=path, methods=[http_method], name=name, endpoint=None)
|
||||
(Route(path=path, methods=[http_method], name=name, endpoint=None), webmethod)
|
||||
) # setting endpoint to None since don't use a Router object
|
||||
|
||||
apis[api] = routes
|
||||
|
@ -73,8 +76,8 @@ def get_all_api_routes() -> dict[Api, list[Route]]:
|
|||
return apis
|
||||
|
||||
|
||||
def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls:
|
||||
routes = get_all_api_routes()
|
||||
def initialize_route_impls(impls, external_apis: dict[Api, ExternalApiSpec] | None = None) -> RouteImpls:
|
||||
api_to_routes = get_all_api_routes(external_apis)
|
||||
route_impls: RouteImpls = {}
|
||||
|
||||
def _convert_path_to_regex(path: str) -> str:
|
||||
|
@ -88,10 +91,10 @@ def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls:
|
|||
|
||||
return f"^{pattern}$"
|
||||
|
||||
for api, api_routes in routes.items():
|
||||
for api, api_routes in api_to_routes.items():
|
||||
if api not in impls:
|
||||
continue
|
||||
for route in api_routes:
|
||||
for route, webmethod in api_routes:
|
||||
impl = impls[api]
|
||||
func = getattr(impl, route.name)
|
||||
# Get the first (and typically only) method from the set, filtering out HEAD
|
||||
|
@ -104,6 +107,7 @@ def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls:
|
|||
route_impls[method][_convert_path_to_regex(route.path)] = (
|
||||
func,
|
||||
route.path,
|
||||
webmethod,
|
||||
)
|
||||
|
||||
return route_impls
|
||||
|
@ -118,7 +122,7 @@ def find_matching_route(method: str, path: str, route_impls: RouteImpls) -> Rout
|
|||
route_impls: A dictionary of endpoint implementations
|
||||
|
||||
Returns:
|
||||
A tuple of (endpoint_function, path_params, descriptive_name)
|
||||
A tuple of (endpoint_function, path_params, route_path, webmethod_metadata)
|
||||
|
||||
Raises:
|
||||
ValueError: If no matching endpoint is found
|
||||
|
@ -127,11 +131,11 @@ def find_matching_route(method: str, path: str, route_impls: RouteImpls) -> Rout
|
|||
if not impls:
|
||||
raise ValueError(f"No endpoint found for {path}")
|
||||
|
||||
for regex, (func, descriptive_name) in impls.items():
|
||||
for regex, (func, route_path, webmethod) in impls.items():
|
||||
match = re.match(regex, path)
|
||||
if match:
|
||||
# Extract named groups from the regex match
|
||||
path_params = match.groupdict()
|
||||
return func, path_params, descriptive_name
|
||||
return func, path_params, route_path, webmethod
|
||||
|
||||
raise ValueError(f"No endpoint found for {path}")
|
||||
|
|
|
@ -40,7 +40,12 @@ from llama_stack.distribution.datatypes import (
|
|||
StackRunConfig,
|
||||
)
|
||||
from llama_stack.distribution.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.distribution.request_headers import PROVIDER_DATA_VAR, User, request_provider_data_context
|
||||
from llama_stack.distribution.external import ExternalApiSpec, load_external_apis
|
||||
from llama_stack.distribution.request_headers import (
|
||||
PROVIDER_DATA_VAR,
|
||||
request_provider_data_context,
|
||||
user_from_scope,
|
||||
)
|
||||
from llama_stack.distribution.resolver import InvalidProviderError
|
||||
from llama_stack.distribution.server.routes import (
|
||||
find_matching_route,
|
||||
|
@ -222,9 +227,7 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
|||
@functools.wraps(func)
|
||||
async def route_handler(request: Request, **kwargs):
|
||||
# Get auth attributes from the request scope
|
||||
user_attributes = request.scope.get("user_attributes", {})
|
||||
principal = request.scope.get("principal", "")
|
||||
user = User(principal=principal, attributes=user_attributes)
|
||||
user = user_from_scope(request.scope)
|
||||
|
||||
await log_request_pre_validation(request)
|
||||
|
||||
|
@ -282,9 +285,10 @@ def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
|||
|
||||
|
||||
class TracingMiddleware:
|
||||
def __init__(self, app, impls):
|
||||
def __init__(self, app, impls, external_apis: dict[str, ExternalApiSpec]):
|
||||
self.app = app
|
||||
self.impls = impls
|
||||
self.external_apis = external_apis
|
||||
# FastAPI built-in paths that should bypass custom routing
|
||||
self.fastapi_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static")
|
||||
|
||||
|
@ -301,10 +305,12 @@ class TracingMiddleware:
|
|||
return await self.app(scope, receive, send)
|
||||
|
||||
if not hasattr(self, "route_impls"):
|
||||
self.route_impls = initialize_route_impls(self.impls)
|
||||
self.route_impls = initialize_route_impls(self.impls, self.external_apis)
|
||||
|
||||
try:
|
||||
_, _, trace_path = find_matching_route(scope.get("method", hdrs.METH_GET), path, self.route_impls)
|
||||
_, _, route_path, webmethod = find_matching_route(
|
||||
scope.get("method", hdrs.METH_GET), path, self.route_impls
|
||||
)
|
||||
except ValueError:
|
||||
# If no matching endpoint is found, pass through to FastAPI
|
||||
logger.debug(f"No matching route found for path: {path}, falling back to FastAPI")
|
||||
|
@ -321,6 +327,7 @@ class TracingMiddleware:
|
|||
if tracestate:
|
||||
trace_attributes["tracestate"] = tracestate
|
||||
|
||||
trace_path = webmethod.descriptive_name or route_path
|
||||
trace_context = await start_trace(trace_path, trace_attributes)
|
||||
|
||||
async def send_with_trace_id(message):
|
||||
|
@ -432,10 +439,21 @@ def main(args: argparse.Namespace | None = None):
|
|||
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
|
||||
app.add_middleware(ClientVersionMiddleware)
|
||||
|
||||
# Add authentication middleware if configured
|
||||
try:
|
||||
# Create and set the event loop that will be used for both construction and server runtime
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# Construct the stack in the persistent event loop
|
||||
impls = loop.run_until_complete(construct_stack(config))
|
||||
|
||||
except InvalidProviderError as e:
|
||||
logger.error(f"Error: {str(e)}")
|
||||
sys.exit(1)
|
||||
|
||||
if config.server.auth:
|
||||
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}")
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth, impls=impls)
|
||||
else:
|
||||
if config.server.quota:
|
||||
quota = config.server.quota
|
||||
|
@ -466,24 +484,14 @@ def main(args: argparse.Namespace | None = None):
|
|||
window_seconds=window_seconds,
|
||||
)
|
||||
|
||||
try:
|
||||
# Create and set the event loop that will be used for both construction and server runtime
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# Construct the stack in the persistent event loop
|
||||
impls = loop.run_until_complete(construct_stack(config))
|
||||
|
||||
except InvalidProviderError as e:
|
||||
logger.error(f"Error: {str(e)}")
|
||||
sys.exit(1)
|
||||
|
||||
if Api.telemetry in impls:
|
||||
setup_logger(impls[Api.telemetry])
|
||||
else:
|
||||
setup_logger(TelemetryAdapter(TelemetryConfig(), {}))
|
||||
|
||||
all_routes = get_all_api_routes()
|
||||
# Load external APIs if configured
|
||||
external_apis = load_external_apis(config)
|
||||
all_routes = get_all_api_routes(external_apis)
|
||||
|
||||
if config.apis:
|
||||
apis_to_serve = set(config.apis)
|
||||
|
@ -502,9 +510,12 @@ def main(args: argparse.Namespace | None = None):
|
|||
api = Api(api_str)
|
||||
|
||||
routes = all_routes[api]
|
||||
impl = impls[api]
|
||||
try:
|
||||
impl = impls[api]
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Could not find provider implementation for {api} API") from e
|
||||
|
||||
for route in routes:
|
||||
for route, _ in routes:
|
||||
if not hasattr(impl, route.name):
|
||||
# ideally this should be a typing violation already
|
||||
raise ValueError(f"Could not find method {route.name} on {impl}!")
|
||||
|
@ -533,7 +544,7 @@ def main(args: argparse.Namespace | None = None):
|
|||
app.exception_handler(Exception)(global_exception_handler)
|
||||
|
||||
app.__llama_stack_impls__ = impls
|
||||
app.add_middleware(TracingMiddleware, impls=impls)
|
||||
app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis)
|
||||
|
||||
import uvicorn
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import importlib.resources
|
||||
import os
|
||||
import re
|
||||
|
@ -38,6 +39,7 @@ from llama_stack.distribution.distribution import get_provider_registry
|
|||
from llama_stack.distribution.inspect import DistributionInspectConfig, DistributionInspectImpl
|
||||
from llama_stack.distribution.providers import ProviderImpl, ProviderImplConfig
|
||||
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
|
||||
from llama_stack.distribution.routing_tables.common import CommonRoutingTableImpl
|
||||
from llama_stack.distribution.store.registry import create_dist_registry
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.log import get_logger
|
||||
|
@ -90,6 +92,9 @@ RESOURCES = [
|
|||
]
|
||||
|
||||
|
||||
REGISTRY_REFRESH_INTERVAL_SECONDS = 300
|
||||
|
||||
|
||||
async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
|
||||
for rsrc, api, register_method, list_method in RESOURCES:
|
||||
objects = getattr(run_config, rsrc)
|
||||
|
@ -324,9 +329,33 @@ async def construct_stack(
|
|||
add_internal_implementations(impls, run_config)
|
||||
|
||||
await register_resources(run_config, impls)
|
||||
|
||||
task = asyncio.create_task(refresh_registry(impls))
|
||||
|
||||
def cb(task):
|
||||
import traceback
|
||||
|
||||
if task.cancelled():
|
||||
logger.error("Model refresh task cancelled")
|
||||
elif task.exception():
|
||||
logger.error(f"Model refresh task failed: {task.exception()}")
|
||||
traceback.print_exception(task.exception())
|
||||
else:
|
||||
logger.debug("Model refresh task completed")
|
||||
|
||||
task.add_done_callback(cb)
|
||||
return impls
|
||||
|
||||
|
||||
async def refresh_registry(impls: dict[Api, Any]):
|
||||
routing_tables = [v for v in impls.values() if isinstance(v, CommonRoutingTableImpl)]
|
||||
while True:
|
||||
for routing_table in routing_tables:
|
||||
await routing_table.refresh()
|
||||
|
||||
await asyncio.sleep(REGISTRY_REFRESH_INTERVAL_SECONDS)
|
||||
|
||||
|
||||
def get_stack_run_config_from_template(template: str) -> StackRunConfig:
|
||||
template_path = importlib.resources.files("llama_stack") / f"templates/{template}/run.yaml"
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from logging.config import dictConfig
|
||||
|
||||
|
@ -30,6 +31,7 @@ CATEGORIES = [
|
|||
"eval",
|
||||
"tools",
|
||||
"client",
|
||||
"telemetry",
|
||||
]
|
||||
|
||||
# Initialize category levels with default level
|
||||
|
@ -113,6 +115,11 @@ def parse_environment_config(env_config: str) -> dict[str, int]:
|
|||
return category_levels
|
||||
|
||||
|
||||
def strip_rich_markup(text):
|
||||
"""Remove Rich markup tags like [dim], [bold magenta], etc."""
|
||||
return re.sub(r"\[/?[a-zA-Z0-9 _#=,]+\]", "", text)
|
||||
|
||||
|
||||
class CustomRichHandler(RichHandler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs["console"] = Console(width=150)
|
||||
|
@ -131,6 +138,19 @@ class CustomRichHandler(RichHandler):
|
|||
self.markup = original_markup
|
||||
|
||||
|
||||
class CustomFileHandler(logging.FileHandler):
|
||||
def __init__(self, filename, mode="a", encoding=None, delay=False):
|
||||
super().__init__(filename, mode, encoding, delay)
|
||||
# Default formatter to match console output
|
||||
self.default_formatter = logging.Formatter("%(asctime)s %(name)s:%(lineno)d %(category)s: %(message)s")
|
||||
self.setFormatter(self.default_formatter)
|
||||
|
||||
def emit(self, record):
|
||||
if hasattr(record, "msg"):
|
||||
record.msg = strip_rich_markup(str(record.msg))
|
||||
super().emit(record)
|
||||
|
||||
|
||||
def setup_logging(category_levels: dict[str, int], log_file: str | None) -> None:
|
||||
"""
|
||||
Configure logging based on the provided category log levels and an optional log file.
|
||||
|
@ -167,8 +187,7 @@ def setup_logging(category_levels: dict[str, int], log_file: str | None) -> None
|
|||
# Add a file handler if log_file is set
|
||||
if log_file:
|
||||
handlers["file"] = {
|
||||
"class": "logging.FileHandler",
|
||||
"formatter": "rich",
|
||||
"()": CustomFileHandler,
|
||||
"filename": log_file,
|
||||
"mode": "a",
|
||||
"encoding": "utf-8",
|
||||
|
|
|
@ -47,6 +47,17 @@ class ModelsProtocolPrivate(Protocol):
|
|||
|
||||
async def unregister_model(self, model_id: str) -> None: ...
|
||||
|
||||
# the Stack router will query each provider for their list of models
|
||||
# if a `refresh_interval_seconds` is provided, this method will be called
|
||||
# periodically to refresh the list of models
|
||||
#
|
||||
# NOTE: each model returned will be registered with the model registry. this means
|
||||
# a callback to the `register_model()` method will be made. this is duplicative and
|
||||
# may be removed in the future.
|
||||
async def list_models(self) -> list[Model] | None: ...
|
||||
|
||||
async def should_refresh_models(self) -> bool: ...
|
||||
|
||||
|
||||
class ShieldsProtocolPrivate(Protocol):
|
||||
async def register_shield(self, shield: Shield) -> None: ...
|
||||
|
|
|
@ -102,6 +102,12 @@ class MetaReferenceInferenceImpl(
|
|||
if self.config.create_distributed_process_group:
|
||||
self.generator.stop()
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return False
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
return None
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
|
@ -41,6 +42,8 @@ class SentenceTransformersInferenceImpl(
|
|||
InferenceProvider,
|
||||
ModelsProtocolPrivate,
|
||||
):
|
||||
__provider_id__: str
|
||||
|
||||
def __init__(self, config: SentenceTransformersInferenceConfig) -> None:
|
||||
self.config = config
|
||||
|
||||
|
@ -50,6 +53,22 @@ class SentenceTransformersInferenceImpl(
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return False
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
return [
|
||||
Model(
|
||||
identifier="all-MiniLM-L6-v2",
|
||||
provider_resource_id="all-MiniLM-L6-v2",
|
||||
provider_id=self.__provider_id__,
|
||||
metadata={
|
||||
"embedding_dimension": 384,
|
||||
},
|
||||
model_type=ModelType.embedding,
|
||||
),
|
||||
]
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
return model
|
||||
|
||||
|
|
|
@ -11,19 +11,9 @@ from opentelemetry.sdk.trace import ReadableSpan
|
|||
from opentelemetry.sdk.trace.export import SpanProcessor
|
||||
from opentelemetry.trace.status import StatusCode
|
||||
|
||||
# Colors for console output
|
||||
COLORS = {
|
||||
"reset": "\033[0m",
|
||||
"bold": "\033[1m",
|
||||
"dim": "\033[2m",
|
||||
"red": "\033[31m",
|
||||
"green": "\033[32m",
|
||||
"yellow": "\033[33m",
|
||||
"blue": "\033[34m",
|
||||
"magenta": "\033[35m",
|
||||
"cyan": "\033[36m",
|
||||
"white": "\033[37m",
|
||||
}
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name="console_span_processor", category="telemetry")
|
||||
|
||||
|
||||
class ConsoleSpanProcessor(SpanProcessor):
|
||||
|
@ -35,34 +25,21 @@ class ConsoleSpanProcessor(SpanProcessor):
|
|||
return
|
||||
|
||||
timestamp = datetime.fromtimestamp(span.start_time / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
|
||||
|
||||
print(
|
||||
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
|
||||
f"{COLORS['magenta']}[START]{COLORS['reset']} "
|
||||
f"{COLORS['dim']}{span.name}{COLORS['reset']}"
|
||||
)
|
||||
logger.info(f"[dim]{timestamp}[/dim] [bold magenta][START][/bold magenta] [dim]{span.name}[/dim]")
|
||||
|
||||
def on_end(self, span: ReadableSpan) -> None:
|
||||
if span.attributes and span.attributes.get("__autotraced__"):
|
||||
return
|
||||
|
||||
timestamp = datetime.fromtimestamp(span.end_time / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
|
||||
|
||||
span_context = (
|
||||
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
|
||||
f"{COLORS['magenta']}[END]{COLORS['reset']} "
|
||||
f"{COLORS['dim']}{span.name}{COLORS['reset']}"
|
||||
)
|
||||
|
||||
span_context = f"[dim]{timestamp}[/dim] [bold magenta][END][/bold magenta] [dim]{span.name}[/dim]"
|
||||
if span.status.status_code == StatusCode.ERROR:
|
||||
span_context += f"{COLORS['reset']} {COLORS['red']}[ERROR]{COLORS['reset']}"
|
||||
span_context += " [bold red][ERROR][/bold red]"
|
||||
elif span.status.status_code != StatusCode.UNSET:
|
||||
span_context += f"{COLORS['reset']} [{span.status.status_code}]"
|
||||
|
||||
span_context += f" [{span.status.status_code}]"
|
||||
duration_ms = (span.end_time - span.start_time) / 1e6
|
||||
span_context += f"{COLORS['reset']} ({duration_ms:.2f}ms)"
|
||||
|
||||
print(span_context)
|
||||
span_context += f" ({duration_ms:.2f}ms)"
|
||||
logger.info(span_context)
|
||||
|
||||
if self.print_attributes and span.attributes:
|
||||
for key, value in span.attributes.items():
|
||||
|
@ -71,31 +48,26 @@ class ConsoleSpanProcessor(SpanProcessor):
|
|||
str_value = str(value)
|
||||
if len(str_value) > 1000:
|
||||
str_value = str_value[:997] + "..."
|
||||
print(f" {COLORS['dim']}{key}: {str_value}{COLORS['reset']}")
|
||||
logger.info(f" [dim]{key}[/dim]: {str_value}")
|
||||
|
||||
for event in span.events:
|
||||
event_time = datetime.fromtimestamp(event.timestamp / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3]
|
||||
|
||||
severity = event.attributes.get("severity", "info")
|
||||
message = event.attributes.get("message", event.name)
|
||||
if isinstance(message, dict | list):
|
||||
if isinstance(message, dict) or isinstance(message, list):
|
||||
message = json.dumps(message, indent=2)
|
||||
|
||||
severity_colors = {
|
||||
"error": f"{COLORS['bold']}{COLORS['red']}",
|
||||
"warn": f"{COLORS['bold']}{COLORS['yellow']}",
|
||||
"info": COLORS["white"],
|
||||
"debug": COLORS["dim"],
|
||||
}
|
||||
msg_color = severity_colors.get(severity, COLORS["white"])
|
||||
|
||||
print(f" {event_time} {msg_color}[{severity.upper()}] {message}{COLORS['reset']}")
|
||||
|
||||
severity_color = {
|
||||
"error": "red",
|
||||
"warn": "yellow",
|
||||
"info": "white",
|
||||
"debug": "dim",
|
||||
}.get(severity, "white")
|
||||
logger.info(f" {event_time} [bold {severity_color}][{severity.upper()}][/bold {severity_color}] {message}")
|
||||
if event.attributes:
|
||||
for key, value in event.attributes.items():
|
||||
if key.startswith("__") or key in ["message", "severity"]:
|
||||
continue
|
||||
print(f" {COLORS['dim']}{key}: {value}{COLORS['reset']}")
|
||||
logger.info(f"/r[dim]{key}[/dim]: {value}")
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Shutdown the processor."""
|
||||
|
|
|
@ -6,13 +6,14 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
from pydantic import Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class FireworksImplConfig(BaseModel):
|
||||
class FireworksImplConfig(RemoteInferenceProviderConfig):
|
||||
url: str = Field(
|
||||
default="https://api.fireworks.ai/inference/v1",
|
||||
description="The URL for the Fireworks server",
|
||||
|
|
|
@ -70,7 +70,7 @@ logger = get_logger(name=__name__, category="inference")
|
|||
|
||||
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||
def __init__(self, config: FireworksImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
|
||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
|
|
@ -13,8 +13,10 @@ DEFAULT_OLLAMA_URL = "http://localhost:11434"
|
|||
|
||||
class OllamaImplConfig(BaseModel):
|
||||
url: str = DEFAULT_OLLAMA_URL
|
||||
refresh_models: bool = Field(default=False, description="refresh and re-register models periodically")
|
||||
refresh_models_interval: int = Field(default=300, description="interval in seconds to refresh models")
|
||||
refresh_models: bool = Field(
|
||||
default=False,
|
||||
description="Whether to refresh models periodically",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", **kwargs) -> dict[str, Any]:
|
||||
|
|
|
@ -98,14 +98,16 @@ class OllamaInferenceAdapter(
|
|||
def __init__(self, config: OllamaImplConfig) -> None:
|
||||
self.register_helper = ModelRegistryHelper(MODEL_ENTRIES)
|
||||
self.config = config
|
||||
self._client = None
|
||||
self._clients: dict[asyncio.AbstractEventLoop, AsyncClient] = {}
|
||||
self._openai_client = None
|
||||
|
||||
@property
|
||||
def client(self) -> AsyncClient:
|
||||
if self._client is None:
|
||||
self._client = AsyncClient(host=self.config.url)
|
||||
return self._client
|
||||
# ollama client attaches itself to the current event loop (sadly?)
|
||||
loop = asyncio.get_running_loop()
|
||||
if loop not in self._clients:
|
||||
self._clients[loop] = AsyncClient(host=self.config.url)
|
||||
return self._clients[loop]
|
||||
|
||||
@property
|
||||
def openai_client(self) -> AsyncOpenAI:
|
||||
|
@ -121,59 +123,61 @@ class OllamaInferenceAdapter(
|
|||
"Ollama Server is not running, make sure to start it using `ollama serve` in a separate terminal"
|
||||
)
|
||||
|
||||
if self.config.refresh_models:
|
||||
logger.debug("ollama starting background model refresh task")
|
||||
self._refresh_task = asyncio.create_task(self._refresh_models())
|
||||
|
||||
def cb(task):
|
||||
if task.cancelled():
|
||||
import traceback
|
||||
|
||||
logger.error(f"ollama background refresh task canceled:\n{''.join(traceback.format_stack())}")
|
||||
elif task.exception():
|
||||
logger.error(f"ollama background refresh task died: {task.exception()}")
|
||||
else:
|
||||
logger.error("ollama background refresh task completed unexpectedly")
|
||||
|
||||
self._refresh_task.add_done_callback(cb)
|
||||
|
||||
async def _refresh_models(self) -> None:
|
||||
# Wait for model store to be available (with timeout)
|
||||
waited_time = 0
|
||||
while not self.model_store and waited_time < 60:
|
||||
await asyncio.sleep(1)
|
||||
waited_time += 1
|
||||
|
||||
if not self.model_store:
|
||||
raise ValueError("Model store not set after waiting 60 seconds")
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return self.config.refresh_models
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
provider_id = self.__provider_id__
|
||||
while True:
|
||||
try:
|
||||
response = await self.client.list()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to list models: {str(e)}")
|
||||
await asyncio.sleep(self.config.refresh_models_interval)
|
||||
response = await self.client.list()
|
||||
|
||||
# always add the two embedding models which can be pulled on demand
|
||||
models = [
|
||||
Model(
|
||||
identifier="all-minilm:l6-v2",
|
||||
provider_resource_id="all-minilm:l6-v2",
|
||||
provider_id=provider_id,
|
||||
metadata={
|
||||
"embedding_dimension": 384,
|
||||
"context_length": 512,
|
||||
},
|
||||
model_type=ModelType.embedding,
|
||||
),
|
||||
# add all-minilm alias
|
||||
Model(
|
||||
identifier="all-minilm",
|
||||
provider_resource_id="all-minilm:l6-v2",
|
||||
provider_id=provider_id,
|
||||
metadata={
|
||||
"embedding_dimension": 384,
|
||||
"context_length": 512,
|
||||
},
|
||||
model_type=ModelType.embedding,
|
||||
),
|
||||
Model(
|
||||
identifier="nomic-embed-text",
|
||||
provider_resource_id="nomic-embed-text",
|
||||
provider_id=provider_id,
|
||||
metadata={
|
||||
"embedding_dimension": 768,
|
||||
"context_length": 8192,
|
||||
},
|
||||
model_type=ModelType.embedding,
|
||||
),
|
||||
]
|
||||
for m in response.models:
|
||||
# kill embedding models since we don't know dimensions for them
|
||||
if m.details.family in ["bert"]:
|
||||
continue
|
||||
|
||||
models = []
|
||||
for m in response.models:
|
||||
model_type = ModelType.embedding if m.details.family in ["bert"] else ModelType.llm
|
||||
if model_type == ModelType.embedding:
|
||||
continue
|
||||
models.append(
|
||||
Model(
|
||||
identifier=m.model,
|
||||
provider_resource_id=m.model,
|
||||
provider_id=provider_id,
|
||||
metadata={},
|
||||
model_type=model_type,
|
||||
)
|
||||
models.append(
|
||||
Model(
|
||||
identifier=m.model,
|
||||
provider_resource_id=m.model,
|
||||
provider_id=provider_id,
|
||||
metadata={},
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
await self.model_store.update_registered_llm_models(provider_id, models)
|
||||
logger.debug(f"ollama refreshed model list ({len(models)} models)")
|
||||
|
||||
await asyncio.sleep(self.config.refresh_models_interval)
|
||||
)
|
||||
return models
|
||||
|
||||
async def health(self) -> HealthResponse:
|
||||
"""
|
||||
|
@ -190,12 +194,7 @@ class OllamaInferenceAdapter(
|
|||
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if hasattr(self, "_refresh_task") and not self._refresh_task.done():
|
||||
logger.debug("ollama cancelling background refresh task")
|
||||
self._refresh_task.cancel()
|
||||
|
||||
self._client = None
|
||||
self._openai_client = None
|
||||
self._clients.clear()
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
|
|
@ -12,11 +12,6 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
)
|
||||
|
||||
LLM_MODEL_IDS = [
|
||||
# the models w/ "openai/" prefix are the litellm specific model names.
|
||||
# they should be deprecated in favor of the canonical openai model names.
|
||||
"openai/gpt-4o",
|
||||
"openai/gpt-4o-mini",
|
||||
"openai/chatgpt-4o-latest",
|
||||
"gpt-3.5-turbo-0125",
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-instruct",
|
||||
|
@ -43,8 +38,6 @@ class EmbeddingModelInfo:
|
|||
|
||||
|
||||
EMBEDDING_MODEL_IDS: dict[str, EmbeddingModelInfo] = {
|
||||
"openai/text-embedding-3-small": EmbeddingModelInfo(1536, 8192),
|
||||
"openai/text-embedding-3-large": EmbeddingModelInfo(3072, 8192),
|
||||
"text-embedding-3-small": EmbeddingModelInfo(1536, 8192),
|
||||
"text-embedding-3-large": EmbeddingModelInfo(3072, 8192),
|
||||
}
|
||||
|
|
|
@ -6,13 +6,14 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
from pydantic import Field, SecretStr
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class TogetherImplConfig(BaseModel):
|
||||
class TogetherImplConfig(RemoteInferenceProviderConfig):
|
||||
url: str = Field(
|
||||
default="https://api.together.xyz/v1",
|
||||
description="The URL for the Together AI server",
|
||||
|
|
|
@ -66,7 +66,7 @@ logger = get_logger(name=__name__, category="inference")
|
|||
|
||||
class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||
def __init__(self, config: TogetherImplConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
|
||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
|
|
@ -33,10 +33,6 @@ class VLLMInferenceAdapterConfig(BaseModel):
|
|||
default=False,
|
||||
description="Whether to refresh models periodically",
|
||||
)
|
||||
refresh_models_interval: int = Field(
|
||||
default=300,
|
||||
description="Interval in seconds to refresh models",
|
||||
)
|
||||
|
||||
@field_validator("tls_verify")
|
||||
@classmethod
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import asyncio
|
||||
import json
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from typing import Any
|
||||
|
@ -293,7 +292,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
# automatically set by the resolver when instantiating the provider
|
||||
__provider_id__: str
|
||||
model_store: ModelStore | None = None
|
||||
_refresh_task: asyncio.Task | None = None
|
||||
|
||||
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
|
||||
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
|
||||
|
@ -301,65 +299,30 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
self.client = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
if not self.config.url:
|
||||
# intentionally don't raise an error here, we want to allow the provider to be "dormant"
|
||||
# or available in distributions like "starter" without causing a ruckus
|
||||
return
|
||||
pass
|
||||
|
||||
if self.config.refresh_models:
|
||||
self._refresh_task = asyncio.create_task(self._refresh_models())
|
||||
|
||||
def cb(task):
|
||||
import traceback
|
||||
|
||||
if task.cancelled():
|
||||
log.error(f"vLLM background refresh task canceled:\n{''.join(traceback.format_stack())}")
|
||||
elif task.exception():
|
||||
# print the stack trace for the exception
|
||||
exc = task.exception()
|
||||
log.error(f"vLLM background refresh task died: {exc}")
|
||||
traceback.print_exception(exc)
|
||||
else:
|
||||
log.error("vLLM background refresh task completed unexpectedly")
|
||||
|
||||
self._refresh_task.add_done_callback(cb)
|
||||
|
||||
async def _refresh_models(self) -> None:
|
||||
provider_id = self.__provider_id__
|
||||
waited_time = 0
|
||||
while not self.model_store and waited_time < 60:
|
||||
await asyncio.sleep(1)
|
||||
waited_time += 1
|
||||
|
||||
if not self.model_store:
|
||||
raise ValueError("Model store not set after waiting 60 seconds")
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return self.config.refresh_models
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
self._lazy_initialize_client()
|
||||
assert self.client is not None # mypy
|
||||
while True:
|
||||
try:
|
||||
models = []
|
||||
async for m in self.client.models.list():
|
||||
model_type = ModelType.llm # unclear how to determine embedding vs. llm models
|
||||
models.append(
|
||||
Model(
|
||||
identifier=m.id,
|
||||
provider_resource_id=m.id,
|
||||
provider_id=provider_id,
|
||||
metadata={},
|
||||
model_type=model_type,
|
||||
)
|
||||
)
|
||||
await self.model_store.update_registered_llm_models(provider_id, models)
|
||||
log.debug(f"vLLM refreshed model list ({len(models)} models)")
|
||||
except Exception as e:
|
||||
log.error(f"vLLM background refresh task failed: {e}")
|
||||
await asyncio.sleep(self.config.refresh_models_interval)
|
||||
models = []
|
||||
async for m in self.client.models.list():
|
||||
model_type = ModelType.llm # unclear how to determine embedding vs. llm models
|
||||
models.append(
|
||||
Model(
|
||||
identifier=m.id,
|
||||
provider_resource_id=m.id,
|
||||
provider_id=self.__provider_id__,
|
||||
metadata={},
|
||||
model_type=model_type,
|
||||
)
|
||||
)
|
||||
return models
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if self._refresh_task:
|
||||
self._refresh_task.cancel()
|
||||
self._refresh_task = None
|
||||
pass
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
|
|
@ -20,6 +20,13 @@ from llama_stack.providers.utils.inference import (
|
|||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
class RemoteInferenceProviderConfig(BaseModel):
|
||||
allowed_models: list[str] | None = Field(
|
||||
default=None,
|
||||
description="List of models that should be registered with the model registry. If None, all models are allowed.",
|
||||
)
|
||||
|
||||
|
||||
# TODO: this class is more confusing than useful right now. We need to make it
|
||||
# more closer to the Model class.
|
||||
class ProviderModelEntry(BaseModel):
|
||||
|
@ -65,7 +72,10 @@ def build_model_entry(provider_model_id: str, model_descriptor: str) -> Provider
|
|||
|
||||
|
||||
class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||
def __init__(self, model_entries: list[ProviderModelEntry]):
|
||||
__provider_id__: str
|
||||
|
||||
def __init__(self, model_entries: list[ProviderModelEntry], allowed_models: list[str] | None = None):
|
||||
self.allowed_models = allowed_models
|
||||
self.alias_to_provider_id_map = {}
|
||||
self.provider_id_to_llama_model_map = {}
|
||||
for entry in model_entries:
|
||||
|
@ -79,6 +89,27 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
|||
self.alias_to_provider_id_map[entry.llama_model] = entry.provider_model_id
|
||||
self.provider_id_to_llama_model_map[entry.provider_model_id] = entry.llama_model
|
||||
|
||||
async def list_models(self) -> list[Model] | None:
|
||||
models = []
|
||||
for entry in self.model_entries:
|
||||
ids = [entry.provider_model_id] + entry.aliases
|
||||
for id in ids:
|
||||
if self.allowed_models and id not in self.allowed_models:
|
||||
continue
|
||||
models.append(
|
||||
Model(
|
||||
model_id=id,
|
||||
provider_resource_id=entry.provider_model_id,
|
||||
model_type=ModelType.llm,
|
||||
metadata=entry.metadata,
|
||||
provider_id=self.__provider_id__,
|
||||
)
|
||||
)
|
||||
return models
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return False
|
||||
|
||||
def get_provider_model_id(self, identifier: str) -> str | None:
|
||||
return self.alias_to_provider_id_map.get(identifier, None)
|
||||
|
||||
|
|
|
@ -4,13 +4,16 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from enum import Enum
|
||||
from typing import Any, cast
|
||||
|
||||
import httpx
|
||||
from mcp import ClientSession
|
||||
from mcp import ClientSession, McpError
|
||||
from mcp import types as mcp_types
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
|
||||
from llama_stack.apis.common.content_types import ImageContentItem, InterleavedContentItem, TextContentItem
|
||||
from llama_stack.apis.tools import (
|
||||
|
@ -21,31 +24,61 @@ from llama_stack.apis.tools import (
|
|||
)
|
||||
from llama_stack.distribution.datatypes import AuthenticationRequiredError
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.tools.ttl_dict import TTLDict
|
||||
|
||||
logger = get_logger(__name__, category="tools")
|
||||
|
||||
protocol_cache = TTLDict(ttl_seconds=3600)
|
||||
|
||||
|
||||
class MCPProtol(Enum):
|
||||
UNKNOWN = 0
|
||||
STREAMABLE_HTTP = 1
|
||||
SSE = 2
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def sse_client_wrapper(endpoint: str, headers: dict[str, str]):
|
||||
try:
|
||||
async with sse_client(endpoint, headers=headers) as streams:
|
||||
async with ClientSession(*streams) as session:
|
||||
await session.initialize()
|
||||
yield session
|
||||
except* httpx.HTTPStatusError as eg:
|
||||
for exc in eg.exceptions:
|
||||
# mypy does not currently narrow the type of `eg.exceptions` based on the `except*` filter,
|
||||
# so we explicitly cast each item to httpx.HTTPStatusError. This is safe because
|
||||
# `except* httpx.HTTPStatusError` guarantees all exceptions in `eg.exceptions` are of that type.
|
||||
err = cast(httpx.HTTPStatusError, exc)
|
||||
if err.response.status_code == 401:
|
||||
raise AuthenticationRequiredError(exc) from exc
|
||||
raise
|
||||
async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerator[ClientSession, Any]:
|
||||
# we use a ttl'd dict to cache the happy path protocol for each endpoint
|
||||
# but, we always fall back to trying the other protocol if we cannot initialize the session
|
||||
connection_strategies = [MCPProtol.STREAMABLE_HTTP, MCPProtol.SSE]
|
||||
mcp_protocol = protocol_cache.get(endpoint, default=MCPProtol.UNKNOWN)
|
||||
if mcp_protocol == MCPProtol.SSE:
|
||||
connection_strategies = [MCPProtol.SSE, MCPProtol.STREAMABLE_HTTP]
|
||||
|
||||
for i, strategy in enumerate(connection_strategies):
|
||||
try:
|
||||
client = streamablehttp_client
|
||||
if strategy == MCPProtol.SSE:
|
||||
client = sse_client
|
||||
async with client(endpoint, headers=headers) as client_streams:
|
||||
async with ClientSession(read_stream=client_streams[0], write_stream=client_streams[1]) as session:
|
||||
await session.initialize()
|
||||
protocol_cache[endpoint] = strategy
|
||||
yield session
|
||||
return
|
||||
except* httpx.HTTPStatusError as eg:
|
||||
for exc in eg.exceptions:
|
||||
# mypy does not currently narrow the type of `eg.exceptions` based on the `except*` filter,
|
||||
# so we explicitly cast each item to httpx.HTTPStatusError. This is safe because
|
||||
# `except* httpx.HTTPStatusError` guarantees all exceptions in `eg.exceptions` are of that type.
|
||||
err = cast(httpx.HTTPStatusError, exc)
|
||||
if err.response.status_code == 401:
|
||||
raise AuthenticationRequiredError(exc) from exc
|
||||
if i == len(connection_strategies) - 1:
|
||||
raise
|
||||
except* McpError:
|
||||
if i < len(connection_strategies) - 1:
|
||||
logger.warning(
|
||||
f"failed to connect via {strategy.name}, falling back to {connection_strategies[i + 1].name}"
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefsResponse:
|
||||
tools = []
|
||||
async with sse_client_wrapper(endpoint, headers) as session:
|
||||
async with client_wrapper(endpoint, headers) as session:
|
||||
tools_result = await session.list_tools()
|
||||
for tool in tools_result.tools:
|
||||
parameters = []
|
||||
|
@ -73,7 +106,7 @@ async def list_mcp_tools(endpoint: str, headers: dict[str, str]) -> ListToolDefs
|
|||
async def invoke_mcp_tool(
|
||||
endpoint: str, headers: dict[str, str], tool_name: str, kwargs: dict[str, Any]
|
||||
) -> ToolInvocationResult:
|
||||
async with sse_client_wrapper(endpoint, headers) as session:
|
||||
async with client_wrapper(endpoint, headers) as session:
|
||||
result = await session.call_tool(tool_name, kwargs)
|
||||
|
||||
content: list[InterleavedContentItem] = []
|
||||
|
|
70
llama_stack/providers/utils/tools/ttl_dict.py
Normal file
70
llama_stack/providers/utils/tools/ttl_dict.py
Normal file
|
@ -0,0 +1,70 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import time
|
||||
from threading import RLock
|
||||
from typing import Any
|
||||
|
||||
|
||||
class TTLDict(dict):
|
||||
"""
|
||||
A dictionary with a ttl for each item
|
||||
"""
|
||||
|
||||
def __init__(self, ttl_seconds: float, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.ttl_seconds = ttl_seconds
|
||||
self._expires: dict[Any, Any] = {} # expires holds when an item will expire
|
||||
self._lock = RLock()
|
||||
|
||||
if args or kwargs:
|
||||
for k, v in self.items():
|
||||
self.__setitem__(k, v)
|
||||
|
||||
def __delitem__(self, key):
|
||||
with self._lock:
|
||||
del self._expires[key]
|
||||
super().__delitem__(key)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
with self._lock:
|
||||
self._expires[key] = time.monotonic() + self.ttl_seconds
|
||||
super().__setitem__(key, value)
|
||||
|
||||
def _is_expired(self, key):
|
||||
if key not in self._expires:
|
||||
return False
|
||||
return time.monotonic() > self._expires[key]
|
||||
|
||||
def __getitem__(self, key):
|
||||
with self._lock:
|
||||
if self._is_expired(key):
|
||||
del self._expires[key]
|
||||
super().__delitem__(key)
|
||||
raise KeyError(f"{key} has expired and was removed")
|
||||
|
||||
return super().__getitem__(key)
|
||||
|
||||
def get(self, key, default=None):
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
return default
|
||||
|
||||
def __contains__(self, key):
|
||||
try:
|
||||
_ = self[key]
|
||||
return True
|
||||
except KeyError:
|
||||
return False
|
||||
|
||||
def __repr__(self):
|
||||
with self._lock:
|
||||
for key in self.keys():
|
||||
if self._is_expired(key):
|
||||
del self._expires[key]
|
||||
super().__delitem__(key)
|
||||
return f"TTLDict({self.ttl_seconds}, {super().__repr__()})"
|
|
@ -22,6 +22,7 @@ class WebMethod:
|
|||
# A descriptive name of the corresponding span created by tracing
|
||||
descriptive_name: str | None = None
|
||||
experimental: bool | None = False
|
||||
required_scope: str | None = None
|
||||
|
||||
|
||||
T = TypeVar("T", bound=Callable[..., Any])
|
||||
|
@ -36,6 +37,7 @@ def webmethod(
|
|||
raw_bytes_request_body: bool | None = False,
|
||||
descriptive_name: str | None = None,
|
||||
experimental: bool | None = False,
|
||||
required_scope: str | None = None,
|
||||
) -> Callable[[T], T]:
|
||||
"""
|
||||
Decorator that supplies additional metadata to an endpoint operation function.
|
||||
|
@ -45,6 +47,7 @@ def webmethod(
|
|||
:param request_examples: Sample requests that the operation might take. Pass a list of objects, not JSON.
|
||||
:param response_examples: Sample responses that the operation might produce. Pass a list of objects, not JSON.
|
||||
:param experimental: True if the operation is experimental and subject to change.
|
||||
:param required_scope: Required scope for this endpoint (e.g., 'monitoring.viewer').
|
||||
"""
|
||||
|
||||
def wrap(func: T) -> T:
|
||||
|
@ -57,6 +60,7 @@ def webmethod(
|
|||
raw_bytes_request_body=raw_bytes_request_body,
|
||||
descriptive_name=descriptive_name,
|
||||
experimental=experimental,
|
||||
required_scope=required_scope,
|
||||
)
|
||||
return func
|
||||
|
||||
|
|
|
@ -785,21 +785,6 @@ models:
|
|||
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
|
||||
provider_model_id: Llama3.2-3B
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_OPENAI:=__disabled__}/openai/gpt-4o
|
||||
provider_id: ${env.ENABLE_OPENAI:=__disabled__}
|
||||
provider_model_id: openai/gpt-4o
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_OPENAI:=__disabled__}/openai/gpt-4o-mini
|
||||
provider_id: ${env.ENABLE_OPENAI:=__disabled__}
|
||||
provider_model_id: openai/gpt-4o-mini
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_OPENAI:=__disabled__}/openai/chatgpt-4o-latest
|
||||
provider_id: ${env.ENABLE_OPENAI:=__disabled__}
|
||||
provider_model_id: openai/chatgpt-4o-latest
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_OPENAI:=__disabled__}/gpt-3.5-turbo-0125
|
||||
provider_id: ${env.ENABLE_OPENAI:=__disabled__}
|
||||
|
@ -870,20 +855,6 @@ models:
|
|||
provider_id: ${env.ENABLE_OPENAI:=__disabled__}
|
||||
provider_model_id: o4-mini
|
||||
model_type: llm
|
||||
- metadata:
|
||||
embedding_dimension: 1536
|
||||
context_length: 8192
|
||||
model_id: ${env.ENABLE_OPENAI:=__disabled__}/openai/text-embedding-3-small
|
||||
provider_id: ${env.ENABLE_OPENAI:=__disabled__}
|
||||
provider_model_id: openai/text-embedding-3-small
|
||||
model_type: embedding
|
||||
- metadata:
|
||||
embedding_dimension: 3072
|
||||
context_length: 8192
|
||||
model_id: ${env.ENABLE_OPENAI:=__disabled__}/openai/text-embedding-3-large
|
||||
provider_id: ${env.ENABLE_OPENAI:=__disabled__}
|
||||
provider_model_id: openai/text-embedding-3-large
|
||||
model_type: embedding
|
||||
- metadata:
|
||||
embedding_dimension: 1536
|
||||
context_length: 8192
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
---
|
||||
orphan: true
|
||||
---
|
||||
# NVIDIA Distribution
|
||||
|
||||
The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations.
|
||||
|
|
|
@ -785,21 +785,6 @@ models:
|
|||
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
|
||||
provider_model_id: Llama3.2-3B
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_OPENAI:=__disabled__}/openai/gpt-4o
|
||||
provider_id: ${env.ENABLE_OPENAI:=__disabled__}
|
||||
provider_model_id: openai/gpt-4o
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_OPENAI:=__disabled__}/openai/gpt-4o-mini
|
||||
provider_id: ${env.ENABLE_OPENAI:=__disabled__}
|
||||
provider_model_id: openai/gpt-4o-mini
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_OPENAI:=__disabled__}/openai/chatgpt-4o-latest
|
||||
provider_id: ${env.ENABLE_OPENAI:=__disabled__}
|
||||
provider_model_id: openai/chatgpt-4o-latest
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: ${env.ENABLE_OPENAI:=__disabled__}/gpt-3.5-turbo-0125
|
||||
provider_id: ${env.ENABLE_OPENAI:=__disabled__}
|
||||
|
@ -870,20 +855,6 @@ models:
|
|||
provider_id: ${env.ENABLE_OPENAI:=__disabled__}
|
||||
provider_model_id: o4-mini
|
||||
model_type: llm
|
||||
- metadata:
|
||||
embedding_dimension: 1536
|
||||
context_length: 8192
|
||||
model_id: ${env.ENABLE_OPENAI:=__disabled__}/openai/text-embedding-3-small
|
||||
provider_id: ${env.ENABLE_OPENAI:=__disabled__}
|
||||
provider_model_id: openai/text-embedding-3-small
|
||||
model_type: embedding
|
||||
- metadata:
|
||||
embedding_dimension: 3072
|
||||
context_length: 8192
|
||||
model_id: ${env.ENABLE_OPENAI:=__disabled__}/openai/text-embedding-3-large
|
||||
provider_id: ${env.ENABLE_OPENAI:=__disabled__}
|
||||
provider_model_id: openai/text-embedding-3-large
|
||||
model_type: embedding
|
||||
- metadata:
|
||||
embedding_dimension: 1536
|
||||
context_length: 8192
|
||||
|
|
15
llama_stack/ui/package-lock.json
generated
15
llama_stack/ui/package-lock.json
generated
|
@ -15,7 +15,7 @@
|
|||
"@radix-ui/react-tooltip": "^1.2.6",
|
||||
"class-variance-authority": "^0.7.1",
|
||||
"clsx": "^2.1.1",
|
||||
"llama-stack-client": "^0.2.14",
|
||||
"llama-stack-client": "^0.2.15",
|
||||
"lucide-react": "^0.510.0",
|
||||
"next": "15.3.3",
|
||||
"next-auth": "^4.24.11",
|
||||
|
@ -6468,14 +6468,15 @@
|
|||
}
|
||||
},
|
||||
"node_modules/form-data": {
|
||||
"version": "4.0.2",
|
||||
"resolved": "https://registry.npmjs.org/form-data/-/form-data-4.0.2.tgz",
|
||||
"integrity": "sha512-hGfm/slu0ZabnNt4oaRZ6uREyfCj6P4fT/n6A1rGV+Z0VdGXjfOhVUpkn6qVQONHGIFwmveGXyDs75+nr6FM8w==",
|
||||
"version": "4.0.4",
|
||||
"resolved": "https://registry.npmjs.org/form-data/-/form-data-4.0.4.tgz",
|
||||
"integrity": "sha512-KrGhL9Q4zjj0kiUt5OO4Mr/A/jlI2jDYs5eHBpYHPcBEVSiipAvn2Ko2HnPe20rmcuuvMHNdZFp+4IlGTMF0Ow==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"asynckit": "^0.4.0",
|
||||
"combined-stream": "^1.0.8",
|
||||
"es-set-tostringtag": "^2.1.0",
|
||||
"hasown": "^2.0.2",
|
||||
"mime-types": "^2.1.12"
|
||||
},
|
||||
"engines": {
|
||||
|
@ -9099,9 +9100,9 @@
|
|||
"license": "MIT"
|
||||
},
|
||||
"node_modules/llama-stack-client": {
|
||||
"version": "0.2.14",
|
||||
"resolved": "https://registry.npmjs.org/llama-stack-client/-/llama-stack-client-0.2.14.tgz",
|
||||
"integrity": "sha512-bVU3JHp+EPEKR0Vb9vcd9ZyQj/72jSDuptKLwOXET9WrkphIQ8xuW5ueecMTgq8UEls3lwB3HiZM2cDOR9eDsQ==",
|
||||
"version": "0.2.15",
|
||||
"resolved": "https://registry.npmjs.org/llama-stack-client/-/llama-stack-client-0.2.15.tgz",
|
||||
"integrity": "sha512-onfYzgPWAxve4uP7BuiK/ZdEC7w6X1PIXXXpQY57qZC7C4xUAM5kwfT3JWIe/jE22Lwc2vTN1ScfYlAYcoYAsg==",
|
||||
"license": "Apache-2.0",
|
||||
"dependencies": {
|
||||
"@types/node": "^18.11.18",
|
||||
|
|
|
@ -12,21 +12,23 @@ set -euo pipefail
|
|||
failed=0
|
||||
|
||||
# Find all workflow YAML files
|
||||
|
||||
# Use GitHub Actions error format
|
||||
# ::error file={name},line={line},col={col}::{message}
|
||||
|
||||
for file in $(find .github/workflows/ -type f \( -name "*.yml" -o -name "*.yaml" \)); do
|
||||
IFS=$'\n'
|
||||
# Grep for `uses:` lines that look like actions
|
||||
for line in $(grep -E '^.*uses:[^@]+@[^ ]+' "$file"); do
|
||||
# Extract the ref part after the last @
|
||||
# Get line numbers for each 'uses:'
|
||||
while IFS= read -r match; do
|
||||
line_num=$(echo "$match" | cut -d: -f1)
|
||||
line=$(echo "$match" | cut -d: -f2-)
|
||||
ref=$(echo "$line" | sed -E 's/.*@([A-Za-z0-9._-]+).*/\1/')
|
||||
# Check if ref is a 40-character hex string (full SHA).
|
||||
#
|
||||
# Note: strictly speaking, this could also be a tag or branch name, but
|
||||
# we'd have to pull this info from the remote. Meh.
|
||||
if ! [[ $ref =~ ^[0-9a-fA-F]{40}$ ]]; then
|
||||
echo "ERROR: $file uses non-SHA action ref: $line"
|
||||
# Output in GitHub Actions annotation format
|
||||
echo "::error file=$file,line=$line_num::uses non-SHA action ref: $line"
|
||||
failed=1
|
||||
fi
|
||||
done
|
||||
done < <(grep -n -E '^.*uses:[^@]+@[^ ]+' "$file")
|
||||
done
|
||||
|
||||
exit $failed
|
||||
|
|
|
@ -15,11 +15,40 @@ set -Eeuo pipefail
|
|||
PORT=8321
|
||||
OLLAMA_PORT=11434
|
||||
MODEL_ALIAS="llama3.2:3b"
|
||||
SERVER_IMAGE="docker.io/llamastack/distribution-ollama:0.2.2"
|
||||
WAIT_TIMEOUT=300
|
||||
SERVER_IMAGE="docker.io/llamastack/distribution-starter:latest"
|
||||
WAIT_TIMEOUT=30
|
||||
TEMP_LOG=""
|
||||
|
||||
# Cleanup function to remove temporary files
|
||||
cleanup() {
|
||||
if [ -n "$TEMP_LOG" ] && [ -f "$TEMP_LOG" ]; then
|
||||
rm -f "$TEMP_LOG"
|
||||
fi
|
||||
}
|
||||
|
||||
# Set up trap to clean up on exit, error, or interrupt
|
||||
trap cleanup EXIT ERR INT TERM
|
||||
|
||||
log(){ printf "\e[1;32m%s\e[0m\n" "$*"; }
|
||||
die(){ printf "\e[1;31m❌ %s\e[0m\n" "$*" >&2; exit 1; }
|
||||
die(){
|
||||
printf "\e[1;31m❌ %s\e[0m\n" "$*" >&2
|
||||
printf "\e[1;31m🐛 Report an issue @ https://github.com/meta-llama/llama-stack/issues if you think it's a bug\e[0m\n" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Helper function to execute command with logging
|
||||
execute_with_log() {
|
||||
local cmd=("$@")
|
||||
TEMP_LOG=$(mktemp)
|
||||
if ! "${cmd[@]}" > "$TEMP_LOG" 2>&1; then
|
||||
log "❌ Command failed; dumping output:"
|
||||
log "Command that failed: ${cmd[*]}"
|
||||
log "Command output:"
|
||||
cat "$TEMP_LOG"
|
||||
return 1
|
||||
fi
|
||||
return 0
|
||||
}
|
||||
|
||||
wait_for_service() {
|
||||
local url="$1"
|
||||
|
@ -27,7 +56,7 @@ wait_for_service() {
|
|||
local timeout="$3"
|
||||
local name="$4"
|
||||
local start ts
|
||||
log "⏳ Waiting for ${name}…"
|
||||
log "⏳ Waiting for ${name}..."
|
||||
start=$(date +%s)
|
||||
while true; do
|
||||
if curl --retry 5 --retry-delay 1 --retry-max-time "$timeout" --retry-all-errors --silent --fail "$url" 2>/dev/null | grep -q "$pattern"; then
|
||||
|
@ -38,24 +67,24 @@ wait_for_service() {
|
|||
return 1
|
||||
fi
|
||||
printf '.'
|
||||
sleep 1
|
||||
done
|
||||
printf '\n'
|
||||
return 0
|
||||
}
|
||||
|
||||
usage() {
|
||||
cat << EOF
|
||||
📚 Llama-Stack Deployment Script
|
||||
📚 Llama Stack Deployment Script
|
||||
|
||||
Description:
|
||||
This script sets up and deploys Llama-Stack with Ollama integration in containers.
|
||||
This script sets up and deploys Llama Stack with Ollama integration in containers.
|
||||
It handles both Docker and Podman runtimes and includes automatic platform detection.
|
||||
|
||||
Usage:
|
||||
$(basename "$0") [OPTIONS]
|
||||
|
||||
Options:
|
||||
-p, --port PORT Server port for Llama-Stack (default: ${PORT})
|
||||
-p, --port PORT Server port for Llama Stack (default: ${PORT})
|
||||
-o, --ollama-port PORT Ollama service port (default: ${OLLAMA_PORT})
|
||||
-m, --model MODEL Model alias to use (default: ${MODEL_ALIAS})
|
||||
-i, --image IMAGE Server image (default: ${SERVER_IMAGE})
|
||||
|
@ -129,15 +158,15 @@ fi
|
|||
# CONTAINERS_MACHINE_PROVIDER=libkrun podman machine init
|
||||
if [ "$ENGINE" = "podman" ] && [ "$(uname -s)" = "Darwin" ]; then
|
||||
if ! podman info &>/dev/null; then
|
||||
log "⌛️ Initializing Podman VM…"
|
||||
log "⌛️ Initializing Podman VM..."
|
||||
podman machine init &>/dev/null || true
|
||||
podman machine start &>/dev/null || true
|
||||
|
||||
log "⌛️ Waiting for Podman API…"
|
||||
log "⌛️ Waiting for Podman API..."
|
||||
until podman info &>/dev/null; do
|
||||
sleep 1
|
||||
done
|
||||
log "✅ Podman VM is up"
|
||||
log "✅ Podman VM is up."
|
||||
fi
|
||||
fi
|
||||
|
||||
|
@ -145,8 +174,10 @@ fi
|
|||
for name in ollama-server llama-stack; do
|
||||
ids=$($ENGINE ps -aq --filter "name=^${name}$")
|
||||
if [ -n "$ids" ]; then
|
||||
log "⚠️ Found existing container(s) for '${name}', removing…"
|
||||
$ENGINE rm -f "$ids" > /dev/null 2>&1
|
||||
log "⚠️ Found existing container(s) for '${name}', removing..."
|
||||
if ! execute_with_log $ENGINE rm -f "$ids"; then
|
||||
die "Container cleanup failed"
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
||||
|
@ -154,28 +185,32 @@ done
|
|||
# 0. Create a shared network
|
||||
###############################################################################
|
||||
if ! $ENGINE network inspect llama-net >/dev/null 2>&1; then
|
||||
log "🌐 Creating network…"
|
||||
$ENGINE network create llama-net >/dev/null 2>&1
|
||||
log "🌐 Creating network..."
|
||||
if ! execute_with_log $ENGINE network create llama-net; then
|
||||
die "Network creation failed"
|
||||
fi
|
||||
fi
|
||||
|
||||
###############################################################################
|
||||
# 1. Ollama
|
||||
###############################################################################
|
||||
log "🦙 Starting Ollama…"
|
||||
$ENGINE run -d "${PLATFORM_OPTS[@]}" --name ollama-server \
|
||||
log "🦙 Starting Ollama..."
|
||||
if ! execute_with_log $ENGINE run -d "${PLATFORM_OPTS[@]}" --name ollama-server \
|
||||
--network llama-net \
|
||||
-p "${OLLAMA_PORT}:${OLLAMA_PORT}" \
|
||||
docker.io/ollama/ollama > /dev/null 2>&1
|
||||
docker.io/ollama/ollama > /dev/null 2>&1; then
|
||||
die "Ollama startup failed"
|
||||
fi
|
||||
|
||||
if ! wait_for_service "http://localhost:${OLLAMA_PORT}/" "Ollama" "$WAIT_TIMEOUT" "Ollama daemon"; then
|
||||
log "❌ Ollama daemon did not become ready in ${WAIT_TIMEOUT}s; dumping container logs:"
|
||||
log "❌ Ollama daemon did not become ready in ${WAIT_TIMEOUT}s; dumping container logs:"
|
||||
$ENGINE logs --tail 200 ollama-server
|
||||
die "Ollama startup failed"
|
||||
fi
|
||||
|
||||
log "📦 Ensuring model is pulled: ${MODEL_ALIAS}…"
|
||||
if ! $ENGINE exec ollama-server ollama pull "${MODEL_ALIAS}" > /dev/null 2>&1; then
|
||||
log "❌ Failed to pull model ${MODEL_ALIAS}; dumping container logs:"
|
||||
log "📦 Ensuring model is pulled: ${MODEL_ALIAS}..."
|
||||
if ! execute_with_log $ENGINE exec ollama-server ollama pull "${MODEL_ALIAS}"; then
|
||||
log "❌ Failed to pull model ${MODEL_ALIAS}; dumping container logs:"
|
||||
$ENGINE logs --tail 200 ollama-server
|
||||
die "Model pull failed"
|
||||
fi
|
||||
|
@ -187,25 +222,29 @@ cmd=( run -d "${PLATFORM_OPTS[@]}" --name llama-stack \
|
|||
--network llama-net \
|
||||
-p "${PORT}:${PORT}" \
|
||||
"${SERVER_IMAGE}" --port "${PORT}" \
|
||||
--env INFERENCE_MODEL="${MODEL_ALIAS}" \
|
||||
--env OLLAMA_URL="http://ollama-server:${OLLAMA_PORT}" )
|
||||
--env OLLAMA_INFERENCE_MODEL="${MODEL_ALIAS}" \
|
||||
--env OLLAMA_URL="http://ollama-server:${OLLAMA_PORT}" \
|
||||
--env ENABLE_OLLAMA=ollama)
|
||||
|
||||
log "🦙 Starting Llama‑Stack…"
|
||||
$ENGINE "${cmd[@]}" > /dev/null 2>&1
|
||||
log "🦙 Starting Llama Stack..."
|
||||
if ! execute_with_log $ENGINE "${cmd[@]}"; then
|
||||
die "Llama Stack startup failed"
|
||||
fi
|
||||
|
||||
if ! wait_for_service "http://127.0.0.1:${PORT}/v1/health" "OK" "$WAIT_TIMEOUT" "Llama-Stack API"; then
|
||||
log "❌ Llama-Stack did not become ready in ${WAIT_TIMEOUT}s; dumping container logs:"
|
||||
if ! wait_for_service "http://127.0.0.1:${PORT}/v1/health" "OK" "$WAIT_TIMEOUT" "Llama Stack API"; then
|
||||
log "❌ Llama Stack did not become ready in ${WAIT_TIMEOUT}s; dumping container logs:"
|
||||
$ENGINE logs --tail 200 llama-stack
|
||||
die "Llama-Stack startup failed"
|
||||
die "Llama Stack startup failed"
|
||||
fi
|
||||
|
||||
###############################################################################
|
||||
# Done
|
||||
###############################################################################
|
||||
log ""
|
||||
log "🎉 Llama‑Stack is ready!"
|
||||
log "🎉 Llama Stack is ready!"
|
||||
log "👉 API endpoint: http://localhost:${PORT}"
|
||||
log "📖 Documentation: https://llama-stack.readthedocs.io/en/latest/references/index.html"
|
||||
log "💻 To access the llama‑stack CLI, exec into the container:"
|
||||
log "💻 To access the llama stack CLI, exec into the container:"
|
||||
log " $ENGINE exec -ti llama-stack bash"
|
||||
log "🐛 Report an issue @ https://github.com/meta-llama/llama-stack/issues if you think it's a bug"
|
||||
log ""
|
||||
|
|
|
@ -1,3 +0,0 @@
|
|||
# Ollama external provider for Llama Stack
|
||||
|
||||
Template code to create a new external provider for Llama Stack.
|
|
@ -1,7 +0,0 @@
|
|||
adapter:
|
||||
adapter_type: custom_ollama
|
||||
pip_packages: ["ollama", "aiohttp", "tests/external-provider/llama-stack-provider-ollama"]
|
||||
config_class: llama_stack_provider_ollama.config.OllamaImplConfig
|
||||
module: llama_stack_provider_ollama
|
||||
api_dependencies: []
|
||||
optional_api_dependencies: []
|
|
@ -1,43 +0,0 @@
|
|||
[project]
|
||||
dependencies = [
|
||||
"llama-stack",
|
||||
"pydantic",
|
||||
"ollama",
|
||||
"aiohttp",
|
||||
"aiosqlite",
|
||||
"autoevals",
|
||||
"chardet",
|
||||
"chromadb-client",
|
||||
"datasets",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
"matplotlib",
|
||||
"mcp",
|
||||
"nltk",
|
||||
"numpy",
|
||||
"openai",
|
||||
"opentelemetry-exporter-otlp-proto-http",
|
||||
"opentelemetry-sdk",
|
||||
"pandas",
|
||||
"pillow",
|
||||
"psycopg2-binary",
|
||||
"pymongo",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"requests",
|
||||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
"uvicorn",
|
||||
]
|
||||
|
||||
name = "llama-stack-provider-ollama"
|
||||
version = "0.1.0"
|
||||
description = "External provider for Ollama using the Llama Stack API"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
|
@ -1,124 +0,0 @@
|
|||
version: 2
|
||||
image_name: ollama
|
||||
apis:
|
||||
- agents
|
||||
- datasetio
|
||||
- eval
|
||||
- inference
|
||||
- safety
|
||||
- scoring
|
||||
- telemetry
|
||||
- tool_runtime
|
||||
- vector_io
|
||||
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: ollama
|
||||
provider_type: remote::ollama
|
||||
config:
|
||||
url: ${env.OLLAMA_URL:=http://localhost:11434}
|
||||
vector_io:
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
config:
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/faiss_store.db
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config: {}
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
agents_store:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/agents_store.db
|
||||
responses_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/responses_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:=\u200b}"
|
||||
sinks: ${env.TELEMETRY_SINKS:=console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/trace_store.db
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/huggingface_datasetio.db
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
config:
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/huggingface_datasetio.db
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config:
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/localfs_datasetio.db
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
config: {}
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
config:
|
||||
api_key: ${env.BRAVE_SEARCH_API_KEY:+}
|
||||
max_results: 3
|
||||
- provider_id: tavily-search
|
||||
provider_type: remote::tavily-search
|
||||
config:
|
||||
api_key: ${env.TAVILY_SEARCH_API_KEY:+}
|
||||
max_results: 3
|
||||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
config: {}
|
||||
- provider_id: wolfram-alpha
|
||||
provider_type: remote::wolfram-alpha
|
||||
config:
|
||||
api_key: ${env.WOLFRAM_ALPHA_API_KEY:+}
|
||||
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ollama}/registry.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: ${env.INFERENCE_MODEL}
|
||||
provider_id: custom_ollama
|
||||
model_type: llm
|
||||
- metadata:
|
||||
embedding_dimension: 384
|
||||
model_id: all-MiniLM-L6-v2
|
||||
provider_id: custom_ollama
|
||||
provider_model_id: all-minilm:l6-v2
|
||||
model_type: embedding
|
||||
shields: []
|
||||
vector_dbs: []
|
||||
datasets: []
|
||||
scoring_fns: []
|
||||
benchmarks: []
|
||||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
- toolgroup_id: builtin::rag
|
||||
provider_id: rag-runtime
|
||||
- toolgroup_id: builtin::wolfram_alpha
|
||||
provider_id: wolfram-alpha
|
||||
server:
|
||||
port: 8321
|
||||
external_providers_dir: ~/.llama/providers.d
|
|
@ -2,8 +2,9 @@ version: '2'
|
|||
distribution_spec:
|
||||
description: Custom distro for CI tests
|
||||
providers:
|
||||
inference:
|
||||
- remote::custom_ollama
|
||||
image_type: container
|
||||
weather:
|
||||
- remote::kaze
|
||||
image_type: venv
|
||||
image_name: ci-test
|
||||
external_providers_dir: ~/.llama/providers.d
|
||||
external_apis_dir: ~/.llama/apis.d
|
6
tests/external/kaze.yaml
vendored
Normal file
6
tests/external/kaze.yaml
vendored
Normal file
|
@ -0,0 +1,6 @@
|
|||
adapter:
|
||||
adapter_type: kaze
|
||||
pip_packages: ["tests/external/llama-stack-provider-kaze"]
|
||||
config_class: llama_stack_provider_kaze.config.KazeProviderConfig
|
||||
module: llama_stack_provider_kaze
|
||||
optional_api_dependencies: []
|
15
tests/external/llama-stack-api-weather/pyproject.toml
vendored
Normal file
15
tests/external/llama-stack-api-weather/pyproject.toml
vendored
Normal file
|
@ -0,0 +1,15 @@
|
|||
[project]
|
||||
name = "llama-stack-api-weather"
|
||||
version = "0.1.0"
|
||||
description = "Weather API for Llama Stack"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = ["llama-stack", "pydantic"]
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
include = ["llama_stack_api_weather", "llama_stack_api_weather.*"]
|
11
tests/external/llama-stack-api-weather/src/llama_stack_api_weather/__init__.py
vendored
Normal file
11
tests/external/llama-stack-api-weather/src/llama_stack_api_weather/__init__.py
vendored
Normal file
|
@ -0,0 +1,11 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Weather API for Llama Stack."""
|
||||
|
||||
from .weather import WeatherProvider, available_providers
|
||||
|
||||
__all__ = ["WeatherProvider", "available_providers"]
|
39
tests/external/llama-stack-api-weather/src/llama_stack_api_weather/weather.py
vendored
Normal file
39
tests/external/llama-stack-api-weather/src/llama_stack_api_weather/weather.py
vendored
Normal file
|
@ -0,0 +1,39 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
from llama_stack.providers.datatypes import AdapterSpec, Api, ProviderSpec, RemoteProviderSpec
|
||||
from llama_stack.schema_utils import webmethod
|
||||
|
||||
|
||||
def available_providers() -> list[ProviderSpec]:
|
||||
return [
|
||||
RemoteProviderSpec(
|
||||
api=Api.weather,
|
||||
provider_type="remote::kaze",
|
||||
config_class="llama_stack_provider_kaze.KazeProviderConfig",
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="kaze",
|
||||
module="llama_stack_provider_kaze",
|
||||
pip_packages=["llama_stack_provider_kaze"],
|
||||
config_class="llama_stack_provider_kaze.KazeProviderConfig",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class WeatherProvider(Protocol):
|
||||
"""
|
||||
A protocol for the Weather API.
|
||||
"""
|
||||
|
||||
@webmethod(route="/weather/locations", method="GET")
|
||||
async def get_available_locations() -> dict[str, list[str]]:
|
||||
"""
|
||||
Get the available locations.
|
||||
"""
|
||||
...
|
15
tests/external/llama-stack-provider-kaze/pyproject.toml
vendored
Normal file
15
tests/external/llama-stack-provider-kaze/pyproject.toml
vendored
Normal file
|
@ -0,0 +1,15 @@
|
|||
[project]
|
||||
name = "llama-stack-provider-kaze"
|
||||
version = "0.1.0"
|
||||
description = "Kaze weather provider for Llama Stack"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = ["llama-stack", "pydantic", "aiohttp"]
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
include = ["llama_stack_provider_kaze", "llama_stack_provider_kaze.*"]
|
20
tests/external/llama-stack-provider-kaze/src/llama_stack_provider_kaze/__init__.py
vendored
Normal file
20
tests/external/llama-stack-provider-kaze/src/llama_stack_provider_kaze/__init__.py
vendored
Normal file
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Kaze weather provider for Llama Stack."""
|
||||
|
||||
from .config import KazeProviderConfig
|
||||
from .kaze import WeatherKazeAdapter
|
||||
|
||||
__all__ = ["KazeProviderConfig", "WeatherKazeAdapter"]
|
||||
|
||||
|
||||
async def get_adapter_impl(config: KazeProviderConfig, _deps):
|
||||
from .kaze import WeatherKazeAdapter
|
||||
|
||||
impl = WeatherKazeAdapter(config)
|
||||
await impl.initialize()
|
||||
return impl
|
11
tests/external/llama-stack-provider-kaze/src/llama_stack_provider_kaze/config.py
vendored
Normal file
11
tests/external/llama-stack-provider-kaze/src/llama_stack_provider_kaze/config.py
vendored
Normal file
|
@ -0,0 +1,11 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class KazeProviderConfig(BaseModel):
|
||||
"""Configuration for the Kaze weather provider."""
|
26
tests/external/llama-stack-provider-kaze/src/llama_stack_provider_kaze/kaze.py
vendored
Normal file
26
tests/external/llama-stack-provider-kaze/src/llama_stack_provider_kaze/kaze.py
vendored
Normal file
|
@ -0,0 +1,26 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack_api_weather.weather import WeatherProvider
|
||||
|
||||
from .config import KazeProviderConfig
|
||||
|
||||
|
||||
class WeatherKazeAdapter(WeatherProvider):
|
||||
"""Kaze weather provider implementation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: KazeProviderConfig,
|
||||
) -> None:
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def get_available_locations(self) -> dict[str, list[str]]:
|
||||
"""Get available weather locations."""
|
||||
return {"locations": ["Paris", "Tokyo"]}
|
13
tests/external/run-byoa.yaml
vendored
Normal file
13
tests/external/run-byoa.yaml
vendored
Normal file
|
@ -0,0 +1,13 @@
|
|||
version: "2"
|
||||
image_name: "llama-stack-api-weather"
|
||||
apis:
|
||||
- weather
|
||||
providers:
|
||||
weather:
|
||||
- provider_id: kaze
|
||||
provider_type: remote::kaze
|
||||
config: {}
|
||||
external_apis_dir: ~/.llama/apis.d
|
||||
external_providers_dir: ~/.llama/providers.d
|
||||
server:
|
||||
port: 8321
|
4
tests/external/weather.yaml
vendored
Normal file
4
tests/external/weather.yaml
vendored
Normal file
|
@ -0,0 +1,4 @@
|
|||
module: llama_stack_api_weather
|
||||
name: weather
|
||||
pip_packages: ["tests/external/llama-stack-api-weather"]
|
||||
protocol: WeatherProvider
|
|
@ -179,9 +179,7 @@ def test_openai_completion_prompt_logprobs(llama_stack_client, client_with_model
|
|||
model=text_model_id,
|
||||
prompt=prompt,
|
||||
stream=False,
|
||||
extra_body={
|
||||
"prompt_logprobs": prompt_logprobs,
|
||||
},
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
)
|
||||
assert len(response.choices) > 0
|
||||
choice = response.choices[0]
|
||||
|
@ -196,9 +194,7 @@ def test_openai_completion_guided_choice(llama_stack_client, client_with_models,
|
|||
model=text_model_id,
|
||||
prompt=prompt,
|
||||
stream=False,
|
||||
extra_body={
|
||||
"guided_choice": ["joy", "sadness"],
|
||||
},
|
||||
guided_choice=["joy", "sadness"],
|
||||
)
|
||||
assert len(response.choices) > 0
|
||||
choice = response.choices[0]
|
||||
|
|
|
@ -15,6 +15,7 @@ from llama_stack.apis.models import Model, ModelType
|
|||
from llama_stack.apis.shields.shields import Shield
|
||||
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.distribution.datatypes import RegistryEntrySource
|
||||
from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable
|
||||
from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable
|
||||
from llama_stack.distribution.routing_tables.models import ModelsRoutingTable
|
||||
|
@ -45,6 +46,30 @@ class InferenceImpl(Impl):
|
|||
async def unregister_model(self, model_id: str):
|
||||
return model_id
|
||||
|
||||
async def should_refresh_models(self):
|
||||
return False
|
||||
|
||||
async def list_models(self):
|
||||
return [
|
||||
Model(
|
||||
identifier="provider-model-1",
|
||||
provider_resource_id="provider-model-1",
|
||||
provider_id="test_provider",
|
||||
metadata={},
|
||||
model_type=ModelType.llm,
|
||||
),
|
||||
Model(
|
||||
identifier="provider-model-2",
|
||||
provider_resource_id="provider-model-2",
|
||||
provider_id="test_provider",
|
||||
metadata={"embedding_dimension": 512},
|
||||
model_type=ModelType.embedding,
|
||||
),
|
||||
]
|
||||
|
||||
async def shutdown(self):
|
||||
pass
|
||||
|
||||
|
||||
class SafetyImpl(Impl):
|
||||
def __init__(self):
|
||||
|
@ -378,3 +403,170 @@ async def test_models_fallback_lookup_behavior(cached_disk_dist_registry):
|
|||
raise AssertionError("Should have raised ValueError for non-existent model")
|
||||
except ValueError as e:
|
||||
assert "not found" in str(e)
|
||||
|
||||
|
||||
async def test_models_source_tracking_default(cached_disk_dist_registry):
|
||||
"""Test that models registered via register_model get default source."""
|
||||
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# Register model via register_model (should get default source)
|
||||
await table.register_model(model_id="user-model", provider_id="test_provider")
|
||||
|
||||
models = await table.list_models()
|
||||
assert len(models.data) == 1
|
||||
model = models.data[0]
|
||||
assert model.source == RegistryEntrySource.via_register_api
|
||||
assert model.identifier == "test_provider/user-model"
|
||||
|
||||
# Cleanup
|
||||
await table.shutdown()
|
||||
|
||||
|
||||
async def test_models_source_tracking_provider(cached_disk_dist_registry):
|
||||
"""Test that models registered via update_registered_models get provider source."""
|
||||
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# Simulate provider refresh by calling update_registered_models
|
||||
provider_models = [
|
||||
Model(
|
||||
identifier="provider-model-1",
|
||||
provider_resource_id="provider-model-1",
|
||||
provider_id="test_provider",
|
||||
metadata={},
|
||||
model_type=ModelType.llm,
|
||||
),
|
||||
Model(
|
||||
identifier="provider-model-2",
|
||||
provider_resource_id="provider-model-2",
|
||||
provider_id="test_provider",
|
||||
metadata={"embedding_dimension": 512},
|
||||
model_type=ModelType.embedding,
|
||||
),
|
||||
]
|
||||
await table.update_registered_models("test_provider", provider_models)
|
||||
|
||||
models = await table.list_models()
|
||||
assert len(models.data) == 2
|
||||
|
||||
# All models should have provider source
|
||||
for model in models.data:
|
||||
assert model.source == RegistryEntrySource.listed_from_provider
|
||||
assert model.provider_id == "test_provider"
|
||||
|
||||
# Cleanup
|
||||
await table.shutdown()
|
||||
|
||||
|
||||
async def test_models_source_interaction_preserves_default(cached_disk_dist_registry):
|
||||
"""Test that provider refresh preserves user-registered models with default source."""
|
||||
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# First register a user model with same provider_resource_id as provider will later provide
|
||||
await table.register_model(
|
||||
model_id="my-custom-alias", provider_model_id="provider-model-1", provider_id="test_provider"
|
||||
)
|
||||
|
||||
# Verify user model is registered with default source
|
||||
models = await table.list_models()
|
||||
assert len(models.data) == 1
|
||||
user_model = models.data[0]
|
||||
assert user_model.source == RegistryEntrySource.via_register_api
|
||||
assert user_model.identifier == "my-custom-alias"
|
||||
assert user_model.provider_resource_id == "provider-model-1"
|
||||
|
||||
# Now simulate provider refresh
|
||||
provider_models = [
|
||||
Model(
|
||||
identifier="provider-model-1",
|
||||
provider_resource_id="provider-model-1",
|
||||
provider_id="test_provider",
|
||||
metadata={},
|
||||
model_type=ModelType.llm,
|
||||
),
|
||||
Model(
|
||||
identifier="different-model",
|
||||
provider_resource_id="different-model",
|
||||
provider_id="test_provider",
|
||||
metadata={},
|
||||
model_type=ModelType.llm,
|
||||
),
|
||||
]
|
||||
await table.update_registered_models("test_provider", provider_models)
|
||||
|
||||
# Verify user model with alias is preserved, but provider added new model
|
||||
models = await table.list_models()
|
||||
assert len(models.data) == 2
|
||||
|
||||
# Find the user model and provider model
|
||||
user_model = next((m for m in models.data if m.identifier == "my-custom-alias"), None)
|
||||
provider_model = next((m for m in models.data if m.identifier == "different-model"), None)
|
||||
|
||||
assert user_model is not None
|
||||
assert user_model.source == RegistryEntrySource.via_register_api
|
||||
assert user_model.provider_resource_id == "provider-model-1"
|
||||
|
||||
assert provider_model is not None
|
||||
assert provider_model.source == RegistryEntrySource.listed_from_provider
|
||||
assert provider_model.provider_resource_id == "different-model"
|
||||
|
||||
# Cleanup
|
||||
await table.shutdown()
|
||||
|
||||
|
||||
async def test_models_source_interaction_cleanup_provider_models(cached_disk_dist_registry):
|
||||
"""Test that provider refresh removes old provider models but keeps default ones."""
|
||||
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||
await table.initialize()
|
||||
|
||||
# Register a user model
|
||||
await table.register_model(model_id="user-model", provider_id="test_provider")
|
||||
|
||||
# Add some provider models
|
||||
provider_models_v1 = [
|
||||
Model(
|
||||
identifier="provider-model-old",
|
||||
provider_resource_id="provider-model-old",
|
||||
provider_id="test_provider",
|
||||
metadata={},
|
||||
model_type=ModelType.llm,
|
||||
),
|
||||
]
|
||||
await table.update_registered_models("test_provider", provider_models_v1)
|
||||
|
||||
# Verify we have both user and provider models
|
||||
models = await table.list_models()
|
||||
assert len(models.data) == 2
|
||||
|
||||
# Now update with new provider models (should remove old provider models)
|
||||
provider_models_v2 = [
|
||||
Model(
|
||||
identifier="provider-model-new",
|
||||
provider_resource_id="provider-model-new",
|
||||
provider_id="test_provider",
|
||||
metadata={},
|
||||
model_type=ModelType.llm,
|
||||
),
|
||||
]
|
||||
await table.update_registered_models("test_provider", provider_models_v2)
|
||||
|
||||
# Should have user model + new provider model, old provider model gone
|
||||
models = await table.list_models()
|
||||
assert len(models.data) == 2
|
||||
|
||||
identifiers = {m.identifier for m in models.data}
|
||||
assert "test_provider/user-model" in identifiers # User model preserved
|
||||
assert "provider-model-new" in identifiers # New provider model (uses provider's identifier)
|
||||
assert "provider-model-old" not in identifiers # Old provider model removed
|
||||
|
||||
# Verify sources are correct
|
||||
user_model = next((m for m in models.data if m.identifier == "test_provider/user-model"), None)
|
||||
provider_model = next((m for m in models.data if m.identifier == "provider-model-new"), None)
|
||||
|
||||
assert user_model.source == RegistryEntrySource.via_register_api
|
||||
assert provider_model.source == RegistryEntrySource.listed_from_provider
|
||||
|
||||
# Cleanup
|
||||
await table.shutdown()
|
||||
|
|
|
@ -5,321 +5,353 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.inference import CompletionMessage, UserMessage
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.apis.safety import RunShieldResponse, ViolationLevel
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.models.llama.datatypes import StopReason
|
||||
from llama_stack.providers.remote.safety.nvidia.config import NVIDIASafetyConfig
|
||||
from llama_stack.providers.remote.safety.nvidia.nvidia import NVIDIASafetyAdapter
|
||||
|
||||
|
||||
class TestNVIDIASafetyAdapter(unittest.TestCase):
|
||||
def setUp(self):
|
||||
os.environ["NVIDIA_GUARDRAILS_URL"] = "http://nemo.test"
|
||||
class TestNVIDIASafetyAdapter(NVIDIASafetyAdapter):
|
||||
"""Test implementation that provides the required shield_store."""
|
||||
|
||||
# Initialize the adapter
|
||||
self.config = NVIDIASafetyConfig(
|
||||
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
|
||||
)
|
||||
self.adapter = NVIDIASafetyAdapter(config=self.config)
|
||||
self.shield_store = AsyncMock()
|
||||
self.adapter.shield_store = self.shield_store
|
||||
def __init__(self, config: NVIDIASafetyConfig, shield_store):
|
||||
super().__init__(config)
|
||||
self.shield_store = shield_store
|
||||
|
||||
# Mock the HTTP request methods
|
||||
self.guardrails_post_patcher = patch(
|
||||
"llama_stack.providers.remote.safety.nvidia.nvidia.NeMoGuardrails._guardrails_post"
|
||||
)
|
||||
self.mock_guardrails_post = self.guardrails_post_patcher.start()
|
||||
self.mock_guardrails_post.return_value = {"status": "allowed"}
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up after each test."""
|
||||
self.guardrails_post_patcher.stop()
|
||||
@pytest.fixture
|
||||
def nvidia_adapter():
|
||||
"""Set up the NVIDIASafetyAdapter for testing."""
|
||||
os.environ["NVIDIA_GUARDRAILS_URL"] = "http://nemo.test"
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def inject_fixtures(self, run_async):
|
||||
self.run_async = run_async
|
||||
# Initialize the adapter
|
||||
config = NVIDIASafetyConfig(
|
||||
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
|
||||
)
|
||||
|
||||
def _assert_request(
|
||||
self,
|
||||
mock_call: MagicMock,
|
||||
expected_url: str,
|
||||
expected_headers: dict[str, str] | None = None,
|
||||
expected_json: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Helper method to verify request details in mock API calls.
|
||||
# Create a mock shield store that implements the ShieldStore protocol
|
||||
shield_store = AsyncMock()
|
||||
shield_store.get_shield = AsyncMock()
|
||||
|
||||
Args:
|
||||
mock_call: The MagicMock object that was called
|
||||
expected_url: The expected URL to which the request was made
|
||||
expected_headers: Optional dictionary of expected request headers
|
||||
expected_json: Optional dictionary of expected JSON payload
|
||||
"""
|
||||
call_args = mock_call.call_args
|
||||
adapter = TestNVIDIASafetyAdapter(config=config, shield_store=shield_store)
|
||||
|
||||
# Check URL
|
||||
assert call_args[0][0] == expected_url
|
||||
return adapter
|
||||
|
||||
# Check headers if provided
|
||||
if expected_headers:
|
||||
for key, value in expected_headers.items():
|
||||
assert call_args[1]["headers"][key] == value
|
||||
|
||||
# Check JSON if provided
|
||||
if expected_json:
|
||||
for key, value in expected_json.items():
|
||||
if isinstance(value, dict):
|
||||
for nested_key, nested_value in value.items():
|
||||
assert call_args[1]["json"][key][nested_key] == nested_value
|
||||
else:
|
||||
assert call_args[1]["json"][key] == value
|
||||
@pytest.fixture
|
||||
def mock_guardrails_post():
|
||||
"""Mock the HTTP request methods."""
|
||||
with patch("llama_stack.providers.remote.safety.nvidia.nvidia.NeMoGuardrails._guardrails_post") as mock_post:
|
||||
mock_post.return_value = {"status": "allowed"}
|
||||
yield mock_post
|
||||
|
||||
def test_register_shield_with_valid_id(self):
|
||||
shield = Shield(
|
||||
provider_id="nvidia",
|
||||
type="shield",
|
||||
identifier="test-shield",
|
||||
provider_resource_id="test-model",
|
||||
)
|
||||
|
||||
# Register the shield
|
||||
self.run_async(self.adapter.register_shield(shield))
|
||||
def _assert_request(
|
||||
mock_call: MagicMock,
|
||||
expected_url: str,
|
||||
expected_headers: dict[str, str] | None = None,
|
||||
expected_json: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Helper method to verify request details in mock API calls.
|
||||
|
||||
def test_register_shield_without_id(self):
|
||||
shield = Shield(
|
||||
provider_id="nvidia",
|
||||
type="shield",
|
||||
identifier="test-shield",
|
||||
provider_resource_id="",
|
||||
)
|
||||
Args:
|
||||
mock_call: The MagicMock object that was called
|
||||
expected_url: The expected URL to which the request was made
|
||||
expected_headers: Optional dictionary of expected request headers
|
||||
expected_json: Optional dictionary of expected JSON payload
|
||||
"""
|
||||
call_args = mock_call.call_args
|
||||
|
||||
# Register the shield should raise a ValueError
|
||||
with self.assertRaises(ValueError):
|
||||
self.run_async(self.adapter.register_shield(shield))
|
||||
# Check URL
|
||||
assert call_args[0][0] == expected_url
|
||||
|
||||
def test_run_shield_allowed(self):
|
||||
# Set up the shield
|
||||
shield_id = "test-shield"
|
||||
shield = Shield(
|
||||
provider_id="nvidia",
|
||||
type="shield",
|
||||
identifier=shield_id,
|
||||
provider_resource_id="test-model",
|
||||
)
|
||||
self.shield_store.get_shield.return_value = shield
|
||||
# Check headers if provided
|
||||
if expected_headers:
|
||||
for key, value in expected_headers.items():
|
||||
assert call_args[1]["headers"][key] == value
|
||||
|
||||
# Mock Guardrails API response
|
||||
self.mock_guardrails_post.return_value = {"status": "allowed"}
|
||||
# Check JSON if provided
|
||||
if expected_json:
|
||||
for key, value in expected_json.items():
|
||||
if isinstance(value, dict):
|
||||
for nested_key, nested_value in value.items():
|
||||
assert call_args[1]["json"][key][nested_key] == nested_value
|
||||
else:
|
||||
assert call_args[1]["json"][key] == value
|
||||
|
||||
# Run the shield
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
CompletionMessage(
|
||||
role="assistant",
|
||||
content="I'm doing well, thank you for asking!",
|
||||
stop_reason="end_of_message",
|
||||
tool_calls=[],
|
||||
),
|
||||
]
|
||||
result = self.run_async(self.adapter.run_shield(shield_id, messages))
|
||||
|
||||
# Verify the shield store was called
|
||||
self.shield_store.get_shield.assert_called_once_with(shield_id)
|
||||
async def test_register_shield_with_valid_id(nvidia_adapter):
|
||||
adapter = nvidia_adapter
|
||||
|
||||
# Verify the Guardrails API was called correctly
|
||||
self.mock_guardrails_post.assert_called_once_with(
|
||||
path="/v1/guardrail/checks",
|
||||
data={
|
||||
"model": shield_id,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
|
||||
],
|
||||
"temperature": 1.0,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
"max_tokens": 160,
|
||||
"stream": False,
|
||||
"guardrails": {
|
||||
"config_id": "self-check",
|
||||
},
|
||||
shield = Shield(
|
||||
provider_id="nvidia",
|
||||
type=ResourceType.shield,
|
||||
identifier="test-shield",
|
||||
provider_resource_id="test-model",
|
||||
)
|
||||
|
||||
# Register the shield
|
||||
await adapter.register_shield(shield)
|
||||
|
||||
|
||||
async def test_register_shield_without_id(nvidia_adapter):
|
||||
adapter = nvidia_adapter
|
||||
|
||||
shield = Shield(
|
||||
provider_id="nvidia",
|
||||
type=ResourceType.shield,
|
||||
identifier="test-shield",
|
||||
provider_resource_id="",
|
||||
)
|
||||
|
||||
# Register the shield should raise a ValueError
|
||||
with pytest.raises(ValueError):
|
||||
await adapter.register_shield(shield)
|
||||
|
||||
|
||||
async def test_run_shield_allowed(nvidia_adapter, mock_guardrails_post):
|
||||
adapter = nvidia_adapter
|
||||
|
||||
# Set up the shield
|
||||
shield_id = "test-shield"
|
||||
shield = Shield(
|
||||
provider_id="nvidia",
|
||||
type=ResourceType.shield,
|
||||
identifier=shield_id,
|
||||
provider_resource_id="test-model",
|
||||
)
|
||||
adapter.shield_store.get_shield.return_value = shield
|
||||
|
||||
# Mock Guardrails API response
|
||||
mock_guardrails_post.return_value = {"status": "allowed"}
|
||||
|
||||
# Run the shield
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
CompletionMessage(
|
||||
role="assistant",
|
||||
content="I'm doing well, thank you for asking!",
|
||||
stop_reason=StopReason.end_of_message,
|
||||
tool_calls=[],
|
||||
),
|
||||
]
|
||||
result = await adapter.run_shield(shield_id, messages)
|
||||
|
||||
# Verify the shield store was called
|
||||
adapter.shield_store.get_shield.assert_called_once_with(shield_id)
|
||||
|
||||
# Verify the Guardrails API was called correctly
|
||||
mock_guardrails_post.assert_called_once_with(
|
||||
path="/v1/guardrail/checks",
|
||||
data={
|
||||
"model": shield_id,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
|
||||
],
|
||||
"temperature": 1.0,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
"max_tokens": 160,
|
||||
"stream": False,
|
||||
"guardrails": {
|
||||
"config_id": "self-check",
|
||||
},
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, RunShieldResponse)
|
||||
assert result.violation is None
|
||||
# Verify the result
|
||||
assert isinstance(result, RunShieldResponse)
|
||||
assert result.violation is None
|
||||
|
||||
def test_run_shield_blocked(self):
|
||||
# Set up the shield
|
||||
shield_id = "test-shield"
|
||||
shield = Shield(
|
||||
provider_id="nvidia",
|
||||
type="shield",
|
||||
identifier=shield_id,
|
||||
provider_resource_id="test-model",
|
||||
)
|
||||
self.shield_store.get_shield.return_value = shield
|
||||
|
||||
# Mock Guardrails API response
|
||||
self.mock_guardrails_post.return_value = {"status": "blocked", "rails_status": {"reason": "harmful_content"}}
|
||||
async def test_run_shield_blocked(nvidia_adapter, mock_guardrails_post):
|
||||
adapter = nvidia_adapter
|
||||
|
||||
# Run the shield
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
CompletionMessage(
|
||||
role="assistant",
|
||||
content="I'm doing well, thank you for asking!",
|
||||
stop_reason="end_of_message",
|
||||
tool_calls=[],
|
||||
),
|
||||
]
|
||||
result = self.run_async(self.adapter.run_shield(shield_id, messages))
|
||||
# Set up the shield
|
||||
shield_id = "test-shield"
|
||||
shield = Shield(
|
||||
provider_id="nvidia",
|
||||
type=ResourceType.shield,
|
||||
identifier=shield_id,
|
||||
provider_resource_id="test-model",
|
||||
)
|
||||
adapter.shield_store.get_shield.return_value = shield
|
||||
|
||||
# Verify the shield store was called
|
||||
self.shield_store.get_shield.assert_called_once_with(shield_id)
|
||||
# Mock Guardrails API response
|
||||
mock_guardrails_post.return_value = {"status": "blocked", "rails_status": {"reason": "harmful_content"}}
|
||||
|
||||
# Verify the Guardrails API was called correctly
|
||||
self.mock_guardrails_post.assert_called_once_with(
|
||||
path="/v1/guardrail/checks",
|
||||
data={
|
||||
"model": shield_id,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
|
||||
],
|
||||
"temperature": 1.0,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
"max_tokens": 160,
|
||||
"stream": False,
|
||||
"guardrails": {
|
||||
"config_id": "self-check",
|
||||
},
|
||||
# Run the shield
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
CompletionMessage(
|
||||
role="assistant",
|
||||
content="I'm doing well, thank you for asking!",
|
||||
stop_reason=StopReason.end_of_message,
|
||||
tool_calls=[],
|
||||
),
|
||||
]
|
||||
result = await adapter.run_shield(shield_id, messages)
|
||||
|
||||
# Verify the shield store was called
|
||||
adapter.shield_store.get_shield.assert_called_once_with(shield_id)
|
||||
|
||||
# Verify the Guardrails API was called correctly
|
||||
mock_guardrails_post.assert_called_once_with(
|
||||
path="/v1/guardrail/checks",
|
||||
data={
|
||||
"model": shield_id,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
|
||||
],
|
||||
"temperature": 1.0,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
"max_tokens": 160,
|
||||
"stream": False,
|
||||
"guardrails": {
|
||||
"config_id": "self-check",
|
||||
},
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result.violation is not None
|
||||
assert isinstance(result, RunShieldResponse)
|
||||
assert result.violation.user_message == "Sorry I cannot do this."
|
||||
assert result.violation.violation_level == ViolationLevel.ERROR
|
||||
assert result.violation.metadata == {"reason": "harmful_content"}
|
||||
# Verify the result
|
||||
assert result.violation is not None
|
||||
assert isinstance(result, RunShieldResponse)
|
||||
assert result.violation.user_message == "Sorry I cannot do this."
|
||||
assert result.violation.violation_level == ViolationLevel.ERROR
|
||||
assert result.violation.metadata == {"reason": "harmful_content"}
|
||||
|
||||
def test_run_shield_not_found(self):
|
||||
# Set up shield store to return None
|
||||
shield_id = "non-existent-shield"
|
||||
self.shield_store.get_shield.return_value = None
|
||||
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
]
|
||||
async def test_run_shield_not_found(nvidia_adapter, mock_guardrails_post):
|
||||
adapter = nvidia_adapter
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
self.run_async(self.adapter.run_shield(shield_id, messages))
|
||||
# Set up shield store to return None
|
||||
shield_id = "non-existent-shield"
|
||||
adapter.shield_store.get_shield.return_value = None
|
||||
|
||||
# Verify the shield store was called
|
||||
self.shield_store.get_shield.assert_called_once_with(shield_id)
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
]
|
||||
|
||||
# Verify the Guardrails API was not called
|
||||
self.mock_guardrails_post.assert_not_called()
|
||||
with pytest.raises(ValueError):
|
||||
await adapter.run_shield(shield_id, messages)
|
||||
|
||||
def test_run_shield_http_error(self):
|
||||
shield_id = "test-shield"
|
||||
shield = Shield(
|
||||
provider_id="nvidia",
|
||||
type="shield",
|
||||
identifier=shield_id,
|
||||
provider_resource_id="test-model",
|
||||
)
|
||||
self.shield_store.get_shield.return_value = shield
|
||||
# Verify the shield store was called
|
||||
adapter.shield_store.get_shield.assert_called_once_with(shield_id)
|
||||
|
||||
# Mock Guardrails API to raise an exception
|
||||
error_msg = "API Error: 500 Internal Server Error"
|
||||
self.mock_guardrails_post.side_effect = Exception(error_msg)
|
||||
# Verify the Guardrails API was not called
|
||||
mock_guardrails_post.assert_not_called()
|
||||
|
||||
# Running the shield should raise an exception
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
CompletionMessage(
|
||||
role="assistant",
|
||||
content="I'm doing well, thank you for asking!",
|
||||
stop_reason="end_of_message",
|
||||
tool_calls=[],
|
||||
),
|
||||
]
|
||||
with self.assertRaises(Exception) as context:
|
||||
self.run_async(self.adapter.run_shield(shield_id, messages))
|
||||
|
||||
# Verify the shield store was called
|
||||
self.shield_store.get_shield.assert_called_once_with(shield_id)
|
||||
async def test_run_shield_http_error(nvidia_adapter, mock_guardrails_post):
|
||||
adapter = nvidia_adapter
|
||||
|
||||
# Verify the Guardrails API was called correctly
|
||||
self.mock_guardrails_post.assert_called_once_with(
|
||||
path="/v1/guardrail/checks",
|
||||
data={
|
||||
"model": shield_id,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
|
||||
],
|
||||
"temperature": 1.0,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
"max_tokens": 160,
|
||||
"stream": False,
|
||||
"guardrails": {
|
||||
"config_id": "self-check",
|
||||
},
|
||||
shield_id = "test-shield"
|
||||
shield = Shield(
|
||||
provider_id="nvidia",
|
||||
type=ResourceType.shield,
|
||||
identifier=shield_id,
|
||||
provider_resource_id="test-model",
|
||||
)
|
||||
adapter.shield_store.get_shield.return_value = shield
|
||||
|
||||
# Mock Guardrails API to raise an exception
|
||||
error_msg = "API Error: 500 Internal Server Error"
|
||||
mock_guardrails_post.side_effect = Exception(error_msg)
|
||||
|
||||
# Running the shield should raise an exception
|
||||
messages = [
|
||||
UserMessage(role="user", content="Hello, how are you?"),
|
||||
CompletionMessage(
|
||||
role="assistant",
|
||||
content="I'm doing well, thank you for asking!",
|
||||
stop_reason=StopReason.end_of_message,
|
||||
tool_calls=[],
|
||||
),
|
||||
]
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await adapter.run_shield(shield_id, messages)
|
||||
|
||||
# Verify the shield store was called
|
||||
adapter.shield_store.get_shield.assert_called_once_with(shield_id)
|
||||
|
||||
# Verify the Guardrails API was called correctly
|
||||
mock_guardrails_post.assert_called_once_with(
|
||||
path="/v1/guardrail/checks",
|
||||
data={
|
||||
"model": shield_id,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
|
||||
],
|
||||
"temperature": 1.0,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
"max_tokens": 160,
|
||||
"stream": False,
|
||||
"guardrails": {
|
||||
"config_id": "self-check",
|
||||
},
|
||||
)
|
||||
# Verify the exception message
|
||||
assert error_msg in str(context.exception)
|
||||
},
|
||||
)
|
||||
# Verify the exception message
|
||||
assert error_msg in str(exc_info.value)
|
||||
|
||||
def test_init_nemo_guardrails(self):
|
||||
from llama_stack.providers.remote.safety.nvidia.nvidia import NeMoGuardrails
|
||||
|
||||
test_config_id = "test-custom-config-id"
|
||||
config = NVIDIASafetyConfig(
|
||||
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
|
||||
config_id=test_config_id,
|
||||
)
|
||||
# Initialize with default parameters
|
||||
test_model = "test-model"
|
||||
guardrails = NeMoGuardrails(config, test_model)
|
||||
def test_init_nemo_guardrails():
|
||||
from llama_stack.providers.remote.safety.nvidia.nvidia import NeMoGuardrails
|
||||
|
||||
# Verify the attributes are set correctly
|
||||
assert guardrails.config_id == test_config_id
|
||||
assert guardrails.model == test_model
|
||||
assert guardrails.threshold == 0.9 # Default value
|
||||
assert guardrails.temperature == 1.0 # Default value
|
||||
assert guardrails.guardrails_service_url == os.environ["NVIDIA_GUARDRAILS_URL"]
|
||||
os.environ["NVIDIA_GUARDRAILS_URL"] = "http://nemo.test"
|
||||
|
||||
# Initialize with custom parameters
|
||||
guardrails = NeMoGuardrails(config, test_model, threshold=0.8, temperature=0.7)
|
||||
test_config_id = "test-custom-config-id"
|
||||
config = NVIDIASafetyConfig(
|
||||
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
|
||||
config_id=test_config_id,
|
||||
)
|
||||
# Initialize with default parameters
|
||||
test_model = "test-model"
|
||||
guardrails = NeMoGuardrails(config, test_model)
|
||||
|
||||
# Verify the attributes are set correctly
|
||||
assert guardrails.config_id == test_config_id
|
||||
assert guardrails.model == test_model
|
||||
assert guardrails.threshold == 0.8
|
||||
assert guardrails.temperature == 0.7
|
||||
assert guardrails.guardrails_service_url == os.environ["NVIDIA_GUARDRAILS_URL"]
|
||||
# Verify the attributes are set correctly
|
||||
assert guardrails.config_id == test_config_id
|
||||
assert guardrails.model == test_model
|
||||
assert guardrails.threshold == 0.9 # Default value
|
||||
assert guardrails.temperature == 1.0 # Default value
|
||||
assert guardrails.guardrails_service_url == os.environ["NVIDIA_GUARDRAILS_URL"]
|
||||
|
||||
def test_init_nemo_guardrails_invalid_temperature(self):
|
||||
from llama_stack.providers.remote.safety.nvidia.nvidia import NeMoGuardrails
|
||||
# Initialize with custom parameters
|
||||
guardrails = NeMoGuardrails(config, test_model, threshold=0.8, temperature=0.7)
|
||||
|
||||
config = NVIDIASafetyConfig(
|
||||
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
|
||||
config_id="test-custom-config-id",
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
NeMoGuardrails(config, "test-model", temperature=0)
|
||||
# Verify the attributes are set correctly
|
||||
assert guardrails.config_id == test_config_id
|
||||
assert guardrails.model == test_model
|
||||
assert guardrails.threshold == 0.8
|
||||
assert guardrails.temperature == 0.7
|
||||
assert guardrails.guardrails_service_url == os.environ["NVIDIA_GUARDRAILS_URL"]
|
||||
|
||||
|
||||
def test_init_nemo_guardrails_invalid_temperature():
|
||||
from llama_stack.providers.remote.safety.nvidia.nvidia import NeMoGuardrails
|
||||
|
||||
os.environ["NVIDIA_GUARDRAILS_URL"] = "http://nemo.test"
|
||||
|
||||
config = NVIDIASafetyConfig(
|
||||
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
|
||||
config_id="test-custom-config-id",
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
NeMoGuardrails(config, "test-model", temperature=0)
|
||||
|
|
|
@ -19,7 +19,8 @@ from llama_stack.distribution.datatypes import (
|
|||
OAuth2JWKSConfig,
|
||||
OAuth2TokenAuthConfig,
|
||||
)
|
||||
from llama_stack.distribution.server.auth import AuthenticationMiddleware
|
||||
from llama_stack.distribution.request_headers import User
|
||||
from llama_stack.distribution.server.auth import AuthenticationMiddleware, _has_required_scope
|
||||
from llama_stack.distribution.server.auth_providers import (
|
||||
get_attributes_from_claims,
|
||||
)
|
||||
|
@ -73,7 +74,7 @@ def http_app(mock_auth_endpoint):
|
|||
),
|
||||
access_policy=[],
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={})
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint():
|
||||
|
@ -111,7 +112,50 @@ def mock_http_middleware(mock_auth_endpoint):
|
|||
),
|
||||
access_policy=[],
|
||||
)
|
||||
return AuthenticationMiddleware(mock_app, auth_config), mock_app
|
||||
return AuthenticationMiddleware(mock_app, auth_config, {}), mock_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_impls():
|
||||
"""Mock implementations for scope testing"""
|
||||
return {}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def scope_middleware_with_mocks(mock_auth_endpoint):
|
||||
"""Create AuthenticationMiddleware with mocked route implementations"""
|
||||
mock_app = AsyncMock()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_config=CustomAuthConfig(
|
||||
type=AuthProviderType.CUSTOM,
|
||||
endpoint=mock_auth_endpoint,
|
||||
),
|
||||
access_policy=[],
|
||||
)
|
||||
middleware = AuthenticationMiddleware(mock_app, auth_config, {})
|
||||
|
||||
# Mock the route_impls to simulate finding routes with required scopes
|
||||
from llama_stack.schema_utils import WebMethod
|
||||
|
||||
scoped_webmethod = WebMethod(route="/test/scoped", method="POST", required_scope="test.read")
|
||||
|
||||
public_webmethod = WebMethod(route="/test/public", method="GET")
|
||||
|
||||
# Mock the route finding logic
|
||||
def mock_find_matching_route(method, path, route_impls):
|
||||
if method == "POST" and path == "/test/scoped":
|
||||
return None, {}, "/test/scoped", scoped_webmethod
|
||||
elif method == "GET" and path == "/test/public":
|
||||
return None, {}, "/test/public", public_webmethod
|
||||
else:
|
||||
raise ValueError("No matching route")
|
||||
|
||||
import llama_stack.distribution.server.auth
|
||||
|
||||
llama_stack.distribution.server.auth.find_matching_route = mock_find_matching_route
|
||||
llama_stack.distribution.server.auth.initialize_route_impls = lambda impls: {}
|
||||
|
||||
return middleware, mock_app
|
||||
|
||||
|
||||
async def mock_post_success(*args, **kwargs):
|
||||
|
@ -138,6 +182,36 @@ async def mock_post_exception(*args, **kwargs):
|
|||
raise Exception("Connection error")
|
||||
|
||||
|
||||
async def mock_post_success_with_scope(*args, **kwargs):
|
||||
"""Mock auth response for user with test.read scope"""
|
||||
return MockResponse(
|
||||
200,
|
||||
{
|
||||
"message": "Authentication successful",
|
||||
"principal": "test-user",
|
||||
"attributes": {
|
||||
"scopes": ["test.read", "other.scope"],
|
||||
"roles": ["user"],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def mock_post_success_no_scope(*args, **kwargs):
|
||||
"""Mock auth response for user without required scope"""
|
||||
return MockResponse(
|
||||
200,
|
||||
{
|
||||
"message": "Authentication successful",
|
||||
"principal": "test-user",
|
||||
"attributes": {
|
||||
"scopes": ["other.scope"],
|
||||
"roles": ["user"],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# HTTP Endpoint Tests
|
||||
def test_missing_auth_header(http_client):
|
||||
response = http_client.get("/test")
|
||||
|
@ -252,7 +326,7 @@ def oauth2_app():
|
|||
),
|
||||
access_policy=[],
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={})
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint():
|
||||
|
@ -351,7 +425,7 @@ def oauth2_app_with_jwks_token():
|
|||
),
|
||||
access_policy=[],
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={})
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint():
|
||||
|
@ -442,7 +516,7 @@ def introspection_app(mock_introspection_endpoint):
|
|||
),
|
||||
access_policy=[],
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={})
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint():
|
||||
|
@ -472,7 +546,7 @@ def introspection_app_with_custom_mapping(mock_introspection_endpoint):
|
|||
),
|
||||
access_policy=[],
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={})
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint():
|
||||
|
@ -581,3 +655,122 @@ def test_valid_introspection_with_custom_mapping_authentication(
|
|||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "Authentication successful"}
|
||||
|
||||
|
||||
# Scope-based authorization tests
|
||||
@patch("httpx.AsyncClient.post", new=mock_post_success_with_scope)
|
||||
async def test_scope_authorization_success(scope_middleware_with_mocks, valid_api_key):
|
||||
"""Test that user with required scope can access protected endpoint"""
|
||||
middleware, mock_app = scope_middleware_with_mocks
|
||||
mock_receive = AsyncMock()
|
||||
mock_send = AsyncMock()
|
||||
|
||||
scope = {
|
||||
"type": "http",
|
||||
"path": "/test/scoped",
|
||||
"method": "POST",
|
||||
"headers": [(b"authorization", f"Bearer {valid_api_key}".encode())],
|
||||
}
|
||||
|
||||
await middleware(scope, mock_receive, mock_send)
|
||||
|
||||
# Should call the downstream app (no 403 error sent)
|
||||
mock_app.assert_called_once_with(scope, mock_receive, mock_send)
|
||||
mock_send.assert_not_called()
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.post", new=mock_post_success_no_scope)
|
||||
async def test_scope_authorization_denied(scope_middleware_with_mocks, valid_api_key):
|
||||
"""Test that user without required scope gets 403 access denied"""
|
||||
middleware, mock_app = scope_middleware_with_mocks
|
||||
mock_receive = AsyncMock()
|
||||
mock_send = AsyncMock()
|
||||
|
||||
scope = {
|
||||
"type": "http",
|
||||
"path": "/test/scoped",
|
||||
"method": "POST",
|
||||
"headers": [(b"authorization", f"Bearer {valid_api_key}".encode())],
|
||||
}
|
||||
|
||||
await middleware(scope, mock_receive, mock_send)
|
||||
|
||||
# Should send 403 error, not call downstream app
|
||||
mock_app.assert_not_called()
|
||||
assert mock_send.call_count == 2 # start + body
|
||||
|
||||
# Check the response
|
||||
start_call = mock_send.call_args_list[0][0][0]
|
||||
assert start_call["status"] == 403
|
||||
|
||||
body_call = mock_send.call_args_list[1][0][0]
|
||||
body_text = body_call["body"].decode()
|
||||
assert "Access denied" in body_text
|
||||
assert "test.read" in body_text
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.post", new=mock_post_success_no_scope)
|
||||
async def test_public_endpoint_no_scope_required(scope_middleware_with_mocks, valid_api_key):
|
||||
"""Test that public endpoints work without specific scopes"""
|
||||
middleware, mock_app = scope_middleware_with_mocks
|
||||
mock_receive = AsyncMock()
|
||||
mock_send = AsyncMock()
|
||||
|
||||
scope = {
|
||||
"type": "http",
|
||||
"path": "/test/public",
|
||||
"method": "GET",
|
||||
"headers": [(b"authorization", f"Bearer {valid_api_key}".encode())],
|
||||
}
|
||||
|
||||
await middleware(scope, mock_receive, mock_send)
|
||||
|
||||
# Should call the downstream app (no error)
|
||||
mock_app.assert_called_once_with(scope, mock_receive, mock_send)
|
||||
mock_send.assert_not_called()
|
||||
|
||||
|
||||
async def test_scope_authorization_no_auth_disabled(scope_middleware_with_mocks):
|
||||
"""Test that when auth is disabled (no user), scope checks are bypassed"""
|
||||
middleware, mock_app = scope_middleware_with_mocks
|
||||
mock_receive = AsyncMock()
|
||||
mock_send = AsyncMock()
|
||||
|
||||
scope = {
|
||||
"type": "http",
|
||||
"path": "/test/scoped",
|
||||
"method": "POST",
|
||||
"headers": [], # No authorization header
|
||||
}
|
||||
|
||||
await middleware(scope, mock_receive, mock_send)
|
||||
|
||||
# Should send 401 auth error, not call downstream app
|
||||
mock_app.assert_not_called()
|
||||
assert mock_send.call_count == 2 # start + body
|
||||
|
||||
# Check the response
|
||||
start_call = mock_send.call_args_list[0][0][0]
|
||||
assert start_call["status"] == 401
|
||||
|
||||
body_call = mock_send.call_args_list[1][0][0]
|
||||
body_text = body_call["body"].decode()
|
||||
assert "Authentication required" in body_text
|
||||
|
||||
|
||||
def test_has_required_scope_function():
|
||||
"""Test the _has_required_scope function directly"""
|
||||
# Test user with required scope
|
||||
user_with_scope = User(principal="test-user", attributes={"scopes": ["test.read", "other.scope"]})
|
||||
assert _has_required_scope("test.read", user_with_scope)
|
||||
|
||||
# Test user without required scope
|
||||
user_without_scope = User(principal="test-user", attributes={"scopes": ["other.scope"]})
|
||||
assert not _has_required_scope("test.read", user_without_scope)
|
||||
|
||||
# Test user with no scopes attribute
|
||||
user_no_scopes = User(principal="test-user", attributes={})
|
||||
assert not _has_required_scope("test.read", user_no_scopes)
|
||||
|
||||
# Test no user (auth disabled)
|
||||
assert _has_required_scope("test.read", None)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue