Merge branch 'main' into nvidia-e2e-notebook

This commit is contained in:
Jash Gulabrai 2025-05-19 09:23:07 -04:00
commit 51b68b4be6
234 changed files with 21943 additions and 7540 deletions

2
.github/CODEOWNERS vendored
View file

@ -2,4 +2,4 @@
# These owners will be the default owners for everything in # These owners will be the default owners for everything in
# the repo. Unless a later match takes precedence, # the repo. Unless a later match takes precedence,
* @ashwinb @yanxi0830 @hardikjshah @dltn @raghotham @dineshyv @vladimirivic @sixianyi0721 @ehhuang @terrytangyuan @SLR722 @leseb * @ashwinb @yanxi0830 @hardikjshah @raghotham @ehhuang @terrytangyuan @leseb @bbrowning

2
.github/TRIAGERS.md vendored
View file

@ -1,2 +1,2 @@
# This file documents Triage members in the Llama Stack community # This file documents Triage members in the Llama Stack community
@franciscojavierarceo @leseb @bbrowning @booxter @franciscojavierarceo @leseb

View file

@ -28,12 +28,13 @@ jobs:
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v4 uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Install uv - name: Install uv
uses: astral-sh/setup-uv@v5 uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1
with: with:
python-version: "3.10" python-version: "3.10"
activate-environment: true
- name: Set Up Environment and Install Dependencies - name: Set Up Environment and Install Dependencies
run: | run: |
@ -43,7 +44,7 @@ jobs:
- name: Install minikube - name: Install minikube
if: ${{ matrix.auth-provider == 'kubernetes' }} if: ${{ matrix.auth-provider == 'kubernetes' }}
uses: medyagh/setup-minikube@latest uses: medyagh/setup-minikube@cea33675329b799adccc9526aa5daccc26cd5052 # v0.0.19
- name: Start minikube - name: Start minikube
if: ${{ matrix.auth-provider == 'kubernetes' }} if: ${{ matrix.auth-provider == 'kubernetes' }}
@ -74,32 +75,8 @@ jobs:
cat <<'EOF' > $run_dir/run.yaml cat <<'EOF' > $run_dir/run.yaml
version: '2' version: '2'
image_name: kube image_name: kube
apis: apis: []
- agents providers: {}
- datasetio
- eval
- inference
- safety
- scoring
- telemetry
- tool_runtime
- vector_io
providers:
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/agents_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/ollama/trace_store.db}
server: server:
port: 8321 port: 8321
EOF EOF

View file

@ -33,7 +33,7 @@ jobs:
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Install uv - name: Install uv
uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0 uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1
with: with:
python-version: "3.10" python-version: "3.10"
activate-environment: true activate-environment: true
@ -58,7 +58,7 @@ jobs:
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct" INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
run: | run: |
source .venv/bin/activate source .venv/bin/activate
nohup uv run llama stack run ./llama_stack/templates/ollama/run.yaml --image-type venv > server.log 2>&1 & LLAMA_STACK_LOG_FILE=server.log nohup uv run llama stack run ./llama_stack/templates/ollama/run.yaml --image-type venv &
- name: Wait for Llama Stack server to be ready - name: Wait for Llama Stack server to be ready
if: matrix.client-type == 'http' if: matrix.client-type == 'http'
@ -85,6 +85,11 @@ jobs:
echo "Ollama health check failed" echo "Ollama health check failed"
exit 1 exit 1
fi fi
- name: Check Storage and Memory Available Before Tests
if: ${{ always() }}
run: |
free -h
df -h
- name: Run Integration Tests - name: Run Integration Tests
env: env:
@ -100,13 +105,20 @@ jobs:
--text-model="meta-llama/Llama-3.2-3B-Instruct" \ --text-model="meta-llama/Llama-3.2-3B-Instruct" \
--embedding-model=all-MiniLM-L6-v2 --embedding-model=all-MiniLM-L6-v2
- name: Check Storage and Memory Available After Tests
if: ${{ always() }}
run: |
free -h
df -h
- name: Write ollama logs to file - name: Write ollama logs to file
if: ${{ always() }}
run: | run: |
sudo journalctl -u ollama.service > ollama.log sudo journalctl -u ollama.service > ollama.log
- name: Upload all logs to artifacts - name: Upload all logs to artifacts
if: always() if: ${{ always() }}
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with: with:
name: logs-${{ github.run_id }}-${{ github.run_attempt }}-${{ matrix.client-type }}-${{ matrix.test-type }} name: logs-${{ github.run_id }}-${{ github.run_attempt }}-${{ matrix.client-type }}-${{ matrix.test-type }}
path: | path: |

View file

@ -56,7 +56,7 @@ jobs:
python-version: '3.10' python-version: '3.10'
- name: Install uv - name: Install uv
uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0 uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1
with: with:
python-version: "3.10" python-version: "3.10"
@ -94,7 +94,7 @@ jobs:
python-version: '3.10' python-version: '3.10'
- name: Install uv - name: Install uv
uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0 uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1
with: with:
python-version: "3.10" python-version: "3.10"
@ -120,7 +120,7 @@ jobs:
python-version: '3.10' python-version: '3.10'
- name: Install uv - name: Install uv
uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0 uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1
with: with:
python-version: "3.10" python-version: "3.10"
@ -132,9 +132,9 @@ jobs:
- name: Build a single provider - name: Build a single provider
run: | run: |
yq -i '.image_type = "container"' llama_stack/templates/dev/build.yaml yq -i '.image_type = "container"' llama_stack/templates/starter/build.yaml
yq -i '.image_name = "test"' llama_stack/templates/dev/build.yaml yq -i '.image_name = "test"' llama_stack/templates/starter/build.yaml
USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --config llama_stack/templates/dev/build.yaml USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --config llama_stack/templates/starter/build.yaml
- name: Inspect the container image entrypoint - name: Inspect the container image entrypoint
run: | run: |
@ -158,7 +158,7 @@ jobs:
python-version: '3.10' python-version: '3.10'
- name: Install uv - name: Install uv
uses: astral-sh/setup-uv@0c5e2b8115b80b4c7c5ddf6ffdd634974642d182 # v5.4.1 uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1
with: with:
python-version: "3.10" python-version: "3.10"
@ -174,14 +174,14 @@ jobs:
.image_type = "container" | .image_type = "container" |
.image_name = "ubi9-test" | .image_name = "ubi9-test" |
.distribution_spec.container_image = "registry.access.redhat.com/ubi9:latest" .distribution_spec.container_image = "registry.access.redhat.com/ubi9:latest"
' llama_stack/templates/dev/build.yaml ' llama_stack/templates/starter/build.yaml
- name: Build dev container (UBI9) - name: Build dev container (UBI9)
env: env:
USE_COPY_NOT_MOUNT: "true" USE_COPY_NOT_MOUNT: "true"
LLAMA_STACK_DIR: "." LLAMA_STACK_DIR: "."
run: | run: |
uv run llama stack build --config llama_stack/templates/dev/build.yaml uv run llama stack build --config llama_stack/templates/starter/build.yaml
- name: Inspect UBI9 image - name: Inspect UBI9 image
run: | run: |

View file

@ -23,10 +23,10 @@ jobs:
# container and point 'uv pip install' to the correct path... # container and point 'uv pip install' to the correct path...
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v4 uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Install uv - name: Install uv
uses: astral-sh/setup-uv@v6 uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1
with: with:
python-version: "3.10" python-version: "3.10"
@ -47,8 +47,8 @@ jobs:
- name: Create provider configuration - name: Create provider configuration
run: | run: |
mkdir -p /tmp/providers.d/remote/inference mkdir -p /home/runner/.llama/providers.d/remote/inference
cp tests/external-provider/llama-stack-provider-ollama/custom_ollama.yaml /tmp/providers.d/remote/inference/custom_ollama.yaml cp tests/external-provider/llama-stack-provider-ollama/custom_ollama.yaml /home/runner/.llama/providers.d/remote/inference/custom_ollama.yaml
- name: Build distro from config file - name: Build distro from config file
run: | run: |
@ -66,7 +66,7 @@ jobs:
- name: Wait for Llama Stack server to be ready - name: Wait for Llama Stack server to be ready
run: | run: |
for i in {1..30}; do for i in {1..30}; do
if ! grep -q "remote::custom_ollama from /tmp/providers.d/remote/inference/custom_ollama.yaml" server.log; then if ! grep -q "remote::custom_ollama from /home/runner/.llama/providers.d/remote/inference/custom_ollama.yaml" server.log; then
echo "Waiting for Llama Stack server to load the provider..." echo "Waiting for Llama Stack server to load the provider..."
sleep 1 sleep 1
else else
@ -75,4 +75,5 @@ jobs:
fi fi
done done
echo "Provider failed to load" echo "Provider failed to load"
cat server.log
exit 1 exit 1

View file

@ -37,7 +37,7 @@ jobs:
with: with:
python-version: ${{ matrix.python }} python-version: ${{ matrix.python }}
- uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0 - uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1
with: with:
python-version: ${{ matrix.python }} python-version: ${{ matrix.python }}
enable-cache: false enable-cache: false

View file

