Merge branch 'main' into nvidia-e2e-notebook

This commit is contained in:
Jash Gulabrai 2025-05-28 17:48:15 -04:00
commit f5cb965f0f
226 changed files with 16519 additions and 8666 deletions

View file

@ -1,10 +1,8 @@
# What does this PR do? # What does this PR do?
[Provide a short summary of what this PR does and why. Link to relevant issues if applicable.] <!-- Provide a short summary of what this PR does and why. Link to relevant issues if applicable. -->
[//]: # (If resolving an issue, uncomment and update the line below) <!-- If resolving an issue, uncomment and update the line below -->
[//]: # (Closes #[issue-number]) <!-- Closes #[issue-number] -->
## Test Plan ## Test Plan
[Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*] <!-- Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.* -->
[//]: # (## Documentation)

22
.github/actions/setup-runner/action.yml vendored Normal file
View file

@ -0,0 +1,22 @@
name: Setup runner
description: Prepare a runner for the tests (install uv, python, project dependencies, etc.)
runs:
using: "composite"
steps:
- name: Install uv
uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1
with:
python-version: "3.10"
activate-environment: true
version: 0.7.6
- name: Install dependencies
shell: bash
run: |
uv sync --all-groups
uv pip install ollama faiss-cpu
# always test against the latest version of the client
# TODO: this is not necessarily a good idea. we need to test against both published and latest
# to find out backwards compatibility issues.
uv pip install git+https://github.com/meta-llama/llama-stack-client-python.git@main
uv pip install -e .

View file

@ -23,23 +23,18 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
matrix: matrix:
auth-provider: [kubernetes] auth-provider: [oauth2_token]
fail-fast: false # we want to run all tests regardless of failure fail-fast: false # we want to run all tests regardless of failure
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Install uv - name: Install dependencies
uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1 uses: ./.github/actions/setup-runner
with:
python-version: "3.10"
activate-environment: true
- name: Set Up Environment and Install Dependencies - name: Build Llama Stack
run: | run: |
uv sync --extra dev --extra test
uv pip install -e .
llama stack build --template ollama --image-type venv llama stack build --template ollama --image-type venv
- name: Install minikube - name: Install minikube
@ -47,29 +42,53 @@ jobs:
uses: medyagh/setup-minikube@cea33675329b799adccc9526aa5daccc26cd5052 # v0.0.19 uses: medyagh/setup-minikube@cea33675329b799adccc9526aa5daccc26cd5052 # v0.0.19
- name: Start minikube - name: Start minikube
if: ${{ matrix.auth-provider == 'kubernetes' }} if: ${{ matrix.auth-provider == 'oauth2_token' }}
run: | run: |
minikube start minikube start
kubectl get pods -A kubectl get pods -A
- name: Configure Kube Auth - name: Configure Kube Auth
if: ${{ matrix.auth-provider == 'kubernetes' }} if: ${{ matrix.auth-provider == 'oauth2_token' }}
run: | run: |
kubectl create namespace llama-stack kubectl create namespace llama-stack
kubectl create serviceaccount llama-stack-auth -n llama-stack kubectl create serviceaccount llama-stack-auth -n llama-stack
kubectl create rolebinding llama-stack-auth-rolebinding --clusterrole=admin --serviceaccount=llama-stack:llama-stack-auth -n llama-stack kubectl create rolebinding llama-stack-auth-rolebinding --clusterrole=admin --serviceaccount=llama-stack:llama-stack-auth -n llama-stack
kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token
cat <<EOF | kubectl apply -f -
apiVersion: rbac.authorization.k8s.io/v1
kind: ClusterRole
metadata:
name: allow-anonymous-openid
rules:
- nonResourceURLs: ["/openid/v1/jwks"]
verbs: ["get"]
---
apiVersion: rbac.authorization.k8s.io/v1
kind: ClusterRoleBinding
metadata:
name: allow-anonymous-openid
roleRef:
apiGroup: rbac.authorization.k8s.io
kind: ClusterRole
name: allow-anonymous-openid
subjects:
- kind: User
name: system:anonymous
apiGroup: rbac.authorization.k8s.io
EOF
- name: Set Kubernetes Config - name: Set Kubernetes Config
if: ${{ matrix.auth-provider == 'kubernetes' }} if: ${{ matrix.auth-provider == 'oauth2_token' }}
run: | run: |
echo "KUBERNETES_API_SERVER_URL=$(kubectl config view --minify -o jsonpath='{.clusters[0].cluster.server}')" >> $GITHUB_ENV echo "KUBERNETES_API_SERVER_URL=$(kubectl get --raw /.well-known/openid-configuration| jq -r .jwks_uri)" >> $GITHUB_ENV
echo "KUBERNETES_CA_CERT_PATH=$(kubectl config view --minify -o jsonpath='{.clusters[0].cluster.certificate-authority}')" >> $GITHUB_ENV echo "KUBERNETES_CA_CERT_PATH=$(kubectl config view --minify -o jsonpath='{.clusters[0].cluster.certificate-authority}')" >> $GITHUB_ENV
echo "KUBERNETES_ISSUER=$(kubectl get --raw /.well-known/openid-configuration| jq -r .issuer)" >> $GITHUB_ENV
echo "KUBERNETES_AUDIENCE=$(kubectl create token llama-stack-auth -n llama-stack --duration=1h | cut -d. -f2 | base64 -d | jq -r '.aud[0]')" >> $GITHUB_ENV
- name: Set Kube Auth Config and run server - name: Set Kube Auth Config and run server
env: env:
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct" INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
if: ${{ matrix.auth-provider == 'kubernetes' }} if: ${{ matrix.auth-provider == 'oauth2_token' }}
run: | run: |
run_dir=$(mktemp -d) run_dir=$(mktemp -d)
cat <<'EOF' > $run_dir/run.yaml cat <<'EOF' > $run_dir/run.yaml
@ -81,10 +100,10 @@ jobs:
port: 8321 port: 8321
EOF EOF
yq eval '.server.auth = {"provider_type": "${{ matrix.auth-provider }}"}' -i $run_dir/run.yaml yq eval '.server.auth = {"provider_type": "${{ matrix.auth-provider }}"}' -i $run_dir/run.yaml
yq eval '.server.auth.config = {"api_server_url": "${{ env.KUBERNETES_API_SERVER_URL }}", "ca_cert_path": "${{ env.KUBERNETES_CA_CERT_PATH }}"}' -i $run_dir/run.yaml yq eval '.server.auth.config = {"tls_cafile": "${{ env.KUBERNETES_CA_CERT_PATH }}", "issuer": "${{ env.KUBERNETES_ISSUER }}", "audience": "${{ env.KUBERNETES_AUDIENCE }}"}' -i $run_dir/run.yaml
yq eval '.server.auth.config.jwks = {"uri": "${{ env.KUBERNETES_API_SERVER_URL }}"}' -i $run_dir/run.yaml
cat $run_dir/run.yaml cat $run_dir/run.yaml
source .venv/bin/activate
nohup uv run llama stack run $run_dir/run.yaml --image-type venv > server.log 2>&1 & nohup uv run llama stack run $run_dir/run.yaml --image-type venv > server.log 2>&1 &
- name: Wait for Llama Stack server to be ready - name: Wait for Llama Stack server to be ready

View file

@ -24,7 +24,7 @@ jobs:
matrix: matrix:
# Listing tests manually since some of them currently fail # Listing tests manually since some of them currently fail
# TODO: generate matrix list from tests/integration when fixed # TODO: generate matrix list from tests/integration when fixed
test-type: [agents, inference, datasets, inspect, scoring, post_training, providers] test-type: [agents, inference, datasets, inspect, scoring, post_training, providers, tool_runtime]
client-type: [library, http] client-type: [library, http]
fail-fast: false # we want to run all tests regardless of failure fail-fast: false # we want to run all tests regardless of failure
@ -32,24 +32,14 @@ jobs:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Install uv - name: Install dependencies
uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1 uses: ./.github/actions/setup-runner
with:
python-version: "3.10"
activate-environment: true
- name: Setup ollama - name: Setup ollama
uses: ./.github/actions/setup-ollama uses: ./.github/actions/setup-ollama
- name: Set Up Environment and Install Dependencies - name: Build Llama Stack
run: | run: |
uv sync --extra dev --extra test
uv pip install ollama faiss-cpu
# always test against the latest version of the client
# TODO: this is not necessarily a good idea. we need to test against both published and latest
# to find out backwards compatibility issues.
uv pip install git+https://github.com/meta-llama/llama-stack-client-python.git@main
uv pip install -e .
llama stack build --template ollama --image-type venv llama stack build --template ollama --image-type venv
- name: Start Llama Stack server in background - name: Start Llama Stack server in background
@ -57,7 +47,6 @@ jobs:
env: env:
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct" INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
run: | run: |
source .venv/bin/activate
LLAMA_STACK_LOG_FILE=server.log nohup uv run llama stack run ./llama_stack/templates/ollama/run.yaml --image-type venv & LLAMA_STACK_LOG_FILE=server.log nohup uv run llama stack run ./llama_stack/templates/ollama/run.yaml --image-type venv &
- name: Wait for Llama Stack server to be ready - name: Wait for Llama Stack server to be ready
@ -85,6 +74,7 @@ jobs:
echo "Ollama health check failed" echo "Ollama health check failed"
exit 1 exit 1
fi fi
- name: Check Storage and Memory Available Before Tests - name: Check Storage and Memory Available Before Tests
if: ${{ always() }} if: ${{ always() }}
run: | run: |
@ -100,7 +90,7 @@ jobs:
else else
stack_config="http://localhost:8321" stack_config="http://localhost:8321"
fi fi
uv run pytest -v tests/integration/${{ matrix.test-type }} --stack-config=${stack_config} \ uv run pytest -s -v tests/integration/${{ matrix.test-type }} --stack-config=${stack_config} \
-k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \ -k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \
--text-model="meta-llama/Llama-3.2-3B-Instruct" \ --text-model="meta-llama/Llama-3.2-3B-Instruct" \
--embedding-model=all-MiniLM-L6-v2 --embedding-model=all-MiniLM-L6-v2

View file

@ -29,6 +29,7 @@ jobs:
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1 - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
env: env:
SKIP: no-commit-to-branch SKIP: no-commit-to-branch
RUFF_OUTPUT_FORMAT: github
- name: Verify if there are any diff files after pre-commit - name: Verify if there are any diff files after pre-commit
run: | run: |

View file

@ -50,21 +50,8 @@ jobs:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python - name: Install dependencies
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 uses: ./.github/actions/setup-runner
with:
python-version: '3.10'
- name: Install uv
uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1
with:
python-version: "3.10"
- name: Install LlamaStack
run: |
uv venv
source .venv/bin/activate
uv pip install -e .
- name: Print build dependencies - name: Print build dependencies
run: | run: |
@ -79,7 +66,6 @@ jobs:
- name: Print dependencies in the image - name: Print dependencies in the image
if: matrix.image-type == 'venv' if: matrix.image-type == 'venv'
run: | run: |
source test/bin/activate
uv pip list uv pip list
build-single-provider: build-single-provider:
@ -88,21 +74,8 @@ jobs:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python - name: Install dependencies
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 uses: ./.github/actions/setup-runner
with:
python-version: '3.10'
- name: Install uv
uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1
with:
python-version: "3.10"
- name: Install LlamaStack
run: |
uv venv
source .venv/bin/activate
uv pip install -e .
- name: Build a single provider - name: Build a single provider
run: | run: |
@ -114,21 +87,8 @@ jobs:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python - name: Install dependencies
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 uses: ./.github/actions/setup-runner
with:
python-version: '3.10'
- name: Install uv
uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1
with:
python-version: "3.10"
- name: Install LlamaStack
run: |
uv venv
source .venv/bin/activate
uv pip install -e .
- name: Build a single provider - name: Build a single provider
run: | run: |
@ -152,21 +112,8 @@ jobs:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python - name: Install dependencies
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 uses: ./.github/actions/setup-runner
with:
python-version: '3.10'
- name: Install uv
uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1
with:
python-version: "3.10"
- name: Install LlamaStack
run: |
uv venv
source .venv/bin/activate
uv pip install -e .
- name: Pin template to UBI9 base - name: Pin template to UBI9 base
run: | run: |

View file

@ -25,15 +25,8 @@ jobs:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Install uv - name: Install dependencies
uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1 uses: ./.github/actions/setup-runner
with:
python-version: "3.10"
- name: Set Up Environment and Install Dependencies
run: |
uv sync --extra dev --extra test
uv pip install -e .
- name: Apply image type to config file - name: Apply image type to config file
run: | run: |
@ -59,7 +52,6 @@ jobs:
env: env:
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct" INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
run: | run: |
source ci-test/bin/activate
uv run pip list uv run pip list
nohup uv run --active llama stack run tests/external-provider/llama-stack-provider-ollama/run.yaml --image-type ${{ matrix.image-type }} > server.log 2>&1 & nohup uv run --active llama stack run tests/external-provider/llama-stack-provider-ollama/run.yaml --image-type ${{ matrix.image-type }} > server.log 2>&1 &

View file

@ -30,17 +30,11 @@ jobs:
- "3.12" - "3.12"
- "3.13" - "3.13"
steps: steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python ${{ matrix.python }} - name: Install dependencies
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 uses: ./.github/actions/setup-runner
with:
python-version: ${{ matrix.python }}
- uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1
with:
python-version: ${{ matrix.python }}
enable-cache: false
- name: Run unit tests - name: Run unit tests
run: | run: |

View file

@ -37,16 +37,8 @@ jobs:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python - name: Install dependencies
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 uses: ./.github/actions/setup-runner
with:
python-version: '3.11'
- name: Install the latest version of uv
uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1
- name: Sync with uv
run: uv sync --extra docs
- name: Build HTML - name: Build HTML
run: | run: |

1
.gitignore vendored
View file

@ -6,6 +6,7 @@ dev_requirements.txt
build build
.DS_Store .DS_Store
llama_stack/configs/* llama_stack/configs/*
.cursor/
xcuserdata/ xcuserdata/
*.hmap *.hmap
.DS_Store .DS_Store

View file

@ -53,7 +53,7 @@ repos:
- black==24.3.0 - black==24.3.0
- repo: https://github.com/astral-sh/uv-pre-commit - repo: https://github.com/astral-sh/uv-pre-commit
rev: 0.6.3 rev: 0.7.8
hooks: hooks:
- id: uv-lock - id: uv-lock
- id: uv-export - id: uv-export
@ -61,6 +61,7 @@ repos:
"--frozen", "--frozen",
"--no-hashes", "--no-hashes",
"--no-emit-project", "--no-emit-project",
"--no-default-groups",
"--output-file=requirements.txt" "--output-file=requirements.txt"
] ]
@ -88,20 +89,17 @@ repos:
- id: distro-codegen - id: distro-codegen
name: Distribution Template Codegen name: Distribution Template Codegen
additional_dependencies: additional_dependencies:
- uv==0.6.0 - uv==0.7.8
entry: uv run --extra codegen ./scripts/distro_codegen.py entry: uv run --group codegen ./scripts/distro_codegen.py
language: python language: python
pass_filenames: false pass_filenames: false
require_serial: true require_serial: true
files: ^llama_stack/templates/.*$|^llama_stack/providers/.*/inference/.*/models\.py$ files: ^llama_stack/templates/.*$|^llama_stack/providers/.*/inference/.*/models\.py$
- repo: local
hooks:
- id: openapi-codegen - id: openapi-codegen
name: API Spec Codegen name: API Spec Codegen
additional_dependencies: additional_dependencies:
- uv==0.6.2 - uv==0.7.8
entry: sh -c 'uv run --with ".[dev]" ./docs/openapi_generator/run_openapi_generator.sh > /dev/null' entry: sh -c 'uv run ./docs/openapi_generator/run_openapi_generator.sh > /dev/null'
language: python language: python
pass_filenames: false pass_filenames: false
require_serial: true require_serial: true

View file

@ -1,5 +1,26 @@
# Changelog # Changelog
# v0.2.7
Published on: 2025-05-16T20:38:10Z
## Highlights
This is a small update. But a couple highlights:
* feat: function tools in OpenAI Responses by @bbrowning in https://github.com/meta-llama/llama-stack/pull/2094, getting closer to ready. Streaming is the next missing piece.
* feat: Adding support for customizing chunk context in RAG insertion and querying by @franciscojavierarceo in https://github.com/meta-llama/llama-stack/pull/2134
* feat: scaffolding for Llama Stack UI by @ehhuang in https://github.com/meta-llama/llama-stack/pull/2149, more to come in the coming releases.
---
# v0.2.6
Published on: 2025-05-12T18:06:52Z
---
# v0.2.5 # v0.2.5
Published on: 2025-05-04T20:16:49Z Published on: 2025-05-04T20:16:49Z

View file

@ -167,14 +167,11 @@ If you have made changes to a provider's configuration in any form (introducing
If you are making changes to the documentation at [https://llama-stack.readthedocs.io/en/latest/](https://llama-stack.readthedocs.io/en/latest/), you can use the following command to build the documentation and preview your changes. You will need [Sphinx](https://www.sphinx-doc.org/en/master/) and the readthedocs theme. If you are making changes to the documentation at [https://llama-stack.readthedocs.io/en/latest/](https://llama-stack.readthedocs.io/en/latest/), you can use the following command to build the documentation and preview your changes. You will need [Sphinx](https://www.sphinx-doc.org/en/master/) and the readthedocs theme.
```bash ```bash
cd docs
uv sync --extra docs
# This rebuilds the documentation pages. # This rebuilds the documentation pages.
uv run make html uv run --with ".[docs]" make -C docs/ html
# This will start a local server (usually at http://127.0.0.1:8000) that automatically rebuilds and refreshes when you make changes to the documentation. # This will start a local server (usually at http://127.0.0.1:8000) that automatically rebuilds and refreshes when you make changes to the documentation.
uv run sphinx-autobuild source build/html --write-all uv run --with ".[docs]" sphinx-autobuild docs/source docs/build/html --write-all
``` ```
### Update API Documentation ### Update API Documentation

View file

@ -1,5 +1,4 @@
include pyproject.toml include pyproject.toml
include llama_stack/templates/dependencies.json
include llama_stack/models/llama/llama3/tokenizer.model include llama_stack/models/llama/llama3/tokenizer.model
include llama_stack/models/llama/llama4/tokenizer.model include llama_stack/models/llama/llama4/tokenizer.model
include llama_stack/distribution/*.sh include llama_stack/distribution/*.sh

View file

@ -110,7 +110,7 @@ Here is a list of the various API providers and available distributions that can
| **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** | | **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** |
|:------------------------:|:----------------------:|:----------:|:-------------:|:----------:|:----------:|:-------------:| |:------------------------:|:----------------------:|:----------:|:-------------:|:----------:|:----------:|:-------------:|
| Meta Reference | Single Node | ✅ | ✅ | ✅ | ✅ | ✅ | | Meta Reference | Single Node | ✅ | ✅ | ✅ | ✅ | ✅ |
| SambaNova | Hosted | | ✅ | | | | | SambaNova | Hosted | | ✅ | | | |
| Cerebras | Hosted | | ✅ | | | | | Cerebras | Hosted | | ✅ | | | |
| Fireworks | Hosted | ✅ | ✅ | ✅ | | | | Fireworks | Hosted | ✅ | ✅ | ✅ | | |
| AWS Bedrock | Hosted | | ✅ | | ✅ | | | AWS Bedrock | Hosted | | ✅ | | ✅ | |

View file

@ -518,6 +518,74 @@
} }
}, },
"/v1/openai/v1/responses": { "/v1/openai/v1/responses": {
"get": {
"responses": {
"200": {
"description": "A ListOpenAIResponseObject.",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ListOpenAIResponseObject"
}
}
}
},
"400": {
"$ref": "#/components/responses/BadRequest400"
},
"429": {
"$ref": "#/components/responses/TooManyRequests429"
},
"500": {
"$ref": "#/components/responses/InternalServerError500"
},
"default": {
"$ref": "#/components/responses/DefaultError"
}
},
"tags": [
"Agents"
],
"description": "List all OpenAI responses.",
"parameters": [
{
"name": "after",
"in": "query",
"description": "The ID of the last response to return.",
"required": false,
"schema": {
"type": "string"
}
},
{
"name": "limit",
"in": "query",
"description": "The number of responses to return.",
"required": false,
"schema": {
"type": "integer"
}
},
{
"name": "model",
"in": "query",
"description": "The model to filter responses by.",
"required": false,
"schema": {
"type": "string"
}
},
{
"name": "order",
"in": "query",
"description": "The order to sort responses by when sorted by created_at ('asc' or 'desc').",
"required": false,
"schema": {
"$ref": "#/components/schemas/Order"
}
}
]
},
"post": { "post": {
"responses": { "responses": {
"200": { "200": {
@ -1395,7 +1463,7 @@
] ]
} }
}, },
"/v1/openai/v1/responses/{id}": { "/v1/openai/v1/responses/{response_id}": {
"get": { "get": {
"responses": { "responses": {
"200": { "200": {
@ -1427,7 +1495,7 @@
"description": "Retrieve an OpenAI response by its ID.", "description": "Retrieve an OpenAI response by its ID.",
"parameters": [ "parameters": [
{ {
"name": "id", "name": "response_id",
"in": "path", "in": "path",
"description": "The ID of the OpenAI response to retrieve.", "description": "The ID of the OpenAI response to retrieve.",
"required": true, "required": true,
@ -2926,6 +2994,97 @@
} }
} }
}, },
"/v1/openai/v1/responses/{response_id}/input_items": {
"get": {
"responses": {
"200": {
"description": "An ListOpenAIResponseInputItem.",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ListOpenAIResponseInputItem"
}
}
}
},
"400": {
"$ref": "#/components/responses/BadRequest400"
},
"429": {
"$ref": "#/components/responses/TooManyRequests429"
},
"500": {
"$ref": "#/components/responses/InternalServerError500"
},
"default": {
"$ref": "#/components/responses/DefaultError"
}
},
"tags": [
"Agents"
],
"description": "List input items for a given OpenAI response.",
"parameters": [
{
"name": "response_id",
"in": "path",
"description": "The ID of the response to retrieve input items for.",
"required": true,
"schema": {
"type": "string"
}
},
{
"name": "after",
"in": "query",
"description": "An item ID to list items after, used for pagination.",
"required": false,
"schema": {
"type": "string"
}
},
{
"name": "before",
"in": "query",
"description": "An item ID to list items before, used for pagination.",
"required": false,
"schema": {
"type": "string"
}
},
{
"name": "include",
"in": "query",
"description": "Additional fields to include in the response.",
"required": false,
"schema": {
"type": "array",
"items": {
"type": "string"
}
}
},
{
"name": "limit",
"in": "query",
"description": "A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.",
"required": false,
"schema": {
"type": "integer"
}
},
{
"name": "order",
"in": "query",
"description": "The order to return the input items in. Default is desc.",
"required": false,
"schema": {
"$ref": "#/components/schemas/Order"
}
}
]
}
},
"/v1/providers": { "/v1/providers": {
"get": { "get": {
"responses": { "responses": {
@ -6742,6 +6901,9 @@
}, },
{ {
"$ref": "#/components/schemas/OpenAIResponseInputToolFunction" "$ref": "#/components/schemas/OpenAIResponseInputToolFunction"
},
{
"$ref": "#/components/schemas/OpenAIResponseInputToolMCP"
} }
], ],
"discriminator": { "discriminator": {
@ -6749,7 +6911,8 @@
"mapping": { "mapping": {
"web_search": "#/components/schemas/OpenAIResponseInputToolWebSearch", "web_search": "#/components/schemas/OpenAIResponseInputToolWebSearch",
"file_search": "#/components/schemas/OpenAIResponseInputToolFileSearch", "file_search": "#/components/schemas/OpenAIResponseInputToolFileSearch",
"function": "#/components/schemas/OpenAIResponseInputToolFunction" "function": "#/components/schemas/OpenAIResponseInputToolFunction",
"mcp": "#/components/schemas/OpenAIResponseInputToolMCP"
} }
} }
}, },
@ -6839,6 +7002,110 @@
], ],
"title": "OpenAIResponseInputToolFunction" "title": "OpenAIResponseInputToolFunction"
}, },
"OpenAIResponseInputToolMCP": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "mcp",
"default": "mcp"
},
"server_label": {
"type": "string"
},
"server_url": {
"type": "string"
},
"headers": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
},
"require_approval": {
"oneOf": [
{
"type": "string",
"const": "always"
},
{
"type": "string",
"const": "never"
},
{
"type": "object",
"properties": {
"always": {
"type": "array",
"items": {
"type": "string"
}
},
"never": {
"type": "array",
"items": {
"type": "string"
}
}
},
"additionalProperties": false,
"title": "ApprovalFilter"
}
],
"default": "never"
},
"allowed_tools": {
"oneOf": [
{
"type": "array",
"items": {
"type": "string"
}
},
{
"type": "object",
"properties": {
"tool_names": {
"type": "array",
"items": {
"type": "string"
}
}
},
"additionalProperties": false,
"title": "AllowedToolsFilter"
}
]
}
},
"additionalProperties": false,
"required": [
"type",
"server_label",
"server_url",
"require_approval"
],
"title": "OpenAIResponseInputToolMCP"
},
"OpenAIResponseInputToolWebSearch": { "OpenAIResponseInputToolWebSearch": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -6951,15 +7218,15 @@
"OpenAIResponseOutputMessageFunctionToolCall": { "OpenAIResponseOutputMessageFunctionToolCall": {
"type": "object", "type": "object",
"properties": { "properties": {
"arguments": {
"type": "string"
},
"call_id": { "call_id": {
"type": "string" "type": "string"
}, },
"name": { "name": {
"type": "string" "type": "string"
}, },
"arguments": {
"type": "string"
},
"type": { "type": {
"type": "string", "type": "string",
"const": "function_call", "const": "function_call",
@ -6974,12 +7241,10 @@
}, },
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [
"arguments",
"call_id", "call_id",
"name", "name",
"type", "arguments",
"id", "type"
"status"
], ],
"title": "OpenAIResponseOutputMessageFunctionToolCall" "title": "OpenAIResponseOutputMessageFunctionToolCall"
}, },
@ -7027,6 +7292,9 @@
"type": "string", "type": "string",
"description": "The underlying LLM used for completions." "description": "The underlying LLM used for completions."
}, },
"instructions": {
"type": "string"
},
"previous_response_id": { "previous_response_id": {
"type": "string", "type": "string",
"description": "(Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses." "description": "(Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses."
@ -7142,6 +7410,12 @@
}, },
{ {
"$ref": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall" "$ref": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall"
},
{
"$ref": "#/components/schemas/OpenAIResponseOutputMessageMCPCall"
},
{
"$ref": "#/components/schemas/OpenAIResponseOutputMessageMCPListTools"
} }
], ],
"discriminator": { "discriminator": {
@ -7149,15 +7423,126 @@
"mapping": { "mapping": {
"message": "#/components/schemas/OpenAIResponseMessage", "message": "#/components/schemas/OpenAIResponseMessage",
"web_search_call": "#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall", "web_search_call": "#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall",
"function_call": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall" "function_call": "#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall",
"mcp_call": "#/components/schemas/OpenAIResponseOutputMessageMCPCall",
"mcp_list_tools": "#/components/schemas/OpenAIResponseOutputMessageMCPListTools"
} }
} }
}, },
"OpenAIResponseOutputMessageMCPCall": {
"type": "object",
"properties": {
"id": {
"type": "string"
},
"type": {
"type": "string",
"const": "mcp_call",
"default": "mcp_call"
},
"arguments": {
"type": "string"
},
"name": {
"type": "string"
},
"server_label": {
"type": "string"
},
"error": {
"type": "string"
},
"output": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"id",
"type",
"arguments",
"name",
"server_label"
],
"title": "OpenAIResponseOutputMessageMCPCall"
},
"OpenAIResponseOutputMessageMCPListTools": {
"type": "object",
"properties": {
"id": {
"type": "string"
},
"type": {
"type": "string",
"const": "mcp_list_tools",
"default": "mcp_list_tools"
},
"server_label": {
"type": "string"
},
"tools": {
"type": "array",
"items": {
"type": "object",
"properties": {
"input_schema": {
"type": "object",
"additionalProperties": {
"oneOf": [
{
"type": "null"
},
{
"type": "boolean"
},
{
"type": "number"
},
{
"type": "string"
},
{
"type": "array"
},
{
"type": "object"
}
]
}
},
"name": {
"type": "string"
},
"description": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"input_schema",
"name"
],
"title": "MCPListToolsTool"
}
}
},
"additionalProperties": false,
"required": [
"id",
"type",
"server_label",
"tools"
],
"title": "OpenAIResponseOutputMessageMCPListTools"
},
"OpenAIResponseObjectStream": { "OpenAIResponseObjectStream": {
"oneOf": [ "oneOf": [
{ {
"$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseCreated" "$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseCreated"
}, },
{
"$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseOutputTextDelta"
},
{ {
"$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseCompleted" "$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseCompleted"
} }
@ -7166,6 +7551,7 @@
"propertyName": "type", "propertyName": "type",
"mapping": { "mapping": {
"response.created": "#/components/schemas/OpenAIResponseObjectStreamResponseCreated", "response.created": "#/components/schemas/OpenAIResponseObjectStreamResponseCreated",
"response.output_text.delta": "#/components/schemas/OpenAIResponseObjectStreamResponseOutputTextDelta",
"response.completed": "#/components/schemas/OpenAIResponseObjectStreamResponseCompleted" "response.completed": "#/components/schemas/OpenAIResponseObjectStreamResponseCompleted"
} }
} }
@ -7208,6 +7594,41 @@
], ],
"title": "OpenAIResponseObjectStreamResponseCreated" "title": "OpenAIResponseObjectStreamResponseCreated"
}, },
"OpenAIResponseObjectStreamResponseOutputTextDelta": {
"type": "object",
"properties": {
"content_index": {
"type": "integer"
},
"delta": {
"type": "string"
},
"item_id": {
"type": "string"
},
"output_index": {
"type": "integer"
},
"sequence_number": {
"type": "integer"
},
"type": {
"type": "string",
"const": "response.output_text.delta",
"default": "response.output_text.delta"
}
},
"additionalProperties": false,
"required": [
"content_index",
"delta",
"item_id",
"output_index",
"sequence_number",
"type"
],
"title": "OpenAIResponseObjectStreamResponseOutputTextDelta"
},
"CreateUploadSessionRequest": { "CreateUploadSessionRequest": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -9173,9 +9594,6 @@
"toolgroup_id": { "toolgroup_id": {
"type": "string" "type": "string"
}, },
"tool_host": {
"$ref": "#/components/schemas/ToolHost"
},
"description": { "description": {
"type": "string" "type": "string"
}, },
@ -9217,21 +9635,11 @@
"provider_id", "provider_id",
"type", "type",
"toolgroup_id", "toolgroup_id",
"tool_host",
"description", "description",
"parameters" "parameters"
], ],
"title": "Tool" "title": "Tool"
}, },
"ToolHost": {
"type": "string",
"enum": [
"distribution",
"client",
"model_context_protocol"
],
"title": "ToolHost"
},
"ToolGroup": { "ToolGroup": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -10068,6 +10476,130 @@
], ],
"title": "ListModelsResponse" "title": "ListModelsResponse"
}, },
"ListOpenAIResponseInputItem": {
"type": "object",
"properties": {
"data": {
"type": "array",
"items": {
"$ref": "#/components/schemas/OpenAIResponseInput"
}
},
"object": {
"type": "string",
"const": "list",
"default": "list"
}
},
"additionalProperties": false,
"required": [
"data",
"object"
],
"title": "ListOpenAIResponseInputItem"
},
"ListOpenAIResponseObject": {
"type": "object",
"properties": {
"data": {
"type": "array",
"items": {
"$ref": "#/components/schemas/OpenAIResponseObjectWithInput"
}
},
"has_more": {
"type": "boolean"
},
"first_id": {
"type": "string"
},
"last_id": {
"type": "string"
},
"object": {
"type": "string",
"const": "list",
"default": "list"
}
},
"additionalProperties": false,
"required": [
"data",
"has_more",
"first_id",
"last_id",
"object"
],
"title": "ListOpenAIResponseObject"
},
"OpenAIResponseObjectWithInput": {
"type": "object",
"properties": {
"created_at": {
"type": "integer"
},
"error": {
"$ref": "#/components/schemas/OpenAIResponseError"
},
"id": {
"type": "string"
},
"model": {
"type": "string"
},
"object": {
"type": "string",
"const": "response",
"default": "response"
},
"output": {
"type": "array",
"items": {
"$ref": "#/components/schemas/OpenAIResponseOutput"
}
},
"parallel_tool_calls": {
"type": "boolean",
"default": false
},
"previous_response_id": {
"type": "string"
},
"status": {
"type": "string"
},
"temperature": {
"type": "number"
},
"top_p": {
"type": "number"
},
"truncation": {
"type": "string"
},
"user": {
"type": "string"
},
"input": {
"type": "array",
"items": {
"$ref": "#/components/schemas/OpenAIResponseInput"
}
}
},
"additionalProperties": false,
"required": [
"created_at",
"id",
"model",
"object",
"output",
"parallel_tool_calls",
"status",
"input"
],
"title": "OpenAIResponseObjectWithInput"
},
"ListProvidersResponse": { "ListProvidersResponse": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -11605,6 +12137,10 @@
"type": "string", "type": "string",
"default": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n", "default": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n",
"description": "Template for formatting each retrieved chunk in the context. Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict). Default: \"Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n\"" "description": "Template for formatting each retrieved chunk in the context. Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict). Default: \"Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n\""
},
"mode": {
"type": "string",
"description": "Search mode for retrieval—either \"vector\" or \"keyword\". Default \"vector\"."
} }
}, },
"additionalProperties": false, "additionalProperties": false,

View file

@ -349,6 +349,53 @@ paths:
$ref: '#/components/schemas/CreateAgentTurnRequest' $ref: '#/components/schemas/CreateAgentTurnRequest'
required: true required: true
/v1/openai/v1/responses: /v1/openai/v1/responses:
get:
responses:
'200':
description: A ListOpenAIResponseObject.
content:
application/json:
schema:
$ref: '#/components/schemas/ListOpenAIResponseObject'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Agents
description: List all OpenAI responses.
parameters:
- name: after
in: query
description: The ID of the last response to return.
required: false
schema:
type: string
- name: limit
in: query
description: The number of responses to return.
required: false
schema:
type: integer
- name: model
in: query
description: The model to filter responses by.
required: false
schema:
type: string
- name: order
in: query
description: >-
The order to sort responses by when sorted by created_at ('asc' or 'desc').
required: false
schema:
$ref: '#/components/schemas/Order'
post: post:
responses: responses:
'200': '200':
@ -963,7 +1010,7 @@ paths:
required: true required: true
schema: schema:
type: string type: string
/v1/openai/v1/responses/{id}: /v1/openai/v1/responses/{response_id}:
get: get:
responses: responses:
'200': '200':
@ -986,7 +1033,7 @@ paths:
- Agents - Agents
description: Retrieve an OpenAI response by its ID. description: Retrieve an OpenAI response by its ID.
parameters: parameters:
- name: id - name: response_id
in: path in: path
description: >- description: >-
The ID of the OpenAI response to retrieve. The ID of the OpenAI response to retrieve.
@ -2038,6 +2085,75 @@ paths:
schema: schema:
$ref: '#/components/schemas/RegisterModelRequest' $ref: '#/components/schemas/RegisterModelRequest'
required: true required: true
/v1/openai/v1/responses/{response_id}/input_items:
get:
responses:
'200':
description: An ListOpenAIResponseInputItem.
content:
application/json:
schema:
$ref: '#/components/schemas/ListOpenAIResponseInputItem'
'400':
$ref: '#/components/responses/BadRequest400'
'429':
$ref: >-
#/components/responses/TooManyRequests429
'500':
$ref: >-
#/components/responses/InternalServerError500
default:
$ref: '#/components/responses/DefaultError'
tags:
- Agents
description: >-
List input items for a given OpenAI response.
parameters:
- name: response_id
in: path
description: >-
The ID of the response to retrieve input items for.
required: true
schema:
type: string
- name: after
in: query
description: >-
An item ID to list items after, used for pagination.
required: false
schema:
type: string
- name: before
in: query
description: >-
An item ID to list items before, used for pagination.
required: false
schema:
type: string
- name: include
in: query
description: >-
Additional fields to include in the response.
required: false
schema:
type: array
items:
type: string
- name: limit
in: query
description: >-
A limit on the number of objects to be returned. Limit can range between
1 and 100, and the default is 20.
required: false
schema:
type: integer
- name: order
in: query
description: >-
The order to return the input items in. Default is desc.
required: false
schema:
$ref: '#/components/schemas/Order'
/v1/providers: /v1/providers:
get: get:
responses: responses:
@ -4762,12 +4878,14 @@ components:
- $ref: '#/components/schemas/OpenAIResponseInputToolWebSearch' - $ref: '#/components/schemas/OpenAIResponseInputToolWebSearch'
- $ref: '#/components/schemas/OpenAIResponseInputToolFileSearch' - $ref: '#/components/schemas/OpenAIResponseInputToolFileSearch'
- $ref: '#/components/schemas/OpenAIResponseInputToolFunction' - $ref: '#/components/schemas/OpenAIResponseInputToolFunction'
- $ref: '#/components/schemas/OpenAIResponseInputToolMCP'
discriminator: discriminator:
propertyName: type propertyName: type
mapping: mapping:
web_search: '#/components/schemas/OpenAIResponseInputToolWebSearch' web_search: '#/components/schemas/OpenAIResponseInputToolWebSearch'
file_search: '#/components/schemas/OpenAIResponseInputToolFileSearch' file_search: '#/components/schemas/OpenAIResponseInputToolFileSearch'
function: '#/components/schemas/OpenAIResponseInputToolFunction' function: '#/components/schemas/OpenAIResponseInputToolFunction'
mcp: '#/components/schemas/OpenAIResponseInputToolMCP'
OpenAIResponseInputToolFileSearch: OpenAIResponseInputToolFileSearch:
type: object type: object
properties: properties:
@ -4822,6 +4940,66 @@ components:
- type - type
- name - name
title: OpenAIResponseInputToolFunction title: OpenAIResponseInputToolFunction
OpenAIResponseInputToolMCP:
type: object
properties:
type:
type: string
const: mcp
default: mcp
server_label:
type: string
server_url:
type: string
headers:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
require_approval:
oneOf:
- type: string
const: always
- type: string
const: never
- type: object
properties:
always:
type: array
items:
type: string
never:
type: array
items:
type: string
additionalProperties: false
title: ApprovalFilter
default: never
allowed_tools:
oneOf:
- type: array
items:
type: string
- type: object
properties:
tool_names:
type: array
items:
type: string
additionalProperties: false
title: AllowedToolsFilter
additionalProperties: false
required:
- type
- server_label
- server_url
- require_approval
title: OpenAIResponseInputToolMCP
OpenAIResponseInputToolWebSearch: OpenAIResponseInputToolWebSearch:
type: object type: object
properties: properties:
@ -4897,12 +5075,12 @@ components:
"OpenAIResponseOutputMessageFunctionToolCall": "OpenAIResponseOutputMessageFunctionToolCall":
type: object type: object
properties: properties:
arguments:
type: string
call_id: call_id:
type: string type: string
name: name:
type: string type: string
arguments:
type: string
type: type:
type: string type: string
const: function_call const: function_call
@ -4913,12 +5091,10 @@ components:
type: string type: string
additionalProperties: false additionalProperties: false
required: required:
- arguments
- call_id - call_id
- name - name
- arguments
- type - type
- id
- status
title: >- title: >-
OpenAIResponseOutputMessageFunctionToolCall OpenAIResponseOutputMessageFunctionToolCall
"OpenAIResponseOutputMessageWebSearchToolCall": "OpenAIResponseOutputMessageWebSearchToolCall":
@ -4952,6 +5128,8 @@ components:
model: model:
type: string type: string
description: The underlying LLM used for completions. description: The underlying LLM used for completions.
instructions:
type: string
previous_response_id: previous_response_id:
type: string type: string
description: >- description: >-
@ -5034,20 +5212,95 @@ components:
- $ref: '#/components/schemas/OpenAIResponseMessage' - $ref: '#/components/schemas/OpenAIResponseMessage'
- $ref: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall' - $ref: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall'
- $ref: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall' - $ref: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall'
- $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPCall'
- $ref: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools'
discriminator: discriminator:
propertyName: type propertyName: type
mapping: mapping:
message: '#/components/schemas/OpenAIResponseMessage' message: '#/components/schemas/OpenAIResponseMessage'
web_search_call: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall' web_search_call: '#/components/schemas/OpenAIResponseOutputMessageWebSearchToolCall'
function_call: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall' function_call: '#/components/schemas/OpenAIResponseOutputMessageFunctionToolCall'
mcp_call: '#/components/schemas/OpenAIResponseOutputMessageMCPCall'
mcp_list_tools: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools'
OpenAIResponseOutputMessageMCPCall:
type: object
properties:
id:
type: string
type:
type: string
const: mcp_call
default: mcp_call
arguments:
type: string
name:
type: string
server_label:
type: string
error:
type: string
output:
type: string
additionalProperties: false
required:
- id
- type
- arguments
- name
- server_label
title: OpenAIResponseOutputMessageMCPCall
OpenAIResponseOutputMessageMCPListTools:
type: object
properties:
id:
type: string
type:
type: string
const: mcp_list_tools
default: mcp_list_tools
server_label:
type: string
tools:
type: array
items:
type: object
properties:
input_schema:
type: object
additionalProperties:
oneOf:
- type: 'null'
- type: boolean
- type: number
- type: string
- type: array
- type: object
name:
type: string
description:
type: string
additionalProperties: false
required:
- input_schema
- name
title: MCPListToolsTool
additionalProperties: false
required:
- id
- type
- server_label
- tools
title: OpenAIResponseOutputMessageMCPListTools
OpenAIResponseObjectStream: OpenAIResponseObjectStream:
oneOf: oneOf:
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated' - $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated'
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseOutputTextDelta'
- $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCompleted' - $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCompleted'
discriminator: discriminator:
propertyName: type propertyName: type
mapping: mapping:
response.created: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated' response.created: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated'
response.output_text.delta: '#/components/schemas/OpenAIResponseObjectStreamResponseOutputTextDelta'
response.completed: '#/components/schemas/OpenAIResponseObjectStreamResponseCompleted' response.completed: '#/components/schemas/OpenAIResponseObjectStreamResponseCompleted'
"OpenAIResponseObjectStreamResponseCompleted": "OpenAIResponseObjectStreamResponseCompleted":
type: object type: object
@ -5079,6 +5332,33 @@ components:
- type - type
title: >- title: >-
OpenAIResponseObjectStreamResponseCreated OpenAIResponseObjectStreamResponseCreated
"OpenAIResponseObjectStreamResponseOutputTextDelta":
type: object
properties:
content_index:
type: integer
delta:
type: string
item_id:
type: string
output_index:
type: integer
sequence_number:
type: integer
type:
type: string
const: response.output_text.delta
default: response.output_text.delta
additionalProperties: false
required:
- content_index
- delta
- item_id
- output_index
- sequence_number
- type
title: >-
OpenAIResponseObjectStreamResponseOutputTextDelta
CreateUploadSessionRequest: CreateUploadSessionRequest:
type: object type: object
properties: properties:
@ -6462,8 +6742,6 @@ components:
default: tool default: tool
toolgroup_id: toolgroup_id:
type: string type: string
tool_host:
$ref: '#/components/schemas/ToolHost'
description: description:
type: string type: string
parameters: parameters:
@ -6486,17 +6764,9 @@ components:
- provider_id - provider_id
- type - type
- toolgroup_id - toolgroup_id
- tool_host
- description - description
- parameters - parameters
title: Tool title: Tool
ToolHost:
type: string
enum:
- distribution
- client
- model_context_protocol
title: ToolHost
ToolGroup: ToolGroup:
type: object type: object
properties: properties:
@ -7042,6 +7312,96 @@ components:
required: required:
- data - data
title: ListModelsResponse title: ListModelsResponse
ListOpenAIResponseInputItem:
type: object
properties:
data:
type: array
items:
$ref: '#/components/schemas/OpenAIResponseInput'
object:
type: string
const: list
default: list
additionalProperties: false
required:
- data
- object
title: ListOpenAIResponseInputItem
ListOpenAIResponseObject:
type: object
properties:
data:
type: array
items:
$ref: '#/components/schemas/OpenAIResponseObjectWithInput'
has_more:
type: boolean
first_id:
type: string
last_id:
type: string
object:
type: string
const: list
default: list
additionalProperties: false
required:
- data
- has_more
- first_id
- last_id
- object
title: ListOpenAIResponseObject
OpenAIResponseObjectWithInput:
type: object
properties:
created_at:
type: integer
error:
$ref: '#/components/schemas/OpenAIResponseError'
id:
type: string
model:
type: string
object:
type: string
const: response
default: response
output:
type: array
items:
$ref: '#/components/schemas/OpenAIResponseOutput'
parallel_tool_calls:
type: boolean
default: false
previous_response_id:
type: string
status:
type: string
temperature:
type: number
top_p:
type: number
truncation:
type: string
user:
type: string
input:
type: array
items:
$ref: '#/components/schemas/OpenAIResponseInput'
additionalProperties: false
required:
- created_at
- id
- model
- object
- output
- parallel_tool_calls
- status
- input
title: OpenAIResponseObjectWithInput
ListProvidersResponse: ListProvidersResponse:
type: object type: object
properties: properties:
@ -8084,6 +8444,10 @@ components:
placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk
content string), {metadata} (chunk metadata dict). Default: "Result {index}\nContent: content string), {metadata} (chunk metadata dict). Default: "Result {index}\nContent:
{chunk.content}\nMetadata: {metadata}\n" {chunk.content}\nMetadata: {metadata}\n"
mode:
type: string
description: >-
Search mode for retrieval—either "vector" or "keyword". Default "vector".
additionalProperties: false additionalProperties: false
required: required:
- query_generator_config - query_generator_config

File diff suppressed because it is too large Load diff

View file

@ -3,10 +3,10 @@
Here's a collection of comprehensive guides, examples, and resources for building AI applications with Llama Stack. For the complete documentation, visit our [ReadTheDocs page](https://llama-stack.readthedocs.io/en/latest/index.html). Here's a collection of comprehensive guides, examples, and resources for building AI applications with Llama Stack. For the complete documentation, visit our [ReadTheDocs page](https://llama-stack.readthedocs.io/en/latest/index.html).
## Render locally ## Render locally
From the llama-stack root directory, run the following command to render the docs locally:
```bash ```bash
pip install -r requirements.txt uv run --with ".[docs]" sphinx-autobuild docs/source docs/build/html --write-all
cd docs
python -m sphinx_autobuild source _build
``` ```
You can open up the docs in your browser at http://localhost:8000 You can open up the docs in your browser at http://localhost:8000

View file

@ -1,16 +0,0 @@
linkify
myst-parser
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
sphinx==8.1.3
sphinx-copybutton
sphinx-design
sphinx-pdj-theme
sphinx-rtd-theme>=1.0.0
sphinx-tabs
sphinx_autobuild
sphinx_rtd_dark_mode
sphinxcontrib-mermaid
sphinxcontrib-openapi
sphinxcontrib-redoc
sphinxcontrib-video
tomli

View file

@ -22,7 +22,11 @@ from docutils import nodes
# Read version from pyproject.toml # Read version from pyproject.toml
with Path(__file__).parent.parent.parent.joinpath("pyproject.toml").open("rb") as f: with Path(__file__).parent.parent.parent.joinpath("pyproject.toml").open("rb") as f:
pypi_url = "https://pypi.org/pypi/llama-stack/json" pypi_url = "https://pypi.org/pypi/llama-stack/json"
version_tag = json.loads(requests.get(pypi_url).text)["info"]["version"] headers = {
'User-Agent': 'pip/23.0.1 (python 3.11)', # Mimic pip's user agent
'Accept': 'application/json'
}
version_tag = json.loads(requests.get(pypi_url, headers=headers).text)["info"]["version"]
print(f"{version_tag=}") print(f"{version_tag=}")
# generate the full link including text and url here # generate the full link including text and url here
@ -53,14 +57,6 @@ myst_enable_extensions = ["colon_fence"]
html_theme = "sphinx_rtd_theme" html_theme = "sphinx_rtd_theme"
html_use_relative_paths = True html_use_relative_paths = True
# html_theme = "sphinx_pdj_theme"
# html_theme_path = [sphinx_pdj_theme.get_html_theme_path()]
# html_theme = "pytorch_sphinx_theme"
# html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()]
templates_path = ["_templates"] templates_path = ["_templates"]
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]

View file

@ -338,6 +338,48 @@ INFO: Application startup complete.
INFO: Uvicorn running on http://['::', '0.0.0.0']:8321 (Press CTRL+C to quit) INFO: Uvicorn running on http://['::', '0.0.0.0']:8321 (Press CTRL+C to quit)
INFO: 2401:db00:35c:2d2b:face:0:c9:0:54678 - "GET /models/list HTTP/1.1" 200 OK INFO: 2401:db00:35c:2d2b:face:0:c9:0:54678 - "GET /models/list HTTP/1.1" 200 OK
``` ```
### Listing Distributions
Using the list command, you can view all existing Llama Stack distributions, including stacks built from templates, from scratch, or using custom configuration files.
```
llama stack list -h
usage: llama stack list [-h]
list the build stacks
options:
-h, --help show this help message and exit
```
Example Usage
```
llama stack list
```
### Removing a Distribution
Use the remove command to delete a distribution you've previously built.
```
llama stack rm -h
usage: llama stack rm [-h] [--all] [name]
Remove the build stack
positional arguments:
name Name of the stack to delete (default: None)
options:
-h, --help show this help message and exit
--all, -a Delete all stacks (use with caution) (default: False)
```
Example
```
llama stack rm llamastack-test
```
To keep your environment organized and avoid clutter, consider using `llama stack list` to review old or unused distributions and `llama stack rm <name>` to delete them when theyre no longer needed.
### Troubleshooting ### Troubleshooting

View file

@ -118,11 +118,6 @@ server:
port: 8321 # Port to listen on (default: 8321) port: 8321 # Port to listen on (default: 8321)
tls_certfile: "/path/to/cert.pem" # Optional: Path to TLS certificate for HTTPS tls_certfile: "/path/to/cert.pem" # Optional: Path to TLS certificate for HTTPS
tls_keyfile: "/path/to/key.pem" # Optional: Path to TLS key for HTTPS tls_keyfile: "/path/to/key.pem" # Optional: Path to TLS key for HTTPS
auth: # Optional: Authentication configuration
provider_type: "kubernetes" # Type of auth provider
config: # Provider-specific configuration
api_server_url: "https://kubernetes.default.svc"
ca_cert_path: "/path/to/ca.crt" # Optional: Path to CA certificate
``` ```
### Authentication Configuration ### Authentication Configuration
@ -135,7 +130,7 @@ Authorization: Bearer <token>
The server supports multiple authentication providers: The server supports multiple authentication providers:
#### Kubernetes Provider #### OAuth 2.0/OpenID Connect Provider with Kubernetes
The Kubernetes cluster must be configured to use a service account for authentication. The Kubernetes cluster must be configured to use a service account for authentication.
@ -146,14 +141,67 @@ kubectl create rolebinding llama-stack-auth-rolebinding --clusterrole=admin --se
kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token
``` ```
Validates tokens against the Kubernetes API server: Make sure the `kube-apiserver` runs with `--anonymous-auth=true` to allow unauthenticated requests
and that the correct RoleBinding is created to allow the service account to access the necessary
resources. If that is not the case, you can create a RoleBinding for the service account to access
the necessary resources:
```yaml
# allow-anonymous-openid.yaml
apiVersion: rbac.authorization.k8s.io/v1
kind: ClusterRole
metadata:
name: allow-anonymous-openid
rules:
- nonResourceURLs: ["/openid/v1/jwks"]
verbs: ["get"]
---
apiVersion: rbac.authorization.k8s.io/v1
kind: ClusterRoleBinding
metadata:
name: allow-anonymous-openid
roleRef:
apiGroup: rbac.authorization.k8s.io
kind: ClusterRole
name: allow-anonymous-openid
subjects:
- kind: User
name: system:anonymous
apiGroup: rbac.authorization.k8s.io
```
And then apply the configuration:
```bash
kubectl apply -f allow-anonymous-openid.yaml
```
Validates tokens against the Kubernetes API server through the OIDC provider:
```yaml ```yaml
server: server:
auth: auth:
provider_type: "kubernetes" provider_type: "oauth2_token"
config: config:
api_server_url: "https://kubernetes.default.svc" # URL of the Kubernetes API server jwks:
ca_cert_path: "/path/to/ca.crt" # Optional: Path to CA certificate uri: "https://kubernetes.default.svc"
key_recheck_period: 3600
tls_cafile: "/path/to/ca.crt"
issuer: "https://kubernetes.default.svc"
audience: "https://kubernetes.default.svc"
```
To find your cluster's audience, run:
```bash
kubectl create token default --duration=1h | cut -d. -f2 | base64 -d | jq .aud
```
For the issuer, you can use the OIDC provider's URL:
```bash
kubectl get --raw /.well-known/openid-configuration| jq .issuer
```
For the tls_cafile, you can use the CA certificate of the OIDC provider:
```bash
kubectl config view --minify -o jsonpath='{.clusters[0].cluster.certificate-authority}'
``` ```
The provider extracts user information from the JWT token: The provider extracts user information from the JWT token:
@ -208,6 +256,80 @@ And must respond with:
If no access attributes are returned, the token is used as a namespace. If no access attributes are returned, the token is used as a namespace.
### Quota Configuration
The `quota` section allows you to enable server-side request throttling for both
authenticated and anonymous clients. This is useful for preventing abuse, enforcing
fairness across tenants, and controlling infrastructure costs without requiring
client-side rate limiting or external proxies.
Quotas are disabled by default. When enabled, each client is tracked using either:
* Their authenticated `client_id` (derived from the Bearer token), or
* Their IP address (fallback for anonymous requests)
Quota state is stored in a SQLite-backed key-value store, and rate limits are applied
within a configurable time window (currently only `day` is supported).
#### Example
```yaml
server:
quota:
kvstore:
type: sqlite
db_path: ./quotas.db
anonymous_max_requests: 100
authenticated_max_requests: 1000
period: day
```
#### Configuration Options
| Field | Description |
| ---------------------------- | -------------------------------------------------------------------------- |
| `kvstore` | Required. Backend storage config for tracking request counts. |
| `kvstore.type` | Must be `"sqlite"` for now. Other backends may be supported in the future. |
| `kvstore.db_path` | File path to the SQLite database. |
| `anonymous_max_requests` | Max requests per period for unauthenticated clients. |
| `authenticated_max_requests` | Max requests per period for authenticated clients. |
| `period` | Time window for quota enforcement. Only `"day"` is supported. |
> Note: if `authenticated_max_requests` is set but no authentication provider is
configured, the server will fall back to applying `anonymous_max_requests` to all
clients.
#### Example with Authentication Enabled
```yaml
server:
port: 8321
auth:
provider_type: custom
config:
endpoint: https://auth.example.com/validate
quota:
kvstore:
type: sqlite
db_path: ./quotas.db
anonymous_max_requests: 100
authenticated_max_requests: 1000
period: day
```
If a client exceeds their limit, the server responds with:
```http
HTTP/1.1 429 Too Many Requests
Content-Type: application/json
{
"error": {
"message": "Quota exceeded"
}
}
```
## Extending to handle Safety ## Extending to handle Safety
Configuring Safety can be a little involved so it is instructive to go through an example. Configuring Safety can be a little involved so it is instructive to go through an example.

View file

@ -17,7 +17,7 @@ The `llamastack/distribution-sambanova` distribution consists of the following p
|-----|-------------| |-----|-------------|
| agents | `inline::meta-reference` | | agents | `inline::meta-reference` |
| inference | `remote::sambanova`, `inline::sentence-transformers` | | inference | `remote::sambanova`, `inline::sentence-transformers` |
| safety | `inline::llama-guard` | | safety | `remote::sambanova` |
| telemetry | `inline::meta-reference` | | telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` | | tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::rag-runtime`, `remote::model-context-protocol`, `remote::wolfram-alpha` |
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` | | vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
@ -48,33 +48,44 @@ The following models are available by default:
### Prerequisite: API Keys ### Prerequisite: API Keys
Make sure you have access to a SambaNova API Key. You can get one by visiting [SambaNova.ai](https://sambanova.ai/). Make sure you have access to a SambaNova API Key. You can get one by visiting [SambaNova.ai](http://cloud.sambanova.ai?utm_source=llamastack&utm_medium=external&utm_campaign=cloud_signup).
## Running Llama Stack with SambaNova ## Running Llama Stack with SambaNova
You can do this via Conda (build code) or Docker which has a pre-built image. You can do this via Conda (build code) or Docker which has a pre-built image.
### Via Docker
This method allows you to get started quickly without having to build the distribution code. ### Via Docker
```bash ```bash
LLAMA_STACK_PORT=8321 LLAMA_STACK_PORT=8321
llama stack build --template sambanova --image-type container
docker run \ docker run \
-it \ -it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
llamastack/distribution-sambanova \ -v ~/.llama:/root/.llama \
distribution-sambanova \
--port $LLAMA_STACK_PORT \ --port $LLAMA_STACK_PORT \
--env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY --env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY
``` ```
### Via Venv
```bash
llama stack build --template sambanova --image-type venv
llama stack run --image-type venv ~/.llama/distributions/sambanova/sambanova-run.yaml \
--port $LLAMA_STACK_PORT \
--env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY
```
### Via Conda ### Via Conda
```bash ```bash
llama stack build --template sambanova --image-type conda llama stack build --template sambanova --image-type conda
llama stack run ./run.yaml \ llama stack run --image-type conda ~/.llama/distributions/sambanova/sambanova-run.yaml \
--port $LLAMA_STACK_PORT \ --port $LLAMA_STACK_PORT \
--env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY --env SAMBANOVA_API_KEY=$SAMBANOVA_API_KEY
``` ```

View file

@ -66,6 +66,25 @@ To use sqlite-vec in your Llama Stack project, follow these steps:
2. Configure your Llama Stack project to use SQLite-Vec. 2. Configure your Llama Stack project to use SQLite-Vec.
3. Start storing and querying vectors. 3. Start storing and querying vectors.
## Supported Search Modes
The sqlite-vec provider supports both vector-based and keyword-based (full-text) search modes.
When using the RAGTool interface, you can specify the desired search behavior via the `mode` parameter in
`RAGQueryConfig`. For example:
```python
from llama_stack.apis.tool_runtime.rag import RAGQueryConfig
query_config = RAGQueryConfig(max_chunks=6, mode="vector")
results = client.tool_runtime.rag_tool.query(
vector_db_ids=[vector_db_id],
content="what is torchtune",
query_config=query_config,
)
```
## Installation ## Installation
You can install SQLite-Vec using pip: You can install SQLite-Vec using pip:

View file

@ -13,7 +13,7 @@ from typing import Annotated, Any, Literal, Protocol, runtime_checkable
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent
from llama_stack.apis.common.responses import PaginatedResponse from llama_stack.apis.common.responses import Order, PaginatedResponse
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
CompletionMessage, CompletionMessage,
ResponseFormat, ResponseFormat,
@ -31,6 +31,8 @@ from llama_stack.apis.tools import ToolDef
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
from .openai_responses import ( from .openai_responses import (
ListOpenAIResponseInputItem,
ListOpenAIResponseObject,
OpenAIResponseInput, OpenAIResponseInput,
OpenAIResponseInputTool, OpenAIResponseInputTool,
OpenAIResponseObject, OpenAIResponseObject,
@ -579,14 +581,14 @@ class Agents(Protocol):
# #
# Both of these APIs are inherently stateful. # Both of these APIs are inherently stateful.
@webmethod(route="/openai/v1/responses/{id}", method="GET") @webmethod(route="/openai/v1/responses/{response_id}", method="GET")
async def get_openai_response( async def get_openai_response(
self, self,
id: str, response_id: str,
) -> OpenAIResponseObject: ) -> OpenAIResponseObject:
"""Retrieve an OpenAI response by its ID. """Retrieve an OpenAI response by its ID.
:param id: The ID of the OpenAI response to retrieve. :param response_id: The ID of the OpenAI response to retrieve.
:returns: An OpenAIResponseObject. :returns: An OpenAIResponseObject.
""" """
... ...
@ -596,6 +598,7 @@ class Agents(Protocol):
self, self,
input: str | list[OpenAIResponseInput], input: str | list[OpenAIResponseInput],
model: str, model: str,
instructions: str | None = None,
previous_response_id: str | None = None, previous_response_id: str | None = None,
store: bool | None = True, store: bool | None = True,
stream: bool | None = False, stream: bool | None = False,
@ -610,3 +613,43 @@ class Agents(Protocol):
:returns: An OpenAIResponseObject. :returns: An OpenAIResponseObject.
""" """
... ...
@webmethod(route="/openai/v1/responses", method="GET")
async def list_openai_responses(
self,
after: str | None = None,
limit: int | None = 50,
model: str | None = None,
order: Order | None = Order.desc,
) -> ListOpenAIResponseObject:
"""List all OpenAI responses.
:param after: The ID of the last response to return.
:param limit: The number of responses to return.
:param model: The model to filter responses by.
:param order: The order to sort responses by when sorted by created_at ('asc' or 'desc').
:returns: A ListOpenAIResponseObject.
"""
...
@webmethod(route="/openai/v1/responses/{response_id}/input_items", method="GET")
async def list_openai_response_input_items(
self,
response_id: str,
after: str | None = None,
before: str | None = None,
include: list[str] | None = None,
limit: int | None = 20,
order: Order | None = Order.desc,
) -> ListOpenAIResponseInputItem:
"""List input items for a given OpenAI response.
:param response_id: The ID of the response to retrieve input items for.
:param after: An item ID to list items after, used for pagination.
:param before: An item ID to list items before, used for pagination.
:param include: Additional fields to include in the response.
:param limit: A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.
:param order: The order to return the input items in. Default is desc.
:returns: An ListOpenAIResponseInputItem.
"""
...

View file

@ -10,6 +10,9 @@ from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type, register_schema from llama_stack.schema_utils import json_schema_type, register_schema
# NOTE(ashwin): this file is literally a copy of the OpenAI responses API schema. We should probably
# take their YAML and generate this file automatically. Their YAML is available.
@json_schema_type @json_schema_type
class OpenAIResponseError(BaseModel): class OpenAIResponseError(BaseModel):
@ -79,16 +82,45 @@ class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel):
@json_schema_type @json_schema_type
class OpenAIResponseOutputMessageFunctionToolCall(BaseModel): class OpenAIResponseOutputMessageFunctionToolCall(BaseModel):
arguments: str
call_id: str call_id: str
name: str name: str
arguments: str
type: Literal["function_call"] = "function_call" type: Literal["function_call"] = "function_call"
id: str | None = None
status: str | None = None
@json_schema_type
class OpenAIResponseOutputMessageMCPCall(BaseModel):
id: str id: str
status: str type: Literal["mcp_call"] = "mcp_call"
arguments: str
name: str
server_label: str
error: str | None = None
output: str | None = None
class MCPListToolsTool(BaseModel):
input_schema: dict[str, Any]
name: str
description: str | None = None
@json_schema_type
class OpenAIResponseOutputMessageMCPListTools(BaseModel):
id: str
type: Literal["mcp_list_tools"] = "mcp_list_tools"
server_label: str
tools: list[MCPListToolsTool]
OpenAIResponseOutput = Annotated[ OpenAIResponseOutput = Annotated[
OpenAIResponseMessage | OpenAIResponseOutputMessageWebSearchToolCall | OpenAIResponseOutputMessageFunctionToolCall, OpenAIResponseMessage
| OpenAIResponseOutputMessageWebSearchToolCall
| OpenAIResponseOutputMessageFunctionToolCall
| OpenAIResponseOutputMessageMCPCall
| OpenAIResponseOutputMessageMCPListTools,
Field(discriminator="type"), Field(discriminator="type"),
] ]
register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput") register_schema(OpenAIResponseOutput, name="OpenAIResponseOutput")
@ -117,6 +149,16 @@ class OpenAIResponseObjectStreamResponseCreated(BaseModel):
type: Literal["response.created"] = "response.created" type: Literal["response.created"] = "response.created"
@json_schema_type
class OpenAIResponseObjectStreamResponseOutputTextDelta(BaseModel):
content_index: int
delta: str
item_id: str
output_index: int
sequence_number: int
type: Literal["response.output_text.delta"] = "response.output_text.delta"
@json_schema_type @json_schema_type
class OpenAIResponseObjectStreamResponseCompleted(BaseModel): class OpenAIResponseObjectStreamResponseCompleted(BaseModel):
response: OpenAIResponseObject response: OpenAIResponseObject
@ -124,7 +166,9 @@ class OpenAIResponseObjectStreamResponseCompleted(BaseModel):
OpenAIResponseObjectStream = Annotated[ OpenAIResponseObjectStream = Annotated[
OpenAIResponseObjectStreamResponseCreated | OpenAIResponseObjectStreamResponseCompleted, OpenAIResponseObjectStreamResponseCreated
| OpenAIResponseObjectStreamResponseOutputTextDelta
| OpenAIResponseObjectStreamResponseCompleted,
Field(discriminator="type"), Field(discriminator="type"),
] ]
register_schema(OpenAIResponseObjectStream, name="OpenAIResponseObjectStream") register_schema(OpenAIResponseObjectStream, name="OpenAIResponseObjectStream")
@ -186,13 +230,50 @@ class OpenAIResponseInputToolFileSearch(BaseModel):
# TODO: add filters # TODO: add filters
class ApprovalFilter(BaseModel):
always: list[str] | None = None
never: list[str] | None = None
class AllowedToolsFilter(BaseModel):
tool_names: list[str] | None = None
@json_schema_type
class OpenAIResponseInputToolMCP(BaseModel):
type: Literal["mcp"] = "mcp"
server_label: str
server_url: str
headers: dict[str, Any] | None = None
require_approval: Literal["always"] | Literal["never"] | ApprovalFilter = "never"
allowed_tools: list[str] | AllowedToolsFilter | None = None
OpenAIResponseInputTool = Annotated[ OpenAIResponseInputTool = Annotated[
OpenAIResponseInputToolWebSearch | OpenAIResponseInputToolFileSearch | OpenAIResponseInputToolFunction, OpenAIResponseInputToolWebSearch
| OpenAIResponseInputToolFileSearch
| OpenAIResponseInputToolFunction
| OpenAIResponseInputToolMCP,
Field(discriminator="type"), Field(discriminator="type"),
] ]
register_schema(OpenAIResponseInputTool, name="OpenAIResponseInputTool") register_schema(OpenAIResponseInputTool, name="OpenAIResponseInputTool")
class OpenAIResponseInputItemList(BaseModel): class ListOpenAIResponseInputItem(BaseModel):
data: list[OpenAIResponseInput] data: list[OpenAIResponseInput]
object: Literal["list"] = "list" object: Literal["list"] = "list"
@json_schema_type
class OpenAIResponseObjectWithInput(OpenAIResponseObject):
input: list[OpenAIResponseInput]
@json_schema_type
class ListOpenAIResponseObject(BaseModel):
data: list[OpenAIResponseObjectWithInput]
has_more: bool
first_id: str
last_id: str
object: Literal["list"] = "list"

View file

@ -1,30 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any
from pydantic import BaseModel
from llama_stack.apis.common.content_types import URL
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class RestAPIMethod(Enum):
GET = "GET"
POST = "POST"
PUT = "PUT"
DELETE = "DELETE"
@json_schema_type
class RestAPIExecutionConfig(BaseModel):
url: URL
method: RestAPIMethod
params: dict[str, Any] | None = None
headers: dict[str, Any] | None = None
body: dict[str, Any] | None = None

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum
from typing import Any from typing import Any
from pydantic import BaseModel from pydantic import BaseModel
@ -11,6 +12,11 @@ from pydantic import BaseModel
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
class Order(Enum):
asc = "asc"
desc = "desc"
@json_schema_type @json_schema_type
class PaginatedResponse(BaseModel): class PaginatedResponse(BaseModel):
"""A generic paginated response that follows a simple format. """A generic paginated response that follows a simple format.

View file

@ -19,6 +19,7 @@ from pydantic import BaseModel, Field, field_validator
from typing_extensions import TypedDict from typing_extensions import TypedDict
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
from llama_stack.apis.common.responses import Order
from llama_stack.apis.models import Model from llama_stack.apis.models import Model
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
from llama_stack.models.llama.datatypes import ( from llama_stack.models.llama.datatypes import (
@ -833,11 +834,6 @@ class ListOpenAIChatCompletionResponse(BaseModel):
object: Literal["list"] = "list" object: Literal["list"] = "list"
class Order(Enum):
asc = "asc"
desc = "desc"
@runtime_checkable @runtime_checkable
@trace_protocol @trace_protocol
class InferenceProvider(Protocol): class InferenceProvider(Protocol):

View file

@ -76,6 +76,7 @@ class RAGQueryConfig(BaseModel):
:param chunk_template: Template for formatting each retrieved chunk in the context. :param chunk_template: Template for formatting each retrieved chunk in the context.
Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict). Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict).
Default: "Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n" Default: "Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n"
:param mode: Search mode for retrievaleither "vector" or "keyword". Default "vector".
""" """
# This config defines how a query is generated using the messages # This config defines how a query is generated using the messages
@ -84,6 +85,7 @@ class RAGQueryConfig(BaseModel):
max_tokens_in_context: int = 4096 max_tokens_in_context: int = 4096
max_chunks: int = 5 max_chunks: int = 5
chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n" chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n"
mode: str | None = None
@field_validator("chunk_template") @field_validator("chunk_template")
def validate_chunk_template(cls, v: str) -> str: def validate_chunk_template(cls, v: str) -> str:

View file

@ -27,18 +27,10 @@ class ToolParameter(BaseModel):
default: Any | None = None default: Any | None = None
@json_schema_type
class ToolHost(Enum):
distribution = "distribution"
client = "client"
model_context_protocol = "model_context_protocol"
@json_schema_type @json_schema_type
class Tool(Resource): class Tool(Resource):
type: Literal[ResourceType.tool] = ResourceType.tool type: Literal[ResourceType.tool] = ResourceType.tool
toolgroup_id: str toolgroup_id: str
tool_host: ToolHost
description: str description: str
parameters: list[ToolParameter] parameters: list[ToolParameter]
metadata: dict[str, Any] | None = None metadata: dict[str, Any] | None = None
@ -76,8 +68,8 @@ class ToolInvocationResult(BaseModel):
class ToolStore(Protocol): class ToolStore(Protocol):
def get_tool(self, tool_name: str) -> Tool: ... async def get_tool(self, tool_name: str) -> Tool: ...
def get_tool_group(self, toolgroup_id: str) -> ToolGroup: ... async def get_tool_group(self, toolgroup_id: str) -> ToolGroup: ...
class ListToolGroupsResponse(BaseModel): class ListToolGroupsResponse(BaseModel):

View file

@ -9,6 +9,7 @@ import asyncio
import json import json
import os import os
import shutil import shutil
import sys
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timezone from datetime import datetime, timezone
from functools import partial from functools import partial
@ -377,14 +378,15 @@ def _meta_download(
downloader = ParallelDownloader(max_concurrent_downloads=max_concurrent_downloads) downloader = ParallelDownloader(max_concurrent_downloads=max_concurrent_downloads)
asyncio.run(downloader.download_all(tasks)) asyncio.run(downloader.download_all(tasks))
cprint(f"\nSuccessfully downloaded model to {output_dir}", "green") cprint(f"\nSuccessfully downloaded model to {output_dir}", color="green", file=sys.stderr)
cprint( cprint(
f"\nView MD5 checksum files at: {output_dir / 'checklist.chk'}", f"\nView MD5 checksum files at: {output_dir / 'checklist.chk'}",
"white", file=sys.stderr,
) )
cprint( cprint(
f"\n[Optionally] To run MD5 checksums, use the following command: llama model verify-download --model-id {model_id}", f"\n[Optionally] To run MD5 checksums, use the following command: llama model verify-download --model-id {model_id}",
"yellow", color="yellow",
file=sys.stderr,
) )

View file

@ -79,6 +79,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
cprint( cprint(
f"Could not find template {args.template}. Please run `llama stack build --list-templates` to check out the available templates", f"Could not find template {args.template}. Please run `llama stack build --list-templates` to check out the available templates",
color="red", color="red",
file=sys.stderr,
) )
sys.exit(1) sys.exit(1)
build_config = available_templates[args.template] build_config = available_templates[args.template]
@ -88,6 +89,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
cprint( cprint(
f"Please specify a image-type ({' | '.join(e.value for e in ImageType)}) for {args.template}", f"Please specify a image-type ({' | '.join(e.value for e in ImageType)}) for {args.template}",
color="red", color="red",
file=sys.stderr,
) )
sys.exit(1) sys.exit(1)
elif args.providers: elif args.providers:
@ -97,6 +99,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
cprint( cprint(
"Could not parse `--providers`. Please ensure the list is in the format api1=provider1,api2=provider2", "Could not parse `--providers`. Please ensure the list is in the format api1=provider1,api2=provider2",
color="red", color="red",
file=sys.stderr,
) )
sys.exit(1) sys.exit(1)
api, provider = api_provider.split("=") api, provider = api_provider.split("=")
@ -105,6 +108,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
cprint( cprint(
f"{api} is not a valid API.", f"{api} is not a valid API.",
color="red", color="red",
file=sys.stderr,
) )
sys.exit(1) sys.exit(1)
if provider in providers_for_api: if provider in providers_for_api:
@ -113,6 +117,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
cprint( cprint(
f"{provider} is not a valid provider for the {api} API.", f"{provider} is not a valid provider for the {api} API.",
color="red", color="red",
file=sys.stderr,
) )
sys.exit(1) sys.exit(1)
distribution_spec = DistributionSpec( distribution_spec = DistributionSpec(
@ -123,6 +128,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
cprint( cprint(
f"Please specify a image-type (container | conda | venv) for {args.template}", f"Please specify a image-type (container | conda | venv) for {args.template}",
color="red", color="red",
file=sys.stderr,
) )
sys.exit(1) sys.exit(1)
@ -151,12 +157,14 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
cprint( cprint(
f"No current conda environment detected or specified, will create a new conda environment with the name `llamastack-{name}`", f"No current conda environment detected or specified, will create a new conda environment with the name `llamastack-{name}`",
color="yellow", color="yellow",
file=sys.stderr,
) )
image_name = f"llamastack-{name}" image_name = f"llamastack-{name}"
else: else:
cprint( cprint(
f"Using conda environment {image_name}", f"Using conda environment {image_name}",
color="green", color="green",
file=sys.stderr,
) )
else: else:
image_name = f"llamastack-{name}" image_name = f"llamastack-{name}"
@ -169,9 +177,10 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
""", """,
), ),
color="green", color="green",
file=sys.stderr,
) )
print("Tip: use <TAB> to see options for the providers.\n") cprint("Tip: use <TAB> to see options for the providers.\n", color="green", file=sys.stderr)
providers = dict() providers = dict()
for api, providers_for_api in get_provider_registry().items(): for api, providers_for_api in get_provider_registry().items():
@ -213,6 +222,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
cprint( cprint(
f"Could not parse config file {args.config}: {e}", f"Could not parse config file {args.config}: {e}",
color="red", color="red",
file=sys.stderr,
) )
sys.exit(1) sys.exit(1)
@ -239,22 +249,25 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
cprint( cprint(
f"Error building stack: {exc}", f"Error building stack: {exc}",
color="red", color="red",
file=sys.stderr,
) )
cprint("Stack trace:", color="red") cprint("Stack trace:", color="red", file=sys.stderr)
traceback.print_exc() traceback.print_exc()
sys.exit(1) sys.exit(1)
if run_config is None: if run_config is None:
cprint( cprint(
"Run config path is empty", "Run config path is empty",
color="red", color="red",
file=sys.stderr,
) )
sys.exit(1) sys.exit(1)
if args.run: if args.run:
config_dict = yaml.safe_load(run_config.read_text()) config_dict = yaml.safe_load(run_config.read_text())
config = parse_and_maybe_upgrade_config(config_dict) config = parse_and_maybe_upgrade_config(config_dict)
if not os.path.exists(config.external_providers_dir): if config.external_providers_dir and not config.external_providers_dir.exists():
os.makedirs(config.external_providers_dir, exist_ok=True) config.external_providers_dir.mkdir(exist_ok=True)
run_args = formulate_run_args(args.image_type, args.image_name, config, args.template) run_args = formulate_run_args(args.image_type, args.image_name, config, args.template)
run_args.extend([str(os.getenv("LLAMA_STACK_PORT", 8321)), "--config", run_config]) run_args.extend([str(os.getenv("LLAMA_STACK_PORT", 8321)), "--config", run_config])
run_command(run_args) run_command(run_args)
@ -304,6 +317,7 @@ def _generate_run_config(
cprint( cprint(
f"Failed to import provider {provider_type} for API {api} - assuming it's external, skipping", f"Failed to import provider {provider_type} for API {api} - assuming it's external, skipping",
color="yellow", color="yellow",
file=sys.stderr,
) )
# Set config_type to None to avoid UnboundLocalError # Set config_type to None to avoid UnboundLocalError
config_type = None config_type = None
@ -331,10 +345,7 @@ def _generate_run_config(
# For non-container builds, the run.yaml is generated at the very end of the build process so it # For non-container builds, the run.yaml is generated at the very end of the build process so it
# makes sense to display this message # makes sense to display this message
if build_config.image_type != LlamaStackImageType.CONTAINER.value: if build_config.image_type != LlamaStackImageType.CONTAINER.value:
cprint( cprint(f"You can now run your stack with `llama stack run {run_config_file}`", color="green", file=sys.stderr)
f"You can now run your stack with `llama stack run {run_config_file}`",
color="green",
)
return run_config_file return run_config_file
@ -372,7 +383,7 @@ def _run_stack_build_command_from_build_config(
# Generate the run.yaml so it can be included in the container image with the proper entrypoint # Generate the run.yaml so it can be included in the container image with the proper entrypoint
# Only do this if we're building a container image and we're not using a template # Only do this if we're building a container image and we're not using a template
if build_config.image_type == LlamaStackImageType.CONTAINER.value and not template_name and config_path: if build_config.image_type == LlamaStackImageType.CONTAINER.value and not template_name and config_path:
cprint("Generating run.yaml file", color="green") cprint("Generating run.yaml file", color="yellow", file=sys.stderr)
run_config_file = _generate_run_config(build_config, build_dir, image_name) run_config_file = _generate_run_config(build_config, build_dir, image_name)
with open(build_file_path, "w") as f: with open(build_file_path, "w") as f:
@ -396,11 +407,13 @@ def _run_stack_build_command_from_build_config(
run_config_file = build_dir / f"{template_name}-run.yaml" run_config_file = build_dir / f"{template_name}-run.yaml"
shutil.copy(path, run_config_file) shutil.copy(path, run_config_file)
cprint("Build Successful!", color="green") cprint("Build Successful!", color="green", file=sys.stderr)
cprint("You can find the newly-built template here: " + colored(template_path, "light_blue")) cprint(f"You can find the newly-built template here: {template_path}", color="light_blue", file=sys.stderr)
cprint( cprint(
"You can run the new Llama Stack distro via: " "You can run the new Llama Stack distro via: "
+ colored(f"llama stack run {template_path} --image-type {build_config.image_type}", "light_blue") + colored(f"llama stack run {template_path} --image-type {build_config.image_type}", "light_blue"),
color="green",
file=sys.stderr,
) )
return template_path return template_path
else: else:

View file

@ -0,0 +1,56 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
from pathlib import Path
from llama_stack.cli.subcommand import Subcommand
from llama_stack.cli.table import print_table
class StackListBuilds(Subcommand):
"""List built stacks in .llama/distributions directory"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"list",
prog="llama stack list",
description="list the build stacks",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._list_stack_command)
def _get_distribution_dirs(self) -> dict[str, Path]:
"""Return a dictionary of distribution names and their paths"""
distributions = {}
dist_dir = Path.home() / ".llama" / "distributions"
if dist_dir.exists():
for stack_dir in dist_dir.iterdir():
if stack_dir.is_dir():
distributions[stack_dir.name] = stack_dir
return distributions
def _list_stack_command(self, args: argparse.Namespace) -> None:
distributions = self._get_distribution_dirs()
if not distributions:
print("No stacks found in ~/.llama/distributions")
return
headers = ["Stack Name", "Path"]
headers.extend(["Build Config", "Run Config"])
rows = []
for name, path in distributions.items():
row = [name, str(path)]
# Check for build and run config files
build_config = "Yes" if (path / f"{name}-build.yaml").exists() else "No"
run_config = "Yes" if (path / f"{name}-run.yaml").exists() else "No"
row.extend([build_config, run_config])
rows.append(row)
print_table(rows, headers, separate_rows=True)

View file

@ -0,0 +1,115 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import shutil
import sys
from pathlib import Path
from termcolor import cprint
from llama_stack.cli.subcommand import Subcommand
from llama_stack.cli.table import print_table
class StackRemove(Subcommand):
"""Remove the build stack"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"rm",
prog="llama stack rm",
description="Remove the build stack",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._remove_stack_build_command)
def _add_arguments(self) -> None:
self.parser.add_argument(
"name",
type=str,
nargs="?",
help="Name of the stack to delete",
)
self.parser.add_argument(
"--all",
"-a",
action="store_true",
help="Delete all stacks (use with caution)",
)
def _get_distribution_dirs(self) -> dict[str, Path]:
"""Return a dictionary of distribution names and their paths"""
distributions = {}
dist_dir = Path.home() / ".llama" / "distributions"
if dist_dir.exists():
for stack_dir in dist_dir.iterdir():
if stack_dir.is_dir():
distributions[stack_dir.name] = stack_dir
return distributions
def _list_stacks(self) -> None:
"""Display available stacks in a table"""
distributions = self._get_distribution_dirs()
if not distributions:
cprint("No stacks found in ~/.llama/distributions", color="red", file=sys.stderr)
sys.exit(1)
headers = ["Stack Name", "Path"]
rows = [[name, str(path)] for name, path in distributions.items()]
print_table(rows, headers, separate_rows=True)
def _remove_stack_build_command(self, args: argparse.Namespace) -> None:
distributions = self._get_distribution_dirs()
if args.all:
confirm = input("Are you sure you want to delete ALL stacks? [yes-i-really-want/N] ").lower()
if confirm != "yes-i-really-want":
cprint("Deletion cancelled.", color="green", file=sys.stderr)
return
for name, path in distributions.items():
try:
shutil.rmtree(path)
cprint(f"Deleted stack: {name}", color="green", file=sys.stderr)
except Exception as e:
cprint(
f"Failed to delete stack {name}: {e}",
color="red",
file=sys.stderr,
)
sys.exit(1)
if not args.name:
self._list_stacks()
if not args.name:
return
if args.name not in distributions:
self._list_stacks()
cprint(
f"Stack not found: {args.name}",
color="red",
file=sys.stderr,
)
sys.exit(1)
stack_path = distributions[args.name]
confirm = input(f"Are you sure you want to delete stack '{args.name}'? [y/N] ").lower()
if confirm != "y":
cprint("Deletion cancelled.", color="green", file=sys.stderr)
return
try:
shutil.rmtree(stack_path)
cprint(f"Successfully deleted stack: {args.name}", color="green", file=sys.stderr)
except Exception as e:
cprint(f"Failed to delete stack {args.name}: {e}", color="red", file=sys.stderr)
sys.exit(1)

View file

@ -6,6 +6,7 @@
import argparse import argparse
import os import os
import subprocess
from pathlib import Path from pathlib import Path
from llama_stack.cli.stack.utils import ImageType from llama_stack.cli.stack.utils import ImageType
@ -60,6 +61,11 @@ class StackRun(Subcommand):
help="Image Type used during the build. This can be either conda or container or venv.", help="Image Type used during the build. This can be either conda or container or venv.",
choices=[e.value for e in ImageType], choices=[e.value for e in ImageType],
) )
self.parser.add_argument(
"--enable-ui",
action="store_true",
help="Start the UI server",
)
# If neither image type nor image name is provided, but at the same time # If neither image type nor image name is provided, but at the same time
# the current environment has conda breadcrumbs, then assume what the user # the current environment has conda breadcrumbs, then assume what the user
@ -83,6 +89,8 @@ class StackRun(Subcommand):
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.distribution.utils.exec import formulate_run_args, run_command from llama_stack.distribution.utils.exec import formulate_run_args, run_command
if args.enable_ui:
self._start_ui_development_server(args.port)
image_type, image_name = self._get_image_type_and_name(args) image_type, image_name = self._get_image_type_and_name(args)
# Check if config is required based on image type # Check if config is required based on image type
@ -170,3 +178,44 @@ class StackRun(Subcommand):
run_args.extend(["--env", f"{key}={value}"]) run_args.extend(["--env", f"{key}={value}"])
run_command(run_args) run_command(run_args)
def _start_ui_development_server(self, stack_server_port: int):
logger.info("Attempting to start UI development server...")
# Check if npm is available
npm_check = subprocess.run(["npm", "--version"], capture_output=True, text=True, check=False)
if npm_check.returncode != 0:
logger.warning(
f"'npm' command not found or not executable. UI development server will not be started. Error: {npm_check.stderr}"
)
return
ui_dir = REPO_ROOT / "llama_stack" / "ui"
logs_dir = Path("~/.llama/ui/logs").expanduser()
try:
# Create logs directory if it doesn't exist
logs_dir.mkdir(parents=True, exist_ok=True)
ui_stdout_log_path = logs_dir / "stdout.log"
ui_stderr_log_path = logs_dir / "stderr.log"
# Open log files in append mode
stdout_log_file = open(ui_stdout_log_path, "a")
stderr_log_file = open(ui_stderr_log_path, "a")
process = subprocess.Popen(
["npm", "run", "dev"],
cwd=str(ui_dir),
stdout=stdout_log_file,
stderr=stderr_log_file,
env={**os.environ, "NEXT_PUBLIC_LLAMA_STACK_BASE_URL": f"http://localhost:{stack_server_port}"},
)
logger.info(f"UI development server process started in {ui_dir} with PID {process.pid}.")
logger.info(f"Logs: stdout -> {ui_stdout_log_path}, stderr -> {ui_stderr_log_path}")
logger.info(f"UI will be available at http://localhost:{os.getenv('LLAMA_STACK_UI_PORT', 8322)}")
except FileNotFoundError:
logger.error(
"Failed to start UI development server: 'npm' command not found. Make sure npm is installed and in your PATH."
)
except Exception as e:
logger.error(f"Failed to start UI development server in {ui_dir}: {e}")

View file

@ -7,12 +7,14 @@
import argparse import argparse
from importlib.metadata import version from importlib.metadata import version
from llama_stack.cli.stack.list_stacks import StackListBuilds
from llama_stack.cli.stack.utils import print_subcommand_description from llama_stack.cli.stack.utils import print_subcommand_description
from llama_stack.cli.subcommand import Subcommand from llama_stack.cli.subcommand import Subcommand
from .build import StackBuild from .build import StackBuild
from .list_apis import StackListApis from .list_apis import StackListApis
from .list_providers import StackListProviders from .list_providers import StackListProviders
from .remove import StackRemove
from .run import StackRun from .run import StackRun
@ -41,5 +43,6 @@ class StackParser(Subcommand):
StackListApis.create(subparsers) StackListApis.create(subparsers)
StackListProviders.create(subparsers) StackListProviders.create(subparsers)
StackRun.create(subparsers) StackRun.create(subparsers)
StackRemove.create(subparsers)
StackListBuilds.create(subparsers)
print_subcommand_description(self.parser, subparsers) print_subcommand_description(self.parser, subparsers)

View file

@ -6,6 +6,7 @@
import importlib.resources import importlib.resources
import logging import logging
import sys
from pathlib import Path from pathlib import Path
from pydantic import BaseModel from pydantic import BaseModel
@ -43,8 +44,20 @@ def get_provider_dependencies(
# Extract providers based on config type # Extract providers based on config type
if isinstance(config, DistributionTemplate): if isinstance(config, DistributionTemplate):
providers = config.providers providers = config.providers
# TODO: This is a hack to get the dependencies for internal APIs into build
# We should have a better way to do this by formalizing the concept of "internal" APIs
# and providers, with a way to specify dependencies for them.
run_configs = config.run_configs
additional_pip_packages: list[str] = []
if run_configs:
for run_config in run_configs.values():
run_config_ = run_config.run_config(name="", providers={}, container_image=None)
if run_config_.inference_store:
additional_pip_packages.extend(run_config_.inference_store.pip_packages)
elif isinstance(config, BuildConfig): elif isinstance(config, BuildConfig):
providers = config.distribution_spec.providers providers = config.distribution_spec.providers
additional_pip_packages = config.additional_pip_packages
deps = [] deps = []
registry = get_provider_registry(config) registry = get_provider_registry(config)
for api_str, provider_or_providers in providers.items(): for api_str, provider_or_providers in providers.items():
@ -72,6 +85,9 @@ def get_provider_dependencies(
else: else:
normal_deps.append(package) normal_deps.append(package)
if additional_pip_packages:
normal_deps.extend(additional_pip_packages)
return list(set(normal_deps)), list(set(special_deps)) return list(set(normal_deps)), list(set(special_deps))
@ -80,10 +96,11 @@ def print_pip_install_help(config: BuildConfig):
cprint( cprint(
f"Please install needed dependencies using the following commands:\n\nuv pip install {' '.join(normal_deps)}", f"Please install needed dependencies using the following commands:\n\nuv pip install {' '.join(normal_deps)}",
"yellow", color="yellow",
file=sys.stderr,
) )
for special_dep in special_deps: for special_dep in special_deps:
cprint(f"uv pip install {special_dep}", "yellow") cprint(f"uv pip install {special_dep}", color="yellow", file=sys.stderr)
print() print()

View file

@ -6,6 +6,7 @@
import inspect import inspect
import json import json
import sys
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from enum import Enum from enum import Enum
from typing import Any, Union, get_args, get_origin from typing import Any, Union, get_args, get_origin
@ -96,13 +97,13 @@ def create_api_client_class(protocol) -> type:
try: try:
data = json.loads(data) data = json.loads(data)
if "error" in data: if "error" in data:
cprint(data, "red") cprint(data, color="red", file=sys.stderr)
continue continue
yield parse_obj_as(return_type, data) yield parse_obj_as(return_type, data)
except Exception as e: except Exception as e:
print(f"Error with parsing or validation: {e}") cprint(f"Error with parsing or validation: {e}", color="red", file=sys.stderr)
print(data) cprint(data, color="red", file=sys.stderr)
def httpx_request_params(self, method_name: str, *args, **kwargs) -> dict: def httpx_request_params(self, method_name: str, *args, **kwargs) -> dict:
webmethod, sig = self.routes[method_name] webmethod, sig = self.routes[method_name]

View file

@ -25,7 +25,8 @@ from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime
from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput
from llama_stack.apis.vector_io import VectorIO from llama_stack.apis.vector_io import VectorIO
from llama_stack.providers.datatypes import Api, ProviderSpec from llama_stack.providers.datatypes import Api, ProviderSpec
from llama_stack.providers.utils.kvstore.config import KVStoreConfig from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStoreConfig
LLAMA_STACK_BUILD_CONFIG_VERSION = "2" LLAMA_STACK_BUILD_CONFIG_VERSION = "2"
LLAMA_STACK_RUN_CONFIG_VERSION = "2" LLAMA_STACK_RUN_CONFIG_VERSION = "2"
@ -220,21 +221,38 @@ class LoggingConfig(BaseModel):
class AuthProviderType(str, Enum): class AuthProviderType(str, Enum):
"""Supported authentication provider types.""" """Supported authentication provider types."""
KUBERNETES = "kubernetes" OAUTH2_TOKEN = "oauth2_token"
CUSTOM = "custom" CUSTOM = "custom"
class AuthenticationConfig(BaseModel): class AuthenticationConfig(BaseModel):
provider_type: AuthProviderType = Field( provider_type: AuthProviderType = Field(
..., ...,
description="Type of authentication provider (e.g., 'kubernetes', 'custom')", description="Type of authentication provider",
) )
config: dict[str, str] = Field( config: dict[str, Any] = Field(
..., ...,
description="Provider-specific configuration", description="Provider-specific configuration",
) )
class AuthenticationRequiredError(Exception):
pass
class QuotaPeriod(str, Enum):
DAY = "day"
class QuotaConfig(BaseModel):
kvstore: SqliteKVStoreConfig = Field(description="Config for KV store backend (SQLite only for now)")
anonymous_max_requests: int = Field(default=100, description="Max requests for unauthenticated clients per period")
authenticated_max_requests: int = Field(
default=1000, description="Max requests for authenticated clients per period"
)
period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set")
class ServerConfig(BaseModel): class ServerConfig(BaseModel):
port: int = Field( port: int = Field(
default=8321, default=8321,
@ -262,6 +280,10 @@ class ServerConfig(BaseModel):
default=None, default=None,
description="The host the server should listen on", description="The host the server should listen on",
) )
quota: QuotaConfig | None = Field(
default=None,
description="Per client quota request configuration",
)
class StackRunConfig(BaseModel): class StackRunConfig(BaseModel):
@ -297,6 +319,13 @@ Configuration for the persistence store used by the distribution registry. If no
a default SQLite store will be used.""", a default SQLite store will be used.""",
) )
inference_store: SqlStoreConfig | None = Field(
default=None,
description="""
Configuration for the persistence store used by the inference API. If not specified,
a default SQLite store will be used.""",
)
# registry of "resources" in the distribution # registry of "resources" in the distribution
models: list[ModelInput] = Field(default_factory=list) models: list[ModelInput] = Field(default_factory=list)
shields: list[ShieldInput] = Field(default_factory=list) shields: list[ShieldInput] = Field(default_factory=list)
@ -345,6 +374,10 @@ class BuildConfig(BaseModel):
description="Path to directory containing external provider implementations. The providers packages will be resolved from this directory. " description="Path to directory containing external provider implementations. The providers packages will be resolved from this directory. "
"pip_packages MUST contain the provider package name.", "pip_packages MUST contain the provider package name.",
) )
additional_pip_packages: list[str] = Field(
default_factory=list,
description="Additional pip packages to install in the distribution. These packages will be installed in the distribution environment.",
)
@field_validator("external_providers_dir") @field_validator("external_providers_dir")
@classmethod @classmethod

View file

@ -31,7 +31,7 @@ async def get_provider_impl(config, deps):
class DistributionInspectImpl(Inspect): class DistributionInspectImpl(Inspect):
def __init__(self, config, deps): def __init__(self, config: DistributionInspectConfig, deps):
self.config = config self.config = config
self.deps = deps self.deps = deps
@ -39,22 +39,36 @@ class DistributionInspectImpl(Inspect):
pass pass
async def list_routes(self) -> ListRoutesResponse: async def list_routes(self) -> ListRoutesResponse:
run_config = self.config.run_config run_config: StackRunConfig = self.config.run_config
ret = [] ret = []
all_endpoints = get_all_api_endpoints() all_endpoints = get_all_api_endpoints()
for api, endpoints in all_endpoints.items(): for api, endpoints in all_endpoints.items():
providers = run_config.providers.get(api.value, []) # Always include provider and inspect APIs, filter others based on run config
ret.extend( if api.value in ["providers", "inspect"]:
[ ret.extend(
RouteInfo( [
route=e.route, RouteInfo(
method=e.method, route=e.route,
provider_types=[p.provider_type for p in providers], method=e.method,
provider_types=[], # These APIs don't have "real" providers - they're internal to the stack
)
for e in endpoints
]
)
else:
providers = run_config.providers.get(api.value, [])
if providers: # Only process if there are providers for this API
ret.extend(
[
RouteInfo(
route=e.route,
method=e.method,
provider_types=[p.provider_type for p in providers],
)
for e in endpoints
]
) )
for e in endpoints
]
)
return ListRoutesResponse(data=ret) return ListRoutesResponse(data=ret)

View file

@ -9,6 +9,7 @@ import inspect
import json import json
import logging import logging
import os import os
import sys
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
@ -210,10 +211,11 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
self.endpoint_impls = None self.endpoint_impls = None
self.impls = await construct_stack(self.config, self.custom_provider_registry) self.impls = await construct_stack(self.config, self.custom_provider_registry)
except ModuleNotFoundError as _e: except ModuleNotFoundError as _e:
cprint(_e.msg, "red") cprint(_e.msg, color="red", file=sys.stderr)
cprint( cprint(
"Using llama-stack as a library requires installing dependencies depending on the template (providers) you choose.\n", "Using llama-stack as a library requires installing dependencies depending on the template (providers) you choose.\n",
"yellow", color="yellow",
file=sys.stderr,
) )
if self.config_path_or_template_name.endswith(".yaml"): if self.config_path_or_template_name.endswith(".yaml"):
# Convert Provider objects to their types # Convert Provider objects to their types
@ -234,7 +236,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
cprint( cprint(
f"Please run:\n\n{prefix}llama stack build --template {self.config_path_or_template_name} --image-type venv\n\n", f"Please run:\n\n{prefix}llama stack build --template {self.config_path_or_template_name} --image-type venv\n\n",
"yellow", "yellow",
file=sys.stderr,
) )
cprint(
"Please check your internet connection and try again.",
"red",
file=sys.stderr,
)
raise _e raise _e
if Api.telemetry in self.impls: if Api.telemetry in self.impls:
@ -261,9 +269,11 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
raise ValueError("Client not initialized") raise ValueError("Client not initialized")
# Create headers with provider data if available # Create headers with provider data if available
headers = {} headers = options.headers or {}
if self.provider_data: if self.provider_data:
headers["X-LlamaStack-Provider-Data"] = json.dumps(self.provider_data) keys = ["X-LlamaStack-Provider-Data", "x-llamastack-provider-data"]
if all(key not in headers for key in keys):
headers["X-LlamaStack-Provider-Data"] = json.dumps(self.provider_data)
# Use context manager for provider data # Use context manager for provider data
with request_provider_data_context(headers): with request_provider_data_context(headers):

View file

@ -47,7 +47,7 @@ from llama_stack.providers.datatypes import (
RemoteProviderSpec, RemoteProviderSpec,
ScoringFunctionsProtocolPrivate, ScoringFunctionsProtocolPrivate,
ShieldsProtocolPrivate, ShieldsProtocolPrivate,
ToolsProtocolPrivate, ToolGroupsProtocolPrivate,
VectorDBsProtocolPrivate, VectorDBsProtocolPrivate,
) )
@ -93,7 +93,7 @@ def api_protocol_map_for_compliance_check() -> dict[Api, Any]:
def additional_protocols_map() -> dict[Api, Any]: def additional_protocols_map() -> dict[Api, Any]:
return { return {
Api.inference: (ModelsProtocolPrivate, Models, Api.models), Api.inference: (ModelsProtocolPrivate, Models, Api.models),
Api.tool_groups: (ToolsProtocolPrivate, ToolGroups, Api.tool_groups), Api.tool_groups: (ToolGroupsProtocolPrivate, ToolGroups, Api.tool_groups),
Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs), Api.vector_io: (VectorDBsProtocolPrivate, VectorDBs, Api.vector_dbs),
Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields), Api.safety: (ShieldsProtocolPrivate, Shields, Api.shields),
Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets), Api.datasetio: (DatasetsProtocolPrivate, Datasets, Api.datasets),
@ -140,7 +140,7 @@ async def resolve_impls(
sorted_providers = sort_providers_by_deps(providers_with_specs, run_config) sorted_providers = sort_providers_by_deps(providers_with_specs, run_config)
return await instantiate_providers(sorted_providers, router_apis, dist_registry) return await instantiate_providers(sorted_providers, router_apis, dist_registry, run_config)
def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]: def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, dict[str, ProviderWithSpec]]:
@ -243,7 +243,10 @@ def sort_providers_by_deps(
async def instantiate_providers( async def instantiate_providers(
sorted_providers: list[tuple[str, ProviderWithSpec]], router_apis: set[Api], dist_registry: DistributionRegistry sorted_providers: list[tuple[str, ProviderWithSpec]],
router_apis: set[Api],
dist_registry: DistributionRegistry,
run_config: StackRunConfig,
) -> dict: ) -> dict:
"""Instantiates providers asynchronously while managing dependencies.""" """Instantiates providers asynchronously while managing dependencies."""
impls: dict[Api, Any] = {} impls: dict[Api, Any] = {}
@ -258,7 +261,7 @@ async def instantiate_providers(
if isinstance(provider.spec, RoutingTableProviderSpec): if isinstance(provider.spec, RoutingTableProviderSpec):
inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"] inner_impls = inner_impls_by_provider_id[f"inner-{provider.spec.router_api.value}"]
impl = await instantiate_provider(provider, deps, inner_impls, dist_registry) impl = await instantiate_provider(provider, deps, inner_impls, dist_registry, run_config)
if api_str.startswith("inner-"): if api_str.startswith("inner-"):
inner_impls_by_provider_id[api_str][provider.provider_id] = impl inner_impls_by_provider_id[api_str][provider.provider_id] = impl
@ -308,6 +311,7 @@ async def instantiate_provider(
deps: dict[Api, Any], deps: dict[Api, Any],
inner_impls: dict[str, Any], inner_impls: dict[str, Any],
dist_registry: DistributionRegistry, dist_registry: DistributionRegistry,
run_config: StackRunConfig,
): ):
provider_spec = provider.spec provider_spec = provider.spec
if not hasattr(provider_spec, "module"): if not hasattr(provider_spec, "module"):
@ -327,7 +331,7 @@ async def instantiate_provider(
method = "get_auto_router_impl" method = "get_auto_router_impl"
config = None config = None
args = [provider_spec.api, deps[provider_spec.routing_table_api], deps] args = [provider_spec.api, deps[provider_spec.routing_table_api], deps, run_config]
elif isinstance(provider_spec, RoutingTableProviderSpec): elif isinstance(provider_spec, RoutingTableProviderSpec):
method = "get_routing_table_impl" method = "get_routing_table_impl"

View file

@ -7,18 +7,10 @@
from typing import Any from typing import Any
from llama_stack.distribution.datatypes import RoutedProtocol from llama_stack.distribution.datatypes import RoutedProtocol
from llama_stack.distribution.stack import StackRunConfig
from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.store import DistributionRegistry
from llama_stack.providers.datatypes import Api, RoutingTable from llama_stack.providers.datatypes import Api, RoutingTable
from llama_stack.providers.utils.inference.inference_store import InferenceStore
from .routing_tables import (
BenchmarksRoutingTable,
DatasetsRoutingTable,
ModelsRoutingTable,
ScoringFunctionsRoutingTable,
ShieldsRoutingTable,
ToolGroupsRoutingTable,
VectorDBsRoutingTable,
)
async def get_routing_table_impl( async def get_routing_table_impl(
@ -27,6 +19,14 @@ async def get_routing_table_impl(
_deps, _deps,
dist_registry: DistributionRegistry, dist_registry: DistributionRegistry,
) -> Any: ) -> Any:
from ..routing_tables.benchmarks import BenchmarksRoutingTable
from ..routing_tables.datasets import DatasetsRoutingTable
from ..routing_tables.models import ModelsRoutingTable
from ..routing_tables.scoring_functions import ScoringFunctionsRoutingTable
from ..routing_tables.shields import ShieldsRoutingTable
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
from ..routing_tables.vector_dbs import VectorDBsRoutingTable
api_to_tables = { api_to_tables = {
"vector_dbs": VectorDBsRoutingTable, "vector_dbs": VectorDBsRoutingTable,
"models": ModelsRoutingTable, "models": ModelsRoutingTable,
@ -45,16 +45,15 @@ async def get_routing_table_impl(
return impl return impl
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: dict[str, Any]) -> Any: async def get_auto_router_impl(
from .routers import ( api: Api, routing_table: RoutingTable, deps: dict[str, Any], run_config: StackRunConfig
DatasetIORouter, ) -> Any:
EvalRouter, from .datasets import DatasetIORouter
InferenceRouter, from .eval_scoring import EvalRouter, ScoringRouter
SafetyRouter, from .inference import InferenceRouter
ScoringRouter, from .safety import SafetyRouter
ToolRuntimeRouter, from .tool_runtime import ToolRuntimeRouter
VectorIORouter, from .vector_io import VectorIORouter
)
api_to_routers = { api_to_routers = {
"vector_io": VectorIORouter, "vector_io": VectorIORouter,
@ -76,6 +75,12 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, deps: dict
if dep_api in deps: if dep_api in deps:
api_to_dep_impl[dep_name] = deps[dep_api] api_to_dep_impl[dep_name] = deps[dep_api]
# TODO: move pass configs to routers instead
if api == Api.inference and run_config.inference_store:
inference_store = InferenceStore(run_config.inference_store)
await inference_store.initialize()
api_to_dep_impl["store"] = inference_store
impl = api_to_routers[api.value](routing_table, **api_to_dep_impl) impl = api_to_routers[api.value](routing_table, **api_to_dep_impl)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -0,0 +1,71 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import DatasetPurpose, DataSource
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core")
class DatasetIORouter(DatasetIO):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing DatasetIORouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("DatasetIORouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("DatasetIORouter.shutdown")
pass
async def register_dataset(
self,
purpose: DatasetPurpose,
source: DataSource,
metadata: dict[str, Any] | None = None,
dataset_id: str | None = None,
) -> None:
logger.debug(
f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}",
)
await self.routing_table.register_dataset(
purpose=purpose,
source=source,
metadata=metadata,
dataset_id=dataset_id,
)
async def iterrows(
self,
dataset_id: str,
start_index: int | None = None,
limit: int | None = None,
) -> PaginatedResponse:
logger.debug(
f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
)
return await self.routing_table.get_provider_impl(dataset_id).iterrows(
dataset_id=dataset_id,
start_index=start_index,
limit=limit,
)
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
dataset_id=dataset_id,
rows=rows,
)

View file

@ -0,0 +1,148 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job
from llama_stack.apis.scoring import (
ScoreBatchResponse,
ScoreResponse,
Scoring,
ScoringFnParams,
)
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core")
class ScoringRouter(Scoring):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing ScoringRouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("ScoringRouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("ScoringRouter.shutdown")
pass
async def score_batch(
self,
dataset_id: str,
scoring_functions: dict[str, ScoringFnParams | None] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
res = {}
for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
dataset_id=dataset_id,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
)
res.update(score_response.results)
if save_results_dataset:
raise NotImplementedError("Save results dataset not implemented yet")
return ScoreBatchResponse(
results=res,
)
async def score(
self,
input_rows: list[dict[str, Any]],
scoring_functions: dict[str, ScoringFnParams | None] = None,
) -> ScoreResponse:
logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
res = {}
# look up and map each scoring function to its provider impl
for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(fn_identifier).score(
input_rows=input_rows,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
)
res.update(score_response.results)
return ScoreResponse(results=res)
class EvalRouter(Eval):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing EvalRouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("EvalRouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("EvalRouter.shutdown")
pass
async def run_eval(
self,
benchmark_id: str,
benchmark_config: BenchmarkConfig,
) -> Job:
logger.debug(f"EvalRouter.run_eval: {benchmark_id}")
return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
benchmark_id=benchmark_id,
benchmark_config=benchmark_config,
)
async def evaluate_rows(
self,
benchmark_id: str,
input_rows: list[dict[str, Any]],
scoring_functions: list[str],
benchmark_config: BenchmarkConfig,
) -> EvaluateResponse:
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
benchmark_id=benchmark_id,
input_rows=input_rows,
scoring_functions=scoring_functions,
benchmark_config=benchmark_config,
)
async def job_status(
self,
benchmark_id: str,
job_id: str,
) -> Job:
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
async def job_cancel(
self,
benchmark_id: str,
job_id: str,
) -> None:
logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
await self.routing_table.get_provider_impl(benchmark_id).job_cancel(
benchmark_id,
job_id,
)
async def job_result(
self,
benchmark_id: str,
job_id: str,
) -> EvaluateResponse:
logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_result(
benchmark_id,
job_id,
)

View file

@ -14,14 +14,9 @@ from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToo
from pydantic import Field, TypeAdapter from pydantic import Field, TypeAdapter
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
URL,
InterleavedContent, InterleavedContent,
InterleavedContentItem, InterleavedContentItem,
) )
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import DatasetPurpose, DataSource
from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
BatchChatCompletionResponse, BatchChatCompletionResponse,
BatchCompletionResponse, BatchCompletionResponse,
@ -32,8 +27,11 @@ from llama_stack.apis.inference import (
EmbeddingsResponse, EmbeddingsResponse,
EmbeddingTaskType, EmbeddingTaskType,
Inference, Inference,
ListOpenAIChatCompletionResponse,
LogProbConfig, LogProbConfig,
Message, Message,
OpenAICompletionWithInputMessages,
Order,
ResponseFormat, ResponseFormat,
SamplingParams, SamplingParams,
StopReason, StopReason,
@ -51,89 +49,18 @@ from llama_stack.apis.inference.inference import (
OpenAIResponseFormatParam, OpenAIResponseFormatParam,
) )
from llama_stack.apis.models import Model, ModelType from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.safety import RunShieldResponse, Safety
from llama_stack.apis.scoring import (
ScoreBatchResponse,
ScoreResponse,
Scoring,
ScoringFnParams,
)
from llama_stack.apis.shields import Shield
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry
from llama_stack.apis.tools import (
ListToolDefsResponse,
RAGDocument,
RAGQueryConfig,
RAGQueryResult,
RAGToolRuntime,
ToolRuntime,
)
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.models.llama.llama3.chat_format import ChatFormat from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
from llama_stack.providers.utils.inference.inference_store import InferenceStore
from llama_stack.providers.utils.inference.stream_utils import stream_and_store_openai_completion
from llama_stack.providers.utils.telemetry.tracing import get_current_span from llama_stack.providers.utils.telemetry.tracing import get_current_span
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core")
class VectorIORouter(VectorIO):
"""Routes to an provider based on the vector db identifier"""
def __init__(
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing VectorIORouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("VectorIORouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("VectorIORouter.shutdown")
pass
async def register_vector_db(
self,
vector_db_id: str,
embedding_model: str,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
) -> None:
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
await self.routing_table.register_vector_db(
vector_db_id,
embedding_model,
embedding_dimension,
provider_id,
provider_vector_db_id,
)
async def insert_chunks(
self,
vector_db_id: str,
chunks: list[Chunk],
ttl_seconds: int | None = None,
) -> None:
logger.debug(
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
)
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
async def query_chunks(
self,
vector_db_id: str,
query: InterleavedContent,
params: dict[str, Any] | None = None,
) -> QueryChunksResponse:
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
class InferenceRouter(Inference): class InferenceRouter(Inference):
"""Routes to an provider based on the model""" """Routes to an provider based on the model"""
@ -141,10 +68,12 @@ class InferenceRouter(Inference):
self, self,
routing_table: RoutingTable, routing_table: RoutingTable,
telemetry: Telemetry | None = None, telemetry: Telemetry | None = None,
store: InferenceStore | None = None,
) -> None: ) -> None:
logger.debug("Initializing InferenceRouter") logger.debug("Initializing InferenceRouter")
self.routing_table = routing_table self.routing_table = routing_table
self.telemetry = telemetry self.telemetry = telemetry
self.store = store
if self.telemetry: if self.telemetry:
self.tokenizer = Tokenizer.get_instance() self.tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(self.tokenizer) self.formatter = ChatFormat(self.tokenizer)
@ -607,9 +536,31 @@ class InferenceRouter(Inference):
provider = self.routing_table.get_provider_impl(model_obj.identifier) provider = self.routing_table.get_provider_impl(model_obj.identifier)
if stream: if stream:
return await provider.openai_chat_completion(**params) response_stream = await provider.openai_chat_completion(**params)
if self.store:
return stream_and_store_openai_completion(response_stream, model, self.store, messages)
return response_stream
else: else:
return await self._nonstream_openai_chat_completion(provider, params) response = await self._nonstream_openai_chat_completion(provider, params)
if self.store:
await self.store.store_chat_completion(response, messages)
return response
async def list_chat_completions(
self,
after: str | None = None,
limit: int | None = 20,
model: str | None = None,
order: Order | None = Order.desc,
) -> ListOpenAIChatCompletionResponse:
if self.store:
return await self.store.list_chat_completions(after, limit, model, order)
raise NotImplementedError("List chat completions is not supported: inference store is not configured.")
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
if self.store:
return await self.store.get_chat_completion(completion_id)
raise NotImplementedError("Get chat completion is not supported: inference store is not configured.")
async def _nonstream_openai_chat_completion(self, provider: Inference, params: dict) -> OpenAIChatCompletion: async def _nonstream_openai_chat_completion(self, provider: Inference, params: dict) -> OpenAIChatCompletion:
response = await provider.openai_chat_completion(**params) response = await provider.openai_chat_completion(**params)
@ -642,295 +593,3 @@ class InferenceRouter(Inference):
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}" status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
) )
return health_statuses return health_statuses
class SafetyRouter(Safety):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing SafetyRouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("SafetyRouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("SafetyRouter.shutdown")
pass
async def register_shield(
self,
shield_id: str,
provider_shield_id: str | None = None,
provider_id: str | None = None,
params: dict[str, Any] | None = None,
) -> Shield:
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
async def run_shield(
self,
shield_id: str,
messages: list[Message],
params: dict[str, Any] = None,
) -> RunShieldResponse:
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
return await self.routing_table.get_provider_impl(shield_id).run_shield(
shield_id=shield_id,
messages=messages,
params=params,
)
class DatasetIORouter(DatasetIO):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing DatasetIORouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("DatasetIORouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("DatasetIORouter.shutdown")
pass
async def register_dataset(
self,
purpose: DatasetPurpose,
source: DataSource,
metadata: dict[str, Any] | None = None,
dataset_id: str | None = None,
) -> None:
logger.debug(
f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}",
)
await self.routing_table.register_dataset(
purpose=purpose,
source=source,
metadata=metadata,
dataset_id=dataset_id,
)
async def iterrows(
self,
dataset_id: str,
start_index: int | None = None,
limit: int | None = None,
) -> PaginatedResponse:
logger.debug(
f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
)
return await self.routing_table.get_provider_impl(dataset_id).iterrows(
dataset_id=dataset_id,
start_index=start_index,
limit=limit,
)
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
logger.debug(f"DatasetIORouter.append_rows: {dataset_id}, {len(rows)} rows")
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
dataset_id=dataset_id,
rows=rows,
)
class ScoringRouter(Scoring):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing ScoringRouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("ScoringRouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("ScoringRouter.shutdown")
pass
async def score_batch(
self,
dataset_id: str,
scoring_functions: dict[str, ScoringFnParams | None] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
logger.debug(f"ScoringRouter.score_batch: {dataset_id}")
res = {}
for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(fn_identifier).score_batch(
dataset_id=dataset_id,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
)
res.update(score_response.results)
if save_results_dataset:
raise NotImplementedError("Save results dataset not implemented yet")
return ScoreBatchResponse(
results=res,
)
async def score(
self,
input_rows: list[dict[str, Any]],
scoring_functions: dict[str, ScoringFnParams | None] = None,
) -> ScoreResponse:
logger.debug(f"ScoringRouter.score: {len(input_rows)} rows, {len(scoring_functions)} functions")
res = {}
# look up and map each scoring function to its provider impl
for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(fn_identifier).score(
input_rows=input_rows,
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
)
res.update(score_response.results)
return ScoreResponse(results=res)
class EvalRouter(Eval):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing EvalRouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("EvalRouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("EvalRouter.shutdown")
pass
async def run_eval(
self,
benchmark_id: str,
benchmark_config: BenchmarkConfig,
) -> Job:
logger.debug(f"EvalRouter.run_eval: {benchmark_id}")
return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
benchmark_id=benchmark_id,
benchmark_config=benchmark_config,
)
async def evaluate_rows(
self,
benchmark_id: str,
input_rows: list[dict[str, Any]],
scoring_functions: list[str],
benchmark_config: BenchmarkConfig,
) -> EvaluateResponse:
logger.debug(f"EvalRouter.evaluate_rows: {benchmark_id}, {len(input_rows)} rows")
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
benchmark_id=benchmark_id,
input_rows=input_rows,
scoring_functions=scoring_functions,
benchmark_config=benchmark_config,
)
async def job_status(
self,
benchmark_id: str,
job_id: str,
) -> Job:
logger.debug(f"EvalRouter.job_status: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
async def job_cancel(
self,
benchmark_id: str,
job_id: str,
) -> None:
logger.debug(f"EvalRouter.job_cancel: {benchmark_id}, {job_id}")
await self.routing_table.get_provider_impl(benchmark_id).job_cancel(
benchmark_id,
job_id,
)
async def job_result(
self,
benchmark_id: str,
job_id: str,
) -> EvaluateResponse:
logger.debug(f"EvalRouter.job_result: {benchmark_id}, {job_id}")
return await self.routing_table.get_provider_impl(benchmark_id).job_result(
benchmark_id,
job_id,
)
class ToolRuntimeRouter(ToolRuntime):
class RagToolImpl(RAGToolRuntime):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
self.routing_table = routing_table
async def query(
self,
content: InterleavedContent,
vector_db_ids: list[str],
query_config: RAGQueryConfig | None = None,
) -> RAGQueryResult:
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
return await self.routing_table.get_provider_impl("knowledge_search").query(
content, vector_db_ids, query_config
)
async def insert(
self,
documents: list[RAGDocument],
vector_db_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
logger.debug(
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
)
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
documents, vector_db_id, chunk_size_in_tokens
)
def __init__(
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing ToolRuntimeRouter")
self.routing_table = routing_table
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
self.rag_tool = self.RagToolImpl(routing_table)
for method in ("query", "insert"):
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
async def initialize(self) -> None:
logger.debug("ToolRuntimeRouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("ToolRuntimeRouter.shutdown")
pass
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any:
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
tool_name=tool_name,
kwargs=kwargs,
)
async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
) -> ListToolDefsResponse:
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)

View file

@ -1,634 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import time
import uuid
from typing import Any
from pydantic import TypeAdapter
from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.datasets import (
Dataset,
DatasetPurpose,
Datasets,
DatasetType,
DataSource,
ListDatasetsResponse,
RowsDataSource,
URIDataSource,
)
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.scoring_functions import (
ListScoringFunctionsResponse,
ScoringFn,
ScoringFnParams,
ScoringFunctions,
)
from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields
from llama_stack.apis.tools import (
ListToolGroupsResponse,
ListToolsResponse,
Tool,
ToolGroup,
ToolGroups,
ToolHost,
)
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
from llama_stack.distribution.access_control import check_access
from llama_stack.distribution.datatypes import (
AccessAttributes,
BenchmarkWithACL,
DatasetWithACL,
ModelWithACL,
RoutableObject,
RoutableObjectWithProvider,
RoutedProtocol,
ScoringFnWithACL,
ShieldWithACL,
ToolGroupWithACL,
ToolWithACL,
VectorDBWithACL,
)
from llama_stack.distribution.request_headers import get_auth_attributes
from llama_stack.distribution.store import DistributionRegistry
from llama_stack.providers.datatypes import Api, RoutingTable
logger = logging.getLogger(__name__)
def get_impl_api(p: Any) -> Api:
return p.__provider_spec__.api
# TODO: this should return the registered object for all APIs
async def register_object_with_provider(obj: RoutableObject, p: Any) -> RoutableObject:
api = get_impl_api(p)
assert obj.provider_id != "remote", "Remote provider should not be registered"
if api == Api.inference:
return await p.register_model(obj)
elif api == Api.safety:
return await p.register_shield(obj)
elif api == Api.vector_io:
return await p.register_vector_db(obj)
elif api == Api.datasetio:
return await p.register_dataset(obj)
elif api == Api.scoring:
return await p.register_scoring_function(obj)
elif api == Api.eval:
return await p.register_benchmark(obj)
elif api == Api.tool_runtime:
return await p.register_tool(obj)
else:
raise ValueError(f"Unknown API {api} for registering object with provider")
async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
api = get_impl_api(p)
if api == Api.vector_io:
return await p.unregister_vector_db(obj.identifier)
elif api == Api.inference:
return await p.unregister_model(obj.identifier)
elif api == Api.datasetio:
return await p.unregister_dataset(obj.identifier)
elif api == Api.tool_runtime:
return await p.unregister_tool(obj.identifier)
else:
raise ValueError(f"Unregister not supported for {api}")
Registry = dict[str, list[RoutableObjectWithProvider]]
class CommonRoutingTableImpl(RoutingTable):
def __init__(
self,
impls_by_provider_id: dict[str, RoutedProtocol],
dist_registry: DistributionRegistry,
) -> None:
self.impls_by_provider_id = impls_by_provider_id
self.dist_registry = dist_registry
async def initialize(self) -> None:
async def add_objects(objs: list[RoutableObjectWithProvider], provider_id: str, cls) -> None:
for obj in objs:
if cls is None:
obj.provider_id = provider_id
else:
# Create a copy of the model data and explicitly set provider_id
model_data = obj.model_dump()
model_data["provider_id"] = provider_id
obj = cls(**model_data)
await self.dist_registry.register(obj)
# Register all objects from providers
for pid, p in self.impls_by_provider_id.items():
api = get_impl_api(p)
if api == Api.inference:
p.model_store = self
elif api == Api.safety:
p.shield_store = self
elif api == Api.vector_io:
p.vector_db_store = self
elif api == Api.datasetio:
p.dataset_store = self
elif api == Api.scoring:
p.scoring_function_store = self
scoring_functions = await p.list_scoring_functions()
await add_objects(scoring_functions, pid, ScoringFn)
elif api == Api.eval:
p.benchmark_store = self
elif api == Api.tool_runtime:
p.tool_store = self
async def shutdown(self) -> None:
for p in self.impls_by_provider_id.values():
await p.shutdown()
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
def apiname_object():
if isinstance(self, ModelsRoutingTable):
return ("Inference", "model")
elif isinstance(self, ShieldsRoutingTable):
return ("Safety", "shield")
elif isinstance(self, VectorDBsRoutingTable):
return ("VectorIO", "vector_db")
elif isinstance(self, DatasetsRoutingTable):
return ("DatasetIO", "dataset")
elif isinstance(self, ScoringFunctionsRoutingTable):
return ("Scoring", "scoring_function")
elif isinstance(self, BenchmarksRoutingTable):
return ("Eval", "benchmark")
elif isinstance(self, ToolGroupsRoutingTable):
return ("Tools", "tool")
else:
raise ValueError("Unknown routing table type")
apiname, objtype = apiname_object()
# Get objects from disk registry
obj = self.dist_registry.get_cached(objtype, routing_key)
if not obj:
provider_ids = list(self.impls_by_provider_id.keys())
if len(provider_ids) > 1:
provider_ids_str = f"any of the providers: {', '.join(provider_ids)}"
else:
provider_ids_str = f"provider: `{provider_ids[0]}`"
raise ValueError(
f"{objtype.capitalize()} `{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objtype}."
)
if not provider_id or provider_id == obj.provider_id:
return self.impls_by_provider_id[obj.provider_id]
raise ValueError(f"Provider not found for `{routing_key}`")
async def get_object_by_identifier(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
# Get from disk registry
obj = await self.dist_registry.get(type, identifier)
if not obj:
return None
# Check if user has permission to access this object
if not check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()):
logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch")
return None
return obj
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
await self.dist_registry.delete(obj.type, obj.identifier)
await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id])
async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider:
# if provider_id is not specified, pick an arbitrary one from existing entries
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
obj.provider_id = list(self.impls_by_provider_id.keys())[0]
if obj.provider_id not in self.impls_by_provider_id:
raise ValueError(f"Provider `{obj.provider_id}` not found")
p = self.impls_by_provider_id[obj.provider_id]
# If object supports access control but no attributes set, use creator's attributes
if not obj.access_attributes:
creator_attributes = get_auth_attributes()
if creator_attributes:
obj.access_attributes = AccessAttributes(**creator_attributes)
logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity")
registered_obj = await register_object_with_provider(obj, p)
# TODO: This needs to be fixed for all APIs once they return the registered object
if obj.type == ResourceType.model.value:
await self.dist_registry.register(registered_obj)
return registered_obj
else:
await self.dist_registry.register(obj)
return obj
async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]:
objs = await self.dist_registry.get_all()
filtered_objs = [obj for obj in objs if obj.type == type]
# Apply attribute-based access control filtering
if filtered_objs:
filtered_objs = [
obj
for obj in filtered_objs
if check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes())
]
return filtered_objs
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def list_models(self) -> ListModelsResponse:
return ListModelsResponse(data=await self.get_all_with_type("model"))
async def openai_list_models(self) -> OpenAIListModelsResponse:
models = await self.get_all_with_type("model")
openai_models = [
OpenAIModel(
id=model.identifier,
object="model",
created=int(time.time()),
owned_by="llama_stack",
)
for model in models
]
return OpenAIListModelsResponse(data=openai_models)
async def get_model(self, model_id: str) -> Model:
model = await self.get_object_by_identifier("model", model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
return model
async def register_model(
self,
model_id: str,
provider_model_id: str | None = None,
provider_id: str | None = None,
metadata: dict[str, Any] | None = None,
model_type: ModelType | None = None,
) -> Model:
if provider_model_id is None:
provider_model_id = model_id
if provider_id is None:
# If provider_id not specified, use the only provider if it supports this model
if len(self.impls_by_provider_id) == 1:
provider_id = list(self.impls_by_provider_id.keys())[0]
else:
raise ValueError(
f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}"
)
if metadata is None:
metadata = {}
if model_type is None:
model_type = ModelType.llm
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
raise ValueError("Embedding model must have an embedding dimension in its metadata")
model = ModelWithACL(
identifier=model_id,
provider_resource_id=provider_model_id,
provider_id=provider_id,
metadata=metadata,
model_type=model_type,
)
registered_model = await self.register_object(model)
return registered_model
async def unregister_model(self, model_id: str) -> None:
existing_model = await self.get_model(model_id)
if existing_model is None:
raise ValueError(f"Model {model_id} not found")
await self.unregister_object(existing_model)
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def list_shields(self) -> ListShieldsResponse:
return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value))
async def get_shield(self, identifier: str) -> Shield:
shield = await self.get_object_by_identifier("shield", identifier)
if shield is None:
raise ValueError(f"Shield '{identifier}' not found")
return shield
async def register_shield(
self,
shield_id: str,
provider_shield_id: str | None = None,
provider_id: str | None = None,
params: dict[str, Any] | None = None,
) -> Shield:
if provider_shield_id is None:
provider_shield_id = shield_id
if provider_id is None:
# If provider_id not specified, use the only provider if it supports this shield type
if len(self.impls_by_provider_id) == 1:
provider_id = list(self.impls_by_provider_id.keys())[0]
else:
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id."
)
if params is None:
params = {}
shield = ShieldWithACL(
identifier=shield_id,
provider_resource_id=provider_shield_id,
provider_id=provider_id,
params=params,
)
await self.register_object(shield)
return shield
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
async def list_vector_dbs(self) -> ListVectorDBsResponse:
return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db"))
async def get_vector_db(self, vector_db_id: str) -> VectorDB:
vector_db = await self.get_object_by_identifier("vector_db", vector_db_id)
if vector_db is None:
raise ValueError(f"Vector DB '{vector_db_id}' not found")
return vector_db
async def register_vector_db(
self,
vector_db_id: str,
embedding_model: str,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
) -> VectorDB:
if provider_vector_db_id is None:
provider_vector_db_id = vector_db_id
if provider_id is None:
if len(self.impls_by_provider_id) > 0:
provider_id = list(self.impls_by_provider_id.keys())[0]
if len(self.impls_by_provider_id) > 1:
logger.warning(
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
)
else:
raise ValueError("No provider available. Please configure a vector_io provider.")
model = await self.get_object_by_identifier("model", embedding_model)
if model is None:
raise ValueError(f"Model {embedding_model} not found")
if model.model_type != ModelType.embedding:
raise ValueError(f"Model {embedding_model} is not an embedding model")
if "embedding_dimension" not in model.metadata:
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
vector_db_data = {
"identifier": vector_db_id,
"type": ResourceType.vector_db.value,
"provider_id": provider_id,
"provider_resource_id": provider_vector_db_id,
"embedding_model": embedding_model,
"embedding_dimension": model.metadata["embedding_dimension"],
}
vector_db = TypeAdapter(VectorDBWithACL).validate_python(vector_db_data)
await self.register_object(vector_db)
return vector_db
async def unregister_vector_db(self, vector_db_id: str) -> None:
existing_vector_db = await self.get_vector_db(vector_db_id)
if existing_vector_db is None:
raise ValueError(f"Vector DB {vector_db_id} not found")
await self.unregister_object(existing_vector_db)
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
async def list_datasets(self) -> ListDatasetsResponse:
return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value))
async def get_dataset(self, dataset_id: str) -> Dataset:
dataset = await self.get_object_by_identifier("dataset", dataset_id)
if dataset is None:
raise ValueError(f"Dataset '{dataset_id}' not found")
return dataset
async def register_dataset(
self,
purpose: DatasetPurpose,
source: DataSource,
metadata: dict[str, Any] | None = None,
dataset_id: str | None = None,
) -> Dataset:
if isinstance(source, dict):
if source["type"] == "uri":
source = URIDataSource.parse_obj(source)
elif source["type"] == "rows":
source = RowsDataSource.parse_obj(source)
if not dataset_id:
dataset_id = f"dataset-{str(uuid.uuid4())}"
provider_dataset_id = dataset_id
# infer provider from source
if metadata:
if metadata.get("provider_id"):
provider_id = metadata.get("provider_id") # pass through from nvidia datasetio
elif source.type == DatasetType.rows.value:
provider_id = "localfs"
elif source.type == DatasetType.uri.value:
# infer provider from uri
if source.uri.startswith("huggingface"):
provider_id = "huggingface"
else:
provider_id = "localfs"
else:
raise ValueError(f"Unknown data source type: {source.type}")
if metadata is None:
metadata = {}
dataset = DatasetWithACL(
identifier=dataset_id,
provider_resource_id=provider_dataset_id,
provider_id=provider_id,
purpose=purpose,
source=source,
metadata=metadata,
)
await self.register_object(dataset)
return dataset
async def unregister_dataset(self, dataset_id: str) -> None:
dataset = await self.get_dataset(dataset_id)
if dataset is None:
raise ValueError(f"Dataset {dataset_id} not found")
await self.unregister_object(dataset)
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value))
async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn:
scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id)
if scoring_fn is None:
raise ValueError(f"Scoring function '{scoring_fn_id}' not found")
return scoring_fn
async def register_scoring_function(
self,
scoring_fn_id: str,
description: str,
return_type: ParamType,
provider_scoring_fn_id: str | None = None,
provider_id: str | None = None,
params: ScoringFnParams | None = None,
) -> None:
if provider_scoring_fn_id is None:
provider_scoring_fn_id = scoring_fn_id
if provider_id is None:
if len(self.impls_by_provider_id) == 1:
provider_id = list(self.impls_by_provider_id.keys())[0]
else:
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id."
)
scoring_fn = ScoringFnWithACL(
identifier=scoring_fn_id,
description=description,
return_type=return_type,
provider_resource_id=provider_scoring_fn_id,
provider_id=provider_id,
params=params,
)
scoring_fn.provider_id = provider_id
await self.register_object(scoring_fn)
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
async def list_benchmarks(self) -> ListBenchmarksResponse:
return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark"))
async def get_benchmark(self, benchmark_id: str) -> Benchmark:
benchmark = await self.get_object_by_identifier("benchmark", benchmark_id)
if benchmark is None:
raise ValueError(f"Benchmark '{benchmark_id}' not found")
return benchmark
async def register_benchmark(
self,
benchmark_id: str,
dataset_id: str,
scoring_functions: list[str],
metadata: dict[str, Any] | None = None,
provider_benchmark_id: str | None = None,
provider_id: str | None = None,
) -> None:
if metadata is None:
metadata = {}
if provider_id is None:
if len(self.impls_by_provider_id) == 1:
provider_id = list(self.impls_by_provider_id.keys())[0]
else:
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id."
)
if provider_benchmark_id is None:
provider_benchmark_id = benchmark_id
benchmark = BenchmarkWithACL(
identifier=benchmark_id,
dataset_id=dataset_id,
scoring_functions=scoring_functions,
metadata=metadata,
provider_id=provider_id,
provider_resource_id=provider_benchmark_id,
)
await self.register_object(benchmark)
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
tools = await self.get_all_with_type("tool")
if toolgroup_id:
tools = [tool for tool in tools if tool.toolgroup_id == toolgroup_id]
return ListToolsResponse(data=tools)
async def list_tool_groups(self) -> ListToolGroupsResponse:
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
tool_group = await self.get_object_by_identifier("tool_group", toolgroup_id)
if tool_group is None:
raise ValueError(f"Tool group '{toolgroup_id}' not found")
return tool_group
async def get_tool(self, tool_name: str) -> Tool:
return await self.get_object_by_identifier("tool", tool_name)
async def register_tool_group(
self,
toolgroup_id: str,
provider_id: str,
mcp_endpoint: URL | None = None,
args: dict[str, Any] | None = None,
) -> None:
tools = []
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint)
tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
for tool_def in tool_defs.data:
tools.append(
ToolWithACL(
identifier=tool_def.name,
toolgroup_id=toolgroup_id,
description=tool_def.description or "",
parameters=tool_def.parameters or [],
provider_id=provider_id,
provider_resource_id=tool_def.name,
metadata=tool_def.metadata,
tool_host=tool_host,
)
)
for tool in tools:
existing_tool = await self.get_tool(tool.identifier)
# Compare existing and new object if one exists
if existing_tool:
existing_dict = existing_tool.model_dump()
new_dict = tool.model_dump()
if existing_dict != new_dict:
raise ValueError(
f"Object {tool.identifier} already exists in registry. Please use a different identifier."
)
await self.register_object(tool)
await self.dist_registry.register(
ToolGroupWithACL(
identifier=toolgroup_id,
provider_id=provider_id,
provider_resource_id=toolgroup_id,
mcp_endpoint=mcp_endpoint,
args=args,
)
)
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
tool_group = await self.get_tool_group(toolgroup_id)
if tool_group is None:
raise ValueError(f"Tool group {toolgroup_id} not found")
tools = await self.list_tools(toolgroup_id)
for tool in getattr(tools, "data", []):
await self.unregister_object(tool)
await self.unregister_object(tool_group)
async def shutdown(self) -> None:
pass

View file

@ -0,0 +1,57 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.inference import (
Message,
)
from llama_stack.apis.safety import RunShieldResponse, Safety
from llama_stack.apis.shields import Shield
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core")
class SafetyRouter(Safety):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing SafetyRouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("SafetyRouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("SafetyRouter.shutdown")
pass
async def register_shield(
self,
shield_id: str,
provider_shield_id: str | None = None,
provider_id: str | None = None,
params: dict[str, Any] | None = None,
) -> Shield:
logger.debug(f"SafetyRouter.register_shield: {shield_id}")
return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params)
async def run_shield(
self,
shield_id: str,
messages: list[Message],
params: dict[str, Any] = None,
) -> RunShieldResponse:
logger.debug(f"SafetyRouter.run_shield: {shield_id}")
return await self.routing_table.get_provider_impl(shield_id).run_shield(
shield_id=shield_id,
messages=messages,
params=params,
)

View file

@ -0,0 +1,92 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.common.content_types import (
URL,
InterleavedContent,
)
from llama_stack.apis.tools import (
ListToolsResponse,
RAGDocument,
RAGQueryConfig,
RAGQueryResult,
RAGToolRuntime,
ToolRuntime,
)
from llama_stack.log import get_logger
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
logger = get_logger(name=__name__, category="core")
class ToolRuntimeRouter(ToolRuntime):
class RagToolImpl(RAGToolRuntime):
def __init__(
self,
routing_table: ToolGroupsRoutingTable,
) -> None:
logger.debug("Initializing ToolRuntimeRouter.RagToolImpl")
self.routing_table = routing_table
async def query(
self,
content: InterleavedContent,
vector_db_ids: list[str],
query_config: RAGQueryConfig | None = None,
) -> RAGQueryResult:
logger.debug(f"ToolRuntimeRouter.RagToolImpl.query: {vector_db_ids}")
return await self.routing_table.get_provider_impl("knowledge_search").query(
content, vector_db_ids, query_config
)
async def insert(
self,
documents: list[RAGDocument],
vector_db_id: str,
chunk_size_in_tokens: int = 512,
) -> None:
logger.debug(
f"ToolRuntimeRouter.RagToolImpl.insert: {vector_db_id}, {len(documents)} documents, chunk_size={chunk_size_in_tokens}"
)
return await self.routing_table.get_provider_impl("insert_into_memory").insert(
documents, vector_db_id, chunk_size_in_tokens
)
def __init__(
self,
routing_table: ToolGroupsRoutingTable,
) -> None:
logger.debug("Initializing ToolRuntimeRouter")
self.routing_table = routing_table
# HACK ALERT this should be in sync with "get_all_api_endpoints()"
self.rag_tool = self.RagToolImpl(routing_table)
for method in ("query", "insert"):
setattr(self, f"rag_tool.{method}", getattr(self.rag_tool, method))
async def initialize(self) -> None:
logger.debug("ToolRuntimeRouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("ToolRuntimeRouter.shutdown")
pass
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> Any:
logger.debug(f"ToolRuntimeRouter.invoke_tool: {tool_name}")
return await self.routing_table.get_provider_impl(tool_name).invoke_tool(
tool_name=tool_name,
kwargs=kwargs,
)
async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
) -> ListToolsResponse:
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
return await self.routing_table.list_tools(tool_group_id)

View file

@ -0,0 +1,72 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.common.content_types import (
InterleavedContent,
)
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core")
class VectorIORouter(VectorIO):
"""Routes to an provider based on the vector db identifier"""
def __init__(
self,
routing_table: RoutingTable,
) -> None:
logger.debug("Initializing VectorIORouter")
self.routing_table = routing_table
async def initialize(self) -> None:
logger.debug("VectorIORouter.initialize")
pass
async def shutdown(self) -> None:
logger.debug("VectorIORouter.shutdown")
pass
async def register_vector_db(
self,
vector_db_id: str,
embedding_model: str,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
) -> None:
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
await self.routing_table.register_vector_db(
vector_db_id,
embedding_model,
embedding_dimension,
provider_id,
provider_vector_db_id,
)
async def insert_chunks(
self,
vector_db_id: str,
chunks: list[Chunk],
ttl_seconds: int | None = None,
) -> None:
logger.debug(
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
)
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
async def query_chunks(
self,
vector_db_id: str,
query: InterleavedContent,
params: dict[str, Any] | None = None,
) -> QueryChunksResponse:
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,58 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
from llama_stack.distribution.datatypes import (
BenchmarkWithACL,
)
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
async def list_benchmarks(self) -> ListBenchmarksResponse:
return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark"))
async def get_benchmark(self, benchmark_id: str) -> Benchmark:
benchmark = await self.get_object_by_identifier("benchmark", benchmark_id)
if benchmark is None:
raise ValueError(f"Benchmark '{benchmark_id}' not found")
return benchmark
async def register_benchmark(
self,
benchmark_id: str,
dataset_id: str,
scoring_functions: list[str],
metadata: dict[str, Any] | None = None,
provider_benchmark_id: str | None = None,
provider_id: str | None = None,
) -> None:
if metadata is None:
metadata = {}
if provider_id is None:
if len(self.impls_by_provider_id) == 1:
provider_id = list(self.impls_by_provider_id.keys())[0]
else:
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id."
)
if provider_benchmark_id is None:
provider_benchmark_id = benchmark_id
benchmark = BenchmarkWithACL(
identifier=benchmark_id,
dataset_id=dataset_id,
scoring_functions=scoring_functions,
metadata=metadata,
provider_id=provider_id,
provider_resource_id=provider_benchmark_id,
)
await self.register_object(benchmark)

View file

@ -0,0 +1,218 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.distribution.access_control import check_access
from llama_stack.distribution.datatypes import (
AccessAttributes,
RoutableObject,
RoutableObjectWithProvider,
RoutedProtocol,
)
from llama_stack.distribution.request_headers import get_auth_attributes
from llama_stack.distribution.store import DistributionRegistry
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, RoutingTable
logger = get_logger(name=__name__, category="core")
def get_impl_api(p: Any) -> Api:
return p.__provider_spec__.api
# TODO: this should return the registered object for all APIs
async def register_object_with_provider(obj: RoutableObject, p: Any) -> RoutableObject:
api = get_impl_api(p)
assert obj.provider_id != "remote", "Remote provider should not be registered"
if api == Api.inference:
return await p.register_model(obj)
elif api == Api.safety:
return await p.register_shield(obj)
elif api == Api.vector_io:
return await p.register_vector_db(obj)
elif api == Api.datasetio:
return await p.register_dataset(obj)
elif api == Api.scoring:
return await p.register_scoring_function(obj)
elif api == Api.eval:
return await p.register_benchmark(obj)
elif api == Api.tool_runtime:
return await p.register_toolgroup(obj)
else:
raise ValueError(f"Unknown API {api} for registering object with provider")
async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
api = get_impl_api(p)
if api == Api.vector_io:
return await p.unregister_vector_db(obj.identifier)
elif api == Api.inference:
return await p.unregister_model(obj.identifier)
elif api == Api.datasetio:
return await p.unregister_dataset(obj.identifier)
elif api == Api.tool_runtime:
return await p.unregister_toolgroup(obj.identifier)
else:
raise ValueError(f"Unregister not supported for {api}")
Registry = dict[str, list[RoutableObjectWithProvider]]
class CommonRoutingTableImpl(RoutingTable):
def __init__(
self,
impls_by_provider_id: dict[str, RoutedProtocol],
dist_registry: DistributionRegistry,
) -> None:
self.impls_by_provider_id = impls_by_provider_id
self.dist_registry = dist_registry
async def initialize(self) -> None:
async def add_objects(objs: list[RoutableObjectWithProvider], provider_id: str, cls) -> None:
for obj in objs:
if cls is None:
obj.provider_id = provider_id
else:
# Create a copy of the model data and explicitly set provider_id
model_data = obj.model_dump()
model_data["provider_id"] = provider_id
obj = cls(**model_data)
await self.dist_registry.register(obj)
# Register all objects from providers
for pid, p in self.impls_by_provider_id.items():
api = get_impl_api(p)
if api == Api.inference:
p.model_store = self
elif api == Api.safety:
p.shield_store = self
elif api == Api.vector_io:
p.vector_db_store = self
elif api == Api.datasetio:
p.dataset_store = self
elif api == Api.scoring:
p.scoring_function_store = self
scoring_functions = await p.list_scoring_functions()
await add_objects(scoring_functions, pid, ScoringFn)
elif api == Api.eval:
p.benchmark_store = self
elif api == Api.tool_runtime:
p.tool_store = self
async def shutdown(self) -> None:
for p in self.impls_by_provider_id.values():
await p.shutdown()
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
from .benchmarks import BenchmarksRoutingTable
from .datasets import DatasetsRoutingTable
from .models import ModelsRoutingTable
from .scoring_functions import ScoringFunctionsRoutingTable
from .shields import ShieldsRoutingTable
from .toolgroups import ToolGroupsRoutingTable
from .vector_dbs import VectorDBsRoutingTable
def apiname_object():
if isinstance(self, ModelsRoutingTable):
return ("Inference", "model")
elif isinstance(self, ShieldsRoutingTable):
return ("Safety", "shield")
elif isinstance(self, VectorDBsRoutingTable):
return ("VectorIO", "vector_db")
elif isinstance(self, DatasetsRoutingTable):
return ("DatasetIO", "dataset")
elif isinstance(self, ScoringFunctionsRoutingTable):
return ("Scoring", "scoring_function")
elif isinstance(self, BenchmarksRoutingTable):
return ("Eval", "benchmark")
elif isinstance(self, ToolGroupsRoutingTable):
return ("ToolGroups", "tool_group")
else:
raise ValueError("Unknown routing table type")
apiname, objtype = apiname_object()
# Get objects from disk registry
obj = self.dist_registry.get_cached(objtype, routing_key)
if not obj:
provider_ids = list(self.impls_by_provider_id.keys())
if len(provider_ids) > 1:
provider_ids_str = f"any of the providers: {', '.join(provider_ids)}"
else:
provider_ids_str = f"provider: `{provider_ids[0]}`"
raise ValueError(
f"{objtype.capitalize()} `{routing_key}` not served by {provider_ids_str}. Make sure there is an {apiname} provider serving this {objtype}."
)
if not provider_id or provider_id == obj.provider_id:
return self.impls_by_provider_id[obj.provider_id]
raise ValueError(f"Provider not found for `{routing_key}`")
async def get_object_by_identifier(self, type: str, identifier: str) -> RoutableObjectWithProvider | None:
# Get from disk registry
obj = await self.dist_registry.get(type, identifier)
if not obj:
return None
# Check if user has permission to access this object
if not check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes()):
logger.debug(f"Access denied to {type} '{identifier}' based on attribute mismatch")
return None
return obj
async def unregister_object(self, obj: RoutableObjectWithProvider) -> None:
await self.dist_registry.delete(obj.type, obj.identifier)
await unregister_object_from_provider(obj, self.impls_by_provider_id[obj.provider_id])
async def register_object(self, obj: RoutableObjectWithProvider) -> RoutableObjectWithProvider:
# if provider_id is not specified, pick an arbitrary one from existing entries
if not obj.provider_id and len(self.impls_by_provider_id) > 0:
obj.provider_id = list(self.impls_by_provider_id.keys())[0]
if obj.provider_id not in self.impls_by_provider_id:
raise ValueError(f"Provider `{obj.provider_id}` not found")
p = self.impls_by_provider_id[obj.provider_id]
# If object supports access control but no attributes set, use creator's attributes
if not obj.access_attributes:
creator_attributes = get_auth_attributes()
if creator_attributes:
obj.access_attributes = AccessAttributes(**creator_attributes)
logger.info(f"Setting access attributes for {obj.type} '{obj.identifier}' based on creator's identity")
registered_obj = await register_object_with_provider(obj, p)
# TODO: This needs to be fixed for all APIs once they return the registered object
if obj.type == ResourceType.model.value:
await self.dist_registry.register(registered_obj)
return registered_obj
else:
await self.dist_registry.register(obj)
return obj
async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]:
objs = await self.dist_registry.get_all()
filtered_objs = [obj for obj in objs if obj.type == type]
# Apply attribute-based access control filtering
if filtered_objs:
filtered_objs = [
obj
for obj in filtered_objs
if check_access(obj.identifier, getattr(obj, "access_attributes", None), get_auth_attributes())
]
return filtered_objs

View file

@ -0,0 +1,93 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
from typing import Any
from llama_stack.apis.datasets import (
Dataset,
DatasetPurpose,
Datasets,
DatasetType,
DataSource,
ListDatasetsResponse,
RowsDataSource,
URIDataSource,
)
from llama_stack.apis.resource import ResourceType
from llama_stack.distribution.datatypes import (
DatasetWithACL,
)
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
async def list_datasets(self) -> ListDatasetsResponse:
return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value))
async def get_dataset(self, dataset_id: str) -> Dataset:
dataset = await self.get_object_by_identifier("dataset", dataset_id)
if dataset is None:
raise ValueError(f"Dataset '{dataset_id}' not found")
return dataset
async def register_dataset(
self,
purpose: DatasetPurpose,
source: DataSource,
metadata: dict[str, Any] | None = None,
dataset_id: str | None = None,
) -> Dataset:
if isinstance(source, dict):
if source["type"] == "uri":
source = URIDataSource.parse_obj(source)
elif source["type"] == "rows":
source = RowsDataSource.parse_obj(source)
if not dataset_id:
dataset_id = f"dataset-{str(uuid.uuid4())}"
provider_dataset_id = dataset_id
# infer provider from source
if metadata:
if metadata.get("provider_id"):
provider_id = metadata.get("provider_id") # pass through from nvidia datasetio
elif source.type == DatasetType.rows.value:
provider_id = "localfs"
elif source.type == DatasetType.uri.value:
# infer provider from uri
if source.uri.startswith("huggingface"):
provider_id = "huggingface"
else:
provider_id = "localfs"
else:
raise ValueError(f"Unknown data source type: {source.type}")
if metadata is None:
metadata = {}
dataset = DatasetWithACL(
identifier=dataset_id,
provider_resource_id=provider_dataset_id,
provider_id=provider_id,
purpose=purpose,
source=source,
metadata=metadata,
)
await self.register_object(dataset)
return dataset
async def unregister_dataset(self, dataset_id: str) -> None:
dataset = await self.get_dataset(dataset_id)
if dataset is None:
raise ValueError(f"Dataset {dataset_id} not found")
await self.unregister_object(dataset)

View file

@ -0,0 +1,82 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import time
from typing import Any
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
from llama_stack.distribution.datatypes import (
ModelWithACL,
)
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
async def list_models(self) -> ListModelsResponse:
return ListModelsResponse(data=await self.get_all_with_type("model"))
async def openai_list_models(self) -> OpenAIListModelsResponse:
models = await self.get_all_with_type("model")
openai_models = [
OpenAIModel(
id=model.identifier,
object="model",
created=int(time.time()),
owned_by="llama_stack",
)
for model in models
]
return OpenAIListModelsResponse(data=openai_models)
async def get_model(self, model_id: str) -> Model:
model = await self.get_object_by_identifier("model", model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
return model
async def register_model(
self,
model_id: str,
provider_model_id: str | None = None,
provider_id: str | None = None,
metadata: dict[str, Any] | None = None,
model_type: ModelType | None = None,
) -> Model:
if provider_model_id is None:
provider_model_id = model_id
if provider_id is None:
# If provider_id not specified, use the only provider if it supports this model
if len(self.impls_by_provider_id) == 1:
provider_id = list(self.impls_by_provider_id.keys())[0]
else:
raise ValueError(
f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}"
)
if metadata is None:
metadata = {}
if model_type is None:
model_type = ModelType.llm
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
raise ValueError("Embedding model must have an embedding dimension in its metadata")
model = ModelWithACL(
identifier=model_id,
provider_resource_id=provider_model_id,
provider_id=provider_id,
metadata=metadata,
model_type=model_type,
)
registered_model = await self.register_object(model)
return registered_model
async def unregister_model(self, model_id: str) -> None:
existing_model = await self.get_model(model_id)
if existing_model is None:
raise ValueError(f"Model {model_id} not found")
await self.unregister_object(existing_model)

View file

@ -0,0 +1,62 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.scoring_functions import (
ListScoringFunctionsResponse,
ScoringFn,
ScoringFnParams,
ScoringFunctions,
)
from llama_stack.distribution.datatypes import (
ScoringFnWithACL,
)
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
async def list_scoring_functions(self) -> ListScoringFunctionsResponse:
return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value))
async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn:
scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id)
if scoring_fn is None:
raise ValueError(f"Scoring function '{scoring_fn_id}' not found")
return scoring_fn
async def register_scoring_function(
self,
scoring_fn_id: str,
description: str,
return_type: ParamType,
provider_scoring_fn_id: str | None = None,
provider_id: str | None = None,
params: ScoringFnParams | None = None,
) -> None:
if provider_scoring_fn_id is None:
provider_scoring_fn_id = scoring_fn_id
if provider_id is None:
if len(self.impls_by_provider_id) == 1:
provider_id = list(self.impls_by_provider_id.keys())[0]
else:
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id."
)
scoring_fn = ScoringFnWithACL(
identifier=scoring_fn_id,
description=description,
return_type=return_type,
provider_resource_id=provider_scoring_fn_id,
provider_id=provider_id,
params=params,
)
scoring_fn.provider_id = provider_id
await self.register_object(scoring_fn)

View file

@ -0,0 +1,57 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields
from llama_stack.distribution.datatypes import (
ShieldWithACL,
)
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def list_shields(self) -> ListShieldsResponse:
return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value))
async def get_shield(self, identifier: str) -> Shield:
shield = await self.get_object_by_identifier("shield", identifier)
if shield is None:
raise ValueError(f"Shield '{identifier}' not found")
return shield
async def register_shield(
self,
shield_id: str,
provider_shield_id: str | None = None,
provider_id: str | None = None,
params: dict[str, Any] | None = None,
) -> Shield:
if provider_shield_id is None:
provider_shield_id = shield_id
if provider_id is None:
# If provider_id not specified, use the only provider if it supports this shield type
if len(self.impls_by_provider_id) == 1:
provider_id = list(self.impls_by_provider_id.keys())[0]
else:
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id."
)
if params is None:
params = {}
shield = ShieldWithACL(
identifier=shield_id,
provider_resource_id=provider_shield_id,
provider_id=provider_id,
params=params,
)
await self.register_object(shield)
return shield

View file

@ -0,0 +1,132 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups
from llama_stack.distribution.datatypes import ToolGroupWithACL
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name: str) -> str | None:
# handle the funny case like "builtin::rag/knowledge_search"
parts = toolgroup_name_with_maybe_tool_name.split("/")
if len(parts) == 2:
return parts[0]
else:
return None
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
toolgroups_to_tools: dict[str, list[Tool]] = {}
tool_to_toolgroup: dict[str, str] = {}
# overridden
def get_provider_impl(self, routing_key: str, provider_id: str | None = None) -> Any:
# we don't index tools in the registry anymore, but only keep a cache of them by toolgroup_id
# TODO: we may want to invalidate the cache (for a given toolgroup_id) every once in a while?
toolgroup_id = parse_toolgroup_from_toolgroup_name_pair(routing_key)
if toolgroup_id:
routing_key = toolgroup_id
if routing_key in self.tool_to_toolgroup:
routing_key = self.tool_to_toolgroup[routing_key]
return super().get_provider_impl(routing_key, provider_id)
async def list_tools(self, toolgroup_id: str | None = None) -> ListToolsResponse:
if toolgroup_id:
if group_id := parse_toolgroup_from_toolgroup_name_pair(toolgroup_id):
toolgroup_id = group_id
toolgroups = [await self.get_tool_group(toolgroup_id)]
else:
toolgroups = await self.get_all_with_type("tool_group")
all_tools = []
for toolgroup in toolgroups:
if toolgroup.identifier not in self.toolgroups_to_tools:
await self._index_tools(toolgroup)
all_tools.extend(self.toolgroups_to_tools[toolgroup.identifier])
return ListToolsResponse(data=all_tools)
async def _index_tools(self, toolgroup: ToolGroup):
provider_impl = super().get_provider_impl(toolgroup.identifier, toolgroup.provider_id)
tooldefs_response = await provider_impl.list_runtime_tools(toolgroup.identifier, toolgroup.mcp_endpoint)
# TODO: kill this Tool vs ToolDef distinction
tooldefs = tooldefs_response.data
tools = []
for t in tooldefs:
tools.append(
Tool(
identifier=t.name,
toolgroup_id=toolgroup.identifier,
description=t.description or "",
parameters=t.parameters or [],
metadata=t.metadata,
provider_id=toolgroup.provider_id,
)
)
self.toolgroups_to_tools[toolgroup.identifier] = tools
for tool in tools:
self.tool_to_toolgroup[tool.identifier] = toolgroup.identifier
async def list_tool_groups(self) -> ListToolGroupsResponse:
return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group"))
async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
tool_group = await self.get_object_by_identifier("tool_group", toolgroup_id)
if tool_group is None:
raise ValueError(f"Tool group '{toolgroup_id}' not found")
return tool_group
async def get_tool(self, tool_name: str) -> Tool:
if tool_name in self.tool_to_toolgroup:
toolgroup_id = self.tool_to_toolgroup[tool_name]
tools = self.toolgroups_to_tools[toolgroup_id]
for tool in tools:
if tool.identifier == tool_name:
return tool
raise ValueError(f"Tool '{tool_name}' not found")
async def register_tool_group(
self,
toolgroup_id: str,
provider_id: str,
mcp_endpoint: URL | None = None,
args: dict[str, Any] | None = None,
) -> None:
toolgroup = ToolGroupWithACL(
identifier=toolgroup_id,
provider_id=provider_id,
provider_resource_id=toolgroup_id,
mcp_endpoint=mcp_endpoint,
args=args,
)
await self.register_object(toolgroup)
# ideally, indexing of the tools should not be necessary because anyone using
# the tools should first list the tools and then use them. but there are assumptions
# baked in some of the code and tests right now.
if not toolgroup.mcp_endpoint:
await self._index_tools(toolgroup)
return toolgroup
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
tool_group = await self.get_tool_group(toolgroup_id)
if tool_group is None:
raise ValueError(f"Tool group {toolgroup_id} not found")
await self.unregister_object(tool_group)
async def shutdown(self) -> None:
pass

View file

@ -0,0 +1,74 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import TypeAdapter
from llama_stack.apis.models import ModelType
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
from llama_stack.distribution.datatypes import (
VectorDBWithACL,
)
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
async def list_vector_dbs(self) -> ListVectorDBsResponse:
return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db"))
async def get_vector_db(self, vector_db_id: str) -> VectorDB:
vector_db = await self.get_object_by_identifier("vector_db", vector_db_id)
if vector_db is None:
raise ValueError(f"Vector DB '{vector_db_id}' not found")
return vector_db
async def register_vector_db(
self,
vector_db_id: str,
embedding_model: str,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
) -> VectorDB:
if provider_vector_db_id is None:
provider_vector_db_id = vector_db_id
if provider_id is None:
if len(self.impls_by_provider_id) > 0:
provider_id = list(self.impls_by_provider_id.keys())[0]
if len(self.impls_by_provider_id) > 1:
logger.warning(
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
)
else:
raise ValueError("No provider available. Please configure a vector_io provider.")
model = await self.get_object_by_identifier("model", embedding_model)
if model is None:
raise ValueError(f"Model {embedding_model} not found")
if model.model_type != ModelType.embedding:
raise ValueError(f"Model {embedding_model} is not an embedding model")
if "embedding_dimension" not in model.metadata:
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
vector_db_data = {
"identifier": vector_db_id,
"type": ResourceType.vector_db.value,
"provider_id": provider_id,
"provider_resource_id": provider_vector_db_id,
"embedding_model": embedding_model,
"embedding_dimension": model.metadata["embedding_dimension"],
}
vector_db = TypeAdapter(VectorDBWithACL).validate_python(vector_db_data)
await self.register_object(vector_db)
return vector_db
async def unregister_vector_db(self, vector_db_id: str) -> None:
existing_vector_db = await self.get_vector_db(vector_db_id)
if existing_vector_db is None:
raise ValueError(f"Vector DB {vector_db_id} not found")
await self.unregister_object(existing_vector_db)

View file

@ -8,7 +8,8 @@ import json
import httpx import httpx
from llama_stack.distribution.server.auth_providers import AuthProviderConfig, create_auth_provider from llama_stack.distribution.datatypes import AuthenticationConfig
from llama_stack.distribution.server.auth_providers import create_auth_provider
from llama_stack.log import get_logger from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth") logger = get_logger(name=__name__, category="auth")
@ -77,7 +78,7 @@ class AuthenticationMiddleware:
access resources that don't have access_attributes defined. access resources that don't have access_attributes defined.
""" """
def __init__(self, app, auth_config: AuthProviderConfig): def __init__(self, app, auth_config: AuthenticationConfig):
self.app = app self.app = app
self.auth_provider = create_auth_provider(auth_config) self.auth_provider = create_auth_provider(auth_config)
@ -113,6 +114,10 @@ class AuthenticationMiddleware:
"roles": [token], "roles": [token],
} }
# Store the client ID in the request scope so that downstream middleware (like QuotaMiddleware)
# can identify the requester and enforce per-client rate limits.
scope["authenticated_client_id"] = token
# Store attributes in request scope # Store attributes in request scope
scope["user_attributes"] = user_attributes scope["user_attributes"] = user_attributes
scope["principal"] = validation_result.principal scope["principal"] = validation_result.principal

View file

@ -4,17 +4,19 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import json import ssl
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum from asyncio import Lock
from pathlib import Path
from urllib.parse import parse_qs from urllib.parse import parse_qs
import httpx import httpx
from jose import jwt from jose import jwt
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator, model_validator
from typing_extensions import Self
from llama_stack.distribution.datatypes import AccessAttributes from llama_stack.distribution.datatypes import AccessAttributes, AuthenticationConfig, AuthProviderType
from llama_stack.log import get_logger from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth") logger = get_logger(name=__name__, category="auth")
@ -72,21 +74,6 @@ class AuthRequest(BaseModel):
request: AuthRequestContext = Field(description="Context information about the request being authenticated") request: AuthRequestContext = Field(description="Context information about the request being authenticated")
class AuthProviderType(str, Enum):
"""Supported authentication provider types."""
KUBERNETES = "kubernetes"
CUSTOM = "custom"
OAUTH2_TOKEN = "oauth2_token"
class AuthProviderConfig(BaseModel):
"""Base configuration for authentication providers."""
provider_type: AuthProviderType = Field(..., description="Type of authentication provider")
config: dict[str, str] = Field(..., description="Provider-specific configuration")
class AuthProvider(ABC): class AuthProvider(ABC):
"""Abstract base class for authentication providers.""" """Abstract base class for authentication providers."""
@ -101,83 +88,6 @@ class AuthProvider(ABC):
pass pass
class KubernetesAuthProviderConfig(BaseModel):
api_server_url: str
ca_cert_path: str | None = None
class KubernetesAuthProvider(AuthProvider):
"""Kubernetes authentication provider that validates tokens against the Kubernetes API server."""
def __init__(self, config: KubernetesAuthProviderConfig):
self.config = config
self._client = None
async def _get_client(self):
"""Get or create a Kubernetes client."""
if self._client is None:
# kubernetes-client has not async support, see:
# https://github.com/kubernetes-client/python/issues/323
from kubernetes import client
from kubernetes.client import ApiClient
# Configure the client
configuration = client.Configuration()
configuration.host = self.config.api_server_url
if self.config.ca_cert_path:
configuration.ssl_ca_cert = self.config.ca_cert_path
configuration.verify_ssl = bool(self.config.ca_cert_path)
# Create API client
self._client = ApiClient(configuration)
return self._client
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
"""Validate a Kubernetes token and return access attributes."""
try:
client = await self._get_client()
# Set the token in the client
client.set_default_header("Authorization", f"Bearer {token}")
# Make a request to validate the token
# We use the /api endpoint which requires authentication
from kubernetes.client import CoreV1Api
api = CoreV1Api(client)
api.get_api_resources(_request_timeout=3.0) # Set timeout for this specific request
# If we get here, the token is valid
# Extract user info from the token claims
import base64
# Decode the token (without verification since we've already validated it)
token_parts = token.split(".")
payload = json.loads(base64.b64decode(token_parts[1] + "=" * (-len(token_parts[1]) % 4)))
# Extract user information from the token
username = payload.get("sub", "")
groups = payload.get("groups", [])
return TokenValidationResult(
principal=username,
access_attributes=AccessAttributes(
roles=[username], # Use username as a role
teams=groups, # Use Kubernetes groups as teams
),
)
except Exception as e:
logger.exception("Failed to validate Kubernetes token")
raise ValueError("Invalid or expired token") from e
async def close(self):
"""Close the HTTP client."""
if self._client:
self._client.close()
self._client = None
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> AccessAttributes: def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> AccessAttributes:
attributes = AccessAttributes() attributes = AccessAttributes()
for claim_key, attribute_key in mapping.items(): for claim_key, attribute_key in mapping.items():
@ -197,11 +107,24 @@ def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str])
return attributes return attributes
class OAuth2TokenAuthProviderConfig(BaseModel): class OAuth2JWKSConfig(BaseModel):
# The JWKS URI for collecting public keys # The JWKS URI for collecting public keys
jwks_uri: str uri: str
cache_ttl: int = 3600 key_recheck_period: int = Field(default=3600, description="The period to recheck the JWKS URI for key updates")
class OAuth2IntrospectionConfig(BaseModel):
url: str
client_id: str
client_secret: str
send_secret_in_body: bool = False
class OAuth2TokenAuthProviderConfig(BaseModel):
audience: str = "llama-stack" audience: str = "llama-stack"
verify_tls: bool = True
tls_cafile: Path | None = None
issuer: str | None = Field(default=None, description="The OIDC issuer URL.")
claims_mapping: dict[str, str] = Field( claims_mapping: dict[str, str] = Field(
default_factory=lambda: { default_factory=lambda: {
"sub": "roles", "sub": "roles",
@ -213,6 +136,8 @@ class OAuth2TokenAuthProviderConfig(BaseModel):
"namespace": "namespaces", "namespace": "namespaces",
}, },
) )
jwks: OAuth2JWKSConfig | None
introspection: OAuth2IntrospectionConfig | None = None
@classmethod @classmethod
@field_validator("claims_mapping") @field_validator("claims_mapping")
@ -224,6 +149,14 @@ class OAuth2TokenAuthProviderConfig(BaseModel):
raise ValueError(f"claims_mapping value is not a valid attribute: {value}") raise ValueError(f"claims_mapping value is not a valid attribute: {value}")
return v return v
@model_validator(mode="after")
def validate_mode(self) -> Self:
if not self.jwks and not self.introspection:
raise ValueError("One of jwks or introspection must be configured")
if self.jwks and self.introspection:
raise ValueError("At present only one of jwks or introspection should be configured")
return self
class OAuth2TokenAuthProvider(AuthProvider): class OAuth2TokenAuthProvider(AuthProvider):
""" """
@ -236,8 +169,16 @@ class OAuth2TokenAuthProvider(AuthProvider):
self.config = config self.config = config
self._jwks_at: float = 0.0 self._jwks_at: float = 0.0
self._jwks: dict[str, str] = {} self._jwks: dict[str, str] = {}
self._jwks_lock = Lock()
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult: async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
if self.config.jwks:
return await self.validate_jwt_token(token, scope)
if self.config.introspection:
return await self.introspect_token(token, scope)
raise ValueError("One of jwks or introspection must be configured")
async def validate_jwt_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
"""Validate a token using the JWT token.""" """Validate a token using the JWT token."""
await self._refresh_jwks() await self._refresh_jwks()
@ -253,7 +194,7 @@ class OAuth2TokenAuthProvider(AuthProvider):
key_data, key_data,
algorithms=[algorithm], algorithms=[algorithm],
audience=self.config.audience, audience=self.config.audience,
options={"verify_exp": True}, issuer=self.config.issuer,
) )
except Exception as exc: except Exception as exc:
raise ValueError(f"Invalid JWT token: {token}") from exc raise ValueError(f"Invalid JWT token: {token}") from exc
@ -267,21 +208,84 @@ class OAuth2TokenAuthProvider(AuthProvider):
access_attributes=access_attributes, access_attributes=access_attributes,
) )
async def introspect_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
"""Validate a token using token introspection as defined by RFC 7662."""
form = {
"token": token,
}
if self.config.introspection is None:
raise ValueError("Introspection is not configured")
if self.config.introspection.send_secret_in_body:
form["client_id"] = self.config.introspection.client_id
form["client_secret"] = self.config.introspection.client_secret
auth = None
else:
auth = (self.config.introspection.client_id, self.config.introspection.client_secret)
ssl_ctxt = None
if self.config.tls_cafile:
ssl_ctxt = ssl.create_default_context(cafile=self.config.tls_cafile.as_posix())
try:
async with httpx.AsyncClient(verify=ssl_ctxt) as client:
response = await client.post(
self.config.introspection.url,
data=form,
auth=auth,
timeout=10.0, # Add a reasonable timeout
)
if response.status_code != 200:
logger.warning(f"Token introspection failed with status code: {response.status_code}")
raise ValueError(f"Token introspection failed: {response.status_code}")
fields = response.json()
if not fields["active"]:
raise ValueError("Token not active")
principal = fields["sub"] or fields["username"]
access_attributes = get_attributes_from_claims(fields, self.config.claims_mapping)
return TokenValidationResult(
principal=principal,
access_attributes=access_attributes,
)
except httpx.TimeoutException:
logger.exception("Token introspection request timed out")
raise
except ValueError:
# Re-raise ValueError exceptions to preserve their message
raise
except Exception as e:
logger.exception("Error during token introspection")
raise ValueError("Token introspection error") from e
async def close(self): async def close(self):
"""Close the HTTP client.""" pass
async def _refresh_jwks(self) -> None: async def _refresh_jwks(self) -> None:
if time.time() - self._jwks_at > self.config.cache_ttl: """
async with httpx.AsyncClient() as client: Refresh the JWKS cache.
res = await client.get(self.config.jwks_uri, timeout=5)
res.raise_for_status() This is a simple cache that expires after a certain amount of time (defined by `key_recheck_period`).
jwks_data = res.json()["keys"] If the cache is expired, we refresh the JWKS from the JWKS URI.
self._jwks = {}
for k in jwks_data: Notes: for Kubernetes which doesn't fully implement the OIDC protocol:
kid = k["kid"] * It doesn't have user authentication flows
# Store the entire key object as it may be needed for different algorithms * It doesn't have refresh tokens
self._jwks[kid] = k """
self._jwks_at = time.time() async with self._jwks_lock:
if self.config.jwks is None:
raise ValueError("JWKS is not configured")
if time.time() - self._jwks_at > self.config.jwks.key_recheck_period:
verify = self.config.tls_cafile.as_posix() if self.config.tls_cafile else self.config.verify_tls
async with httpx.AsyncClient(verify=verify) as client:
res = await client.get(self.config.jwks.uri, timeout=5)
res.raise_for_status()
jwks_data = res.json()["keys"]
updated = {}
for k in jwks_data:
kid = k["kid"]
# Store the entire key object as it may be needed for different algorithms
updated[kid] = k
self._jwks = updated
self._jwks_at = time.time()
class CustomAuthProviderConfig(BaseModel): class CustomAuthProviderConfig(BaseModel):
@ -359,13 +363,11 @@ class CustomAuthProvider(AuthProvider):
self._client = None self._client = None
def create_auth_provider(config: AuthProviderConfig) -> AuthProvider: def create_auth_provider(config: AuthenticationConfig) -> AuthProvider:
"""Factory function to create the appropriate auth provider.""" """Factory function to create the appropriate auth provider."""
provider_type = config.provider_type.lower() provider_type = config.provider_type.lower()
if provider_type == "kubernetes": if provider_type == "custom":
return KubernetesAuthProvider(KubernetesAuthProviderConfig.model_validate(config.config))
elif provider_type == "custom":
return CustomAuthProvider(CustomAuthProviderConfig.model_validate(config.config)) return CustomAuthProvider(CustomAuthProviderConfig.model_validate(config.config))
elif provider_type == "oauth2_token": elif provider_type == "oauth2_token":
return OAuth2TokenAuthProvider(OAuth2TokenAuthProviderConfig.model_validate(config.config)) return OAuth2TokenAuthProvider(OAuth2TokenAuthProviderConfig.model_validate(config.config))

View file

@ -0,0 +1,110 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import time
from datetime import datetime, timedelta, timezone
from starlette.types import ASGIApp, Receive, Scope, Send
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl
logger = get_logger(name=__name__, category="quota")
class QuotaMiddleware:
"""
ASGI middleware that enforces separate quotas for authenticated and anonymous clients
within a configurable time window.
- For authenticated requests, it reads the client ID from the
`Authorization: Bearer <client_id>` header.
- For anonymous requests, it falls back to the IP address of the client.
Requests are counted in a KV store (e.g., SQLite), and HTTP 429 is returned
once a client exceeds its quota.
"""
def __init__(
self,
app: ASGIApp,
kv_config: KVStoreConfig,
anonymous_max_requests: int,
authenticated_max_requests: int,
window_seconds: int = 86400,
):
self.app = app
self.kv_config = kv_config
self.kv: KVStore | None = None
self.anonymous_max_requests = anonymous_max_requests
self.authenticated_max_requests = authenticated_max_requests
self.window_seconds = window_seconds
if isinstance(self.kv_config, SqliteKVStoreConfig):
logger.warning(
"QuotaMiddleware: Using SQLite backend. Expiry/TTL is not enforced; cleanup is manual. "
f"window_seconds={self.window_seconds}"
)
async def _get_kv(self) -> KVStore:
if self.kv is None:
self.kv = await kvstore_impl(self.kv_config)
return self.kv
async def __call__(self, scope: Scope, receive: Receive, send: Send):
if scope["type"] == "http":
# pick key & limit based on auth
auth_id = scope.get("authenticated_client_id")
if auth_id:
key_id = auth_id
limit = self.authenticated_max_requests
else:
# fallback to IP
client = scope.get("client")
key_id = client[0] if client else "anonymous"
limit = self.anonymous_max_requests
current_window = int(time.time() // self.window_seconds)
key = f"quota:{key_id}:{current_window}"
try:
kv = await self._get_kv()
prev = await kv.get(key) or "0"
count = int(prev) + 1
if int(prev) == 0:
# Set with expiration datetime when it is the first request in the window.
expiration = datetime.now(timezone.utc) + timedelta(seconds=self.window_seconds)
await kv.set(key, str(count), expiration=expiration)
else:
await kv.set(key, str(count))
except Exception:
logger.exception("Failed to access KV store for quota")
return await self._send_error(send, 500, "Quota service error")
if count > limit:
logger.warning(
"Quota exceeded for client %s: %d/%d",
key_id,
count,
limit,
)
return await self._send_error(send, 429, "Quota exceeded")
return await self.app(scope, receive, send)
async def _send_error(self, send: Send, status: int, message: str):
await send(
{
"type": "http.response.start",
"status": status,
"headers": [[b"content-type", b"application/json"]],
}
)
body = json.dumps({"error": {"message": message}}).encode()
await send({"type": "http.response.body", "body": body})

View file

@ -23,11 +23,12 @@ import yaml
from fastapi import Body, FastAPI, HTTPException, Request from fastapi import Body, FastAPI, HTTPException, Request
from fastapi import Path as FastapiPath from fastapi import Path as FastapiPath
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from openai import BadRequestError from openai import BadRequestError
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
from llama_stack.distribution.datatypes import LoggingConfig, StackRunConfig from llama_stack.distribution.datatypes import AuthenticationRequiredError, LoggingConfig, StackRunConfig
from llama_stack.distribution.distribution import builtin_automatically_routed_apis from llama_stack.distribution.distribution import builtin_automatically_routed_apis
from llama_stack.distribution.request_headers import ( from llama_stack.distribution.request_headers import (
PROVIDER_DATA_VAR, PROVIDER_DATA_VAR,
@ -60,6 +61,7 @@ from llama_stack.providers.utils.telemetry.tracing import (
from .auth import AuthenticationMiddleware from .auth import AuthenticationMiddleware
from .endpoints import get_all_api_endpoints from .endpoints import get_all_api_endpoints
from .quota import QuotaMiddleware
REPO_ROOT = Path(__file__).parent.parent.parent.parent REPO_ROOT = Path(__file__).parent.parent.parent.parent
@ -120,6 +122,8 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
return HTTPException(status_code=504, detail=f"Operation timed out: {str(exc)}") return HTTPException(status_code=504, detail=f"Operation timed out: {str(exc)}")
elif isinstance(exc, NotImplementedError): elif isinstance(exc, NotImplementedError):
return HTTPException(status_code=501, detail=f"Not implemented: {str(exc)}") return HTTPException(status_code=501, detail=f"Not implemented: {str(exc)}")
elif isinstance(exc, AuthenticationRequiredError):
return HTTPException(status_code=401, detail=f"Authentication required: {str(exc)}")
else: else:
return HTTPException( return HTTPException(
status_code=500, status_code=500,
@ -280,7 +284,18 @@ class TracingMiddleware:
logger.debug(f"No matching endpoint found for path: {path}, falling back to FastAPI") logger.debug(f"No matching endpoint found for path: {path}, falling back to FastAPI")
return await self.app(scope, receive, send) return await self.app(scope, receive, send)
trace_context = await start_trace(trace_path, {"__location__": "server", "raw_path": path}) trace_attributes = {"__location__": "server", "raw_path": path}
# Extract W3C trace context headers and store as trace attributes
headers = dict(scope.get("headers", []))
traceparent = headers.get(b"traceparent", b"").decode()
if traceparent:
trace_attributes["traceparent"] = traceparent
tracestate = headers.get(b"tracestate", b"").decode()
if tracestate:
trace_attributes["tracestate"] = tracestate
trace_context = await start_trace(trace_path, trace_attributes)
async def send_with_trace_id(message): async def send_with_trace_id(message):
if message["type"] == "http.response.start": if message["type"] == "http.response.start":
@ -370,14 +385,6 @@ def main(args: argparse.Namespace | None = None):
if args is None: if args is None:
args = parser.parse_args() args = parser.parse_args()
# Check for deprecated argument usage
if "--config" in sys.argv:
warnings.warn(
"The '--config' argument is deprecated and will be removed in a future version. Use '--config' instead.",
DeprecationWarning,
stacklevel=2,
)
log_line = "" log_line = ""
if args.config: if args.config:
# if the user provided a config file, use it, even if template was specified # if the user provided a config file, use it, even if template was specified
@ -431,6 +438,46 @@ def main(args: argparse.Namespace | None = None):
if config.server.auth: if config.server.auth:
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_type.value}") logger.info(f"Enabling authentication with provider: {config.server.auth.provider_type.value}")
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth) app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth)
else:
if config.server.quota:
quota = config.server.quota
logger.warning(
"Configured authenticated_max_requests (%d) but no auth is enabled; "
"falling back to anonymous_max_requests (%d) for all the requests",
quota.authenticated_max_requests,
quota.anonymous_max_requests,
)
if config.server.quota:
logger.info("Enabling quota middleware for authenticated and anonymous clients")
quota = config.server.quota
anonymous_max_requests = quota.anonymous_max_requests
# if auth is disabled, use the anonymous max requests
authenticated_max_requests = quota.authenticated_max_requests if config.server.auth else anonymous_max_requests
kv_config = quota.kvstore
window_map = {"day": 86400}
window_seconds = window_map[quota.period.value]
app.add_middleware(
QuotaMiddleware,
kv_config=kv_config,
anonymous_max_requests=anonymous_max_requests,
authenticated_max_requests=authenticated_max_requests,
window_seconds=window_seconds,
)
# --- CORS middleware for local development ---
# TODO: move to reverse proxy
ui_port = os.environ.get("LLAMA_STACK_UI_PORT", 8322)
app.add_middleware(
CORSMiddleware,
allow_origins=[f"http://localhost:{ui_port}"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
try: try:
impls = asyncio.run(construct_stack(config)) impls = asyncio.run(construct_stack(config))

View file

@ -36,7 +36,7 @@ class DistributionRegistry(Protocol):
REGISTER_PREFIX = "distributions:registry" REGISTER_PREFIX = "distributions:registry"
KEY_VERSION = "v8" KEY_VERSION = "v9"
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}" KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"

View file

@ -8,6 +8,7 @@ import logging
import os import os
import signal import signal
import subprocess import subprocess
import sys
from termcolor import cprint from termcolor import cprint
@ -33,6 +34,7 @@ def formulate_run_args(image_type, image_name, config, template_name) -> list:
cprint( cprint(
"No current conda environment detected, please specify a conda environment name with --image-name", "No current conda environment detected, please specify a conda environment name with --image-name",
color="red", color="red",
file=sys.stderr,
) )
return return
@ -49,12 +51,13 @@ def formulate_run_args(image_type, image_name, config, template_name) -> list:
return envpath return envpath
return None return None
print(f"Using conda environment: {env_name}") cprint(f"Using conda environment: {env_name}", color="green", file=sys.stderr)
conda_prefix = get_conda_prefix(env_name) conda_prefix = get_conda_prefix(env_name)
if not conda_prefix: if not conda_prefix:
cprint( cprint(
f"Conda environment {env_name} does not exist.", f"Conda environment {env_name} does not exist.",
color="red", color="red",
file=sys.stderr,
) )
return return
@ -63,6 +66,7 @@ def formulate_run_args(image_type, image_name, config, template_name) -> list:
cprint( cprint(
f"Build file {build_file} does not exist.\n\nPlease run `llama stack build` or specify the correct conda environment name with --image-name", f"Build file {build_file} does not exist.\n\nPlease run `llama stack build` or specify the correct conda environment name with --image-name",
color="red", color="red",
file=sys.stderr,
) )
return return
else: else:
@ -73,9 +77,10 @@ def formulate_run_args(image_type, image_name, config, template_name) -> list:
cprint( cprint(
"No current virtual environment detected, please specify a virtual environment name with --image-name", "No current virtual environment detected, please specify a virtual environment name with --image-name",
color="red", color="red",
file=sys.stderr,
) )
return return
print(f"Using virtual environment: {env_name}") cprint(f"Using virtual environment: {env_name}", file=sys.stderr)
script = importlib.resources.files("llama_stack") / "distribution/start_stack.sh" script = importlib.resources.files("llama_stack") / "distribution/start_stack.sh"
run_args = [ run_args = [

View file

@ -6,6 +6,7 @@
import logging import logging
import os import os
import sys
from logging.config import dictConfig from logging.config import dictConfig
from rich.console import Console from rich.console import Console
@ -234,7 +235,7 @@ def get_logger(
env_config = os.environ.get("LLAMA_STACK_LOGGING", "") env_config = os.environ.get("LLAMA_STACK_LOGGING", "")
if env_config: if env_config:
cprint(f"Environment variable LLAMA_STACK_LOGGING found: {env_config}", "yellow") cprint(f"Environment variable LLAMA_STACK_LOGGING found: {env_config}", color="yellow", file=sys.stderr)
_category_levels.update(parse_environment_config(env_config)) _category_levels.update(parse_environment_config(env_config))
log_file = os.environ.get("LLAMA_STACK_LOG_FILE") log_file = os.environ.get("LLAMA_STACK_LOG_FILE")

View file

@ -174,6 +174,7 @@ class Llama3:
cprint( cprint(
"Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n", "Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n",
"red", "red",
file=sys.stderr,
) )
prompt_tokens = [inp.tokens for inp in llm_inputs] prompt_tokens = [inp.tokens for inp in llm_inputs]
@ -184,7 +185,11 @@ class Llama3:
max_prompt_len = max(len(t) for t in prompt_tokens) max_prompt_len = max(len(t) for t in prompt_tokens)
if max_prompt_len >= params.max_seq_len: if max_prompt_len >= params.max_seq_len:
cprint(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", "red") cprint(
f"Out of token budget {max_prompt_len} vs {params.max_seq_len}",
color="red",
file=sys.stderr,
)
return return
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len) total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)

View file

@ -133,9 +133,9 @@ class Llama4:
print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1" print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1"
if print_model_input: if print_model_input:
cprint("Input to model:\n", "yellow") cprint("Input to model:\n", color="yellow", file=sys.stderr)
for inp in llm_inputs: for inp in llm_inputs:
cprint(self.tokenizer.decode(inp.tokens), "grey") cprint(self.tokenizer.decode(inp.tokens), color="grey", file=sys.stderr)
prompt_tokens = [inp.tokens for inp in llm_inputs] prompt_tokens = [inp.tokens for inp in llm_inputs]
bsz = len(llm_inputs) bsz = len(llm_inputs)
@ -145,7 +145,7 @@ class Llama4:
max_prompt_len = max(len(t) for t in prompt_tokens) max_prompt_len = max(len(t) for t in prompt_tokens)
if max_prompt_len >= params.max_seq_len: if max_prompt_len >= params.max_seq_len:
cprint(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", "red") cprint(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", color="red", file=sys.stderr)
return return
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len) total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)

View file

@ -16,7 +16,7 @@ from llama_stack.apis.datatypes import Api
from llama_stack.apis.models import Model from llama_stack.apis.models import Model
from llama_stack.apis.scoring_functions import ScoringFn from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.apis.shields import Shield from llama_stack.apis.shields import Shield
from llama_stack.apis.tools import Tool from llama_stack.apis.tools import ToolGroup
from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
@ -74,10 +74,10 @@ class BenchmarksProtocolPrivate(Protocol):
async def register_benchmark(self, benchmark: Benchmark) -> None: ... async def register_benchmark(self, benchmark: Benchmark) -> None: ...
class ToolsProtocolPrivate(Protocol): class ToolGroupsProtocolPrivate(Protocol):
async def register_tool(self, tool: Tool) -> None: ... async def register_toolgroup(self, toolgroup: ToolGroup) -> None: ...
async def unregister_tool(self, tool_id: str) -> None: ... async def unregister_toolgroup(self, toolgroup_id: str) -> None: ...
@json_schema_type @json_schema_type

View file

@ -20,9 +20,12 @@ from llama_stack.apis.agents import (
AgentTurnCreateRequest, AgentTurnCreateRequest,
AgentTurnResumeRequest, AgentTurnResumeRequest,
Document, Document,
ListOpenAIResponseInputItem,
ListOpenAIResponseObject,
OpenAIResponseInput, OpenAIResponseInput,
OpenAIResponseInputTool, OpenAIResponseInputTool,
OpenAIResponseObject, OpenAIResponseObject,
Order,
Session, Session,
Turn, Turn,
) )
@ -39,6 +42,7 @@ from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO from llama_stack.apis.vector_io import VectorIO
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
from llama_stack.providers.utils.pagination import paginate_records from llama_stack.providers.utils.pagination import paginate_records
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
from .agent_instance import ChatAgent from .agent_instance import ChatAgent
from .config import MetaReferenceAgentsImplConfig from .config import MetaReferenceAgentsImplConfig
@ -66,15 +70,17 @@ class MetaReferenceAgentsImpl(Agents):
self.tool_groups_api = tool_groups_api self.tool_groups_api = tool_groups_api
self.in_memory_store = InmemoryKVStoreImpl() self.in_memory_store = InmemoryKVStoreImpl()
self.openai_responses_impl = None self.openai_responses_impl: OpenAIResponsesImpl | None = None
async def initialize(self) -> None: async def initialize(self) -> None:
self.persistence_store = await kvstore_impl(self.config.persistence_store) self.persistence_store = await kvstore_impl(self.config.persistence_store)
self.responses_store = ResponsesStore(self.config.responses_store)
await self.responses_store.initialize()
self.openai_responses_impl = OpenAIResponsesImpl( self.openai_responses_impl = OpenAIResponsesImpl(
self.persistence_store,
inference_api=self.inference_api, inference_api=self.inference_api,
tool_groups_api=self.tool_groups_api, tool_groups_api=self.tool_groups_api,
tool_runtime_api=self.tool_runtime_api, tool_runtime_api=self.tool_runtime_api,
responses_store=self.responses_store,
) )
async def create_agent( async def create_agent(
@ -305,14 +311,15 @@ class MetaReferenceAgentsImpl(Agents):
# OpenAI responses # OpenAI responses
async def get_openai_response( async def get_openai_response(
self, self,
id: str, response_id: str,
) -> OpenAIResponseObject: ) -> OpenAIResponseObject:
return await self.openai_responses_impl.get_openai_response(id) return await self.openai_responses_impl.get_openai_response(response_id)
async def create_openai_response( async def create_openai_response(
self, self,
input: str | list[OpenAIResponseInput], input: str | list[OpenAIResponseInput],
model: str, model: str,
instructions: str | None = None,
previous_response_id: str | None = None, previous_response_id: str | None = None,
store: bool | None = True, store: bool | None = True,
stream: bool | None = False, stream: bool | None = False,
@ -320,5 +327,27 @@ class MetaReferenceAgentsImpl(Agents):
tools: list[OpenAIResponseInputTool] | None = None, tools: list[OpenAIResponseInputTool] | None = None,
) -> OpenAIResponseObject: ) -> OpenAIResponseObject:
return await self.openai_responses_impl.create_openai_response( return await self.openai_responses_impl.create_openai_response(
input, model, previous_response_id, store, stream, temperature, tools input, model, instructions, previous_response_id, store, stream, temperature, tools
)
async def list_openai_responses(
self,
after: str | None = None,
limit: int | None = 50,
model: str | None = None,
order: Order | None = Order.desc,
) -> ListOpenAIResponseObject:
return await self.openai_responses_impl.list_openai_responses(after, limit, model, order)
async def list_openai_response_input_items(
self,
response_id: str,
after: str | None = None,
before: str | None = None,
include: list[str] | None = None,
limit: int | None = 20,
order: Order | None = Order.desc,
) -> ListOpenAIResponseInputItem:
return await self.openai_responses_impl.list_openai_response_input_items(
response_id, after, before, include, limit, order
) )

View file

@ -10,10 +10,12 @@ from pydantic import BaseModel
from llama_stack.providers.utils.kvstore import KVStoreConfig from llama_stack.providers.utils.kvstore import KVStoreConfig
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig
class MetaReferenceAgentsImplConfig(BaseModel): class MetaReferenceAgentsImplConfig(BaseModel):
persistence_store: KVStoreConfig persistence_store: KVStoreConfig
responses_store: SqlStoreConfig
@classmethod @classmethod
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]: def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
@ -21,5 +23,9 @@ class MetaReferenceAgentsImplConfig(BaseModel):
"persistence_store": SqliteKVStoreConfig.sample_run_config( "persistence_store": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__, __distro_dir__=__distro_dir__,
db_name="agents_store.db", db_name="agents_store.db",
) ),
"responses_store": SqliteSqlStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="responses_store.db",
),
} }

View file

@ -5,6 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import json import json
import time
import uuid import uuid
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from typing import Any, cast from typing import Any, cast
@ -12,24 +13,29 @@ from typing import Any, cast
from openai.types.chat import ChatCompletionToolParam from openai.types.chat import ChatCompletionToolParam
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.apis.agents import Order
from llama_stack.apis.agents.openai_responses import ( from llama_stack.apis.agents.openai_responses import (
AllowedToolsFilter,
ListOpenAIResponseInputItem,
ListOpenAIResponseObject,
OpenAIResponseInput, OpenAIResponseInput,
OpenAIResponseInputFunctionToolCallOutput, OpenAIResponseInputFunctionToolCallOutput,
OpenAIResponseInputItemList,
OpenAIResponseInputMessageContent, OpenAIResponseInputMessageContent,
OpenAIResponseInputMessageContentImage, OpenAIResponseInputMessageContentImage,
OpenAIResponseInputMessageContentText, OpenAIResponseInputMessageContentText,
OpenAIResponseInputTool, OpenAIResponseInputTool,
OpenAIResponseInputToolFunction, OpenAIResponseInputToolMCP,
OpenAIResponseMessage, OpenAIResponseMessage,
OpenAIResponseObject, OpenAIResponseObject,
OpenAIResponseObjectStream, OpenAIResponseObjectStream,
OpenAIResponseObjectStreamResponseCompleted, OpenAIResponseObjectStreamResponseCompleted,
OpenAIResponseObjectStreamResponseCreated, OpenAIResponseObjectStreamResponseCreated,
OpenAIResponseObjectStreamResponseOutputTextDelta,
OpenAIResponseOutput, OpenAIResponseOutput,
OpenAIResponseOutputMessageContent, OpenAIResponseOutputMessageContent,
OpenAIResponseOutputMessageContentOutputText, OpenAIResponseOutputMessageContentOutputText,
OpenAIResponseOutputMessageFunctionToolCall, OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseOutputMessageMCPListTools,
OpenAIResponseOutputMessageWebSearchToolCall, OpenAIResponseOutputMessageWebSearchToolCall,
) )
from llama_stack.apis.inference.inference import ( from llama_stack.apis.inference.inference import (
@ -49,11 +55,12 @@ from llama_stack.apis.inference.inference import (
OpenAIToolMessageParam, OpenAIToolMessageParam,
OpenAIUserMessageParam, OpenAIUserMessageParam,
) )
from llama_stack.apis.tools.tools import ToolGroups, ToolInvocationResult, ToolRuntime from llama_stack.apis.tools.tools import ToolGroups, ToolRuntime
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.responses.responses_store import ResponsesStore
from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool, list_mcp_tools
logger = get_logger(name=__name__, category="openai_responses") logger = get_logger(name=__name__, category="openai_responses")
@ -162,41 +169,43 @@ async def _get_message_type_by_role(role: str):
class OpenAIResponsePreviousResponseWithInputItems(BaseModel): class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
input_items: OpenAIResponseInputItemList input_items: ListOpenAIResponseInputItem
response: OpenAIResponseObject response: OpenAIResponseObject
class ChatCompletionContext(BaseModel):
model: str
messages: list[OpenAIMessageParam]
tools: list[ChatCompletionToolParam] | None = None
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP]
stream: bool
temperature: float | None
class OpenAIResponsesImpl: class OpenAIResponsesImpl:
def __init__( def __init__(
self, self,
persistence_store: KVStore,
inference_api: Inference, inference_api: Inference,
tool_groups_api: ToolGroups, tool_groups_api: ToolGroups,
tool_runtime_api: ToolRuntime, tool_runtime_api: ToolRuntime,
responses_store: ResponsesStore,
): ):
self.persistence_store = persistence_store
self.inference_api = inference_api self.inference_api = inference_api
self.tool_groups_api = tool_groups_api self.tool_groups_api = tool_groups_api
self.tool_runtime_api = tool_runtime_api self.tool_runtime_api = tool_runtime_api
self.responses_store = responses_store
async def _get_previous_response_with_input(self, id: str) -> OpenAIResponsePreviousResponseWithInputItems:
key = f"{OPENAI_RESPONSES_PREFIX}{id}"
response_json = await self.persistence_store.get(key=key)
if response_json is None:
raise ValueError(f"OpenAI response with id '{id}' not found")
return OpenAIResponsePreviousResponseWithInputItems.model_validate_json(response_json)
async def _prepend_previous_response( async def _prepend_previous_response(
self, input: str | list[OpenAIResponseInput], previous_response_id: str | None = None self, input: str | list[OpenAIResponseInput], previous_response_id: str | None = None
): ):
if previous_response_id: if previous_response_id:
previous_response_with_input = await self._get_previous_response_with_input(previous_response_id) previous_response_with_input = await self.responses_store.get_response_object(previous_response_id)
# previous response input items # previous response input items
new_input_items = previous_response_with_input.input_items.data new_input_items = previous_response_with_input.input
# previous response output items # previous response output items
new_input_items.extend(previous_response_with_input.response.output) new_input_items.extend(previous_response_with_input.output)
# new input items from the current request # new input items from the current request
if isinstance(input, str): if isinstance(input, str):
@ -208,99 +217,60 @@ class OpenAIResponsesImpl:
return input return input
async def _prepend_instructions(self, messages, instructions):
if instructions:
messages.insert(0, OpenAISystemMessageParam(content=instructions))
async def get_openai_response( async def get_openai_response(
self, self,
id: str, response_id: str,
) -> OpenAIResponseObject: ) -> OpenAIResponseObject:
response_with_input = await self._get_previous_response_with_input(id) response_with_input = await self.responses_store.get_response_object(response_id)
return response_with_input.response return OpenAIResponseObject(**{k: v for k, v in response_with_input.model_dump().items() if k != "input"})
async def create_openai_response( async def list_openai_responses(
self, self,
input: str | list[OpenAIResponseInput], after: str | None = None,
model: str, limit: int | None = 50,
previous_response_id: str | None = None, model: str | None = None,
store: bool | None = True, order: Order | None = Order.desc,
stream: bool | None = False, ) -> ListOpenAIResponseObject:
temperature: float | None = None, return await self.responses_store.list_responses(after, limit, model, order)
tools: list[OpenAIResponseInputTool] | None = None,
):
stream = False if stream is None else stream
input = await self._prepend_previous_response(input, previous_response_id) async def list_openai_response_input_items(
messages = await _convert_response_input_to_chat_messages(input) self,
chat_tools = await self._convert_response_tools_to_chat_tools(tools) if tools else None response_id: str,
chat_response = await self.inference_api.openai_chat_completion( after: str | None = None,
model=model, before: str | None = None,
messages=messages, include: list[str] | None = None,
tools=chat_tools, limit: int | None = 20,
stream=stream, order: Order | None = Order.desc,
temperature=temperature, ) -> ListOpenAIResponseInputItem:
) """List input items for a given OpenAI response.
if stream: :param response_id: The ID of the response to retrieve input items for.
# TODO: refactor this into a separate method that handles streaming :param after: An item ID to list items after, used for pagination.
chat_response_id = "" :param before: An item ID to list items before, used for pagination.
chat_response_content = [] :param include: Additional fields to include in the response.
chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {} :param limit: A limit on the number of objects to be returned.
# TODO: these chunk_ fields are hacky and only take the last chunk into account :param order: The order to return the input items in.
chunk_created = 0 :returns: An ListOpenAIResponseInputItem.
chunk_model = "" """
chunk_finish_reason = "" return await self.responses_store.list_response_input_items(response_id, after, before, include, limit, order)
async for chunk in chat_response:
chat_response_id = chunk.id
chunk_created = chunk.created
chunk_model = chunk.model
for chunk_choice in chunk.choices:
# TODO: this only works for text content
chat_response_content.append(chunk_choice.delta.content or "")
if chunk_choice.finish_reason:
chunk_finish_reason = chunk_choice.finish_reason
# Aggregate tool call arguments across chunks, using their index as the aggregation key
if chunk_choice.delta.tool_calls:
for tool_call in chunk_choice.delta.tool_calls:
response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
if response_tool_call:
response_tool_call.function.arguments += tool_call.function.arguments
else:
tool_call_dict: dict[str, Any] = tool_call.model_dump()
# Ensure we don't have any empty type field in the tool call dict.
# The OpenAI client used by providers often returns a type=None here.
tool_call_dict.pop("type", None)
response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict)
chat_response_tool_calls[tool_call.index] = response_tool_call
# Convert the dict of tool calls by index to a list of tool calls to pass back in our response
if chat_response_tool_calls:
tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())]
else:
tool_calls = None
assistant_message = OpenAIAssistantMessageParam(
content="".join(chat_response_content),
tool_calls=tool_calls,
)
chat_response = OpenAIChatCompletion(
id=chat_response_id,
choices=[
OpenAIChoice(
message=assistant_message,
finish_reason=chunk_finish_reason,
index=0,
)
],
created=chunk_created,
model=chunk_model,
)
else:
# dump and reload to map to our pydantic types
chat_response = OpenAIChatCompletion(**chat_response.model_dump())
async def _process_response_choices(
self,
chat_response: OpenAIChatCompletion,
ctx: ChatCompletionContext,
tools: list[OpenAIResponseInputTool] | None,
) -> list[OpenAIResponseOutput]:
"""Handle tool execution and response message creation."""
output_messages: list[OpenAIResponseOutput] = [] output_messages: list[OpenAIResponseOutput] = []
# Execute tool calls if any
for choice in chat_response.choices: for choice in chat_response.choices:
if choice.message.tool_calls and tools: if choice.message.tool_calls and tools:
# Assume if the first tool is a function, all tools are functions # Assume if the first tool is a function, all tools are functions
if isinstance(tools[0], OpenAIResponseInputToolFunction): if tools[0].type == "function":
for tool_call in choice.message.tool_calls: for tool_call in choice.message.tool_calls:
output_messages.append( output_messages.append(
OpenAIResponseOutputMessageFunctionToolCall( OpenAIResponseOutputMessageFunctionToolCall(
@ -312,11 +282,133 @@ class OpenAIResponsesImpl:
) )
) )
else: else:
output_messages.extend( tool_messages = await self._execute_tool_and_return_final_output(choice, ctx)
await self._execute_tool_and_return_final_output(model, stream, choice, messages, temperature) output_messages.extend(tool_messages)
)
else: else:
output_messages.append(await _convert_chat_choice_to_response_message(choice)) output_messages.append(await _convert_chat_choice_to_response_message(choice))
return output_messages
async def _store_response(
self,
response: OpenAIResponseObject,
original_input: str | list[OpenAIResponseInput],
) -> None:
new_input_id = f"msg_{uuid.uuid4()}"
if isinstance(original_input, str):
# synthesize a message from the input string
input_content = OpenAIResponseInputMessageContentText(text=original_input)
input_content_item = OpenAIResponseMessage(
role="user",
content=[input_content],
id=new_input_id,
)
input_items_data = [input_content_item]
else:
# we already have a list of messages
input_items_data = []
for input_item in original_input:
if isinstance(input_item, OpenAIResponseMessage):
# These may or may not already have an id, so dump to dict, check for id, and add if missing
input_item_dict = input_item.model_dump()
if "id" not in input_item_dict:
input_item_dict["id"] = new_input_id
input_items_data.append(OpenAIResponseMessage(**input_item_dict))
else:
input_items_data.append(input_item)
await self.responses_store.store_response_object(
response_object=response,
input=input_items_data,
)
async def create_openai_response(
self,
input: str | list[OpenAIResponseInput],
model: str,
instructions: str | None = None,
previous_response_id: str | None = None,
store: bool | None = True,
stream: bool | None = False,
temperature: float | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
):
stream = False if stream is None else stream
original_input = input # Keep reference for storage
output_messages: list[OpenAIResponseOutput] = []
# Input preprocessing
input = await self._prepend_previous_response(input, previous_response_id)
messages = await _convert_response_input_to_chat_messages(input)
await self._prepend_instructions(messages, instructions)
# Tool setup
chat_tools, mcp_tool_to_server, mcp_list_message = (
await self._convert_response_tools_to_chat_tools(tools) if tools else (None, {}, None)
)
if mcp_list_message:
output_messages.append(mcp_list_message)
ctx = ChatCompletionContext(
model=model,
messages=messages,
tools=chat_tools,
mcp_tool_to_server=mcp_tool_to_server,
stream=stream,
temperature=temperature,
)
inference_result = await self.inference_api.openai_chat_completion(
model=model,
messages=messages,
tools=chat_tools,
stream=stream,
temperature=temperature,
)
if stream:
return self._create_streaming_response(
inference_result=inference_result,
ctx=ctx,
output_messages=output_messages,
original_input=original_input,
model=model,
store=store,
tools=tools,
)
else:
return await self._create_non_streaming_response(
inference_result=inference_result,
ctx=ctx,
output_messages=output_messages,
original_input=original_input,
model=model,
store=store,
tools=tools,
)
async def _create_non_streaming_response(
self,
inference_result: Any,
ctx: ChatCompletionContext,
output_messages: list[OpenAIResponseOutput],
original_input: str | list[OpenAIResponseInput],
model: str,
store: bool | None,
tools: list[OpenAIResponseInputTool] | None,
) -> OpenAIResponseObject:
chat_response = OpenAIChatCompletion(**inference_result.model_dump())
# Process response choices (tool execution and message creation)
output_messages.extend(
await self._process_response_choices(
chat_response=chat_response,
ctx=ctx,
tools=tools,
)
)
response = OpenAIResponseObject( response = OpenAIResponseObject(
created_at=chat_response.created, created_at=chat_response.created,
id=f"resp-{uuid.uuid4()}", id=f"resp-{uuid.uuid4()}",
@ -327,57 +419,168 @@ class OpenAIResponsesImpl:
) )
logger.debug(f"OpenAI Responses response: {response}") logger.debug(f"OpenAI Responses response: {response}")
# Store response if requested
if store: if store:
# Store in kvstore await self._store_response(
new_input_id = f"msg_{uuid.uuid4()}"
if isinstance(input, str):
# synthesize a message from the input string
input_content = OpenAIResponseInputMessageContentText(text=input)
input_content_item = OpenAIResponseMessage(
role="user",
content=[input_content],
id=new_input_id,
)
input_items_data = [input_content_item]
else:
# we already have a list of messages
input_items_data = []
for input_item in input:
if isinstance(input_item, OpenAIResponseMessage):
# These may or may not already have an id, so dump to dict, check for id, and add if missing
input_item_dict = input_item.model_dump()
if "id" not in input_item_dict:
input_item_dict["id"] = new_input_id
input_items_data.append(OpenAIResponseMessage(**input_item_dict))
else:
input_items_data.append(input_item)
input_items = OpenAIResponseInputItemList(data=input_items_data)
prev_response = OpenAIResponsePreviousResponseWithInputItems(
input_items=input_items,
response=response, response=response,
original_input=original_input,
) )
key = f"{OPENAI_RESPONSES_PREFIX}{response.id}"
await self.persistence_store.set(
key=key,
value=prev_response.model_dump_json(),
)
if stream:
async def async_response() -> AsyncIterator[OpenAIResponseObjectStream]:
# TODO: response created should actually get emitted much earlier in the process
yield OpenAIResponseObjectStreamResponseCreated(response=response)
yield OpenAIResponseObjectStreamResponseCompleted(response=response)
return async_response()
return response return response
async def _create_streaming_response(
self,
inference_result: Any,
ctx: ChatCompletionContext,
output_messages: list[OpenAIResponseOutput],
original_input: str | list[OpenAIResponseInput],
model: str,
store: bool | None,
tools: list[OpenAIResponseInputTool] | None,
) -> AsyncIterator[OpenAIResponseObjectStream]:
# Create initial response and emit response.created immediately
response_id = f"resp-{uuid.uuid4()}"
created_at = int(time.time())
initial_response = OpenAIResponseObject(
created_at=created_at,
id=response_id,
model=model,
object="response",
status="in_progress",
output=output_messages.copy(),
)
# Emit response.created immediately
yield OpenAIResponseObjectStreamResponseCreated(response=initial_response)
# For streaming, inference_result is an async iterator of chunks
# Stream chunks and emit delta events as they arrive
chat_response_id = ""
chat_response_content = []
chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {}
chunk_created = 0
chunk_model = ""
chunk_finish_reason = ""
sequence_number = 0
# Create a placeholder message item for delta events
message_item_id = f"msg_{uuid.uuid4()}"
async for chunk in inference_result:
chat_response_id = chunk.id
chunk_created = chunk.created
chunk_model = chunk.model
for chunk_choice in chunk.choices:
# Emit incremental text content as delta events
if chunk_choice.delta.content:
sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputTextDelta(
content_index=0,
delta=chunk_choice.delta.content,
item_id=message_item_id,
output_index=0,
sequence_number=sequence_number,
)
# Collect content for final response
chat_response_content.append(chunk_choice.delta.content or "")
if chunk_choice.finish_reason:
chunk_finish_reason = chunk_choice.finish_reason
# Aggregate tool call arguments across chunks, using their index as the aggregation key
if chunk_choice.delta.tool_calls:
for tool_call in chunk_choice.delta.tool_calls:
response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
if response_tool_call:
response_tool_call.function.arguments += tool_call.function.arguments
else:
tool_call_dict: dict[str, Any] = tool_call.model_dump()
tool_call_dict.pop("type", None)
response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict)
chat_response_tool_calls[tool_call.index] = response_tool_call
# Convert collected chunks to complete response
if chat_response_tool_calls:
tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())]
else:
tool_calls = None
assistant_message = OpenAIAssistantMessageParam(
content="".join(chat_response_content),
tool_calls=tool_calls,
)
chat_response_obj = OpenAIChatCompletion(
id=chat_response_id,
choices=[
OpenAIChoice(
message=assistant_message,
finish_reason=chunk_finish_reason,
index=0,
)
],
created=chunk_created,
model=chunk_model,
)
# Process response choices (tool execution and message creation)
output_messages.extend(
await self._process_response_choices(
chat_response=chat_response_obj,
ctx=ctx,
tools=tools,
)
)
# Create final response
final_response = OpenAIResponseObject(
created_at=created_at,
id=response_id,
model=model,
object="response",
status="completed",
output=output_messages,
)
if store:
await self._store_response(
response=final_response,
original_input=original_input,
)
# Emit response.completed
yield OpenAIResponseObjectStreamResponseCompleted(response=final_response)
async def _convert_response_tools_to_chat_tools( async def _convert_response_tools_to_chat_tools(
self, tools: list[OpenAIResponseInputTool] self, tools: list[OpenAIResponseInputTool]
) -> list[ChatCompletionToolParam]: ) -> tuple[
list[ChatCompletionToolParam],
dict[str, OpenAIResponseInputToolMCP],
OpenAIResponseOutput | None,
]:
from llama_stack.apis.agents.openai_responses import (
MCPListToolsTool,
)
from llama_stack.apis.tools.tools import Tool
mcp_tool_to_server = {}
def make_openai_tool(tool_name: str, tool: Tool) -> ChatCompletionToolParam:
tool_def = ToolDefinition(
tool_name=tool_name,
description=tool.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
)
for param in tool.parameters
},
)
return convert_tooldef_to_openai_tool(tool_def)
mcp_list_message = None
chat_tools: list[ChatCompletionToolParam] = [] chat_tools: list[ChatCompletionToolParam] = []
for input_tool in tools: for input_tool in tools:
# TODO: Handle other tool types # TODO: Handle other tool types
@ -386,91 +589,95 @@ class OpenAIResponsesImpl:
elif input_tool.type == "web_search": elif input_tool.type == "web_search":
tool_name = "web_search" tool_name = "web_search"
tool = await self.tool_groups_api.get_tool(tool_name) tool = await self.tool_groups_api.get_tool(tool_name)
tool_def = ToolDefinition( if not tool:
tool_name=tool_name, raise ValueError(f"Tool {tool_name} not found")
description=tool.description, chat_tools.append(make_openai_tool(tool_name, tool))
parameters={ elif input_tool.type == "mcp":
param.name: ToolParamDefinition( always_allowed = None
param_type=param.parameter_type, never_allowed = None
description=param.description, if input_tool.allowed_tools:
required=param.required, if isinstance(input_tool.allowed_tools, list):
default=param.default, always_allowed = input_tool.allowed_tools
) elif isinstance(input_tool.allowed_tools, AllowedToolsFilter):
for param in tool.parameters always_allowed = input_tool.allowed_tools.always
}, never_allowed = input_tool.allowed_tools.never
tool_defs = await list_mcp_tools(
endpoint=input_tool.server_url,
headers=input_tool.headers or {},
) )
chat_tool = convert_tooldef_to_openai_tool(tool_def)
chat_tools.append(chat_tool) mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
id=f"mcp_list_{uuid.uuid4()}",
status="completed",
server_label=input_tool.server_label,
tools=[],
)
for t in tool_defs.data:
if never_allowed and t.name in never_allowed:
continue
if not always_allowed or t.name in always_allowed:
chat_tools.append(make_openai_tool(t.name, t))
if t.name in mcp_tool_to_server:
raise ValueError(f"Duplicate tool name {t.name} found for server {input_tool.server_label}")
mcp_tool_to_server[t.name] = input_tool
mcp_list_message.tools.append(
MCPListToolsTool(
name=t.name,
description=t.description,
input_schema={
"type": "object",
"properties": {
p.name: {
"type": p.parameter_type,
"description": p.description,
}
for p in t.parameters
},
"required": [p.name for p in t.parameters if p.required],
},
)
)
else: else:
raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}") raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}")
return chat_tools return chat_tools, mcp_tool_to_server, mcp_list_message
async def _execute_tool_and_return_final_output( async def _execute_tool_and_return_final_output(
self, self,
model_id: str,
stream: bool,
choice: OpenAIChoice, choice: OpenAIChoice,
messages: list[OpenAIMessageParam], ctx: ChatCompletionContext,
temperature: float,
) -> list[OpenAIResponseOutput]: ) -> list[OpenAIResponseOutput]:
output_messages: list[OpenAIResponseOutput] = [] output_messages: list[OpenAIResponseOutput] = []
# If the choice is not an assistant message, we don't need to execute any tools
if not isinstance(choice.message, OpenAIAssistantMessageParam): if not isinstance(choice.message, OpenAIAssistantMessageParam):
return output_messages return output_messages
# If the assistant message doesn't have any tool calls, we don't need to execute any tools
if not choice.message.tool_calls: if not choice.message.tool_calls:
return output_messages return output_messages
# Copy the messages list to avoid mutating the original list next_turn_messages = ctx.messages.copy()
messages = messages.copy()
# Add the assistant message with tool_calls response to the messages list # Add the assistant message with tool_calls response to the messages list
messages.append(choice.message) next_turn_messages.append(choice.message)
for tool_call in choice.message.tool_calls: for tool_call in choice.message.tool_calls:
tool_call_id = tool_call.id
function = tool_call.function
# If for some reason the tool call doesn't have a function or id, we can't execute it
if not function or not tool_call_id:
continue
# TODO: telemetry spans for tool calls # TODO: telemetry spans for tool calls
result = await self._execute_tool_call(function) tool_call_log, further_input = await self._execute_tool_call(tool_call, ctx)
if tool_call_log:
# Handle tool call failure output_messages.append(tool_call_log)
if not result: if further_input:
output_messages.append( next_turn_messages.append(further_input)
OpenAIResponseOutputMessageWebSearchToolCall(
id=tool_call_id,
status="failed",
)
)
continue
output_messages.append(
OpenAIResponseOutputMessageWebSearchToolCall(
id=tool_call_id,
status="completed",
),
)
result_content = ""
# TODO: handle other result content types and lists
if isinstance(result.content, str):
result_content = result.content
messages.append(OpenAIToolMessageParam(content=result_content, tool_call_id=tool_call_id))
tool_results_chat_response = await self.inference_api.openai_chat_completion( tool_results_chat_response = await self.inference_api.openai_chat_completion(
model=model_id, model=ctx.model,
messages=messages, messages=next_turn_messages,
stream=stream, stream=ctx.stream,
temperature=temperature, temperature=ctx.temperature,
) )
# type cast to appease mypy # type cast to appease mypy: this is needed because we don't handle streaming properly :)
tool_results_chat_response = cast(OpenAIChatCompletion, tool_results_chat_response) tool_results_chat_response = cast(OpenAIChatCompletion, tool_results_chat_response)
# Huge TODO: these are NOT the final outputs, we must keep the loop going
tool_final_outputs = [ tool_final_outputs = [
await _convert_chat_choice_to_response_message(choice) for choice in tool_results_chat_response.choices await _convert_chat_choice_to_response_message(choice) for choice in tool_results_chat_response.choices
] ]
@ -480,15 +687,86 @@ class OpenAIResponsesImpl:
async def _execute_tool_call( async def _execute_tool_call(
self, self,
function: OpenAIChatCompletionToolCallFunction, tool_call: OpenAIChatCompletionToolCall,
) -> ToolInvocationResult | None: ctx: ChatCompletionContext,
if not function.name: ) -> tuple[OpenAIResponseOutput | None, OpenAIMessageParam | None]:
return None from llama_stack.providers.utils.inference.prompt_adapter import (
function_args = json.loads(function.arguments) if function.arguments else {} interleaved_content_as_str,
logger.info(f"executing tool call: {function.name} with args: {function_args}")
result = await self.tool_runtime_api.invoke_tool(
tool_name=function.name,
kwargs=function_args,
) )
logger.debug(f"tool call {function.name} completed with result: {result}")
return result tool_call_id = tool_call.id
function = tool_call.function
if not function or not tool_call_id or not function.name:
return None, None
error_exc = None
result = None
try:
if function.name in ctx.mcp_tool_to_server:
mcp_tool = ctx.mcp_tool_to_server[function.name]
result = await invoke_mcp_tool(
endpoint=mcp_tool.server_url,
headers=mcp_tool.headers or {},
tool_name=function.name,
kwargs=json.loads(function.arguments) if function.arguments else {},
)
else:
result = await self.tool_runtime_api.invoke_tool(
tool_name=function.name,
kwargs=json.loads(function.arguments) if function.arguments else {},
)
except Exception as e:
error_exc = e
if function.name in ctx.mcp_tool_to_server:
from llama_stack.apis.agents.openai_responses import OpenAIResponseOutputMessageMCPCall
message = OpenAIResponseOutputMessageMCPCall(
id=tool_call_id,
arguments=function.arguments,
name=function.name,
server_label=ctx.mcp_tool_to_server[function.name].server_label,
)
if error_exc:
message.error = str(error_exc)
elif (result.error_code and result.error_code > 0) or result.error_message:
message.error = f"Error (code {result.error_code}): {result.error_message}"
elif result.content:
message.output = interleaved_content_as_str(result.content)
else:
if function.name == "web_search":
message = OpenAIResponseOutputMessageWebSearchToolCall(
id=tool_call_id,
status="completed",
)
if error_exc or (result.error_code and result.error_code > 0) or result.error_message:
message.status = "failed"
else:
raise ValueError(f"Unknown tool {function.name} called")
input_message = None
if result and result.content:
if isinstance(result.content, str):
content = result.content
elif isinstance(result.content, list):
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
content = []
for item in result.content:
if isinstance(item, TextContentItem):
part = OpenAIChatCompletionContentPartTextParam(text=item.text)
elif isinstance(item, ImageContentItem):
if item.image.data:
url = f"data:image;base64,{item.image.data}"
else:
url = item.image.url
part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url))
else:
raise ValueError(f"Unknown result content type: {type(item)}")
content.append(part)
else:
raise ValueError(f"Unknown result content type: {type(result.content)}")
input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id)
return message, input_message

View file

@ -6,6 +6,7 @@
import asyncio import asyncio
import os import os
import sys
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from pydantic import BaseModel from pydantic import BaseModel
@ -455,9 +456,9 @@ class MetaReferenceInferenceImpl(
first = token_results[0] first = token_results[0]
if not first.finished and not first.ignore_token: if not first.finished and not first.ignore_token:
if os.environ.get("LLAMA_MODELS_DEBUG", "0") in ("1", "2"): if os.environ.get("LLAMA_MODELS_DEBUG", "0") in ("1", "2"):
cprint(first.text, "cyan", end="") cprint(first.text, color="cyan", end="", file=sys.stderr)
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2": if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
cprint(f"<{first.token}>", "magenta", end="") cprint(f"<{first.token}>", color="magenta", end="", file=sys.stderr)
for result in token_results: for result in token_results:
idx = result.batch_idx idx = result.batch_idx
@ -519,9 +520,9 @@ class MetaReferenceInferenceImpl(
for token_results in self.generator.chat_completion([request]): for token_results in self.generator.chat_completion([request]):
token_result = token_results[0] token_result = token_results[0]
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1": if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
cprint(token_result.text, "cyan", end="") cprint(token_result.text, color="cyan", end="", file=sys.stderr)
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2": if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
cprint(f"<{token_result.token}>", "magenta", end="") cprint(f"<{token_result.token}>", color="magenta", end="", file=sys.stderr)
if token_result.token == tokenizer.eot_id: if token_result.token == tokenizer.eot_id:
stop_reason = StopReason.end_of_turn stop_reason = StopReason.end_of_turn

View file

@ -16,6 +16,7 @@ from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.semconv.resource import ResourceAttributes from opentelemetry.semconv.resource import ResourceAttributes
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
from llama_stack.apis.telemetry import ( from llama_stack.apis.telemetry import (
Event, Event,
@ -44,6 +45,7 @@ from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor
) )
from llama_stack.providers.utils.telemetry.dataset_mixin import TelemetryDatasetMixin from llama_stack.providers.utils.telemetry.dataset_mixin import TelemetryDatasetMixin
from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTraceStore from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTraceStore
from llama_stack.providers.utils.telemetry.tracing import ROOT_SPAN_MARKERS
from .config import TelemetryConfig, TelemetrySink from .config import TelemetryConfig, TelemetrySink
@ -146,7 +148,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
if span: if span:
timestamp_ns = int(event.timestamp.timestamp() * 1e9) timestamp_ns = int(event.timestamp.timestamp() * 1e9)
span.add_event( span.add_event(
name=event.type, name=event.type.value,
attributes={ attributes={
"message": event.message, "message": event.message,
"severity": event.severity.value, "severity": event.severity.value,
@ -206,6 +208,15 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
event.attributes = {} event.attributes = {}
event.attributes["__ttl__"] = ttl_seconds event.attributes["__ttl__"] = ttl_seconds
# Extract these W3C trace context attributes so they are not written to
# underlying storage, as we just need them to propagate the trace context.
traceparent = event.attributes.pop("traceparent", None)
tracestate = event.attributes.pop("tracestate", None)
if traceparent:
# If we have a traceparent header value, we're not the root span.
for root_attribute in ROOT_SPAN_MARKERS:
event.attributes.pop(root_attribute, None)
if isinstance(event.payload, SpanStartPayload): if isinstance(event.payload, SpanStartPayload):
# Check if span already exists to prevent duplicates # Check if span already exists to prevent duplicates
if span_id in _GLOBAL_STORAGE["active_spans"]: if span_id in _GLOBAL_STORAGE["active_spans"]:
@ -216,8 +227,12 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
parent_span_id = int(event.payload.parent_span_id, 16) parent_span_id = int(event.payload.parent_span_id, 16)
parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id) parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id)
context = trace.set_span_in_context(parent_span) context = trace.set_span_in_context(parent_span)
else: elif traceparent:
event.attributes["__root_span__"] = "true" carrier = {
"traceparent": traceparent,
"tracestate": tracestate,
}
context = TraceContextTextMapPropagator().extract(carrier=carrier)
span = tracer.start_span( span = tracer.start_span(
name=event.payload.name, name=event.payload.name,

View file

@ -25,14 +25,14 @@ from llama_stack.apis.tools import (
RAGQueryConfig, RAGQueryConfig,
RAGQueryResult, RAGQueryResult,
RAGToolRuntime, RAGToolRuntime,
Tool,
ToolDef, ToolDef,
ToolGroup,
ToolInvocationResult, ToolInvocationResult,
ToolParameter, ToolParameter,
ToolRuntime, ToolRuntime,
) )
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import ToolsProtocolPrivate from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
content_from_doc, content_from_doc,
@ -49,7 +49,7 @@ def make_random_string(length: int = 8):
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length)) return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRuntime):
def __init__( def __init__(
self, self,
config: RagToolRuntimeConfig, config: RagToolRuntimeConfig,
@ -66,10 +66,10 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
async def shutdown(self): async def shutdown(self):
pass pass
async def register_tool(self, tool: Tool) -> None: async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
pass pass
async def unregister_tool(self, tool_id: str) -> None: async def unregister_toolgroup(self, toolgroup_id: str) -> None:
return return
async def insert( async def insert(
@ -122,6 +122,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
query=query, query=query,
params={ params={
"max_chunks": query_config.max_chunks, "max_chunks": query_config.max_chunks,
"mode": query_config.mode,
}, },
) )
for vector_db_id in vector_db_ids for vector_db_id in vector_db_ids
@ -146,7 +147,7 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
for i, chunk in enumerate(chunks): for i, chunk in enumerate(chunks):
metadata = chunk.metadata metadata = chunk.metadata
tokens += metadata["token_count"] tokens += metadata["token_count"]
tokens += metadata["metadata_token_count"] tokens += metadata.get("metadata_token_count", 0)
if tokens > query_config.max_tokens_in_context: if tokens > query_config.max_tokens_in_context:
log.error( log.error(

View file

@ -99,9 +99,13 @@ class FaissIndex(EmbeddingIndex):
# Save updated index # Save updated index
await self._save_index() await self._save_index()
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: async def query_vector(
self,
embedding: NDArray,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
distances, indices = await asyncio.to_thread(self.index.search, embedding.reshape(1, -1).astype(np.float32), k) distances, indices = await asyncio.to_thread(self.index.search, embedding.reshape(1, -1).astype(np.float32), k)
chunks = [] chunks = []
scores = [] scores = []
for d, i in zip(distances[0], indices[0], strict=False): for d, i in zip(distances[0], indices[0], strict=False):
@ -112,6 +116,14 @@ class FaissIndex(EmbeddingIndex):
return QueryChunksResponse(chunks=chunks, scores=scores) return QueryChunksResponse(chunks=chunks, scores=scores)
async def query_keyword(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in FAISS")
class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference) -> None: def __init__(self, config: FaissVectorIOConfig, inference_api: Inference) -> None:

View file

@ -24,6 +24,11 @@ from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, Vect
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Specifying search mode is dependent on the VectorIO provider.
VECTOR_SEARCH = "vector"
KEYWORD_SEARCH = "keyword"
SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH}
def serialize_vector(vector: list[float]) -> bytes: def serialize_vector(vector: list[float]) -> bytes:
"""Serialize a list of floats into a compact binary representation.""" """Serialize a list of floats into a compact binary representation."""
@ -45,6 +50,7 @@ class SQLiteVecIndex(EmbeddingIndex):
Two tables are used: Two tables are used:
- A metadata table (chunks_{bank_id}) that holds the chunk JSON. - A metadata table (chunks_{bank_id}) that holds the chunk JSON.
- A virtual table (vec_chunks_{bank_id}) that holds the serialized vector. - A virtual table (vec_chunks_{bank_id}) that holds the serialized vector.
- An FTS5 table (fts_chunks_{bank_id}) for full-text keyword search.
""" """
def __init__(self, dimension: int, db_path: str, bank_id: str): def __init__(self, dimension: int, db_path: str, bank_id: str):
@ -53,6 +59,7 @@ class SQLiteVecIndex(EmbeddingIndex):
self.bank_id = bank_id self.bank_id = bank_id
self.metadata_table = f"chunks_{bank_id}".replace("-", "_") self.metadata_table = f"chunks_{bank_id}".replace("-", "_")
self.vector_table = f"vec_chunks_{bank_id}".replace("-", "_") self.vector_table = f"vec_chunks_{bank_id}".replace("-", "_")
self.fts_table = f"fts_chunks_{bank_id}".replace("-", "_")
@classmethod @classmethod
async def create(cls, dimension: int, db_path: str, bank_id: str): async def create(cls, dimension: int, db_path: str, bank_id: str):
@ -78,6 +85,14 @@ class SQLiteVecIndex(EmbeddingIndex):
USING vec0(embedding FLOAT[{self.dimension}], id TEXT); USING vec0(embedding FLOAT[{self.dimension}], id TEXT);
""") """)
connection.commit() connection.commit()
# FTS5 table (for keyword search) - creating both the tables by default. Will use the relevant one
# based on query. Implementation of the change on client side will allow passing the search_mode option
# during initialization to make it easier to create the table that is required.
cur.execute(f"""
CREATE VIRTUAL TABLE IF NOT EXISTS {self.fts_table}
USING fts5(id, content);
""")
connection.commit()
finally: finally:
cur.close() cur.close()
connection.close() connection.close()
@ -91,6 +106,7 @@ class SQLiteVecIndex(EmbeddingIndex):
try: try:
cur.execute(f"DROP TABLE IF EXISTS {self.metadata_table};") cur.execute(f"DROP TABLE IF EXISTS {self.metadata_table};")
cur.execute(f"DROP TABLE IF EXISTS {self.vector_table};") cur.execute(f"DROP TABLE IF EXISTS {self.vector_table};")
cur.execute(f"DROP TABLE IF EXISTS {self.fts_table};")
connection.commit() connection.commit()
finally: finally:
cur.close() cur.close()
@ -104,6 +120,7 @@ class SQLiteVecIndex(EmbeddingIndex):
For each chunk, we insert its JSON into the metadata table and then insert its For each chunk, we insert its JSON into the metadata table and then insert its
embedding (serialized to raw bytes) into the virtual table using the assigned rowid. embedding (serialized to raw bytes) into the virtual table using the assigned rowid.
If any insert fails, the transaction is rolled back to maintain consistency. If any insert fails, the transaction is rolled back to maintain consistency.
Also inserts chunk content into FTS table for keyword search support.
""" """
assert all(isinstance(chunk.content, str) for chunk in chunks), "SQLiteVecIndex only supports text chunks" assert all(isinstance(chunk.content, str) for chunk in chunks), "SQLiteVecIndex only supports text chunks"
@ -112,18 +129,16 @@ class SQLiteVecIndex(EmbeddingIndex):
cur = connection.cursor() cur = connection.cursor()
try: try:
# Start transaction a single transcation for all batches
cur.execute("BEGIN TRANSACTION") cur.execute("BEGIN TRANSACTION")
for i in range(0, len(chunks), batch_size): for i in range(0, len(chunks), batch_size):
batch_chunks = chunks[i : i + batch_size] batch_chunks = chunks[i : i + batch_size]
batch_embeddings = embeddings[i : i + batch_size] batch_embeddings = embeddings[i : i + batch_size]
# Prepare metadata inserts
# Insert metadata
metadata_data = [ metadata_data = [
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.model_dump_json()) (generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.model_dump_json())
for chunk in batch_chunks for chunk in batch_chunks
if isinstance(chunk.content, str)
] ]
# Insert metadata (ON CONFLICT to avoid duplicates)
cur.executemany( cur.executemany(
f""" f"""
INSERT INTO {self.metadata_table} (id, chunk) INSERT INTO {self.metadata_table} (id, chunk)
@ -132,21 +147,43 @@ class SQLiteVecIndex(EmbeddingIndex):
""", """,
metadata_data, metadata_data,
) )
# Prepare embeddings inserts
# Insert vector embeddings
embedding_data = [ embedding_data = [
( (
generate_chunk_id(chunk.metadata["document_id"], chunk.content), (
serialize_vector(emb.tolist()), generate_chunk_id(chunk.metadata["document_id"], chunk.content),
serialize_vector(emb.tolist()),
)
) )
for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True) for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True)
if isinstance(chunk.content, str)
] ]
# Insert embeddings in batch cur.executemany(
cur.executemany(f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);", embedding_data) f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);",
embedding_data,
)
# Insert FTS content
fts_data = [
(generate_chunk_id(chunk.metadata["document_id"], chunk.content), chunk.content)
for chunk in batch_chunks
]
# DELETE existing entries with same IDs (FTS5 doesn't support ON CONFLICT)
cur.executemany(
f"DELETE FROM {self.fts_table} WHERE id = ?;",
[(row[0],) for row in fts_data],
)
# INSERT new entries
cur.executemany(
f"INSERT INTO {self.fts_table} (id, content) VALUES (?, ?);",
fts_data,
)
connection.commit() connection.commit()
except sqlite3.Error as e: except sqlite3.Error as e:
connection.rollback() # Rollback on failure connection.rollback()
logger.error(f"Error inserting into {self.vector_table}: {e}") logger.error(f"Error inserting into {self.vector_table}: {e}")
raise raise
@ -154,22 +191,25 @@ class SQLiteVecIndex(EmbeddingIndex):
cur.close() cur.close()
connection.close() connection.close()
# Process all batches in a single thread # Run batch insertion in a background thread
await asyncio.to_thread(_execute_all_batch_inserts) await asyncio.to_thread(_execute_all_batch_inserts)
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: async def query_vector(
self,
embedding: NDArray,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
""" """
Query for the k most similar chunks. We convert the query embedding to a blob and run a SQL query Performs vector-based search using a virtual table for vector similarity.
against the virtual table. The SQL joins the metadata table to recover the chunk JSON.
""" """
emb_list = embedding.tolist() if isinstance(embedding, np.ndarray) else list(embedding)
emb_blob = serialize_vector(emb_list)
def _execute_query(): def _execute_query():
connection = _create_sqlite_connection(self.db_path) connection = _create_sqlite_connection(self.db_path)
cur = connection.cursor() cur = connection.cursor()
try: try:
emb_list = embedding.tolist() if isinstance(embedding, np.ndarray) else list(embedding)
emb_blob = serialize_vector(emb_list)
query_sql = f""" query_sql = f"""
SELECT m.id, m.chunk, v.distance SELECT m.id, m.chunk, v.distance
FROM {self.vector_table} AS v FROM {self.vector_table} AS v
@ -184,17 +224,66 @@ class SQLiteVecIndex(EmbeddingIndex):
connection.close() connection.close()
rows = await asyncio.to_thread(_execute_query) rows = await asyncio.to_thread(_execute_query)
chunks, scores = [], [] chunks, scores = [], []
for _id, chunk_json, distance in rows: for row in rows:
_id, chunk_json, distance = row
score = 1.0 / distance if distance != 0 else float("inf")
if score < score_threshold:
continue
try:
chunk = Chunk.model_validate_json(chunk_json)
except Exception as e:
logger.error(f"Error parsing chunk JSON for id {_id}: {e}")
continue
chunks.append(chunk)
scores.append(score)
return QueryChunksResponse(chunks=chunks, scores=scores)
async def query_keyword(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
"""
Performs keyword-based search using SQLite FTS5 for relevance-ranked full-text search.
"""
if query_string is None:
raise ValueError("query_string is required for keyword search.")
def _execute_query():
connection = _create_sqlite_connection(self.db_path)
cur = connection.cursor()
try:
query_sql = f"""
SELECT DISTINCT m.id, m.chunk, bm25({self.fts_table}) AS score
FROM {self.fts_table} AS f
JOIN {self.metadata_table} AS m ON m.id = f.id
WHERE f.content MATCH ?
ORDER BY score ASC
LIMIT ?;
"""
cur.execute(query_sql, (query_string, k))
return cur.fetchall()
finally:
cur.close()
connection.close()
rows = await asyncio.to_thread(_execute_query)
chunks, scores = [], []
for row in rows:
_id, chunk_json, score = row
# BM25 scores returned by sqlite-vec are NEGATED (i.e., more relevant = more negative).
# This design is intentional to simplify sorting by ascending score.
# Reference: https://alexgarcia.xyz/blog/2024/sqlite-vec-hybrid-search/index.html
if score > -score_threshold:
continue
try: try:
chunk = Chunk.model_validate_json(chunk_json) chunk = Chunk.model_validate_json(chunk_json)
except Exception as e: except Exception as e:
logger.error(f"Error parsing chunk JSON for id {_id}: {e}") logger.error(f"Error parsing chunk JSON for id {_id}: {e}")
continue continue
chunks.append(chunk) chunks.append(chunk)
# Mimic the Faiss scoring: score = 1/distance (avoid division by zero)
score = 1.0 / distance if distance != 0 else float("inf")
scores.append(score) scores.append(score)
return QueryChunksResponse(chunks=chunks, scores=scores) return QueryChunksResponse(chunks=chunks, scores=scores)

View file

@ -63,4 +63,14 @@ def available_providers() -> list[ProviderSpec]:
config_class="llama_stack.providers.remote.safety.nvidia.NVIDIASafetyConfig", config_class="llama_stack.providers.remote.safety.nvidia.NVIDIASafetyConfig",
), ),
), ),
remote_provider_spec(
api=Api.safety,
adapter=AdapterSpec(
adapter_type="sambanova",
pip_packages=["litellm"],
module="llama_stack.providers.remote.safety.sambanova",
config_class="llama_stack.providers.remote.safety.sambanova.SambaNovaSafetyConfig",
provider_data_validator="llama_stack.providers.remote.safety.sambanova.config.SambaNovaProviderDataValidator",
),
),
] ]

View file

@ -80,8 +80,9 @@ def available_providers() -> list[ProviderSpec]:
adapter=AdapterSpec( adapter=AdapterSpec(
adapter_type="model-context-protocol", adapter_type="model-context-protocol",
module="llama_stack.providers.remote.tool_runtime.model_context_protocol", module="llama_stack.providers.remote.tool_runtime.model_context_protocol",
config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.ModelContextProtocolConfig", config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderConfig",
pip_packages=["mcp"], pip_packages=["mcp"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderDataValidator",
), ),
), ),
] ]

View file

@ -92,8 +92,11 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
if prompt_logprobs is not None: if prompt_logprobs is not None:
logging.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.") logging.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.")
model_id = (await self.model_store.get_model(model)).provider_resource_id
if model_id.startswith("openai/"):
model_id = model_id[len("openai/") :]
params = await prepare_openai_completion_params( params = await prepare_openai_completion_params(
model=(await self.model_store.get_model(model)).provider_resource_id, model=model_id,
prompt=prompt, prompt=prompt,
best_of=best_of, best_of=best_of,
echo=echo, echo=echo,
@ -139,8 +142,11 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
top_p: float | None = None, top_p: float | None = None,
user: str | None = None, user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
model_id = (await self.model_store.get_model(model)).provider_resource_id
if model_id.startswith("openai/"):
model_id = model_id[len("openai/") :]
params = await prepare_openai_completion_params( params = await prepare_openai_completion_params(
model=(await self.model_store.get_model(model)).provider_resource_id, model=model_id,
messages=messages, messages=messages,
frequency_penalty=frequency_penalty, frequency_penalty=frequency_penalty,
function_call=function_call, function_call=function_call,

View file

@ -4,8 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from pathlib import Path
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, field_validator
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
@ -24,11 +25,27 @@ class VLLMInferenceAdapterConfig(BaseModel):
default="fake", default="fake",
description="The API token", description="The API token",
) )
tls_verify: bool = Field( tls_verify: bool | str = Field(
default=True, default=True,
description="Whether to verify TLS certificates", description="Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file.",
) )
@field_validator("tls_verify")
@classmethod
def validate_tls_verify(cls, v):
if isinstance(v, str):
# Check if it's a boolean string
if v.lower() in ("true", "false"):
return v.lower() == "true"
# Otherwise, treat it as a cert path
cert_path = Path(v).expanduser().resolve()
if not cert_path.exists():
raise ValueError(f"TLS certificate file does not exist: {v}")
if not cert_path.is_file():
raise ValueError(f"TLS certificate path is not a file: {v}")
return v
return v
@classmethod @classmethod
def sample_run_config( def sample_run_config(
cls, cls,

View file

@ -313,7 +313,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
return AsyncOpenAI( return AsyncOpenAI(
base_url=self.config.url, base_url=self.config.url,
api_key=self.config.api_token, api_key=self.config.api_token,
http_client=None if self.config.tls_verify else httpx.AsyncClient(verify=False), http_client=httpx.AsyncClient(verify=self.config.tls_verify),
) )
async def completion( async def completion(

View file

@ -224,7 +224,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
Parameters: Parameters:
training_config: TrainingConfig - Configuration for training training_config: TrainingConfig - Configuration for training
model: str - Model identifier model: str - NeMo Customizer configuration name
algorithm_config: Optional[AlgorithmConfig] - Algorithm-specific configuration algorithm_config: Optional[AlgorithmConfig] - Algorithm-specific configuration
checkpoint_dir: Optional[str] - Directory containing model checkpoints, ignored atm checkpoint_dir: Optional[str] - Directory containing model checkpoints, ignored atm
job_uuid: str - Unique identifier for the job, ignored atm job_uuid: str - Unique identifier for the job, ignored atm
@ -299,9 +299,6 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
User is informed about unsupported parameters via warnings. User is informed about unsupported parameters via warnings.
""" """
# Map model to nvidia model name
# See `_MODEL_ENTRIES` for supported models
nvidia_model = self.get_provider_model_id(model)
# Check for unsupported method parameters # Check for unsupported method parameters
unsupported_method_params = [] unsupported_method_params = []
@ -347,7 +344,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
# Prepare base job configuration # Prepare base job configuration
job_config = { job_config = {
"config": nvidia_model, "config": model,
"dataset": { "dataset": {
"name": training_config["data_config"]["dataset_id"], "name": training_config["data_config"]["dataset_id"],
"namespace": self.config.dataset_namespace, "namespace": self.config.dataset_namespace,

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from .config import SambaNovaSafetyConfig
async def get_adapter_impl(config: SambaNovaSafetyConfig, _deps) -> Any:
from .sambanova import SambaNovaSafetyAdapter
impl = SambaNovaSafetyAdapter(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,37 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from pydantic import BaseModel, Field, SecretStr
from llama_stack.schema_utils import json_schema_type
class SambaNovaProviderDataValidator(BaseModel):
sambanova_api_key: str | None = Field(
default=None,
description="Sambanova Cloud API key",
)
@json_schema_type
class SambaNovaSafetyConfig(BaseModel):
url: str = Field(
default="https://api.sambanova.ai/v1",
description="The URL for the SambaNova AI server",
)
api_key: SecretStr | None = Field(
default=None,
description="The SambaNova cloud API Key",
)
@classmethod
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY}", **kwargs) -> dict[str, Any]:
return {
"url": "https://api.sambanova.ai/v1",
"api_key": api_key,
}

View file

@ -0,0 +1,100 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import logging
from typing import Any
import litellm
import requests
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import (
RunShieldResponse,
Safety,
SafetyViolation,
ViolationLevel,
)
from llama_stack.apis.shields import Shield
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new
from .config import SambaNovaSafetyConfig
logger = logging.getLogger(__name__)
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProviderData):
def __init__(self, config: SambaNovaSafetyConfig) -> None:
self.config = config
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
def _get_api_key(self) -> str:
config_api_key = self.config.api_key if self.config.api_key else None
if config_api_key:
return config_api_key.get_secret_value()
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.sambanova_api_key:
raise ValueError(
'Pass Sambanova API Key in the header X-LlamaStack-Provider-Data as { "sambanova_api_key": <your api key> }'
)
return provider_data.sambanova_api_key
async def register_shield(self, shield: Shield) -> None:
list_models_url = self.config.url + "/models"
try:
response = requests.get(list_models_url)
response.raise_for_status()
except requests.exceptions.RequestException as e:
raise RuntimeError(f"Request to {list_models_url} failed") from e
available_models = [model.get("id") for model in response.json().get("data", {})]
if (
len(available_models) == 0
or "guard" not in shield.provider_resource_id.lower()
or shield.provider_resource_id.split("sambanova/")[-1] not in available_models
):
raise ValueError(f"Shield {shield.provider_resource_id} not found in SambaNova")
async def run_shield(
self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None
) -> RunShieldResponse:
shield = await self.shield_store.get_shield(shield_id)
if not shield:
raise ValueError(f"Shield {shield_id} not found")
shield_params = shield.params
logger.debug(f"run_shield::{shield_params}::messages={messages}")
content_messages = [await convert_message_to_openai_dict_new(m) for m in messages]
logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:")
response = litellm.completion(
model=shield.provider_resource_id, messages=content_messages, api_key=self._get_api_key()
)
shield_message = response.choices[0].message.content
if "unsafe" in shield_message.lower():
user_message = CANNED_RESPONSE_TEXT
violation_type = shield_message.split("\n")[-1]
metadata = {"violation_type": violation_type}
return RunShieldResponse(
violation=SafetyViolation(
user_message=user_message,
violation_level=ViolationLevel.ERROR,
metadata=metadata,
)
)
return RunShieldResponse()

View file

@ -12,19 +12,19 @@ import httpx
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
ListToolDefsResponse, ListToolDefsResponse,
Tool,
ToolDef, ToolDef,
ToolGroup,
ToolInvocationResult, ToolInvocationResult,
ToolParameter, ToolParameter,
ToolRuntime, ToolRuntime,
) )
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ToolsProtocolPrivate from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from .config import BingSearchToolConfig from .config import BingSearchToolConfig
class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): class BingSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
def __init__(self, config: BingSearchToolConfig): def __init__(self, config: BingSearchToolConfig):
self.config = config self.config = config
self.url = "https://api.bing.microsoft.com/v7.0/search" self.url = "https://api.bing.microsoft.com/v7.0/search"
@ -32,10 +32,10 @@ class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestP
async def initialize(self): async def initialize(self):
pass pass
async def register_tool(self, tool: Tool) -> None: async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
pass pass
async def unregister_tool(self, tool_id: str) -> None: async def unregister_toolgroup(self, toolgroup_id: str) -> None:
return return
def _get_api_key(self) -> str: def _get_api_key(self) -> str:

View file

@ -11,30 +11,30 @@ import httpx
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
ListToolDefsResponse, ListToolDefsResponse,
Tool,
ToolDef, ToolDef,
ToolGroup,
ToolInvocationResult, ToolInvocationResult,
ToolParameter, ToolParameter,
ToolRuntime, ToolRuntime,
) )
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.models.llama.datatypes import BuiltinTool from llama_stack.models.llama.datatypes import BuiltinTool
from llama_stack.providers.datatypes import ToolsProtocolPrivate from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from .config import BraveSearchToolConfig from .config import BraveSearchToolConfig
class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): class BraveSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
def __init__(self, config: BraveSearchToolConfig): def __init__(self, config: BraveSearchToolConfig):
self.config = config self.config = config
async def initialize(self): async def initialize(self):
pass pass
async def register_tool(self, tool: Tool) -> None: async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
pass pass
async def unregister_tool(self, tool_id: str) -> None: async def unregister_toolgroup(self, toolgroup_id: str) -> None:
return return
def _get_api_key(self) -> str: def _get_api_key(self) -> str:

View file

@ -4,18 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from pydantic import BaseModel from .config import MCPProviderConfig
from .config import ModelContextProtocolConfig
class ModelContextProtocolToolProviderDataValidator(BaseModel): async def get_adapter_impl(config: MCPProviderConfig, _deps):
api_key: str
async def get_adapter_impl(config: ModelContextProtocolConfig, _deps):
from .model_context_protocol import ModelContextProtocolToolRuntimeImpl from .model_context_protocol import ModelContextProtocolToolRuntimeImpl
impl = ModelContextProtocolToolRuntimeImpl(config) impl = ModelContextProtocolToolRuntimeImpl(config, _deps)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -9,7 +9,12 @@ from typing import Any
from pydantic import BaseModel from pydantic import BaseModel
class ModelContextProtocolConfig(BaseModel): class MCPProviderDataValidator(BaseModel):
# mcp_endpoint => dict of headers to send
mcp_headers: dict[str, dict[str, str]] | None = None
class MCPProviderConfig(BaseModel):
@classmethod @classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]: def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {} return {}

View file

@ -7,61 +7,45 @@
from typing import Any from typing import Any
from urllib.parse import urlparse from urllib.parse import urlparse
from mcp import ClientSession
from mcp.client.sse import sse_client
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.datatypes import Api
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
ListToolDefsResponse, ListToolDefsResponse,
ToolDef, ToolGroup,
ToolInvocationResult, ToolInvocationResult,
ToolParameter,
ToolRuntime, ToolRuntime,
) )
from llama_stack.providers.datatypes import ToolsProtocolPrivate from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool, list_mcp_tools
from .config import ModelContextProtocolConfig from .config import MCPProviderConfig
logger = get_logger(__name__, category="tools")
class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): class ModelContextProtocolToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
def __init__(self, config: ModelContextProtocolConfig): def __init__(self, config: MCPProviderConfig, _deps: dict[Api, Any]):
self.config = config self.config = config
async def initialize(self): async def initialize(self):
pass pass
async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
pass
async def unregister_toolgroup(self, toolgroup_id: str) -> None:
return
async def list_runtime_tools( async def list_runtime_tools(
self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None self, tool_group_id: str | None = None, mcp_endpoint: URL | None = None
) -> ListToolDefsResponse: ) -> ListToolDefsResponse:
# this endpoint should be retrieved by getting the tool group right?
if mcp_endpoint is None: if mcp_endpoint is None:
raise ValueError("mcp_endpoint is required") raise ValueError("mcp_endpoint is required")
headers = await self.get_headers_from_request(mcp_endpoint.uri)
tools = [] return await list_mcp_tools(mcp_endpoint.uri, headers)
async with sse_client(mcp_endpoint.uri) as streams:
async with ClientSession(*streams) as session:
await session.initialize()
tools_result = await session.list_tools()
for tool in tools_result.tools:
parameters = []
for param_name, param_schema in tool.inputSchema.get("properties", {}).items():
parameters.append(
ToolParameter(
name=param_name,
parameter_type=param_schema.get("type", "string"),
description=param_schema.get("description", ""),
)
)
tools.append(
ToolDef(
name=tool.name,
description=tool.description,
parameters=parameters,
metadata={
"endpoint": mcp_endpoint.uri,
},
)
)
return ListToolDefsResponse(data=tools)
async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult: async def invoke_tool(self, tool_name: str, kwargs: dict[str, Any]) -> ToolInvocationResult:
tool = await self.tool_store.get_tool(tool_name) tool = await self.tool_store.get_tool(tool_name)
@ -71,12 +55,19 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
if urlparse(endpoint).scheme not in ("http", "https"): if urlparse(endpoint).scheme not in ("http", "https"):
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL") raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
async with sse_client(endpoint) as streams: headers = await self.get_headers_from_request(endpoint)
async with ClientSession(*streams) as session: return await invoke_mcp_tool(endpoint, headers, tool_name, kwargs)
await session.initialize()
result = await session.call_tool(tool.identifier, kwargs)
return ToolInvocationResult( async def get_headers_from_request(self, mcp_endpoint_uri: str) -> dict[str, str]:
content="\n".join([result.model_dump_json() for result in result.content]), def canonicalize_uri(uri: str) -> str:
error_code=1 if result.isError else 0, return f"{urlparse(uri).netloc or ''}/{urlparse(uri).path or ''}"
)
headers = {}
provider_data = self.get_request_provider_data()
if provider_data and provider_data.mcp_headers:
for uri, values in provider_data.mcp_headers.items():
if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri):
continue
headers.update(values)
return headers

View file

@ -12,29 +12,29 @@ import httpx
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
ListToolDefsResponse, ListToolDefsResponse,
Tool,
ToolDef, ToolDef,
ToolGroup,
ToolInvocationResult, ToolInvocationResult,
ToolParameter, ToolParameter,
ToolRuntime, ToolRuntime,
) )
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ToolsProtocolPrivate from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from .config import TavilySearchToolConfig from .config import TavilySearchToolConfig
class TavilySearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): class TavilySearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
def __init__(self, config: TavilySearchToolConfig): def __init__(self, config: TavilySearchToolConfig):
self.config = config self.config = config
async def initialize(self): async def initialize(self):
pass pass
async def register_tool(self, tool: Tool) -> None: async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
pass pass
async def unregister_tool(self, tool_id: str) -> None: async def unregister_toolgroup(self, toolgroup_id: str) -> None:
return return
def _get_api_key(self) -> str: def _get_api_key(self) -> str:

View file

@ -12,19 +12,19 @@ import httpx
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
ListToolDefsResponse, ListToolDefsResponse,
Tool,
ToolDef, ToolDef,
ToolGroup,
ToolInvocationResult, ToolInvocationResult,
ToolParameter, ToolParameter,
ToolRuntime, ToolRuntime,
) )
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ToolsProtocolPrivate from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from .config import WolframAlphaToolConfig from .config import WolframAlphaToolConfig
class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData): class WolframAlphaToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, NeedsRequestProviderData):
def __init__(self, config: WolframAlphaToolConfig): def __init__(self, config: WolframAlphaToolConfig):
self.config = config self.config = config
self.url = "https://api.wolframalpha.com/v2/query" self.url = "https://api.wolframalpha.com/v2/query"
@ -32,10 +32,10 @@ class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques
async def initialize(self): async def initialize(self):
pass pass
async def register_tool(self, tool: Tool) -> None: async def register_toolgroup(self, toolgroup: ToolGroup) -> None:
pass pass
async def unregister_tool(self, tool_id: str) -> None: async def unregister_toolgroup(self, toolgroup_id: str) -> None:
return return
def _get_api_key(self) -> str: def _get_api_key(self) -> str:

View file

@ -84,6 +84,14 @@ class ChromaIndex(EmbeddingIndex):
async def delete(self): async def delete(self):
await maybe_await(self.client.delete_collection(self.collection.name)) await maybe_await(self.client.delete_collection(self.collection.name))
async def query_keyword(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in Chroma")
class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__( def __init__(

View file

@ -73,7 +73,7 @@ class MilvusIndex(EmbeddingIndex):
logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}") logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}")
raise e raise e
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
search_res = await asyncio.to_thread( search_res = await asyncio.to_thread(
self.client.search, self.client.search,
collection_name=self.collection_name, collection_name=self.collection_name,
@ -86,6 +86,14 @@ class MilvusIndex(EmbeddingIndex):
scores = [res["distance"] for res in search_res[0]] scores = [res["distance"] for res in search_res[0]]
return QueryChunksResponse(chunks=chunks, scores=scores) return QueryChunksResponse(chunks=chunks, scores=scores)
async def query_keyword(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in Milvus")
class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__( def __init__(

View file

@ -99,7 +99,7 @@ class PGVectorIndex(EmbeddingIndex):
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
execute_values(cur, query, values, template="(%s, %s, %s::vector)") execute_values(cur, query, values, template="(%s, %s, %s::vector)")
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute( cur.execute(
f""" f"""
@ -120,6 +120,14 @@ class PGVectorIndex(EmbeddingIndex):
return QueryChunksResponse(chunks=chunks, scores=scores) return QueryChunksResponse(chunks=chunks, scores=scores)
async def query_keyword(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in PGVector")
async def delete(self): async def delete(self):
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}") cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")

View file

@ -68,7 +68,7 @@ class QdrantIndex(EmbeddingIndex):
await self.client.upsert(collection_name=self.collection_name, points=points) await self.client.upsert(collection_name=self.collection_name, points=points)
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
results = ( results = (
await self.client.query_points( await self.client.query_points(
collection_name=self.collection_name, collection_name=self.collection_name,
@ -95,6 +95,14 @@ class QdrantIndex(EmbeddingIndex):
return QueryChunksResponse(chunks=chunks, scores=scores) return QueryChunksResponse(chunks=chunks, scores=scores)
async def query_keyword(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in Qdrant")
async def delete(self): async def delete(self):
await self.client.delete_collection(collection_name=self.collection_name) await self.client.delete_collection(collection_name=self.collection_name)

View file

@ -55,7 +55,7 @@ class WeaviateIndex(EmbeddingIndex):
# TODO: make this async friendly # TODO: make this async friendly
collection.data.insert_many(data_objects) collection.data.insert_many(data_objects)
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
collection = self.client.collections.get(self.collection_name) collection = self.client.collections.get(self.collection_name)
results = collection.query.near_vector( results = collection.query.near_vector(
@ -84,6 +84,14 @@ class WeaviateIndex(EmbeddingIndex):
collection = self.client.collections.get(self.collection_name) collection = self.client.collections.get(self.collection_name)
collection.data.delete_many(where=Filter.by_property("id").contains_any(chunk_ids)) collection.data.delete_many(where=Filter.by_property("id").contains_any(chunk_ids))
async def query_keyword(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in Weaviate")
class WeaviateVectorIOAdapter( class WeaviateVectorIOAdapter(
VectorIO, VectorIO,

Some files were not shown because too many files have changed in this diff Show more