@ -14,6 +14,8 @@ on:
- 'docs/**' - 'docs/**'
- 'pyproject.toml' - 'pyproject.toml'
- '.github/workflows/update-readthedocs.yml' - '.github/workflows/update-readthedocs.yml'
tags:
- '*'
pull_request: pull_request:
branches: branches:
- main - main
@ -41,7 +43,7 @@ jobs:
python-version: '3.11' python-version: '3.11'
- name: Install the latest version of uv - name: Install the latest version of uv
uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0 uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1
- name: Sync with uv - name: Sync with uv
run: uv sync --extra docs run: uv sync --extra docs
@ -61,7 +63,10 @@ jobs:
response=$(curl -X POST \ response=$(curl -X POST \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-d "{\"token\": \"$TOKEN\"}" \ -d "{
\"token\": \"$TOKEN\",
\"version\": \"$GITHUB_REF_NAME\"
}" \
https://readthedocs.org/api/v2/webhook/llama-stack/289768/) https://readthedocs.org/api/v2/webhook/llama-stack/289768/)
echo "Response: $response" echo "Response: $response"

View file

@ -106,6 +106,14 @@ repos:
pass_filenames: false pass_filenames: false
require_serial: true require_serial: true
files: ^llama_stack/apis/|^docs/openapi_generator/ files: ^llama_stack/apis/|^docs/openapi_generator/
- id: check-workflows-use-hashes
name: Check GitHub Actions use SHA-pinned actions
entry: ./scripts/check-workflows-use-hashes.sh
language: system
pass_filenames: false
require_serial: true
always_run: true
files: ^\.github/workflows/.*\.ya?ml$
ci: ci:
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks

View file

@ -1,5 +1,26 @@
# Changelog # Changelog
# v0.2.5
Published on: 2025-05-04T20:16:49Z
---
# v0.2.4
Published on: 2025-04-29T17:26:01Z
## Highlights
* One-liner to install and run Llama Stack yay! by @reluctantfuturist in https://github.com/meta-llama/llama-stack/pull/1383
* support for NVIDIA NeMo datastore by @raspawar in https://github.com/meta-llama/llama-stack/pull/1852
* (yuge!) Kubernetes authentication by @leseb in https://github.com/meta-llama/llama-stack/pull/1778
* (yuge!) OpenAI Responses API by @bbrowning in https://github.com/meta-llama/llama-stack/pull/1989
* add api.llama provider, llama-guard-4 model by @ashwinb in https://github.com/meta-llama/llama-stack/pull/2058
---
# v0.2.3 # v0.2.3
Published on: 2025-04-25T22:46:21Z Published on: 2025-04-25T22:46:21Z

View file

@ -110,25 +110,9 @@ uv run pre-commit run --all-files
> [!CAUTION] > [!CAUTION]
> Before pushing your changes, make sure that the pre-commit hooks have passed successfully. > Before pushing your changes, make sure that the pre-commit hooks have passed successfully.
## Running unit tests ## Running tests
You can run the unit tests by running: You can find the Llama Stack testing documentation here [here](tests/README.md).
```bash
source .venv/bin/activate
./scripts/unit-tests.sh
```
If you'd like to run for a non-default version of Python (currently 3.10), pass `PYTHON_VERSION` variable as follows:
```
source .venv/bin/activate
PYTHON_VERSION=3.13 ./scripts/unit-tests.sh
```
## Running integration tests
You can run integration tests following the instructions [here](tests/integration/README.md).
## Adding a new dependency to the project ## Adding a new dependency to the project
@ -153,6 +137,8 @@ uv sync
justification for bypassing the check. justification for bypassing the check.
* When using `# type: ignore` to suppress a mypy warning, include a comment explaining the * When using `# type: ignore` to suppress a mypy warning, include a comment explaining the
justification for bypassing the check. justification for bypassing the check.
* Don't use unicode characters in the codebase. ASCII-only is preferred for compatibility or
readability reasons.
## Common Tasks ## Common Tasks

View file

@ -7,7 +7,7 @@
[![Unit Tests](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml/badge.svg?branch=main)](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml?query=branch%3Amain) [![Unit Tests](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml/badge.svg?branch=main)](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml?query=branch%3Amain)
[![Integration Tests](https://github.com/meta-llama/llama-stack/actions/workflows/integration-tests.yml/badge.svg?branch=main)](https://github.com/meta-llama/llama-stack/actions/workflows/integration-tests.yml?query=branch%3Amain) [![Integration Tests](https://github.com/meta-llama/llama-stack/actions/workflows/integration-tests.yml/badge.svg?branch=main)](https://github.com/meta-llama/llama-stack/actions/workflows/integration-tests.yml?query=branch%3Amain)
[**Quick Start**](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) | [**Documentation**](https://llama-stack.readthedocs.io/en/latest/index.html) | [**Colab Notebook**](./docs/getting_started.ipynb) [**Quick Start**](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) | [**Documentation**](https://llama-stack.readthedocs.io/en/latest/index.html) | [**Colab Notebook**](./docs/getting_started.ipynb) | [**Discord**](https://discord.gg/llama-stack)
### ✨🎉 Llama 4 Support 🎉✨ ### ✨🎉 Llama 4 Support 🎉✨
We released [Version 0.2.0](https://github.com/meta-llama/llama-stack/releases/tag/v0.2.0) with support for the Llama 4 herd of models released by Meta. We released [Version 0.2.0](https://github.com/meta-llama/llama-stack/releases/tag/v0.2.0) with support for the Llama 4 herd of models released by Meta.

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -1050,8 +1050,6 @@
"text/html": [ "text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">ToolGroup</span><span style=\"font-weight: bold\">(</span>\n", "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">ToolGroup</span><span style=\"font-weight: bold\">(</span>\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">identifier</span>=<span style=\"color: #008000; text-decoration-color: #008000\">'builtin::code_interpreter'</span>,\n", "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">identifier</span>=<span style=\"color: #008000; text-decoration-color: #008000\">'builtin::code_interpreter'</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">provider_id</span>=<span style=\"color: #008000; text-decoration-color: #008000\">'code-interpreter'</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">provider_resource_id</span>=<span style=\"color: #008000; text-decoration-color: #008000\">'builtin::code_interpreter'</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">type</span>=<span style=\"color: #008000; text-decoration-color: #008000\">'tool_group'</span>,\n", "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">type</span>=<span style=\"color: #008000; text-decoration-color: #008000\">'tool_group'</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">args</span>=<span style=\"color: #800080; text-decoration-color: #800080; font-style: italic\">None</span>,\n", "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">args</span>=<span style=\"color: #800080; text-decoration-color: #800080; font-style: italic\">None</span>,\n",
"<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">mcp_endpoint</span>=<span style=\"color: #800080; text-decoration-color: #800080; font-style: italic\">None</span>\n", "<span style=\"color: #7fbf7f; text-decoration-color: #7fbf7f\">│ </span><span style=\"color: #808000; text-decoration-color: #808000\">mcp_endpoint</span>=<span style=\"color: #800080; text-decoration-color: #800080; font-style: italic\">None</span>\n",
@ -1061,7 +1059,6 @@
"text/plain": [ "text/plain": [
"\u001b[1;35mToolGroup\u001b[0m\u001b[1m(\u001b[0m\n", "\u001b[1;35mToolGroup\u001b[0m\u001b[1m(\u001b[0m\n",
"\u001b[2;32m│ \u001b[0m\u001b[33midentifier\u001b[0m=\u001b[32m'builtin::code_interpreter'\u001b[0m,\n", "\u001b[2;32m│ \u001b[0m\u001b[33midentifier\u001b[0m=\u001b[32m'builtin::code_interpreter'\u001b[0m,\n",
"\u001b[2;32m│ \u001b[0m\u001b[33mprovider_id\u001b[0m=\u001b[32m'code-interpreter'\u001b[0m,\n",
"\u001b[2;32m│ \u001b[0m\u001b[33mprovider_resource_id\u001b[0m=\u001b[32m'builtin::code_interpreter'\u001b[0m,\n", "\u001b[2;32m│ \u001b[0m\u001b[33mprovider_resource_id\u001b[0m=\u001b[32m'builtin::code_interpreter'\u001b[0m,\n",
"\u001b[2;32m│ \u001b[0m\u001b[33mtype\u001b[0m=\u001b[32m'tool_group'\u001b[0m,\n", "\u001b[2;32m│ \u001b[0m\u001b[33mtype\u001b[0m=\u001b[32m'tool_group'\u001b[0m,\n",
"\u001b[2;32m│ \u001b[0m\u001b[33margs\u001b[0m=\u001b[3;35mNone\u001b[0m,\n", "\u001b[2;32m│ \u001b[0m\u001b[33margs\u001b[0m=\u001b[3;35mNone\u001b[0m,\n",

View file

@ -337,9 +337,6 @@
" provider_id: tavily-search\n", " provider_id: tavily-search\n",
" provider_type: remote::tavily-search\n", " provider_type: remote::tavily-search\n",
" - config: <span style=\"font-weight: bold\">{}</span>\n", " - config: <span style=\"font-weight: bold\">{}</span>\n",
" provider_id: code-interpreter\n",
" provider_type: inlin<span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">e::c</span>ode-interpreter\n",
" - config: <span style=\"font-weight: bold\">{}</span>\n",
" provider_id: rag-runtime\n", " provider_id: rag-runtime\n",
" provider_type: inline::rag-runtime\n", " provider_type: inline::rag-runtime\n",
" - config: <span style=\"font-weight: bold\">{}</span>\n", " - config: <span style=\"font-weight: bold\">{}</span>\n",
@ -378,10 +375,6 @@
" toolgroup_id: builtin::rag\n", " toolgroup_id: builtin::rag\n",
"- args: null\n", "- args: null\n",
" mcp_endpoint: null\n", " mcp_endpoint: null\n",
" provider_id: code-interpreter\n",
" toolgroup_id: builtin::code_interpreter\n",
"- args: null\n",
" mcp_endpoint: null\n",
" provider_id: wolfram-alpha\n", " provider_id: wolfram-alpha\n",
" toolgroup_id: builtin::wolfram_alpha\n", " toolgroup_id: builtin::wolfram_alpha\n",
"vector_dbs: <span style=\"font-weight: bold\">[]</span>\n", "vector_dbs: <span style=\"font-weight: bold\">[]</span>\n",
@ -617,9 +610,6 @@
" provider_id: tavily-search\n", " provider_id: tavily-search\n",
" provider_type: remote::tavily-search\n", " provider_type: remote::tavily-search\n",
" - config: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", " - config: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
" provider_id: code-interpreter\n",
" provider_type: inlin\u001b[1;92me::c\u001b[0mode-interpreter\n",
" - config: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
" provider_id: rag-runtime\n", " provider_id: rag-runtime\n",
" provider_type: inline::rag-runtime\n", " provider_type: inline::rag-runtime\n",
" - config: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", " - config: \u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n",
@ -658,10 +648,6 @@
" toolgroup_id: builtin::rag\n", " toolgroup_id: builtin::rag\n",
"- args: null\n", "- args: null\n",
" mcp_endpoint: null\n", " mcp_endpoint: null\n",
" provider_id: code-interpreter\n",
" toolgroup_id: builtin::code_interpreter\n",
"- args: null\n",
" mcp_endpoint: null\n",
" provider_id: wolfram-alpha\n", " provider_id: wolfram-alpha\n",
" toolgroup_id: builtin::wolfram_alpha\n", " toolgroup_id: builtin::wolfram_alpha\n",
"vector_dbs: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n", "vector_dbs: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n",

View file

@ -840,7 +840,6 @@
" \"memory_optimizations.rst\",\n", " \"memory_optimizations.rst\",\n",
" \"chat.rst\",\n", " \"chat.rst\",\n",
" \"llama3.rst\",\n", " \"llama3.rst\",\n",
" \"datasets.rst\",\n",
" \"qat_finetune.rst\",\n", " \"qat_finetune.rst\",\n",
" \"lora_finetune.rst\",\n", " \"lora_finetune.rst\",\n",
"]\n", "]\n",
@ -1586,7 +1585,6 @@
" \"memory_optimizations.rst\",\n", " \"memory_optimizations.rst\",\n",
" \"chat.rst\",\n", " \"chat.rst\",\n",
" \"llama3.rst\",\n", " \"llama3.rst\",\n",
" \"datasets.rst\",\n",
" \"qat_finetune.rst\",\n", " \"qat_finetune.rst\",\n",
" \"lora_finetune.rst\",\n", " \"lora_finetune.rst\",\n",
"]\n", "]\n",

View file

@ -44,7 +44,7 @@ def main(output_dir: str):
if return_type_errors: if return_type_errors:
print("\nAPI Method Return Type Validation Errors:\n") print("\nAPI Method Return Type Validation Errors:\n")
for error in return_type_errors: for error in return_type_errors:
print(error) print(error, file=sys.stderr)
sys.exit(1) sys.exit(1)
now = str(datetime.now()) now = str(datetime.now())
print( print(

View file

@ -759,7 +759,7 @@ class Generator:
) )
return Operation( return Operation(
tags=[op.defining_class.__name__], tags=[getattr(op.defining_class, "API_NAMESPACE", op.defining_class.__name__)],
summary=None, summary=None,
# summary=doc_string.short_description, # summary=doc_string.short_description,
description=description, description=description,
@ -805,6 +805,8 @@ class Generator:
operation_tags: List[Tag] = [] operation_tags: List[Tag] = []
for cls in endpoint_classes: for cls in endpoint_classes:
doc_string = parse_type(cls) doc_string = parse_type(cls)
if hasattr(cls, "API_NAMESPACE") and cls.API_NAMESPACE != cls.__name__:
continue
operation_tags.append( operation_tags.append(
Tag( Tag(
name=cls.__name__, name=cls.__name__,

View file

@ -174,14 +174,64 @@ def _validate_list_parameters_contain_data(method) -> str | None:
return "does not have a mandatory data attribute containing the list of objects" return "does not have a mandatory data attribute containing the list of objects"
def _validate_has_ellipsis(method) -> str | None:
source = inspect.getsource(method)
if "..." not in source and not "NotImplementedError" in source:
return "does not contain ellipsis (...) in its implementation"
def _validate_has_return_in_docstring(method) -> str | None:
source = inspect.getsource(method)
return_type = method.__annotations__.get('return')
if return_type is not None and return_type != type(None) and ":returns:" not in source:
return "does not have a ':returns:' in its docstring"
def _validate_has_params_in_docstring(method) -> str | None:
source = inspect.getsource(method)
sig = inspect.signature(method)
# Only check if the method has more than one parameter
if len(sig.parameters) > 1 and ":param" not in source:
return "does not have a ':param' in its docstring"
def _validate_has_no_return_none_in_docstring(method) -> str | None:
source = inspect.getsource(method)
return_type = method.__annotations__.get('return')
if return_type is None and ":returns: None" in source:
return "has a ':returns: None' in its docstring which is redundant for None-returning functions"
def _validate_docstring_lines_end_with_dot(method) -> str | None:
docstring = inspect.getdoc(method)
if docstring is None:
return None
lines = docstring.split('\n')
for line in lines:
line = line.strip()
if line and not any(line.endswith(char) for char in '.:{}[]()",'):
return f"docstring line '{line}' does not end with a valid character: . : {{ }} [ ] ( ) , \""
_VALIDATORS = { _VALIDATORS = {
"GET": [ "GET": [
_validate_api_method_return_type, _validate_api_method_return_type,
_validate_list_parameters_contain_data, _validate_list_parameters_contain_data,
_validate_api_method_doesnt_return_list, _validate_api_method_doesnt_return_list,
_validate_has_ellipsis,
_validate_has_return_in_docstring,
_validate_has_params_in_docstring,
_validate_docstring_lines_end_with_dot,
], ],
"DELETE": [ "DELETE": [
_validate_api_delete_method_returns_none, _validate_api_delete_method_returns_none,
_validate_has_ellipsis,
_validate_has_return_in_docstring,
_validate_has_params_in_docstring,
_validate_has_no_return_none_in_docstring
],
"POST": [
_validate_has_ellipsis,
_validate_has_return_in_docstring,
_validate_has_params_in_docstring,
_validate_has_no_return_none_in_docstring,
_validate_docstring_lines_end_with_dot,
], ],
} }

View file

@ -51,6 +51,7 @@ chunks = [
"mime_type": "text/plain", "mime_type": "text/plain",
"metadata": { "metadata": {
"document_id": "doc1", "document_id": "doc1",
"author": "Jane Doe",
}, },
}, },
] ]
@ -98,6 +99,17 @@ results = client.tool_runtime.rag_tool.query(
) )
``` ```
You can configure how the RAG tool adds metadata to the context if you find it useful for your application. Simply add:
```python
# Query documents
results = client.tool_runtime.rag_tool.query(
vector_db_ids=[vector_db_id],
content="What do you know about...",
query_config={
"chunk_template": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n",
},
)
```
### Building RAG-Enhanced Agents ### Building RAG-Enhanced Agents
One of the most powerful patterns is combining agents with RAG capabilities. Here's a complete example: One of the most powerful patterns is combining agents with RAG capabilities. Here's a complete example:
@ -115,6 +127,12 @@ agent = Agent(
"name": "builtin::rag/knowledge_search", "name": "builtin::rag/knowledge_search",
"args": { "args": {
"vector_db_ids": [vector_db_id], "vector_db_ids": [vector_db_id],
# Defaults
"query_config": {
"chunk_size_in_tokens": 512,
"chunk_overlap_in_tokens": 0,
"chunk_template": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n",
},
}, },
} }
], ],

View file

@ -165,34 +165,6 @@ all_tools = client.tools.list_tools()
group_tools = client.tools.list_tools(toolgroup_id="search_tools") group_tools = client.tools.list_tools(toolgroup_id="search_tools")
``` ```
## Simple Example: Using an Agent with the Code-Interpreter Tool
```python
from llama_stack_client import Agent
# Instantiate the AI agent with the given configuration
agent = Agent(
client,
name="code-interpreter",
description="A code interpreter agent for executing Python code snippets",
instructions="""
You are a highly reliable, concise, and precise assistant.
Always show the generated code, never generate your own code, and never anticipate results.
""",
model="meta-llama/Llama-3.2-3B-Instruct",
tools=["builtin::code_interpreter"],
max_infer_iters=5,
)
# Start a session
session_id = agent.create_session("tool_session")
# Send a query to the AI agent for code execution
response = agent.create_turn(
messages=[{"role": "user", "content": "Run this code: print(3 ** 4 - 5 * 2)"}],
session_id=session_id,
)
```
## Simple Example 2: Using an Agent with the Web Search Tool ## Simple Example 2: Using an Agent with the Web Search Tool
1. Start by registering a Tavily API key at [Tavily](https://tavily.com/). 1. Start by registering a Tavily API key at [Tavily](https://tavily.com/).
2. [Optional] Provide the API key directly to the Llama Stack server 2. [Optional] Provide the API key directly to the Llama Stack server

View file

@ -110,6 +110,8 @@ html_theme_options = {
"canonical_url": "https://github.com/meta-llama/llama-stack", "canonical_url": "https://github.com/meta-llama/llama-stack",
"collapse_navigation": False, "collapse_navigation": False,
# "style_nav_header_background": "#c3c9d4", # "style_nav_header_background": "#c3c9d4",
'display_version': True,
'version_selector': True,
} }
default_dark_mode = False default_dark_mode = False

View file

@ -6,7 +6,7 @@ This guide will walk you through the process of adding a new API provider to Lla
- Begin by reviewing the [core concepts](../concepts/index.md) of Llama Stack and choose the API your provider belongs to (Inference, Safety, VectorIO, etc.) - Begin by reviewing the [core concepts](../concepts/index.md) of Llama Stack and choose the API your provider belongs to (Inference, Safety, VectorIO, etc.)
- Determine the provider type ({repopath}`Remote::llama_stack/providers/remote` or {repopath}`Inline::llama_stack/providers/inline`). Remote providers make requests to external services, while inline providers execute implementation locally. - Determine the provider type ({repopath}`Remote::llama_stack/providers/remote` or {repopath}`Inline::llama_stack/providers/inline`). Remote providers make requests to external services, while inline providers execute implementation locally.
- Add your provider to the appropriate {repopath}`Registry::llama_stack/providers/registry/`. Specify pip dependencies necessary. - Add your provider to the appropriate {repopath}`Registry::llama_stack/providers/registry/`. Specify pip dependencies necessary.
- Update any distribution {repopath}`Templates::llama_stack/templates/` build.yaml and run.yaml files if they should include your provider by default. Run {repopath}`./scripts/distro_codegen.py` if necessary. Note that `distro_codegen.py` will fail if the new provider causes any distribution template to attempt to import provider-specific dependencies. This usually means the distribution's `get_distribution_template()` code path should only import any necessary Config or model alias definitions from each provider and not the provider's actual implementation. - Update any distribution {repopath}`Templates::llama_stack/templates/` `build.yaml` and `run.yaml` files if they should include your provider by default. Run {repopath}`./scripts/distro_codegen.py` if necessary. Note that `distro_codegen.py` will fail if the new provider causes any distribution template to attempt to import provider-specific dependencies. This usually means the distribution's `get_distribution_template()` code path should only import any necessary Config or model alias definitions from each provider and not the provider's actual implementation.
Here are some example PRs to help you get started: Here are some example PRs to help you get started:
@ -33,6 +33,7 @@ Note that each provider's `sample_run_config()` method (in the configuration cla
Unit tests are located in {repopath}`tests/unit`. Provider-specific unit tests are located in {repopath}`tests/unit/providers`. These tests are all run automatically as part of the CI process. Unit tests are located in {repopath}`tests/unit`. Provider-specific unit tests are located in {repopath}`tests/unit/providers`. These tests are all run automatically as part of the CI process.
Consult {repopath}`tests/unit/README.md` for more details on how to run the tests manually.
### 3. Additional end-to-end testing ### 3. Additional end-to-end testing

View file

@ -178,7 +178,7 @@ image_name: ollama
image_type: conda image_type: conda
# If some providers are external, you can specify the path to the implementation # If some providers are external, you can specify the path to the implementation
external_providers_dir: /etc/llama-stack/providers.d external_providers_dir: ~/.llama/providers.d
``` ```
``` ```
@ -206,7 +206,7 @@ distribution_spec:
image_type: container image_type: container
image_name: ci-test image_name: ci-test
# Path to external provider implementations # Path to external provider implementations
external_providers_dir: /etc/llama-stack/providers.d external_providers_dir: ~/.llama/providers.d
``` ```
Here's an example for a custom Ollama provider: Here's an example for a custom Ollama provider:
@ -271,7 +271,7 @@ Now, let's start the Llama Stack Distribution Server. You will need the YAML con
``` ```
llama stack run -h llama stack run -h
usage: llama stack run [-h] [--port PORT] [--image-name IMAGE_NAME] [--disable-ipv6] [--env KEY=VALUE] [--tls-keyfile TLS_KEYFILE] [--tls-certfile TLS_CERTFILE] usage: llama stack run [-h] [--port PORT] [--image-name IMAGE_NAME] [--env KEY=VALUE] [--tls-keyfile TLS_KEYFILE] [--tls-certfile TLS_CERTFILE]
[--image-type {conda,container,venv}] [--image-type {conda,container,venv}]
config config
@ -285,7 +285,6 @@ options:
--port PORT Port to run the server on. It can also be passed via the env var LLAMA_STACK_PORT. (default: 8321) --port PORT Port to run the server on. It can also be passed via the env var LLAMA_STACK_PORT. (default: 8321)
--image-name IMAGE_NAME --image-name IMAGE_NAME
Name of the image to run. Defaults to the current environment (default: None) Name of the image to run. Defaults to the current environment (default: None)
--disable-ipv6 Disable IPv6 support (default: False)
--env KEY=VALUE Environment variables to pass to the server in KEY=VALUE format. Can be specified multiple times. (default: []) --env KEY=VALUE Environment variables to pass to the server in KEY=VALUE format. Can be specified multiple times. (default: [])
--tls-keyfile TLS_KEYFILE --tls-keyfile TLS_KEYFILE
Path to TLS key file for HTTPS (default: None) Path to TLS key file for HTTPS (default: None)

View file

@ -172,7 +172,7 @@ spec:
- name: llama-stack - name: llama-stack
image: localhost/llama-stack-run-k8s:latest image: localhost/llama-stack-run-k8s:latest
imagePullPolicy: IfNotPresent imagePullPolicy: IfNotPresent
command: ["python", "-m", "llama_stack.distribution.server.server", "--yaml-config", "/app/config.yaml"] command: ["python", "-m", "llama_stack.distribution.server.server", "--config", "/app/config.yaml"]
ports: ports:
- containerPort: 5000 - containerPort: 5000
volumeMounts: volumeMounts:

View file

@ -18,7 +18,7 @@ The `llamastack/distribution-watsonx` distribution consists of the following pro
| agents | `inline::meta-reference` | | agents | `inline::meta-reference` |
| datasetio | `remote::huggingface`, `inline::localfs` | | datasetio | `remote::huggingface`, `inline::localfs` |
| eval | `inline::meta-reference` | | eval | `inline::meta-reference` |
| inference | `remote::watsonx` | | inference | `remote::watsonx`, `inline::sentence-transformers` |
| safety | `inline::llama-guard` | | safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | | scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` | | telemetry | `inline::meta-reference` |
@ -70,7 +70,7 @@ docker run \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ./run.yaml:/root/my-run.yaml \ -v ./run.yaml:/root/my-run.yaml \
llamastack/distribution-watsonx \ llamastack/distribution-watsonx \
--yaml-config /root/my-run.yaml \ --config /root/my-run.yaml \
--port $LLAMA_STACK_PORT \ --port $LLAMA_STACK_PORT \
--env WATSONX_API_KEY=$WATSONX_API_KEY \ --env WATSONX_API_KEY=$WATSONX_API_KEY \
--env WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID \ --env WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID \

View file

@ -52,7 +52,7 @@ docker run \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ./run.yaml:/root/my-run.yaml \ -v ./run.yaml:/root/my-run.yaml \
llamastack/distribution-cerebras \ llamastack/distribution-cerebras \
--yaml-config /root/my-run.yaml \ --config /root/my-run.yaml \
--port $LLAMA_STACK_PORT \ --port $LLAMA_STACK_PORT \
--env CEREBRAS_API_KEY=$CEREBRAS_API_KEY --env CEREBRAS_API_KEY=$CEREBRAS_API_KEY
``` ```

View file

@ -23,7 +23,7 @@ The `llamastack/distribution-dell` distribution consists of the following provid
| safety | `inline::llama-guard` | | safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | | scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` | | telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime` | | tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime` |
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` | | vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
@ -155,7 +155,7 @@ docker run \
-v $HOME/.llama:/root/.llama \ -v $HOME/.llama:/root/.llama \
-v ./llama_stack/templates/tgi/run-with-safety.yaml:/root/my-run.yaml \ -v ./llama_stack/templates/tgi/run-with-safety.yaml:/root/my-run.yaml \
llamastack/distribution-dell \ llamastack/distribution-dell \
--yaml-config /root/my-run.yaml \ --config /root/my-run.yaml \
--port $LLAMA_STACK_PORT \ --port $LLAMA_STACK_PORT \
--env INFERENCE_MODEL=$INFERENCE_MODEL \ --env INFERENCE_MODEL=$INFERENCE_MODEL \
--env DEH_URL=$DEH_URL \ --env DEH_URL=$DEH_URL \

View file

@ -144,7 +144,7 @@ docker run \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ./run.yaml:/root/my-run.yaml \ -v ./run.yaml:/root/my-run.yaml \
llamastack/distribution-nvidia \ llamastack/distribution-nvidia \
--yaml-config /root/my-run.yaml \ --config /root/my-run.yaml \
--port $LLAMA_STACK_PORT \ --port $LLAMA_STACK_PORT \
--env NVIDIA_API_KEY=$NVIDIA_API_KEY --env NVIDIA_API_KEY=$NVIDIA_API_KEY
``` ```

View file

@ -19,6 +19,7 @@ The `llamastack/distribution-ollama` distribution consists of the following prov
| datasetio | `remote::huggingface`, `inline::localfs` | | datasetio | `remote::huggingface`, `inline::localfs` |
| eval | `inline::meta-reference` | | eval | `inline::meta-reference` |
| inference | `remote::ollama` | | inference | `remote::ollama` |
| post_training | `inline::huggingface` |
| safety | `inline::llama-guard` | | safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | | scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` | | telemetry | `inline::meta-reference` |
@ -97,7 +98,7 @@ docker run \
-v ~/.llama:/root/.llama \ -v ~/.llama:/root/.llama \
-v ./llama_stack/templates/ollama/run-with-safety.yaml:/root/my-run.yaml \ -v ./llama_stack/templates/ollama/run-with-safety.yaml:/root/my-run.yaml \
llamastack/distribution-ollama \ llamastack/distribution-ollama \
--yaml-config /root/my-run.yaml \ --config /root/my-run.yaml \
--port $LLAMA_STACK_PORT \ --port $LLAMA_STACK_PORT \
--env INFERENCE_MODEL=$INFERENCE_MODEL \ --env INFERENCE_MODEL=$INFERENCE_MODEL \
--env SAFETY_MODEL=$SAFETY_MODEL \ --env SAFETY_MODEL=$SAFETY_MODEL \

View file

@ -233,7 +233,7 @@ docker run \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ./llama_stack/templates/remote-vllm/run.yaml:/root/my-run.yaml \ -v ./llama_stack/templates/remote-vllm/run.yaml:/root/my-run.yaml \
llamastack/distribution-remote-vllm \ llamastack/distribution-remote-vllm \
--yaml-config /root/my-run.yaml \ --config /root/my-run.yaml \
--port $LLAMA_STACK_PORT \ --port $LLAMA_STACK_PORT \
--env INFERENCE_MODEL=$INFERENCE_MODEL \ --env INFERENCE_MODEL=$INFERENCE_MODEL \
--env VLLM_URL=http://host.docker.internal:$INFERENCE_PORT/v1 --env VLLM_URL=http://host.docker.internal:$INFERENCE_PORT/v1
@ -255,7 +255,7 @@ docker run \
-v ~/.llama:/root/.llama \ -v ~/.llama:/root/.llama \
-v ./llama_stack/templates/remote-vllm/run-with-safety.yaml:/root/my-run.yaml \ -v ./llama_stack/templates/remote-vllm/run-with-safety.yaml:/root/my-run.yaml \
llamastack/distribution-remote-vllm \ llamastack/distribution-remote-vllm \
--yaml-config /root/my-run.yaml \ --config /root/my-run.yaml \
--port $LLAMA_STACK_PORT \ --port $LLAMA_STACK_PORT \
--env INFERENCE_MODEL=$INFERENCE_MODEL \ --env INFERENCE_MODEL=$INFERENCE_MODEL \
--env VLLM_URL=http://host.docker.internal:$INFERENCE_PORT/v1 \ --env VLLM_URL=http://host.docker.internal:$INFERENCE_PORT/v1 \

View file

@ -16,10 +16,10 @@ The `llamastack/distribution-sambanova` distribution consists of the following p
| API | Provider(s) | | API | Provider(s) |
|-----|-------------| |-----|-------------|
| agents | `inline::meta-reference` | | agents | `inline::meta-reference` |
| inference | `remote::sambanova` | | inference | `remote::sambanova`, `inline::sentence-transformers` |
| safety | `inline::llama-guard` | | safety | `inline::llama-guard` |
| telemetry | `inline::meta-reference` | | telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime` | | tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` | | vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
@ -28,22 +28,22 @@ The `llamastack/distribution-sambanova` distribution consists of the following p
The following environment variables can be configured: The following environment variables can be configured:
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `8321`) - `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
- `SAMBANOVA_API_KEY`: SambaNova.AI API Key (default: ``) - `SAMBANOVA_API_KEY`: SambaNova API Key (default: ``)
### Models ### Models
The following models are available by default: The following models are available by default:
- `Meta-Llama-3.1-8B-Instruct (aliases: meta-llama/Llama-3.1-8B-Instruct)` - `sambanova/Meta-Llama-3.1-8B-Instruct (aliases: meta-llama/Llama-3.1-8B-Instruct)`
- `Meta-Llama-3.1-70B-Instruct (aliases: meta-llama/Llama-3.1-70B-Instruct)` - `sambanova/Meta-Llama-3.1-405B-Instruct (aliases: meta-llama/Llama-3.1-405B-Instruct-FP8)`
- `Meta-Llama-3.1-405B-Instruct (aliases: meta-llama/Llama-3.1-405B-Instruct-FP8)` - `sambanova/Meta-Llama-3.2-1B-Instruct (aliases: meta-llama/Llama-3.2-1B-Instruct)`
- `Meta-Llama-3.2-1B-Instruct (aliases: meta-llama/Llama-3.2-1B-Instruct)` - `sambanova/Meta-Llama-3.2-3B-Instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)`
- `Meta-Llama-3.2-3B-Instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)` - `sambanova/Meta-Llama-3.3-70B-Instruct (aliases: meta-llama/Llama-3.3-70B-Instruct)`
- `Meta-Llama-3.3-70B-Instruct (aliases: meta-llama/Llama-3.3-70B-Instruct)` - `sambanova/Llama-3.2-11B-Vision-Instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
- `Llama-3.2-11B-Vision-Instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)` - `sambanova/Llama-3.2-90B-Vision-Instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`
- `Llama-3.2-90B-Vision-Instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)` - `sambanova/Llama-4-Scout-17B-16E-Instruct (aliases: meta-llama/Llama-4-Scout-17B-16E-Instruct)`
- `Meta-Llama-Guard-3-8B (aliases: meta-llama/Llama-Guard-3-8B)` - `sambanova/Llama-4-Maverick-17B-128E-Instruct (aliases: meta-llama/Llama-4-Maverick-17B-128E-Instruct)`
- `Llama-4-Scout-17B-16E-Instruct (aliases: meta-llama/Llama-4-Scout-17B-16E-Instruct)` - `sambanova/Meta-Llama-Guard-3-8B (aliases: meta-llama/Llama-Guard-3-8B)`
### Prerequisite: API Keys ### Prerequisite: API Keys

View file

@ -117,7 +117,7 @@ docker run \
-v ~/.llama:/root/.llama \ -v ~/.llama:/root/.llama \
-v ./llama_stack/templates/tgi/run-with-safety.yaml:/root/my-run.yaml \ -v ./llama_stack/templates/tgi/run-with-safety.yaml:/root/my-run.yaml \
llamastack/distribution-tgi \ llamastack/distribution-tgi \
--yaml-config /root/my-run.yaml \ --config /root/my-run.yaml \
--port $LLAMA_STACK_PORT \ --port $LLAMA_STACK_PORT \
--env INFERENCE_MODEL=$INFERENCE_MODEL \ --env INFERENCE_MODEL=$INFERENCE_MODEL \
--env TGI_URL=http://host.docker.internal:$INFERENCE_PORT \ --env TGI_URL=http://host.docker.internal:$INFERENCE_PORT \

View file

@ -42,7 +42,7 @@ powershell -ExecutionPolicy ByPass -c "irm https://astral.sh/uv/install.ps1 | ie
Setup your virtual environment. Setup your virtual environment.
```bash ```bash
uv venv --python 3.10 uv sync --python 3.10
source .venv/bin/activate source .venv/bin/activate
``` ```
## Step 2: Run Llama Stack ## Step 2: Run Llama Stack
@ -445,7 +445,6 @@ from llama_stack_client import LlamaStackClient
from llama_stack_client import Agent, AgentEventLogger from llama_stack_client import Agent, AgentEventLogger
from llama_stack_client.types import Document from llama_stack_client.types import Document
import uuid import uuid
from termcolor import cprint
client = LlamaStackClient(base_url="http://localhost:8321") client = LlamaStackClient(base_url="http://localhost:8321")
@ -463,7 +462,6 @@ urls = [
"memory_optimizations.rst", "memory_optimizations.rst",
"chat.rst", "chat.rst",
"llama3.rst", "llama3.rst",
"datasets.rst",
"qat_finetune.rst", "qat_finetune.rst",
"lora_finetune.rst", "lora_finetune.rst",
] ]

View file

@ -10,7 +10,7 @@ Llama Stack supports external providers that live outside of the main codebase.
To enable external providers, you need to configure the `external_providers_dir` in your Llama Stack configuration. This directory should contain your external provider specifications: To enable external providers, you need to configure the `external_providers_dir` in your Llama Stack configuration. This directory should contain your external provider specifications:
```yaml ```yaml
external_providers_dir: /etc/llama-stack/providers.d/ external_providers_dir: ~/.llama/providers.d/
``` ```
## Directory Structure ## Directory Structure
@ -53,7 +53,7 @@ Here's a list of known external providers that you can use with Llama Stack:
| Name | Description | API | Type | Repository | | Name | Description | API | Type | Repository |
|------|-------------|-----|------|------------| |------|-------------|-----|------|------------|
| KubeFlow Training | Train models with KubeFlow | Post Training | Remote | [llama-stack-provider-kft](https://github.com/opendatahub-io/llama-stack-provider-kft) | | KubeFlow Training | Train models with KubeFlow | Post Training | Remote | [llama-stack-provider-kft](https://github.com/opendatahub-io/llama-stack-provider-kft) |
| KubeFlow Pipelines | Train models with KubeFlow Pipelines | Post Training | Remote | [llama-stack-provider-kfp-trainer](https://github.com/opendatahub-io/llama-stack-provider-kfp-trainer) | | KubeFlow Pipelines | Train models with KubeFlow Pipelines | Post Training | Inline **and** Remote | [llama-stack-provider-kfp-trainer](https://github.com/opendatahub-io/llama-stack-provider-kfp-trainer) |
| RamaLama | Inference models with RamaLama | Inference | Remote | [ramalama-stack](https://github.com/containers/ramalama-stack) | | RamaLama | Inference models with RamaLama | Inference | Remote | [ramalama-stack](https://github.com/containers/ramalama-stack) |
| TrustyAI LM-Eval | Evaluate models with TrustyAI LM-Eval | Eval | Remote | [llama-stack-provider-lmeval](https://github.com/trustyai-explainability/llama-stack-provider-lmeval) | | TrustyAI LM-Eval | Evaluate models with TrustyAI LM-Eval | Eval | Remote | [llama-stack-provider-lmeval](https://github.com/trustyai-explainability/llama-stack-provider-lmeval) |
@ -182,7 +182,7 @@ dependencies = ["llama-stack", "pydantic", "ollama", "aiohttp"]
3. Create the provider specification: 3. Create the provider specification:
```yaml ```yaml
# /etc/llama-stack/providers.d/remote/inference/custom_ollama.yaml # ~/.llama/providers.d/remote/inference/custom_ollama.yaml
adapter: adapter:
adapter_type: custom_ollama adapter_type: custom_ollama
pip_packages: ["ollama", "aiohttp"] pip_packages: ["ollama", "aiohttp"]
@ -201,7 +201,7 @@ uv pip install -e .
5. Configure Llama Stack to use external providers: 5. Configure Llama Stack to use external providers:
```yaml ```yaml
external_providers_dir: /etc/llama-stack/providers.d/ external_providers_dir: ~/.llama/providers.d/
``` ```
The provider will now be available in Llama Stack with the type `remote::custom_ollama`. The provider will now be available in Llama Stack with the type `remote::custom_ollama`.

View file

@ -253,8 +253,6 @@ llama-stack-client toolgroups list
+---------------------------+------------------+------+---------------+ +---------------------------+------------------+------+---------------+
| identifier | provider_id | args | mcp_endpoint | | identifier | provider_id | args | mcp_endpoint |
+===========================+==================+======+===============+ +===========================+==================+======+===============+
| builtin::code_interpreter | code-interpreter | None | None |
+---------------------------+------------------+------+---------------+
| builtin::rag | rag-runtime | None | None | | builtin::rag | rag-runtime | None | None |
+---------------------------+------------------+------+---------------+ +---------------------------+------------------+------+---------------+
| builtin::websearch | tavily-search | None | None | | builtin::websearch | tavily-search | None | None |

View file

@ -38,6 +38,67 @@ wait_for_service() {
return 0 return 0
} }
usage() {
cat << EOF
📚 Llama-Stack Deployment Script
Description:
This script sets up and deploys Llama-Stack with Ollama integration in containers.
It handles both Docker and Podman runtimes and includes automatic platform detection.
Usage:
$(basename "$0") [OPTIONS]
Options:
-p, --port PORT Server port for Llama-Stack (default: ${PORT})
-o, --ollama-port PORT Ollama service port (default: ${OLLAMA_PORT})
-m, --model MODEL Model alias to use (default: ${MODEL_ALIAS})
-i, --image IMAGE Server image (default: ${SERVER_IMAGE})
-t, --timeout SECONDS Service wait timeout in seconds (default: ${WAIT_TIMEOUT})
-h, --help Show this help message
For more information:
Documentation: https://llama-stack.readthedocs.io/
GitHub: https://github.com/meta-llama/llama-stack
Report issues:
https://github.com/meta-llama/llama-stack/issues
EOF
}
# Parse command line arguments
while [[ $# -gt 0 ]]; do
case $1 in
-h|--help)
usage
exit 0
;;
-p|--port)
PORT="$2"
shift 2
;;
-o|--ollama-port)
OLLAMA_PORT="$2"
shift 2
;;
-m|--model)
MODEL_ALIAS="$2"
shift 2
;;
-i|--image)
SERVER_IMAGE="$2"
shift 2
;;
-t|--timeout)
WAIT_TIMEOUT="$2"
shift 2
;;
*)
die "Unknown option: $1"
;;
esac
done
if command -v docker &> /dev/null; then if command -v docker &> /dev/null; then
ENGINE="docker" ENGINE="docker"
elif command -v podman &> /dev/null; then elif command -v podman &> /dev/null; then

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import sys
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
@ -12,6 +13,7 @@ from typing import Annotated, Any, Literal, Protocol, runtime_checkable
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
CompletionMessage, CompletionMessage,
ResponseFormat, ResponseFormat,
@ -29,12 +31,20 @@ from llama_stack.apis.tools import ToolDef
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
from .openai_responses import ( from .openai_responses import (
OpenAIResponseInputMessage, OpenAIResponseInput,
OpenAIResponseInputTool, OpenAIResponseInputTool,
OpenAIResponseObject, OpenAIResponseObject,
OpenAIResponseObjectStream, OpenAIResponseObjectStream,
) )
# TODO: use enum.StrEnum when we drop support for python 3.10
if sys.version_info >= (3, 11):
from enum import StrEnum
else:
class StrEnum(str, Enum):
"""Backport of StrEnum for Python 3.10 and below."""
class Attachment(BaseModel): class Attachment(BaseModel):
"""An attachment to an agent turn. """An attachment to an agent turn.
@ -73,7 +83,7 @@ class StepCommon(BaseModel):
completed_at: datetime | None = None completed_at: datetime | None = None
class StepType(Enum): class StepType(StrEnum):
"""Type of the step in an agent turn. """Type of the step in an agent turn.
:cvar inference: The step is an inference step that calls an LLM. :cvar inference: The step is an inference step that calls an LLM.
@ -97,7 +107,7 @@ class InferenceStep(StepCommon):
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
step_type: Literal[StepType.inference.value] = StepType.inference.value step_type: Literal[StepType.inference] = StepType.inference
model_response: CompletionMessage model_response: CompletionMessage
@ -109,7 +119,7 @@ class ToolExecutionStep(StepCommon):
:param tool_responses: The tool responses from the tool calls. :param tool_responses: The tool responses from the tool calls.
""" """
step_type: Literal[StepType.tool_execution.value] = StepType.tool_execution.value step_type: Literal[StepType.tool_execution] = StepType.tool_execution
tool_calls: list[ToolCall] tool_calls: list[ToolCall]
tool_responses: list[ToolResponse] tool_responses: list[ToolResponse]
@ -121,7 +131,7 @@ class ShieldCallStep(StepCommon):
:param violation: The violation from the shield call. :param violation: The violation from the shield call.
""" """
step_type: Literal[StepType.shield_call.value] = StepType.shield_call.value step_type: Literal[StepType.shield_call] = StepType.shield_call
violation: SafetyViolation | None violation: SafetyViolation | None
@ -133,7 +143,7 @@ class MemoryRetrievalStep(StepCommon):
:param inserted_context: The context retrieved from the vector databases. :param inserted_context: The context retrieved from the vector databases.
""" """
step_type: Literal[StepType.memory_retrieval.value] = StepType.memory_retrieval.value step_type: Literal[StepType.memory_retrieval] = StepType.memory_retrieval
# TODO: should this be List[str]? # TODO: should this be List[str]?
vector_db_ids: str vector_db_ids: str
inserted_context: InterleavedContent inserted_context: InterleavedContent
@ -154,7 +164,7 @@ class Turn(BaseModel):
input_messages: list[UserMessage | ToolResponseMessage] input_messages: list[UserMessage | ToolResponseMessage]
steps: list[Step] steps: list[Step]
output_message: CompletionMessage output_message: CompletionMessage
output_attachments: list[Attachment] | None = Field(default_factory=list) output_attachments: list[Attachment] | None = Field(default_factory=lambda: [])
started_at: datetime started_at: datetime
completed_at: datetime | None = None completed_at: datetime | None = None
@ -182,10 +192,10 @@ register_schema(AgentToolGroup, name="AgentTool")
class AgentConfigCommon(BaseModel): class AgentConfigCommon(BaseModel):
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams) sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
input_shields: list[str] | None = Field(default_factory=list) input_shields: list[str] | None = Field(default_factory=lambda: [])
output_shields: list[str] | None = Field(default_factory=list) output_shields: list[str] | None = Field(default_factory=lambda: [])
toolgroups: list[AgentToolGroup] | None = Field(default_factory=list) toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
client_tools: list[ToolDef] | None = Field(default_factory=list) client_tools: list[ToolDef] | None = Field(default_factory=lambda: [])
tool_choice: ToolChoice | None = Field(default=None, deprecated="use tool_config instead") tool_choice: ToolChoice | None = Field(default=None, deprecated="use tool_config instead")
tool_prompt_format: ToolPromptFormat | None = Field(default=None, deprecated="use tool_config instead") tool_prompt_format: ToolPromptFormat | None = Field(default=None, deprecated="use tool_config instead")
tool_config: ToolConfig | None = Field(default=None) tool_config: ToolConfig | None = Field(default=None)
@ -232,21 +242,11 @@ class Agent(BaseModel):
created_at: datetime created_at: datetime
@json_schema_type
class ListAgentsResponse(BaseModel):
data: list[Agent]
@json_schema_type
class ListAgentSessionsResponse(BaseModel):
data: list[Session]
class AgentConfigOverridablePerTurn(AgentConfigCommon): class AgentConfigOverridablePerTurn(AgentConfigCommon):
instructions: str | None = None instructions: str | None = None
class AgentTurnResponseEventType(Enum): class AgentTurnResponseEventType(StrEnum):
step_start = "step_start" step_start = "step_start"
step_complete = "step_complete" step_complete = "step_complete"
step_progress = "step_progress" step_progress = "step_progress"
@ -258,15 +258,15 @@ class AgentTurnResponseEventType(Enum):
@json_schema_type @json_schema_type
class AgentTurnResponseStepStartPayload(BaseModel): class AgentTurnResponseStepStartPayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.step_start.value] = AgentTurnResponseEventType.step_start.value event_type: Literal[AgentTurnResponseEventType.step_start] = AgentTurnResponseEventType.step_start
step_type: StepType step_type: StepType
step_id: str step_id: str
metadata: dict[str, Any] | None = Field(default_factory=dict) metadata: dict[str, Any] | None = Field(default_factory=lambda: {})
@json_schema_type @json_schema_type
class AgentTurnResponseStepCompletePayload(BaseModel): class AgentTurnResponseStepCompletePayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.step_complete.value] = AgentTurnResponseEventType.step_complete.value event_type: Literal[AgentTurnResponseEventType.step_complete] = AgentTurnResponseEventType.step_complete
step_type: StepType step_type: StepType
step_id: str step_id: str
step_details: Step step_details: Step
@ -276,7 +276,7 @@ class AgentTurnResponseStepCompletePayload(BaseModel):
class AgentTurnResponseStepProgressPayload(BaseModel): class AgentTurnResponseStepProgressPayload(BaseModel):
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
event_type: Literal[AgentTurnResponseEventType.step_progress.value] = AgentTurnResponseEventType.step_progress.value event_type: Literal[AgentTurnResponseEventType.step_progress] = AgentTurnResponseEventType.step_progress
step_type: StepType step_type: StepType
step_id: str step_id: str
@ -285,21 +285,19 @@ class AgentTurnResponseStepProgressPayload(BaseModel):
@json_schema_type @json_schema_type
class AgentTurnResponseTurnStartPayload(BaseModel): class AgentTurnResponseTurnStartPayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.turn_start.value] = AgentTurnResponseEventType.turn_start.value event_type: Literal[AgentTurnResponseEventType.turn_start] = AgentTurnResponseEventType.turn_start
turn_id: str turn_id: str
@json_schema_type @json_schema_type
class AgentTurnResponseTurnCompletePayload(BaseModel): class AgentTurnResponseTurnCompletePayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.turn_complete.value] = AgentTurnResponseEventType.turn_complete.value event_type: Literal[AgentTurnResponseEventType.turn_complete] = AgentTurnResponseEventType.turn_complete
turn: Turn turn: Turn
@json_schema_type @json_schema_type
class AgentTurnResponseTurnAwaitingInputPayload(BaseModel): class AgentTurnResponseTurnAwaitingInputPayload(BaseModel):
event_type: Literal[AgentTurnResponseEventType.turn_awaiting_input.value] = ( event_type: Literal[AgentTurnResponseEventType.turn_awaiting_input] = AgentTurnResponseEventType.turn_awaiting_input
AgentTurnResponseEventType.turn_awaiting_input.value
)
turn: Turn turn: Turn
@ -341,7 +339,7 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
messages: list[UserMessage | ToolResponseMessage] messages: list[UserMessage | ToolResponseMessage]
documents: list[Document] | None = None documents: list[Document] | None = None
toolgroups: list[AgentToolGroup] | None = None toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
stream: bool | None = False stream: bool | None = False
tool_config: ToolConfig | None = None tool_config: ToolConfig | None = None
@ -415,8 +413,9 @@ class Agents(Protocol):
:param toolgroups: (Optional) List of toolgroups to create the turn with, will be used in addition to the agent's config toolgroups for the request. :param toolgroups: (Optional) List of toolgroups to create the turn with, will be used in addition to the agent's config toolgroups for the request.
:param tool_config: (Optional) The tool configuration to create the turn with, will be used to override the agent's tool_config. :param tool_config: (Optional) The tool configuration to create the turn with, will be used to override the agent's tool_config.
:returns: If stream=False, returns a Turn object. :returns: If stream=False, returns a Turn object.
If stream=True, returns an SSE event stream of AgentTurnResponseStreamChunk If stream=True, returns an SSE event stream of AgentTurnResponseStreamChunk.
""" """
...
@webmethod( @webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume", route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
@ -510,6 +509,7 @@ class Agents(Protocol):
:param session_id: The ID of the session to get. :param session_id: The ID of the session to get.
:param agent_id: The ID of the agent to get the session for. :param agent_id: The ID of the agent to get the session for.
:param turn_ids: (Optional) List of turn IDs to filter the session by. :param turn_ids: (Optional) List of turn IDs to filter the session by.
:returns: A Session.
""" """
... ...
@ -519,7 +519,7 @@ class Agents(Protocol):
session_id: str, session_id: str,
agent_id: str, agent_id: str,
) -> None: ) -> None:
"""Delete an agent session by its ID. """Delete an agent session by its ID and its associated turns.
:param session_id: The ID of the session to delete. :param session_id: The ID of the session to delete.
:param agent_id: The ID of the agent to delete the session for. :param agent_id: The ID of the agent to delete the session for.
@ -531,17 +531,19 @@ class Agents(Protocol):
self, self,
agent_id: str, agent_id: str,
) -> None: ) -> None:
"""Delete an agent by its ID. """Delete an agent by its ID and its associated sessions and turns.
:param agent_id: The ID of the agent to delete. :param agent_id: The ID of the agent to delete.
""" """
... ...
@webmethod(route="/agents", method="GET") @webmethod(route="/agents", method="GET")
async def list_agents(self) -> ListAgentsResponse: async def list_agents(self, start_index: int | None = None, limit: int | None = None) -> PaginatedResponse:
"""List all agents. """List all agents.
:returns: A ListAgentsResponse. :param start_index: The index to start the pagination from.
:param limit: The number of agents to return.
:returns: A PaginatedResponse.
""" """
... ...
@ -558,11 +560,15 @@ class Agents(Protocol):
async def list_agent_sessions( async def list_agent_sessions(
self, self,
agent_id: str, agent_id: str,
) -> ListAgentSessionsResponse: start_index: int | None = None,
limit: int | None = None,
) -> PaginatedResponse:
"""List all session(s) of a given agent. """List all session(s) of a given agent.
:param agent_id: The ID of the agent to list sessions for. :param agent_id: The ID of the agent to list sessions for.
:returns: A ListAgentSessionsResponse. :param start_index: The index to start the pagination from.
:param limit: The number of sessions to return.
:returns: A PaginatedResponse.
""" """
... ...
@ -588,7 +594,7 @@ class Agents(Protocol):
@webmethod(route="/openai/v1/responses", method="POST") @webmethod(route="/openai/v1/responses", method="POST")
async def create_openai_response( async def create_openai_response(
self, self,
input: str | list[OpenAIResponseInputMessage], input: str | list[OpenAIResponseInput],
model: str, model: str,
previous_response_id: str | None = None, previous_response_id: str | None = None,
store: bool | None = True, store: bool | None = True,
@ -601,4 +607,6 @@ class Agents(Protocol):
:param input: Input message(s) to create the response. :param input: Input message(s) to create the response.
:param model: The underlying LLM used for completions. :param model: The underlying LLM used for completions.
:param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses. :param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses.
:returns: An OpenAIResponseObject.
""" """
...

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Annotated, Literal from typing import Annotated, Any, Literal
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -17,6 +17,28 @@ class OpenAIResponseError(BaseModel):
message: str message: str
@json_schema_type
class OpenAIResponseInputMessageContentText(BaseModel):
text: str
type: Literal["input_text"] = "input_text"
@json_schema_type
class OpenAIResponseInputMessageContentImage(BaseModel):
detail: Literal["low"] | Literal["high"] | Literal["auto"] = "auto"
type: Literal["input_image"] = "input_image"
# TODO: handle file_id
image_url: str | None = None
# TODO: handle file content types
OpenAIResponseInputMessageContent = Annotated[
OpenAIResponseInputMessageContentText | OpenAIResponseInputMessageContentImage,
Field(discriminator="type"),
]
register_schema(OpenAIResponseInputMessageContent, name="OpenAIResponseInputMessageContent")
@json_schema_type @json_schema_type
class OpenAIResponseOutputMessageContentOutputText(BaseModel): class OpenAIResponseOutputMessageContentOutputText(BaseModel):
text: str text: str
@ -31,13 +53,22 @@ register_schema(OpenAIResponseOutputMessageContent, name="OpenAIResponseOutputMe
@json_schema_type @json_schema_type
class OpenAIResponseOutputMessage(BaseModel): class OpenAIResponseMessage(BaseModel):
id: str """
content: list[OpenAIResponseOutputMessageContent] Corresponds to the various Message types in the Responses API.
role: Literal["assistant"] = "assistant" They are all under one type because the Responses API gives them all
status: str the same "type" value, and there is no way to tell them apart in certain
scenarios.
"""
content: str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]
role: Literal["system"] | Literal["developer"] | Literal["user"] | Literal["assistant"]
type: Literal["message"] = "message" type: Literal["message"] = "message"
# The fields below are not used in all scenarios, but are required in others.
id: str | None = None
status: str | None = None
@json_schema_type @json_schema_type
class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel): class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel):
@ -46,8 +77,18 @@ class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel):
type: Literal["web_search_call"] = "web_search_call" type: Literal["web_search_call"] = "web_search_call"
@json_schema_type
class OpenAIResponseOutputMessageFunctionToolCall(BaseModel):
arguments: str
call_id: str
name: str
type: Literal["function_call"] = "function_call"
id: str
status: str
OpenAIResponseOutput = Annotated[ OpenAIResponseOutput = Annotated[
OpenAIResponseOutputMessage | OpenAIResponseOutputMessageWebSearchToolCall, OpenAIResponseMessage | OpenAIResponseOutputMessageWebSearchToolCall | OpenAIResponseOutputMessageFunctionToolCall,
Field(discriminator="type"), Field(discriminator="type"),
] ]
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput") register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
@ -90,32 +131,29 @@ register_schema(OpenAIResponseObjectStream, name="OpenAIResponseObjectStream")
@json_schema_type @json_schema_type
class OpenAIResponseInputMessageContentText(BaseModel): class OpenAIResponseInputFunctionToolCallOutput(BaseModel):
text: str """
type: Literal["input_text"] = "input_text" This represents the output of a function call that gets passed back to the model.
"""
call_id: str
output: str
type: Literal["function_call_output"] = "function_call_output"
id: str | None = None
status: str | None = None
@json_schema_type OpenAIResponseInput = Annotated[
class OpenAIResponseInputMessageContentImage(BaseModel): # Responses API allows output messages to be passed in as input
detail: Literal["low"] | Literal["high"] | Literal["auto"] = "auto" OpenAIResponseOutputMessageWebSearchToolCall
type: Literal["input_image"] = "input_image" | OpenAIResponseOutputMessageFunctionToolCall
# TODO: handle file_id | OpenAIResponseInputFunctionToolCallOutput
image_url: str | None = None |
# Fallback to the generic message type as a last resort
OpenAIResponseMessage,
# TODO: handle file content types Field(union_mode="left_to_right"),
OpenAIResponseInputMessageContent = Annotated[
OpenAIResponseInputMessageContentText | OpenAIResponseInputMessageContentImage,
Field(discriminator="type"),
] ]
register_schema(OpenAIResponseInputMessageContent, name="OpenAIResponseInputMessageContent") register_schema(OpenAIResponseInput, name="OpenAIResponseInput")
@json_schema_type
class OpenAIResponseInputMessage(BaseModel):
content: str | list[OpenAIResponseInputMessageContent]
role: Literal["system"] | Literal["developer"] | Literal["user"] | Literal["assistant"]
type: Literal["message"] | None = "message"
@json_schema_type @json_schema_type
@ -126,8 +164,35 @@ class OpenAIResponseInputToolWebSearch(BaseModel):
# TODO: add user_location # TODO: add user_location
@json_schema_type
class OpenAIResponseInputToolFunction(BaseModel):
type: Literal["function"] = "function"
name: str
description: str | None = None
parameters: dict[str, Any] | None
strict: bool | None = None
class FileSearchRankingOptions(BaseModel):
ranker: str | None = None
score_threshold: float | None = Field(default=0.0, ge=0.0, le=1.0)
@json_schema_type
class OpenAIResponseInputToolFileSearch(BaseModel):
type: Literal["file_search"] = "file_search"
vector_store_id: list[str]
ranking_options: FileSearchRankingOptions | None = None
# TODO: add filters
OpenAIResponseInputTool = Annotated[ OpenAIResponseInputTool = Annotated[
OpenAIResponseInputToolWebSearch, OpenAIResponseInputToolWebSearch | OpenAIResponseInputToolFileSearch | OpenAIResponseInputToolFunction,
Field(discriminator="type"), Field(discriminator="type"),
] ]
register_schema(OpenAIResponseInputTool, name="OpenAIResponseInputTool") register_schema(OpenAIResponseInputTool, name="OpenAIResponseInputTool")
class OpenAIResponseInputItemList(BaseModel):
data: list[OpenAIResponseInput]
object: Literal["list"] = "list"

View file

@ -38,7 +38,17 @@ class BatchInference(Protocol):
sampling_params: SamplingParams | None = None, sampling_params: SamplingParams | None = None,
response_format: ResponseFormat | None = None, response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None, logprobs: LogProbConfig | None = None,
) -> Job: ... ) -> Job:
"""Generate completions for a batch of content.
:param model: The model to use for the completion.
:param content_batch: The content to complete.
:param sampling_params: The sampling parameters to use for the completion.
:param response_format: The response format to use for the completion.
:param logprobs: The logprobs to use for the completion.
:returns: A job for the completion.
"""
...
@webmethod(route="/batch-inference/chat-completion", method="POST") @webmethod(route="/batch-inference/chat-completion", method="POST")
async def chat_completion( async def chat_completion(
@ -52,4 +62,17 @@ class BatchInference(Protocol):
tool_prompt_format: ToolPromptFormat | None = None, tool_prompt_format: ToolPromptFormat | None = None,
response_format: ResponseFormat | None = None, response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None, logprobs: LogProbConfig | None = None,
) -> Job: ... ) -> Job:
"""Generate chat completions for a batch of messages.
:param model: The model to use for the chat completion.
:param messages_batch: The messages to complete.
:param sampling_params: The sampling parameters to use for the completion.
:param tools: The tools to use for the chat completion.
:param tool_choice: The tool choice to use for the chat completion.
:param tool_prompt_format: The tool prompt format to use for the chat completion.
:param response_format: The response format to use for the chat completion.
:param logprobs: The logprobs to use for the chat completion.
:returns: A job for the chat completion.
"""
...

View file

@ -22,14 +22,14 @@ class CommonBenchmarkFields(BaseModel):
@json_schema_type @json_schema_type
class Benchmark(CommonBenchmarkFields, Resource): class Benchmark(CommonBenchmarkFields, Resource):
type: Literal[ResourceType.benchmark.value] = ResourceType.benchmark.value type: Literal[ResourceType.benchmark] = ResourceType.benchmark
@property @property
def benchmark_id(self) -> str: def benchmark_id(self) -> str:
return self.identifier return self.identifier
@property @property
def provider_benchmark_id(self) -> str: def provider_benchmark_id(self) -> str | None:
return self.provider_resource_id return self.provider_resource_id
@ -46,13 +46,24 @@ class ListBenchmarksResponse(BaseModel):
@runtime_checkable @runtime_checkable
class Benchmarks(Protocol): class Benchmarks(Protocol):
@webmethod(route="/eval/benchmarks", method="GET") @webmethod(route="/eval/benchmarks", method="GET")
async def list_benchmarks(self) -> ListBenchmarksResponse: ... async def list_benchmarks(self) -> ListBenchmarksResponse:
"""List all benchmarks.
:returns: A ListBenchmarksResponse.
"""
...
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="GET") @webmethod(route="/eval/benchmarks/{benchmark_id}", method="GET")
async def get_benchmark( async def get_benchmark(
self, self,
benchmark_id: str, benchmark_id: str,
) -> Benchmark: ... ) -> Benchmark:
"""Get a benchmark by its ID.
:param benchmark_id: The ID of the benchmark to get.
:returns: A Benchmark.
"""
...
@webmethod(route="/eval/benchmarks", method="POST") @webmethod(route="/eval/benchmarks", method="POST")
async def register_benchmark( async def register_benchmark(
@ -63,4 +74,14 @@ class Benchmarks(Protocol):
provider_benchmark_id: str | None = None, provider_benchmark_id: str | None = None,
provider_id: str | None = None, provider_id: str | None = None,
metadata: dict[str, Any] | None = None, metadata: dict[str, Any] | None = None,
) -> None: ... ) -> None:
"""Register a benchmark.
:param benchmark_id: The ID of the benchmark to register.
:param dataset_id: The ID of the dataset to use for the benchmark.
:param scoring_functions: The scoring functions to use for the benchmark.
:param provider_benchmark_id: The ID of the provider benchmark to use for the benchmark.
:param provider_id: The ID of the provider to use for the benchmark.
:param metadata: The metadata to use for the benchmark.
"""
...

View file

@ -28,7 +28,7 @@ class _URLOrData(BaseModel):
url: URL | None = None url: URL | None = None
# data is a base64 encoded string, hint with contentEncoding=base64 # data is a base64 encoded string, hint with contentEncoding=base64
data: str | None = Field(contentEncoding="base64", default=None) data: str | None = Field(default=None, json_schema_extra={"contentEncoding": "base64"})
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod

View file

@ -34,14 +34,21 @@ class DatasetIO(Protocol):
- limit: Number of items to return. If None or -1, returns all items. - limit: Number of items to return. If None or -1, returns all items.
The response includes: The response includes:
- data: List of items for the current page - data: List of items for the current page.
- has_more: Whether there are more items available after this set - has_more: Whether there are more items available after this set.
:param dataset_id: The ID of the dataset to get the rows from. :param dataset_id: The ID of the dataset to get the rows from.
:param start_index: Index into dataset for the first row to get. Get all rows if None. :param start_index: Index into dataset for the first row to get. Get all rows if None.
:param limit: The number of rows to get. :param limit: The number of rows to get.
:returns: A PaginatedResponse.
""" """
... ...
@webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST") @webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST")
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None: ... async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
"""Append rows to a dataset.
:param dataset_id: The ID of the dataset to append the rows to.
:param rows: The rows to append to the dataset.
"""
...

View file

@ -106,14 +106,14 @@ class CommonDatasetFields(BaseModel):
@json_schema_type @json_schema_type
class Dataset(CommonDatasetFields, Resource): class Dataset(CommonDatasetFields, Resource):
type: Literal[ResourceType.dataset.value] = ResourceType.dataset.value type: Literal[ResourceType.dataset] = ResourceType.dataset
@property @property
def dataset_id(self) -> str: def dataset_id(self) -> str:
return self.identifier return self.identifier
@property @property
def provider_dataset_id(self) -> str: def provider_dataset_id(self) -> str | None:
return self.provider_resource_id return self.provider_resource_id
@ -137,7 +137,8 @@ class Datasets(Protocol):
""" """
Register a new dataset. Register a new dataset.
:param purpose: The purpose of the dataset. One of :param purpose: The purpose of the dataset.
One of:
- "post-training/messages": The dataset contains a messages column with list of messages for post-training. - "post-training/messages": The dataset contains a messages column with list of messages for post-training.
{ {
"messages": [ "messages": [
@ -188,8 +189,9 @@ class Datasets(Protocol):
] ]
} }
:param metadata: The metadata for the dataset. :param metadata: The metadata for the dataset.
- E.g. {"description": "My dataset"} - E.g. {"description": "My dataset"}.
:param dataset_id: The ID of the dataset. If not provided, an ID will be generated. :param dataset_id: The ID of the dataset. If not provided, an ID will be generated.
:returns: A Dataset.
""" """
... ...
@ -197,13 +199,29 @@ class Datasets(Protocol):
async def get_dataset( async def get_dataset(
self, self,
dataset_id: str, dataset_id: str,
) -> Dataset: ... ) -> Dataset:
"""Get a dataset by its ID.
:param dataset_id: The ID of the dataset to get.
:returns: A Dataset.
"""
...
@webmethod(route="/datasets", method="GET") @webmethod(route="/datasets", method="GET")
async def list_datasets(self) -> ListDatasetsResponse: ... async def list_datasets(self) -> ListDatasetsResponse:
"""List all datasets.
:returns: A ListDatasetsResponse.
"""
...
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE") @webmethod(route="/datasets/{dataset_id:path}", method="DELETE")
async def unregister_dataset( async def unregister_dataset(
self, self,
dataset_id: str, dataset_id: str,
) -> None: ... ) -> None:
"""Unregister a dataset by its ID.
:param dataset_id: The ID of the dataset to unregister.
"""
...

View file

@ -93,8 +93,9 @@ class Eval(Protocol):
:param benchmark_id: The ID of the benchmark to run the evaluation on. :param benchmark_id: The ID of the benchmark to run the evaluation on.
:param benchmark_config: The configuration for the benchmark. :param benchmark_config: The configuration for the benchmark.
:return: The job that was created to run the evaluation. :returns: The job that was created to run the evaluation.
""" """
...
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST") @webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST")
async def evaluate_rows( async def evaluate_rows(
@ -110,8 +111,9 @@ class Eval(Protocol):
:param input_rows: The rows to evaluate. :param input_rows: The rows to evaluate.
:param scoring_functions: The scoring functions to use for the evaluation. :param scoring_functions: The scoring functions to use for the evaluation.
:param benchmark_config: The configuration for the benchmark. :param benchmark_config: The configuration for the benchmark.
:return: EvaluateResponse object containing generations and scores :returns: EvaluateResponse object containing generations and scores.
""" """
...
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET") @webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
async def job_status(self, benchmark_id: str, job_id: str) -> Job: async def job_status(self, benchmark_id: str, job_id: str) -> Job:
@ -119,7 +121,7 @@ class Eval(Protocol):
:param benchmark_id: The ID of the benchmark to run the evaluation on. :param benchmark_id: The ID of the benchmark to run the evaluation on.
:param job_id: The ID of the job to get the status of. :param job_id: The ID of the job to get the status of.
:return: The status of the evaluationjob. :returns: The status of the evaluation job.
""" """
... ...
@ -138,5 +140,6 @@ class Eval(Protocol):
:param benchmark_id: The ID of the benchmark to run the evaluation on. :param benchmark_id: The ID of the benchmark to run the evaluation on.
:param job_id: The ID of the job to get the result of. :param job_id: The ID of the job to get the result of.
:return: The result of the job. :returns: The result of the job.
""" """
...

View file

@ -91,10 +91,11 @@ class Files(Protocol):
""" """
Create a new upload session for a file identified by a bucket and key. Create a new upload session for a file identified by a bucket and key.
:param bucket: Bucket under which the file is stored (valid chars: a-zA-Z0-9_-) :param bucket: Bucket under which the file is stored (valid chars: a-zA-Z0-9_-).
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.) :param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.).
:param mime_type: MIME type of the file :param mime_type: MIME type of the file.
:param size: File size in bytes :param size: File size in bytes.
:returns: A FileUploadResponse.
""" """
... ...
@ -107,7 +108,8 @@ class Files(Protocol):
Upload file content to an existing upload session. Upload file content to an existing upload session.
On the server, request body will have the raw bytes that are uploaded. On the server, request body will have the raw bytes that are uploaded.
:param upload_id: ID of the upload session :param upload_id: ID of the upload session.
:returns: A FileResponse or None if the upload is not complete.
""" """
... ...
@ -117,9 +119,10 @@ class Files(Protocol):
upload_id: str, upload_id: str,
) -> FileUploadResponse: ) -> FileUploadResponse:
""" """
Returns information about an existsing upload session Returns information about an existsing upload session.
:param upload_id: ID of the upload session :param upload_id: ID of the upload session.
:returns: A FileUploadResponse.
""" """
... ...
@ -130,6 +133,9 @@ class Files(Protocol):
) -> ListBucketResponse: ) -> ListBucketResponse:
""" """
List all buckets. List all buckets.
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-).
:returns: A ListBucketResponse.
""" """
... ...
@ -141,7 +147,8 @@ class Files(Protocol):
""" """
List all files in a bucket. List all files in a bucket.
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-) :param bucket: Bucket name (valid chars: a-zA-Z0-9_-).
:returns: A ListFileResponse.
""" """
... ...
@ -154,8 +161,9 @@ class Files(Protocol):
""" """
Get a file info identified by a bucket and key. Get a file info identified by a bucket and key.
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-) :param bucket: Bucket name (valid chars: a-zA-Z0-9_-).
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.) :param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.).
:returns: A FileResponse.
""" """
... ...
@ -168,7 +176,7 @@ class Files(Protocol):
""" """
Delete a file identified by a bucket and key. Delete a file identified by a bucket and key.
:param bucket: Bucket name (valid chars: a-zA-Z0-9_-) :param bucket: Bucket name (valid chars: a-zA-Z0-9_-).
:param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.) :param key: Key under which the file is stored (valid chars: a-zA-Z0-9_-/.).
""" """
... ...

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import sys
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from enum import Enum from enum import Enum
from typing import ( from typing import (
@ -35,6 +36,16 @@ register_schema(ToolCall)
register_schema(ToolParamDefinition) register_schema(ToolParamDefinition)
register_schema(ToolDefinition) register_schema(ToolDefinition)
# TODO: use enum.StrEnum when we drop support for python 3.10
if sys.version_info >= (3, 11):
from enum import StrEnum
else:
class StrEnum(str, Enum):
"""Backport of StrEnum for Python 3.10 and below."""
pass
@json_schema_type @json_schema_type
class GreedySamplingStrategy(BaseModel): class GreedySamplingStrategy(BaseModel):
@ -187,7 +198,7 @@ class CompletionMessage(BaseModel):
role: Literal["assistant"] = "assistant" role: Literal["assistant"] = "assistant"
content: InterleavedContent content: InterleavedContent
stop_reason: StopReason stop_reason: StopReason
tool_calls: list[ToolCall] | None = Field(default_factory=list) tool_calls: list[ToolCall] | None = Field(default_factory=lambda: [])
Message = Annotated[ Message = Annotated[
@ -267,7 +278,7 @@ class ChatCompletionResponseEvent(BaseModel):
stop_reason: StopReason | None = None stop_reason: StopReason | None = None
class ResponseFormatType(Enum): class ResponseFormatType(StrEnum):
"""Types of formats for structured (guided) decoding. """Types of formats for structured (guided) decoding.
:cvar json_schema: Response should conform to a JSON schema. In a Python SDK, this is often a `pydantic` model. :cvar json_schema: Response should conform to a JSON schema. In a Python SDK, this is often a `pydantic` model.
@ -286,7 +297,7 @@ class JsonSchemaResponseFormat(BaseModel):
:param json_schema: The JSON schema the response should conform to. In a Python SDK, this is often a `pydantic` model. :param json_schema: The JSON schema the response should conform to. In a Python SDK, this is often a `pydantic` model.
""" """
type: Literal[ResponseFormatType.json_schema.value] = ResponseFormatType.json_schema.value type: Literal[ResponseFormatType.json_schema] = ResponseFormatType.json_schema
json_schema: dict[str, Any] json_schema: dict[str, Any]
@ -298,7 +309,7 @@ class GrammarResponseFormat(BaseModel):
:param bnf: The BNF grammar specification the response should conform to :param bnf: The BNF grammar specification the response should conform to
""" """
type: Literal[ResponseFormatType.grammar.value] = ResponseFormatType.grammar.value type: Literal[ResponseFormatType.grammar] = ResponseFormatType.grammar
bnf: dict[str, Any] bnf: dict[str, Any]
@ -394,7 +405,7 @@ class ChatCompletionRequest(BaseModel):
messages: list[Message] messages: list[Message]
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams) sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
tools: list[ToolDefinition] | None = Field(default_factory=list) tools: list[ToolDefinition] | None = Field(default_factory=lambda: [])
tool_config: ToolConfig | None = Field(default_factory=ToolConfig) tool_config: ToolConfig | None = Field(default_factory=ToolConfig)
response_format: ResponseFormat | None = None response_format: ResponseFormat | None = None
@ -567,14 +578,14 @@ class OpenAIResponseFormatText(BaseModel):
@json_schema_type @json_schema_type
class OpenAIJSONSchema(TypedDict, total=False): class OpenAIJSONSchema(TypedDict, total=False):
name: str name: str
description: str | None = None description: str | None
strict: bool | None = None strict: bool | None
# Pydantic BaseModel cannot be used with a schema param, since it already # Pydantic BaseModel cannot be used with a schema param, since it already
# has one. And, we don't want to alias here because then have to handle # has one. And, we don't want to alias here because then have to handle
# that alias when converting to OpenAI params. So, to support schema, # that alias when converting to OpenAI params. So, to support schema,
# we use a TypedDict. # we use a TypedDict.
schema: dict[str, Any] | None = None schema: dict[str, Any] | None
@json_schema_type @json_schema_type
@ -809,15 +820,32 @@ class BatchChatCompletionResponse(BaseModel):
batch: list[ChatCompletionResponse] batch: list[ChatCompletionResponse]
class OpenAICompletionWithInputMessages(OpenAIChatCompletion):
input_messages: list[OpenAIMessageParam]
@json_schema_type
class ListOpenAIChatCompletionResponse(BaseModel):
data: list[OpenAICompletionWithInputMessages]
has_more: bool
first_id: str
last_id: str
object: Literal["list"] = "list"
class Order(Enum):
asc = "asc"
desc = "desc"
@runtime_checkable @runtime_checkable
@trace_protocol @trace_protocol
class Inference(Protocol): class InferenceProvider(Protocol):
"""Llama Stack Inference API for generating completions, chat completions, and embeddings.
This API provides the raw interface to the underlying models. Two kinds of models are supported:
- LLM models: these models generate "raw" and "chat" (conversational) completions.
- Embedding models: these models generate embeddings to be used for semantic search.
""" """
This protocol defines the interface that should be implemented by all inference providers.
"""
API_NAMESPACE: str = "Inference"
model_store: ModelStore | None = None model_store: ModelStore | None = None
@ -834,13 +862,13 @@ class Inference(Protocol):
"""Generate a completion for the given content using the specified model. """Generate a completion for the given content using the specified model.
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint. :param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param content: The content to generate a completion for :param content: The content to generate a completion for.
:param sampling_params: (Optional) Parameters to control the sampling strategy :param sampling_params: (Optional) Parameters to control the sampling strategy.
:param response_format: (Optional) Grammar specification for guided (structured) decoding :param response_format: (Optional) Grammar specification for guided (structured) decoding.
:param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False. :param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False.
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned. :param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
:returns: If stream=False, returns a CompletionResponse with the full completion. :returns: If stream=False, returns a CompletionResponse with the full completion.
If stream=True, returns an SSE event stream of CompletionResponseStreamChunk If stream=True, returns an SSE event stream of CompletionResponseStreamChunk.
""" """
... ...
@ -853,6 +881,15 @@ class Inference(Protocol):
response_format: ResponseFormat | None = None, response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None, logprobs: LogProbConfig | None = None,
) -> BatchCompletionResponse: ) -> BatchCompletionResponse:
"""Generate completions for a batch of content using the specified model.
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param content_batch: The content to generate completions for.
:param sampling_params: (Optional) Parameters to control the sampling strategy.
:param response_format: (Optional) Grammar specification for guided (structured) decoding.
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
:returns: A BatchCompletionResponse with the full completions.
"""
raise NotImplementedError("Batch completion is not implemented") raise NotImplementedError("Batch completion is not implemented")
@webmethod(route="/inference/chat-completion", method="POST") @webmethod(route="/inference/chat-completion", method="POST")
@ -872,9 +909,9 @@ class Inference(Protocol):
"""Generate a chat completion for the given messages using the specified model. """Generate a chat completion for the given messages using the specified model.
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint. :param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param messages: List of messages in the conversation :param messages: List of messages in the conversation.
:param sampling_params: Parameters to control the sampling strategy :param sampling_params: Parameters to control the sampling strategy.
:param tools: (Optional) List of tool definitions available to the model :param tools: (Optional) List of tool definitions available to the model.
:param tool_choice: (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto. :param tool_choice: (Optional) Whether tool use is required or automatic. Defaults to ToolChoice.auto.
.. deprecated:: .. deprecated::
Use tool_config instead. Use tool_config instead.
@ -891,7 +928,7 @@ class Inference(Protocol):
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned. :param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
:param tool_config: (Optional) Configuration for tool use. :param tool_config: (Optional) Configuration for tool use.
:returns: If stream=False, returns a ChatCompletionResponse with the full completion. :returns: If stream=False, returns a ChatCompletionResponse with the full completion.
If stream=True, returns an SSE event stream of ChatCompletionResponseStreamChunk If stream=True, returns an SSE event stream of ChatCompletionResponseStreamChunk.
""" """
... ...
@ -906,6 +943,17 @@ class Inference(Protocol):
response_format: ResponseFormat | None = None, response_format: ResponseFormat | None = None,
logprobs: LogProbConfig | None = None, logprobs: LogProbConfig | None = None,
) -> BatchChatCompletionResponse: ) -> BatchChatCompletionResponse:
"""Generate chat completions for a batch of messages using the specified model.
:param model_id: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param messages_batch: The messages to generate completions for.
:param sampling_params: (Optional) Parameters to control the sampling strategy.
:param tools: (Optional) List of tool definitions available to the model.
:param tool_config: (Optional) Configuration for tool use.
:param response_format: (Optional) Grammar specification for guided (structured) decoding.
:param logprobs: (Optional) If specified, log probabilities for each token position will be returned.
:returns: A BatchChatCompletionResponse with the full completions.
"""
raise NotImplementedError("Batch chat completion is not implemented") raise NotImplementedError("Batch chat completion is not implemented")
@webmethod(route="/inference/embeddings", method="POST") @webmethod(route="/inference/embeddings", method="POST")
@ -924,7 +972,7 @@ class Inference(Protocol):
:param output_dimension: (Optional) Output dimensionality for the embeddings. Only supported by Matryoshka models. :param output_dimension: (Optional) Output dimensionality for the embeddings. Only supported by Matryoshka models.
:param text_truncation: (Optional) Config for how to truncate text for embedding when text is longer than the model's max sequence length. :param text_truncation: (Optional) Config for how to truncate text for embedding when text is longer than the model's max sequence length.
:param task_type: (Optional) How is the embedding being used? This is only supported by asymmetric embedding models. :param task_type: (Optional) How is the embedding being used? This is only supported by asymmetric embedding models.
:returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id} :returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}.
""" """
... ...
@ -956,22 +1004,23 @@ class Inference(Protocol):
"""Generate an OpenAI-compatible completion for the given prompt using the specified model. """Generate an OpenAI-compatible completion for the given prompt using the specified model.
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint. :param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param prompt: The prompt to generate a completion for :param prompt: The prompt to generate a completion for.
:param best_of: (Optional) The number of completions to generate :param best_of: (Optional) The number of completions to generate.
:param echo: (Optional) Whether to echo the prompt :param echo: (Optional) Whether to echo the prompt.
:param frequency_penalty: (Optional) The penalty for repeated tokens :param frequency_penalty: (Optional) The penalty for repeated tokens.
:param logit_bias: (Optional) The logit bias to use :param logit_bias: (Optional) The logit bias to use.
:param logprobs: (Optional) The log probabilities to use :param logprobs: (Optional) The log probabilities to use.
:param max_tokens: (Optional) The maximum number of tokens to generate :param max_tokens: (Optional) The maximum number of tokens to generate.
:param n: (Optional) The number of completions to generate :param n: (Optional) The number of completions to generate.
:param presence_penalty: (Optional) The penalty for repeated tokens :param presence_penalty: (Optional) The penalty for repeated tokens.
:param seed: (Optional) The seed to use :param seed: (Optional) The seed to use.
:param stop: (Optional) The stop tokens to use :param stop: (Optional) The stop tokens to use.
:param stream: (Optional) Whether to stream the response :param stream: (Optional) Whether to stream the response.
:param stream_options: (Optional) The stream options to use :param stream_options: (Optional) The stream options to use.
:param temperature: (Optional) The temperature to use :param temperature: (Optional) The temperature to use.
:param top_p: (Optional) The top p to use :param top_p: (Optional) The top p to use.
:param user: (Optional) The user to use :param user: (Optional) The user to use.
:returns: An OpenAICompletion.
""" """
... ...
@ -1005,27 +1054,64 @@ class Inference(Protocol):
"""Generate an OpenAI-compatible chat completion for the given messages using the specified model. """Generate an OpenAI-compatible chat completion for the given messages using the specified model.
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint. :param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
:param messages: List of messages in the conversation :param messages: List of messages in the conversation.
:param frequency_penalty: (Optional) The penalty for repeated tokens :param frequency_penalty: (Optional) The penalty for repeated tokens.
:param function_call: (Optional) The function call to use :param function_call: (Optional) The function call to use.
:param functions: (Optional) List of functions to use :param functions: (Optional) List of functions to use.
:param logit_bias: (Optional) The logit bias to use :param logit_bias: (Optional) The logit bias to use.
:param logprobs: (Optional) The log probabilities to use :param logprobs: (Optional) The log probabilities to use.
:param max_completion_tokens: (Optional) The maximum number of tokens to generate :param max_completion_tokens: (Optional) The maximum number of tokens to generate.
:param max_tokens: (Optional) The maximum number of tokens to generate :param max_tokens: (Optional) The maximum number of tokens to generate.
:param n: (Optional) The number of completions to generate :param n: (Optional) The number of completions to generate.
:param parallel_tool_calls: (Optional) Whether to parallelize tool calls :param parallel_tool_calls: (Optional) Whether to parallelize tool calls.
:param presence_penalty: (Optional) The penalty for repeated tokens :param presence_penalty: (Optional) The penalty for repeated tokens.
:param response_format: (Optional) The response format to use :param response_format: (Optional) The response format to use.
:param seed: (Optional) The seed to use :param seed: (Optional) The seed to use.
:param stop: (Optional) The stop tokens to use :param stop: (Optional) The stop tokens to use.
:param stream: (Optional) Whether to stream the response :param stream: (Optional) Whether to stream the response.
:param stream_options: (Optional) The stream options to use :param stream_options: (Optional) The stream options to use.
:param temperature: (Optional) The temperature to use :param temperature: (Optional) The temperature to use.
:param tool_choice: (Optional) The tool choice to use :param tool_choice: (Optional) The tool choice to use.
:param tools: (Optional) The tools to use :param tools: (Optional) The tools to use.
:param top_logprobs: (Optional) The top log probabilities to use :param top_logprobs: (Optional) The top log probabilities to use.
:param top_p: (Optional) The top p to use :param top_p: (Optional) The top p to use.
:param user: (Optional) The user to use :param user: (Optional) The user to use.
:returns: An OpenAIChatCompletion.
""" """
... ...
class Inference(InferenceProvider):
"""Llama Stack Inference API for generating completions, chat completions, and embeddings.
This API provides the raw interface to the underlying models. Two kinds of models are supported:
- LLM models: these models generate "raw" and "chat" (conversational) completions.
- Embedding models: these models generate embeddings to be used for semantic search.
"""
@webmethod(route="/openai/v1/chat/completions", method="GET")
async def list_chat_completions(
self,
after: str | None = None,
limit: int | None = 20,
model: str | None = None,
order: Order | None = Order.desc,
) -> ListOpenAIChatCompletionResponse:
"""List all chat completions.
:param after: The ID of the last chat completion to return.
:param limit: The maximum number of chat completions to return.
:param model: The model to filter by.
:param order: The order to sort the chat completions by: "asc" or "desc". Defaults to "desc".
:returns: A ListOpenAIChatCompletionResponse.
"""
raise NotImplementedError("List chat completions is not implemented")
@webmethod(route="/openai/v1/chat/completions/{completion_id}", method="GET")
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
"""Describe a chat completion by its ID.
:param completion_id: ID of the chat completion.
:returns: A OpenAICompletionWithInputMessages.
"""
raise NotImplementedError("Get chat completion is not implemented")

View file

@ -36,10 +36,25 @@ class ListRoutesResponse(BaseModel):
@runtime_checkable @runtime_checkable
class Inspect(Protocol): class Inspect(Protocol):
@webmethod(route="/inspect/routes", method="GET") @webmethod(route="/inspect/routes", method="GET")
async def list_routes(self) -> ListRoutesResponse: ... async def list_routes(self) -> ListRoutesResponse:
"""List all routes.
:returns: A ListRoutesResponse.
"""
...
@webmethod(route="/health", method="GET") @webmethod(route="/health", method="GET")
async def health(self) -> HealthInfo: ... async def health(self) -> HealthInfo:
"""Get the health of the service.
:returns: A HealthInfo.
"""
...
@webmethod(route="/version", method="GET") @webmethod(route="/version", method="GET")
async def version(self) -> VersionInfo: ... async def version(self) -> VersionInfo:
"""Get the version of the service.
:returns: A VersionInfo.
"""
...

View file

@ -29,14 +29,14 @@ class ModelType(str, Enum):
@json_schema_type @json_schema_type
class Model(CommonModelFields, Resource): class Model(CommonModelFields, Resource):
type: Literal[ResourceType.model.value] = ResourceType.model.value type: Literal[ResourceType.model] = ResourceType.model
@property @property
def model_id(self) -> str: def model_id(self) -> str:
return self.identifier return self.identifier
@property @property
def provider_model_id(self) -> str: def provider_model_id(self) -> str | None:
return self.provider_resource_id return self.provider_resource_id
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
@ -80,16 +80,32 @@ class OpenAIListModelsResponse(BaseModel):
@trace_protocol @trace_protocol
class Models(Protocol): class Models(Protocol):
@webmethod(route="/models", method="GET") @webmethod(route="/models", method="GET")
async def list_models(self) -> ListModelsResponse: ... async def list_models(self) -> ListModelsResponse:
"""List all models.
:returns: A ListModelsResponse.
"""
...
@webmethod(route="/openai/v1/models", method="GET") @webmethod(route="/openai/v1/models", method="GET")
async def openai_list_models(self) -> OpenAIListModelsResponse: ... async def openai_list_models(self) -> OpenAIListModelsResponse:
"""List models using the OpenAI API.
:returns: A OpenAIListModelsResponse.
"""
...
@webmethod(route="/models/{model_id:path}", method="GET") @webmethod(route="/models/{model_id:path}", method="GET")
async def get_model( async def get_model(
self, self,
model_id: str, model_id: str,
) -> Model: ... ) -> Model:
"""Get a model by its identifier.
:param model_id: The identifier of the model to get.
:returns: A Model.
"""
...
@webmethod(route="/models", method="POST") @webmethod(route="/models", method="POST")
async def register_model( async def register_model(
@ -99,10 +115,25 @@ class Models(Protocol):
provider_id: str | None = None, provider_id: str | None = None,
metadata: dict[str, Any] | None = None, metadata: dict[str, Any] | None = None,
model_type: ModelType | None = None, model_type: ModelType | None = None,
) -> Model: ... ) -> Model:
"""Register a model.
:param model_id: The identifier of the model to register.
:param provider_model_id: The identifier of the model in the provider.
:param provider_id: The identifier of the provider.
:param metadata: Any additional metadata for this model.
:param model_type: The type of model to register.
:returns: A Model.
"""
...
@webmethod(route="/models/{model_id:path}", method="DELETE") @webmethod(route="/models/{model_id:path}", method="DELETE")
async def unregister_model( async def unregister_model(
self, self,
model_id: str, model_id: str,
) -> None: ... ) -> None:
"""Unregister a model.
:param model_id: The identifier of the model to unregister.
"""
...

View file

@ -182,7 +182,19 @@ class PostTraining(Protocol):
), ),
checkpoint_dir: str | None = None, checkpoint_dir: str | None = None,
algorithm_config: AlgorithmConfig | None = None, algorithm_config: AlgorithmConfig | None = None,
) -> PostTrainingJob: ... ) -> PostTrainingJob:
"""Run supervised fine-tuning of a model.
:param job_uuid: The UUID of the job to create.
:param training_config: The training configuration.
:param hyperparam_search_config: The hyperparam search configuration.
:param logger_config: The logger configuration.
:param model: The model to fine-tune.
:param checkpoint_dir: The directory to save checkpoint(s) to.
:param algorithm_config: The algorithm configuration.
:returns: A PostTrainingJob.
"""
...
@webmethod(route="/post-training/preference-optimize", method="POST") @webmethod(route="/post-training/preference-optimize", method="POST")
async def preference_optimize( async def preference_optimize(
@ -193,16 +205,49 @@ class PostTraining(Protocol):
training_config: TrainingConfig, training_config: TrainingConfig,
hyperparam_search_config: dict[str, Any], hyperparam_search_config: dict[str, Any],
logger_config: dict[str, Any], logger_config: dict[str, Any],
) -> PostTrainingJob: ... ) -> PostTrainingJob:
"""Run preference optimization of a model.
:param job_uuid: The UUID of the job to create.
:param finetuned_model: The model to fine-tune.
:param algorithm_config: The algorithm configuration.
:param training_config: The training configuration.
:param hyperparam_search_config: The hyperparam search configuration.
:param logger_config: The logger configuration.
:returns: A PostTrainingJob.
"""
...
@webmethod(route="/post-training/jobs", method="GET") @webmethod(route="/post-training/jobs", method="GET")
async def get_training_jobs(self) -> ListPostTrainingJobsResponse: ... async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
"""Get all training jobs.
:returns: A ListPostTrainingJobsResponse.
"""
...
@webmethod(route="/post-training/job/status", method="GET") @webmethod(route="/post-training/job/status", method="GET")
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse: ... async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse:
"""Get the status of a training job.
:param job_uuid: The UUID of the job to get the status of.
:returns: A PostTrainingJobStatusResponse.
"""
...
@webmethod(route="/post-training/job/cancel", method="POST") @webmethod(route="/post-training/job/cancel", method="POST")
async def cancel_training_job(self, job_uuid: str) -> None: ... async def cancel_training_job(self, job_uuid: str) -> None:
"""Cancel a training job.
:param job_uuid: The UUID of the job to cancel.
"""
...
@webmethod(route="/post-training/job/artifacts", method="GET") @webmethod(route="/post-training/job/artifacts", method="GET")
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse: ... async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
"""Get the artifacts of a training job.
:param job_uuid: The UUID of the job to get the artifacts of.
:returns: A PostTrainingJobArtifactsResponse.
"""
...

View file

@ -32,7 +32,18 @@ class Providers(Protocol):
""" """
@webmethod(route="/providers", method="GET") @webmethod(route="/providers", method="GET")
async def list_providers(self) -> ListProvidersResponse: ... async def list_providers(self) -> ListProvidersResponse:
"""List all available providers.
:returns: A ListProvidersResponse containing information about all providers.
"""
...
@webmethod(route="/providers/{provider_id}", method="GET") @webmethod(route="/providers/{provider_id}", method="GET")
async def inspect_provider(self, provider_id: str) -> ProviderInfo: ... async def inspect_provider(self, provider_id: str) -> ProviderInfo:
"""Get detailed information about a specific provider.
:param provider_id: The ID of the provider to inspect.
:returns: A ProviderInfo object containing the provider's details.
"""
...

View file

@ -4,12 +4,23 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import sys
from enum import Enum from enum import Enum
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
# TODO: use enum.StrEnum when we drop support for python 3.10
if sys.version_info >= (3, 11):
from enum import StrEnum
else:
class ResourceType(Enum): class StrEnum(str, Enum):
"""Backport of StrEnum for Python 3.10 and below."""
pass
class ResourceType(StrEnum):
model = "model" model = "model"
shield = "shield" shield = "shield"
vector_db = "vector_db" vector_db = "vector_db"
@ -25,9 +36,9 @@ class Resource(BaseModel):
identifier: str = Field(description="Unique identifier for this resource in llama stack") identifier: str = Field(description="Unique identifier for this resource in llama stack")
provider_resource_id: str = Field( provider_resource_id: str | None = Field(
description="Unique identifier for this resource in the provider",
default=None, default=None,
description="Unique identifier for this resource in the provider",
) )
provider_id: str = Field(description="ID of the provider that owns this resource") provider_id: str = Field(description="ID of the provider that owns this resource")

View file

@ -53,5 +53,13 @@ class Safety(Protocol):
self, self,
shield_id: str, shield_id: str,
messages: list[Message], messages: list[Message],
params: dict[str, Any] = None, params: dict[str, Any],
) -> RunShieldResponse: ... ) -> RunShieldResponse:
"""Run a shield.
:param shield_id: The identifier of the shield to run.
:param messages: The messages to run the shield on.
:param params: The parameters of the shield.
:returns: A RunShieldResponse.
"""
...

View file

@ -61,7 +61,15 @@ class Scoring(Protocol):
dataset_id: str, dataset_id: str,
scoring_functions: dict[str, ScoringFnParams | None], scoring_functions: dict[str, ScoringFnParams | None],
save_results_dataset: bool = False, save_results_dataset: bool = False,
) -> ScoreBatchResponse: ... ) -> ScoreBatchResponse:
"""Score a batch of rows.
:param dataset_id: The ID of the dataset to score.
:param scoring_functions: The scoring functions to use for the scoring.
:param save_results_dataset: Whether to save the results to a dataset.
:returns: A ScoreBatchResponse.
"""
...
@webmethod(route="/scoring/score", method="POST") @webmethod(route="/scoring/score", method="POST")
async def score( async def score(
@ -73,6 +81,6 @@ class Scoring(Protocol):
:param input_rows: The rows to score. :param input_rows: The rows to score.
:param scoring_functions: The scoring functions to use for the scoring. :param scoring_functions: The scoring functions to use for the scoring.
:return: ScoreResponse object containing rows and aggregated results :returns: A ScoreResponse object containing rows and aggregated results.
""" """
... ...

View file

@ -4,6 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# TODO: use enum.StrEnum when we drop support for python 3.10
import sys
from enum import Enum from enum import Enum
from typing import ( from typing import (
Annotated, Annotated,
@ -19,18 +21,27 @@ from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.resource import Resource, ResourceType from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
if sys.version_info >= (3, 11):
from enum import StrEnum
else:
class StrEnum(str, Enum):
"""Backport of StrEnum for Python 3.10 and below."""
pass
# Perhaps more structure can be imposed on these functions. Maybe they could be associated # Perhaps more structure can be imposed on these functions. Maybe they could be associated
# with standard metrics so they can be rolled up? # with standard metrics so they can be rolled up?
@json_schema_type @json_schema_type
class ScoringFnParamsType(Enum): class ScoringFnParamsType(StrEnum):
llm_as_judge = "llm_as_judge" llm_as_judge = "llm_as_judge"
regex_parser = "regex_parser" regex_parser = "regex_parser"
basic = "basic" basic = "basic"
@json_schema_type @json_schema_type
class AggregationFunctionType(Enum): class AggregationFunctionType(StrEnum):
average = "average" average = "average"
weighted_average = "weighted_average" weighted_average = "weighted_average"
median = "median" median = "median"
@ -40,36 +51,36 @@ class AggregationFunctionType(Enum):
@json_schema_type @json_schema_type
class LLMAsJudgeScoringFnParams(BaseModel): class LLMAsJudgeScoringFnParams(BaseModel):
type: Literal[ScoringFnParamsType.llm_as_judge.value] = ScoringFnParamsType.llm_as_judge.value type: Literal[ScoringFnParamsType.llm_as_judge] = ScoringFnParamsType.llm_as_judge
judge_model: str judge_model: str
prompt_template: str | None = None prompt_template: str | None = None
judge_score_regexes: list[str] | None = Field( judge_score_regexes: list[str] = Field(
description="Regexes to extract the answer from generated response", description="Regexes to extract the answer from generated response",
default_factory=list, default_factory=lambda: [],
) )
aggregation_functions: list[AggregationFunctionType] | None = Field( aggregation_functions: list[AggregationFunctionType] = Field(
description="Aggregation functions to apply to the scores of each row", description="Aggregation functions to apply to the scores of each row",
default_factory=list, default_factory=lambda: [],
) )
@json_schema_type @json_schema_type
class RegexParserScoringFnParams(BaseModel): class RegexParserScoringFnParams(BaseModel):
type: Literal[ScoringFnParamsType.regex_parser.value] = ScoringFnParamsType.regex_parser.value type: Literal[ScoringFnParamsType.regex_parser] = ScoringFnParamsType.regex_parser
parsing_regexes: list[str] | None = Field( parsing_regexes: list[str] = Field(
description="Regex to extract the answer from generated response", description="Regex to extract the answer from generated response",
default_factory=list, default_factory=lambda: [],
) )
aggregation_functions: list[AggregationFunctionType] | None = Field( aggregation_functions: list[AggregationFunctionType] = Field(
description="Aggregation functions to apply to the scores of each row", description="Aggregation functions to apply to the scores of each row",
default_factory=list, default_factory=lambda: [],
) )
@json_schema_type @json_schema_type
class BasicScoringFnParams(BaseModel): class BasicScoringFnParams(BaseModel):
type: Literal[ScoringFnParamsType.basic.value] = ScoringFnParamsType.basic.value type: Literal[ScoringFnParamsType.basic] = ScoringFnParamsType.basic
aggregation_functions: list[AggregationFunctionType] | None = Field( aggregation_functions: list[AggregationFunctionType] = Field(
description="Aggregation functions to apply to the scores of each row", description="Aggregation functions to apply to the scores of each row",
default_factory=list, default_factory=list,
) )
@ -99,14 +110,14 @@ class CommonScoringFnFields(BaseModel):
@json_schema_type @json_schema_type
class ScoringFn(CommonScoringFnFields, Resource): class ScoringFn(CommonScoringFnFields, Resource):
type: Literal[ResourceType.scoring_function.value] = ResourceType.scoring_function.value type: Literal[ResourceType.scoring_function] = ResourceType.scoring_function
@property @property
def scoring_fn_id(self) -> str: def scoring_fn_id(self) -> str:
return self.identifier return self.identifier
@property @property
def provider_scoring_fn_id(self) -> str: def provider_scoring_fn_id(self) -> str | None:
return self.provider_resource_id return self.provider_resource_id
@ -123,10 +134,21 @@ class ListScoringFunctionsResponse(BaseModel):
@runtime_checkable @runtime_checkable
class ScoringFunctions(Protocol): class ScoringFunctions(Protocol):
@webmethod(route="/scoring-functions", method="GET") @webmethod(route="/scoring-functions", method="GET")
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ... async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
"""List all scoring functions.
:returns: A ListScoringFunctionsResponse.
"""
...
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET") @webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET")
async def get_scoring_function(self, scoring_fn_id: str, /) -> ScoringFn: ... async def get_scoring_function(self, scoring_fn_id: str, /) -> ScoringFn:
"""Get a scoring function by its ID.
:param scoring_fn_id: The ID of the scoring function to get.
:returns: A ScoringFn.
"""
...
@webmethod(route="/scoring-functions", method="POST") @webmethod(route="/scoring-functions", method="POST")
async def register_scoring_function( async def register_scoring_function(
@ -137,4 +159,14 @@ class ScoringFunctions(Protocol):
provider_scoring_fn_id: str | None = None, provider_scoring_fn_id: str | None = None,
provider_id: str | None = None, provider_id: str | None = None,
params: ScoringFnParams | None = None, params: ScoringFnParams | None = None,
) -> None: ... ) -> None:
"""Register a scoring function.
:param scoring_fn_id: The ID of the scoring function to register.
:param description: The description of the scoring function.
:param return_type: The return type of the scoring function.
:param provider_scoring_fn_id: The ID of the provider scoring function to use for the scoring function.
:param provider_id: The ID of the provider to use for the scoring function.
:param params: The parameters for the scoring function for benchmark eval, these can be overridden for app eval.
"""
...

View file

@ -21,14 +21,14 @@ class CommonShieldFields(BaseModel):
class Shield(CommonShieldFields, Resource): class Shield(CommonShieldFields, Resource):
"""A safety shield resource that can be used to check content""" """A safety shield resource that can be used to check content"""
type: Literal[ResourceType.shield.value] = ResourceType.shield.value type: Literal[ResourceType.shield] = ResourceType.shield
@property @property
def shield_id(self) -> str: def shield_id(self) -> str:
return self.identifier return self.identifier
@property @property
def provider_shield_id(self) -> str: def provider_shield_id(self) -> str | None:
return self.provider_resource_id return self.provider_resource_id
@ -46,10 +46,21 @@ class ListShieldsResponse(BaseModel):
@trace_protocol @trace_protocol
class Shields(Protocol): class Shields(Protocol):
@webmethod(route="/shields", method="GET") @webmethod(route="/shields", method="GET")
async def list_shields(self) -> ListShieldsResponse: ... async def list_shields(self) -> ListShieldsResponse:
"""List all shields.
:returns: A ListShieldsResponse.
"""
...
@webmethod(route="/shields/{identifier:path}", method="GET") @webmethod(route="/shields/{identifier:path}", method="GET")
async def get_shield(self, identifier: str) -> Shield: ... async def get_shield(self, identifier: str) -> Shield:
"""Get a shield by its identifier.
:param identifier: The identifier of the shield to get.
:returns: A Shield.
"""
...
@webmethod(route="/shields", method="POST") @webmethod(route="/shields", method="POST")
async def register_shield( async def register_shield(
@ -58,4 +69,13 @@ class Shields(Protocol):
provider_shield_id: str | None = None, provider_shield_id: str | None = None,
provider_id: str | None = None, provider_id: str | None = None,
params: dict[str, Any] | None = None, params: dict[str, Any] | None = None,
) -> Shield: ... ) -> Shield:
"""Register a shield.
:param shield_id: The identifier of the shield to register.
:param provider_shield_id: The identifier of the shield in the provider.
:param provider_id: The identifier of the provider.
:param params: The parameters of the shield.
:returns: A Shield.
"""
...

View file

@ -37,7 +37,7 @@ class Span(BaseModel):
name: str name: str
start_time: datetime start_time: datetime
end_time: datetime | None = None end_time: datetime | None = None
attributes: dict[str, Any] | None = Field(default_factory=dict) attributes: dict[str, Any] | None = Field(default_factory=lambda: {})
def set_attribute(self, key: str, value: Any): def set_attribute(self, key: str, value: Any):
if self.attributes is None: if self.attributes is None:
@ -74,19 +74,19 @@ class EventCommon(BaseModel):
trace_id: str trace_id: str
span_id: str span_id: str
timestamp: datetime timestamp: datetime
attributes: dict[str, Primitive] | None = Field(default_factory=dict) attributes: dict[str, Primitive] | None = Field(default_factory=lambda: {})
@json_schema_type @json_schema_type
class UnstructuredLogEvent(EventCommon): class UnstructuredLogEvent(EventCommon):
type: Literal[EventType.UNSTRUCTURED_LOG.value] = EventType.UNSTRUCTURED_LOG.value type: Literal[EventType.UNSTRUCTURED_LOG] = EventType.UNSTRUCTURED_LOG
message: str message: str
severity: LogSeverity severity: LogSeverity
@json_schema_type @json_schema_type
class MetricEvent(EventCommon): class MetricEvent(EventCommon):
type: Literal[EventType.METRIC.value] = EventType.METRIC.value type: Literal[EventType.METRIC] = EventType.METRIC
metric: str # this would be an enum metric: str # this would be an enum
value: int | float value: int | float
unit: str unit: str
@ -131,14 +131,14 @@ class StructuredLogType(Enum):
@json_schema_type @json_schema_type
class SpanStartPayload(BaseModel): class SpanStartPayload(BaseModel):
type: Literal[StructuredLogType.SPAN_START.value] = StructuredLogType.SPAN_START.value type: Literal[StructuredLogType.SPAN_START] = StructuredLogType.SPAN_START
name: str name: str
parent_span_id: str | None = None parent_span_id: str | None = None
@json_schema_type @json_schema_type
class SpanEndPayload(BaseModel): class SpanEndPayload(BaseModel):
type: Literal[StructuredLogType.SPAN_END.value] = StructuredLogType.SPAN_END.value type: Literal[StructuredLogType.SPAN_END] = StructuredLogType.SPAN_END
status: SpanStatus status: SpanStatus
@ -151,7 +151,7 @@ register_schema(StructuredLogPayload, name="StructuredLogPayload")
@json_schema_type @json_schema_type
class StructuredLogEvent(EventCommon): class StructuredLogEvent(EventCommon):
type: Literal[EventType.STRUCTURED_LOG.value] = EventType.STRUCTURED_LOG.value type: Literal[EventType.STRUCTURED_LOG] = EventType.STRUCTURED_LOG
payload: StructuredLogPayload payload: StructuredLogPayload
@ -203,10 +203,61 @@ class QuerySpanTreeResponse(BaseModel):
data: dict[str, SpanWithStatus] data: dict[str, SpanWithStatus]
class MetricQueryType(Enum):
RANGE = "range"
INSTANT = "instant"
class MetricLabelOperator(Enum):
EQUALS = "="
NOT_EQUALS = "!="
REGEX_MATCH = "=~"
REGEX_NOT_MATCH = "!~"
class MetricLabelMatcher(BaseModel):
name: str
value: str
operator: MetricLabelOperator = MetricLabelOperator.EQUALS
@json_schema_type
class MetricLabel(BaseModel):
name: str
value: str
@json_schema_type
class MetricDataPoint(BaseModel):
timestamp: int
value: float
@json_schema_type
class MetricSeries(BaseModel):
metric: str
labels: list[MetricLabel]
values: list[MetricDataPoint]
class QueryMetricsResponse(BaseModel):
data: list[MetricSeries]
@runtime_checkable @runtime_checkable
class Telemetry(Protocol): class Telemetry(Protocol):
@webmethod(route="/telemetry/events", method="POST") @webmethod(route="/telemetry/events", method="POST")
async def log_event(self, event: Event, ttl_seconds: int = DEFAULT_TTL_DAYS * 86400) -> None: ... async def log_event(
self,
event: Event,
ttl_seconds: int = DEFAULT_TTL_DAYS * 86400,
) -> None:
"""Log an event.
:param event: The event to log.
:param ttl_seconds: The time to live of the event.
"""
...
@webmethod(route="/telemetry/traces", method="POST") @webmethod(route="/telemetry/traces", method="POST")
async def query_traces( async def query_traces(
@ -215,13 +266,35 @@ class Telemetry(Protocol):
limit: int | None = 100, limit: int | None = 100,
offset: int | None = 0, offset: int | None = 0,
order_by: list[str] | None = None, order_by: list[str] | None = None,
) -> QueryTracesResponse: ... ) -> QueryTracesResponse:
"""Query traces.
:param attribute_filters: The attribute filters to apply to the traces.
:param limit: The limit of traces to return.
:param offset: The offset of the traces to return.
:param order_by: The order by of the traces to return.
:returns: A QueryTracesResponse.
"""
...
@webmethod(route="/telemetry/traces/{trace_id:path}", method="GET") @webmethod(route="/telemetry/traces/{trace_id:path}", method="GET")
async def get_trace(self, trace_id: str) -> Trace: ... async def get_trace(self, trace_id: str) -> Trace:
"""Get a trace by its ID.
:param trace_id: The ID of the trace to get.
:returns: A Trace.
"""
...
@webmethod(route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}", method="GET") @webmethod(route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}", method="GET")
async def get_span(self, trace_id: str, span_id: str) -> Span: ... async def get_span(self, trace_id: str, span_id: str) -> Span:
"""Get a span by its ID.
:param trace_id: The ID of the trace to get the span from.
:param span_id: The ID of the span to get.
:returns: A Span.
"""
...
@webmethod(route="/telemetry/spans/{span_id:path}/tree", method="POST") @webmethod(route="/telemetry/spans/{span_id:path}/tree", method="POST")
async def get_span_tree( async def get_span_tree(
@ -229,7 +302,15 @@ class Telemetry(Protocol):
span_id: str, span_id: str,
attributes_to_return: list[str] | None = None, attributes_to_return: list[str] | None = None,
max_depth: int | None = None, max_depth: int | None = None,
) -> QuerySpanTreeResponse: ... ) -> QuerySpanTreeResponse:
"""Get a span tree by its ID.
:param span_id: The ID of the span to get the tree from.
:param attributes_to_return: The attributes to return in the tree.
:param max_depth: The maximum depth of the tree.
:returns: A QuerySpanTreeResponse.
"""
...
@webmethod(route="/telemetry/spans", method="POST") @webmethod(route="/telemetry/spans", method="POST")
async def query_spans( async def query_spans(
@ -237,7 +318,15 @@ class Telemetry(Protocol):
attribute_filters: list[QueryCondition], attribute_filters: list[QueryCondition],
attributes_to_return: list[str], attributes_to_return: list[str],
max_depth: int | None = None, max_depth: int | None = None,
) -> QuerySpansResponse: ... ) -> QuerySpansResponse:
"""Query spans.
:param attribute_filters: The attribute filters to apply to the spans.
:param attributes_to_return: The attributes to return in the spans.
:param max_depth: The maximum depth of the tree.
:returns: A QuerySpansResponse.
"""
...
@webmethod(route="/telemetry/spans/export", method="POST") @webmethod(route="/telemetry/spans/export", method="POST")
async def save_spans_to_dataset( async def save_spans_to_dataset(
@ -246,4 +335,34 @@ class Telemetry(Protocol):
attributes_to_save: list[str], attributes_to_save: list[str],
dataset_id: str, dataset_id: str,
max_depth: int | None = None, max_depth: int | None = None,
) -> None: ... ) -> None:
"""Save spans to a dataset.
:param attribute_filters: The attribute filters to apply to the spans.
:param attributes_to_save: The attributes to save to the dataset.
:param dataset_id: The ID of the dataset to save the spans to.
:param max_depth: The maximum depth of the tree.
"""
...
@webmethod(route="/telemetry/metrics/{metric_name}", method="POST")
async def query_metrics(
self,
metric_name: str,
start_time: int,
end_time: int | None = None,
granularity: str | None = "1d",
query_type: MetricQueryType = MetricQueryType.RANGE,
label_matchers: list[MetricLabelMatcher] | None = None,
) -> QueryMetricsResponse:
"""Query metrics.
:param metric_name: The name of the metric to query.
:param start_time: The start time of the metric to query.
:param end_time: The end time of the metric to query.
:param granularity: The granularity of the metric to query.
:param query_type: The type of query to perform.
:param label_matchers: The label matchers to apply to the metric.
:returns: A QueryMetricsResponse.
"""
...

View file

@ -7,7 +7,7 @@
from enum import Enum from enum import Enum
from typing import Annotated, Any, Literal from typing import Annotated, Any, Literal
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, field_validator
from typing_extensions import Protocol, runtime_checkable from typing_extensions import Protocol, runtime_checkable
from llama_stack.apis.common.content_types import URL, InterleavedContent from llama_stack.apis.common.content_types import URL, InterleavedContent
@ -67,11 +67,33 @@ register_schema(RAGQueryGeneratorConfig, name="RAGQueryGeneratorConfig")
@json_schema_type @json_schema_type
class RAGQueryConfig(BaseModel): class RAGQueryConfig(BaseModel):
"""
Configuration for the RAG query generation.
:param query_generator_config: Configuration for the query generator.
:param max_tokens_in_context: Maximum number of tokens in the context.
:param max_chunks: Maximum number of chunks to retrieve.
:param chunk_template: Template for formatting each retrieved chunk in the context.
Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict).
Default: "Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n"
"""
# This config defines how a query is generated using the messages # This config defines how a query is generated using the messages
# for memory bank retrieval. # for memory bank retrieval.
query_generator_config: RAGQueryGeneratorConfig = Field(default=DefaultRAGQueryGeneratorConfig()) query_generator_config: RAGQueryGeneratorConfig = Field(default=DefaultRAGQueryGeneratorConfig())
max_tokens_in_context: int = 4096 max_tokens_in_context: int = 4096
max_chunks: int = 5 max_chunks: int = 5
chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n"
@field_validator("chunk_template")
def validate_chunk_template(cls, v: str) -> str:
if "{chunk.content}" not in v:
raise ValueError("chunk_template must contain {chunk.content}")
if "{index}" not in v:
raise ValueError("chunk_template must contain {index}")
if len(v) == 0:
raise ValueError("chunk_template must not be empty")
return v
@runtime_checkable @runtime_checkable

View file

@ -36,7 +36,7 @@ class ToolHost(Enum):
@json_schema_type @json_schema_type
class Tool(Resource): class Tool(Resource):
type: Literal[ResourceType.tool.value] = ResourceType.tool.value type: Literal[ResourceType.tool] = ResourceType.tool
toolgroup_id: str toolgroup_id: str
tool_host: ToolHost tool_host: ToolHost
description: str description: str
@ -62,7 +62,7 @@ class ToolGroupInput(BaseModel):
@json_schema_type @json_schema_type
class ToolGroup(Resource): class ToolGroup(Resource):
type: Literal[ResourceType.tool_group.value] = ResourceType.tool_group.value type: Literal[ResourceType.tool_group] = ResourceType.tool_group
mcp_endpoint: URL | None = None mcp_endpoint: URL | None = None
args: dict[str, Any] | None = None args: dict[str, Any] | None = None
@ -103,37 +103,65 @@ class ToolGroups(Protocol):
mcp_endpoint: URL | None = None, mcp_endpoint: URL | None = None,
args: dict[str, Any] | None = None, args: dict[str, Any] | None = None,
) -> None: ) -> None:
"""Register a tool group""" """Register a tool group.
:param toolgroup_id: The ID of the tool group to register.
:param provider_id: The ID of the provider to use for the tool group.
:param mcp_endpoint: The MCP endpoint to use for the tool group.
:param args: A dictionary of arguments to pass to the tool group.
"""
... ...
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="GET") @webmethod(route="/toolgroups/{toolgroup_id:path}", method="GET")
async def get_tool_group( async def get_tool_group(
self, self,
toolgroup_id: str, toolgroup_id: str,
) -> ToolGroup: ... ) -> ToolGroup:
"""Get a tool group by its ID.
:param toolgroup_id: The ID of the tool group to get.
:returns: A ToolGroup.
"""
...
@webmethod(route="/toolgroups", method="GET") @webmethod(route="/toolgroups", method="GET")
async def list_tool_groups(self) -> ListToolGroupsResponse: async def list_tool_groups(self) -> ListToolGroupsResponse:
"""List tool groups with optional provider""" """List tool groups with optional provider.
:returns: A ListToolGroupsResponse.
"""
... ...
@webmethod(route="/tools", method="GET") @webmethod(route="/tools", method="GET")
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse: async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
"""List tools with optional tool group""" """List tools with optional tool group.
:param toolgroup_id: The ID of the tool group to list tools for.
:returns: A ListToolsResponse.
"""
... ...
@webmethod(route="/tools/{tool_name:path}", method="GET") @webmethod(route="/tools/{tool_name:path}", method="GET")
async def get_tool( async def get_tool(
self, self,
tool_name: str, tool_name: str,
) -> Tool: ... ) -> Tool:
"""Get a tool by its name.
:param tool_name: The name of the tool to get.
:returns: A Tool.
"""
...
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="DELETE") @webmethod(route="/toolgroups/{toolgroup_id:path}", method="DELETE")
async def unregister_toolgroup( async def unregister_toolgroup(
self, self,
toolgroup_id: str, toolgroup_id: str,
) -> None: ) -> None:
"""Unregister a tool group""" """Unregister a tool group.
:param toolgroup_id: The ID of the tool group to unregister.
"""
... ...
@ -152,9 +180,21 @@ class ToolRuntime(Protocol):
@webmethod(route="/tool-runtime/list-tools", method="GET") @webmethod(route="/tool-runtime/list-tools", method="GET")
async def list_runtime_tools( async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
) -> ListToolDefsResponse: ... ) -> ListToolDefsResponse:
"""List all tools in the runtime.
:param tool_group_id: The ID of the tool group to list tools for.
:param mcp_endpoint: The MCP endpoint to use for the tool group.
:returns: A ListToolDefsResponse.
"""
...
@webmethod(route="/tool-runtime/invoke", method="POST") @webmethod(route="/tool-runtime/invoke", method="POST")
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult: async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
"""Run a tool with the given arguments""" """Run a tool with the given arguments.
:param tool_name: The name of the tool to invoke.
:param kwargs: A dictionary of arguments to pass to the tool.
:returns: A ToolInvocationResult.
"""
... ...

View file

@ -15,7 +15,7 @@ from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type @json_schema_type
class VectorDB(Resource): class VectorDB(Resource):
type: Literal[ResourceType.vector_db.value] = ResourceType.vector_db.value type: Literal[ResourceType.vector_db] = ResourceType.vector_db
embedding_model: str embedding_model: str
embedding_dimension: int embedding_dimension: int
@ -25,7 +25,7 @@ class VectorDB(Resource):
return self.identifier return self.identifier
@property @property
def provider_vector_db_id(self) -> str: def provider_vector_db_id(self) -> str | None:
return self.provider_resource_id return self.provider_resource_id
@ -44,13 +44,24 @@ class ListVectorDBsResponse(BaseModel):
@trace_protocol @trace_protocol
class VectorDBs(Protocol): class VectorDBs(Protocol):
@webmethod(route="/vector-dbs", method="GET") @webmethod(route="/vector-dbs", method="GET")
async def list_vector_dbs(self) -> ListVectorDBsResponse: ... async def list_vector_dbs(self) -> ListVectorDBsResponse:
"""List all vector databases.
:returns: A ListVectorDBsResponse.
"""
...
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="GET") @webmethod(route="/vector-dbs/{vector_db_id:path}", method="GET")
async def get_vector_db( async def get_vector_db(
self, self,
vector_db_id: str, vector_db_id: str,
) -> VectorDB: ... ) -> VectorDB:
"""Get a vector database by its identifier.
:param vector_db_id: The identifier of the vector database to get.
:returns: A VectorDB.
"""
...
@webmethod(route="/vector-dbs", method="POST") @webmethod(route="/vector-dbs", method="POST")
async def register_vector_db( async def register_vector_db(
@ -60,7 +71,22 @@ class VectorDBs(Protocol):
embedding_dimension: int | None = 384, embedding_dimension: int | None = 384,
provider_id: str | None = None, provider_id: str | None = None,
provider_vector_db_id: str | None = None, provider_vector_db_id: str | None = None,
) -> VectorDB: ... ) -> VectorDB:
"""Register a vector database.
:param vector_db_id: The identifier of the vector database to register.
:param embedding_model: The embedding model to use.
:param embedding_dimension: The dimension of the embedding model.
:param provider_id: The identifier of the provider.
:param provider_vector_db_id: The identifier of the vector database in the provider.
:returns: A VectorDB.
"""
...
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="DELETE") @webmethod(route="/vector-dbs/{vector_db_id:path}", method="DELETE")
async def unregister_vector_db(self, vector_db_id: str) -> None: ... async def unregister_vector_db(self, vector_db_id: str) -> None:
"""Unregister a vector database.
:param vector_db_id: The identifier of the vector database to unregister.
"""
...

View file

@ -46,7 +46,14 @@ class VectorIO(Protocol):
vector_db_id: str, vector_db_id: str,
chunks: list[Chunk], chunks: list[Chunk],
ttl_seconds: int | None = None, ttl_seconds: int | None = None,
) -> None: ... ) -> None:
"""Insert chunks into a vector database.
:param vector_db_id: The identifier of the vector database to insert the chunks into.
:param chunks: The chunks to insert.
:param ttl_seconds: The time to live of the chunks.
"""
...
@webmethod(route="/vector-io/query", method="POST") @webmethod(route="/vector-io/query", method="POST")
async def query_chunks( async def query_chunks(
@ -54,4 +61,12 @@ class VectorIO(Protocol):
vector_db_id: str, vector_db_id: str,
query: InterleavedContent, query: InterleavedContent,
params: dict[str, Any] | None = None, params: dict[str, Any] | None = None,
) -> QueryChunksResponse: ... ) -> QueryChunksResponse:
"""Query chunks from a vector database.
:param vector_db_id: The identifier of the vector database to query.
:param query: The query to search for.
:param params: The parameters of the query.
:returns: A QueryChunksResponse.
"""
...

View file

@ -38,7 +38,10 @@ class LlamaCLIParser:
print_subcommand_description(self.parser, subparsers) print_subcommand_description(self.parser, subparsers)
def parse_args(self) -> argparse.Namespace: def parse_args(self) -> argparse.Namespace:
return self.parser.parse_args() args = self.parser.parse_args()
if not isinstance(args, argparse.Namespace):
raise TypeError(f"Expected argparse.Namespace, got {type(args)}")
return args
def run(self, args: argparse.Namespace) -> None: def run(self, args: argparse.Namespace) -> None:
args.func(args) args.func(args)

View file

@ -12,6 +12,7 @@ import shutil
import sys import sys
import textwrap import textwrap
from functools import lru_cache from functools import lru_cache
from importlib.abc import Traversable
from pathlib import Path from pathlib import Path
import yaml import yaml
@ -36,7 +37,8 @@ from llama_stack.distribution.datatypes import (
) )
from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.resolver import InvalidProviderError from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR from llama_stack.distribution.stack import replace_env_vars
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR, EXTERNAL_PROVIDERS_DIR
from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.distribution.utils.exec import formulate_run_args, run_command from llama_stack.distribution.utils.exec import formulate_run_args, run_command
from llama_stack.distribution.utils.image_types import LlamaStackImageType from llama_stack.distribution.utils.image_types import LlamaStackImageType
@ -202,7 +204,11 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
else: else:
with open(args.config) as f: with open(args.config) as f:
try: try:
build_config = BuildConfig(**yaml.safe_load(f)) contents = yaml.safe_load(f)
contents = replace_env_vars(contents)
build_config = BuildConfig(**contents)
if args.image_type:
build_config.image_type = args.image_type
except Exception as e: except Exception as e:
cprint( cprint(
f"Could not parse config file {args.config}: {e}", f"Could not parse config file {args.config}: {e}",
@ -245,11 +251,12 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
sys.exit(1) sys.exit(1)
if args.run: if args.run:
run_config = Path(run_config)
config_dict = yaml.safe_load(run_config.read_text()) config_dict = yaml.safe_load(run_config.read_text())
config = parse_and_maybe_upgrade_config(config_dict) config = parse_and_maybe_upgrade_config(config_dict)
if not os.path.exists(config.external_providers_dir):
os.makedirs(config.external_providers_dir, exist_ok=True)
run_args = formulate_run_args(args.image_type, args.image_name, config, args.template) run_args = formulate_run_args(args.image_type, args.image_name, config, args.template)
run_args.extend([run_config, str(os.getenv("LLAMA_STACK_PORT", 8321))]) run_args.extend([str(os.getenv("LLAMA_STACK_PORT", 8321)), "--config", run_config])
run_command(run_args) run_command(run_args)
@ -257,7 +264,7 @@ def _generate_run_config(
build_config: BuildConfig, build_config: BuildConfig,
build_dir: Path, build_dir: Path,
image_name: str, image_name: str,
) -> str: ) -> Path:
""" """
Generate a run.yaml template file for user to edit from a build.yaml file Generate a run.yaml template file for user to edit from a build.yaml file
""" """
@ -267,7 +274,9 @@ def _generate_run_config(
image_name=image_name, image_name=image_name,
apis=apis, apis=apis,
providers={}, providers={},
external_providers_dir=build_config.external_providers_dir if build_config.external_providers_dir else None, external_providers_dir=build_config.external_providers_dir
if build_config.external_providers_dir
else EXTERNAL_PROVIDERS_DIR,
) )
# build providers dict # build providers dict
provider_registry = get_provider_registry(build_config) provider_registry = get_provider_registry(build_config)
@ -334,7 +343,7 @@ def _run_stack_build_command_from_build_config(
image_name: str | None = None, image_name: str | None = None,
template_name: str | None = None, template_name: str | None = None,
config_path: str | None = None, config_path: str | None = None,
) -> str: ) -> Path | Traversable:
image_name = image_name or build_config.image_name image_name = image_name or build_config.image_name
if build_config.image_type == LlamaStackImageType.CONTAINER.value: if build_config.image_type == LlamaStackImageType.CONTAINER.value:
if template_name: if template_name:

View file

@ -49,7 +49,7 @@ class StackBuild(Subcommand):
type=str, type=str,
help="Image Type to use for the build. If not specified, will use the image type from the template config.", help="Image Type to use for the build. If not specified, will use the image type from the template config.",
choices=[e.value for e in ImageType], choices=[e.value for e in ImageType],
default=ImageType.CONDA.value, default=None, # no default so we can detect if a user specified --image-type and override image_type in the config
) )
self.parser.add_argument( self.parser.add_argument(

View file

@ -46,7 +46,7 @@ class StackListProviders(Subcommand):
else: else:
providers = [(k.value, prov) for k, prov in all_providers.items()] providers = [(k.value, prov) for k, prov in all_providers.items()]
providers = [p for api, p in providers if api in self.providable_apis] providers = [(api, p) for api, p in providers if api in self.providable_apis]
# eventually, this should query a registry at llama.meta.com/llamastack/distributions # eventually, this should query a registry at llama.meta.com/llamastack/distributions
headers = [ headers = [
@ -57,7 +57,7 @@ class StackListProviders(Subcommand):
rows = [] rows = []
specs = [spec for p in providers for spec in p.values()] specs = [spec for api, p in providers for spec in p.values()]
for spec in specs: for spec in specs:
if spec.is_sample: if spec.is_sample:
continue continue
@ -65,7 +65,7 @@ class StackListProviders(Subcommand):
[ [
spec.api.value, spec.api.value,
spec.provider_type, spec.provider_type,
",".join(spec.pip_packages), ",".join(spec.pip_packages) if hasattr(spec, "pip_packages") else "",
] ]
) )
print_table( print_table(

View file

@ -33,7 +33,8 @@ class StackRun(Subcommand):
self.parser.add_argument( self.parser.add_argument(
"config", "config",
type=str, type=str,
help="Path to config file to use for the run", nargs="?", # Make it optional
help="Path to config file to use for the run. Required for venv and conda environments.",
) )
self.parser.add_argument( self.parser.add_argument(
"--port", "--port",
@ -47,28 +48,12 @@ class StackRun(Subcommand):
default=os.environ.get("CONDA_DEFAULT_ENV"), default=os.environ.get("CONDA_DEFAULT_ENV"),
help="Name of the image to run. Defaults to the current environment", help="Name of the image to run. Defaults to the current environment",
) )
self.parser.add_argument(
"--disable-ipv6",
action="store_true",
help="Disable IPv6 support",
default=False,
)
self.parser.add_argument( self.parser.add_argument(
"--env", "--env",
action="append", action="append",
help="Environment variables to pass to the server in KEY=VALUE format. Can be specified multiple times.", help="Environment variables to pass to the server in KEY=VALUE format. Can be specified multiple times.",
metavar="KEY=VALUE", metavar="KEY=VALUE",
) )
self.parser.add_argument(
"--tls-keyfile",
type=str,
help="Path to TLS key file for HTTPS",
)
self.parser.add_argument(
"--tls-certfile",
type=str,
help="Path to TLS certificate file for HTTPS",
)
self.parser.add_argument( self.parser.add_argument(
"--image-type", "--image-type",
type=str, type=str,
@ -98,6 +83,13 @@ class StackRun(Subcommand):
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.distribution.utils.exec import formulate_run_args, run_command from llama_stack.distribution.utils.exec import formulate_run_args, run_command
image_type, image_name = self._get_image_type_and_name(args)
# Check if config is required based on image type
if (image_type in [ImageType.CONDA.value, ImageType.VENV.value]) and not args.config:
self.parser.error("Config file is required for venv and conda environments")
if args.config:
config_file = Path(args.config) config_file = Path(args.config)
has_yaml_suffix = args.config.endswith(".yaml") has_yaml_suffix = args.config.endswith(".yaml")
template_name = None template_name = None
@ -131,10 +123,14 @@ class StackRun(Subcommand):
try: try:
config = parse_and_maybe_upgrade_config(config_dict) config = parse_and_maybe_upgrade_config(config_dict)
if not os.path.exists(str(config.external_providers_dir)):
os.makedirs(str(config.external_providers_dir), exist_ok=True)
except AttributeError as e: except AttributeError as e:
self.parser.error(f"failed to parse config file '{config_file}':\n {e}") self.parser.error(f"failed to parse config file '{config_file}':\n {e}")
else:
image_type, image_name = self._get_image_type_and_name(args) config = None
config_file = None
template_name = None
# If neither image type nor image name is provided, assume the server should be run directly # If neither image type nor image name is provided, assume the server should be run directly
# using the current environment packages. # using the current environment packages.
@ -157,9 +153,10 @@ class StackRun(Subcommand):
else: else:
run_args = formulate_run_args(image_type, image_name, config, template_name) run_args = formulate_run_args(image_type, image_name, config, template_name)
run_args.extend([str(config_file), str(args.port)]) run_args.extend([str(args.port)])
if args.disable_ipv6:
run_args.append("--disable-ipv6") if config_file:
run_args.extend(["--config", str(config_file)])
if args.env: if args.env:
for env_var in args.env: for env_var in args.env:
@ -172,6 +169,4 @@ class StackRun(Subcommand):
return return
run_args.extend(["--env", f"{key}={value}"]) run_args.extend(["--env", f"{key}={value}"])
if args.tls_keyfile and args.tls_certfile:
run_args.extend(["--tls-keyfile", args.tls_keyfile, "--tls-certfile", args.tls_certfile])
run_command(run_args) run_command(run_args)

View file

@ -154,6 +154,12 @@ get_python_cmd() {
fi fi
} }
# Add other required item commands generic to all containers
add_to_container << EOF
# Allows running as non-root user
RUN mkdir -p /.llama/providers.d /.cache
EOF
if [ -n "$run_config" ]; then if [ -n "$run_config" ]; then
# Copy the run config to the build context since it's an absolute path # Copy the run config to the build context since it's an absolute path
cp "$run_config" "$BUILD_CONTEXT_DIR/run.yaml" cp "$run_config" "$BUILD_CONTEXT_DIR/run.yaml"
@ -166,17 +172,19 @@ EOF
# and update the configuration to reference the new container path # and update the configuration to reference the new container path
python_cmd=$(get_python_cmd) python_cmd=$(get_python_cmd)
external_providers_dir=$($python_cmd -c "import yaml; config = yaml.safe_load(open('$run_config')); print(config.get('external_providers_dir') or '')") external_providers_dir=$($python_cmd -c "import yaml; config = yaml.safe_load(open('$run_config')); print(config.get('external_providers_dir') or '')")
if [ -n "$external_providers_dir" ]; then external_providers_dir=$(eval echo "$external_providers_dir")
if [ -n "$external_providers_dir" ] && [ -d "$external_providers_dir" ]; then
echo "Copying external providers directory: $external_providers_dir" echo "Copying external providers directory: $external_providers_dir"
cp -r "$external_providers_dir" "$BUILD_CONTEXT_DIR/providers.d"
add_to_container << EOF add_to_container << EOF
COPY $external_providers_dir /app/providers.d COPY providers.d /.llama/providers.d
EOF EOF
# Edit the run.yaml file to change the external_providers_dir to /app/providers.d # Edit the run.yaml file to change the external_providers_dir to /.llama/providers.d
if [ "$(uname)" = "Darwin" ]; then if [ "$(uname)" = "Darwin" ]; then
sed -i.bak -e 's|external_providers_dir:.*|external_providers_dir: /app/providers.d|' "$BUILD_CONTEXT_DIR/run.yaml" sed -i.bak -e 's|external_providers_dir:.*|external_providers_dir: /.llama/providers.d|' "$BUILD_CONTEXT_DIR/run.yaml"
rm -f "$BUILD_CONTEXT_DIR/run.yaml.bak" rm -f "$BUILD_CONTEXT_DIR/run.yaml.bak"
else else
sed -i 's|external_providers_dir:.*|external_providers_dir: /app/providers.d|' "$BUILD_CONTEXT_DIR/run.yaml" sed -i 's|external_providers_dir:.*|external_providers_dir: /.llama/providers.d|' "$BUILD_CONTEXT_DIR/run.yaml"
fi fi
fi fi
fi fi
@ -255,9 +263,6 @@ fi
# Add other require item commands genearic to all containers # Add other require item commands genearic to all containers
add_to_container << EOF add_to_container << EOF
# Allows running as non-root user
RUN mkdir -p /.llama /.cache
RUN chmod -R g+rw /app /.llama /.cache RUN chmod -R g+rw /app /.llama /.cache
EOF EOF

View file

@ -17,6 +17,7 @@ from llama_stack.distribution.distribution import (
builtin_automatically_routed_apis, builtin_automatically_routed_apis,
get_provider_registry, get_provider_registry,
) )
from llama_stack.distribution.utils.config_dirs import EXTERNAL_PROVIDERS_DIR
from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
from llama_stack.providers.datatypes import Api, ProviderSpec from llama_stack.providers.datatypes import Api, ProviderSpec
@ -73,11 +74,7 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
existing_providers = config.providers.get(api_str, []) existing_providers = config.providers.get(api_str, [])
if existing_providers: if existing_providers:
logger.info( logger.info(f"Re-configuring existing providers for API `{api_str}`...")
f"Re-configuring existing providers for API `{api_str}`...",
"green",
attrs=["bold"],
)
updated_providers = [] updated_providers = []
for p in existing_providers: for p in existing_providers:
logger.info(f"> Configuring provider `({p.provider_type})`") logger.info(f"> Configuring provider `({p.provider_type})`")
@ -91,7 +88,7 @@ def configure_api_providers(config: StackRunConfig, build_spec: DistributionSpec
if not plist: if not plist:
raise ValueError(f"No provider configured for API {api_str}?") raise ValueError(f"No provider configured for API {api_str}?")
logger.info(f"Configuring API `{api_str}`...", "green", attrs=["bold"]) logger.info(f"Configuring API `{api_str}`...")
updated_providers = [] updated_providers = []
for i, provider_type in enumerate(plist): for i, provider_type in enumerate(plist):
if i >= 1: if i >= 1:
@ -174,4 +171,7 @@ def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfi
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION
if not config_dict.get("external_providers_dir", None):
config_dict["external_providers_dir"] = EXTERNAL_PROVIDERS_DIR
return StackRunConfig(**config_dict) return StackRunConfig(**config_dict)

View file

@ -5,9 +5,10 @@
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum from enum import Enum
from pathlib import Path
from typing import Annotated, Any from typing import Annotated, Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, field_validator
from llama_stack.apis.benchmarks import Benchmark, BenchmarkInput from llama_stack.apis.benchmarks import Benchmark, BenchmarkInput
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
@ -249,10 +250,18 @@ class ServerConfig(BaseModel):
default=None, default=None,
description="Path to TLS key file for HTTPS", description="Path to TLS key file for HTTPS",
) )
tls_cafile: str | None = Field(
default=None,
description="Path to TLS CA file for HTTPS with mutual TLS authentication",
)
auth: AuthenticationConfig | None = Field( auth: AuthenticationConfig | None = Field(
default=None, default=None,
description="Authentication configuration for the server", description="Authentication configuration for the server",
) )
host: str | None = Field(
default=None,
description="The host the server should listen on",
)
class StackRunConfig(BaseModel): class StackRunConfig(BaseModel):
@ -304,11 +313,20 @@ a default SQLite store will be used.""",
description="Configuration for the HTTP(S) server", description="Configuration for the HTTP(S) server",
) )
external_providers_dir: str | None = Field( external_providers_dir: Path | None = Field(
default=None, default=None,
description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.", description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.",
) )
@field_validator("external_providers_dir")
@classmethod
def validate_external_providers_dir(cls, v):
if v is None:
return None
if isinstance(v, str):
return Path(v)
return v
class BuildConfig(BaseModel): class BuildConfig(BaseModel):
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION version: str = LLAMA_STACK_BUILD_CONFIG_VERSION
@ -322,8 +340,17 @@ class BuildConfig(BaseModel):
default=None, default=None,
description="Name of the distribution to build", description="Name of the distribution to build",
) )
external_providers_dir: str | None = Field( external_providers_dir: Path | None = Field(
default=None, default=None,
description="Path to directory containing external provider implementations. The providers packages will be resolved from this directory. " description="Path to directory containing external provider implementations. The providers packages will be resolved from this directory. "
"pip_packages MUST contain the provider package name.", "pip_packages MUST contain the provider package name.",
) )
@field_validator("external_providers_dir")
@classmethod
def validate_external_providers_dir(cls, v):
if v is None:
return None
if isinstance(v, str):
return Path(v)
return v

View file

@ -145,7 +145,7 @@ def get_provider_registry(
# Check if config has the external_providers_dir attribute # Check if config has the external_providers_dir attribute
if config and hasattr(config, "external_providers_dir") and config.external_providers_dir: if config and hasattr(config, "external_providers_dir") and config.external_providers_dir:
external_providers_dir = os.path.abspath(config.external_providers_dir) external_providers_dir = os.path.abspath(os.path.expanduser(config.external_providers_dir))
if not os.path.exists(external_providers_dir): if not os.path.exists(external_providers_dir):
raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}") raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}")
logger.info(f"Loading external providers from {external_providers_dir}") logger.info(f"Loading external providers from {external_providers_dir}")

View file

@ -30,7 +30,7 @@ from termcolor import cprint
from llama_stack.distribution.build import print_pip_install_help from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.datatypes import Api from llama_stack.distribution.datatypes import Api, BuildConfig, DistributionSpec
from llama_stack.distribution.request_headers import ( from llama_stack.distribution.request_headers import (
PROVIDER_DATA_VAR, PROVIDER_DATA_VAR,
request_provider_data_context, request_provider_data_context,
@ -216,7 +216,19 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
"yellow", "yellow",
) )
if self.config_path_or_template_name.endswith(".yaml"): if self.config_path_or_template_name.endswith(".yaml"):
print_pip_install_help(self.config.providers) # Convert Provider objects to their types
provider_types: dict[str, str | list[str]] = {}
for api, providers in self.config.providers.items():
types = [p.provider_type for p in providers]
# Convert single-item lists to strings
provider_types[api] = types[0] if len(types) == 1 else types
build_config = BuildConfig(
distribution_spec=DistributionSpec(
providers=provider_types,
),
external_providers_dir=self.config.external_providers_dir,
)
print_pip_install_help(build_config)
else: else:
prefix = "!" if in_notebook() else "" prefix = "!" if in_notebook() else ""
cprint( cprint(

View file

@ -99,7 +99,7 @@ class ProviderImpl(Providers):
try: try:
health = await asyncio.wait_for(impl.health(), timeout=timeout) health = await asyncio.wait_for(impl.health(), timeout=timeout)
return api_name, health return api_name, health
except asyncio.TimeoutError: except (asyncio.TimeoutError, TimeoutError):
return ( return (
api_name, api_name,
HealthResponse( HealthResponse(

View file

@ -44,7 +44,8 @@ class RequestProviderDataContext(AbstractContextManager):
class NeedsRequestProviderData: class NeedsRequestProviderData:
def get_request_provider_data(self) -> Any: def get_request_provider_data(self) -> Any:
spec = self.__provider_spec__ spec = self.__provider_spec__
assert spec, f"Provider spec not set on {self.__class__}" if not spec:
raise ValueError(f"Provider spec not set on {self.__class__}")
provider_type = spec.provider_type provider_type = spec.provider_type
validator_class = spec.provider_data_validator validator_class = spec.provider_data_validator

View file

@ -13,7 +13,7 @@ from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval import Eval from llama_stack.apis.eval import Eval
from llama_stack.apis.files import Files from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference from llama_stack.apis.inference import Inference, InferenceProvider
from llama_stack.apis.inspect import Inspect from llama_stack.apis.inspect import Inspect
from llama_stack.apis.models import Models from llama_stack.apis.models import Models
from llama_stack.apis.post_training import PostTraining from llama_stack.apis.post_training import PostTraining
@ -83,6 +83,13 @@ def api_protocol_map() -> dict[Api, Any]:
} }
def api_protocol_map_for_compliance_check() -> dict[Api, Any]:
return {
**api_protocol_map(),
Api.inference: InferenceProvider,
}
def additional_protocols_map() -> dict[Api, Any]: def additional_protocols_map() -> dict[Api, Any]:
return { return {
Api.inference: (ModelsProtocolPrivate, Models, Api.models), Api.inference: (ModelsProtocolPrivate, Models, Api.models),
@ -302,9 +309,6 @@ async def instantiate_provider(
inner_impls: dict[str, Any], inner_impls: dict[str, Any],
dist_registry: DistributionRegistry, dist_registry: DistributionRegistry,
): ):
protocols = api_protocol_map()
additional_protocols = additional_protocols_map()
provider_spec = provider.spec provider_spec = provider.spec
if not hasattr(provider_spec, "module"): if not hasattr(provider_spec, "module"):
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute") raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
@ -342,6 +346,8 @@ async def instantiate_provider(
impl.__provider_spec__ = provider_spec impl.__provider_spec__ = provider_spec
impl.__provider_config__ = config impl.__provider_config__ = config
protocols = api_protocol_map_for_compliance_check()
additional_protocols = additional_protocols_map()
# TODO: check compliance for special tool groups # TODO: check compliance for special tool groups
# the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol # the impl should be for Api.tool_runtime, the name should be the special tool group, the protocol should be the special tool group protocol
check_protocol_compliance(impl, protocols[provider_spec.api]) check_protocol_compliance(impl, protocols[provider_spec.api])

View file

@ -573,6 +573,12 @@ class InferenceRouter(Inference):
for tool in tools: for tool in tools:
TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool) TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool)
# Some providers make tool calls even when tool_choice is "none"
# so just clear them both out to avoid unexpected tool calls
if tool_choice == "none" and tools is not None:
tool_choice = None
tools = None
params = dict( params = dict(
model=model_obj.identifier, model=model_obj.identifier,
messages=messages, messages=messages,
@ -600,7 +606,19 @@ class InferenceRouter(Inference):
) )
provider = self.routing_table.get_provider_impl(model_obj.identifier) provider = self.routing_table.get_provider_impl(model_obj.identifier)
if stream:
return await provider.openai_chat_completion(**params) return await provider.openai_chat_completion(**params)
else:
return await self._nonstream_openai_chat_completion(provider, params)
async def _nonstream_openai_chat_completion(self, provider: Inference, params: dict) -> OpenAIChatCompletion:
response = await provider.openai_chat_completion(**params)
for choice in response.choices:
# some providers return an empty list for no tool calls in non-streaming responses
# but the OpenAI API returns None. So, set tool_calls to None if it's empty
if choice.message and choice.message.tool_calls is not None and len(choice.message.tool_calls) == 0:
choice.message.tool_calls = None
return response
async def health(self) -> dict[str, HealthResponse]: async def health(self) -> dict[str, HealthResponse]:
health_statuses = {} health_statuses = {}
@ -612,7 +630,7 @@ class InferenceRouter(Inference):
continue continue
health = await asyncio.wait_for(impl.health(), timeout=timeout) health = await asyncio.wait_for(impl.health(), timeout=timeout)
health_statuses[provider_id] = health health_statuses[provider_id] = health
except asyncio.TimeoutError: except (asyncio.TimeoutError, TimeoutError):
health_statuses[provider_id] = HealthResponse( health_statuses[provider_id] = HealthResponse(
status=HealthStatus.ERROR, status=HealthStatus.ERROR,
message=f"Health check timed out after {timeout} seconds", message=f"Health check timed out after {timeout} seconds",

View file

@ -93,7 +93,7 @@ class AuthenticationMiddleware:
# Validate token and get access attributes # Validate token and get access attributes
try: try:
access_attributes = await self.auth_provider.validate_token(token, scope) validation_result = await self.auth_provider.validate_token(token, scope)
except httpx.TimeoutException: except httpx.TimeoutException:
logger.exception("Authentication request timed out") logger.exception("Authentication request timed out")
return await self._send_auth_error(send, "Authentication service timeout") return await self._send_auth_error(send, "Authentication service timeout")
@ -105,17 +105,20 @@ class AuthenticationMiddleware:
return await self._send_auth_error(send, "Authentication service error") return await self._send_auth_error(send, "Authentication service error")
# Store attributes in request scope for access control # Store attributes in request scope for access control
if access_attributes: if validation_result.access_attributes:
user_attributes = access_attributes.model_dump(exclude_none=True) user_attributes = validation_result.access_attributes.model_dump(exclude_none=True)
else: else:
logger.warning("No access attributes, setting namespace to token by default") logger.warning("No access attributes, setting namespace to token by default")
user_attributes = { user_attributes = {
"namespaces": [token], "roles": [token],
} }
# Store attributes in request scope # Store attributes in request scope
scope["user_attributes"] = user_attributes scope["user_attributes"] = user_attributes
logger.debug(f"Authentication successful: {len(scope['user_attributes'])} attributes") scope["principal"] = validation_result.principal
logger.debug(
f"Authentication successful: {validation_result.principal} with {len(scope['user_attributes'])} attributes"
)
return await self.app(scope, receive, send) return await self.app(scope, receive, send)

View file

@ -5,12 +5,14 @@
# the root directory of this source tree. # the root directory of this source tree.
import json import json
import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
from urllib.parse import parse_qs from urllib.parse import parse_qs
import httpx import httpx
from pydantic import BaseModel, Field from jose import jwt
from pydantic import BaseModel, Field, field_validator
from llama_stack.distribution.datatypes import AccessAttributes from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.log import get_logger from llama_stack.log import get_logger
@ -18,9 +20,11 @@ from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth") logger = get_logger(name=__name__, category="auth")
class AuthResponse(BaseModel): class TokenValidationResult(BaseModel):
"""The format of the authentication response from the auth endpoint.""" principal: str | None = Field(
default=None,
description="The principal (username or persistent identifier) of the authenticated user",
)
access_attributes: AccessAttributes | None = Field( access_attributes: AccessAttributes | None = Field(
default=None, default=None,
description=""" description="""
@ -43,6 +47,10 @@ class AuthResponse(BaseModel):
""", """,
) )
class AuthResponse(TokenValidationResult):
"""The format of the authentication response from the auth endpoint."""
message: str | None = Field( message: str | None = Field(
default=None, description="Optional message providing additional context about the authentication result." default=None, description="Optional message providing additional context about the authentication result."
) )
@ -69,6 +77,7 @@ class AuthProviderType(str, Enum):
KUBERNETES = "kubernetes" KUBERNETES = "kubernetes"
CUSTOM = "custom" CUSTOM = "custom"
OAUTH2_TOKEN = "oauth2_token"
class AuthProviderConfig(BaseModel): class AuthProviderConfig(BaseModel):
@ -82,7 +91,7 @@ class AuthProvider(ABC):
"""Abstract base class for authentication providers.""" """Abstract base class for authentication providers."""
@abstractmethod @abstractmethod
async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None: async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
"""Validate a token and return access attributes.""" """Validate a token and return access attributes."""
pass pass
@ -92,12 +101,16 @@ class AuthProvider(ABC):
pass pass
class KubernetesAuthProviderConfig(BaseModel):
api_server_url: str
ca_cert_path: str | None = None
class KubernetesAuthProvider(AuthProvider): class KubernetesAuthProvider(AuthProvider):
"""Kubernetes authentication provider that validates tokens against the Kubernetes API server.""" """Kubernetes authentication provider that validates tokens against the Kubernetes API server."""
def __init__(self, config: dict[str, str]): def __init__(self, config: KubernetesAuthProviderConfig):
self.api_server_url = config["api_server_url"] self.config = config
self.ca_cert_path = config.get("ca_cert_path")
self._client = None self._client = None
async def _get_client(self): async def _get_client(self):
@ -110,16 +123,16 @@ class KubernetesAuthProvider(AuthProvider):
# Configure the client # Configure the client
configuration = client.Configuration() configuration = client.Configuration()
configuration.host = self.api_server_url configuration.host = self.config.api_server_url
if self.ca_cert_path: if self.config.ca_cert_path:
configuration.ssl_ca_cert = self.ca_cert_path configuration.ssl_ca_cert = self.config.ca_cert_path
configuration.verify_ssl = bool(self.ca_cert_path) configuration.verify_ssl = bool(self.config.ca_cert_path)
# Create API client # Create API client
self._client = ApiClient(configuration) self._client = ApiClient(configuration)
return self._client return self._client
async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None: async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
"""Validate a Kubernetes token and return access attributes.""" """Validate a Kubernetes token and return access attributes."""
try: try:
client = await self._get_client() client = await self._get_client()
@ -146,9 +159,12 @@ class KubernetesAuthProvider(AuthProvider):
username = payload.get("sub", "") username = payload.get("sub", "")
groups = payload.get("groups", []) groups = payload.get("groups", [])
return AccessAttributes( return TokenValidationResult(
principal=username,
access_attributes=AccessAttributes(
roles=[username], # Use username as a role roles=[username], # Use username as a role
teams=groups, # Use Kubernetes groups as teams teams=groups, # Use Kubernetes groups as teams
),
) )
except Exception as e: except Exception as e:
@ -162,18 +178,125 @@ class KubernetesAuthProvider(AuthProvider):
self._client = None self._client = None
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> AccessAttributes:
attributes = AccessAttributes()
for claim_key, attribute_key in mapping.items():
if claim_key not in claims or not hasattr(attributes, attribute_key):
continue
claim = claims[claim_key]
if isinstance(claim, list):
values = claim
else:
values = claim.split()
current = getattr(attributes, attribute_key)
if current:
current.extend(values)
else:
setattr(attributes, attribute_key, values)
return attributes
class OAuth2TokenAuthProviderConfig(BaseModel):
# The JWKS URI for collecting public keys
jwks_uri: str
cache_ttl: int = 3600
audience: str = "llama-stack"
claims_mapping: dict[str, str] = Field(
default_factory=lambda: {
"sub": "roles",
"username": "roles",
"groups": "teams",
"team": "teams",
"project": "projects",
"tenant": "namespaces",
"namespace": "namespaces",
},
)
@classmethod
@field_validator("claims_mapping")
def validate_claims_mapping(cls, v):
for key, value in v.items():
if not value:
raise ValueError(f"claims_mapping value cannot be empty: {key}")
if value not in AccessAttributes.model_fields:
raise ValueError(f"claims_mapping value is not a valid attribute: {value}")
return v
class OAuth2TokenAuthProvider(AuthProvider):
"""
JWT token authentication provider that validates a JWT token and extracts access attributes.
This should be the standard authentication provider for most use cases.
"""
def __init__(self, config: OAuth2TokenAuthProviderConfig):
self.config = config
self._jwks_at: float = 0.0
self._jwks: dict[str, str] = {}
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
"""Validate a token using the JWT token."""
await self._refresh_jwks()
try:
header = jwt.get_unverified_header(token)
kid = header["kid"]
if kid not in self._jwks:
raise ValueError(f"Unknown key ID: {kid}")
key_data = self._jwks[kid]
algorithm = header.get("alg", "RS256")
claims = jwt.decode(
token,
key_data,
algorithms=[algorithm],
audience=self.config.audience,
options={"verify_exp": True},
)
except Exception as exc:
raise ValueError(f"Invalid JWT token: {token}") from exc
# There are other standard claims, the most relevant of which is `scope`.
# We should incorporate these into the access attributes.
principal = claims["sub"]
access_attributes = get_attributes_from_claims(claims, self.config.claims_mapping)
return TokenValidationResult(
principal=principal,
access_attributes=access_attributes,
)
async def close(self):
"""Close the HTTP client."""
async def _refresh_jwks(self) -> None:
if time.time() - self._jwks_at > self.config.cache_ttl:
async with httpx.AsyncClient() as client:
res = await client.get(self.config.jwks_uri, timeout=5)
res.raise_for_status()
jwks_data = res.json()["keys"]
self._jwks = {}
for k in jwks_data:
kid = k["kid"]
# Store the entire key object as it may be needed for different algorithms
self._jwks[kid] = k
self._jwks_at = time.time()
class CustomAuthProviderConfig(BaseModel):
endpoint: str
class CustomAuthProvider(AuthProvider): class CustomAuthProvider(AuthProvider):
"""Custom authentication provider that uses an external endpoint.""" """Custom authentication provider that uses an external endpoint."""
def __init__(self, config: dict[str, str]): def __init__(self, config: CustomAuthProviderConfig):
self.endpoint = config["endpoint"] self.config = config
self._client = None self._client = None
async def validate_token(self, token: str, scope: dict | None = None) -> AccessAttributes | None: async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
"""Validate a token using the custom authentication endpoint.""" """Validate a token using the custom authentication endpoint."""
if not self.endpoint:
raise ValueError("Authentication endpoint not configured")
if scope is None: if scope is None:
scope = {} scope = {}
@ -202,7 +325,7 @@ class CustomAuthProvider(AuthProvider):
try: try:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
self.endpoint, self.config.endpoint,
json=auth_request.model_dump(), json=auth_request.model_dump(),
timeout=10.0, # Add a reasonable timeout timeout=10.0, # Add a reasonable timeout
) )
@ -214,19 +337,7 @@ class CustomAuthProvider(AuthProvider):
try: try:
response_data = response.json() response_data = response.json()
auth_response = AuthResponse(**response_data) auth_response = AuthResponse(**response_data)
return auth_response
# Store attributes in request scope for access control
if auth_response.access_attributes:
return auth_response.access_attributes
else:
logger.warning("No access attributes, setting namespace to api_key by default")
user_attributes = {
"namespaces": [token],
}
scope["user_attributes"] = user_attributes
logger.debug(f"Authentication successful: {len(user_attributes)} attributes")
return auth_response.access_attributes
except Exception as e: except Exception as e:
logger.exception("Error parsing authentication response") logger.exception("Error parsing authentication response")
raise ValueError("Invalid authentication response format") from e raise ValueError("Invalid authentication response format") from e
@ -253,9 +364,11 @@ def create_auth_provider(config: AuthProviderConfig) -> AuthProvider:
provider_type = config.provider_type.lower() provider_type = config.provider_type.lower()
if provider_type == "kubernetes": if provider_type == "kubernetes":
return KubernetesAuthProvider(config.config) return KubernetesAuthProvider(KubernetesAuthProviderConfig.model_validate(config.config))
elif provider_type == "custom": elif provider_type == "custom":
return CustomAuthProvider(config.config) return CustomAuthProvider(CustomAuthProviderConfig.model_validate(config.config))
elif provider_type == "oauth2_token":
return OAuth2TokenAuthProvider(OAuth2TokenAuthProviderConfig.model_validate(config.config))
else: else:
supported_providers = ", ".join([t.value for t in AuthProviderType]) supported_providers = ", ".join([t.value for t in AuthProviderType])
raise ValueError(f"Unsupported auth provider type: {provider_type}. Supported types are: {supported_providers}") raise ValueError(f"Unsupported auth provider type: {provider_type}. Supported types are: {supported_providers}")

View file

@ -9,6 +9,7 @@ import asyncio
import inspect import inspect
import json import json
import os import os
import ssl
import sys import sys
import traceback import traceback
import warnings import warnings
@ -17,6 +18,7 @@ from importlib.metadata import version as parse_version
from pathlib import Path from pathlib import Path
from typing import Annotated, Any from typing import Annotated, Any
import rich.pretty
import yaml import yaml
from fastapi import Body, FastAPI, HTTPException, Request from fastapi import Body, FastAPI, HTTPException, Request
from fastapi import Path as FastapiPath from fastapi import Path as FastapiPath
@ -114,7 +116,7 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
return HTTPException(status_code=400, detail=str(exc)) return HTTPException(status_code=400, detail=str(exc))
elif isinstance(exc, PermissionError): elif isinstance(exc, PermissionError):
return HTTPException(status_code=403, detail=f"Permission denied: {str(exc)}") return HTTPException(status_code=403, detail=f"Permission denied: {str(exc)}")
elif isinstance(exc, TimeoutError): elif isinstance(exc, asyncio.TimeoutError | TimeoutError):
return HTTPException(status_code=504, detail=f"Operation timed out: {str(exc)}") return HTTPException(status_code=504, detail=f"Operation timed out: {str(exc)}")
elif isinstance(exc, NotImplementedError): elif isinstance(exc, NotImplementedError):
return HTTPException(status_code=501, detail=f"Not implemented: {str(exc)}") return HTTPException(status_code=501, detail=f"Not implemented: {str(exc)}")
@ -139,7 +141,7 @@ async def shutdown(app):
await asyncio.wait_for(impl.shutdown(), timeout=5) await asyncio.wait_for(impl.shutdown(), timeout=5)
else: else:
logger.warning("No shutdown method for %s", impl_name) logger.warning("No shutdown method for %s", impl_name)
except asyncio.TimeoutError: except (asyncio.TimeoutError, TimeoutError):
logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True) logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True)
except (Exception, asyncio.CancelledError) as e: except (Exception, asyncio.CancelledError) as e:
logger.exception("Failed to shutdown %s: %s", impl_name, {e}) logger.exception("Failed to shutdown %s: %s", impl_name, {e})
@ -186,11 +188,30 @@ async def sse_generator(event_gen_coroutine):
) )
async def log_request_pre_validation(request: Request):
if request.method in ("POST", "PUT", "PATCH"):
try:
body_bytes = await request.body()
if body_bytes:
try:
parsed_body = json.loads(body_bytes.decode())
log_output = rich.pretty.pretty_repr(parsed_body)
except (json.JSONDecodeError, UnicodeDecodeError):
log_output = repr(body_bytes)
logger.debug(f"Incoming raw request body for {request.method} {request.url.path}:\n{log_output}")
else:
logger.debug(f"Incoming {request.method} {request.url.path} request with empty body.")
except Exception as e:
logger.warning(f"Could not read or log request body for {request.method} {request.url.path}: {e}")
def create_dynamic_typed_route(func: Any, method: str, route: str): def create_dynamic_typed_route(func: Any, method: str, route: str):
async def endpoint(request: Request, **kwargs): async def endpoint(request: Request, **kwargs):
# Get auth attributes from the request scope # Get auth attributes from the request scope
user_attributes = request.scope.get("user_attributes", {}) user_attributes = request.scope.get("user_attributes", {})
await log_request_pre_validation(request)
# Use context manager with both provider data and auth attributes # Use context manager with both provider data and auth attributes
with request_provider_data_context(request.headers, user_attributes): with request_provider_data_context(request.headers, user_attributes):
is_streaming = is_streaming_request(func.__name__, request, **kwargs) is_streaming = is_streaming_request(func.__name__, request, **kwargs)
@ -337,22 +358,11 @@ def main(args: argparse.Namespace | None = None):
default=int(os.getenv("LLAMA_STACK_PORT", 8321)), default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
help="Port to listen on", help="Port to listen on",
) )
parser.add_argument("--disable-ipv6", action="store_true", help="Whether to disable IPv6 support")
parser.add_argument( parser.add_argument(
"--env", "--env",
action="append", action="append",
help="Environment variables in KEY=value format. Can be specified multiple times.", help="Environment variables in KEY=value format. Can be specified multiple times.",
) )
parser.add_argument(
"--tls-keyfile",
help="Path to TLS key file for HTTPS",
required="--tls-certfile" in sys.argv,
)
parser.add_argument(
"--tls-certfile",
help="Path to TLS certificate file for HTTPS",
required="--tls-keyfile" in sys.argv,
)
# Determine whether the server args are being passed by the "run" command, if this is the case # Determine whether the server args are being passed by the "run" command, if this is the case
# the args will be passed as a Namespace object to the main function, otherwise they will be # the args will be passed as a Namespace object to the main function, otherwise they will be
@ -361,9 +371,9 @@ def main(args: argparse.Namespace | None = None):
args = parser.parse_args() args = parser.parse_args()
# Check for deprecated argument usage # Check for deprecated argument usage
if "--yaml-config" in sys.argv: if "--config" in sys.argv:
warnings.warn( warnings.warn(
"The '--yaml-config' argument is deprecated and will be removed in a future version. Use '--config' instead.", "The '--config' argument is deprecated and will be removed in a future version. Use '--config' instead.",
DeprecationWarning, DeprecationWarning,
stacklevel=2, stacklevel=2,
) )
@ -381,7 +391,7 @@ def main(args: argparse.Namespace | None = None):
raise ValueError(f"Template {args.template} does not exist") raise ValueError(f"Template {args.template} does not exist")
log_line = f"Using template {args.template} config file: {config_file}" log_line = f"Using template {args.template} config file: {config_file}"
else: else:
raise ValueError("Either --yaml-config or --template must be provided") raise ValueError("Either --config or --template must be provided")
logger_config = None logger_config = None
with open(config_file) as fp: with open(config_file) as fp:
@ -486,10 +496,6 @@ def main(args: argparse.Namespace | None = None):
port = args.port or config.server.port port = args.port or config.server.port
ssl_config = None ssl_config = None
if args.tls_keyfile:
keyfile = args.tls_keyfile
certfile = args.tls_certfile
else:
keyfile = config.server.tls_keyfile keyfile = config.server.tls_keyfile
certfile = config.server.tls_certfile certfile = config.server.tls_certfile
@ -498,9 +504,16 @@ def main(args: argparse.Namespace | None = None):
"ssl_keyfile": keyfile, "ssl_keyfile": keyfile,
"ssl_certfile": certfile, "ssl_certfile": certfile,
} }
if config.server.tls_cafile:
ssl_config["ssl_ca_certs"] = config.server.tls_cafile
ssl_config["ssl_cert_reqs"] = ssl.CERT_REQUIRED
logger.info(
f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}\n CA: {config.server.tls_cafile}"
)
else:
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}") logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
listen_host = ["::", "0.0.0.0"] if not args.disable_ipv6 else "0.0.0.0" listen_host = config.server.host or ["::", "0.0.0.0"]
logger.info(f"Listening on {listen_host}:{port}") logger.info(f"Listening on {listen_host}:{port}")
uvicorn_config = { uvicorn_config = {

View file

@ -29,7 +29,7 @@ error_handler() {
trap 'error_handler ${LINENO}' ERR trap 'error_handler ${LINENO}' ERR
if [ $# -lt 3 ]; then if [ $# -lt 3 ]; then
echo "Usage: $0 <env_type> <env_path_or_name> <yaml_config> <port> <script_args...>" echo "Usage: $0 <env_type> <env_path_or_name> <port> [--config <yaml_config>] [--env KEY=VALUE]..."
exit 1 exit 1
fi fi
@ -40,23 +40,30 @@ env_path_or_name="$1"
container_image="localhost/$env_path_or_name" container_image="localhost/$env_path_or_name"
shift shift
yaml_config="$1"
shift
port="$1" port="$1"
shift shift
SCRIPT_DIR=$(dirname "$(readlink -f "$0")") SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
source "$SCRIPT_DIR/common.sh" source "$SCRIPT_DIR/common.sh"
# Initialize env_vars as an string # Initialize variables
yaml_config=""
env_vars="" env_vars=""
other_args="" other_args=""
# Process environment variables from --env arguments
# Process remaining arguments
while [[ $# -gt 0 ]]; do while [[ $# -gt 0 ]]; do
case "$1" in case "$1" in
--config)
if [[ -n "$2" ]]; then
yaml_config="$2"
shift 2
else
echo -e "${RED}Error: $1 requires a CONFIG argument${NC}" >&2
exit 1
fi
;;
--env) --env)
if [[ -n "$2" ]]; then if [[ -n "$2" ]]; then
env_vars="$env_vars --env $2" env_vars="$env_vars --env $2"
shift 2 shift 2
@ -71,6 +78,13 @@ while [[ $# -gt 0 ]]; do
;; ;;
esac esac
done done
# Check if yaml_config is required based on env_type
if [[ "$env_type" == "venv" || "$env_type" == "conda" ]] && [ -z "$yaml_config" ]; then
echo -e "${RED}Error: --config is required for venv and conda environments${NC}" >&2
exit 1
fi
PYTHON_BINARY="python" PYTHON_BINARY="python"
case "$env_type" in case "$env_type" in
"venv") "venv")
@ -106,8 +120,14 @@ esac
if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then
set -x set -x
if [ -n "$yaml_config" ]; then
yaml_config_arg="--config $yaml_config"
else
yaml_config_arg=""
fi
$PYTHON_BINARY -m llama_stack.distribution.server.server \ $PYTHON_BINARY -m llama_stack.distribution.server.server \
--yaml-config "$yaml_config" \ $yaml_config_arg \
--port "$port" \ --port "$port" \
$env_vars \ $env_vars \
$other_args $other_args
@ -149,15 +169,26 @@ elif [[ "$env_type" == "container" ]]; then
version_tag=$(curl -s $URL | jq -r '.info.version') version_tag=$(curl -s $URL | jq -r '.info.version')
fi fi
$CONTAINER_BINARY run $CONTAINER_OPTS -it \ # Build the command with optional yaml config
cmd="$CONTAINER_BINARY run $CONTAINER_OPTS -it \
-p $port:$port \ -p $port:$port \
$env_vars \ $env_vars \
-v "$yaml_config:/app/config.yaml" \
$mounts \ $mounts \
--env LLAMA_STACK_PORT=$port \ --env LLAMA_STACK_PORT=$port \
--entrypoint python \ --entrypoint python \
$container_image:$version_tag \ $container_image:$version_tag \
-m llama_stack.distribution.server.server \ -m llama_stack.distribution.server.server"
--yaml-config /app/config.yaml \
$other_args # Add yaml config if provided, otherwise use default
if [ -n "$yaml_config" ]; then
cmd="$cmd -v $yaml_config:/app/run.yaml --config /app/run.yaml"
else
cmd="$cmd --config /app/run.yaml"
fi
# Add any other args
cmd="$cmd $other_args"
# Execute the command
eval $cmd
fi fi

View file

@ -73,7 +73,7 @@ class DiskDistributionRegistry(DistributionRegistry):
async def get_all(self) -> list[RoutableObjectWithProvider]: async def get_all(self) -> list[RoutableObjectWithProvider]:
start_key, end_key = _get_registry_key_range() start_key, end_key = _get_registry_key_range()
values = await self.kvstore.range(start_key, end_key) values = await self.kvstore.values_in_range(start_key, end_key)
return _parse_registry_values(values) return _parse_registry_values(values)
async def get(self, type: str, identifier: str) -> RoutableObjectWithProvider | None: async def get(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
@ -134,7 +134,7 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
return return
start_key, end_key = _get_registry_key_range() start_key, end_key = _get_registry_key_range()
values = await self.kvstore.range(start_key, end_key) values = await self.kvstore.values_in_range(start_key, end_key)
objects = _parse_registry_values(values) objects = _parse_registry_values(values)
async with self._locked_cache() as cache: async with self._locked_cache() as cache:

View file

@ -124,7 +124,7 @@ if prompt := st.chat_input("Example: What is Llama Stack?"):
message_placeholder.markdown(full_response + "") message_placeholder.markdown(full_response + "")
message_placeholder.markdown(full_response) message_placeholder.markdown(full_response)
else: else:
full_response = response full_response = response.completion_message.content
message_placeholder.markdown(full_response.completion_message.content) message_placeholder.markdown(full_response)
st.session_state.messages.append({"role": "assistant", "content": full_response}) st.session_state.messages.append({"role": "assistant", "content": full_response})

View file

@ -14,3 +14,5 @@ DISTRIBS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "distributions"
DEFAULT_CHECKPOINT_DIR = LLAMA_STACK_CONFIG_DIR / "checkpoints" DEFAULT_CHECKPOINT_DIR = LLAMA_STACK_CONFIG_DIR / "checkpoints"
RUNTIME_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "runtime" RUNTIME_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "runtime"
EXTERNAL_PROVIDERS_DIR = LLAMA_STACK_CONFIG_DIR / "providers.d"

View file

@ -22,8 +22,10 @@ from llama_stack.distribution.utils.image_types import LlamaStackImageType
def formulate_run_args(image_type, image_name, config, template_name) -> list: def formulate_run_args(image_type, image_name, config, template_name) -> list:
env_name = "" env_name = ""
if image_type == LlamaStackImageType.CONTAINER.value or config.container_image: if image_type == LlamaStackImageType.CONTAINER.value:
env_name = f"distribution-{template_name}" if template_name else config.container_image env_name = (
f"distribution-{template_name}" if template_name else (config.container_image if config else image_name)
)
elif image_type == LlamaStackImageType.CONDA.value: elif image_type == LlamaStackImageType.CONDA.value:
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV") current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
env_name = image_name or current_conda_env env_name = image_name or current_conda_env

View file

@ -245,7 +245,7 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
{"function_description": self._gen_function_description(custom_tools)}, {"function_description": self._gen_function_description(custom_tools)},
) )
def _gen_function_description(self, custom_tools: list[ToolDefinition]) -> PromptTemplate: def _gen_function_description(self, custom_tools: list[ToolDefinition]) -> str:
template_str = textwrap.dedent( template_str = textwrap.dedent(
""" """
Here is a list of functions in JSON format that you can invoke. Here is a list of functions in JSON format that you can invoke.
@ -286,10 +286,12 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
""" """
) )
return PromptTemplate( template = PromptTemplate(
template_str.strip("\n"), template_str.strip("\n"),
{"tools": [t.model_dump() for t in custom_tools]}, {"tools": [t.model_dump() for t in custom_tools]},
).render() )
rendered: str = template.render()
return rendered
def data_examples(self) -> list[list[ToolDefinition]]: def data_examples(self) -> list[list[ToolDefinition]]:
return [ return [

View file

@ -173,9 +173,7 @@ INCORRECT: [get_events(location="Singapore")] <- If function not in list
- Don't repeat tool response verbatim - Don't repeat tool response verbatim
- Don't add supplementary information - Don't add supplementary information
Here is a list of functions in JSON format that you can invoke:
Here is a list of functions in JSON format that you can invoke.
[ [
{ {
"name": "get_weather", "name": "get_weather",
@ -196,10 +194,7 @@ Here is a list of functions in JSON format that you can invoke.
} }
} }
} }
] ]<|eot|><|header_start|>user<|header_end|>
You can answer general questions or invoke tools when necessary.
In addition to tool calls, you should also augment your responses by using the tool outputs.<|eot|><|header_start|>user<|header_end|>
What is the weather in SF and Seattle?<|eot|><|header_start|>assistant<|header_end|> What is the weather in SF and Seattle?<|eot|><|header_start|>assistant<|header_end|>

View file

@ -61,7 +61,6 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
- Don't repeat tool response verbatim - Don't repeat tool response verbatim
- Don't add supplementary information - Don't add supplementary information
{{ function_description }} {{ function_description }}
""".strip("\n") """.strip("\n")
) )
@ -76,8 +75,7 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
def _gen_function_description(self, custom_tools: list[ToolDefinition]) -> PromptTemplate: def _gen_function_description(self, custom_tools: list[ToolDefinition]) -> PromptTemplate:
template_str = textwrap.dedent( template_str = textwrap.dedent(
""" """
Here is a list of functions in JSON format that you can invoke. Here is a list of functions in JSON format that you can invoke:
[ [
{% for t in tools -%} {% for t in tools -%}
{# manually setting up JSON because jinja sorts keys in unexpected ways -#} {# manually setting up JSON because jinja sorts keys in unexpected ways -#}
@ -108,10 +106,6 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
{% endif -%} {% endif -%}
{%- endfor %} {%- endfor %}
] ]
You can answer general questions or invoke tools when necessary.
In addition to tool calls, you should also augment your responses by using the tool outputs.
""" """
) )
return PromptTemplate( return PromptTemplate(

View file

@ -948,6 +948,8 @@ def llama_meta_net_info(model: Model) -> LlamaDownloadInfo:
elif model.core_model_id == CoreModelId.llama_guard_2_8b: elif model.core_model_id == CoreModelId.llama_guard_2_8b:
folder = "llama-guard-2" folder = "llama-guard-2"
else: else:
if model.huggingface_repo is None:
raise ValueError(f"Model {model.core_model_id} has no huggingface_repo set")
folder = model.huggingface_repo.split("/")[-1] folder = model.huggingface_repo.split("/")[-1]
if "Llama-2" in folder: if "Llama-2" in folder:
folder = folder.lower() folder = folder.lower()
@ -1024,3 +1026,4 @@ def llama_meta_pth_size(model: Model) -> int:
return 54121549657 return 54121549657
else: else:
return 100426653046 return 100426653046
return 0

View file

@ -95,6 +95,7 @@ class ChatAgent(ShieldRunnerMixin):
tool_groups_api: ToolGroups, tool_groups_api: ToolGroups,
vector_io_api: VectorIO, vector_io_api: VectorIO,
persistence_store: KVStore, persistence_store: KVStore,
created_at: str,
): ):
self.agent_id = agent_id self.agent_id = agent_id
self.agent_config = agent_config self.agent_config = agent_config
@ -104,6 +105,7 @@ class ChatAgent(ShieldRunnerMixin):
self.storage = AgentPersistence(agent_id, persistence_store) self.storage = AgentPersistence(agent_id, persistence_store)
self.tool_runtime_api = tool_runtime_api self.tool_runtime_api = tool_runtime_api
self.tool_groups_api = tool_groups_api self.tool_groups_api = tool_groups_api
self.created_at = created_at
ShieldRunnerMixin.__init__( ShieldRunnerMixin.__init__(
self, self,

View file

@ -4,10 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import json
import logging import logging
import uuid import uuid
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from datetime import datetime, timezone
from llama_stack.apis.agents import ( from llama_stack.apis.agents import (
Agent, Agent,
@ -20,14 +20,13 @@ from llama_stack.apis.agents import (
AgentTurnCreateRequest, AgentTurnCreateRequest,
AgentTurnResumeRequest, AgentTurnResumeRequest,
Document, Document,
ListAgentSessionsResponse, OpenAIResponseInput,
ListAgentsResponse,
OpenAIResponseInputMessage,
OpenAIResponseInputTool, OpenAIResponseInputTool,
OpenAIResponseObject, OpenAIResponseObject,
Session, Session,
Turn, Turn,
) )
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
Inference, Inference,
ToolConfig, ToolConfig,
@ -39,13 +38,14 @@ from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO from llama_stack.apis.vector_io import VectorIO
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
from llama_stack.providers.utils.pagination import paginate_records
from .agent_instance import ChatAgent from .agent_instance import ChatAgent
from .config import MetaReferenceAgentsImplConfig from .config import MetaReferenceAgentsImplConfig
from .openai_responses import OpenAIResponsesImpl from .openai_responses import OpenAIResponsesImpl
from .persistence import AgentInfo
logger = logging.getLogger() logger = logging.getLogger()
logger.setLevel(logging.INFO)
class MetaReferenceAgentsImpl(Agents): class MetaReferenceAgentsImpl(Agents):
@ -82,43 +82,47 @@ class MetaReferenceAgentsImpl(Agents):
agent_config: AgentConfig, agent_config: AgentConfig,
) -> AgentCreateResponse: ) -> AgentCreateResponse:
agent_id = str(uuid.uuid4()) agent_id = str(uuid.uuid4())
created_at = datetime.now(timezone.utc)
agent_info = AgentInfo(
**agent_config.model_dump(),
created_at=created_at,
)
# Store the agent info
await self.persistence_store.set( await self.persistence_store.set(
key=f"agent:{agent_id}", key=f"agent:{agent_id}",
value=agent_config.model_dump_json(), value=agent_info.model_dump_json(),
) )
return AgentCreateResponse( return AgentCreateResponse(
agent_id=agent_id, agent_id=agent_id,
) )
async def _get_agent_impl(self, agent_id: str) -> ChatAgent: async def _get_agent_impl(self, agent_id: str) -> ChatAgent:
agent_config = await self.persistence_store.get( agent_info_json = await self.persistence_store.get(
key=f"agent:{agent_id}", key=f"agent:{agent_id}",
) )
if not agent_config: if not agent_info_json:
raise ValueError(f"Could not find agent config for {agent_id}") raise ValueError(f"Could not find agent info for {agent_id}")
try: try:
agent_config = json.loads(agent_config) agent_info = AgentInfo.model_validate_json(agent_info_json)
except json.JSONDecodeError as e:
raise ValueError(f"Could not JSON decode agent config for {agent_id}") from e
try:
agent_config = AgentConfig(**agent_config)
except Exception as e: except Exception as e:
raise ValueError(f"Could not validate(?) agent config for {agent_id}") from e raise ValueError(f"Could not validate agent info for {agent_id}") from e
return ChatAgent( return ChatAgent(
agent_id=agent_id, agent_id=agent_id,
agent_config=agent_config, agent_config=agent_info,
inference_api=self.inference_api, inference_api=self.inference_api,
safety_api=self.safety_api, safety_api=self.safety_api,
vector_io_api=self.vector_io_api, vector_io_api=self.vector_io_api,
tool_runtime_api=self.tool_runtime_api, tool_runtime_api=self.tool_runtime_api,
tool_groups_api=self.tool_groups_api, tool_groups_api=self.tool_groups_api,
persistence_store=( persistence_store=(
self.persistence_store if agent_config.enable_session_persistence else self.in_memory_store self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store
), ),
created_at=agent_info.created_at,
) )
async def create_agent_session( async def create_agent_session(
@ -212,6 +216,7 @@ class MetaReferenceAgentsImpl(Agents):
turn_ids: list[str] | None = None, turn_ids: list[str] | None = None,
) -> Session: ) -> Session:
agent = await self._get_agent_impl(agent_id) agent = await self._get_agent_impl(agent_id)
session_info = await agent.storage.get_session_info(session_id) session_info = await agent.storage.get_session_info(session_id)
if session_info is None: if session_info is None:
raise ValueError(f"Session {session_id} not found") raise ValueError(f"Session {session_id} not found")
@ -226,24 +231,75 @@ class MetaReferenceAgentsImpl(Agents):
) )
async def delete_agents_session(self, agent_id: str, session_id: str) -> None: async def delete_agents_session(self, agent_id: str, session_id: str) -> None:
await self.persistence_store.delete(f"session:{agent_id}:{session_id}") agent = await self._get_agent_impl(agent_id)
session_info = await agent.storage.get_session_info(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")
# Delete turns first, then the session
await agent.storage.delete_session_turns(session_id)
await agent.storage.delete_session(session_id)
async def delete_agent(self, agent_id: str) -> None: async def delete_agent(self, agent_id: str) -> None:
# First get all sessions for this agent
agent = await self._get_agent_impl(agent_id)
sessions = await agent.storage.list_sessions()
# Delete all sessions
for session in sessions:
await self.delete_agents_session(agent_id, session.session_id)
# Finally delete the agent itself
await self.persistence_store.delete(f"agent:{agent_id}") await self.persistence_store.delete(f"agent:{agent_id}")
async def shutdown(self) -> None: async def list_agents(self, start_index: int | None = None, limit: int | None = None) -> PaginatedResponse:
pass agent_keys = await self.persistence_store.keys_in_range("agent:", "agent:\xff")
agent_list: list[Agent] = []
for agent_key in agent_keys:
agent_id = agent_key.split(":")[1]
async def list_agents(self) -> ListAgentsResponse: # Get the agent info using the key
pass agent_info_json = await self.persistence_store.get(agent_key)
if not agent_info_json:
logger.error(f"Could not find agent info for key {agent_key}")
continue
try:
agent_info = AgentInfo.model_validate_json(agent_info_json)
agent_list.append(
Agent(
agent_id=agent_id,
agent_config=agent_info,
created_at=agent_info.created_at,
)
)
except Exception as e:
logger.error(f"Error parsing agent info for {agent_id}: {e}")
continue
# Convert Agent objects to dictionaries
agent_dicts = [agent.model_dump() for agent in agent_list]
return paginate_records(agent_dicts, start_index, limit)
async def get_agent(self, agent_id: str) -> Agent: async def get_agent(self, agent_id: str) -> Agent:
pass chat_agent = await self._get_agent_impl(agent_id)
agent = Agent(
agent_id=agent_id,
agent_config=chat_agent.agent_config,
created_at=chat_agent.created_at,
)
return agent
async def list_agent_sessions( async def list_agent_sessions(
self, self, agent_id: str, start_index: int | None = None, limit: int | None = None
agent_id: str, ) -> PaginatedResponse:
) -> ListAgentSessionsResponse: agent = await self._get_agent_impl(agent_id)
sessions = await agent.storage.list_sessions()
# Convert Session objects to dictionaries
session_dicts = [session.model_dump() for session in sessions]
return paginate_records(session_dicts, start_index, limit)
async def shutdown(self) -> None:
pass pass
# OpenAI responses # OpenAI responses
@ -255,7 +311,7 @@ class MetaReferenceAgentsImpl(Agents):
async def create_openai_response( async def create_openai_response(
self, self,
input: str | list[OpenAIResponseInputMessage], input: str | list[OpenAIResponseInput],
model: str, model: str,
previous_response_id: str | None = None, previous_response_id: str | None = None,
store: bool | None = True, store: bool | None = True,

View file

@ -7,22 +7,29 @@
import json import json
import uuid import uuid
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from typing import cast from typing import Any, cast
from openai.types.chat import ChatCompletionToolParam from openai.types.chat import ChatCompletionToolParam
from pydantic import BaseModel
from llama_stack.apis.agents.openai_responses import ( from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInputMessage, OpenAIResponseInput,
OpenAIResponseInputFunctionToolCallOutput,
OpenAIResponseInputItemList,
OpenAIResponseInputMessageContent,
OpenAIResponseInputMessageContentImage, OpenAIResponseInputMessageContentImage,
OpenAIResponseInputMessageContentText, OpenAIResponseInputMessageContentText,
OpenAIResponseInputTool, OpenAIResponseInputTool,
OpenAIResponseInputToolFunction,
OpenAIResponseMessage,
OpenAIResponseObject, OpenAIResponseObject,
OpenAIResponseObjectStream, OpenAIResponseObjectStream,
OpenAIResponseObjectStreamResponseCompleted, OpenAIResponseObjectStreamResponseCompleted,
OpenAIResponseObjectStreamResponseCreated, OpenAIResponseObjectStreamResponseCreated,
OpenAIResponseOutput, OpenAIResponseOutput,
OpenAIResponseOutputMessage, OpenAIResponseOutputMessageContent,
OpenAIResponseOutputMessageContentOutputText, OpenAIResponseOutputMessageContentOutputText,
OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseOutputMessageWebSearchToolCall, OpenAIResponseOutputMessageWebSearchToolCall,
) )
from llama_stack.apis.inference.inference import ( from llama_stack.apis.inference.inference import (
@ -32,10 +39,13 @@ from llama_stack.apis.inference.inference import (
OpenAIChatCompletionContentPartImageParam, OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartParam, OpenAIChatCompletionContentPartParam,
OpenAIChatCompletionContentPartTextParam, OpenAIChatCompletionContentPartTextParam,
OpenAIChatCompletionToolCall,
OpenAIChatCompletionToolCallFunction, OpenAIChatCompletionToolCallFunction,
OpenAIChoice, OpenAIChoice,
OpenAIDeveloperMessageParam,
OpenAIImageURL, OpenAIImageURL,
OpenAIMessageParam, OpenAIMessageParam,
OpenAISystemMessageParam,
OpenAIToolMessageParam, OpenAIToolMessageParam,
OpenAIUserMessageParam, OpenAIUserMessageParam,
) )
@ -50,31 +60,110 @@ logger = get_logger(name=__name__, category="openai_responses")
OPENAI_RESPONSES_PREFIX = "openai_responses:" OPENAI_RESPONSES_PREFIX = "openai_responses:"
async def _previous_response_to_messages(previous_response: OpenAIResponseObject) -> list[OpenAIMessageParam]: async def _convert_response_content_to_chat_content(
content: str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent],
) -> str | list[OpenAIChatCompletionContentPartParam]:
"""
Convert the content parts from an OpenAI Response API request into OpenAI Chat Completion content parts.
The content schemas of each API look similar, but are not exactly the same.
"""
if isinstance(content, str):
return content
converted_parts = []
for content_part in content:
if isinstance(content_part, OpenAIResponseInputMessageContentText):
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
elif isinstance(content_part, OpenAIResponseOutputMessageContentOutputText):
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
elif isinstance(content_part, OpenAIResponseInputMessageContentImage):
if content_part.image_url:
image_url = OpenAIImageURL(url=content_part.image_url, detail=content_part.detail)
converted_parts.append(OpenAIChatCompletionContentPartImageParam(image_url=image_url))
elif isinstance(content_part, str):
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part))
else:
raise ValueError(
f"Llama Stack OpenAI Responses does not yet support content type '{type(content_part)}' in this context"
)
return converted_parts
async def _convert_response_input_to_chat_messages(
input: str | list[OpenAIResponseInput],
) -> list[OpenAIMessageParam]:
"""
Convert the input from an OpenAI Response API request into OpenAI Chat Completion messages.
"""
messages: list[OpenAIMessageParam] = [] messages: list[OpenAIMessageParam] = []
for output_message in previous_response.output: if isinstance(input, list):
if isinstance(output_message, OpenAIResponseOutputMessage): for input_item in input:
messages.append(OpenAIAssistantMessageParam(content=output_message.content[0].text)) if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput):
messages.append(
OpenAIToolMessageParam(
content=input_item.output,
tool_call_id=input_item.call_id,
)
)
elif isinstance(input_item, OpenAIResponseOutputMessageFunctionToolCall):
tool_call = OpenAIChatCompletionToolCall(
index=0,
id=input_item.call_id,
function=OpenAIChatCompletionToolCallFunction(
name=input_item.name,
arguments=input_item.arguments,
),
)
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
else:
content = await _convert_response_content_to_chat_content(input_item.content)
message_type = await _get_message_type_by_role(input_item.role)
if message_type is None:
raise ValueError(
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
)
messages.append(message_type(content=content))
else:
messages.append(OpenAIUserMessageParam(content=input))
return messages return messages
async def _openai_choices_to_output_messages(choices: list[OpenAIChoice]) -> list[OpenAIResponseOutputMessage]: async def _convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenAIResponseMessage:
output_messages = [] """
for choice in choices: Convert an OpenAI Chat Completion choice into an OpenAI Response output message.
"""
output_content = "" output_content = ""
if isinstance(choice.message.content, str): if isinstance(choice.message.content, str):
output_content = choice.message.content output_content = choice.message.content
elif isinstance(choice.message.content, OpenAIChatCompletionContentPartTextParam): elif isinstance(choice.message.content, OpenAIChatCompletionContentPartTextParam):
output_content = choice.message.content.text output_content = choice.message.content.text
# TODO: handle image content else:
output_messages.append( raise ValueError(
OpenAIResponseOutputMessage( f"Llama Stack OpenAI Responses does not yet support output content type: {type(choice.message.content)}"
)
return OpenAIResponseMessage(
id=f"msg_{uuid.uuid4()}", id=f"msg_{uuid.uuid4()}",
content=[OpenAIResponseOutputMessageContentOutputText(text=output_content)], content=[OpenAIResponseOutputMessageContentOutputText(text=output_content)],
status="completed", status="completed",
role="assistant",
) )
)
return output_messages
async def _get_message_type_by_role(role: str):
role_to_type = {
"user": OpenAIUserMessageParam,
"system": OpenAISystemMessageParam,
"assistant": OpenAIAssistantMessageParam,
"developer": OpenAIDeveloperMessageParam,
}
return role_to_type.get(role)
class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
input_items: OpenAIResponseInputItemList
response: OpenAIResponseObject
class OpenAIResponsesImpl: class OpenAIResponsesImpl:
@ -90,19 +179,45 @@ class OpenAIResponsesImpl:
self.tool_groups_api = tool_groups_api self.tool_groups_api = tool_groups_api
self.tool_runtime_api = tool_runtime_api self.tool_runtime_api = tool_runtime_api
async def get_openai_response( async def _get_previous_response_with_input(self, id: str) -> OpenAIResponsePreviousResponseWithInputItems:
self,
id: str,
) -> OpenAIResponseObject:
key = f"{OPENAI_RESPONSES_PREFIX}{id}" key = f"{OPENAI_RESPONSES_PREFIX}{id}"
response_json = await self.persistence_store.get(key=key) response_json = await self.persistence_store.get(key=key)
if response_json is None: if response_json is None:
raise ValueError(f"OpenAI response with id '{id}' not found") raise ValueError(f"OpenAI response with id '{id}' not found")
return OpenAIResponseObject.model_validate_json(response_json) return OpenAIResponsePreviousResponseWithInputItems.model_validate_json(response_json)
async def _prepend_previous_response(
self, input: str | list[OpenAIResponseInput], previous_response_id: str | None = None
):
if previous_response_id:
previous_response_with_input = await self._get_previous_response_with_input(previous_response_id)
# previous response input items
new_input_items = previous_response_with_input.input_items.data
# previous response output items
new_input_items.extend(previous_response_with_input.response.output)
# new input items from the current request
if isinstance(input, str):
new_input_items.append(OpenAIResponseMessage(content=input, role="user"))
else:
new_input_items.extend(input)
input = new_input_items
return input
async def get_openai_response(
self,
id: str,
) -> OpenAIResponseObject:
response_with_input = await self._get_previous_response_with_input(id)
return response_with_input.response
async def create_openai_response( async def create_openai_response(
self, self,
input: str | list[OpenAIResponseInputMessage], input: str | list[OpenAIResponseInput],
model: str, model: str,
previous_response_id: str | None = None, previous_response_id: str | None = None,
store: bool | None = True, store: bool | None = True,
@ -112,31 +227,8 @@ class OpenAIResponsesImpl:
): ):
stream = False if stream is None else stream stream = False if stream is None else stream
messages: list[OpenAIMessageParam] = [] input = await self._prepend_previous_response(input, previous_response_id)
if previous_response_id: messages = await _convert_response_input_to_chat_messages(input)
previous_response = await self.get_openai_response(previous_response_id)
messages.extend(await _previous_response_to_messages(previous_response))
# TODO: refactor this user_content parsing out into a separate method
user_content: str | list[OpenAIChatCompletionContentPartParam] = ""
if isinstance(input, list):
user_content = []
for user_input in input:
if isinstance(user_input.content, list):
for user_input_content in user_input.content:
if isinstance(user_input_content, OpenAIResponseInputMessageContentText):
user_content.append(OpenAIChatCompletionContentPartTextParam(text=user_input_content.text))
elif isinstance(user_input_content, OpenAIResponseInputMessageContentImage):
if user_input_content.image_url:
image_url = OpenAIImageURL(
url=user_input_content.image_url, detail=user_input_content.detail
)
user_content.append(OpenAIChatCompletionContentPartImageParam(image_url=image_url))
else:
user_content.append(OpenAIChatCompletionContentPartTextParam(text=user_input.content))
else:
user_content = input
messages.append(OpenAIUserMessageParam(content=user_content))
chat_tools = await self._convert_response_tools_to_chat_tools(tools) if tools else None chat_tools = await self._convert_response_tools_to_chat_tools(tools) if tools else None
chat_response = await self.inference_api.openai_chat_completion( chat_response = await self.inference_api.openai_chat_completion(
model=model, model=model,
@ -150,6 +242,7 @@ class OpenAIResponsesImpl:
# TODO: refactor this into a separate method that handles streaming # TODO: refactor this into a separate method that handles streaming
chat_response_id = "" chat_response_id = ""
chat_response_content = [] chat_response_content = []
chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {}
# TODO: these chunk_ fields are hacky and only take the last chunk into account # TODO: these chunk_ fields are hacky and only take the last chunk into account
chunk_created = 0 chunk_created = 0
chunk_model = "" chunk_model = ""
@ -163,7 +256,30 @@ class OpenAIResponsesImpl:
chat_response_content.append(chunk_choice.delta.content or "") chat_response_content.append(chunk_choice.delta.content or "")
if chunk_choice.finish_reason: if chunk_choice.finish_reason:
chunk_finish_reason = chunk_choice.finish_reason chunk_finish_reason = chunk_choice.finish_reason
assistant_message = OpenAIAssistantMessageParam(content="".join(chat_response_content))
# Aggregate tool call arguments across chunks, using their index as the aggregation key
if chunk_choice.delta.tool_calls:
for tool_call in chunk_choice.delta.tool_calls:
response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
if response_tool_call:
response_tool_call.function.arguments += tool_call.function.arguments
else:
tool_call_dict: dict[str, Any] = tool_call.model_dump()
# Ensure we don't have any empty type field in the tool call dict.
# The OpenAI client used by providers often returns a type=None here.
tool_call_dict.pop("type", None)
response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict)
chat_response_tool_calls[tool_call.index] = response_tool_call
# Convert the dict of tool calls by index to a list of tool calls to pass back in our response
if chat_response_tool_calls:
tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())]
else:
tool_calls = None
assistant_message = OpenAIAssistantMessageParam(
content="".join(chat_response_content),
tool_calls=tool_calls,
)
chat_response = OpenAIChatCompletion( chat_response = OpenAIChatCompletion(
id=chat_response_id, id=chat_response_id,
choices=[ choices=[
@ -181,12 +297,26 @@ class OpenAIResponsesImpl:
chat_response = OpenAIChatCompletion(**chat_response.model_dump()) chat_response = OpenAIChatCompletion(**chat_response.model_dump())
output_messages: list[OpenAIResponseOutput] = [] output_messages: list[OpenAIResponseOutput] = []
if chat_response.choices[0].message.tool_calls: for choice in chat_response.choices:
output_messages.extend( if choice.message.tool_calls and tools:
await self._execute_tool_and_return_final_output(model, stream, chat_response, messages, temperature) # Assume if the first tool is a function, all tools are functions
if isinstance(tools[0], OpenAIResponseInputToolFunction):
for tool_call in choice.message.tool_calls:
output_messages.append(
OpenAIResponseOutputMessageFunctionToolCall(
arguments=tool_call.function.arguments or "",
call_id=tool_call.id,
name=tool_call.function.name or "",
id=f"fc_{uuid.uuid4()}",
status="completed",
)
) )
else: else:
output_messages.extend(await _openai_choices_to_output_messages(chat_response.choices)) output_messages.extend(
await self._execute_tool_and_return_final_output(model, stream, choice, messages, temperature)
)
else:
output_messages.append(await _convert_chat_choice_to_response_message(choice))
response = OpenAIResponseObject( response = OpenAIResponseObject(
created_at=chat_response.created, created_at=chat_response.created,
id=f"resp-{uuid.uuid4()}", id=f"resp-{uuid.uuid4()}",
@ -195,13 +325,43 @@ class OpenAIResponsesImpl:
status="completed", status="completed",
output=output_messages, output=output_messages,
) )
logger.debug(f"OpenAI Responses response: {response}")
if store: if store:
# Store in kvstore # Store in kvstore
new_input_id = f"msg_{uuid.uuid4()}"
if isinstance(input, str):
# synthesize a message from the input string
input_content = OpenAIResponseInputMessageContentText(text=input)
input_content_item = OpenAIResponseMessage(
role="user",
content=[input_content],
id=new_input_id,
)
input_items_data = [input_content_item]
else:
# we already have a list of messages
input_items_data = []
for input_item in input:
if isinstance(input_item, OpenAIResponseMessage):
# These may or may not already have an id, so dump to dict, check for id, and add if missing
input_item_dict = input_item.model_dump()
if "id" not in input_item_dict:
input_item_dict["id"] = new_input_id
input_items_data.append(OpenAIResponseMessage(**input_item_dict))
else:
input_items_data.append(input_item)
input_items = OpenAIResponseInputItemList(data=input_items_data)
prev_response = OpenAIResponsePreviousResponseWithInputItems(
input_items=input_items,
response=response,
)
key = f"{OPENAI_RESPONSES_PREFIX}{response.id}" key = f"{OPENAI_RESPONSES_PREFIX}{response.id}"
await self.persistence_store.set( await self.persistence_store.set(
key=key, key=key,
value=response.model_dump_json(), value=prev_response.model_dump_json(),
) )
if stream: if stream:
@ -221,7 +381,9 @@ class OpenAIResponsesImpl:
chat_tools: list[ChatCompletionToolParam] = [] chat_tools: list[ChatCompletionToolParam] = []
for input_tool in tools: for input_tool in tools:
# TODO: Handle other tool types # TODO: Handle other tool types
if input_tool.type == "web_search": if input_tool.type == "function":
chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump()))
elif input_tool.type == "web_search":
tool_name = "web_search" tool_name = "web_search"
tool = await self.tool_groups_api.get_tool(tool_name) tool = await self.tool_groups_api.get_tool(tool_name)
tool_def = ToolDefinition( tool_def = ToolDefinition(
@ -247,12 +409,11 @@ class OpenAIResponsesImpl:
self, self,
model_id: str, model_id: str,
stream: bool, stream: bool,
chat_response: OpenAIChatCompletion, choice: OpenAIChoice,
messages: list[OpenAIMessageParam], messages: list[OpenAIMessageParam],
temperature: float, temperature: float,
) -> list[OpenAIResponseOutput]: ) -> list[OpenAIResponseOutput]:
output_messages: list[OpenAIResponseOutput] = [] output_messages: list[OpenAIResponseOutput] = []
choice = chat_response.choices[0]
# If the choice is not an assistant message, we don't need to execute any tools # If the choice is not an assistant message, we don't need to execute any tools
if not isinstance(choice.message, OpenAIAssistantMessageParam): if not isinstance(choice.message, OpenAIAssistantMessageParam):
@ -262,6 +423,9 @@ class OpenAIResponsesImpl:
if not choice.message.tool_calls: if not choice.message.tool_calls:
return output_messages return output_messages
# Copy the messages list to avoid mutating the original list
messages = messages.copy()
# Add the assistant message with tool_calls response to the messages list # Add the assistant message with tool_calls response to the messages list
messages.append(choice.message) messages.append(choice.message)
@ -307,7 +471,9 @@ class OpenAIResponsesImpl:
) )
# type cast to appease mypy # type cast to appease mypy
tool_results_chat_response = cast(OpenAIChatCompletion, tool_results_chat_response) tool_results_chat_response = cast(OpenAIChatCompletion, tool_results_chat_response)
tool_final_outputs = await _openai_choices_to_output_messages(tool_results_chat_response.choices) tool_final_outputs = [
await _convert_chat_choice_to_response_message(choice) for choice in tool_results_chat_response.choices
]
# TODO: Wire in annotations with URLs, titles, etc to these output messages # TODO: Wire in annotations with URLs, titles, etc to these output messages
output_messages.extend(tool_final_outputs) output_messages.extend(tool_final_outputs)
return output_messages return output_messages

View file

@ -9,9 +9,7 @@ import logging
import uuid import uuid
from datetime import datetime, timezone from datetime import datetime, timezone
from pydantic import BaseModel from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
from llama_stack.apis.agents import ToolExecutionStep, Turn
from llama_stack.distribution.access_control import check_access from llama_stack.distribution.access_control import check_access
from llama_stack.distribution.datatypes import AccessAttributes from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.distribution.request_headers import get_auth_attributes from llama_stack.distribution.request_headers import get_auth_attributes
@ -20,15 +18,17 @@ from llama_stack.providers.utils.kvstore import KVStore
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class AgentSessionInfo(BaseModel): class AgentSessionInfo(Session):
session_id: str
session_name: str
# TODO: is this used anywhere? # TODO: is this used anywhere?
vector_db_id: str | None = None vector_db_id: str | None = None
started_at: datetime started_at: datetime
access_attributes: AccessAttributes | None = None access_attributes: AccessAttributes | None = None
class AgentInfo(AgentConfig):
created_at: datetime
class AgentPersistence: class AgentPersistence:
def __init__(self, agent_id: str, kvstore: KVStore): def __init__(self, agent_id: str, kvstore: KVStore):
self.agent_id = agent_id self.agent_id = agent_id
@ -46,6 +46,7 @@ class AgentPersistence:
session_name=name, session_name=name,
started_at=datetime.now(timezone.utc), started_at=datetime.now(timezone.utc),
access_attributes=access_attributes, access_attributes=access_attributes,
turns=[],
) )
await self.kvstore.set( await self.kvstore.set(
@ -109,7 +110,7 @@ class AgentPersistence:
if not await self.get_session_if_accessible(session_id): if not await self.get_session_if_accessible(session_id):
raise ValueError(f"Session {session_id} not found or access denied") raise ValueError(f"Session {session_id} not found or access denied")
values = await self.kvstore.range( values = await self.kvstore.values_in_range(
start_key=f"session:{self.agent_id}:{session_id}:", start_key=f"session:{self.agent_id}:{session_id}:",
end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff", end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff",
) )
@ -121,7 +122,6 @@ class AgentPersistence:
except Exception as e: except Exception as e:
log.error(f"Error parsing turn: {e}") log.error(f"Error parsing turn: {e}")
continue continue
turns.sort(key=lambda x: (x.completed_at or datetime.min))
return turns return turns
async def get_session_turn(self, session_id: str, turn_id: str) -> Turn | None: async def get_session_turn(self, session_id: str, turn_id: str) -> Turn | None:
@ -170,3 +170,43 @@ class AgentPersistence:
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}", key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
) )
return int(value) if value else None return int(value) if value else None
async def list_sessions(self) -> list[Session]:
values = await self.kvstore.values_in_range(
start_key=f"session:{self.agent_id}:",
end_key=f"session:{self.agent_id}:\xff\xff\xff\xff",
)
sessions = []
for value in values:
try:
session_info = Session(**json.loads(value))
sessions.append(session_info)
except Exception as e:
log.error(f"Error parsing session info: {e}")
continue
return sessions
async def delete_session_turns(self, session_id: str) -> None:
"""Delete all turns and their associated data for a session.
Args:
session_id: The ID of the session whose turns should be deleted.
"""
turns = await self.get_session_turns(session_id)
for turn in turns:
await self.kvstore.delete(key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}")
async def delete_session(self, session_id: str) -> None:
"""Delete a session and all its associated turns.
Args:
session_id: The ID of the session to delete.
Raises:
ValueError: If the session does not exist.
"""
session_info = await self.get_session_info(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")
await self.kvstore.delete(key=f"session:{self.agent_id}:{session_id}")

View file

@ -11,9 +11,9 @@ from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Dataset from llama_stack.apis.datasets import Dataset
from llama_stack.providers.datatypes import DatasetsProtocolPrivate from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.pagination import paginate_records
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_uri from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_uri
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.pagination import paginate_records
from .config import LocalFSDatasetIOConfig from .config import LocalFSDatasetIOConfig
@ -64,7 +64,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
# Load existing datasets from kvstore # Load existing datasets from kvstore
start_key = DATASETS_PREFIX start_key = DATASETS_PREFIX
end_key = f"{DATASETS_PREFIX}\xff" end_key = f"{DATASETS_PREFIX}\xff"
stored_datasets = await self.kvstore.range(start_key, end_key) stored_datasets = await self.kvstore.values_in_range(start_key, end_key)
for dataset in stored_datasets: for dataset in stored_datasets:
dataset = Dataset.model_validate_json(dataset) dataset = Dataset.model_validate_json(dataset)

View file

@ -58,7 +58,7 @@ class MetaReferenceEvalImpl(
# Load existing benchmarks from kvstore # Load existing benchmarks from kvstore
start_key = EVAL_TASKS_PREFIX start_key = EVAL_TASKS_PREFIX
end_key = f"{EVAL_TASKS_PREFIX}\xff" end_key = f"{EVAL_TASKS_PREFIX}\xff"
stored_benchmarks = await self.kvstore.range(start_key, end_key) stored_benchmarks = await self.kvstore.values_in_range(start_key, end_key)
for benchmark in stored_benchmarks: for benchmark in stored_benchmarks:
benchmark = Benchmark.model_validate_json(benchmark) benchmark = Benchmark.model_validate_json(benchmark)

View file

@ -28,7 +28,7 @@ from llama_stack.apis.inference import (
CompletionRequest, CompletionRequest,
CompletionResponse, CompletionResponse,
CompletionResponseStreamChunk, CompletionResponseStreamChunk,
Inference, InferenceProvider,
InterleavedContent, InterleavedContent,
LogProbConfig, LogProbConfig,
Message, Message,
@ -86,7 +86,7 @@ class MetaReferenceInferenceImpl(
OpenAICompletionToLlamaStackMixin, OpenAICompletionToLlamaStackMixin,
OpenAIChatCompletionToLlamaStackMixin, OpenAIChatCompletionToLlamaStackMixin,
SentenceTransformerEmbeddingMixin, SentenceTransformerEmbeddingMixin,
Inference, InferenceProvider,
ModelsProtocolPrivate, ModelsProtocolPrivate,
): ):
def __init__(self, config: MetaReferenceInferenceConfig) -> None: def __init__(self, config: MetaReferenceInferenceConfig) -> None:

View file

@ -9,7 +9,7 @@ from collections.abc import AsyncGenerator
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
CompletionResponse, CompletionResponse,
Inference, InferenceProvider,
InterleavedContent, InterleavedContent,
LogProbConfig, LogProbConfig,
Message, Message,
@ -38,7 +38,7 @@ class SentenceTransformersInferenceImpl(
OpenAIChatCompletionToLlamaStackMixin, OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin, OpenAICompletionToLlamaStackMixin,
SentenceTransformerEmbeddingMixin, SentenceTransformerEmbeddingMixin,
Inference, InferenceProvider,
ModelsProtocolPrivate, ModelsProtocolPrivate,
): ):
def __init__(self, config: SentenceTransformersInferenceConfig) -> None: def __init__(self, config: SentenceTransformersInferenceConfig) -> None:

View file

@ -0,0 +1,35 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import gc
def evacuate_model_from_device(model, device: str):
"""Safely clear a model from memory and free device resources.
This function handles the proper cleanup of a model by:
1. Moving the model to CPU if it's on a non-CPU device
2. Deleting the model object to free memory
3. Running garbage collection
4. Clearing CUDA cache if the model was on a CUDA device
Args:
model: The PyTorch model to clear
device: The device type the model is currently on ('cuda', 'mps', 'cpu')
Note:
- For CUDA devices, this will clear the CUDA cache after moving the model to CPU
- For MPS devices, only moves the model to CPU (no cache clearing available)
- For CPU devices, only deletes the model object and runs garbage collection
"""
if device != "cpu":
model.to("cpu")
del model
gc.collect()
if device == "cuda":
# we need to import such that this is only imported when the method is called
import torch
torch.cuda.empty_cache()

View file

@ -0,0 +1,27 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.distribution.datatypes import Api
from .config import HuggingFacePostTrainingConfig
# post_training api and the huggingface provider is still experimental and under heavy development
async def get_provider_impl(
config: HuggingFacePostTrainingConfig,
deps: dict[Api, Any],
):
from .post_training import HuggingFacePostTrainingImpl
impl = HuggingFacePostTrainingImpl(
config,
deps[Api.datasetio],
deps[Api.datasets],
)
return impl

View file

@ -0,0 +1,72 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Literal
from pydantic import BaseModel
class HuggingFacePostTrainingConfig(BaseModel):
# Device to run training on (cuda, cpu, mps)
device: str = "cuda"
# Distributed training backend if using multiple devices
# fsdp: Fully Sharded Data Parallel
# deepspeed: DeepSpeed ZeRO optimization
distributed_backend: Literal["fsdp", "deepspeed"] | None = None
# Format for saving model checkpoints
# full_state: Save complete model state
# huggingface: Save in HuggingFace format (recommended for compatibility)
checkpoint_format: Literal["full_state", "huggingface"] | None = "huggingface"
# Template for formatting chat inputs and outputs
# Used to structure the conversation format for training
chat_template: str = "<|user|>\n{input}\n<|assistant|>\n{output}"
# Model-specific configuration parameters
# trust_remote_code: Allow execution of custom model code
# attn_implementation: Use SDPA (Scaled Dot Product Attention) for better performance
model_specific_config: dict = {
"trust_remote_code": True,
"attn_implementation": "sdpa",
}
# Maximum sequence length for training
# Set to 2048 as this is the maximum that works reliably on MPS (Apple Silicon)
# Longer sequences may cause memory issues on MPS devices
max_seq_length: int = 2048
# Enable gradient checkpointing to reduce memory usage
# Trades computation for memory by recomputing activations
gradient_checkpointing: bool = False
# Maximum number of checkpoints to keep
# Older checkpoints are deleted when this limit is reached
save_total_limit: int = 3
# Number of training steps between logging updates
logging_steps: int = 10
# Ratio of training steps used for learning rate warmup
# Helps stabilize early training
warmup_ratio: float = 0.1
# L2 regularization coefficient
# Helps prevent overfitting
weight_decay: float = 0.01
# Number of worker processes for data loading
# Higher values can improve data loading speed but increase memory usage
dataloader_num_workers: int = 4
# Whether to pin memory in data loader
# Can improve data transfer speed to GPU but uses more memory
dataloader_pin_memory: bool = True
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {"checkpoint_format": "huggingface", "distributed_backend": None, "device": "cpu"}

View file

@ -0,0 +1,176 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
AlgorithmConfig,
Checkpoint,
DPOAlignmentConfig,
JobStatus,
ListPostTrainingJobsResponse,
PostTrainingJob,
PostTrainingJobArtifactsResponse,
PostTrainingJobStatusResponse,
TrainingConfig,
)
from llama_stack.providers.inline.post_training.huggingface.config import (
HuggingFacePostTrainingConfig,
)
from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device import (
HFFinetuningSingleDevice,
)
from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler
from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus
from llama_stack.schema_utils import webmethod
class TrainingArtifactType(Enum):
CHECKPOINT = "checkpoint"
RESOURCES_STATS = "resources_stats"
_JOB_TYPE_SUPERVISED_FINE_TUNE = "supervised-fine-tune"
class HuggingFacePostTrainingImpl:
def __init__(
self,
config: HuggingFacePostTrainingConfig,
datasetio_api: DatasetIO,
datasets: Datasets,
) -> None:
self.config = config
self.datasetio_api = datasetio_api
self.datasets_api = datasets
self._scheduler = Scheduler()
async def shutdown(self) -> None:
await self._scheduler.shutdown()
@staticmethod
def _checkpoint_to_artifact(checkpoint: Checkpoint) -> JobArtifact:
return JobArtifact(
type=TrainingArtifactType.CHECKPOINT.value,
name=checkpoint.identifier,
uri=checkpoint.path,
metadata=dict(checkpoint),
)
@staticmethod
def _resources_stats_to_artifact(resources_stats: dict[str, Any]) -> JobArtifact:
return JobArtifact(
type=TrainingArtifactType.RESOURCES_STATS.value,
name=TrainingArtifactType.RESOURCES_STATS.value,
metadata=resources_stats,
)
async def supervised_fine_tune(
self,
job_uuid: str,
training_config: TrainingConfig,
hyperparam_search_config: dict[str, Any],
logger_config: dict[str, Any],
model: str,
checkpoint_dir: str | None = None,
algorithm_config: AlgorithmConfig | None = None,
) -> PostTrainingJob:
async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb):
on_log_message_cb("Starting HF finetuning")
recipe = HFFinetuningSingleDevice(
job_uuid=job_uuid,
datasetio_api=self.datasetio_api,
datasets_api=self.datasets_api,
)
resources_allocated, checkpoints = await recipe.train(
model=model,
output_dir=checkpoint_dir,
job_uuid=job_uuid,
lora_config=algorithm_config,
config=training_config,
provider_config=self.config,
)
on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated))
if checkpoints:
for checkpoint in checkpoints:
artifact = self._checkpoint_to_artifact(checkpoint)
on_artifact_collected_cb(artifact)
on_status_change_cb(SchedulerJobStatus.completed)
on_log_message_cb("HF finetuning completed")
job_uuid = self._scheduler.schedule(_JOB_TYPE_SUPERVISED_FINE_TUNE, job_uuid, handler)
return PostTrainingJob(job_uuid=job_uuid)
async def preference_optimize(
self,
job_uuid: str,
finetuned_model: str,
algorithm_config: DPOAlignmentConfig,
training_config: TrainingConfig,
hyperparam_search_config: dict[str, Any],
logger_config: dict[str, Any],
) -> PostTrainingJob:
raise NotImplementedError("DPO alignment is not implemented yet")
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
return ListPostTrainingJobsResponse(
data=[PostTrainingJob(job_uuid=job.id) for job in self._scheduler.get_jobs()]
)
@staticmethod
def _get_artifacts_metadata_by_type(job, artifact_type):
return [artifact.metadata for artifact in job.artifacts if artifact.type == artifact_type]
@classmethod
def _get_checkpoints(cls, job):
return cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.CHECKPOINT.value)
@classmethod
def _get_resources_allocated(cls, job):
data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value)
return data[0] if data else None
@webmethod(route="/post-training/job/status")
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse | None:
job = self._scheduler.get_job(job_uuid)
match job.status:
# TODO: Add support for other statuses to API
case SchedulerJobStatus.new | SchedulerJobStatus.scheduled:
status = JobStatus.scheduled
case SchedulerJobStatus.running:
status = JobStatus.in_progress
case SchedulerJobStatus.completed:
status = JobStatus.completed
case SchedulerJobStatus.failed:
status = JobStatus.failed
case _:
raise NotImplementedError()
return PostTrainingJobStatusResponse(
job_uuid=job_uuid,
status=status,
scheduled_at=job.scheduled_at,
started_at=job.started_at,
completed_at=job.completed_at,
checkpoints=self._get_checkpoints(job),
resources_allocated=self._get_resources_allocated(job),
)
@webmethod(route="/post-training/job/cancel")
async def cancel_training_job(self, job_uuid: str) -> None:
self._scheduler.cancel(job_uuid)
@webmethod(route="/post-training/job/artifacts")
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse | None:
job = self._scheduler.get_job(job_uuid)
return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job))

Some files were not shown because too many files have changed in this diff Show more