Merge branch 'main' into toolcall-arg-recursive-type

This commit is contained in:
Ben Keith 2025-11-04 11:18:07 -05:00 committed by GitHub
commit 56aa508f82
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
443 changed files with 98114 additions and 79944 deletions

View file

@ -0,0 +1,60 @@
name: Install llama-stack-client
description: Install llama-stack-client based on branch context and client-version input
inputs:
client-version:
description: 'Client version to install on non-release branches (latest or published). Ignored on release branches.'
required: false
default: ""
outputs:
uv-extra-index-url:
description: 'UV_EXTRA_INDEX_URL to use (set for release branches)'
value: ${{ steps.configure.outputs.uv-extra-index-url }}
install-after-sync:
description: 'Whether to install client after uv sync'
value: ${{ steps.configure.outputs.install-after-sync }}
install-source:
description: 'Where to install client from after sync'
value: ${{ steps.configure.outputs.install-source }}
runs:
using: "composite"
steps:
- name: Configure client installation
id: configure
shell: bash
run: |
# Determine the branch we're working with
BRANCH="${{ github.base_ref || github.ref }}"
BRANCH="${BRANCH#refs/heads/}"
echo "Working with branch: $BRANCH"
# On release branches: use test.pypi for uv sync, then install from git
# On non-release branches: install based on client-version after sync
if [[ "$BRANCH" =~ ^release-[0-9]+\.[0-9]+\.x$ ]]; then
echo "Detected release branch: $BRANCH"
# Check if matching branch exists in client repo
if ! git ls-remote --exit-code --heads https://github.com/llamastack/llama-stack-client-python.git "$BRANCH" > /dev/null 2>&1; then
echo "::error::Branch $BRANCH not found in llama-stack-client-python repository"
echo "::error::Please create the matching release branch in llama-stack-client-python before testing"
exit 1
fi
# Configure to use test.pypi as extra index (PyPI is primary)
echo "uv-extra-index-url=https://test.pypi.org/simple/" >> $GITHUB_OUTPUT
echo "install-after-sync=true" >> $GITHUB_OUTPUT
echo "install-source=git+https://github.com/llamastack/llama-stack-client-python.git@$BRANCH" >> $GITHUB_OUTPUT
elif [ "${{ inputs.client-version }}" = "latest" ]; then
# Install from main git after sync
echo "install-after-sync=true" >> $GITHUB_OUTPUT
echo "install-source=git+https://github.com/llamastack/llama-stack-client-python.git@main" >> $GITHUB_OUTPUT
elif [ "${{ inputs.client-version }}" = "published" ]; then
# Use published version from PyPI (installed by sync)
echo "install-after-sync=false" >> $GITHUB_OUTPUT
elif [ -n "${{ inputs.client-version }}" ]; then
echo "::error::Invalid client-version: ${{ inputs.client-version }}"
exit 1
fi

View file

@ -94,7 +94,7 @@ runs:
if: ${{ always() }}
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: logs-${{ github.run_id }}-${{ github.run_attempt || '' }}-${{ strategy.job-index }}
name: logs-${{ github.run_id }}-${{ github.run_attempt || '1' }}-${{ strategy.job-index || github.job }}-${{ github.action }}
path: |
*.log
retention-days: 1

View file

@ -18,25 +18,35 @@ runs:
python-version: ${{ inputs.python-version }}
version: 0.7.6
- name: Configure client installation
id: client-config
uses: ./.github/actions/install-llama-stack-client
with:
client-version: ${{ inputs.client-version }}
- name: Install dependencies
shell: bash
env:
UV_EXTRA_INDEX_URL: ${{ steps.client-config.outputs.uv-extra-index-url }}
run: |
# Export UV env vars for current step and persist to GITHUB_ENV for subsequent steps
if [ -n "$UV_EXTRA_INDEX_URL" ]; then
export UV_INDEX_STRATEGY=unsafe-best-match
echo "UV_EXTRA_INDEX_URL=$UV_EXTRA_INDEX_URL" >> $GITHUB_ENV
echo "UV_INDEX_STRATEGY=$UV_INDEX_STRATEGY" >> $GITHUB_ENV
echo "Exported UV environment variables for current and subsequent steps"
fi
echo "Updating project dependencies via uv sync"
uv sync --all-groups
echo "Installing ad-hoc dependencies"
uv pip install faiss-cpu
# Install llama-stack-client-python based on the client-version input
if [ "${{ inputs.client-version }}" = "latest" ]; then
echo "Installing latest llama-stack-client-python from main branch"
uv pip install git+https://github.com/llamastack/llama-stack-client-python.git@main
elif [ "${{ inputs.client-version }}" = "published" ]; then
echo "Installing published llama-stack-client-python from PyPI"
uv pip install llama-stack-client
else
echo "Invalid client-version: ${{ inputs.client-version }}"
exit 1
# Install specific client version after sync if needed
if [ "${{ steps.client-config.outputs.install-after-sync }}" = "true" ]; then
echo "Installing llama-stack-client from: ${{ steps.client-config.outputs.install-source }}"
uv pip install ${{ steps.client-config.outputs.install-source }}
fi
echo "Installed llama packages"

View file

@ -42,18 +42,7 @@ runs:
- name: Build Llama Stack
shell: bash
run: |
# Install llama-stack-client-python based on the client-version input
if [ "${{ inputs.client-version }}" = "latest" ]; then
echo "Installing latest llama-stack-client-python from main branch"
export LLAMA_STACK_CLIENT_DIR=git+https://github.com/llamastack/llama-stack-client-python.git@main
elif [ "${{ inputs.client-version }}" = "published" ]; then
echo "Installing published llama-stack-client-python from PyPI"
unset LLAMA_STACK_CLIENT_DIR
else
echo "Invalid client-version: ${{ inputs.client-version }}"
exit 1
fi
# Client is already installed by setup-runner (handles both main and release branches)
echo "Building Llama Stack"
LLAMA_STACK_DIR=. \

23
.github/mergify.yml vendored Normal file
View file

@ -0,0 +1,23 @@
pull_request_rules:
- name: ping author on conflicts and add 'needs-rebase' label
conditions:
- conflict
- -closed
actions:
label:
add:
- needs-rebase
comment:
message: >
This pull request has merge conflicts that must be resolved before it
can be merged. @{{author}} please rebase it.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
- name: remove 'needs-rebase' label when conflict is resolved
conditions:
- -conflict
- -closed
actions:
label:
remove:
- needs-rebase

View file

@ -13,7 +13,6 @@ Llama Stack uses GitHub Actions for Continuous Integration (CI). Below is a tabl
| Integration Tests (Replay) | [integration-tests.yml](integration-tests.yml) | Run the integration test suites from tests/integration in replay mode |
| Vector IO Integration Tests | [integration-vector-io-tests.yml](integration-vector-io-tests.yml) | Run the integration test suite with various VectorIO providers |
| Pre-commit | [pre-commit.yml](pre-commit.yml) | Run pre-commit checks |
| Pre-commit Bot | [precommit-trigger.yml](precommit-trigger.yml) | Pre-commit bot for PR |
| Test Llama Stack Build | [providers-build.yml](providers-build.yml) | Test llama stack build |
| Test llama stack list-deps | [providers-list-deps.yml](providers-list-deps.yml) | Test llama stack list-deps |
| Python Package Build Test | [python-build-test.yml](python-build-test.yml) | Test building the llama-stack PyPI project |

View file

@ -4,7 +4,11 @@ run-name: Check backward compatibility for run.yaml configs
on:
pull_request:
branches: [main]
branches:
- main
- 'release-[0-9]+.[0-9]+.[0-9]+.[0-9]+'
- 'release-[0-9]+.[0-9]+.[0-9]+'
- 'release-[0-9]+.[0-9]+'
paths:
- 'src/llama_stack/core/datatypes.py'
- 'src/llama_stack/providers/datatypes.py'
@ -33,7 +37,7 @@ jobs:
python-version: '3.12'
- name: Install uv
uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1
uses: astral-sh/setup-uv@85856786d1ce8acfbcc2f13a5f3fbd6b938f9f41 # v7.1.2
with:
enable-cache: true
@ -411,7 +415,7 @@ jobs:
python-version: '3.12'
- name: Install uv
uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1
uses: astral-sh/setup-uv@85856786d1ce8acfbcc2f13a5f3fbd6b938f9f41 # v7.1.2
with:
enable-cache: true

View file

@ -22,7 +22,6 @@ on:
- 'docs/static/stable-llama-stack-spec.yaml' # Stable APIs spec
- 'docs/static/experimental-llama-stack-spec.yaml' # Experimental APIs spec
- 'docs/static/deprecated-llama-stack-spec.yaml' # Deprecated APIs spec
- 'docs/static/llama-stack-spec.html' # Legacy HTML spec
- '.github/workflows/conformance.yml' # This workflow itself
concurrency:

View file

@ -30,10 +30,16 @@ jobs:
- name: Build a single provider
run: |
BUILD_ARGS="--build-arg INSTALL_MODE=editable --build-arg DISTRO_NAME=starter"
if [ -n "${UV_EXTRA_INDEX_URL:-}" ]; then
BUILD_ARGS="$BUILD_ARGS --build-arg UV_EXTRA_INDEX_URL=$UV_EXTRA_INDEX_URL"
fi
if [ -n "${UV_INDEX_STRATEGY:-}" ]; then
BUILD_ARGS="$BUILD_ARGS --build-arg UV_INDEX_STRATEGY=$UV_INDEX_STRATEGY"
fi
docker build . \
-f containers/Containerfile \
--build-arg INSTALL_MODE=editable \
--build-arg DISTRO_NAME=starter \
$BUILD_ARGS \
--tag llama-stack:starter-ci
- name: Run installer end-to-end

View file

@ -4,9 +4,13 @@ run-name: Run the integration test suite with Kubernetes authentication
on:
push:
branches: [ main ]
branches:
- main
- 'release-[0-9]+.[0-9]+.x'
pull_request:
branches: [ main ]
branches:
- main
- 'release-[0-9]+.[0-9]+.x'
paths:
- 'distributions/**'
- 'src/llama_stack/**'

View file

@ -4,9 +4,13 @@ run-name: Run the integration test suite with SqlStore
on:
push:
branches: [ main ]
branches:
- main
- 'release-[0-9]+.[0-9]+.x'
pull_request:
branches: [ main ]
branches:
- main
- 'release-[0-9]+.[0-9]+.x'
paths:
- 'src/llama_stack/providers/utils/sqlstore/**'
- 'tests/integration/sqlstore/**'

View file

@ -4,9 +4,13 @@ run-name: Run the integration test suites from tests/integration in replay mode
on:
push:
branches: [ main ]
branches:
- main
- 'release-[0-9]+.[0-9]+.x'
pull_request:
branches: [ main ]
branches:
- main
- 'release-[0-9]+.[0-9]+.x'
types: [opened, synchronize, reopened]
paths:
- 'src/llama_stack/**'
@ -18,6 +22,7 @@ on:
- '.github/actions/setup-ollama/action.yml'
- '.github/actions/setup-test-environment/action.yml'
- '.github/actions/run-and-record-tests/action.yml'
- 'scripts/integration-tests.sh'
schedule:
# If changing the cron schedule, update the provider in the test-matrix job
- cron: '0 0 * * *' # (test latest client) Daily at 12 AM UTC
@ -47,7 +52,7 @@ jobs:
strategy:
fail-fast: false
matrix:
client-type: [library, docker]
client-type: [library, docker, server]
# Use Python 3.13 only on nightly schedule (daily latest client test), otherwise use 3.12
python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }}
client-version: ${{ (github.event.schedule == '0 0 * * *' || github.event.inputs.test-all-client-versions == 'true') && fromJSON('["published", "latest"]') || fromJSON('["latest"]') }}

View file

@ -4,9 +4,13 @@ run-name: Run the integration test suite with various VectorIO providers
on:
push:
branches: [ main ]
branches:
- main
- 'release-[0-9]+.[0-9]+.x'
pull_request:
branches: [ main ]
branches:
- main
- 'release-[0-9]+.[0-9]+.x'
paths:
- 'src/llama_stack/**'
- '!src/llama_stack/ui/**'

View file

@ -5,7 +5,9 @@ run-name: Run pre-commit checks
on:
pull_request:
push:
branches: [main]
branches:
- main
- 'release-[0-9]+.[0-9]+.x'
concurrency:
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
@ -43,23 +45,41 @@ jobs:
cache: 'npm'
cache-dependency-path: 'src/llama_stack/ui/'
- name: Set up uv
uses: astral-sh/setup-uv@85856786d1ce8acfbcc2f13a5f3fbd6b938f9f41 # v7.1.2
- name: Install npm dependencies
run: npm ci
working-directory: src/llama_stack/ui
- name: Install pre-commit
run: python -m pip install pre-commit
- name: Cache pre-commit
uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # v4
with:
path: ~/.cache/pre-commit
key: pre-commit-3|${{ env.pythonLocation }}|${{ hashFiles('.pre-commit-config.yaml') }}
- name: Run pre-commit
id: precommit
uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
continue-on-error: true
run: |
set +e
pre-commit run --show-diff-on-failure --color=always --all-files 2>&1 | tee /tmp/precommit.log
status=${PIPESTATUS[0]}
echo "status=$status" >> $GITHUB_OUTPUT
exit 0
env:
SKIP: no-commit-to-branch
SKIP: no-commit-to-branch,mypy
RUFF_OUTPUT_FORMAT: github
- name: Check pre-commit results
if: steps.precommit.outcome == 'failure'
if: steps.precommit.outputs.status != '0'
run: |
echo "::error::Pre-commit hooks failed. Please run 'pre-commit run --all-files' locally and commit the fixes."
echo "::warning::Some pre-commit hooks failed. Check the output above for details."
echo ""
echo "Failed hooks output:"
cat /tmp/precommit.log
exit 1
- name: Debug
@ -109,3 +129,39 @@ jobs:
echo "$unstaged_files"
exit 1
fi
- name: Configure client installation
id: client-config
uses: ./.github/actions/install-llama-stack-client
- name: Sync dev + type_checking dependencies
env:
UV_EXTRA_INDEX_URL: ${{ steps.client-config.outputs.uv-extra-index-url }}
run: |
if [ -n "$UV_EXTRA_INDEX_URL" ]; then
export UV_INDEX_STRATEGY="unsafe-best-match"
fi
uv sync --group dev --group type_checking
# Install specific client version after sync if needed
if [ "${{ steps.client-config.outputs.install-after-sync }}" = "true" ]; then
echo "Installing llama-stack-client from: ${{ steps.client-config.outputs.install-source }}"
uv pip install ${{ steps.client-config.outputs.install-source }}
fi
- name: Run mypy (full type_checking)
env:
UV_EXTRA_INDEX_URL: ${{ steps.client-config.outputs.uv-extra-index-url }}
run: |
if [ -n "$UV_EXTRA_INDEX_URL" ]; then
export UV_INDEX_STRATEGY="unsafe-best-match"
fi
set +e
uv run --group dev --group type_checking mypy
status=$?
if [ $status -ne 0 ]; then
echo "::error::Full mypy failed. Reproduce locally with 'uv run pre-commit run mypy-full --hook-stage manual --all-files'."
fi
exit $status

View file

@ -1,227 +0,0 @@
name: Pre-commit Bot
run-name: Pre-commit bot for PR #${{ github.event.issue.number }}
on:
issue_comment:
types: [created]
jobs:
pre-commit:
# Only run on pull request comments
if: github.event.issue.pull_request && contains(github.event.comment.body, '@github-actions run precommit')
runs-on: ubuntu-latest
permissions:
contents: write
pull-requests: write
steps:
- name: Check comment author and get PR details
id: check_author
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
// Get PR details
const pr = await github.rest.pulls.get({
owner: context.repo.owner,
repo: context.repo.repo,
pull_number: context.issue.number
});
// Check if commenter has write access or is the PR author
const commenter = context.payload.comment.user.login;
const prAuthor = pr.data.user.login;
let hasPermission = false;
// Check if commenter is PR author
if (commenter === prAuthor) {
hasPermission = true;
console.log(`Comment author ${commenter} is the PR author`);
} else {
// Check if commenter has write/admin access
try {
const permission = await github.rest.repos.getCollaboratorPermissionLevel({
owner: context.repo.owner,
repo: context.repo.repo,
username: commenter
});
const level = permission.data.permission;
hasPermission = ['write', 'admin', 'maintain'].includes(level);
console.log(`Comment author ${commenter} has permission: ${level}`);
} catch (error) {
console.log(`Could not check permissions for ${commenter}: ${error.message}`);
}
}
if (!hasPermission) {
await github.rest.issues.createComment({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: context.issue.number,
body: `❌ @${commenter} You don't have permission to trigger pre-commit. Only PR authors or repository collaborators can run this command.`
});
core.setFailed(`User ${commenter} does not have permission`);
return;
}
// Save PR info for later steps
core.setOutput('pr_number', context.issue.number);
core.setOutput('pr_head_ref', pr.data.head.ref);
core.setOutput('pr_head_sha', pr.data.head.sha);
core.setOutput('pr_head_repo', pr.data.head.repo.full_name);
core.setOutput('pr_base_ref', pr.data.base.ref);
core.setOutput('is_fork', pr.data.head.repo.full_name !== context.payload.repository.full_name);
core.setOutput('authorized', 'true');
- name: React to comment
if: steps.check_author.outputs.authorized == 'true'
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
await github.rest.reactions.createForIssueComment({
owner: context.repo.owner,
repo: context.repo.repo,
comment_id: context.payload.comment.id,
content: 'rocket'
});
- name: Comment starting
if: steps.check_author.outputs.authorized == 'true'
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
await github.rest.issues.createComment({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: ${{ steps.check_author.outputs.pr_number }},
body: `⏳ Running [pre-commit hooks](https://github.com/${context.repo.owner}/${context.repo.repo}/actions/runs/${context.runId}) on PR #${{ steps.check_author.outputs.pr_number }}...`
});
- name: Checkout PR branch (same-repo)
if: steps.check_author.outputs.authorized == 'true' && steps.check_author.outputs.is_fork == 'false'
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
with:
ref: ${{ steps.check_author.outputs.pr_head_ref }}
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
- name: Checkout PR branch (fork)
if: steps.check_author.outputs.authorized == 'true' && steps.check_author.outputs.is_fork == 'true'
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
with:
repository: ${{ steps.check_author.outputs.pr_head_repo }}
ref: ${{ steps.check_author.outputs.pr_head_ref }}
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
- name: Verify checkout
if: steps.check_author.outputs.authorized == 'true'
run: |
echo "Current SHA: $(git rev-parse HEAD)"
echo "Expected SHA: ${{ steps.check_author.outputs.pr_head_sha }}"
if [[ "$(git rev-parse HEAD)" != "${{ steps.check_author.outputs.pr_head_sha }}" ]]; then
echo "::error::Checked out SHA does not match expected SHA"
exit 1
fi
- name: Set up Python
if: steps.check_author.outputs.authorized == 'true'
uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0
with:
python-version: '3.12'
cache: pip
cache-dependency-path: |
**/requirements*.txt
.pre-commit-config.yaml
- name: Set up Node.js
if: steps.check_author.outputs.authorized == 'true'
uses: actions/setup-node@2028fbc5c25fe9cf00d9f06a71cc4710d4507903 # v6.0.0
with:
node-version: '20'
cache: 'npm'
cache-dependency-path: 'src/llama_stack/ui/'
- name: Install npm dependencies
if: steps.check_author.outputs.authorized == 'true'
run: npm ci
working-directory: src/llama_stack/ui
- name: Run pre-commit
if: steps.check_author.outputs.authorized == 'true'
id: precommit
uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
continue-on-error: true
env:
SKIP: no-commit-to-branch
RUFF_OUTPUT_FORMAT: github
- name: Check for changes
if: steps.check_author.outputs.authorized == 'true'
id: changes
run: |
if ! git diff --exit-code || [ -n "$(git ls-files --others --exclude-standard)" ]; then
echo "has_changes=true" >> $GITHUB_OUTPUT
echo "Changes detected after pre-commit"
else
echo "has_changes=false" >> $GITHUB_OUTPUT
echo "No changes after pre-commit"
fi
- name: Commit and push changes
if: steps.check_author.outputs.authorized == 'true' && steps.changes.outputs.has_changes == 'true'
run: |
git config --local user.email "github-actions[bot]@users.noreply.github.com"
git config --local user.name "github-actions[bot]"
git add -A
git commit -m "style: apply pre-commit fixes
🤖 Applied by @github-actions bot via pre-commit workflow"
# Push changes
git push origin HEAD:${{ steps.check_author.outputs.pr_head_ref }}
- name: Comment success with changes
if: steps.check_author.outputs.authorized == 'true' && steps.changes.outputs.has_changes == 'true'
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
await github.rest.issues.createComment({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: ${{ steps.check_author.outputs.pr_number }},
body: `✅ Pre-commit hooks completed successfully!\n\n🔧 Changes have been committed and pushed to the PR branch.`
});
- name: Comment success without changes
if: steps.check_author.outputs.authorized == 'true' && steps.changes.outputs.has_changes == 'false' && steps.precommit.outcome == 'success'
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
await github.rest.issues.createComment({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: ${{ steps.check_author.outputs.pr_number }},
body: `✅ Pre-commit hooks passed!\n\n✨ No changes needed - your code is already formatted correctly.`
});
- name: Comment failure
if: failure()
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
await github.rest.issues.createComment({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: ${{ steps.check_author.outputs.pr_number }},
body: `❌ Pre-commit workflow failed!\n\nPlease check the [workflow logs](https://github.com/${context.repo.owner}/${context.repo.repo}/actions/runs/${context.runId}) for details.`
});

View file

@ -72,10 +72,16 @@ jobs:
- name: Build container image
if: matrix.image-type == 'container'
run: |
BUILD_ARGS="--build-arg INSTALL_MODE=editable --build-arg DISTRO_NAME=${{ matrix.distro }}"
if [ -n "${UV_EXTRA_INDEX_URL:-}" ]; then
BUILD_ARGS="$BUILD_ARGS --build-arg UV_EXTRA_INDEX_URL=$UV_EXTRA_INDEX_URL"
fi
if [ -n "${UV_INDEX_STRATEGY:-}" ]; then
BUILD_ARGS="$BUILD_ARGS --build-arg UV_INDEX_STRATEGY=$UV_INDEX_STRATEGY"
fi
docker build . \
-f containers/Containerfile \
--build-arg INSTALL_MODE=editable \
--build-arg DISTRO_NAME=${{ matrix.distro }} \
$BUILD_ARGS \
--tag llama-stack:${{ matrix.distro }}-ci
- name: Print dependencies in the image
@ -108,12 +114,18 @@ jobs:
- name: Build container image
run: |
BASE_IMAGE=$(yq -r '.distribution_spec.container_image // "python:3.12-slim"' src/llama_stack/distributions/ci-tests/build.yaml)
BUILD_ARGS="--build-arg INSTALL_MODE=editable --build-arg DISTRO_NAME=ci-tests"
BUILD_ARGS="$BUILD_ARGS --build-arg BASE_IMAGE=$BASE_IMAGE"
BUILD_ARGS="$BUILD_ARGS --build-arg RUN_CONFIG_PATH=/workspace/src/llama_stack/distributions/ci-tests/run.yaml"
if [ -n "${UV_EXTRA_INDEX_URL:-}" ]; then
BUILD_ARGS="$BUILD_ARGS --build-arg UV_EXTRA_INDEX_URL=$UV_EXTRA_INDEX_URL"
fi
if [ -n "${UV_INDEX_STRATEGY:-}" ]; then
BUILD_ARGS="$BUILD_ARGS --build-arg UV_INDEX_STRATEGY=$UV_INDEX_STRATEGY"
fi
docker build . \
-f containers/Containerfile \
--build-arg INSTALL_MODE=editable \
--build-arg DISTRO_NAME=ci-tests \
--build-arg BASE_IMAGE="$BASE_IMAGE" \
--build-arg RUN_CONFIG_PATH=/workspace/src/llama_stack/distributions/ci-tests/run.yaml \
$BUILD_ARGS \
-t llama-stack:ci-tests
- name: Inspect the container image entrypoint
@ -148,12 +160,18 @@ jobs:
- name: Build UBI9 container image
run: |
BASE_IMAGE=$(yq -r '.distribution_spec.container_image // "registry.access.redhat.com/ubi9:latest"' src/llama_stack/distributions/ci-tests/build.yaml)
BUILD_ARGS="--build-arg INSTALL_MODE=editable --build-arg DISTRO_NAME=ci-tests"
BUILD_ARGS="$BUILD_ARGS --build-arg BASE_IMAGE=$BASE_IMAGE"
BUILD_ARGS="$BUILD_ARGS --build-arg RUN_CONFIG_PATH=/workspace/src/llama_stack/distributions/ci-tests/run.yaml"
if [ -n "${UV_EXTRA_INDEX_URL:-}" ]; then
BUILD_ARGS="$BUILD_ARGS --build-arg UV_EXTRA_INDEX_URL=$UV_EXTRA_INDEX_URL"
fi
if [ -n "${UV_INDEX_STRATEGY:-}" ]; then
BUILD_ARGS="$BUILD_ARGS --build-arg UV_INDEX_STRATEGY=$UV_INDEX_STRATEGY"
fi
docker build . \
-f containers/Containerfile \
--build-arg INSTALL_MODE=editable \
--build-arg DISTRO_NAME=ci-tests \
--build-arg BASE_IMAGE="$BASE_IMAGE" \
--build-arg RUN_CONFIG_PATH=/workspace/src/llama_stack/distributions/ci-tests/run.yaml \
$BUILD_ARGS \
-t llama-stack:ci-tests-ubi9
- name: Inspect UBI9 image

View file

@ -24,7 +24,7 @@ jobs:
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Install uv
uses: astral-sh/setup-uv@2ddd2b9cb38ad8efd50337e8ab201519a34c9f24 # v7.1.1
uses: astral-sh/setup-uv@85856786d1ce8acfbcc2f13a5f3fbd6b938f9f41 # v7.1.2
with:
python-version: ${{ matrix.python-version }}
activate-environment: true

View file

@ -4,9 +4,13 @@ run-name: Run the unit test suite
on:
push:
branches: [ main ]
branches:
- main
- 'release-[0-9]+.[0-9]+.x'
pull_request:
branches: [ main ]
branches:
- main
- 'release-[0-9]+.[0-9]+.x'
paths:
- 'src/llama_stack/**'
- '!src/llama_stack/ui/**'

3
.gitignore vendored
View file

@ -32,3 +32,6 @@ CLAUDE.md
docs/.docusaurus/
docs/node_modules/
docs/static/imported-files/
docs/docs/api-deprecated/
docs/docs/api-experimental/
docs/docs/api/

View file

@ -52,22 +52,19 @@ repos:
additional_dependencies:
- black==24.3.0
- repo: https://github.com/astral-sh/uv-pre-commit
rev: 0.7.20
hooks:
- id: uv-lock
- repo: local
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.18.2
hooks:
- id: mypy
name: mypy
additional_dependencies:
- uv==0.7.8
entry: uv run --group dev --group type_checking mypy
language: python
types: [python]
- uv==0.6.2
- mypy
- pytest
- rich
- types-requests
- pydantic
pass_filenames: false
require_serial: true
# - repo: https://github.com/tcort/markdown-link-check
# rev: v3.11.2
@ -77,11 +74,26 @@ repos:
- repo: local
hooks:
- id: uv-lock
name: uv-lock
additional_dependencies:
- uv==0.7.20
entry: ./scripts/uv-run-with-index.sh lock
language: python
pass_filenames: false
require_serial: true
files: ^(pyproject\.toml|uv\.lock)$
- id: mypy-full
name: mypy (full type_checking)
entry: ./scripts/uv-run-with-index.sh run --group dev --group type_checking mypy
language: system
pass_filenames: false
stages: [manual]
- id: distro-codegen
name: Distribution Template Codegen
additional_dependencies:
- uv==0.7.8
entry: uv run --group codegen ./scripts/distro_codegen.py
entry: ./scripts/uv-run-with-index.sh run --group codegen ./scripts/distro_codegen.py
language: python
pass_filenames: false
require_serial: true
@ -90,7 +102,7 @@ repos:
name: Provider Codegen
additional_dependencies:
- uv==0.7.8
entry: uv run --group codegen ./scripts/provider_codegen.py
entry: ./scripts/uv-run-with-index.sh run --group codegen ./scripts/provider_codegen.py
language: python
pass_filenames: false
require_serial: true
@ -99,7 +111,7 @@ repos:
name: API Spec Codegen
additional_dependencies:
- uv==0.7.8
entry: sh -c 'uv run ./docs/openapi_generator/run_openapi_generator.sh > /dev/null'
entry: sh -c './scripts/uv-run-with-index.sh run ./docs/openapi_generator/run_openapi_generator.sh > /dev/null'
language: python
pass_filenames: false
require_serial: true
@ -140,7 +152,7 @@ repos:
name: Generate CI documentation
additional_dependencies:
- uv==0.7.8
entry: uv run ./scripts/gen-ci-docs.py
entry: ./scripts/uv-run-with-index.sh run ./scripts/gen-ci-docs.py
language: python
pass_filenames: false
require_serial: true
@ -171,6 +183,23 @@ repos:
exit 1
fi
exit 0
- id: fips-compliance
name: Ensure llama-stack remains FIPS compliant
entry: bash
language: system
types: [python]
pass_filenames: true
exclude: '^tests/.*$' # Exclude test dir as some safety tests used MD5
args:
- -c
- |
grep -EnH '^[^#]*\b(md5|sha1|uuid3|uuid5)\b' "$@" && {
echo;
echo "❌ Do not use any of the following functions: hashlib.md5, hashlib.sha1, uuid.uuid3, uuid.uuid5"
echo " These functions are not FIPS-compliant"
echo;
exit 1;
} || true
ci:
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks

View file

@ -61,6 +61,18 @@ uv run pre-commit run --all-files -v
The `-v` (verbose) parameter is optional but often helpful for getting more information about any issues with that the pre-commit checks identify.
To run the expanded mypy configuration that CI enforces, use:
```bash
uv run pre-commit run mypy-full --hook-stage manual --all-files
```
or invoke mypy directly with all optional dependencies:
```bash
uv run --group dev --group type_checking mypy
```
```{caution}
Before pushing your changes, make sure that the pre-commit hooks have passed successfully.
```

View file

@ -1,610 +0,0 @@
# yaml-language-server: $schema=https://app.stainlessapi.com/config-internal.schema.json
organization:
# Name of your organization or company, used to determine the name of the client
# and headings.
name: llama-stack-client
docs: https://llama-stack.readthedocs.io/en/latest/
contact: llamastack@meta.com
security:
- {}
- BearerAuth: []
security_schemes:
BearerAuth:
type: http
scheme: bearer
# `targets` define the output targets and their customization options, such as
# whether to emit the Node SDK and what it's package name should be.
targets:
node:
package_name: llama-stack-client
production_repo: llamastack/llama-stack-client-typescript
publish:
npm: false
python:
package_name: llama_stack_client
production_repo: llamastack/llama-stack-client-python
options:
use_uv: true
publish:
pypi: true
project_name: llama_stack_client
kotlin:
reverse_domain: com.llama_stack_client.api
production_repo: null
publish:
maven: false
go:
package_name: llama-stack-client
production_repo: llamastack/llama-stack-client-go
options:
enable_v2: true
back_compat_use_shared_package: false
# `client_settings` define settings for the API client, such as extra constructor
# arguments (used for authentication), retry behavior, idempotency, etc.
client_settings:
default_env_prefix: LLAMA_STACK_CLIENT
opts:
api_key:
type: string
read_env: LLAMA_STACK_CLIENT_API_KEY
auth: { security_scheme: BearerAuth }
nullable: true
# `environments` are a map of the name of the environment (e.g. "sandbox",
# "production") to the corresponding url to use.
environments:
production: http://any-hosted-llama-stack.com
# `pagination` defines [pagination schemes] which provides a template to match
# endpoints and generate next-page and auto-pagination helpers in the SDKs.
pagination:
- name: datasets_iterrows
type: offset
request:
dataset_id:
type: string
start_index:
type: integer
x-stainless-pagination-property:
purpose: offset_count_param
limit:
type: integer
response:
data:
type: array
items:
type: object
next_index:
type: integer
x-stainless-pagination-property:
purpose: offset_count_start_field
- name: openai_cursor_page
type: cursor
request:
limit:
type: integer
after:
type: string
x-stainless-pagination-property:
purpose: next_cursor_param
response:
data:
type: array
items: {}
has_more:
type: boolean
last_id:
type: string
x-stainless-pagination-property:
purpose: next_cursor_field
# `resources` define the structure and organziation for your API, such as how
# methods and models are grouped together and accessed. See the [configuration
# guide] for more information.
#
# [configuration guide]:
# https://app.stainlessapi.com/docs/guides/configure#resources
resources:
$shared:
models:
agent_config: AgentConfig
interleaved_content_item: InterleavedContentItem
interleaved_content: InterleavedContent
param_type: ParamType
safety_violation: SafetyViolation
sampling_params: SamplingParams
scoring_result: ScoringResult
message: Message
user_message: UserMessage
completion_message: CompletionMessage
tool_response_message: ToolResponseMessage
system_message: SystemMessage
tool_call: ToolCall
query_result: RAGQueryResult
document: RAGDocument
query_config: RAGQueryConfig
response_format: ResponseFormat
toolgroups:
models:
tool_group: ToolGroup
list_tool_groups_response: ListToolGroupsResponse
methods:
register: post /v1/toolgroups
get: get /v1/toolgroups/{toolgroup_id}
list: get /v1/toolgroups
unregister: delete /v1/toolgroups/{toolgroup_id}
tools:
methods:
get: get /v1/tools/{tool_name}
list:
endpoint: get /v1/tools
paginated: false
tool_runtime:
models:
tool_def: ToolDef
tool_invocation_result: ToolInvocationResult
methods:
list_tools:
endpoint: get /v1/tool-runtime/list-tools
paginated: false
invoke_tool: post /v1/tool-runtime/invoke
subresources:
rag_tool:
methods:
insert: post /v1/tool-runtime/rag-tool/insert
query: post /v1/tool-runtime/rag-tool/query
responses:
models:
response_object_stream: OpenAIResponseObjectStream
response_object: OpenAIResponseObject
methods:
create:
type: http
endpoint: post /v1/responses
streaming:
stream_event_model: responses.response_object_stream
param_discriminator: stream
retrieve: get /v1/responses/{response_id}
list:
type: http
endpoint: get /v1/responses
delete:
type: http
endpoint: delete /v1/responses/{response_id}
subresources:
input_items:
methods:
list:
type: http
endpoint: get /v1/responses/{response_id}/input_items
conversations:
models:
conversation_object: Conversation
methods:
create:
type: http
endpoint: post /v1/conversations
retrieve: get /v1/conversations/{conversation_id}
update:
type: http
endpoint: post /v1/conversations/{conversation_id}
delete:
type: http
endpoint: delete /v1/conversations/{conversation_id}
subresources:
items:
methods:
get:
type: http
endpoint: get /v1/conversations/{conversation_id}/items/{item_id}
list:
type: http
endpoint: get /v1/conversations/{conversation_id}/items
create:
type: http
endpoint: post /v1/conversations/{conversation_id}/items
inspect:
models:
healthInfo: HealthInfo
providerInfo: ProviderInfo
routeInfo: RouteInfo
versionInfo: VersionInfo
methods:
health: get /v1/health
version: get /v1/version
embeddings:
models:
create_embeddings_response: OpenAIEmbeddingsResponse
methods:
create: post /v1/embeddings
chat:
models:
chat_completion_chunk: OpenAIChatCompletionChunk
subresources:
completions:
methods:
create:
type: http
endpoint: post /v1/chat/completions
streaming:
stream_event_model: chat.chat_completion_chunk
param_discriminator: stream
list:
type: http
endpoint: get /v1/chat/completions
retrieve:
type: http
endpoint: get /v1/chat/completions/{completion_id}
completions:
methods:
create:
type: http
endpoint: post /v1/completions
streaming:
param_discriminator: stream
vector_io:
models:
queryChunksResponse: QueryChunksResponse
methods:
insert: post /v1/vector-io/insert
query: post /v1/vector-io/query
vector_stores:
models:
vector_store: VectorStoreObject
list_vector_stores_response: VectorStoreListResponse
vector_store_delete_response: VectorStoreDeleteResponse
vector_store_search_response: VectorStoreSearchResponsePage
methods:
create: post /v1/vector_stores
list:
endpoint: get /v1/vector_stores
retrieve: get /v1/vector_stores/{vector_store_id}
update: post /v1/vector_stores/{vector_store_id}
delete: delete /v1/vector_stores/{vector_store_id}
search: post /v1/vector_stores/{vector_store_id}/search
subresources:
files:
models:
vector_store_file: VectorStoreFileObject
methods:
list: get /v1/vector_stores/{vector_store_id}/files
retrieve: get /v1/vector_stores/{vector_store_id}/files/{file_id}
update: post /v1/vector_stores/{vector_store_id}/files/{file_id}
delete: delete /v1/vector_stores/{vector_store_id}/files/{file_id}
create: post /v1/vector_stores/{vector_store_id}/files
content: get /v1/vector_stores/{vector_store_id}/files/{file_id}/content
file_batches:
models:
vector_store_file_batches: VectorStoreFileBatchObject
list_vector_store_files_in_batch_response: VectorStoreFilesListInBatchResponse
methods:
create: post /v1/vector_stores/{vector_store_id}/file_batches
retrieve: get /v1/vector_stores/{vector_store_id}/file_batches/{batch_id}
list_files: get /v1/vector_stores/{vector_store_id}/file_batches/{batch_id}/files
cancel: post /v1/vector_stores/{vector_store_id}/file_batches/{batch_id}/cancel
models:
models:
model: Model
list_models_response: ListModelsResponse
methods:
retrieve: get /v1/models/{model_id}
list:
endpoint: get /v1/models
paginated: false
register: post /v1/models
unregister: delete /v1/models/{model_id}
subresources:
openai:
methods:
list:
endpoint: get /v1/models
paginated: false
providers:
models:
list_providers_response: ListProvidersResponse
methods:
list:
endpoint: get /v1/providers
paginated: false
retrieve: get /v1/providers/{provider_id}
routes:
models:
list_routes_response: ListRoutesResponse
methods:
list:
endpoint: get /v1/inspect/routes
paginated: false
moderations:
models:
create_response: ModerationObject
methods:
create: post /v1/moderations
safety:
models:
run_shield_response: RunShieldResponse
methods:
run_shield: post /v1/safety/run-shield
shields:
models:
shield: Shield
list_shields_response: ListShieldsResponse
methods:
retrieve: get /v1/shields/{identifier}
list:
endpoint: get /v1/shields
paginated: false
register: post /v1/shields
delete: delete /v1/shields/{identifier}
synthetic_data_generation:
models:
syntheticDataGenerationResponse: SyntheticDataGenerationResponse
methods:
generate: post /v1/synthetic-data-generation/generate
telemetry:
models:
span_with_status: SpanWithStatus
trace: Trace
query_spans_response: QuerySpansResponse
event: Event
query_condition: QueryCondition
methods:
query_traces:
endpoint: post /v1alpha/telemetry/traces
skip_test_reason: 'unsupported query params in java / kotlin'
get_span_tree: post /v1alpha/telemetry/spans/{span_id}/tree
query_spans:
endpoint: post /v1alpha/telemetry/spans
skip_test_reason: 'unsupported query params in java / kotlin'
query_metrics:
endpoint: post /v1alpha/telemetry/metrics/{metric_name}
skip_test_reason: 'unsupported query params in java / kotlin'
# log_event: post /v1alpha/telemetry/events
save_spans_to_dataset: post /v1alpha/telemetry/spans/export
get_span: get /v1alpha/telemetry/traces/{trace_id}/spans/{span_id}
get_trace: get /v1alpha/telemetry/traces/{trace_id}
scoring:
methods:
score: post /v1/scoring/score
score_batch: post /v1/scoring/score-batch
scoring_functions:
methods:
retrieve: get /v1/scoring-functions/{scoring_fn_id}
list:
endpoint: get /v1/scoring-functions
paginated: false
register: post /v1/scoring-functions
models:
scoring_fn: ScoringFn
scoring_fn_params: ScoringFnParams
list_scoring_functions_response: ListScoringFunctionsResponse
benchmarks:
methods:
retrieve: get /v1alpha/eval/benchmarks/{benchmark_id}
list:
endpoint: get /v1alpha/eval/benchmarks
paginated: false
register: post /v1alpha/eval/benchmarks
models:
benchmark: Benchmark
list_benchmarks_response: ListBenchmarksResponse
files:
methods:
create: post /v1/files
list: get /v1/files
retrieve: get /v1/files/{file_id}
delete: delete /v1/files/{file_id}
content: get /v1/files/{file_id}/content
models:
file: OpenAIFileObject
list_files_response: ListOpenAIFileResponse
delete_file_response: OpenAIFileDeleteResponse
alpha:
subresources:
inference:
methods:
rerank: post /v1alpha/inference/rerank
post_training:
models:
algorithm_config: AlgorithmConfig
post_training_job: PostTrainingJob
list_post_training_jobs_response: ListPostTrainingJobsResponse
methods:
preference_optimize: post /v1alpha/post-training/preference-optimize
supervised_fine_tune: post /v1alpha/post-training/supervised-fine-tune
subresources:
job:
methods:
artifacts: get /v1alpha/post-training/job/artifacts
cancel: post /v1alpha/post-training/job/cancel
status: get /v1alpha/post-training/job/status
list:
endpoint: get /v1alpha/post-training/jobs
paginated: false
eval:
methods:
evaluate_rows: post /v1alpha/eval/benchmarks/{benchmark_id}/evaluations
run_eval: post /v1alpha/eval/benchmarks/{benchmark_id}/jobs
evaluate_rows_alpha: post /v1alpha/eval/benchmarks/{benchmark_id}/evaluations
run_eval_alpha: post /v1alpha/eval/benchmarks/{benchmark_id}/jobs
subresources:
jobs:
methods:
cancel: delete /v1alpha/eval/benchmarks/{benchmark_id}/jobs/{job_id}
status: get /v1alpha/eval/benchmarks/{benchmark_id}/jobs/{job_id}
retrieve: get /v1alpha/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result
models:
evaluate_response: EvaluateResponse
benchmark_config: BenchmarkConfig
job: Job
agents:
methods:
create: post /v1alpha/agents
list: get /v1alpha/agents
retrieve: get /v1alpha/agents/{agent_id}
delete: delete /v1alpha/agents/{agent_id}
models:
inference_step: InferenceStep
tool_execution_step: ToolExecutionStep
tool_response: ToolResponse
shield_call_step: ShieldCallStep
memory_retrieval_step: MemoryRetrievalStep
subresources:
session:
models:
session: Session
methods:
list: get /v1alpha/agents/{agent_id}/sessions
create: post /v1alpha/agents/{agent_id}/session
delete: delete /v1alpha/agents/{agent_id}/session/{session_id}
retrieve: get /v1alpha/agents/{agent_id}/session/{session_id}
steps:
methods:
retrieve: get /v1alpha/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id}
turn:
models:
turn: Turn
turn_response_event: AgentTurnResponseEvent
agent_turn_response_stream_chunk: AgentTurnResponseStreamChunk
methods:
create:
type: http
endpoint: post /v1alpha/agents/{agent_id}/session/{session_id}/turn
streaming:
stream_event_model: alpha.agents.turn.agent_turn_response_stream_chunk
param_discriminator: stream
retrieve: get /v1alpha/agents/{agent_id}/session/{session_id}/turn/{turn_id}
resume:
type: http
endpoint: post /v1alpha/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume
streaming:
stream_event_model: alpha.agents.turn.agent_turn_response_stream_chunk
param_discriminator: stream
beta:
subresources:
datasets:
models:
list_datasets_response: ListDatasetsResponse
methods:
register: post /v1beta/datasets
retrieve: get /v1beta/datasets/{dataset_id}
list:
endpoint: get /v1beta/datasets
paginated: false
unregister: delete /v1beta/datasets/{dataset_id}
iterrows: get /v1beta/datasetio/iterrows/{dataset_id}
appendrows: post /v1beta/datasetio/append-rows/{dataset_id}
settings:
license: MIT
unwrap_response_fields: [ data ]
openapi:
transformations:
- command: renameValue
reason: pydantic reserved name
args:
filter:
only:
- '$.components.schemas.InferenceStep.properties.model_response'
rename:
python:
property_name: 'inference_model_response'
# - command: renameValue
# reason: pydantic reserved name
# args:
# filter:
# only:
# - '$.components.schemas.Model.properties.model_type'
# rename:
# python:
# property_name: 'type'
- command: mergeObject
reason: Better return_type using enum
args:
target:
- '$.components.schemas'
object:
ReturnType:
additionalProperties: false
properties:
type:
enum:
- string
- number
- boolean
- array
- object
- json
- union
- chat_completion_input
- completion_input
- agent_turn_input
required:
- type
type: object
- command: replaceProperties
reason: Replace return type properties with better model (see above)
args:
filter:
only:
- '$.components.schemas.ScoringFn.properties.return_type'
- '$.components.schemas.RegisterScoringFunctionRequest.properties.return_type'
value:
$ref: '#/components/schemas/ReturnType'
- command: oneOfToAnyOf
reason: Prism (mock server) doesn't like one of our requests as it technically matches multiple variants
- reason: For better names
command: extractToRefs
args:
ref:
target: '$.components.schemas.ToolCallDelta.properties.tool_call'
name: '#/components/schemas/ToolCallOrString'
# `readme` is used to configure the code snippets that will be rendered in the
# README.md of various SDKs. In particular, you can change the `headline`
# snippet's endpoint and the arguments to call it with.
readme:
example_requests:
default:
type: request
endpoint: post /v1/chat/completions
params: &ref_0 {}
headline:
type: request
endpoint: post /v1/models
params: *ref_0
pagination:
type: request
endpoint: post /v1/chat/completions
params: {}

File diff suppressed because it is too large Load diff

View file

@ -19,6 +19,8 @@ ARG KEEP_WORKSPACE=""
ARG DISTRO_NAME="starter"
ARG RUN_CONFIG_PATH=""
ARG UV_HTTP_TIMEOUT=500
ARG UV_EXTRA_INDEX_URL=""
ARG UV_INDEX_STRATEGY=""
ENV UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT}
ENV PYTHONDONTWRITEBYTECODE=1
ENV PIP_DISABLE_PIP_VERSION_CHECK=1
@ -45,7 +47,7 @@ RUN set -eux; \
exit 1; \
fi
RUN pip install --no-cache uv
RUN pip install --no-cache-dir uv
ENV UV_SYSTEM_PYTHON=1
ENV INSTALL_MODE=${INSTALL_MODE}
@ -62,47 +64,60 @@ COPY . /workspace
# Install the client package if it is provided
# NOTE: this is installed before llama-stack since llama-stack depends on llama-stack-client-python
# Unset UV index env vars to ensure we only use PyPI for the client
RUN set -eux; \
unset UV_EXTRA_INDEX_URL UV_INDEX_STRATEGY; \
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then \
if [ ! -d "$LLAMA_STACK_CLIENT_DIR" ]; then \
echo "LLAMA_STACK_CLIENT_DIR is set but $LLAMA_STACK_CLIENT_DIR does not exist" >&2; \
exit 1; \
fi; \
uv pip install --no-cache -e "$LLAMA_STACK_CLIENT_DIR"; \
uv pip install --no-cache-dir -e "$LLAMA_STACK_CLIENT_DIR"; \
fi;
# Install llama-stack
# Use UV_EXTRA_INDEX_URL inline only for editable install with RC dependencies
RUN set -eux; \
SAVED_UV_EXTRA_INDEX_URL="${UV_EXTRA_INDEX_URL:-}"; \
SAVED_UV_INDEX_STRATEGY="${UV_INDEX_STRATEGY:-}"; \
unset UV_EXTRA_INDEX_URL UV_INDEX_STRATEGY; \
if [ "$INSTALL_MODE" = "editable" ]; then \
if [ ! -d "$LLAMA_STACK_DIR" ]; then \
echo "INSTALL_MODE=editable requires LLAMA_STACK_DIR to point to a directory inside the build context" >&2; \
exit 1; \
fi; \
uv pip install --no-cache -e "$LLAMA_STACK_DIR"; \
elif [ "$INSTALL_MODE" = "test-pypi" ]; then \
uv pip install --no-cache fastapi libcst; \
if [ -n "$TEST_PYPI_VERSION" ]; then \
uv pip install --no-cache --extra-index-url https://test.pypi.org/simple/ --index-strategy unsafe-best-match "llama-stack==$TEST_PYPI_VERSION"; \
if [ -n "$SAVED_UV_EXTRA_INDEX_URL" ] && [ -n "$SAVED_UV_INDEX_STRATEGY" ]; then \
UV_EXTRA_INDEX_URL="$SAVED_UV_EXTRA_INDEX_URL" UV_INDEX_STRATEGY="$SAVED_UV_INDEX_STRATEGY" \
uv pip install --no-cache-dir -e "$LLAMA_STACK_DIR"; \
else \
uv pip install --no-cache --extra-index-url https://test.pypi.org/simple/ --index-strategy unsafe-best-match llama-stack; \
uv pip install --no-cache-dir -e "$LLAMA_STACK_DIR"; \
fi; \
elif [ "$INSTALL_MODE" = "test-pypi" ]; then \
uv pip install --no-cache-dir fastapi libcst; \
if [ -n "$TEST_PYPI_VERSION" ]; then \
uv pip install --no-cache-dir --extra-index-url https://test.pypi.org/simple/ --index-strategy unsafe-best-match "llama-stack==$TEST_PYPI_VERSION"; \
else \
uv pip install --no-cache-dir --extra-index-url https://test.pypi.org/simple/ --index-strategy unsafe-best-match llama-stack; \
fi; \
else \
if [ -n "$PYPI_VERSION" ]; then \
uv pip install --no-cache "llama-stack==$PYPI_VERSION"; \
uv pip install --no-cache-dir "llama-stack==$PYPI_VERSION"; \
else \
uv pip install --no-cache llama-stack; \
uv pip install --no-cache-dir llama-stack; \
fi; \
fi;
# Install the dependencies for the distribution
# Explicitly unset UV index env vars to ensure we only use PyPI for distribution deps
RUN set -eux; \
unset UV_EXTRA_INDEX_URL UV_INDEX_STRATEGY; \
if [ -z "$DISTRO_NAME" ]; then \
echo "DISTRO_NAME must be provided" >&2; \
exit 1; \
fi; \
deps="$(llama stack list-deps "$DISTRO_NAME")"; \
if [ -n "$deps" ]; then \
printf '%s\n' "$deps" | xargs -L1 uv pip install --no-cache; \
printf '%s\n' "$deps" | xargs -L1 uv pip install --no-cache-dir; \
fi
# Cleanup

View file

@ -23,5 +23,4 @@ A Llama Stack API is described as a collection of REST endpoints. We currently s
We are working on adding a few more APIs to complete the application lifecycle. These will include:
- **Batch Inference**: run inference on a dataset of inputs
- **Batch Agents**: run agents on a dataset of inputs
- **Synthetic Data Generation**: generate synthetic data for model development
- **Batches**: OpenAI-compatible batch management for inference

View file

@ -239,8 +239,13 @@ client = LlamaStackClient(base_url="http://localhost:8321")
models = client.models.list()
# Select the first LLM
llm = next(m for m in models if m.model_type == "llm" and m.provider_id == "ollama")
model_id = llm.identifier
llm = next(
m for m in models
if m.custom_metadata
and m.custom_metadata.get("model_type") == "llm"
and m.custom_metadata.get("provider_id") == "ollama"
)
model_id = llm.id
print("Model:", model_id)
@ -279,8 +284,13 @@ import uuid
client = LlamaStackClient(base_url=f"http://localhost:8321")
models = client.models.list()
llm = next(m for m in models if m.model_type == "llm" and m.provider_id == "ollama")
model_id = llm.identifier
llm = next(
m for m in models
if m.custom_metadata
and m.custom_metadata.get("model_type") == "llm"
and m.custom_metadata.get("provider_id") == "ollama"
)
model_id = llm.id
agent = Agent(client, model=model_id, instructions="You are a helpful assistant.")
@ -450,8 +460,11 @@ import uuid
client = LlamaStackClient(base_url="http://localhost:8321")
# Create a vector database instance
embed_lm = next(m for m in client.models.list() if m.model_type == "embedding")
embedding_model = embed_lm.identifier
embed_lm = next(
m for m in client.models.list()
if m.custom_metadata and m.custom_metadata.get("model_type") == "embedding"
)
embedding_model = embed_lm.id
vector_db_id = f"v{uuid.uuid4().hex}"
# The VectorDB API is deprecated; the server now returns its own authoritative ID.
# We capture the correct ID from the response's .identifier attribute.
@ -489,9 +502,11 @@ client.tool_runtime.rag_tool.insert(
llm = next(
m
for m in client.models.list()
if m.model_type == "llm" and m.provider_id == "ollama"
if m.custom_metadata
and m.custom_metadata.get("model_type") == "llm"
and m.custom_metadata.get("provider_id") == "ollama"
)
model = llm.identifier
model = llm.id
# Create the RAG agent
rag_agent = Agent(

View file

@ -20,6 +20,7 @@ NVIDIA inference provider for accessing NVIDIA NIM models and AI services.
| `url` | `<class 'str'>` | No | https://integrate.api.nvidia.com | A base url for accessing the NVIDIA NIM |
| `timeout` | `<class 'int'>` | No | 60 | Timeout for the HTTP requests |
| `append_api_version` | `<class 'bool'>` | No | True | When set to false, the API version will not be appended to the base_url. By default, it is true. |
| `rerank_model_to_url` | `dict[str, str` | No | `{'nv-rerank-qa-mistral-4b:1': 'https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking', 'nvidia/nv-rerankqa-mistral-4b-v3': 'https://ai.api.nvidia.com/v1/retrieval/nvidia/nv-rerankqa-mistral-4b-v3/reranking', 'nvidia/llama-3.2-nv-rerankqa-1b-v2': 'https://ai.api.nvidia.com/v1/retrieval/nvidia/llama-3_2-nv-rerankqa-1b-v2/reranking'}` | Mapping of rerank model identifiers to their API endpoints. |
## Sample Configuration

File diff suppressed because it is too large Load diff

View file

@ -84,7 +84,6 @@ def generate_spec(output_dir: Path, stability_filter: str = None, main_spec: boo
)
yaml_filename = f"{filename_prefix}llama-stack-spec.yaml"
html_filename = f"{filename_prefix}llama-stack-spec.html"
with open(output_dir / yaml_filename, "w", encoding="utf-8") as fp:
y = yaml.YAML()
@ -102,11 +101,6 @@ def generate_spec(output_dir: Path, stability_filter: str = None, main_spec: boo
fp,
)
with open(output_dir / html_filename, "w") as fp:
spec.write_html(fp, pretty_print=True)
print(f"Generated {yaml_filename} and {html_filename}")
def main(output_dir: str):
output_dir = Path(output_dir)
if not output_dir.exists():

View file

@ -242,15 +242,6 @@ const sidebars: SidebarsConfig = {
'providers/eval/remote_nvidia'
],
},
{
type: 'category',
label: 'Telemetry',
collapsed: true,
items: [
'providers/telemetry/index',
'providers/telemetry/inline_meta-reference'
],
},
{
type: 'category',
label: 'Batches',

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -7,7 +7,7 @@ required-version = ">=0.7.0"
[project]
name = "llama_stack"
version = "0.3.0"
version = "0.4.0.dev0"
authors = [{ name = "Meta Llama", email = "llama-oss@meta.com" }]
description = "Llama Stack"
readme = "README.md"
@ -284,7 +284,6 @@ exclude = [
"^src/llama_stack/models/llama/llama3/interface\\.py$",
"^src/llama_stack/models/llama/llama3/tokenizer\\.py$",
"^src/llama_stack/models/llama/llama3/tool_utils\\.py$",
"^src/llama_stack/providers/inline/agents/meta_reference/",
"^src/llama_stack/providers/inline/datasetio/localfs/",
"^src/llama_stack/providers/inline/eval/meta_reference/eval\\.py$",
"^src/llama_stack/providers/inline/inference/meta_reference/inference\\.py$",

View file

@ -215,6 +215,16 @@ build_image() {
--build-arg "LLAMA_STACK_DIR=/workspace"
)
# Pass UV index configuration for release branches
if [[ -n "${UV_EXTRA_INDEX_URL:-}" ]]; then
echo "Adding UV_EXTRA_INDEX_URL to docker build: $UV_EXTRA_INDEX_URL"
build_cmd+=(--build-arg "UV_EXTRA_INDEX_URL=$UV_EXTRA_INDEX_URL")
fi
if [[ -n "${UV_INDEX_STRATEGY:-}" ]]; then
echo "Adding UV_INDEX_STRATEGY to docker build: $UV_INDEX_STRATEGY"
build_cmd+=(--build-arg "UV_INDEX_STRATEGY=$UV_INDEX_STRATEGY")
fi
if ! "${build_cmd[@]}"; then
echo "❌ Failed to build Docker image"
exit 1

View file

@ -23,7 +23,7 @@ COLLECT_ONLY=false
# Function to display usage
usage() {
cat << EOF
cat <<EOF
Usage: $0 [OPTIONS]
Options:
@ -102,7 +102,6 @@ while [[ $# -gt 0 ]]; do
esac
done
# Validate required parameters
if [[ -z "$STACK_CONFIG" && "$COLLECT_ONLY" == false ]]; then
echo "Error: --stack-config is required"
@ -177,21 +176,45 @@ cd $ROOT_DIR
# check if "llama" and "pytest" are available. this script does not use `uv run` given
# it can be used in a pre-release environment where we have not been able to tell
# uv about pre-release dependencies properly (yet).
if [[ "$COLLECT_ONLY" == false ]] && ! command -v llama &> /dev/null; then
if [[ "$COLLECT_ONLY" == false ]] && ! command -v llama &>/dev/null; then
echo "llama could not be found, ensure llama-stack is installed"
exit 1
fi
if ! command -v pytest &> /dev/null; then
if ! command -v pytest &>/dev/null; then
echo "pytest could not be found, ensure pytest is installed"
exit 1
fi
# Helper function to find next available port
find_available_port() {
local start_port=$1
local port=$start_port
for ((i=0; i<100; i++)); do
if ! lsof -Pi :$port -sTCP:LISTEN -t >/dev/null 2>&1; then
echo $port
return 0
fi
((port++))
done
echo "Failed to find available port starting from $start_port" >&2
return 1
}
# Start Llama Stack Server if needed
if [[ "$STACK_CONFIG" == *"server:"* && "$COLLECT_ONLY" == false ]]; then
# Find an available port for the server
LLAMA_STACK_PORT=$(find_available_port 8321)
if [[ $? -ne 0 ]]; then
echo "Error: $LLAMA_STACK_PORT"
exit 1
fi
export LLAMA_STACK_PORT
echo "Will use port: $LLAMA_STACK_PORT"
stop_server() {
echo "Stopping Llama Stack Server..."
pids=$(lsof -i :8321 | awk 'NR>1 {print $2}')
pids=$(lsof -i :$LLAMA_STACK_PORT | awk 'NR>1 {print $2}')
if [[ -n "$pids" ]]; then
echo "Killing Llama Stack Server processes: $pids"
kill -9 $pids
@ -201,20 +224,25 @@ if [[ "$STACK_CONFIG" == *"server:"* && "$COLLECT_ONLY" == false ]]; then
echo "Llama Stack Server stopped"
}
# check if server is already running
if curl -s http://localhost:8321/v1/health 2>/dev/null | grep -q "OK"; then
echo "Llama Stack Server is already running, skipping start"
else
echo "=== Starting Llama Stack Server ==="
export LLAMA_STACK_LOG_WIDTH=120
# Configure telemetry collector for server mode
# Use a fixed port for the OTEL collector so the server can connect to it
COLLECTOR_PORT=4317
export LLAMA_STACK_TEST_COLLECTOR_PORT="${COLLECTOR_PORT}"
export OTEL_EXPORTER_OTLP_ENDPOINT="http://127.0.0.1:${COLLECTOR_PORT}"
export OTEL_EXPORTER_OTLP_PROTOCOL="http/protobuf"
export OTEL_BSP_SCHEDULE_DELAY="200"
export OTEL_BSP_EXPORT_TIMEOUT="2000"
# remove "server:" from STACK_CONFIG
stack_config=$(echo "$STACK_CONFIG" | sed 's/^server://')
nohup llama stack run $stack_config > server.log 2>&1 &
nohup llama stack run $stack_config >server.log 2>&1 &
echo "Waiting for Llama Stack Server to start..."
echo "Waiting for Llama Stack Server to start on port $LLAMA_STACK_PORT..."
for i in {1..30}; do
if curl -s http://localhost:8321/v1/health 2>/dev/null | grep -q "OK"; then
if curl -s http://localhost:$LLAMA_STACK_PORT/v1/health 2>/dev/null | grep -q "OK"; then
echo "✅ Llama Stack Server started successfully"
break
fi
@ -227,7 +255,6 @@ if [[ "$STACK_CONFIG" == *"server:"* && "$COLLECT_ONLY" == false ]]; then
sleep 1
done
echo ""
fi
trap stop_server EXIT ERR INT TERM
fi
@ -239,7 +266,7 @@ if [[ "$STACK_CONFIG" == *"docker:"* && "$COLLECT_ONLY" == false ]]; then
container_name="llama-stack-test-$DISTRO"
if docker ps -a --format '{{.Names}}' | grep -q "^${container_name}$"; then
echo "Dumping container logs before stopping..."
docker logs "$container_name" > "docker-${DISTRO}-${INFERENCE_MODE}.log" 2>&1 || true
docker logs "$container_name" >"docker-${DISTRO}-${INFERENCE_MODE}.log" 2>&1 || true
echo "Stopping and removing container: $container_name"
docker stop "$container_name" 2>/dev/null || true
docker rm "$container_name" 2>/dev/null || true
@ -251,7 +278,14 @@ if [[ "$STACK_CONFIG" == *"docker:"* && "$COLLECT_ONLY" == false ]]; then
# Extract distribution name from docker:distro format
DISTRO=$(echo "$STACK_CONFIG" | sed 's/^docker://')
export LLAMA_STACK_PORT=8321
# Find an available port for the docker container
LLAMA_STACK_PORT=$(find_available_port 8321)
if [[ $? -ne 0 ]]; then
echo "Error: $LLAMA_STACK_PORT"
exit 1
fi
export LLAMA_STACK_PORT
echo "Will use port: $LLAMA_STACK_PORT"
echo "=== Building Docker Image for distribution: $DISTRO ==="
containerfile="$ROOT_DIR/containers/Containerfile"
@ -271,6 +305,16 @@ if [[ "$STACK_CONFIG" == *"docker:"* && "$COLLECT_ONLY" == false ]]; then
--build-arg "LLAMA_STACK_DIR=/workspace"
)
# Pass UV index configuration for release branches
if [[ -n "${UV_EXTRA_INDEX_URL:-}" ]]; then
echo "Adding UV_EXTRA_INDEX_URL to docker build: $UV_EXTRA_INDEX_URL"
build_cmd+=(--build-arg "UV_EXTRA_INDEX_URL=$UV_EXTRA_INDEX_URL")
fi
if [[ -n "${UV_INDEX_STRATEGY:-}" ]]; then
echo "Adding UV_INDEX_STRATEGY to docker build: $UV_INDEX_STRATEGY"
build_cmd+=(--build-arg "UV_INDEX_STRATEGY=$UV_INDEX_STRATEGY")
fi
if ! "${build_cmd[@]}"; then
echo "❌ Failed to build Docker image"
exit 1
@ -313,8 +357,20 @@ if [[ "$STACK_CONFIG" == *"docker:"* && "$COLLECT_ONLY" == false ]]; then
fi
echo "Using image: $IMAGE_NAME"
docker run -d --network host --name "$container_name" \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
# On macOS/Darwin, --network host doesn't work as expected due to Docker running in a VM
# Use regular port mapping instead
NETWORK_MODE=""
PORT_MAPPINGS=""
if [[ "$(uname)" != "Darwin" ]] && [[ "$(uname)" != *"MINGW"* ]]; then
NETWORK_MODE="--network host"
else
# On non-Linux (macOS, Windows), need explicit port mappings for both app and telemetry
PORT_MAPPINGS="-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT -p $COLLECTOR_PORT:$COLLECTOR_PORT"
echo "Using bridge networking with port mapping (non-Linux)"
fi
docker run -d $NETWORK_MODE --name "$container_name" \
$PORT_MAPPINGS \
$DOCKER_ENV_VARS \
"$IMAGE_NAME" \
--port $LLAMA_STACK_PORT
@ -416,17 +472,13 @@ elif [ $exit_code -eq 5 ]; then
else
echo "❌ Tests failed"
echo ""
echo "=== Dumping last 100 lines of logs for debugging ==="
# Output server or container logs based on stack config
if [[ "$STACK_CONFIG" == *"server:"* && -f "server.log" ]]; then
echo "--- Last 100 lines of server.log ---"
tail -100 server.log
echo "--- Server side failures can be located inside server.log (available from artifacts on CI) ---"
elif [[ "$STACK_CONFIG" == *"docker:"* ]]; then
docker_log_file="docker-${DISTRO}-${INFERENCE_MODE}.log"
if [[ -f "$docker_log_file" ]]; then
echo "--- Last 100 lines of $docker_log_file ---"
tail -100 "$docker_log_file"
echo "--- Server side failures can be located inside $docker_log_file (available from artifacts on CI) ---"
fi
fi

42
scripts/uv-run-with-index.sh Executable file
View file

@ -0,0 +1,42 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
set -euo pipefail
# Detect current branch and target branch
# In GitHub Actions, use GITHUB_REF/GITHUB_BASE_REF
if [[ -n "${GITHUB_REF:-}" ]]; then
BRANCH="${GITHUB_REF#refs/heads/}"
else
BRANCH=$(git rev-parse --abbrev-ref HEAD 2>/dev/null || echo "")
fi
# For PRs, check the target branch
if [[ -n "${GITHUB_BASE_REF:-}" ]]; then
TARGET_BRANCH="${GITHUB_BASE_REF}"
else
TARGET_BRANCH=$(git rev-parse --abbrev-ref HEAD@{upstream} 2>/dev/null | sed 's|origin/||' || echo "")
fi
# Check if on a release branch or targeting one, or LLAMA_STACK_RELEASE_MODE is set
IS_RELEASE=false
if [[ "$BRANCH" =~ ^release-[0-9]+\.[0-9]+\.x$ ]]; then
IS_RELEASE=true
elif [[ "$TARGET_BRANCH" =~ ^release-[0-9]+\.[0-9]+\.x$ ]]; then
IS_RELEASE=true
elif [[ "${LLAMA_STACK_RELEASE_MODE:-}" == "true" ]]; then
IS_RELEASE=true
fi
# On release branches, use test.pypi as extra index for RC versions
if [[ "$IS_RELEASE" == "true" ]]; then
export UV_EXTRA_INDEX_URL="https://test.pypi.org/simple/"
export UV_INDEX_STRATEGY="unsafe-best-match"
fi
# Run uv with all arguments passed through
exec uv "$@"

View file

@ -491,13 +491,6 @@ class Agents(Protocol):
APIs for creating and interacting with agentic systems."""
@webmethod(
route="/agents",
method="POST",
descriptive_name="create_agent",
deprecated=True,
level=LLAMA_STACK_API_V1,
)
@webmethod(
route="/agents",
method="POST",
@ -515,13 +508,6 @@ class Agents(Protocol):
"""
...
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn",
method="POST",
descriptive_name="create_agent_turn",
deprecated=True,
level=LLAMA_STACK_API_V1,
)
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn",
method="POST",
@ -552,13 +538,6 @@ class Agents(Protocol):
"""
...
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
method="POST",
descriptive_name="resume_agent_turn",
deprecated=True,
level=LLAMA_STACK_API_V1,
)
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
method="POST",
@ -586,12 +565,6 @@ class Agents(Protocol):
"""
...
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}",
method="GET",
deprecated=True,
level=LLAMA_STACK_API_V1,
)
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}",
method="GET",
@ -612,12 +585,6 @@ class Agents(Protocol):
"""
...
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id}",
method="GET",
deprecated=True,
level=LLAMA_STACK_API_V1,
)
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id}",
method="GET",
@ -640,13 +607,6 @@ class Agents(Protocol):
"""
...
@webmethod(
route="/agents/{agent_id}/session",
method="POST",
descriptive_name="create_agent_session",
deprecated=True,
level=LLAMA_STACK_API_V1,
)
@webmethod(
route="/agents/{agent_id}/session",
method="POST",
@ -666,12 +626,6 @@ class Agents(Protocol):
"""
...
@webmethod(
route="/agents/{agent_id}/session/{session_id}",
method="GET",
deprecated=True,
level=LLAMA_STACK_API_V1,
)
@webmethod(
route="/agents/{agent_id}/session/{session_id}",
method="GET",
@ -692,12 +646,6 @@ class Agents(Protocol):
"""
...
@webmethod(
route="/agents/{agent_id}/session/{session_id}",
method="DELETE",
deprecated=True,
level=LLAMA_STACK_API_V1,
)
@webmethod(
route="/agents/{agent_id}/session/{session_id}",
method="DELETE",
@ -715,12 +663,6 @@ class Agents(Protocol):
"""
...
@webmethod(
route="/agents/{agent_id}",
method="DELETE",
deprecated=True,
level=LLAMA_STACK_API_V1,
)
@webmethod(route="/agents/{agent_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA)
async def delete_agent(
self,
@ -732,7 +674,6 @@ class Agents(Protocol):
"""
...
@webmethod(route="/agents", method="GET", deprecated=True, level=LLAMA_STACK_API_V1)
@webmethod(route="/agents", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def list_agents(self, start_index: int | None = None, limit: int | None = None) -> PaginatedResponse:
"""List all agents.
@ -743,12 +684,6 @@ class Agents(Protocol):
"""
...
@webmethod(
route="/agents/{agent_id}",
method="GET",
deprecated=True,
level=LLAMA_STACK_API_V1,
)
@webmethod(route="/agents/{agent_id}", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def get_agent(self, agent_id: str) -> Agent:
"""Describe an agent by its ID.
@ -758,12 +693,6 @@ class Agents(Protocol):
"""
...
@webmethod(
route="/agents/{agent_id}/sessions",
method="GET",
deprecated=True,
level=LLAMA_STACK_API_V1,
)
@webmethod(route="/agents/{agent_id}/sessions", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def list_agent_sessions(
self,
@ -787,12 +716,6 @@ class Agents(Protocol):
#
# Both of these APIs are inherently stateful.
@webmethod(
route="/openai/v1/responses/{response_id}",
method="GET",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(route="/responses/{response_id}", method="GET", level=LLAMA_STACK_API_V1)
async def get_openai_response(
self,
@ -805,7 +728,6 @@ class Agents(Protocol):
"""
...
@webmethod(route="/openai/v1/responses", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/responses", method="POST", level=LLAMA_STACK_API_V1)
async def create_openai_response(
self,
@ -842,7 +764,6 @@ class Agents(Protocol):
"""
...
@webmethod(route="/openai/v1/responses", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/responses", method="GET", level=LLAMA_STACK_API_V1)
async def list_openai_responses(
self,
@ -861,9 +782,6 @@ class Agents(Protocol):
"""
...
@webmethod(
route="/openai/v1/responses/{response_id}/input_items", method="GET", level=LLAMA_STACK_API_V1, deprecated=True
)
@webmethod(route="/responses/{response_id}/input_items", method="GET", level=LLAMA_STACK_API_V1)
async def list_openai_response_input_items(
self,
@ -886,7 +804,6 @@ class Agents(Protocol):
"""
...
@webmethod(route="/openai/v1/responses/{response_id}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/responses/{response_id}", method="DELETE", level=LLAMA_STACK_API_V1)
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
"""Delete a response.

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import Sequence
from typing import Annotated, Any, Literal
from pydantic import BaseModel, Field, model_validator
@ -202,7 +203,7 @@ class OpenAIResponseMessage(BaseModel):
scenarios.
"""
content: str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]
content: str | Sequence[OpenAIResponseInputMessageContent] | Sequence[OpenAIResponseOutputMessageContent]
role: Literal["system"] | Literal["developer"] | Literal["user"] | Literal["assistant"]
type: Literal["message"] = "message"
@ -254,10 +255,10 @@ class OpenAIResponseOutputMessageFileSearchToolCall(BaseModel):
"""
id: str
queries: list[str]
queries: Sequence[str]
status: str
type: Literal["file_search_call"] = "file_search_call"
results: list[OpenAIResponseOutputMessageFileSearchToolCallResults] | None = None
results: Sequence[OpenAIResponseOutputMessageFileSearchToolCallResults] | None = None
@json_schema_type
@ -597,7 +598,7 @@ class OpenAIResponseObject(BaseModel):
id: str
model: str
object: Literal["response"] = "response"
output: list[OpenAIResponseOutput]
output: Sequence[OpenAIResponseOutput]
parallel_tool_calls: bool = False
previous_response_id: str | None = None
prompt: OpenAIResponsePrompt | None = None
@ -607,7 +608,7 @@ class OpenAIResponseObject(BaseModel):
# before the field was added. New responses will have this set always.
text: OpenAIResponseText = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text"))
top_p: float | None = None
tools: list[OpenAIResponseTool] | None = None
tools: Sequence[OpenAIResponseTool] | None = None
truncation: str | None = None
usage: OpenAIResponseUsage | None = None
instructions: str | None = None
@ -1315,7 +1316,7 @@ class ListOpenAIResponseInputItem(BaseModel):
:param object: Object type identifier, always "list"
"""
data: list[OpenAIResponseInput]
data: Sequence[OpenAIResponseInput]
object: Literal["list"] = "list"
@ -1326,7 +1327,7 @@ class OpenAIResponseObjectWithInput(OpenAIResponseObject):
:param input: List of input items that led to this response
"""
input: list[OpenAIResponseInput]
input: Sequence[OpenAIResponseInput]
def to_response_object(self) -> OpenAIResponseObject:
"""Convert to OpenAIResponseObject by excluding input field."""
@ -1344,7 +1345,7 @@ class ListOpenAIResponseObject(BaseModel):
:param object: Object type identifier, always "list"
"""
data: list[OpenAIResponseObjectWithInput]
data: Sequence[OpenAIResponseObjectWithInput]
has_more: bool
first_id: str
last_id: str

View file

@ -43,7 +43,6 @@ class Batches(Protocol):
Note: This API is currently under active development and may undergo changes.
"""
@webmethod(route="/openai/v1/batches", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/batches", method="POST", level=LLAMA_STACK_API_V1)
async def create_batch(
self,
@ -64,7 +63,6 @@ class Batches(Protocol):
"""
...
@webmethod(route="/openai/v1/batches/{batch_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/batches/{batch_id}", method="GET", level=LLAMA_STACK_API_V1)
async def retrieve_batch(self, batch_id: str) -> BatchObject:
"""Retrieve information about a specific batch.
@ -74,7 +72,6 @@ class Batches(Protocol):
"""
...
@webmethod(route="/openai/v1/batches/{batch_id}/cancel", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/batches/{batch_id}/cancel", method="POST", level=LLAMA_STACK_API_V1)
async def cancel_batch(self, batch_id: str) -> BatchObject:
"""Cancel a batch that is in progress.
@ -84,7 +81,6 @@ class Batches(Protocol):
"""
...
@webmethod(route="/openai/v1/batches", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/batches", method="GET", level=LLAMA_STACK_API_V1)
async def list_batches(
self,

View file

@ -8,7 +8,7 @@ from typing import Any, Literal, Protocol, runtime_checkable
from pydantic import BaseModel, Field
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
from llama_stack.apis.version import LLAMA_STACK_API_V1ALPHA
from llama_stack.schema_utils import json_schema_type, webmethod
@ -54,7 +54,6 @@ class ListBenchmarksResponse(BaseModel):
@runtime_checkable
class Benchmarks(Protocol):
@webmethod(route="/eval/benchmarks", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/eval/benchmarks", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def list_benchmarks(self) -> ListBenchmarksResponse:
"""List all benchmarks.
@ -63,7 +62,6 @@ class Benchmarks(Protocol):
"""
...
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def get_benchmark(
self,
@ -76,7 +74,6 @@ class Benchmarks(Protocol):
"""
...
@webmethod(route="/eval/benchmarks", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/eval/benchmarks", method="POST", level=LLAMA_STACK_API_V1ALPHA)
async def register_benchmark(
self,
@ -98,7 +95,6 @@ class Benchmarks(Protocol):
"""
...
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA)
async def unregister_benchmark(self, benchmark_id: str) -> None:
"""Unregister a benchmark.

View file

@ -8,7 +8,7 @@ from typing import Any, Protocol, runtime_checkable
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.datasets import Dataset
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1BETA
from llama_stack.apis.version import LLAMA_STACK_API_V1BETA
from llama_stack.schema_utils import webmethod
@ -21,7 +21,6 @@ class DatasetIO(Protocol):
# keeping for aligning with inference/safety, but this is not used
dataset_store: DatasetStore
@webmethod(route="/datasetio/iterrows/{dataset_id:path}", method="GET", deprecated=True, level=LLAMA_STACK_API_V1)
@webmethod(route="/datasetio/iterrows/{dataset_id:path}", method="GET", level=LLAMA_STACK_API_V1BETA)
async def iterrows(
self,
@ -46,9 +45,6 @@ class DatasetIO(Protocol):
"""
...
@webmethod(
route="/datasetio/append-rows/{dataset_id:path}", method="POST", deprecated=True, level=LLAMA_STACK_API_V1
)
@webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST", level=LLAMA_STACK_API_V1BETA)
async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None:
"""Append rows to a dataset.

View file

@ -10,7 +10,7 @@ from typing import Annotated, Any, Literal, Protocol
from pydantic import BaseModel, Field
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1BETA
from llama_stack.apis.version import LLAMA_STACK_API_V1BETA
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@ -146,7 +146,6 @@ class ListDatasetsResponse(BaseModel):
class Datasets(Protocol):
@webmethod(route="/datasets", method="POST", deprecated=True, level=LLAMA_STACK_API_V1)
@webmethod(route="/datasets", method="POST", level=LLAMA_STACK_API_V1BETA)
async def register_dataset(
self,
@ -216,7 +215,6 @@ class Datasets(Protocol):
"""
...
@webmethod(route="/datasets/{dataset_id:path}", method="GET", deprecated=True, level=LLAMA_STACK_API_V1)
@webmethod(route="/datasets/{dataset_id:path}", method="GET", level=LLAMA_STACK_API_V1BETA)
async def get_dataset(
self,
@ -229,7 +227,6 @@ class Datasets(Protocol):
"""
...
@webmethod(route="/datasets", method="GET", deprecated=True, level=LLAMA_STACK_API_V1)
@webmethod(route="/datasets", method="GET", level=LLAMA_STACK_API_V1BETA)
async def list_datasets(self) -> ListDatasetsResponse:
"""List all datasets.
@ -238,7 +235,6 @@ class Datasets(Protocol):
"""
...
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE", deprecated=True, level=LLAMA_STACK_API_V1)
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE", level=LLAMA_STACK_API_V1BETA)
async def unregister_dataset(
self,

View file

@ -13,7 +13,7 @@ from llama_stack.apis.common.job_types import Job
from llama_stack.apis.inference import SamplingParams, SystemMessage
from llama_stack.apis.scoring import ScoringResult
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
from llama_stack.apis.version import LLAMA_STACK_API_V1ALPHA
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@ -86,7 +86,6 @@ class Eval(Protocol):
Llama Stack Evaluation API for running evaluations on model and agent candidates."""
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST", level=LLAMA_STACK_API_V1ALPHA)
async def run_eval(
self,
@ -101,9 +100,6 @@ class Eval(Protocol):
"""
...
@webmethod(
route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST", level=LLAMA_STACK_API_V1, deprecated=True
)
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST", level=LLAMA_STACK_API_V1ALPHA)
async def evaluate_rows(
self,
@ -122,9 +118,6 @@ class Eval(Protocol):
"""
...
@webmethod(
route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True
)
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
"""Get the status of a job.
@ -135,12 +128,6 @@ class Eval(Protocol):
"""
...
@webmethod(
route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}",
method="DELETE",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA)
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
"""Cancel a job.
@ -150,12 +137,6 @@ class Eval(Protocol):
"""
...
@webmethod(
route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result",
method="GET",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(
route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result", method="GET", level=LLAMA_STACK_API_V1ALPHA
)

View file

@ -110,7 +110,6 @@ class Files(Protocol):
"""
# OpenAI Files API Endpoints
@webmethod(route="/openai/v1/files", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/files", method="POST", level=LLAMA_STACK_API_V1)
async def openai_upload_file(
self,
@ -134,7 +133,6 @@ class Files(Protocol):
"""
...
@webmethod(route="/openai/v1/files", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/files", method="GET", level=LLAMA_STACK_API_V1)
async def openai_list_files(
self,
@ -155,7 +153,6 @@ class Files(Protocol):
"""
...
@webmethod(route="/openai/v1/files/{file_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/files/{file_id}", method="GET", level=LLAMA_STACK_API_V1)
async def openai_retrieve_file(
self,
@ -170,7 +167,6 @@ class Files(Protocol):
"""
...
@webmethod(route="/openai/v1/files/{file_id}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/files/{file_id}", method="DELETE", level=LLAMA_STACK_API_V1)
async def openai_delete_file(
self,
@ -183,7 +179,6 @@ class Files(Protocol):
"""
...
@webmethod(route="/openai/v1/files/{file_id}/content", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/files/{file_id}/content", method="GET", level=LLAMA_STACK_API_V1)
async def openai_retrieve_file_content(
self,

View file

@ -1189,7 +1189,6 @@ class InferenceProvider(Protocol):
raise NotImplementedError("Reranking is not implemented")
return # this is so mypy's safe-super rule will consider the method concrete
@webmethod(route="/openai/v1/completions", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/completions", method="POST", level=LLAMA_STACK_API_V1)
async def openai_completion(
self,
@ -1202,7 +1201,6 @@ class InferenceProvider(Protocol):
"""
...
@webmethod(route="/openai/v1/chat/completions", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/chat/completions", method="POST", level=LLAMA_STACK_API_V1)
async def openai_chat_completion(
self,
@ -1215,7 +1213,6 @@ class InferenceProvider(Protocol):
"""
...
@webmethod(route="/openai/v1/embeddings", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/embeddings", method="POST", level=LLAMA_STACK_API_V1)
async def openai_embeddings(
self,
@ -1240,7 +1237,6 @@ class Inference(InferenceProvider):
- Rerank models: these models reorder the documents based on their relevance to a query.
"""
@webmethod(route="/openai/v1/chat/completions", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/chat/completions", method="GET", level=LLAMA_STACK_API_V1)
async def list_chat_completions(
self,
@ -1259,9 +1255,6 @@ class Inference(InferenceProvider):
"""
raise NotImplementedError("List chat completions is not implemented")
@webmethod(
route="/openai/v1/chat/completions/{completion_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True
)
@webmethod(route="/chat/completions/{completion_id}", method="GET", level=LLAMA_STACK_API_V1)
async def get_chat_completion(self, completion_id: str) -> OpenAICompletionWithInputMessages:
"""Get chat completion.

View file

@ -4,14 +4,21 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Protocol, runtime_checkable
from typing import Literal, Protocol, runtime_checkable
from pydantic import BaseModel
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.apis.version import (
LLAMA_STACK_API_V1,
)
from llama_stack.providers.datatypes import HealthStatus
from llama_stack.schema_utils import json_schema_type, webmethod
# Valid values for the route filter parameter.
# Actual API levels: v1, v1alpha, v1beta (filters by level, excludes deprecated)
# Special filter value: "deprecated" (shows deprecated routes regardless of level)
ApiFilter = Literal["v1", "v1alpha", "v1beta", "deprecated"]
@json_schema_type
class RouteInfo(BaseModel):
@ -64,11 +71,12 @@ class Inspect(Protocol):
"""
@webmethod(route="/inspect/routes", method="GET", level=LLAMA_STACK_API_V1)
async def list_routes(self) -> ListRoutesResponse:
async def list_routes(self, api_filter: ApiFilter | None = None) -> ListRoutesResponse:
"""List routes.
List all available API routes with their methods and implementing providers.
:param api_filter: Optional filter to control which routes are returned. Can be an API level ('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level, or 'deprecated' to show deprecated routes across all levels. If not specified, returns only non-deprecated v1 routes.
:returns: Response containing information about all available routes.
"""
...

View file

@ -90,12 +90,14 @@ class OpenAIModel(BaseModel):
:object: The object type, which will be "model"
:created: The Unix timestamp in seconds when the model was created
:owned_by: The owner of the model
:custom_metadata: Llama Stack-specific metadata including model_type, provider info, and additional metadata
"""
id: str
object: Literal["model"] = "model"
created: int
owned_by: str
custom_metadata: dict[str, Any] | None = None
class OpenAIListModelsResponse(BaseModel):
@ -105,7 +107,6 @@ class OpenAIListModelsResponse(BaseModel):
@runtime_checkable
@trace_protocol
class Models(Protocol):
@webmethod(route="/models", method="GET", level=LLAMA_STACK_API_V1)
async def list_models(self) -> ListModelsResponse:
"""List all models.
@ -113,7 +114,7 @@ class Models(Protocol):
"""
...
@webmethod(route="/openai/v1/models", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/models", method="GET", level=LLAMA_STACK_API_V1)
async def openai_list_models(self) -> OpenAIListModelsResponse:
"""List models using the OpenAI API.

View file

@ -13,7 +13,7 @@ from pydantic import BaseModel, Field
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.job_types import JobStatus
from llama_stack.apis.common.training_types import Checkpoint
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
from llama_stack.apis.version import LLAMA_STACK_API_V1ALPHA
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@ -284,7 +284,6 @@ class PostTrainingJobArtifactsResponse(BaseModel):
class PostTraining(Protocol):
@webmethod(route="/post-training/supervised-fine-tune", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/post-training/supervised-fine-tune", method="POST", level=LLAMA_STACK_API_V1ALPHA)
async def supervised_fine_tune(
self,
@ -312,7 +311,6 @@ class PostTraining(Protocol):
"""
...
@webmethod(route="/post-training/preference-optimize", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/post-training/preference-optimize", method="POST", level=LLAMA_STACK_API_V1ALPHA)
async def preference_optimize(
self,
@ -335,7 +333,6 @@ class PostTraining(Protocol):
"""
...
@webmethod(route="/post-training/jobs", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/post-training/jobs", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
"""Get all training jobs.
@ -344,7 +341,6 @@ class PostTraining(Protocol):
"""
...
@webmethod(route="/post-training/job/status", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/post-training/job/status", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse:
"""Get the status of a training job.
@ -354,7 +350,6 @@ class PostTraining(Protocol):
"""
...
@webmethod(route="/post-training/job/cancel", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/post-training/job/cancel", method="POST", level=LLAMA_STACK_API_V1ALPHA)
async def cancel_training_job(self, job_uuid: str) -> None:
"""Cancel a training job.
@ -363,7 +358,6 @@ class PostTraining(Protocol):
"""
...
@webmethod(route="/post-training/job/artifacts", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/post-training/job/artifacts", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
"""Get the artifacts of a training job.

View file

@ -121,7 +121,6 @@ class Safety(Protocol):
"""
...
@webmethod(route="/openai/v1/moderations", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/moderations", method="POST", level=LLAMA_STACK_API_V1)
async def run_moderation(self, input: str | list[str], model: str | None = None) -> ModerationObject:
"""Create moderation.

View file

@ -1,7 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .synthetic_data_generation import *

View file

@ -1,77 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any, Protocol
from pydantic import BaseModel
from llama_stack.apis.inference import Message
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.schema_utils import json_schema_type, webmethod
class FilteringFunction(Enum):
"""The type of filtering function.
:cvar none: No filtering applied, accept all generated synthetic data
:cvar random: Random sampling of generated data points
:cvar top_k: Keep only the top-k highest scoring synthetic data samples
:cvar top_p: Nucleus-style filtering, keep samples exceeding cumulative score threshold
:cvar top_k_top_p: Combined top-k and top-p filtering strategy
:cvar sigmoid: Apply sigmoid function for probability-based filtering
"""
none = "none"
random = "random"
top_k = "top_k"
top_p = "top_p"
top_k_top_p = "top_k_top_p"
sigmoid = "sigmoid"
@json_schema_type
class SyntheticDataGenerationRequest(BaseModel):
"""Request to generate synthetic data. A small batch of prompts and a filtering function
:param dialogs: List of conversation messages to use as input for synthetic data generation
:param filtering_function: Type of filtering to apply to generated synthetic data samples
:param model: (Optional) The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint
"""
dialogs: list[Message]
filtering_function: FilteringFunction = FilteringFunction.none
model: str | None = None
@json_schema_type
class SyntheticDataGenerationResponse(BaseModel):
"""Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold.
:param synthetic_data: List of generated synthetic data samples that passed the filtering criteria
:param statistics: (Optional) Statistical information about the generation process and filtering results
"""
synthetic_data: list[dict[str, Any]]
statistics: dict[str, Any] | None = None
class SyntheticDataGeneration(Protocol):
@webmethod(route="/synthetic-data-generation/generate", level=LLAMA_STACK_API_V1)
def synthetic_data_generate(
self,
dialogs: list[Message],
filtering_function: FilteringFunction = FilteringFunction.none,
model: str | None = None,
) -> SyntheticDataGenerationResponse:
"""Generate synthetic data based on input dialogs and apply filtering.
:param dialogs: List of conversation messages to use as input for synthetic data generation
:param filtering_function: Type of filtering to apply to generated synthetic data samples
:param model: (Optional) The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint
:returns: Response containing filtered synthetic data samples and optional statistics
"""
...

View file

@ -8,7 +8,6 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
from typing import Annotated, Any, Literal, Protocol, runtime_checkable
from fastapi import Body
@ -18,7 +17,6 @@ from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.core.telemetry.trace_protocol import trace_protocol
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
from llama_stack.schema_utils import json_schema_type, webmethod
from llama_stack.strong_typing.schema import register_schema
@ -61,38 +59,19 @@ class Chunk(BaseModel):
"""
A chunk of content that can be inserted into a vector database.
:param content: The content of the chunk, which can be interleaved text, images, or other types.
:param embedding: Optional embedding for the chunk. If not provided, it will be computed later.
:param chunk_id: Unique identifier for the chunk. Must be provided explicitly.
:param metadata: Metadata associated with the chunk that will be used in the model context during inference.
:param stored_chunk_id: The chunk ID that is stored in the vector database. Used for backend functionality.
:param embedding: Optional embedding for the chunk. If not provided, it will be computed later.
:param chunk_metadata: Metadata for the chunk that will NOT be used in the context during inference.
The `chunk_metadata` is required backend functionality.
"""
content: InterleavedContent
chunk_id: str
metadata: dict[str, Any] = Field(default_factory=dict)
embedding: list[float] | None = None
# The alias parameter serializes the field as "chunk_id" in JSON but keeps the internal name as "stored_chunk_id"
stored_chunk_id: str | None = Field(default=None, alias="chunk_id")
chunk_metadata: ChunkMetadata | None = None
model_config = {"populate_by_name": True}
def model_post_init(self, __context):
# Extract chunk_id from metadata if present
if self.metadata and "chunk_id" in self.metadata:
self.stored_chunk_id = self.metadata.pop("chunk_id")
@property
def chunk_id(self) -> str:
"""Returns the chunk ID, which is either an input `chunk_id` or a generated one if not set."""
if self.stored_chunk_id:
return self.stored_chunk_id
if "document_id" in self.metadata:
return generate_chunk_id(self.metadata["document_id"], str(self.content))
return generate_chunk_id(str(uuid.uuid4()), str(self.content))
@property
def document_id(self) -> str | None:
"""Returns the document_id from either metadata or chunk_metadata, with metadata taking precedence."""
@ -566,7 +545,6 @@ class VectorIO(Protocol):
...
# OpenAI Vector Stores API endpoints
@webmethod(route="/openai/v1/vector_stores", method="POST", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/vector_stores", method="POST", level=LLAMA_STACK_API_V1)
async def openai_create_vector_store(
self,
@ -579,7 +557,6 @@ class VectorIO(Protocol):
"""
...
@webmethod(route="/openai/v1/vector_stores", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
@webmethod(route="/vector_stores", method="GET", level=LLAMA_STACK_API_V1)
async def openai_list_vector_stores(
self,
@ -598,9 +575,6 @@ class VectorIO(Protocol):
"""
...
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}", method="GET", level=LLAMA_STACK_API_V1, deprecated=True
)
@webmethod(route="/vector_stores/{vector_store_id}", method="GET", level=LLAMA_STACK_API_V1)
async def openai_retrieve_vector_store(
self,
@ -613,9 +587,6 @@ class VectorIO(Protocol):
"""
...
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}", method="POST", level=LLAMA_STACK_API_V1, deprecated=True
)
@webmethod(
route="/vector_stores/{vector_store_id}",
method="POST",
@ -638,9 +609,6 @@ class VectorIO(Protocol):
"""
...
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}", method="DELETE", level=LLAMA_STACK_API_V1, deprecated=True
)
@webmethod(
route="/vector_stores/{vector_store_id}",
method="DELETE",
@ -657,12 +625,6 @@ class VectorIO(Protocol):
"""
...
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}/search",
method="POST",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(
route="/vector_stores/{vector_store_id}/search",
method="POST",
@ -695,12 +657,6 @@ class VectorIO(Protocol):
"""
...
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}/files",
method="POST",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(
route="/vector_stores/{vector_store_id}/files",
method="POST",
@ -723,12 +679,6 @@ class VectorIO(Protocol):
"""
...
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}/files",
method="GET",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(
route="/vector_stores/{vector_store_id}/files",
method="GET",
@ -755,12 +705,6 @@ class VectorIO(Protocol):
"""
...
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}",
method="GET",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(
route="/vector_stores/{vector_store_id}/files/{file_id}",
method="GET",
@ -779,12 +723,6 @@ class VectorIO(Protocol):
"""
...
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}/content",
method="GET",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(
route="/vector_stores/{vector_store_id}/files/{file_id}/content",
method="GET",
@ -803,12 +741,6 @@ class VectorIO(Protocol):
"""
...
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}",
method="POST",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(
route="/vector_stores/{vector_store_id}/files/{file_id}",
method="POST",
@ -829,12 +761,6 @@ class VectorIO(Protocol):
"""
...
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}/files/{file_id}",
method="DELETE",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(
route="/vector_stores/{vector_store_id}/files/{file_id}",
method="DELETE",
@ -858,12 +784,6 @@ class VectorIO(Protocol):
method="POST",
level=LLAMA_STACK_API_V1,
)
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}/file_batches",
method="POST",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
async def openai_create_vector_store_file_batch(
self,
vector_store_id: str,
@ -882,12 +802,6 @@ class VectorIO(Protocol):
method="GET",
level=LLAMA_STACK_API_V1,
)
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}/file_batches/{batch_id}",
method="GET",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
async def openai_retrieve_vector_store_file_batch(
self,
batch_id: str,
@ -901,12 +815,6 @@ class VectorIO(Protocol):
"""
...
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}/file_batches/{batch_id}/files",
method="GET",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(
route="/vector_stores/{vector_store_id}/file_batches/{batch_id}/files",
method="GET",
@ -935,12 +843,6 @@ class VectorIO(Protocol):
"""
...
@webmethod(
route="/openai/v1/vector_stores/{vector_store_id}/file_batches/{batch_id}/cancel",
method="POST",
level=LLAMA_STACK_API_V1,
deprecated=True,
)
@webmethod(
route="/vector_stores/{vector_store_id}/file_batches/{batch_id}/cancel",
method="POST",

View file

@ -8,16 +8,30 @@ import argparse
import os
import ssl
import subprocess
import sys
from pathlib import Path
import uvicorn
import yaml
from termcolor import cprint
from llama_stack.cli.stack.utils import ImageType
from llama_stack.cli.subcommand import Subcommand
from llama_stack.core.datatypes import StackRunConfig
from llama_stack.core.datatypes import Api, Provider, StackRunConfig
from llama_stack.core.distribution import get_provider_registry
from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
from llama_stack.core.storage.datatypes import (
InferenceStoreReference,
KVStoreReference,
ServerStoresConfig,
SqliteKVStoreConfig,
SqliteSqlStoreConfig,
SqlStoreReference,
StorageConfig,
)
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.log import LoggingConfig, get_logger
REPO_ROOT = Path(__file__).parent.parent.parent.parent
@ -68,6 +82,12 @@ class StackRun(Subcommand):
action="store_true",
help="Start the UI server",
)
self.parser.add_argument(
"--providers",
type=str,
default=None,
help="Run a stack with only a list of providers. This list is formatted like: api1=provider1,api1=provider2,api2=provider3. Where there can be multiple providers per API.",
)
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
import yaml
@ -93,6 +113,55 @@ class StackRun(Subcommand):
config_file = resolve_config_or_distro(args.config, Mode.RUN)
except ValueError as e:
self.parser.error(str(e))
elif args.providers:
provider_list: dict[str, list[Provider]] = dict()
for api_provider in args.providers.split(","):
if "=" not in api_provider:
cprint(
"Could not parse `--providers`. Please ensure the list is in the format api1=provider1,api2=provider2",
color="red",
file=sys.stderr,
)
sys.exit(1)
api, provider_type = api_provider.split("=")
providers_for_api = get_provider_registry().get(Api(api), None)
if providers_for_api is None:
cprint(
f"{api} is not a valid API.",
color="red",
file=sys.stderr,
)
sys.exit(1)
if provider_type in providers_for_api:
config_type = instantiate_class_type(providers_for_api[provider_type].config_class)
if config_type is not None and hasattr(config_type, "sample_run_config"):
config = config_type.sample_run_config(__distro_dir__="~/.llama/distributions/providers-run")
else:
config = {}
provider = Provider(
provider_type=provider_type,
config=config,
provider_id=provider_type.split("::")[1],
)
provider_list.setdefault(api, []).append(provider)
else:
cprint(
f"{provider} is not a valid provider for the {api} API.",
color="red",
file=sys.stderr,
)
sys.exit(1)
run_config = self._generate_run_config_from_providers(providers=provider_list)
config_dict = run_config.model_dump(mode="json")
# Write config to disk in providers-run directory
distro_dir = DISTRIBS_BASE_DIR / "providers-run"
config_file = distro_dir / "run.yaml"
logger.info(f"Writing generated config to: {config_file}")
with open(config_file, "w") as f:
yaml.dump(config_dict, f, default_flow_style=False, sort_keys=False)
else:
config_file = None
@ -106,7 +175,8 @@ class StackRun(Subcommand):
try:
config = parse_and_maybe_upgrade_config(config_dict)
if not os.path.exists(str(config.external_providers_dir)):
# Create external_providers_dir if it's specified and doesn't exist
if config.external_providers_dir and not os.path.exists(str(config.external_providers_dir)):
os.makedirs(str(config.external_providers_dir), exist_ok=True)
except AttributeError as e:
self.parser.error(f"failed to parse config file '{config_file}':\n {e}")
@ -127,7 +197,7 @@ class StackRun(Subcommand):
config = StackRunConfig(**cast_image_name_to_string(replace_env_vars(config_contents)))
port = args.port or config.server.port
host = config.server.host or ["::", "0.0.0.0"]
host = config.server.host or "0.0.0.0"
# Set the config file in environment so create_app can find it
os.environ["LLAMA_STACK_CONFIG"] = str(config_file)
@ -139,6 +209,7 @@ class StackRun(Subcommand):
"lifespan": "on",
"log_level": logger.getEffectiveLevel(),
"log_config": logger_config,
"workers": config.server.workers,
}
keyfile = config.server.tls_keyfile
@ -212,3 +283,44 @@ class StackRun(Subcommand):
)
except Exception as e:
logger.error(f"Failed to start UI development server in {ui_dir}: {e}")
def _generate_run_config_from_providers(self, providers: dict[str, list[Provider]]):
apis = list(providers.keys())
distro_dir = DISTRIBS_BASE_DIR / "providers-run"
# need somewhere to put the storage.
os.makedirs(distro_dir, exist_ok=True)
storage = StorageConfig(
backends={
"kv_default": SqliteKVStoreConfig(
db_path=f"${{env.SQLITE_STORE_DIR:={distro_dir}}}/kvstore.db",
),
"sql_default": SqliteSqlStoreConfig(
db_path=f"${{env.SQLITE_STORE_DIR:={distro_dir}}}/sql_store.db",
),
},
stores=ServerStoresConfig(
metadata=KVStoreReference(
backend="kv_default",
namespace="registry",
),
inference=InferenceStoreReference(
backend="sql_default",
table_name="inference_store",
),
conversations=SqlStoreReference(
backend="sql_default",
table_name="openai_conversations",
),
prompts=KVStoreReference(
backend="kv_default",
namespace="prompts",
),
),
)
return StackRunConfig(
image_name="providers-run",
apis=apis,
providers=providers,
storage=storage,
)

View file

@ -17,7 +17,6 @@ from llama_stack.core.distribution import (
get_provider_registry,
)
from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
from llama_stack.core.utils.config_dirs import EXTERNAL_PROVIDERS_DIR
from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.core.utils.prompt_for_config import prompt_for_config
from llama_stack.log import get_logger
@ -194,19 +193,11 @@ def upgrade_from_routing_table(
def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfig:
version = config_dict.get("version", None)
if version == LLAMA_STACK_RUN_CONFIG_VERSION:
processed_config_dict = replace_env_vars(config_dict)
return StackRunConfig(**cast_image_name_to_string(processed_config_dict))
if "routing_table" in config_dict:
logger.info("Upgrading config...")
config_dict = upgrade_from_routing_table(config_dict)
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION
if not config_dict.get("external_providers_dir", None):
config_dict["external_providers_dir"] = EXTERNAL_PROVIDERS_DIR
processed_config_dict = replace_env_vars(config_dict)
return StackRunConfig(**cast_image_name_to_string(processed_config_dict))

View file

@ -473,6 +473,10 @@ class ServerConfig(BaseModel):
"- true: Enable localhost CORS for development\n"
"- {allow_origins: [...], allow_methods: [...], ...}: Full configuration",
)
workers: int = Field(
default=1,
description="Number of workers to use for the server",
)
class StackRunConfig(BaseModel):

View file

@ -15,6 +15,7 @@ from llama_stack.apis.inspect import (
RouteInfo,
VersionInfo,
)
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.core.datatypes import StackRunConfig
from llama_stack.core.external import load_external_apis
from llama_stack.core.server.routes import get_all_api_routes
@ -39,9 +40,21 @@ class DistributionInspectImpl(Inspect):
async def initialize(self) -> None:
pass
async def list_routes(self) -> ListRoutesResponse:
async def list_routes(self, api_filter: str | None = None) -> ListRoutesResponse:
run_config: StackRunConfig = self.config.run_config
# Helper function to determine if a route should be included based on api_filter
def should_include_route(webmethod) -> bool:
if api_filter is None:
# Default: only non-deprecated v1 APIs
return not webmethod.deprecated and webmethod.level == LLAMA_STACK_API_V1
elif api_filter == "deprecated":
# Special filter: show deprecated routes regardless of their actual level
return bool(webmethod.deprecated)
else:
# Filter by API level (non-deprecated routes only)
return not webmethod.deprecated and webmethod.level == api_filter
ret = []
external_apis = load_external_apis(run_config)
all_endpoints = get_all_api_routes(external_apis)
@ -55,8 +68,8 @@ class DistributionInspectImpl(Inspect):
method=next(iter([m for m in e.methods if m != "HEAD"])),
provider_types=[], # These APIs don't have "real" providers - they're internal to the stack
)
for e, _ in endpoints
if e.methods is not None
for e, webmethod in endpoints
if e.methods is not None and should_include_route(webmethod)
]
)
else:
@ -69,8 +82,8 @@ class DistributionInspectImpl(Inspect):
method=next(iter([m for m in e.methods if m != "HEAD"])),
provider_types=[p.provider_type for p in providers],
)
for e, _ in endpoints
if e.methods is not None
for e, webmethod in endpoints
if e.methods is not None and should_include_route(webmethod)
]
)

View file

@ -13,6 +13,8 @@ from llama_stack.core.datatypes import (
ModelWithOwner,
RegistryEntrySource,
)
from llama_stack.core.request_headers import PROVIDER_DATA_VAR, NeedsRequestProviderData
from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl, lookup_model
@ -42,19 +44,104 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
await self.update_registered_models(provider_id, models)
async def _get_dynamic_models_from_provider_data(self) -> list[Model]:
"""
Fetch models from providers that have credentials in the current request's provider_data.
This allows users to see models available to them from providers that require
per-request API keys (via X-LlamaStack-Provider-Data header).
Returns models with fully qualified identifiers (provider_id/model_id) but does NOT
cache them in the registry since they are user-specific.
"""
provider_data = PROVIDER_DATA_VAR.get()
if not provider_data:
return []
dynamic_models = []
for provider_id, provider in self.impls_by_provider_id.items():
# Check if this provider supports provider_data
if not isinstance(provider, NeedsRequestProviderData):
continue
# Check if provider has a validator (some providers like ollama don't need per-request credentials)
spec = getattr(provider, "__provider_spec__", None)
if not spec or not getattr(spec, "provider_data_validator", None):
continue
# Validate provider_data silently - we're speculatively checking all providers
# so validation failures are expected when user didn't provide keys for this provider
try:
validator = instantiate_class_type(spec.provider_data_validator)
validator(**provider_data)
except Exception:
# User didn't provide credentials for this provider - skip silently
continue
# Validation succeeded! User has credentials for this provider
# Now try to list models
try:
models = await provider.list_models()
if not models:
continue
# Ensure models have fully qualified identifiers with provider_id prefix
for model in models:
# Only add prefix if model identifier doesn't already have it
if not model.identifier.startswith(f"{provider_id}/"):
model.identifier = f"{provider_id}/{model.provider_resource_id}"
dynamic_models.append(model)
logger.debug(f"Fetched {len(models)} models from provider {provider_id} using provider_data")
except Exception as e:
logger.debug(f"Failed to list models from provider {provider_id} with provider_data: {e}")
continue
return dynamic_models
async def list_models(self) -> ListModelsResponse:
return ListModelsResponse(data=await self.get_all_with_type("model"))
# Get models from registry
registry_models = await self.get_all_with_type("model")
# Get additional models available via provider_data (user-specific, not cached)
dynamic_models = await self._get_dynamic_models_from_provider_data()
# Combine, avoiding duplicates (registry takes precedence)
registry_identifiers = {m.identifier for m in registry_models}
unique_dynamic_models = [m for m in dynamic_models if m.identifier not in registry_identifiers]
return ListModelsResponse(data=registry_models + unique_dynamic_models)
async def openai_list_models(self) -> OpenAIListModelsResponse:
models = await self.get_all_with_type("model")
# Get models from registry
registry_models = await self.get_all_with_type("model")
# Get additional models available via provider_data (user-specific, not cached)
dynamic_models = await self._get_dynamic_models_from_provider_data()
# Combine, avoiding duplicates (registry takes precedence)
registry_identifiers = {m.identifier for m in registry_models}
unique_dynamic_models = [m for m in dynamic_models if m.identifier not in registry_identifiers]
all_models = registry_models + unique_dynamic_models
openai_models = [
OpenAIModel(
id=model.identifier,
object="model",
created=int(time.time()),
owned_by="llama_stack",
custom_metadata={
"model_type": model.model_type,
"provider_id": model.provider_id,
"provider_resource_id": model.provider_resource_id,
**model.metadata,
},
)
for model in models
for model in all_models
]
return OpenAIListModelsResponse(data=openai_models)

View file

@ -14,6 +14,7 @@ from typing import Any
import yaml
from llama_stack.apis.agents import Agents
from llama_stack.apis.batches import Batches
from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.conversations import Conversations
from llama_stack.apis.datasetio import DatasetIO
@ -30,7 +31,6 @@ from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
@ -63,8 +63,8 @@ class LlamaStack(
Providers,
Inference,
Agents,
Batches,
Safety,
SyntheticDataGeneration,
Datasets,
PostTraining,
VectorIO,

View file

@ -12,7 +12,7 @@ from llama_stack.core.ui.modules.api import llama_stack_api
def models():
# Models Section
st.header("Models")
models_info = {m.identifier: m.to_dict() for m in llama_stack_api.client.models.list()}
models_info = {m.id: m.model_dump() for m in llama_stack_api.client.models.list()}
selected_model = st.selectbox("Select a model", list(models_info.keys()))
st.json(models_info[selected_model])

View file

@ -12,7 +12,11 @@ from llama_stack.core.ui.modules.api import llama_stack_api
with st.sidebar:
st.header("Configuration")
available_models = llama_stack_api.client.models.list()
available_models = [model.identifier for model in available_models if model.model_type == "llm"]
available_models = [
model.id
for model in available_models
if model.custom_metadata and model.custom_metadata.get("model_type") == "llm"
]
selected_model = st.selectbox(
"Choose a model",
available_models,

View file

@ -11,6 +11,7 @@ import uuid
import warnings
from collections.abc import AsyncGenerator
from datetime import UTC, datetime
from typing import Any, cast
import httpx
@ -125,12 +126,12 @@ class ChatAgent(ShieldRunnerMixin):
)
def turn_to_messages(self, turn: Turn) -> list[Message]:
messages = []
messages: list[Message] = []
# NOTE: if a toolcall response is in a step, we do not add it when processing the input messages
tool_call_ids = set()
for step in turn.steps:
if step.step_type == StepType.tool_execution.value:
if step.step_type == StepType.tool_execution.value and isinstance(step, ToolExecutionStep):
for response in step.tool_responses:
tool_call_ids.add(response.call_id)
@ -149,9 +150,9 @@ class ChatAgent(ShieldRunnerMixin):
messages.append(msg)
for step in turn.steps:
if step.step_type == StepType.inference.value:
if step.step_type == StepType.inference.value and isinstance(step, InferenceStep):
messages.append(step.model_response)
elif step.step_type == StepType.tool_execution.value:
elif step.step_type == StepType.tool_execution.value and isinstance(step, ToolExecutionStep):
for response in step.tool_responses:
messages.append(
ToolResponseMessage(
@ -159,8 +160,8 @@ class ChatAgent(ShieldRunnerMixin):
content=response.content,
)
)
elif step.step_type == StepType.shield_call.value:
if step.violation:
elif step.step_type == StepType.shield_call.value and isinstance(step, ShieldCallStep):
if step.violation and step.violation.user_message:
# CompletionMessage itself in the ShieldResponse
messages.append(
CompletionMessage(
@ -174,7 +175,7 @@ class ChatAgent(ShieldRunnerMixin):
return await self.storage.create_session(name)
async def get_messages_from_turns(self, turns: list[Turn]) -> list[Message]:
messages = []
messages: list[Message] = []
if self.agent_config.instructions != "":
messages.append(SystemMessage(content=self.agent_config.instructions))
@ -231,7 +232,9 @@ class ChatAgent(ShieldRunnerMixin):
steps = []
messages = await self.get_messages_from_turns(turns)
if is_resume:
assert isinstance(request, AgentTurnResumeRequest)
tool_response_messages = [
ToolResponseMessage(call_id=x.call_id, content=x.content) for x in request.tool_responses
]
@ -252,42 +255,52 @@ class ChatAgent(ShieldRunnerMixin):
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
request.session_id, request.turn_id
)
now = datetime.now(UTC).isoformat()
now_dt = datetime.now(UTC)
tool_execution_step = ToolExecutionStep(
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
turn_id=request.turn_id,
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
tool_responses=request.tool_responses,
completed_at=now,
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now),
completed_at=now_dt,
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now_dt),
)
steps.append(tool_execution_step)
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.tool_execution.value,
step_type=StepType.tool_execution,
step_id=tool_execution_step.step_id,
step_details=tool_execution_step,
)
)
)
input_messages = last_turn.input_messages
# Cast needed due to list invariance - last_turn.input_messages is the right type
input_messages = last_turn.input_messages # type: ignore[assignment]
turn_id = request.turn_id
actual_turn_id = request.turn_id
start_time = last_turn.started_at
else:
assert isinstance(request, AgentTurnCreateRequest)
messages.extend(request.messages)
start_time = datetime.now(UTC).isoformat()
input_messages = request.messages
start_time = datetime.now(UTC)
# Cast needed due to list invariance - request.messages is the right type
input_messages = request.messages # type: ignore[assignment]
# Use the generated turn_id from beginning of function
actual_turn_id = turn_id if turn_id else str(uuid.uuid4())
output_message = None
req_documents = request.documents if isinstance(request, AgentTurnCreateRequest) and not is_resume else None
req_sampling = (
self.agent_config.sampling_params if self.agent_config.sampling_params is not None else SamplingParams()
)
async for chunk in self.run(
session_id=request.session_id,
turn_id=turn_id,
turn_id=actual_turn_id,
input_messages=messages,
sampling_params=self.agent_config.sampling_params,
sampling_params=req_sampling,
stream=request.stream,
documents=request.documents if not is_resume else None,
documents=req_documents,
):
if isinstance(chunk, CompletionMessage):
output_message = chunk
@ -295,20 +308,23 @@ class ChatAgent(ShieldRunnerMixin):
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
event = chunk.event
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
steps.append(event.payload.step_details)
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value and hasattr(
event.payload, "step_details"
):
step_details = event.payload.step_details
steps.append(step_details)
yield chunk
assert output_message is not None
turn = Turn(
turn_id=turn_id,
turn_id=actual_turn_id,
session_id=request.session_id,
input_messages=input_messages,
input_messages=input_messages, # type: ignore[arg-type]
output_message=output_message,
started_at=start_time,
completed_at=datetime.now(UTC).isoformat(),
completed_at=datetime.now(UTC),
steps=steps,
)
await self.storage.add_turn_to_session(request.session_id, turn)
@ -345,9 +361,9 @@ class ChatAgent(ShieldRunnerMixin):
# return a "final value" for the `yield from` statement. we simulate that by yielding a
# final boolean (to see whether an exception happened) and then explicitly testing for it.
if len(self.input_shields) > 0:
if self.input_shields:
async for res in self.run_multiple_shields_wrapper(
turn_id, input_messages, self.input_shields, "user-input"
turn_id, cast(list[OpenAIMessageParam], input_messages), self.input_shields, "user-input"
):
if isinstance(res, bool):
return
@ -374,9 +390,9 @@ class ChatAgent(ShieldRunnerMixin):
# for output shields run on the full input and output combination
messages = input_messages + [final_response]
if len(self.output_shields) > 0:
if self.output_shields:
async for res in self.run_multiple_shields_wrapper(
turn_id, messages, self.output_shields, "assistant-output"
turn_id, cast(list[OpenAIMessageParam], messages), self.output_shields, "assistant-output"
):
if isinstance(res, bool):
return
@ -388,7 +404,7 @@ class ChatAgent(ShieldRunnerMixin):
async def run_multiple_shields_wrapper(
self,
turn_id: str,
messages: list[Message],
messages: list[OpenAIMessageParam],
shields: list[str],
touchpoint: str,
) -> AsyncGenerator:
@ -402,12 +418,12 @@ class ChatAgent(ShieldRunnerMixin):
return
step_id = str(uuid.uuid4())
shield_call_start_time = datetime.now(UTC).isoformat()
shield_call_start_time = datetime.now(UTC)
try:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
step_type=StepType.shield_call.value,
step_type=StepType.shield_call,
step_id=step_id,
metadata=dict(touchpoint=touchpoint),
)
@ -419,14 +435,14 @@ class ChatAgent(ShieldRunnerMixin):
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.shield_call.value,
step_type=StepType.shield_call,
step_id=step_id,
step_details=ShieldCallStep(
step_id=step_id,
turn_id=turn_id,
violation=e.violation,
started_at=shield_call_start_time,
completed_at=datetime.now(UTC).isoformat(),
completed_at=datetime.now(UTC),
),
)
)
@ -443,14 +459,14 @@ class ChatAgent(ShieldRunnerMixin):
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.shield_call.value,
step_type=StepType.shield_call,
step_id=step_id,
step_details=ShieldCallStep(
step_id=step_id,
turn_id=turn_id,
violation=None,
started_at=shield_call_start_time,
completed_at=datetime.now(UTC).isoformat(),
completed_at=datetime.now(UTC),
),
)
)
@ -496,21 +512,22 @@ class ChatAgent(ShieldRunnerMixin):
else:
self.tool_name_to_args[tool_name]["vector_store_ids"].append(session_info.vector_store_id)
output_attachments = []
output_attachments: list[Attachment] = []
n_iter = await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0
# Build a map of custom tools to their definitions for faster lookup
client_tools = {}
if self.agent_config.client_tools:
for tool in self.agent_config.client_tools:
client_tools[tool.name] = tool
while True:
step_id = str(uuid.uuid4())
inference_start_time = datetime.now(UTC).isoformat()
inference_start_time = datetime.now(UTC)
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
step_type=StepType.inference.value,
step_type=StepType.inference,
step_id=step_id,
)
)
@ -538,7 +555,7 @@ class ChatAgent(ShieldRunnerMixin):
else:
return value
def _add_type(openai_msg: dict) -> OpenAIMessageParam:
def _add_type(openai_msg: Any) -> OpenAIMessageParam:
# Serialize any nested Pydantic models to plain dicts
openai_msg = _serialize_nested(openai_msg)
@ -588,7 +605,7 @@ class ChatAgent(ShieldRunnerMixin):
messages=openai_messages,
tools=openai_tools if openai_tools else None,
tool_choice=tool_choice,
response_format=self.agent_config.response_format,
response_format=self.agent_config.response_format, # type: ignore[arg-type]
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
@ -598,7 +615,8 @@ class ChatAgent(ShieldRunnerMixin):
# Convert OpenAI stream back to Llama Stack format
response_stream = convert_openai_chat_completion_stream(
openai_stream, enable_incremental_tool_calls=True
openai_stream, # type: ignore[arg-type]
enable_incremental_tool_calls=True,
)
async for chunk in response_stream:
@ -620,7 +638,7 @@ class ChatAgent(ShieldRunnerMixin):
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value,
step_type=StepType.inference,
step_id=step_id,
delta=delta,
)
@ -633,7 +651,7 @@ class ChatAgent(ShieldRunnerMixin):
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value,
step_type=StepType.inference,
step_id=step_id,
delta=delta,
)
@ -651,7 +669,9 @@ class ChatAgent(ShieldRunnerMixin):
output_attr = json.dumps(
{
"content": content,
"tool_calls": [json.loads(t.model_dump_json()) for t in tool_calls],
"tool_calls": [
json.loads(t.model_dump_json()) for t in tool_calls if isinstance(t, ToolCall)
],
}
)
span.set_attribute("output", output_attr)
@ -667,16 +687,18 @@ class ChatAgent(ShieldRunnerMixin):
if tool_calls:
content = ""
# Filter out string tool calls for CompletionMessage (only keep ToolCall objects)
valid_tool_calls = [t for t in tool_calls if isinstance(t, ToolCall)]
message = CompletionMessage(
content=content,
stop_reason=stop_reason,
tool_calls=tool_calls,
tool_calls=valid_tool_calls if valid_tool_calls else None,
)
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.inference.value,
step_type=StepType.inference,
step_id=step_id,
step_details=InferenceStep(
# somewhere deep, we are re-assigning message or closing over some
@ -686,13 +708,14 @@ class ChatAgent(ShieldRunnerMixin):
turn_id=turn_id,
model_response=copy.deepcopy(message),
started_at=inference_start_time,
completed_at=datetime.now(UTC).isoformat(),
completed_at=datetime.now(UTC),
),
)
)
)
if n_iter >= self.agent_config.max_infer_iters:
max_iters = self.agent_config.max_infer_iters if self.agent_config.max_infer_iters is not None else 10
if n_iter >= max_iters:
logger.info(f"done with MAX iterations ({n_iter}), exiting.")
# NOTE: mark end_of_turn to indicate to client that we are done with the turn
# Do not continue the tool call loop after this point
@ -705,14 +728,16 @@ class ChatAgent(ShieldRunnerMixin):
yield message
break
if len(message.tool_calls) == 0:
if not message.tool_calls or len(message.tool_calls) == 0:
if stop_reason == StopReason.end_of_turn:
# TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS)
if len(output_attachments) > 0:
if isinstance(message.content, list):
message.content += output_attachments
# List invariance - attachments are compatible at runtime
message.content += output_attachments # type: ignore[arg-type]
else:
message.content = [message.content] + output_attachments
# List invariance - attachments are compatible at runtime
message.content = [message.content] + output_attachments # type: ignore[assignment]
yield message
else:
logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}")
@ -725,6 +750,7 @@ class ChatAgent(ShieldRunnerMixin):
non_client_tool_calls = []
# Separate client and non-client tool calls
if message.tool_calls:
for tool_call in message.tool_calls:
if tool_call.tool_name in client_tools:
client_tool_calls.append(tool_call)
@ -737,7 +763,7 @@ class ChatAgent(ShieldRunnerMixin):
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
step_type=StepType.tool_execution.value,
step_type=StepType.tool_execution,
step_id=step_id,
)
)
@ -746,7 +772,7 @@ class ChatAgent(ShieldRunnerMixin):
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.tool_execution.value,
step_type=StepType.tool_execution,
step_id=step_id,
delta=ToolCallDelta(
parse_status=ToolCallParseStatus.in_progress,
@ -766,7 +792,7 @@ class ChatAgent(ShieldRunnerMixin):
if self.telemetry_enabled
else {},
) as span:
tool_execution_start_time = datetime.now(UTC).isoformat()
tool_execution_start_time = datetime.now(UTC)
tool_result = await self.execute_tool_call_maybe(
session_id,
tool_call,
@ -796,14 +822,14 @@ class ChatAgent(ShieldRunnerMixin):
)
],
started_at=tool_execution_start_time,
completed_at=datetime.now(UTC).isoformat(),
completed_at=datetime.now(UTC),
)
# Yield the step completion event
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.tool_execution.value,
step_type=StepType.tool_execution,
step_id=step_id,
step_details=tool_execution_step,
)
@ -833,7 +859,7 @@ class ChatAgent(ShieldRunnerMixin):
turn_id=turn_id,
tool_calls=client_tool_calls,
tool_responses=[],
started_at=datetime.now(UTC).isoformat(),
started_at=datetime.now(UTC),
),
)
@ -868,9 +894,10 @@ class ChatAgent(ShieldRunnerMixin):
toolgroup_to_args = toolgroup_to_args or {}
tool_name_to_def = {}
tool_name_to_def: dict[str, ToolDefinition] = {}
tool_name_to_args = {}
if self.agent_config.client_tools:
for tool_def in self.agent_config.client_tools:
if tool_name_to_def.get(tool_def.name, None):
raise ValueError(f"Tool {tool_def.name} already exists")
@ -908,15 +935,17 @@ class ChatAgent(ShieldRunnerMixin):
else:
identifier = None
if tool_name_to_def.get(identifier, None):
raise ValueError(f"Tool {identifier} already exists")
if identifier:
tool_name_to_def[identifier] = ToolDefinition(
tool_name=identifier,
# Convert BuiltinTool to string for dictionary key
identifier_str = identifier.value if isinstance(identifier, BuiltinTool) else identifier
if tool_name_to_def.get(identifier_str, None):
raise ValueError(f"Tool {identifier_str} already exists")
tool_name_to_def[identifier_str] = ToolDefinition(
tool_name=identifier_str,
description=tool_def.description,
input_schema=tool_def.input_schema,
)
tool_name_to_args[identifier] = toolgroup_to_args.get(toolgroup_name, {})
tool_name_to_args[identifier_str] = toolgroup_to_args.get(toolgroup_name, {})
self.tool_defs, self.tool_name_to_args = (
list(tool_name_to_def.values()),
@ -966,7 +995,9 @@ class ChatAgent(ShieldRunnerMixin):
except json.JSONDecodeError as e:
raise ValueError(f"Failed to parse arguments for tool call: {tool_call.arguments}") from e
result = await self.tool_runtime_api.invoke_tool(
result = cast(
ToolInvocationResult,
await self.tool_runtime_api.invoke_tool(
tool_name=tool_name_str,
kwargs={
"session_id": session_id,
@ -974,6 +1005,7 @@ class ChatAgent(ShieldRunnerMixin):
**args,
**self.tool_name_to_args.get(tool_name_str, {}),
},
),
)
logger.debug(f"tool call {tool_name_str} completed with result: {result}")
return result
@ -983,7 +1015,7 @@ async def load_data_from_url(url: str) -> str:
if url.startswith("http"):
async with httpx.AsyncClient() as client:
r = await client.get(url)
resp = r.text
resp: str = r.text
return resp
raise ValueError(f"Unexpected URL: {type(url)}")
@ -1017,7 +1049,7 @@ def _interpret_content_as_attachment(
snippet = match.group(1)
data = json.loads(snippet)
return Attachment(
url=URL(uri="file://" + data["filepath"]),
content=URL(uri="file://" + data["filepath"]),
mime_type=data["mimetype"],
)

View file

@ -21,6 +21,7 @@ from llama_stack.apis.agents import (
Document,
ListOpenAIResponseInputItem,
ListOpenAIResponseObject,
OpenAIDeleteResponseObject,
OpenAIResponseInput,
OpenAIResponseInputTool,
OpenAIResponseObject,
@ -141,7 +142,7 @@ class MetaReferenceAgentsImpl(Agents):
persistence_store=(
self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store
),
created_at=agent_info.created_at,
created_at=agent_info.created_at.isoformat(),
policy=self.policy,
telemetry_enabled=self.telemetry_enabled,
)
@ -163,9 +164,9 @@ class MetaReferenceAgentsImpl(Agents):
agent_id: str,
session_id: str,
messages: list[UserMessage | ToolResponseMessage],
toolgroups: list[AgentToolGroup] | None = None,
documents: list[Document] | None = None,
stream: bool | None = False,
documents: list[Document] | None = None,
toolgroups: list[AgentToolGroup] | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
request = AgentTurnCreateRequest(
@ -221,6 +222,8 @@ class MetaReferenceAgentsImpl(Agents):
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
agent = await self._get_agent_impl(agent_id)
turn = await agent.storage.get_session_turn(session_id, turn_id)
if turn is None:
raise ValueError(f"Turn {turn_id} not found in session {session_id}")
return turn
async def get_agents_step(self, agent_id: str, session_id: str, turn_id: str, step_id: str) -> AgentStepResponse:
@ -232,13 +235,15 @@ class MetaReferenceAgentsImpl(Agents):
async def get_agents_session(
self,
agent_id: str,
session_id: str,
agent_id: str,
turn_ids: list[str] | None = None,
) -> Session:
agent = await self._get_agent_impl(agent_id)
session_info = await agent.storage.get_session_info(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")
turns = await agent.storage.get_session_turns(session_id)
if turn_ids:
turns = [turn for turn in turns if turn.turn_id in turn_ids]
@ -249,7 +254,7 @@ class MetaReferenceAgentsImpl(Agents):
started_at=session_info.started_at,
)
async def delete_agents_session(self, agent_id: str, session_id: str) -> None:
async def delete_agents_session(self, session_id: str, agent_id: str) -> None:
agent = await self._get_agent_impl(agent_id)
# Delete turns first, then the session
@ -302,7 +307,7 @@ class MetaReferenceAgentsImpl(Agents):
agent = Agent(
agent_id=agent_id,
agent_config=chat_agent.agent_config,
created_at=chat_agent.created_at,
created_at=datetime.fromisoformat(chat_agent.created_at),
)
return agent
@ -323,6 +328,7 @@ class MetaReferenceAgentsImpl(Agents):
self,
response_id: str,
) -> OpenAIResponseObject:
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
return await self.openai_responses_impl.get_openai_response(response_id)
async def create_openai_response(
@ -342,7 +348,8 @@ class MetaReferenceAgentsImpl(Agents):
max_infer_iters: int | None = 10,
guardrails: list[ResponseGuardrail] | None = None,
) -> OpenAIResponseObject:
return await self.openai_responses_impl.create_openai_response(
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
result = await self.openai_responses_impl.create_openai_response(
input,
model,
prompt,
@ -358,6 +365,7 @@ class MetaReferenceAgentsImpl(Agents):
max_infer_iters,
guardrails,
)
return result # type: ignore[no-any-return]
async def list_openai_responses(
self,
@ -366,6 +374,7 @@ class MetaReferenceAgentsImpl(Agents):
model: str | None = None,
order: Order | None = Order.desc,
) -> ListOpenAIResponseObject:
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
return await self.openai_responses_impl.list_openai_responses(after, limit, model, order)
async def list_openai_response_input_items(
@ -377,9 +386,11 @@ class MetaReferenceAgentsImpl(Agents):
limit: int | None = 20,
order: Order | None = Order.desc,
) -> ListOpenAIResponseInputItem:
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
return await self.openai_responses_impl.list_openai_response_input_items(
response_id, after, before, include, limit, order
)
async def delete_openai_response(self, response_id: str) -> None:
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
return await self.openai_responses_impl.delete_openai_response(response_id)

View file

@ -6,12 +6,14 @@
import json
import uuid
from dataclasses import dataclass
from datetime import UTC, datetime
from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
from llama_stack.apis.common.errors import SessionNotFoundError
from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed
from llama_stack.core.access_control.datatypes import AccessRule
from llama_stack.core.access_control.conditions import User as ProtocolUser
from llama_stack.core.access_control.datatypes import AccessRule, Action
from llama_stack.core.datatypes import User
from llama_stack.core.request_headers import get_authenticated_user
from llama_stack.log import get_logger
@ -33,6 +35,15 @@ class AgentInfo(AgentConfig):
created_at: datetime
@dataclass
class SessionResource:
"""Concrete implementation of ProtectedResource for session access control."""
type: str
identifier: str
owner: ProtocolUser # Use the protocol type for structural compatibility
class AgentPersistence:
def __init__(self, agent_id: str, kvstore: KVStore, policy: list[AccessRule]):
self.agent_id = agent_id
@ -53,8 +64,15 @@ class AgentPersistence:
turns=[],
identifier=name, # should this be qualified in any way?
)
if not is_action_allowed(self.policy, "create", session_info, user):
raise AccessDeniedError("create", session_info, user)
# Only perform access control if we have an authenticated user
if user is not None and session_info.identifier is not None:
resource = SessionResource(
type=session_info.type,
identifier=session_info.identifier,
owner=user,
)
if not is_action_allowed(self.policy, Action.CREATE, resource, user):
raise AccessDeniedError(Action.CREATE, resource, user)
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}",
@ -62,7 +80,7 @@ class AgentPersistence:
)
return session_id
async def get_session_info(self, session_id: str) -> AgentSessionInfo:
async def get_session_info(self, session_id: str) -> AgentSessionInfo | None:
value = await self.kvstore.get(
key=f"session:{self.agent_id}:{session_id}",
)
@ -83,7 +101,22 @@ class AgentPersistence:
if not hasattr(session_info, "access_attributes") and not hasattr(session_info, "owner"):
return True
return is_action_allowed(self.policy, "read", session_info, get_authenticated_user())
# Get current user - if None, skip access control (e.g., in tests)
user = get_authenticated_user()
if user is None:
return True
# Access control requires identifier and owner to be set
if session_info.identifier is None or session_info.owner is None:
return True
# At this point, both identifier and owner are guaranteed to be non-None
resource = SessionResource(
type=session_info.type,
identifier=session_info.identifier,
owner=session_info.owner,
)
return is_action_allowed(self.policy, Action.READ, resource, user)
async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None:
"""Get session info if the user has access to it. For internal use by sub-session methods."""

View file

@ -91,7 +91,8 @@ class OpenAIResponsesImpl:
input: str | list[OpenAIResponseInput],
previous_response: _OpenAIResponseObjectWithInputAndMessages,
):
new_input_items = previous_response.input.copy()
# Convert Sequence to list for mutation
new_input_items = list(previous_response.input)
new_input_items.extend(previous_response.output)
if isinstance(input, str):
@ -107,7 +108,7 @@ class OpenAIResponsesImpl:
tools: list[OpenAIResponseInputTool] | None,
previous_response_id: str | None,
conversation: str | None,
) -> tuple[str | list[OpenAIResponseInput], list[OpenAIMessageParam]]:
) -> tuple[str | list[OpenAIResponseInput], list[OpenAIMessageParam], ToolContext]:
"""Process input with optional previous response context.
Returns:
@ -208,6 +209,9 @@ class OpenAIResponsesImpl:
messages: list[OpenAIMessageParam],
) -> None:
new_input_id = f"msg_{uuid.uuid4()}"
# Type input_items_data as the full OpenAIResponseInput union to avoid list invariance issues
input_items_data: list[OpenAIResponseInput] = []
if isinstance(input, str):
# synthesize a message from the input string
input_content = OpenAIResponseInputMessageContentText(text=input)
@ -219,7 +223,6 @@ class OpenAIResponsesImpl:
input_items_data = [input_content_item]
else:
# we already have a list of messages
input_items_data = []
for input_item in input:
if isinstance(input_item, OpenAIResponseMessage):
# These may or may not already have an id, so dump to dict, check for id, and add if missing
@ -251,7 +254,7 @@ class OpenAIResponsesImpl:
tools: list[OpenAIResponseInputTool] | None = None,
include: list[str] | None = None,
max_infer_iters: int | None = 10,
guardrails: list[ResponseGuardrailSpec] | None = None,
guardrails: list[str | ResponseGuardrailSpec] | None = None,
):
stream = bool(stream)
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
@ -289,7 +292,8 @@ class OpenAIResponsesImpl:
failed_response = None
async for stream_chunk in stream_gen:
if stream_chunk.type in {"response.completed", "response.incomplete"}:
match stream_chunk.type:
case "response.completed" | "response.incomplete":
if final_response is not None:
raise ValueError(
"The response stream produced multiple terminal responses! "
@ -297,8 +301,10 @@ class OpenAIResponsesImpl:
)
final_response = stream_chunk.response
final_event_type = stream_chunk.type
elif stream_chunk.type == "response.failed":
case "response.failed":
failed_response = stream_chunk.response
case _:
pass # Other event types don't have .response
if failed_response is not None:
error_message = (
@ -326,6 +332,11 @@ class OpenAIResponsesImpl:
max_infer_iters: int | None = 10,
guardrail_ids: list[str] | None = None,
) -> AsyncIterator[OpenAIResponseObjectStream]:
# These should never be None when called from create_openai_response (which sets defaults)
# but we assert here to help mypy understand the types
assert text is not None, "text must not be None"
assert max_infer_iters is not None, "max_infer_iters must not be None"
# Input preprocessing
all_input, messages, tool_context = await self._process_input_with_previous_response(
input, tools, previous_response_id, conversation
@ -368,16 +379,19 @@ class OpenAIResponsesImpl:
final_response = None
failed_response = None
output_items = []
# Type as ConversationItem to avoid list invariance issues
output_items: list[ConversationItem] = []
async for stream_chunk in orchestrator.create_response():
if stream_chunk.type in {"response.completed", "response.incomplete"}:
match stream_chunk.type:
case "response.completed" | "response.incomplete":
final_response = stream_chunk.response
elif stream_chunk.type == "response.failed":
case "response.failed":
failed_response = stream_chunk.response
if stream_chunk.type == "response.output_item.done":
case "response.output_item.done":
item = stream_chunk.item
output_items.append(item)
case _:
pass # Other event types
# Store and sync before yielding terminal events
# This ensures the storage/syncing happens even if the consumer breaks after receiving the event
@ -410,7 +424,8 @@ class OpenAIResponsesImpl:
self, conversation_id: str, input: str | list[OpenAIResponseInput] | None, output_items: list[ConversationItem]
) -> None:
"""Sync content and response messages to the conversation."""
conversation_items = []
# Type as ConversationItem union to avoid list invariance issues
conversation_items: list[ConversationItem] = []
if isinstance(input, str):
conversation_items.append(

View file

@ -111,7 +111,7 @@ class StreamingResponseOrchestrator:
text: OpenAIResponseText,
max_infer_iters: int,
tool_executor, # Will be the tool execution logic from the main class
instructions: str,
instructions: str | None,
safety_api,
guardrail_ids: list[str] | None = None,
prompt: OpenAIResponsePrompt | None = None,
@ -128,7 +128,9 @@ class StreamingResponseOrchestrator:
self.prompt = prompt
self.sequence_number = 0
# Store MCP tool mapping that gets built during tool processing
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = ctx.tool_context.previous_tools or {}
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = (
ctx.tool_context.previous_tools if ctx.tool_context else {}
)
# Track final messages after all tool executions
self.final_messages: list[OpenAIMessageParam] = []
# mapping for annotations
@ -229,7 +231,8 @@ class StreamingResponseOrchestrator:
params = OpenAIChatCompletionRequestWithExtraBody(
model=self.ctx.model,
messages=messages,
tools=self.ctx.chat_tools,
# Pydantic models are dict-compatible but mypy treats them as distinct types
tools=self.ctx.chat_tools, # type: ignore[arg-type]
stream=True,
temperature=self.ctx.temperature,
response_format=response_format,
@ -272,7 +275,12 @@ class StreamingResponseOrchestrator:
# Handle choices with no tool calls
for choice in current_response.choices:
if not (choice.message.tool_calls and self.ctx.response_tools):
has_tool_calls = (
isinstance(choice.message, OpenAIAssistantMessageParam)
and choice.message.tool_calls
and self.ctx.response_tools
)
if not has_tool_calls:
output_messages.append(
await convert_chat_choice_to_response_message(
choice,
@ -722,7 +730,10 @@ class StreamingResponseOrchestrator:
)
# Accumulate arguments for final response (only for subsequent chunks)
if not is_new_tool_call:
if not is_new_tool_call and response_tool_call is not None:
# Both should have functions since we're inside the tool_call.function check above
assert response_tool_call.function is not None
assert tool_call.function is not None
response_tool_call.function.arguments = (
response_tool_call.function.arguments or ""
) + tool_call.function.arguments
@ -747,10 +758,13 @@ class StreamingResponseOrchestrator:
for tool_call_index in sorted(chat_response_tool_calls.keys()):
tool_call = chat_response_tool_calls[tool_call_index]
# Ensure that arguments, if sent back to the inference provider, are not None
if tool_call.function:
tool_call.function.arguments = tool_call.function.arguments or "{}"
tool_call_item_id = tool_call_item_ids[tool_call_index]
final_arguments = tool_call.function.arguments
tool_call_name = chat_response_tool_calls[tool_call_index].function.name
final_arguments: str = tool_call.function.arguments or "{}" if tool_call.function else "{}"
func = chat_response_tool_calls[tool_call_index].function
tool_call_name = func.name if func else ""
# Check if this is an MCP tool call
is_mcp_tool = tool_call_name and tool_call_name in self.mcp_tool_to_server
@ -894,12 +908,11 @@ class StreamingResponseOrchestrator:
self.sequence_number += 1
if tool_call.function.name and tool_call.function.name in self.mcp_tool_to_server:
item = OpenAIResponseOutputMessageMCPCall(
item: OpenAIResponseOutput = OpenAIResponseOutputMessageMCPCall(
arguments="",
name=tool_call.function.name,
id=matching_item_id,
server_label=self.mcp_tool_to_server[tool_call.function.name].server_label,
status="in_progress",
)
elif tool_call.function.name == "web_search":
item = OpenAIResponseOutputMessageWebSearchToolCall(
@ -1008,7 +1021,7 @@ class StreamingResponseOrchestrator:
description=tool.description,
input_schema=tool.input_schema,
)
return convert_tooldef_to_openai_tool(tool_def)
return convert_tooldef_to_openai_tool(tool_def) # type: ignore[return-value] # Returns dict but ChatCompletionToolParam expects TypedDict
# Initialize chat_tools if not already set
if self.ctx.chat_tools is None:
@ -1016,7 +1029,7 @@ class StreamingResponseOrchestrator:
for input_tool in tools:
if input_tool.type == "function":
self.ctx.chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump()))
self.ctx.chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump())) # type: ignore[typeddict-item,arg-type] # Dict compatible with FunctionDefinition
elif input_tool.type in WebSearchToolTypes:
tool_name = "web_search"
# Need to access tool_groups_api from tool_executor
@ -1055,8 +1068,8 @@ class StreamingResponseOrchestrator:
if isinstance(mcp_tool.allowed_tools, list):
always_allowed = mcp_tool.allowed_tools
elif isinstance(mcp_tool.allowed_tools, AllowedToolsFilter):
always_allowed = mcp_tool.allowed_tools.always
never_allowed = mcp_tool.allowed_tools.never
# AllowedToolsFilter only has tool_names field (not allowed/disallowed)
always_allowed = mcp_tool.allowed_tools.tool_names
# Call list_mcp_tools
tool_defs = None
@ -1088,7 +1101,7 @@ class StreamingResponseOrchestrator:
openai_tool = convert_tooldef_to_chat_tool(t)
if self.ctx.chat_tools is None:
self.ctx.chat_tools = []
self.ctx.chat_tools.append(openai_tool)
self.ctx.chat_tools.append(openai_tool) # type: ignore[arg-type] # Returns dict but ChatCompletionToolParam expects TypedDict
# Add to MCP tool mapping
if t.name in self.mcp_tool_to_server:
@ -1120,12 +1133,16 @@ class StreamingResponseOrchestrator:
self, output_messages: list[OpenAIResponseOutput]
) -> AsyncIterator[OpenAIResponseObjectStream]:
# Handle all mcp tool lists from previous response that are still valid:
# tool_context can be None when no tools are provided in the response request
if self.ctx.tool_context:
for tool in self.ctx.tool_context.previous_tool_listings:
async for evt in self._reuse_mcp_list_tools(tool, output_messages):
yield evt
# Process all remaining tools (including MCP tools) and emit streaming events
if self.ctx.tool_context.tools_to_process:
async for stream_event in self._process_new_tools(self.ctx.tool_context.tools_to_process, output_messages):
async for stream_event in self._process_new_tools(
self.ctx.tool_context.tools_to_process, output_messages
):
yield stream_event
def _approval_required(self, tool_name: str) -> bool:
@ -1220,7 +1237,7 @@ class StreamingResponseOrchestrator:
openai_tool = convert_tooldef_to_openai_tool(tool_def)
if self.ctx.chat_tools is None:
self.ctx.chat_tools = []
self.ctx.chat_tools.append(openai_tool)
self.ctx.chat_tools.append(openai_tool) # type: ignore[arg-type] # Returns dict but ChatCompletionToolParam expects TypedDict
mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
id=f"mcp_list_{uuid.uuid4()}",

View file

@ -7,6 +7,7 @@
import asyncio
import json
from collections.abc import AsyncIterator
from typing import Any
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInputToolFileSearch,
@ -22,6 +23,7 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseObjectStreamResponseWebSearchCallSearching,
OpenAIResponseOutputMessageFileSearchToolCall,
OpenAIResponseOutputMessageFileSearchToolCallResults,
OpenAIResponseOutputMessageMCPCall,
OpenAIResponseOutputMessageWebSearchToolCall,
)
from llama_stack.apis.common.content_types import (
@ -67,7 +69,7 @@ class ToolExecutor:
) -> AsyncIterator[ToolExecutionResult]:
tool_call_id = tool_call.id
function = tool_call.function
tool_kwargs = json.loads(function.arguments) if function.arguments else {}
tool_kwargs = json.loads(function.arguments) if function and function.arguments else {}
if not function or not tool_call_id or not function.name:
yield ToolExecutionResult(sequence_number=sequence_number)
@ -84,7 +86,16 @@ class ToolExecutor:
error_exc, result = await self._execute_tool(function.name, tool_kwargs, ctx, mcp_tool_to_server)
# Emit completion events for tool execution
has_error = error_exc or (result and ((result.error_code and result.error_code > 0) or result.error_message))
has_error = bool(
error_exc
or (
result
and (
((error_code := getattr(result, "error_code", None)) and error_code > 0)
or getattr(result, "error_message", None)
)
)
)
async for event_result in self._emit_completion_events(
function.name, ctx, sequence_number, output_index, item_id, has_error, mcp_tool_to_server
):
@ -101,7 +112,9 @@ class ToolExecutor:
sequence_number=sequence_number,
final_output_message=output_message,
final_input_message=input_message,
citation_files=result.metadata.get("citation_files") if result and result.metadata else None,
citation_files=(
metadata.get("citation_files") if result and (metadata := getattr(result, "metadata", None)) else None
),
)
async def _execute_knowledge_search_via_vector_store(
@ -188,8 +201,9 @@ class ToolExecutor:
citation_files[file_id] = filename
# Cast to proper InterleavedContent type (list invariance)
return ToolInvocationResult(
content=content_items,
content=content_items, # type: ignore[arg-type]
metadata={
"document_ids": [r.file_id for r in search_results],
"chunks": [r.content[0].text if r.content else "" for r in search_results],
@ -209,51 +223,60 @@ class ToolExecutor:
) -> AsyncIterator[ToolExecutionResult]:
"""Emit progress events for tool execution start."""
# Emit in_progress event based on tool type (only for tools with specific streaming events)
progress_event = None
if mcp_tool_to_server and function_name in mcp_tool_to_server:
sequence_number += 1
progress_event = OpenAIResponseObjectStreamResponseMcpCallInProgress(
yield ToolExecutionResult(
stream_event=OpenAIResponseObjectStreamResponseMcpCallInProgress(
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
),
sequence_number=sequence_number,
)
elif function_name == "web_search":
sequence_number += 1
progress_event = OpenAIResponseObjectStreamResponseWebSearchCallInProgress(
yield ToolExecutionResult(
stream_event=OpenAIResponseObjectStreamResponseWebSearchCallInProgress(
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
),
sequence_number=sequence_number,
)
elif function_name == "knowledge_search":
sequence_number += 1
progress_event = OpenAIResponseObjectStreamResponseFileSearchCallInProgress(
yield ToolExecutionResult(
stream_event=OpenAIResponseObjectStreamResponseFileSearchCallInProgress(
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
),
sequence_number=sequence_number,
)
if progress_event:
yield ToolExecutionResult(stream_event=progress_event, sequence_number=sequence_number)
# For web search, emit searching event
if function_name == "web_search":
sequence_number += 1
searching_event = OpenAIResponseObjectStreamResponseWebSearchCallSearching(
yield ToolExecutionResult(
stream_event=OpenAIResponseObjectStreamResponseWebSearchCallSearching(
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
),
sequence_number=sequence_number,
)
yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number)
# For file search, emit searching event
if function_name == "knowledge_search":
sequence_number += 1
searching_event = OpenAIResponseObjectStreamResponseFileSearchCallSearching(
yield ToolExecutionResult(
stream_event=OpenAIResponseObjectStreamResponseFileSearchCallSearching(
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
),
sequence_number=sequence_number,
)
yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number)
async def _execute_tool(
self,
@ -261,7 +284,7 @@ class ToolExecutor:
tool_kwargs: dict,
ctx: ChatCompletionContext,
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
) -> tuple[Exception | None, any]:
) -> tuple[Exception | None, Any]:
"""Execute the tool and return error exception and result."""
error_exc = None
result = None
@ -284,10 +307,14 @@ class ToolExecutor:
kwargs=tool_kwargs,
)
elif function_name == "knowledge_search":
response_file_search_tool = next(
response_file_search_tool = (
next(
(t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)),
None,
)
if ctx.response_tools
else None
)
if response_file_search_tool:
# Use vector_stores.search API instead of knowledge_search tool
# to support filters and ranking_options
@ -322,35 +349,34 @@ class ToolExecutor:
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
) -> AsyncIterator[ToolExecutionResult]:
"""Emit completion or failure events for tool execution."""
completion_event = None
if mcp_tool_to_server and function_name in mcp_tool_to_server:
sequence_number += 1
if has_error:
completion_event = OpenAIResponseObjectStreamResponseMcpCallFailed(
mcp_failed_event = OpenAIResponseObjectStreamResponseMcpCallFailed(
sequence_number=sequence_number,
)
yield ToolExecutionResult(stream_event=mcp_failed_event, sequence_number=sequence_number)
else:
completion_event = OpenAIResponseObjectStreamResponseMcpCallCompleted(
mcp_completed_event = OpenAIResponseObjectStreamResponseMcpCallCompleted(
sequence_number=sequence_number,
)
yield ToolExecutionResult(stream_event=mcp_completed_event, sequence_number=sequence_number)
elif function_name == "web_search":
sequence_number += 1
completion_event = OpenAIResponseObjectStreamResponseWebSearchCallCompleted(
web_completion_event = OpenAIResponseObjectStreamResponseWebSearchCallCompleted(
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
)
yield ToolExecutionResult(stream_event=web_completion_event, sequence_number=sequence_number)
elif function_name == "knowledge_search":
sequence_number += 1
completion_event = OpenAIResponseObjectStreamResponseFileSearchCallCompleted(
file_completion_event = OpenAIResponseObjectStreamResponseFileSearchCallCompleted(
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
)
if completion_event:
yield ToolExecutionResult(stream_event=completion_event, sequence_number=sequence_number)
yield ToolExecutionResult(stream_event=file_completion_event, sequence_number=sequence_number)
async def _build_result_messages(
self,
@ -360,21 +386,18 @@ class ToolExecutor:
tool_kwargs: dict,
ctx: ChatCompletionContext,
error_exc: Exception | None,
result: any,
result: Any,
has_error: bool,
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
) -> tuple[any, any]:
) -> tuple[Any, Any]:
"""Build output and input messages from tool execution results."""
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
# Build output message
message: Any
if mcp_tool_to_server and function.name in mcp_tool_to_server:
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseOutputMessageMCPCall,
)
message = OpenAIResponseOutputMessageMCPCall(
id=item_id,
arguments=function.arguments,
@ -383,10 +406,14 @@ class ToolExecutor:
)
if error_exc:
message.error = str(error_exc)
elif (result and result.error_code and result.error_code > 0) or (result and result.error_message):
message.error = f"Error (code {result.error_code}): {result.error_message}"
elif result and result.content:
message.output = interleaved_content_as_str(result.content)
elif (result and (error_code := getattr(result, "error_code", None)) and error_code > 0) or (
result and getattr(result, "error_message", None)
):
ec = getattr(result, "error_code", "unknown")
em = getattr(result, "error_message", "")
message.error = f"Error (code {ec}): {em}"
elif result and (content := getattr(result, "content", None)):
message.output = interleaved_content_as_str(content)
else:
if function.name == "web_search":
message = OpenAIResponseOutputMessageWebSearchToolCall(
@ -401,17 +428,17 @@ class ToolExecutor:
queries=[tool_kwargs.get("query", "")],
status="completed",
)
if result and "document_ids" in result.metadata:
if result and (metadata := getattr(result, "metadata", None)) and "document_ids" in metadata:
message.results = []
for i, doc_id in enumerate(result.metadata["document_ids"]):
text = result.metadata["chunks"][i] if "chunks" in result.metadata else None
score = result.metadata["scores"][i] if "scores" in result.metadata else None
for i, doc_id in enumerate(metadata["document_ids"]):
text = metadata["chunks"][i] if "chunks" in metadata else None
score = metadata["scores"][i] if "scores" in metadata else None
message.results.append(
OpenAIResponseOutputMessageFileSearchToolCallResults(
file_id=doc_id,
filename=doc_id,
text=text,
score=score,
text=text if text is not None else "",
score=score if score is not None else 0.0,
attributes={},
)
)
@ -421,27 +448,32 @@ class ToolExecutor:
raise ValueError(f"Unknown tool {function.name} called")
# Build input message
input_message = None
if result and result.content:
if isinstance(result.content, str):
content = result.content
elif isinstance(result.content, list):
content = []
for item in result.content:
input_message: OpenAIToolMessageParam | None = None
if result and (result_content := getattr(result, "content", None)):
# all the mypy contortions here are still unsatisfactory with random Any typing
if isinstance(result_content, str):
msg_content: str | list[Any] = result_content
elif isinstance(result_content, list):
content_list: list[Any] = []
for item in result_content:
part: Any
if isinstance(item, TextContentItem):
part = OpenAIChatCompletionContentPartTextParam(text=item.text)
elif isinstance(item, ImageContentItem):
if item.image.data:
url = f"data:image;base64,{item.image.data}"
url_value = f"data:image;base64,{item.image.data}"
else:
url = item.image.url
part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url))
url_value = str(item.image.url) if item.image.url else ""
part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url_value))
else:
raise ValueError(f"Unknown result content type: {type(item)}")
content.append(part)
content_list.append(part)
msg_content = content_list
else:
raise ValueError(f"Unknown result content type: {type(result.content)}")
input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id)
raise ValueError(f"Unknown result content type: {type(result_content)}")
# OpenAIToolMessageParam accepts str | list[TextParam] but we may have images
# This is runtime-safe as the API accepts it, but mypy complains
input_message = OpenAIToolMessageParam(content=msg_content, tool_call_id=tool_call_id) # type: ignore[arg-type]
else:
text = str(error_exc) if error_exc else "Tool execution failed"
input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id)

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
from dataclasses import dataclass
from typing import cast
from openai.types.chat import ChatCompletionToolParam
from pydantic import BaseModel
@ -100,17 +101,19 @@ class ToolContext(BaseModel):
if isinstance(tool, OpenAIResponseToolMCP):
previous_tools_by_label[tool.server_label] = tool
# collect tool definitions which are the same in current and previous requests:
tools_to_process = []
tools_to_process: list[OpenAIResponseInputTool] = []
matched: dict[str, OpenAIResponseInputToolMCP] = {}
for tool in self.current_tools:
# Mypy confuses OpenAIResponseInputTool (Input union) with OpenAIResponseTool (output union)
# which differ only in MCP type (InputToolMCP vs ToolMCP). Code is correct.
for tool in cast(list[OpenAIResponseInputTool], self.current_tools): # type: ignore[assignment]
if isinstance(tool, OpenAIResponseInputToolMCP) and tool.server_label in previous_tools_by_label:
previous_tool = previous_tools_by_label[tool.server_label]
if previous_tool.allowed_tools == tool.allowed_tools:
matched[tool.server_label] = tool
else:
tools_to_process.append(tool)
tools_to_process.append(tool) # type: ignore[arg-type]
else:
tools_to_process.append(tool)
tools_to_process.append(tool) # type: ignore[arg-type]
# tools that are not the same or were not previously defined need to be processed:
self.tools_to_process = tools_to_process
# for all matched definitions, get the mcp_list_tools objects from the previous output:
@ -119,9 +122,11 @@ class ToolContext(BaseModel):
]
# reconstruct the tool to server mappings that can be reused:
for listing in self.previous_tool_listings:
# listing is OpenAIResponseOutputMessageMCPListTools which has tools: list[MCPListToolsTool]
definition = matched[listing.server_label]
for tool in listing.tools:
self.previous_tools[tool.name] = definition
for mcp_tool in listing.tools:
# mcp_tool is MCPListToolsTool which has a name: str field
self.previous_tools[mcp_tool.name] = definition
def available_tools(self) -> list[OpenAIResponseTool]:
if not self.current_tools:
@ -139,6 +144,8 @@ class ToolContext(BaseModel):
server_label=tool.server_label,
allowed_tools=tool.allowed_tools,
)
# Exhaustive check - all tool types should be handled above
raise AssertionError(f"Unexpected tool type: {type(tool)}")
return [convert_tool(tool) for tool in self.current_tools]

View file

@ -7,6 +7,7 @@
import asyncio
import re
import uuid
from collections.abc import Sequence
from llama_stack.apis.agents.agents import ResponseGuardrailSpec
from llama_stack.apis.agents.openai_responses import (
@ -71,14 +72,14 @@ async def convert_chat_choice_to_response_message(
return OpenAIResponseMessage(
id=message_id or f"msg_{uuid.uuid4()}",
content=[OpenAIResponseOutputMessageContentOutputText(text=clean_text, annotations=annotations)],
content=[OpenAIResponseOutputMessageContentOutputText(text=clean_text, annotations=list(annotations))],
status="completed",
role="assistant",
)
async def convert_response_content_to_chat_content(
content: (str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]),
content: str | Sequence[OpenAIResponseInputMessageContent | OpenAIResponseOutputMessageContent],
) -> str | list[OpenAIChatCompletionContentPartParam]:
"""
Convert the content parts from an OpenAI Response API request into OpenAI Chat Completion content parts.
@ -88,7 +89,8 @@ async def convert_response_content_to_chat_content(
if isinstance(content, str):
return content
converted_parts = []
# Type with union to avoid list invariance issues
converted_parts: list[OpenAIChatCompletionContentPartParam] = []
for content_part in content:
if isinstance(content_part, OpenAIResponseInputMessageContentText):
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
@ -158,9 +160,11 @@ async def convert_response_input_to_chat_messages(
),
)
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
# Output can be None, use empty string as fallback
output_content = input_item.output if input_item.output is not None else ""
messages.append(
OpenAIToolMessageParam(
content=input_item.output,
content=output_content,
tool_call_id=input_item.id,
)
)
@ -172,7 +176,8 @@ async def convert_response_input_to_chat_messages(
):
# these are handled by the responses impl itself and not pass through to chat completions
pass
else:
elif isinstance(input_item, OpenAIResponseMessage):
# Narrow type to OpenAIResponseMessage which has content and role attributes
content = await convert_response_content_to_chat_content(input_item.content)
message_type = await get_message_type_by_role(input_item.role)
if message_type is None:
@ -191,7 +196,8 @@ async def convert_response_input_to_chat_messages(
last_user_content = getattr(last_user_msg, "content", None)
if last_user_content == content:
continue # Skip duplicate user message
messages.append(message_type(content=content))
# Dynamic message type call - different message types have different content expectations
messages.append(message_type(content=content)) # type: ignore[call-arg,arg-type]
if len(tool_call_results):
# Check if unpaired function_call_outputs reference function_calls from previous messages
if previous_messages:
@ -237,8 +243,11 @@ async def convert_response_text_to_chat_response_format(
if text.format["type"] == "json_object":
return OpenAIResponseFormatJSONObject()
if text.format["type"] == "json_schema":
# Assert name exists for json_schema format
assert text.format.get("name"), "json_schema format requires a name"
schema_name: str = text.format["name"] # type: ignore[assignment]
return OpenAIResponseFormatJSONSchema(
json_schema=OpenAIJSONSchema(name=text.format["name"], schema=text.format["schema"])
json_schema=OpenAIJSONSchema(name=schema_name, schema=text.format["schema"])
)
raise ValueError(f"Unsupported text format: {text.format}")
@ -251,7 +260,7 @@ async def get_message_type_by_role(role: str) -> type[OpenAIMessageParam] | None
"assistant": OpenAIAssistantMessageParam,
"developer": OpenAIDeveloperMessageParam,
}
return role_to_type.get(role)
return role_to_type.get(role) # type: ignore[return-value] # Pydantic models use ModelMetaclass
def _extract_citations_from_text(
@ -320,7 +329,8 @@ async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[
# Look up shields to get their provider_resource_id (actual model ID)
model_ids = []
shields_list = await safety_api.routing_table.list_shields()
# TODO: list_shields not in Safety interface but available at runtime via API routing
shields_list = await safety_api.routing_table.list_shields() # type: ignore[attr-defined]
for guardrail_id in guardrail_ids:
matching_shields = [shield for shield in shields_list.data if shield.identifier == guardrail_id]
@ -337,7 +347,9 @@ async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[
for result in response.results:
if result.flagged:
message = result.user_message or "Content blocked by safety guardrails"
flagged_categories = [cat for cat, flagged in result.categories.items() if flagged]
flagged_categories = (
[cat for cat, flagged in result.categories.items() if flagged] if result.categories else []
)
violation_type = result.metadata.get("violation_type", []) if result.metadata else []
if flagged_categories:
@ -347,6 +359,9 @@ async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[
return message
# No violations found
return None
def extract_guardrail_ids(guardrails: list | None) -> list[str]:
"""Extract guardrail IDs from guardrails parameter, handling both string IDs and ResponseGuardrailSpec objects."""

View file

@ -6,7 +6,7 @@
import asyncio
from llama_stack.apis.inference import Message
from llama_stack.apis.inference import OpenAIMessageParam
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
from llama_stack.core.telemetry import tracing
from llama_stack.log import get_logger
@ -31,7 +31,7 @@ class ShieldRunnerMixin:
self.input_shields = input_shields
self.output_shields = output_shields
async def run_multiple_shields(self, messages: list[Message], identifiers: list[str]) -> None:
async def run_multiple_shields(self, messages: list[OpenAIMessageParam], identifiers: list[str]) -> None:
async def run_shield_with_span(identifier: str):
async with tracing.span(f"run_shield_{identifier}"):
return await self.safety_api.run_shield(

View file

@ -33,4 +33,5 @@ class AnthropicInferenceAdapter(OpenAIMixin):
return "https://api.anthropic.com/v1"
async def list_provider_model_ids(self) -> Iterable[str]:
return [m.id async for m in AsyncAnthropic(api_key=self.get_api_key()).models.list()]
api_key = self._get_api_key_from_config_or_provider_data()
return [m.id async for m in AsyncAnthropic(api_key=api_key).models.list()]

View file

@ -33,10 +33,11 @@ class DatabricksInferenceAdapter(OpenAIMixin):
async def list_provider_model_ids(self) -> Iterable[str]:
# Filter out None values from endpoint names
api_token = self._get_api_key_from_config_or_provider_data()
return [
endpoint.name # type: ignore[misc]
for endpoint in WorkspaceClient(
host=self.config.url, token=self.get_api_key()
host=self.config.url, token=api_token
).serving_endpoints.list() # TODO: this is not async
]

View file

@ -181,3 +181,22 @@ vlm_response = client.chat.completions.create(
print(f"VLM Response: {vlm_response.choices[0].message.content}")
```
### Rerank Example
The following example shows how to rerank documents using an NVIDIA NIM.
```python
rerank_response = client.alpha.inference.rerank(
model="nvidia/nvidia/llama-3.2-nv-rerankqa-1b-v2",
query="query",
items=[
"item_1",
"item_2",
"item_3",
],
)
for i, result in enumerate(rerank_response):
print(f"{i+1}. [Index: {result.index}, " f"Score: {(result.relevance_score):.3f}]")
```

View file

@ -28,6 +28,7 @@ class NVIDIAConfig(RemoteInferenceProviderConfig):
Attributes:
url (str): A base url for accessing the NVIDIA NIM, e.g. http://localhost:8000
api_key (str): The access key for the hosted NIM endpoints
rerank_model_to_url (dict[str, str]): Mapping of rerank model identifiers to their API endpoints
There are two ways to access NVIDIA NIMs -
0. Hosted: Preview APIs hosted at https://integrate.api.nvidia.com
@ -55,6 +56,14 @@ class NVIDIAConfig(RemoteInferenceProviderConfig):
default_factory=lambda: os.getenv("NVIDIA_APPEND_API_VERSION", "True").lower() != "false",
description="When set to false, the API version will not be appended to the base_url. By default, it is true.",
)
rerank_model_to_url: dict[str, str] = Field(
default_factory=lambda: {
"nv-rerank-qa-mistral-4b:1": "https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking",
"nvidia/nv-rerankqa-mistral-4b-v3": "https://ai.api.nvidia.com/v1/retrieval/nvidia/nv-rerankqa-mistral-4b-v3/reranking",
"nvidia/llama-3.2-nv-rerankqa-1b-v2": "https://ai.api.nvidia.com/v1/retrieval/nvidia/llama-3_2-nv-rerankqa-1b-v2/reranking",
},
description="Mapping of rerank model identifiers to their API endpoints. ",
)
@classmethod
def sample_run_config(

View file

@ -5,6 +5,19 @@
# the root directory of this source tree.
from collections.abc import Iterable
import aiohttp
from llama_stack.apis.inference import (
RerankData,
RerankResponse,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
@ -61,3 +74,101 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
:return: The NVIDIA API base URL
"""
return f"{self.config.url}/v1" if self.config.append_api_version else self.config.url
async def list_provider_model_ids(self) -> Iterable[str]:
"""
Return both dynamic model IDs and statically configured rerank model IDs.
"""
dynamic_ids: Iterable[str] = []
try:
dynamic_ids = await super().list_provider_model_ids()
except Exception:
# If the dynamic listing fails, proceed with just configured rerank IDs
dynamic_ids = []
configured_rerank_ids = list(self.config.rerank_model_to_url.keys())
return list(dict.fromkeys(list(dynamic_ids) + configured_rerank_ids)) # remove duplicates
def construct_model_from_identifier(self, identifier: str) -> Model:
"""
Classify rerank models from config; otherwise use the base behavior.
"""
if identifier in self.config.rerank_model_to_url:
return Model(
provider_id=self.__provider_id__, # type: ignore[attr-defined]
provider_resource_id=identifier,
identifier=identifier,
model_type=ModelType.rerank,
)
return super().construct_model_from_identifier(identifier)
async def rerank(
self,
model: str,
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
max_num_results: int | None = None,
) -> RerankResponse:
provider_model_id = await self._get_provider_model_id(model)
ranking_url = self.get_base_url()
if _is_nvidia_hosted(self.config) and provider_model_id in self.config.rerank_model_to_url:
ranking_url = self.config.rerank_model_to_url[provider_model_id]
logger.debug(f"Using rerank endpoint: {ranking_url} for model: {provider_model_id}")
# Convert query to text format
if isinstance(query, str):
query_text = query
elif isinstance(query, OpenAIChatCompletionContentPartTextParam):
query_text = query.text
else:
raise ValueError("Query must be a string or text content part")
# Convert items to text format
passages = []
for item in items:
if isinstance(item, str):
passages.append({"text": item})
elif isinstance(item, OpenAIChatCompletionContentPartTextParam):
passages.append({"text": item.text})
else:
raise ValueError("Items must be strings or text content parts")
payload = {
"model": provider_model_id,
"query": {"text": query_text},
"passages": passages,
}
headers = {
"Authorization": f"Bearer {self.get_api_key()}",
"Content-Type": "application/json",
}
try:
async with aiohttp.ClientSession() as session:
async with session.post(ranking_url, headers=headers, json=payload) as response:
if response.status != 200:
response_text = await response.text()
raise ConnectionError(
f"NVIDIA rerank API request failed with status {response.status}: {response_text}"
)
result = await response.json()
rankings = result.get("rankings", [])
# Convert to RerankData format
rerank_data = []
for ranking in rankings:
rerank_data.append(RerankData(index=ranking["index"], relevance_score=ranking["logit"]))
# Apply max_num_results limit
if max_num_results is not None:
rerank_data = rerank_data[:max_num_results]
return RerankResponse(data=rerank_data)
except aiohttp.ClientError as e:
raise ConnectionError(f"Failed to connect to NVIDIA rerank API at {ranking_url}: {e}") from e

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import Iterable
import google.auth.transport.requests
from google.auth import default
@ -42,3 +43,12 @@ class VertexAIInferenceAdapter(OpenAIMixin):
Source: https://cloud.google.com/vertex-ai/generative-ai/docs/start/openai
"""
return f"https://{self.config.location}-aiplatform.googleapis.com/v1/projects/{self.config.project}/locations/{self.config.location}/endpoints/openapi"
async def list_provider_model_ids(self) -> Iterable[str]:
"""
VertexAI doesn't currently offer a way to query a list of available models from Google's Model Garden
For now we return a hardcoded version of the available models
:return: An iterable of model IDs
"""
return ["vertexai/gemini-2.0-flash", "vertexai/gemini-2.5-flash", "vertexai/gemini-2.5-pro"]

View file

@ -35,6 +35,7 @@ class InferenceStore:
self.reference = reference
self.sql_store = None
self.policy = policy
self.enable_write_queue = True
# Async write queue and worker control
self._queue: asyncio.Queue[tuple[OpenAIChatCompletion, list[OpenAIMessageParam]]] | None = None
@ -47,14 +48,13 @@ class InferenceStore:
base_store = sqlstore_impl(self.reference)
self.sql_store = AuthorizedSqlStore(base_store, self.policy)
# Disable write queue for SQLite to avoid concurrency issues
backend_name = self.reference.backend
backend_config = _SQLSTORE_BACKENDS.get(backend_name)
if backend_config is None:
raise ValueError(
f"Unregistered SQL backend '{backend_name}'. Registered backends: {sorted(_SQLSTORE_BACKENDS)}"
)
self.enable_write_queue = backend_config.type != StorageBackendType.SQL_SQLITE
# Disable write queue for SQLite since WAL mode handles concurrency
# Keep it enabled for other backends (like Postgres) for performance
backend_config = _SQLSTORE_BACKENDS.get(self.reference.backend)
if backend_config and backend_config.type == StorageBackendType.SQL_SQLITE:
self.enable_write_queue = False
logger.debug("Write queue disabled for SQLite (WAL mode handles concurrency)")
await self.sql_store.create_table(
"chat_completions",
{
@ -70,8 +70,9 @@ class InferenceStore:
self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
for _ in range(self._num_writers):
self._worker_tasks.append(asyncio.create_task(self._worker_loop()))
else:
logger.info("Write queue disabled for SQLite to avoid concurrency issues")
logger.debug(
f"Inference store write queue enabled with {self._num_writers} writers, max queue size {self._max_write_queue_size}"
)
async def shutdown(self) -> None:
if not self._worker_tasks:

View file

@ -128,7 +128,9 @@ class LiteLLMOpenAIMixin(
return schema
async def _get_params(self, request: ChatCompletionRequest) -> dict:
input_dict = {}
from typing import Any
input_dict: dict[str, Any] = {}
input_dict["messages"] = [
await convert_message_to_openai_dict_new(m, download_images=self.download_images) for m in request.messages
@ -139,30 +141,27 @@ class LiteLLMOpenAIMixin(
f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported."
)
fmt = fmt.json_schema
name = fmt["title"]
del fmt["title"]
fmt["additionalProperties"] = False
# Convert to dict for manipulation
fmt_dict = dict(fmt.json_schema)
name = fmt_dict["title"]
del fmt_dict["title"]
fmt_dict["additionalProperties"] = False
# Apply additionalProperties: False recursively to all objects
fmt = self._add_additional_properties_recursive(fmt)
fmt_dict = self._add_additional_properties_recursive(fmt_dict)
input_dict["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": name,
"schema": fmt,
"schema": fmt_dict,
"strict": self.json_schema_strict,
},
}
if request.tools:
input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools]
if request.tool_config.tool_choice:
input_dict["tool_choice"] = (
request.tool_config.tool_choice.value
if isinstance(request.tool_config.tool_choice, ToolChoice)
else request.tool_config.tool_choice
)
if request.tool_config and (tool_choice := request.tool_config.tool_choice):
input_dict["tool_choice"] = tool_choice.value if isinstance(tool_choice, ToolChoice) else tool_choice
return {
"model": request.model,
@ -176,9 +175,9 @@ class LiteLLMOpenAIMixin(
def get_api_key(self) -> str:
provider_data = self.get_request_provider_data()
key_field = self.provider_data_api_key_field
if provider_data and getattr(provider_data, key_field, None):
api_key = getattr(provider_data, key_field)
else:
if provider_data and key_field and (api_key := getattr(provider_data, key_field, None)):
return str(api_key) # type: ignore[no-any-return] # getattr returns Any, can't narrow without runtime type inspection
api_key = self.api_key_from_config
if not api_key:
raise ValueError(
@ -192,7 +191,13 @@ class LiteLLMOpenAIMixin(
self,
params: OpenAIEmbeddingsRequestWithExtraBody,
) -> OpenAIEmbeddingsResponse:
if not self.model_store:
raise ValueError("Model store is not initialized")
model_obj = await self.model_store.get_model(params.model)
if model_obj.provider_resource_id is None:
raise ValueError(f"Model {params.model} has no provider_resource_id")
provider_resource_id = model_obj.provider_resource_id
# Convert input to list if it's a string
input_list = [params.input] if isinstance(params.input, str) else params.input
@ -200,7 +205,7 @@ class LiteLLMOpenAIMixin(
# Call litellm embedding function
# litellm.drop_params = True
response = litellm.embedding(
model=self.get_litellm_model_name(model_obj.provider_resource_id),
model=self.get_litellm_model_name(provider_resource_id),
input=input_list,
api_key=self.get_api_key(),
api_base=self.api_base,
@ -217,7 +222,7 @@ class LiteLLMOpenAIMixin(
return OpenAIEmbeddingsResponse(
data=data,
model=model_obj.provider_resource_id,
model=provider_resource_id,
usage=usage,
)
@ -225,10 +230,16 @@ class LiteLLMOpenAIMixin(
self,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
if not self.model_store:
raise ValueError("Model store is not initialized")
model_obj = await self.model_store.get_model(params.model)
if model_obj.provider_resource_id is None:
raise ValueError(f"Model {params.model} has no provider_resource_id")
provider_resource_id = model_obj.provider_resource_id
request_params = await prepare_openai_completion_params(
model=self.get_litellm_model_name(model_obj.provider_resource_id),
model=self.get_litellm_model_name(provider_resource_id),
prompt=params.prompt,
best_of=params.best_of,
echo=params.echo,
@ -249,7 +260,8 @@ class LiteLLMOpenAIMixin(
api_key=self.get_api_key(),
api_base=self.api_base,
)
return await litellm.atext_completion(**request_params)
# LiteLLM returns compatible type but mypy can't verify external library
return await litellm.atext_completion(**request_params) # type: ignore[no-any-return] # external lib lacks type stubs
async def openai_chat_completion(
self,
@ -265,10 +277,16 @@ class LiteLLMOpenAIMixin(
elif "include_usage" not in stream_options:
stream_options = {**stream_options, "include_usage": True}
if not self.model_store:
raise ValueError("Model store is not initialized")
model_obj = await self.model_store.get_model(params.model)
if model_obj.provider_resource_id is None:
raise ValueError(f"Model {params.model} has no provider_resource_id")
provider_resource_id = model_obj.provider_resource_id
request_params = await prepare_openai_completion_params(
model=self.get_litellm_model_name(model_obj.provider_resource_id),
model=self.get_litellm_model_name(provider_resource_id),
messages=params.messages,
frequency_penalty=params.frequency_penalty,
function_call=params.function_call,
@ -294,7 +312,8 @@ class LiteLLMOpenAIMixin(
api_key=self.get_api_key(),
api_base=self.api_base,
)
return await litellm.acompletion(**request_params)
# LiteLLM returns compatible type but mypy can't verify external library
return await litellm.acompletion(**request_params) # type: ignore[no-any-return] # external lib lacks type stubs
async def check_model_availability(self, model: str) -> bool:
"""

View file

@ -20,7 +20,7 @@ logger = get_logger(name=__name__, category="providers::utils")
class RemoteInferenceProviderConfig(BaseModel):
allowed_models: list[str] | None = Field( # TODO: make this non-optional and give a list() default
allowed_models: list[str] | None = Field(
default=None,
description="List of models that should be registered with the model registry. If None, all models are allowed.",
)

View file

@ -161,7 +161,9 @@ def get_sampling_strategy_options(params: SamplingParams) -> dict:
if isinstance(params.strategy, GreedySamplingStrategy):
options["temperature"] = 0.0
elif isinstance(params.strategy, TopPSamplingStrategy):
if params.strategy.temperature is not None:
options["temperature"] = params.strategy.temperature
if params.strategy.top_p is not None:
options["top_p"] = params.strategy.top_p
elif isinstance(params.strategy, TopKSamplingStrategy):
options["top_k"] = params.strategy.top_k
@ -192,12 +194,12 @@ def get_sampling_options(params: SamplingParams | None) -> dict:
def text_from_choice(choice) -> str:
if hasattr(choice, "delta") and choice.delta:
return choice.delta.content
return choice.delta.content # type: ignore[no-any-return] # external OpenAI types lack precise annotations
if hasattr(choice, "message"):
return choice.message.content
return choice.message.content # type: ignore[no-any-return] # external OpenAI types lack precise annotations
return choice.text
return choice.text # type: ignore[no-any-return] # external OpenAI types lack precise annotations
def get_stop_reason(finish_reason: str) -> StopReason:
@ -216,7 +218,7 @@ def convert_openai_completion_logprobs(
) -> list[TokenLogProbs] | None:
if not logprobs:
return None
if hasattr(logprobs, "top_logprobs"):
if hasattr(logprobs, "top_logprobs") and logprobs.top_logprobs:
return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs]
# Together supports logprobs with top_k=1 only. This means for each token position,
@ -236,7 +238,7 @@ def convert_openai_completion_logprobs_stream(text: str, logprobs: float | OpenA
if isinstance(logprobs, float):
# Adapt response from Together CompletionChoicesChunk
return [TokenLogProbs(logprobs_by_token={text: logprobs})]
if hasattr(logprobs, "top_logprobs"):
if hasattr(logprobs, "top_logprobs") and logprobs.top_logprobs:
return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs]
return None
@ -245,23 +247,24 @@ def process_completion_response(
response: OpenAICompatCompletionResponse,
) -> CompletionResponse:
choice = response.choices[0]
text = choice.text or ""
# drop suffix <eot_id> if present and return stop reason as end of turn
if choice.text.endswith("<|eot_id|>"):
if text.endswith("<|eot_id|>"):
return CompletionResponse(
stop_reason=StopReason.end_of_turn,
content=choice.text[: -len("<|eot_id|>")],
content=text[: -len("<|eot_id|>")],
logprobs=convert_openai_completion_logprobs(choice.logprobs),
)
# drop suffix <eom_id> if present and return stop reason as end of message
if choice.text.endswith("<|eom_id|>"):
if text.endswith("<|eom_id|>"):
return CompletionResponse(
stop_reason=StopReason.end_of_message,
content=choice.text[: -len("<|eom_id|>")],
content=text[: -len("<|eom_id|>")],
logprobs=convert_openai_completion_logprobs(choice.logprobs),
)
return CompletionResponse(
stop_reason=get_stop_reason(choice.finish_reason),
content=choice.text,
stop_reason=get_stop_reason(choice.finish_reason or "stop"),
content=text,
logprobs=convert_openai_completion_logprobs(choice.logprobs),
)
@ -272,10 +275,10 @@ def process_chat_completion_response(
) -> ChatCompletionResponse:
choice = response.choices[0]
if choice.finish_reason == "tool_calls":
if not choice.message or not choice.message.tool_calls:
if not hasattr(choice, "message") or not choice.message or not choice.message.tool_calls: # type: ignore[attr-defined] # OpenAICompatCompletionChoice is runtime duck-typed
raise ValueError("Tool calls are not present in the response")
tool_calls = [convert_tool_call(tool_call) for tool_call in choice.message.tool_calls]
tool_calls = [convert_tool_call(tool_call) for tool_call in choice.message.tool_calls] # type: ignore[attr-defined] # OpenAICompatCompletionChoice is runtime duck-typed
if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls):
# If we couldn't parse a tool call, jsonify the tool calls and return them
return ChatCompletionResponse(
@ -287,9 +290,11 @@ def process_chat_completion_response(
)
else:
# Otherwise, return tool calls as normal
# Filter to only valid ToolCall objects
valid_tool_calls = [tc for tc in tool_calls if isinstance(tc, ToolCall)]
return ChatCompletionResponse(
completion_message=CompletionMessage(
tool_calls=tool_calls,
tool_calls=valid_tool_calls,
stop_reason=StopReason.end_of_turn,
# Content is not optional
content="",
@ -299,7 +304,7 @@ def process_chat_completion_response(
# TODO: This does not work well with tool calls for vLLM remote provider
# Ref: https://github.com/meta-llama/llama-stack/issues/1058
raw_message = decode_assistant_message(text_from_choice(choice), get_stop_reason(choice.finish_reason))
raw_message = decode_assistant_message(text_from_choice(choice), get_stop_reason(choice.finish_reason or "stop"))
# NOTE: If we do not set tools in chat-completion request, we should not
# expect the ToolCall in the response. Instead, we should return the raw
@ -324,8 +329,8 @@ def process_chat_completion_response(
return ChatCompletionResponse(
completion_message=CompletionMessage(
content=raw_message.content,
stop_reason=raw_message.stop_reason,
content=raw_message.content, # type: ignore[arg-type] # decode_assistant_message returns Union[str, InterleavedContent]
stop_reason=raw_message.stop_reason or StopReason.end_of_turn,
tool_calls=raw_message.tool_calls,
),
logprobs=None,
@ -448,7 +453,7 @@ async def process_chat_completion_stream_response(
)
# parse tool calls and report errors
message = decode_assistant_message(buffer, stop_reason)
message = decode_assistant_message(buffer, stop_reason or StopReason.end_of_turn)
parsed_tool_calls = len(message.tool_calls) > 0
if ipython and not parsed_tool_calls:
@ -463,7 +468,7 @@ async def process_chat_completion_stream_response(
)
)
request_tools = {t.tool_name: t for t in request.tools}
request_tools = {t.tool_name: t for t in (request.tools or [])}
for tool_call in message.tool_calls:
if tool_call.tool_name in request_tools:
yield ChatCompletionResponseStreamChunk(
@ -525,7 +530,7 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals
}
if hasattr(message, "tool_calls") and message.tool_calls:
result["tool_calls"] = []
tool_calls_list = []
for tc in message.tool_calls:
# The tool.tool_name can be a str or a BuiltinTool enum. If
# it's the latter, convert to a string.
@ -533,7 +538,7 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals
if isinstance(tool_name, BuiltinTool):
tool_name = tool_name.value
result["tool_calls"].append(
tool_calls_list.append(
{
"id": tc.call_id,
"type": "function",
@ -543,6 +548,7 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals
},
}
)
result["tool_calls"] = tool_calls_list # type: ignore[assignment] # dict allows Any value, stricter type expected
return result
@ -608,7 +614,7 @@ async def convert_message_to_openai_dict_new(
),
)
elif isinstance(content_, list):
return [await impl(item) for item in content_]
return [await impl(item) for item in content_] # type: ignore[misc] # recursive list comprehension confuses mypy's type narrowing
else:
raise ValueError(f"Unsupported content type: {type(content_)}")
@ -620,7 +626,7 @@ async def convert_message_to_openai_dict_new(
else:
return [ret]
out: OpenAIChatCompletionMessage = None
out: OpenAIChatCompletionMessage
if isinstance(message, UserMessage):
out = OpenAIChatCompletionUserMessage(
role="user",
@ -636,7 +642,7 @@ async def convert_message_to_openai_dict_new(
),
type="function",
)
for tool in message.tool_calls
for tool in (message.tool_calls or [])
]
params = {}
if tool_calls:
@ -644,18 +650,18 @@ async def convert_message_to_openai_dict_new(
out = OpenAIChatCompletionAssistantMessage(
role="assistant",
content=await _convert_message_content(message.content),
**params,
**params, # type: ignore[typeddict-item] # tool_calls dict expansion conflicts with TypedDict optional field
)
elif isinstance(message, ToolResponseMessage):
out = OpenAIChatCompletionToolMessage(
role="tool",
tool_call_id=message.call_id,
content=await _convert_message_content(message.content),
content=await _convert_message_content(message.content), # type: ignore[typeddict-item] # content union type incompatible with TypedDict str requirement
)
elif isinstance(message, SystemMessage):
out = OpenAIChatCompletionSystemMessage(
role="system",
content=await _convert_message_content(message.content),
content=await _convert_message_content(message.content), # type: ignore[typeddict-item] # content union type incompatible with TypedDict str requirement
)
else:
raise ValueError(f"Unsupported message type: {type(message)}")
@ -758,16 +764,16 @@ def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
function = out["function"]
if isinstance(tool.tool_name, BuiltinTool):
function["name"] = tool.tool_name.value
function["name"] = tool.tool_name.value # type: ignore[index] # dict value inferred as Any but mypy sees Collection[str]
else:
function["name"] = tool.tool_name
function["name"] = tool.tool_name # type: ignore[index] # dict value inferred as Any but mypy sees Collection[str]
if tool.description:
function["description"] = tool.description
function["description"] = tool.description # type: ignore[index] # dict value inferred as Any but mypy sees Collection[str]
if tool.input_schema:
# Pass through the entire JSON Schema as-is
function["parameters"] = tool.input_schema
function["parameters"] = tool.input_schema # type: ignore[index] # dict value inferred as Any but mypy sees Collection[str]
# NOTE: OpenAI does not support output_schema, so we drop it here
# It's stored in LlamaStack for validation and other provider usage
@ -815,15 +821,15 @@ def _convert_openai_request_tool_config(tool_choice: str | dict[str, Any] | None
tool_config = ToolConfig()
if tool_choice:
try:
tool_choice = ToolChoice(tool_choice)
tool_choice = ToolChoice(tool_choice) # type: ignore[assignment] # reassigning to enum narrows union but mypy can't track after exception
except ValueError:
pass
tool_config.tool_choice = tool_choice
tool_config.tool_choice = tool_choice # type: ignore[assignment] # ToolConfig.tool_choice accepts Union[ToolChoice, dict] but mypy tracks narrower type
return tool_config
def _convert_openai_request_tools(tools: list[dict[str, Any]] | None = None) -> list[ToolDefinition]:
lls_tools = []
lls_tools: list[ToolDefinition] = []
if not tools:
return lls_tools
@ -843,16 +849,16 @@ def _convert_openai_request_tools(tools: list[dict[str, Any]] | None = None) ->
def _convert_openai_request_response_format(
response_format: OpenAIResponseFormatParam = None,
response_format: OpenAIResponseFormatParam | None = None,
):
if not response_format:
return None
# response_format can be a dict or a pydantic model
response_format = dict(response_format)
if response_format.get("type", "") == "json_schema":
response_format_dict = dict(response_format) # type: ignore[arg-type] # OpenAIResponseFormatParam union needs dict conversion
if response_format_dict.get("type", "") == "json_schema":
return JsonSchemaResponseFormat(
type="json_schema",
json_schema=response_format.get("json_schema", {}).get("schema", ""),
type="json_schema", # type: ignore[arg-type] # Literal["json_schema"] incompatible with expected type
json_schema=response_format_dict.get("json_schema", {}).get("schema", ""),
)
return None
@ -938,16 +944,15 @@ def _convert_openai_sampling_params(
# Map an explicit temperature of 0 to greedy sampling
if temperature == 0:
strategy = GreedySamplingStrategy()
sampling_params.strategy = GreedySamplingStrategy()
else:
# OpenAI defaults to 1.0 for temperature and top_p if unset
if temperature is None:
temperature = 1.0
if top_p is None:
top_p = 1.0
strategy = TopPSamplingStrategy(temperature=temperature, top_p=top_p)
sampling_params.strategy = TopPSamplingStrategy(temperature=temperature, top_p=top_p) # type: ignore[assignment] # SamplingParams.strategy union accepts this type
sampling_params.strategy = strategy
return sampling_params
@ -957,23 +962,24 @@ def openai_messages_to_messages(
"""
Convert a list of OpenAIChatCompletionMessage into a list of Message.
"""
converted_messages = []
converted_messages: list[Message] = []
for message in messages:
converted_message: Message
if message.role == "system":
converted_message = SystemMessage(content=openai_content_to_content(message.content))
converted_message = SystemMessage(content=openai_content_to_content(message.content)) # type: ignore[arg-type] # OpenAI SDK uses aliased types internally that mypy sees as incompatible with base types
elif message.role == "user":
converted_message = UserMessage(content=openai_content_to_content(message.content))
converted_message = UserMessage(content=openai_content_to_content(message.content)) # type: ignore[arg-type] # OpenAI SDK uses aliased types internally that mypy sees as incompatible with base types
elif message.role == "assistant":
converted_message = CompletionMessage(
content=openai_content_to_content(message.content),
tool_calls=_convert_openai_tool_calls(message.tool_calls),
content=openai_content_to_content(message.content), # type: ignore[arg-type] # OpenAI SDK uses aliased types internally that mypy sees as incompatible with base types
tool_calls=_convert_openai_tool_calls(message.tool_calls) if message.tool_calls else [], # type: ignore[arg-type] # OpenAI tool_calls type incompatible with conversion function
stop_reason=StopReason.end_of_turn,
)
elif message.role == "tool":
converted_message = ToolResponseMessage(
role="tool",
call_id=message.tool_call_id,
content=openai_content_to_content(message.content),
content=openai_content_to_content(message.content), # type: ignore[arg-type] # OpenAI SDK uses aliased types internally that mypy sees as incompatible with base types
)
else:
raise ValueError(f"Unknown role {message.role}")
@ -990,9 +996,9 @@ def openai_content_to_content(content: str | Iterable[OpenAIChatCompletionConten
return [openai_content_to_content(c) for c in content]
elif hasattr(content, "type"):
if content.type == "text":
return TextContentItem(type="text", text=content.text)
return TextContentItem(type="text", text=content.text) # type: ignore[attr-defined] # Iterable narrowed by hasattr check but mypy doesn't track
elif content.type == "image_url":
return ImageContentItem(type="image", image=_URLOrData(url=URL(uri=content.image_url.url)))
return ImageContentItem(type="image", image=_URLOrData(url=URL(uri=content.image_url.url))) # type: ignore[attr-defined] # Iterable narrowed by hasattr check but mypy doesn't track
else:
raise ValueError(f"Unknown content type: {content.type}")
else:
@ -1041,9 +1047,9 @@ def convert_openai_chat_completion_choice(
completion_message=CompletionMessage(
content=choice.message.content or "", # CompletionMessage content is not optional
stop_reason=_convert_openai_finish_reason(choice.finish_reason),
tool_calls=_convert_openai_tool_calls(choice.message.tool_calls),
tool_calls=_convert_openai_tool_calls(choice.message.tool_calls) if choice.message.tool_calls else [], # type: ignore[arg-type] # OpenAI tool_calls Optional type broadens union
),
logprobs=_convert_openai_logprobs(getattr(choice, "logprobs", None)),
logprobs=_convert_openai_logprobs(getattr(choice, "logprobs", None)), # type: ignore[arg-type] # getattr returns Any, can't narrow without inspection
)
@ -1070,7 +1076,7 @@ async def convert_openai_chat_completion_stream(
choice = chunk.choices[0] # assuming only one choice per chunk
# we assume there's only one finish_reason in the stream
stop_reason = _convert_openai_finish_reason(choice.finish_reason) or stop_reason
stop_reason = _convert_openai_finish_reason(choice.finish_reason) if choice.finish_reason else stop_reason
logprobs = getattr(choice, "logprobs", None)
# if there's a tool call, emit an event for each tool in the list
@ -1083,7 +1089,7 @@ async def convert_openai_chat_completion_stream(
event=ChatCompletionResponseEvent(
event_type=event_type,
delta=TextDelta(text=choice.delta.content),
logprobs=_convert_openai_logprobs(logprobs),
logprobs=_convert_openai_logprobs(logprobs), # type: ignore[arg-type] # logprobs type broadened from getattr result
)
)
@ -1101,10 +1107,10 @@ async def convert_openai_chat_completion_stream(
event=ChatCompletionResponseEvent(
event_type=event_type,
delta=ToolCallDelta(
tool_call=_convert_openai_tool_calls([tool_call])[0],
tool_call=_convert_openai_tool_calls([tool_call])[0], # type: ignore[arg-type, list-item] # delta tool_call type differs from complete tool_call
parse_status=ToolCallParseStatus.succeeded,
),
logprobs=_convert_openai_logprobs(logprobs),
logprobs=_convert_openai_logprobs(logprobs), # type: ignore[arg-type] # logprobs type broadened from getattr result
)
)
else:
@ -1125,11 +1131,14 @@ async def convert_openai_chat_completion_stream(
if tool_call.function.name:
buffer["name"] = tool_call.function.name
delta = f"{buffer['name']}("
if buffer["content"] is not None:
buffer["content"] += delta
if tool_call.function.arguments:
delta = tool_call.function.arguments
if buffer["arguments"] is not None and delta:
buffer["arguments"] += delta
if buffer["content"] is not None and delta:
buffer["content"] += delta
yield ChatCompletionResponseStreamChunk(
@ -1139,7 +1148,7 @@ async def convert_openai_chat_completion_stream(
tool_call=delta,
parse_status=ToolCallParseStatus.in_progress,
),
logprobs=_convert_openai_logprobs(logprobs),
logprobs=_convert_openai_logprobs(logprobs), # type: ignore[arg-type] # logprobs type broadened from getattr result
)
)
elif choice.delta.content:
@ -1147,7 +1156,7 @@ async def convert_openai_chat_completion_stream(
event=ChatCompletionResponseEvent(
event_type=event_type,
delta=TextDelta(text=choice.delta.content or ""),
logprobs=_convert_openai_logprobs(logprobs),
logprobs=_convert_openai_logprobs(logprobs), # type: ignore[arg-type] # logprobs type broadened from getattr result
)
)
@ -1155,6 +1164,7 @@ async def convert_openai_chat_completion_stream(
logger.debug(f"toolcall_buffer[{idx}]: {buffer}")
if buffer["name"]:
delta = ")"
if buffer["content"] is not None:
buffer["content"] += delta
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
@ -1168,16 +1178,16 @@ async def convert_openai_chat_completion_stream(
)
try:
tool_call = ToolCall(
call_id=buffer["call_id"],
tool_name=buffer["name"],
arguments=buffer["arguments"],
parsed_tool_call = ToolCall(
call_id=buffer["call_id"] or "",
tool_name=buffer["name"] or "",
arguments=buffer["arguments"] or "",
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
tool_call=tool_call,
tool_call=parsed_tool_call, # type: ignore[arg-type] # ToolCallDelta.tool_call accepts Union[str, ToolCall]
parse_status=ToolCallParseStatus.succeeded,
),
stop_reason=stop_reason,
@ -1189,7 +1199,7 @@ async def convert_openai_chat_completion_stream(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
tool_call=buffer["content"],
tool_call=buffer["content"], # type: ignore[arg-type] # ToolCallDelta.tool_call accepts Union[str, ToolCall]
parse_status=ToolCallParseStatus.failed,
),
stop_reason=stop_reason,
@ -1250,7 +1260,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
messages = openai_messages_to_messages(messages)
messages = openai_messages_to_messages(messages) # type: ignore[assignment] # converted from OpenAI to LlamaStack message format
response_format = _convert_openai_request_response_format(response_format)
sampling_params = _convert_openai_sampling_params(
max_tokens=max_tokens,
@ -1259,15 +1269,15 @@ class OpenAIChatCompletionToLlamaStackMixin:
)
tool_config = _convert_openai_request_tool_config(tool_choice)
tools = _convert_openai_request_tools(tools)
tools = _convert_openai_request_tools(tools) # type: ignore[assignment] # converted from OpenAI to LlamaStack tool format
if tool_config.tool_choice == ToolChoice.none:
tools = []
tools = [] # type: ignore[assignment] # empty list narrows return type but mypy tracks broader type
outstanding_responses = []
# "n" is the number of completions to generate per prompt
n = n or 1
for _i in range(0, n):
response = self.chat_completion(
response = self.chat_completion( # type: ignore[attr-defined] # mixin expects class to implement chat_completion
model_id=model,
messages=messages,
sampling_params=sampling_params,
@ -1279,7 +1289,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
outstanding_responses.append(response)
if stream:
return OpenAIChatCompletionToLlamaStackMixin._process_stream_response(self, model, outstanding_responses)
return OpenAIChatCompletionToLlamaStackMixin._process_stream_response(self, model, outstanding_responses) # type: ignore[no-any-return] # mixin async generator return type too complex for mypy
return await OpenAIChatCompletionToLlamaStackMixin._process_non_stream_response(
self, model, outstanding_responses
@ -1295,14 +1305,16 @@ class OpenAIChatCompletionToLlamaStackMixin:
response = await outstanding_response
async for chunk in response:
event = chunk.event
finish_reason = _convert_stop_reason_to_openai_finish_reason(event.stop_reason)
finish_reason = (
_convert_stop_reason_to_openai_finish_reason(event.stop_reason) if event.stop_reason else None
)
if isinstance(event.delta, TextDelta):
text_delta = event.delta.text
delta = OpenAIChoiceDelta(content=text_delta)
yield OpenAIChatCompletionChunk(
id=id,
choices=[OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)],
choices=[OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)], # type: ignore[arg-type] # finish_reason Optional[str] incompatible with Literal union
created=int(time.time()),
model=model,
object="chat.completion.chunk",
@ -1310,13 +1322,17 @@ class OpenAIChatCompletionToLlamaStackMixin:
elif isinstance(event.delta, ToolCallDelta):
if event.delta.parse_status == ToolCallParseStatus.succeeded:
tool_call = event.delta.tool_call
if isinstance(tool_call, str):
continue
# First chunk includes full structure
openai_tool_call = OpenAIChoiceDeltaToolCall(
index=0,
id=tool_call.call_id,
function=OpenAIChoiceDeltaToolCallFunction(
name=tool_call.tool_name,
name=tool_call.tool_name
if isinstance(tool_call.tool_name, str)
else tool_call.tool_name.value, # type: ignore[arg-type] # enum .value extraction on Union confuses mypy
arguments="",
),
)
@ -1324,7 +1340,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
yield OpenAIChatCompletionChunk(
id=id,
choices=[
OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)
OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta) # type: ignore[arg-type] # finish_reason Optional[str] incompatible with Literal union
],
created=int(time.time()),
model=model,
@ -1341,7 +1357,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
yield OpenAIChatCompletionChunk(
id=id,
choices=[
OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)
OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta) # type: ignore[arg-type] # finish_reason Optional[str] incompatible with Literal union
],
created=int(time.time()),
model=model,
@ -1351,7 +1367,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
async def _process_non_stream_response(
self, model: str, outstanding_responses: list[Awaitable[ChatCompletionResponse]]
) -> OpenAIChatCompletion:
choices = []
choices: list[OpenAIChatCompletionChoice] = []
for outstanding_response in outstanding_responses:
response = await outstanding_response
completion_message = response.completion_message
@ -1360,14 +1376,14 @@ class OpenAIChatCompletionToLlamaStackMixin:
choice = OpenAIChatCompletionChoice(
index=len(choices),
message=message,
message=message, # type: ignore[arg-type] # OpenAIChatCompletionMessage union incompatible with narrower Message type
finish_reason=finish_reason,
)
choices.append(choice)
choices.append(choice) # type: ignore[arg-type] # OpenAIChatCompletionChoice type annotation mismatch
return OpenAIChatCompletion(
id=f"chatcmpl-{uuid.uuid4()}",
choices=choices,
choices=choices, # type: ignore[arg-type] # list[OpenAIChatCompletionChoice] union incompatible
created=int(time.time()),
model=model,
object="chat.completion",

View file

@ -83,9 +83,6 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
# This is set in list_models() and used in check_model_availability()
_model_cache: dict[str, Model] = {}
# List of allowed models for this provider, if empty all models allowed
allowed_models: list[str] = []
# Optional field name in provider data to look for API key, which takes precedence
provider_data_api_key_field: str | None = None
@ -441,7 +438,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
for provider_model_id in provider_models_ids:
if not isinstance(provider_model_id, str):
raise ValueError(f"Model ID {provider_model_id} from list_provider_model_ids() is not a string")
if self.allowed_models and provider_model_id not in self.allowed_models:
if self.config.allowed_models is not None and provider_model_id not in self.config.allowed_models:
logger.info(f"Skipping model {provider_model_id} as it is not in the allowed models list")
continue
model = self.construct_model_from_identifier(provider_model_id)

View file

@ -196,6 +196,7 @@ def make_overlapped_chunks(
chunks.append(
Chunk(
content=chunk,
chunk_id=chunk_id,
metadata=chunk_metadata,
chunk_metadata=backend_chunk_metadata,
)

View file

@ -70,13 +70,13 @@ class ResponsesStore:
base_store = sqlstore_impl(self.reference)
self.sql_store = AuthorizedSqlStore(base_store, self.policy)
# Disable write queue for SQLite since WAL mode handles concurrency
# Keep it enabled for other backends (like Postgres) for performance
backend_config = _SQLSTORE_BACKENDS.get(self.reference.backend)
if backend_config is None:
raise ValueError(
f"Unregistered SQL backend '{self.reference.backend}'. Registered backends: {sorted(_SQLSTORE_BACKENDS)}"
)
if backend_config.type == StorageBackendType.SQL_SQLITE:
if backend_config and backend_config.type == StorageBackendType.SQL_SQLITE:
self.enable_write_queue = False
logger.debug("Write queue disabled for SQLite (WAL mode handles concurrency)")
await self.sql_store.create_table(
"openai_responses",
{
@ -99,8 +99,9 @@ class ResponsesStore:
self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
for _ in range(self._num_writers):
self._worker_tasks.append(asyncio.create_task(self._worker_loop()))
else:
logger.debug("Write queue disabled for SQLite to avoid concurrency issues")
logger.debug(
f"Responses store write queue enabled with {self._num_writers} writers, max queue size {self._max_write_queue_size}"
)
async def shutdown(self) -> None:
if not self._worker_tasks:

View file

@ -17,6 +17,7 @@ from sqlalchemy import (
String,
Table,
Text,
event,
inspect,
select,
text,
@ -75,7 +76,36 @@ class SqlAlchemySqlStoreImpl(SqlStore):
self.metadata = MetaData()
def create_engine(self) -> AsyncEngine:
return create_async_engine(self.config.engine_str, pool_pre_ping=True)
# Configure connection args for better concurrency support
connect_args = {}
if "sqlite" in self.config.engine_str:
# SQLite-specific optimizations for concurrent access
# With WAL mode, most locks resolve in milliseconds, but allow up to 5s for edge cases
connect_args["timeout"] = 5.0
connect_args["check_same_thread"] = False # Allow usage across asyncio tasks
engine = create_async_engine(
self.config.engine_str,
pool_pre_ping=True,
connect_args=connect_args,
)
# Enable WAL mode for SQLite to support concurrent readers and writers
if "sqlite" in self.config.engine_str:
@event.listens_for(engine.sync_engine, "connect")
def set_sqlite_pragma(dbapi_conn, connection_record):
cursor = dbapi_conn.cursor()
# Enable Write-Ahead Logging for better concurrency
cursor.execute("PRAGMA journal_mode=WAL")
# Set busy timeout to 5 seconds (retry instead of immediate failure)
# With WAL mode, locks should be brief; if we hit 5s there's a bigger issue
cursor.execute("PRAGMA busy_timeout=5000")
# Use NORMAL synchronous mode for better performance (still safe with WAL)
cursor.execute("PRAGMA synchronous=NORMAL")
cursor.close()
return engine
async def create_table(
self,

View file

@ -430,6 +430,32 @@ def _unwrap_generic_list(typ: type[list[T]]) -> type[T]:
return list_type # type: ignore[no-any-return]
def is_generic_sequence(typ: object) -> bool:
"True if the specified type is a generic Sequence, i.e. `Sequence[T]`."
import collections.abc
typ = unwrap_annotated_type(typ)
return typing.get_origin(typ) is collections.abc.Sequence
def unwrap_generic_sequence(typ: object) -> type:
"""
Extracts the item type of a Sequence type.
:param typ: The Sequence type `Sequence[T]`.
:returns: The item type `T`.
"""
return rewrap_annotated_type(_unwrap_generic_sequence, typ) # type: ignore[arg-type]
def _unwrap_generic_sequence(typ: object) -> type:
"Extracts the item type of a Sequence type (e.g. returns `T` for `Sequence[T]`)."
(sequence_type,) = typing.get_args(typ) # unpack single tuple element
return sequence_type # type: ignore[no-any-return]
def is_generic_set(typ: object) -> TypeGuard[type[set]]:
"True if the specified type is a generic set, i.e. `Set[T]`."

View file

@ -18,10 +18,12 @@ from .inspection import (
TypeLike,
is_generic_dict,
is_generic_list,
is_generic_sequence,
is_type_optional,
is_type_union,
unwrap_generic_dict,
unwrap_generic_list,
unwrap_generic_sequence,
unwrap_optional_type,
unwrap_union_types,
)
@ -155,24 +157,28 @@ def python_type_to_name(data_type: TypeLike, force: bool = False) -> str:
if metadata is not None:
# type is Annotated[T, ...]
arg = typing.get_args(data_type)[0]
return python_type_to_name(arg)
return python_type_to_name(arg, force=force)
if force:
# generic types
if is_type_optional(data_type, strict=True):
inner_name = python_type_to_name(unwrap_optional_type(data_type))
inner_name = python_type_to_name(unwrap_optional_type(data_type), force=True)
return f"Optional__{inner_name}"
elif is_generic_list(data_type):
item_name = python_type_to_name(unwrap_generic_list(data_type))
item_name = python_type_to_name(unwrap_generic_list(data_type), force=True)
return f"List__{item_name}"
elif is_generic_sequence(data_type):
# Treat Sequence the same as List for schema generation purposes
item_name = python_type_to_name(unwrap_generic_sequence(data_type), force=True)
return f"List__{item_name}"
elif is_generic_dict(data_type):
key_type, value_type = unwrap_generic_dict(data_type)
key_name = python_type_to_name(key_type)
value_name = python_type_to_name(value_type)
key_name = python_type_to_name(key_type, force=True)
value_name = python_type_to_name(value_type, force=True)
return f"Dict__{key_name}__{value_name}"
elif is_type_union(data_type):
member_types = unwrap_union_types(data_type)
member_names = "__".join(python_type_to_name(member_type) for member_type in member_types)
member_names = "__".join(python_type_to_name(member_type, force=True) for member_type in member_types)
return f"Union__{member_names}"
# named system or user-defined type

View file

@ -111,7 +111,7 @@ def get_class_property_docstrings(
def docstring_to_schema(data_type: type) -> Schema:
short_description, long_description = get_class_docstrings(data_type)
schema: Schema = {
"title": python_type_to_name(data_type),
"title": python_type_to_name(data_type, force=True),
}
description = "\n".join(filter(None, [short_description, long_description]))
@ -417,6 +417,10 @@ class JsonSchemaGenerator:
if origin_type is list:
(list_type,) = typing.get_args(typ) # unpack single tuple element
return {"type": "array", "items": self.type_to_schema(list_type)}
elif origin_type is collections.abc.Sequence:
# Treat Sequence the same as list for JSON schema (both are arrays)
(sequence_type,) = typing.get_args(typ) # unpack single tuple element
return {"type": "array", "items": self.type_to_schema(sequence_type)}
elif origin_type is dict:
key_type, value_type = typing.get_args(typ)
if not (key_type is str or key_type is int or is_type_enum(key_type)):

View file

@ -156,7 +156,7 @@ def normalize_inference_request(method: str, url: str, headers: dict[str, Any],
}
# Include test_id for isolation, except for shared infrastructure endpoints
if parsed.path not in ("/api/tags", "/v1/models"):
if parsed.path not in ("/api/tags", "/v1/models", "/v1/openai/v1/models"):
normalized["test_id"] = test_id
normalized_json = json.dumps(normalized, sort_keys=True)
@ -430,7 +430,7 @@ class ResponseStorage:
# For model-list endpoints, include digest in filename to distinguish different model sets
endpoint = request.get("endpoint")
if endpoint in ("/api/tags", "/v1/models"):
if endpoint in ("/api/tags", "/v1/models", "/v1/openai/v1/models"):
digest = _model_identifiers_digest(endpoint, response)
response_file = f"models-{request_hash}-{digest}.json"
@ -554,13 +554,14 @@ def _model_identifiers_digest(endpoint: str, response: dict[str, Any]) -> str:
Supported endpoints:
- '/api/tags' (Ollama): response body has 'models': [ { name/model/digest/id/... }, ... ]
- '/v1/models' (OpenAI): response body is: [ { id: ... }, ... ]
- '/v1/openai/v1/models' (OpenAI): response body is: [ { id: ... }, ... ]
Returns a list of unique identifiers or None if structure doesn't match.
"""
if "models" in response["body"]:
# ollama
items = response["body"]["models"]
else:
# openai
# openai or openai-style endpoints
items = response["body"]
idents = [m.model if endpoint == "/api/tags" else m.id for m in items]
return sorted(set(idents))
@ -581,7 +582,7 @@ def _combine_model_list_responses(endpoint: str, records: list[dict[str, Any]])
seen: dict[str, dict[str, Any]] = {}
for rec in records:
body = rec["response"]["body"]
if endpoint == "/v1/models":
if endpoint in ("/v1/models", "/v1/openai/v1/models"):
for m in body:
key = m.id
seen[key] = m
@ -665,7 +666,7 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
logger.info(f" Test context: {get_test_context()}")
if mode == APIRecordingMode.LIVE or storage is None:
if endpoint == "/v1/models":
if endpoint in ("/v1/models", "/v1/openai/v1/models"):
return original_method(self, *args, **kwargs)
else:
return await original_method(self, *args, **kwargs)
@ -699,7 +700,7 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
recording = None
if mode == APIRecordingMode.REPLAY or mode == APIRecordingMode.RECORD_IF_MISSING:
# Special handling for model-list endpoints: merge all recordings with this hash
if endpoint in ("/api/tags", "/v1/models"):
if endpoint in ("/api/tags", "/v1/models", "/v1/openai/v1/models"):
records = storage._model_list_responses(request_hash)
recording = _combine_model_list_responses(endpoint, records)
else:
@ -739,13 +740,13 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
)
if mode == APIRecordingMode.RECORD or (mode == APIRecordingMode.RECORD_IF_MISSING and not recording):
if endpoint == "/v1/models":
if endpoint in ("/v1/models", "/v1/openai/v1/models"):
response = original_method(self, *args, **kwargs)
else:
response = await original_method(self, *args, **kwargs)
# we want to store the result of the iterator, not the iterator itself
if endpoint == "/v1/models":
if endpoint in ("/v1/models", "/v1/openai/v1/models"):
response = [m async for m in response]
request_data = {

View file

@ -51,7 +51,11 @@ async function proxyRequest(request: NextRequest, method: string) {
);
// Create response with same status and headers
const proxyResponse = new NextResponse(responseText, {
// Handle 204 No Content responses specially
const proxyResponse =
response.status === 204
? new NextResponse(null, { status: 204 })
: new NextResponse(responseText, {
status: response.status,
statusText: response.statusText,
});

View file

@ -0,0 +1,5 @@
import { PromptManagement } from "@/components/prompts";
export default function PromptsPage() {
return <PromptManagement />;
}

View file

@ -8,6 +8,7 @@ import {
MessageCircle,
Settings2,
Compass,
FileText,
} from "lucide-react";
import Link from "next/link";
import { usePathname } from "next/navigation";
@ -50,6 +51,11 @@ const manageItems = [
url: "/logs/vector-stores",
icon: Database,
},
{
title: "Prompts",
url: "/prompts",
icon: FileText,
},
{
title: "Documentation",
url: "https://llama-stack.readthedocs.io/en/latest/references/api_reference/index.html",

View file

@ -0,0 +1,4 @@
export { PromptManagement } from "./prompt-management";
export { PromptList } from "./prompt-list";
export { PromptEditor } from "./prompt-editor";
export * from "./types";

View file

@ -0,0 +1,309 @@
import React from "react";
import { render, screen, fireEvent } from "@testing-library/react";
import "@testing-library/jest-dom";
import { PromptEditor } from "./prompt-editor";
import type { Prompt, PromptFormData } from "./types";
describe("PromptEditor", () => {
const mockOnSave = jest.fn();
const mockOnCancel = jest.fn();
const mockOnDelete = jest.fn();
const defaultProps = {
onSave: mockOnSave,
onCancel: mockOnCancel,
onDelete: mockOnDelete,
};
beforeEach(() => {
jest.clearAllMocks();
});
describe("Create Mode", () => {
test("renders create form correctly", () => {
render(<PromptEditor {...defaultProps} />);
expect(screen.getByLabelText("Prompt Content *")).toBeInTheDocument();
expect(screen.getByText("Variables")).toBeInTheDocument();
expect(screen.getByText("Preview")).toBeInTheDocument();
expect(screen.getByText("Create Prompt")).toBeInTheDocument();
expect(screen.getByText("Cancel")).toBeInTheDocument();
});
test("shows preview placeholder when no content", () => {
render(<PromptEditor {...defaultProps} />);
expect(
screen.getByText("Enter content to preview the compiled prompt")
).toBeInTheDocument();
});
test("submits form with correct data", () => {
render(<PromptEditor {...defaultProps} />);
const promptInput = screen.getByLabelText("Prompt Content *");
fireEvent.change(promptInput, {
target: { value: "Hello {{name}}, welcome!" },
});
fireEvent.click(screen.getByText("Create Prompt"));
expect(mockOnSave).toHaveBeenCalledWith({
prompt: "Hello {{name}}, welcome!",
variables: [],
});
});
test("prevents submission with empty prompt", () => {
render(<PromptEditor {...defaultProps} />);
fireEvent.click(screen.getByText("Create Prompt"));
expect(mockOnSave).not.toHaveBeenCalled();
});
});
describe("Edit Mode", () => {
const mockPrompt: Prompt = {
prompt_id: "prompt_123",
prompt: "Hello {{name}}, how is {{weather}}?",
version: 1,
variables: ["name", "weather"],
is_default: true,
};
test("renders edit form with existing data", () => {
render(<PromptEditor {...defaultProps} prompt={mockPrompt} />);
expect(
screen.getByDisplayValue("Hello {{name}}, how is {{weather}}?")
).toBeInTheDocument();
expect(screen.getAllByText("name")).toHaveLength(2); // One in variables, one in preview
expect(screen.getAllByText("weather")).toHaveLength(2); // One in variables, one in preview
expect(screen.getByText("Update Prompt")).toBeInTheDocument();
expect(screen.getByText("Delete Prompt")).toBeInTheDocument();
});
test("submits updated data correctly", () => {
render(<PromptEditor {...defaultProps} prompt={mockPrompt} />);
const promptInput = screen.getByLabelText("Prompt Content *");
fireEvent.change(promptInput, {
target: { value: "Updated: Hello {{name}}!" },
});
fireEvent.click(screen.getByText("Update Prompt"));
expect(mockOnSave).toHaveBeenCalledWith({
prompt: "Updated: Hello {{name}}!",
variables: ["name", "weather"],
});
});
});
describe("Variables Management", () => {
test("adds new variable", () => {
render(<PromptEditor {...defaultProps} />);
const variableInput = screen.getByPlaceholderText(
"Add variable name (e.g. user_name, topic)"
);
fireEvent.change(variableInput, { target: { value: "testVar" } });
fireEvent.click(screen.getByText("Add"));
expect(screen.getByText("testVar")).toBeInTheDocument();
});
test("prevents adding duplicate variables", () => {
render(<PromptEditor {...defaultProps} />);
const variableInput = screen.getByPlaceholderText(
"Add variable name (e.g. user_name, topic)"
);
// Add first variable
fireEvent.change(variableInput, { target: { value: "test" } });
fireEvent.click(screen.getByText("Add"));
// Try to add same variable again
fireEvent.change(variableInput, { target: { value: "test" } });
// Button should be disabled
expect(screen.getByText("Add")).toBeDisabled();
});
test("removes variable", () => {
const mockPrompt: Prompt = {
prompt_id: "prompt_123",
prompt: "Hello {{name}}",
version: 1,
variables: ["name", "location"],
is_default: true,
};
render(<PromptEditor {...defaultProps} prompt={mockPrompt} />);
// Check that both variables are present initially
expect(screen.getAllByText("name").length).toBeGreaterThan(0);
expect(screen.getAllByText("location").length).toBeGreaterThan(0);
// Remove the location variable by clicking the X button with the specific title
const removeLocationButton = screen.getByTitle(
"Remove location variable"
);
fireEvent.click(removeLocationButton);
// Name should still be there, location should be gone from the variables section
expect(screen.getAllByText("name").length).toBeGreaterThan(0);
expect(
screen.queryByTitle("Remove location variable")
).not.toBeInTheDocument();
});
test("adds variable on Enter key", () => {
render(<PromptEditor {...defaultProps} />);
const variableInput = screen.getByPlaceholderText(
"Add variable name (e.g. user_name, topic)"
);
fireEvent.change(variableInput, { target: { value: "enterVar" } });
// Simulate Enter key press
fireEvent.keyPress(variableInput, {
key: "Enter",
code: "Enter",
charCode: 13,
preventDefault: jest.fn(),
});
// Check if the variable was added by looking for the badge
expect(screen.getAllByText("enterVar").length).toBeGreaterThan(0);
});
});
describe("Preview Functionality", () => {
test("shows live preview with variables", () => {
render(<PromptEditor {...defaultProps} />);
// Add prompt content
const promptInput = screen.getByLabelText("Prompt Content *");
fireEvent.change(promptInput, {
target: { value: "Hello {{name}}, welcome to {{place}}!" },
});
// Add variables
const variableInput = screen.getByPlaceholderText(
"Add variable name (e.g. user_name, topic)"
);
fireEvent.change(variableInput, { target: { value: "name" } });
fireEvent.click(screen.getByText("Add"));
fireEvent.change(variableInput, { target: { value: "place" } });
fireEvent.click(screen.getByText("Add"));
// Check that preview area shows the content
expect(screen.getByText("Compiled Prompt")).toBeInTheDocument();
});
test("shows variable value inputs in preview", () => {
const mockPrompt: Prompt = {
prompt_id: "prompt_123",
prompt: "Hello {{name}}",
version: 1,
variables: ["name"],
is_default: true,
};
render(<PromptEditor {...defaultProps} prompt={mockPrompt} />);
expect(screen.getByText("Variable Values")).toBeInTheDocument();
expect(
screen.getByPlaceholderText("Enter value for name")
).toBeInTheDocument();
});
test("shows color legend for variable states", () => {
render(<PromptEditor {...defaultProps} />);
// Add content to show preview
const promptInput = screen.getByLabelText("Prompt Content *");
fireEvent.change(promptInput, {
target: { value: "Hello {{name}}" },
});
expect(screen.getByText("Used")).toBeInTheDocument();
expect(screen.getByText("Unused")).toBeInTheDocument();
expect(screen.getByText("Undefined")).toBeInTheDocument();
});
});
describe("Error Handling", () => {
test("displays error message", () => {
const errorMessage = "Prompt contains undeclared variables";
render(<PromptEditor {...defaultProps} error={errorMessage} />);
expect(screen.getByText(errorMessage)).toBeInTheDocument();
});
});
describe("Delete Functionality", () => {
const mockPrompt: Prompt = {
prompt_id: "prompt_123",
prompt: "Hello {{name}}",
version: 1,
variables: ["name"],
is_default: true,
};
test("shows delete button in edit mode", () => {
render(<PromptEditor {...defaultProps} prompt={mockPrompt} />);
expect(screen.getByText("Delete Prompt")).toBeInTheDocument();
});
test("hides delete button in create mode", () => {
render(<PromptEditor {...defaultProps} />);
expect(screen.queryByText("Delete Prompt")).not.toBeInTheDocument();
});
test("calls onDelete with confirmation", () => {
const originalConfirm = window.confirm;
window.confirm = jest.fn(() => true);
render(<PromptEditor {...defaultProps} prompt={mockPrompt} />);
fireEvent.click(screen.getByText("Delete Prompt"));
expect(window.confirm).toHaveBeenCalledWith(
"Are you sure you want to delete this prompt? This action cannot be undone."
);
expect(mockOnDelete).toHaveBeenCalledWith("prompt_123");
window.confirm = originalConfirm;
});
test("does not delete when confirmation is cancelled", () => {
const originalConfirm = window.confirm;
window.confirm = jest.fn(() => false);
render(<PromptEditor {...defaultProps} prompt={mockPrompt} />);
fireEvent.click(screen.getByText("Delete Prompt"));
expect(mockOnDelete).not.toHaveBeenCalled();
window.confirm = originalConfirm;
});
});
describe("Cancel Functionality", () => {
test("calls onCancel when cancel button is clicked", () => {
render(<PromptEditor {...defaultProps} />);
fireEvent.click(screen.getByText("Cancel"));
expect(mockOnCancel).toHaveBeenCalled();
});
});
});

View file

@ -0,0 +1,346 @@
"use client";
import { useState, useEffect } from "react";
import { Button } from "@/components/ui/button";
import { Input } from "@/components/ui/input";
import { Label } from "@/components/ui/label";
import { Textarea } from "@/components/ui/textarea";
import { Badge } from "@/components/ui/badge";
import {
Card,
CardContent,
CardDescription,
CardHeader,
CardTitle,
} from "@/components/ui/card";
import { Separator } from "@/components/ui/separator";
import { X, Plus, Save, Trash2 } from "lucide-react";
import { Prompt, PromptFormData } from "./types";
interface PromptEditorProps {
prompt?: Prompt;
onSave: (prompt: PromptFormData) => void;
onCancel: () => void;
onDelete?: (promptId: string) => void;
error?: string | null;
}
export function PromptEditor({
prompt,
onSave,
onCancel,
onDelete,
error,
}: PromptEditorProps) {
const [formData, setFormData] = useState<PromptFormData>({
prompt: "",
variables: [],
});
const [newVariable, setNewVariable] = useState("");
const [variableValues, setVariableValues] = useState<Record<string, string>>(
{}
);
useEffect(() => {
if (prompt) {
setFormData({
prompt: prompt.prompt || "",
variables: prompt.variables || [],
});
}
}, [prompt]);
const handleSubmit = (e: React.FormEvent) => {
e.preventDefault();
if (!formData.prompt.trim()) {
return;
}
onSave(formData);
};
const addVariable = () => {
if (
newVariable.trim() &&
!formData.variables.includes(newVariable.trim())
) {
setFormData(prev => ({
...prev,
variables: [...prev.variables, newVariable.trim()],
}));
setNewVariable("");
}
};
const removeVariable = (variableToRemove: string) => {
setFormData(prev => ({
...prev,
variables: prev.variables.filter(
variable => variable !== variableToRemove
),
}));
};
const renderPreview = () => {
const text = formData.prompt;
if (!text) return text;
// Split text by variable patterns and process each part
const parts = text.split(/(\{\{\s*\w+\s*\}\})/g);
return parts.map((part, index) => {
const variableMatch = part.match(/\{\{\s*(\w+)\s*\}\}/);
if (variableMatch) {
const variableName = variableMatch[1];
const isDefined = formData.variables.includes(variableName);
const value = variableValues[variableName];
if (!isDefined) {
// Variable not in variables list - likely a typo/bug (RED)
return (
<span
key={index}
className="bg-red-100 text-red-800 dark:bg-red-900 dark:text-red-200 px-1 rounded font-medium"
>
{part}
</span>
);
} else if (value && value.trim()) {
// Variable defined and has value - show the value (GREEN)
return (
<span
key={index}
className="bg-green-100 text-green-800 dark:bg-green-900 dark:text-green-200 px-1 rounded font-medium"
>
{value}
</span>
);
} else {
// Variable defined but empty (YELLOW)
return (
<span
key={index}
className="bg-yellow-100 text-yellow-800 dark:bg-yellow-900 dark:text-yellow-200 px-1 rounded font-medium"
>
{part}
</span>
);
}
}
return part;
});
};
const updateVariableValue = (variable: string, value: string) => {
setVariableValues(prev => ({
...prev,
[variable]: value,
}));
};
return (
<form onSubmit={handleSubmit} className="space-y-6">
{error && (
<div className="p-4 bg-destructive/10 border border-destructive/20 rounded-md">
<p className="text-destructive text-sm">{error}</p>
</div>
)}
<div className="grid grid-cols-1 lg:grid-cols-2 gap-6">
{/* Form Section */}
<div className="space-y-4">
<div>
<Label htmlFor="prompt">Prompt Content *</Label>
<Textarea
id="prompt"
value={formData.prompt}
onChange={e =>
setFormData(prev => ({ ...prev, prompt: e.target.value }))
}
placeholder="Enter your prompt content here. Use {{variable_name}} for dynamic variables."
className="min-h-32 font-mono mt-2"
required
/>
<p className="text-xs text-muted-foreground mt-2">
Use double curly braces around variable names, e.g.,{" "}
{`{{user_name}}`} or {`{{topic}}`}
</p>
</div>
<div className="space-y-3">
<Label className="text-sm font-medium">Variables</Label>
<div className="flex gap-2 mt-2">
<Input
value={newVariable}
onChange={e => setNewVariable(e.target.value)}
placeholder="Add variable name (e.g. user_name, topic)"
onKeyPress={e =>
e.key === "Enter" && (e.preventDefault(), addVariable())
}
className="flex-1"
/>
<Button
type="button"
onClick={addVariable}
size="sm"
disabled={
!newVariable.trim() ||
formData.variables.includes(newVariable.trim())
}
>
<Plus className="h-4 w-4" />
Add
</Button>
</div>
{formData.variables.length > 0 && (
<div className="border rounded-lg p-3 bg-muted/20">
<div className="flex flex-wrap gap-2">
{formData.variables.map(variable => (
<Badge
key={variable}
variant="secondary"
className="text-sm px-2 py-1"
>
{variable}
<button
type="button"
onClick={() => removeVariable(variable)}
className="ml-2 hover:text-destructive transition-colors"
title={`Remove ${variable} variable`}
>
<X className="h-3 w-3" />
</button>
</Badge>
))}
</div>
</div>
)}
<p className="text-xs text-muted-foreground">
Variables that can be used in the prompt template. Each variable
should match a {`{{variable}}`} placeholder in the content above.
</p>
</div>
</div>
{/* Preview Section */}
<div className="space-y-4">
<Card>
<CardHeader>
<CardTitle className="text-lg">Preview</CardTitle>
<CardDescription>
Live preview of compiled prompt and variable substitution.
</CardDescription>
</CardHeader>
<CardContent className="space-y-4">
{formData.prompt ? (
<>
{/* Variable Values */}
{formData.variables.length > 0 && (
<div className="space-y-3">
<Label className="text-sm font-medium">
Variable Values
</Label>
<div className="space-y-2">
{formData.variables.map(variable => (
<div
key={variable}
className="grid grid-cols-2 gap-3 items-center"
>
<div className="text-sm font-mono text-muted-foreground">
{variable}
</div>
<Input
id={`var-${variable}`}
value={variableValues[variable] || ""}
onChange={e =>
updateVariableValue(variable, e.target.value)
}
placeholder={`Enter value for ${variable}`}
className="text-sm"
/>
</div>
))}
</div>
<Separator />
</div>
)}
{/* Live Preview */}
<div>
<Label className="text-sm font-medium mb-2 block">
Compiled Prompt
</Label>
<div className="bg-muted/50 p-4 rounded-lg border">
<div className="text-sm leading-relaxed whitespace-pre-wrap">
{renderPreview()}
</div>
</div>
<div className="flex flex-wrap gap-4 mt-2 text-xs">
<div className="flex items-center gap-1">
<div className="w-3 h-3 bg-green-500 dark:bg-green-400 border rounded"></div>
<span className="text-muted-foreground">Used</span>
</div>
<div className="flex items-center gap-1">
<div className="w-3 h-3 bg-yellow-500 dark:bg-yellow-400 border rounded"></div>
<span className="text-muted-foreground">Unused</span>
</div>
<div className="flex items-center gap-1">
<div className="w-3 h-3 bg-red-500 dark:bg-red-400 border rounded"></div>
<span className="text-muted-foreground">Undefined</span>
</div>
</div>
</div>
</>
) : (
<div className="text-center py-8">
<div className="text-muted-foreground text-sm">
Enter content to preview the compiled prompt
</div>
<div className="text-xs text-muted-foreground mt-2">
Use {`{{variable_name}}`} to add dynamic variables
</div>
</div>
)}
</CardContent>
</Card>
</div>
</div>
<Separator />
<div className="flex justify-between">
<div>
{prompt && onDelete && (
<Button
type="button"
variant="destructive"
onClick={() => {
if (
confirm(
`Are you sure you want to delete this prompt? This action cannot be undone.`
)
) {
onDelete(prompt.prompt_id);
}
}}
>
<Trash2 className="h-4 w-4 mr-2" />
Delete Prompt
</Button>
)}
</div>
<div className="flex gap-2">
<Button type="button" variant="outline" onClick={onCancel}>
Cancel
</Button>
<Button type="submit">
<Save className="h-4 w-4 mr-2" />
{prompt ? "Update" : "Create"} Prompt
</Button>
</div>
</div>
</form>
);
}

View file

@ -0,0 +1,259 @@
import React from "react";
import { render, screen, fireEvent } from "@testing-library/react";
import "@testing-library/jest-dom";
import { PromptList } from "./prompt-list";
import type { Prompt } from "./types";
describe("PromptList", () => {
const mockOnEdit = jest.fn();
const mockOnDelete = jest.fn();
const defaultProps = {
prompts: [],
onEdit: mockOnEdit,
onDelete: mockOnDelete,
};
beforeEach(() => {
jest.clearAllMocks();
});
describe("Empty State", () => {
test("renders empty message when no prompts", () => {
render(<PromptList {...defaultProps} />);
expect(screen.getByText("No prompts yet")).toBeInTheDocument();
});
test("shows filtered empty message when search has no results", () => {
const prompts: Prompt[] = [
{
prompt_id: "prompt_123",
prompt: "Hello world",
version: 1,
variables: [],
is_default: false,
},
];
render(<PromptList {...defaultProps} prompts={prompts} />);
// Search for something that doesn't exist
const searchInput = screen.getByPlaceholderText("Search prompts...");
fireEvent.change(searchInput, { target: { value: "nonexistent" } });
expect(
screen.getByText("No prompts match your filters")
).toBeInTheDocument();
});
});
describe("Prompts Display", () => {
const mockPrompts: Prompt[] = [
{
prompt_id: "prompt_123",
prompt: "Hello {{name}}, how are you?",
version: 1,
variables: ["name"],
is_default: true,
},
{
prompt_id: "prompt_456",
prompt: "Summarize this {{text}} in {{length}} words",
version: 2,
variables: ["text", "length"],
is_default: false,
},
{
prompt_id: "prompt_789",
prompt: "Simple prompt with no variables",
version: 1,
variables: [],
is_default: false,
},
];
test("renders prompts table with correct headers", () => {
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
expect(screen.getByText("ID")).toBeInTheDocument();
expect(screen.getByText("Content")).toBeInTheDocument();
expect(screen.getByText("Variables")).toBeInTheDocument();
expect(screen.getByText("Version")).toBeInTheDocument();
expect(screen.getByText("Actions")).toBeInTheDocument();
});
test("renders prompt data correctly", () => {
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
// Check prompt IDs
expect(screen.getByText("prompt_123")).toBeInTheDocument();
expect(screen.getByText("prompt_456")).toBeInTheDocument();
expect(screen.getByText("prompt_789")).toBeInTheDocument();
// Check content
expect(
screen.getByText("Hello {{name}}, how are you?")
).toBeInTheDocument();
expect(
screen.getByText("Summarize this {{text}} in {{length}} words")
).toBeInTheDocument();
expect(
screen.getByText("Simple prompt with no variables")
).toBeInTheDocument();
// Check versions
expect(screen.getAllByText("1")).toHaveLength(2); // Two prompts with version 1
expect(screen.getByText("2")).toBeInTheDocument();
// Check default badge
expect(screen.getByText("Default")).toBeInTheDocument();
});
test("renders variables correctly", () => {
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
// Check variables display
expect(screen.getByText("name")).toBeInTheDocument();
expect(screen.getByText("text")).toBeInTheDocument();
expect(screen.getByText("length")).toBeInTheDocument();
expect(screen.getByText("None")).toBeInTheDocument(); // For prompt with no variables
});
test("prompt ID links are clickable and call onEdit", () => {
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
// Click on the first prompt ID link
const promptLink = screen.getByRole("button", { name: "prompt_123" });
fireEvent.click(promptLink);
expect(mockOnEdit).toHaveBeenCalledWith(mockPrompts[0]);
});
test("edit buttons call onEdit", () => {
const { container } = render(
<PromptList {...defaultProps} prompts={mockPrompts} />
);
// Find the action buttons in the table - they should be in the last column
const actionCells = container.querySelectorAll("td:last-child");
const firstActionCell = actionCells[0];
const editButton = firstActionCell?.querySelector("button");
expect(editButton).toBeInTheDocument();
fireEvent.click(editButton!);
expect(mockOnEdit).toHaveBeenCalledWith(mockPrompts[0]);
});
test("delete buttons call onDelete with confirmation", () => {
const originalConfirm = window.confirm;
window.confirm = jest.fn(() => true);
const { container } = render(
<PromptList {...defaultProps} prompts={mockPrompts} />
);
// Find the delete button (second button in the first action cell)
const actionCells = container.querySelectorAll("td:last-child");
const firstActionCell = actionCells[0];
const buttons = firstActionCell?.querySelectorAll("button");
const deleteButton = buttons?.[1]; // Second button should be delete
expect(deleteButton).toBeInTheDocument();
fireEvent.click(deleteButton!);
expect(window.confirm).toHaveBeenCalledWith(
"Are you sure you want to delete this prompt? This action cannot be undone."
);
expect(mockOnDelete).toHaveBeenCalledWith("prompt_123");
window.confirm = originalConfirm;
});
test("delete does not execute when confirmation is cancelled", () => {
const originalConfirm = window.confirm;
window.confirm = jest.fn(() => false);
const { container } = render(
<PromptList {...defaultProps} prompts={mockPrompts} />
);
const actionCells = container.querySelectorAll("td:last-child");
const firstActionCell = actionCells[0];
const buttons = firstActionCell?.querySelectorAll("button");
const deleteButton = buttons?.[1]; // Second button should be delete
expect(deleteButton).toBeInTheDocument();
fireEvent.click(deleteButton!);
expect(mockOnDelete).not.toHaveBeenCalled();
window.confirm = originalConfirm;
});
});
describe("Search Functionality", () => {
const mockPrompts: Prompt[] = [
{
prompt_id: "user_greeting",
prompt: "Hello {{name}}, welcome!",
version: 1,
variables: ["name"],
is_default: true,
},
{
prompt_id: "system_summary",
prompt: "Summarize the following text",
version: 1,
variables: [],
is_default: false,
},
];
test("filters prompts by prompt ID", () => {
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
const searchInput = screen.getByPlaceholderText("Search prompts...");
fireEvent.change(searchInput, { target: { value: "user" } });
expect(screen.getByText("user_greeting")).toBeInTheDocument();
expect(screen.queryByText("system_summary")).not.toBeInTheDocument();
});
test("filters prompts by content", () => {
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
const searchInput = screen.getByPlaceholderText("Search prompts...");
fireEvent.change(searchInput, { target: { value: "welcome" } });
expect(screen.getByText("user_greeting")).toBeInTheDocument();
expect(screen.queryByText("system_summary")).not.toBeInTheDocument();
});
test("search is case insensitive", () => {
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
const searchInput = screen.getByPlaceholderText("Search prompts...");
fireEvent.change(searchInput, { target: { value: "HELLO" } });
expect(screen.getByText("user_greeting")).toBeInTheDocument();
expect(screen.queryByText("system_summary")).not.toBeInTheDocument();
});
test("clearing search shows all prompts", () => {
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
const searchInput = screen.getByPlaceholderText("Search prompts...");
// Filter first
fireEvent.change(searchInput, { target: { value: "user" } });
expect(screen.queryByText("system_summary")).not.toBeInTheDocument();
// Clear search
fireEvent.change(searchInput, { target: { value: "" } });
expect(screen.getByText("user_greeting")).toBeInTheDocument();
expect(screen.getByText("system_summary")).toBeInTheDocument();
});
});
});

View file

@ -0,0 +1,164 @@
"use client";
import { useState } from "react";
import { Badge } from "@/components/ui/badge";
import { Button } from "@/components/ui/button";
import {
Table,
TableBody,
TableCell,
TableHead,
TableHeader,
TableRow,
} from "@/components/ui/table";
import { Input } from "@/components/ui/input";
import { Edit, Search, Trash2 } from "lucide-react";
import { Prompt, PromptFilters } from "./types";
interface PromptListProps {
prompts: Prompt[];
onEdit: (prompt: Prompt) => void;
onDelete: (promptId: string) => void;
}
export function PromptList({ prompts, onEdit, onDelete }: PromptListProps) {
const [filters, setFilters] = useState<PromptFilters>({});
const filteredPrompts = prompts.filter(prompt => {
if (
filters.searchTerm &&
!(
prompt.prompt
?.toLowerCase()
.includes(filters.searchTerm.toLowerCase()) ||
prompt.prompt_id
.toLowerCase()
.includes(filters.searchTerm.toLowerCase())
)
) {
return false;
}
return true;
});
return (
<div className="space-y-4">
{/* Filters */}
<div className="flex flex-col sm:flex-row gap-4">
<div className="relative flex-1">
<Search className="absolute left-3 top-1/2 transform -translate-y-1/2 text-muted-foreground h-4 w-4" />
<Input
placeholder="Search prompts..."
value={filters.searchTerm || ""}
onChange={e =>
setFilters(prev => ({ ...prev, searchTerm: e.target.value }))
}
className="pl-10"
/>
</div>
</div>
{/* Prompts Table */}
<div className="overflow-auto">
<Table>
<TableHeader>
<TableRow>
<TableHead>ID</TableHead>
<TableHead>Content</TableHead>
<TableHead>Variables</TableHead>
<TableHead>Version</TableHead>
<TableHead>Actions</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{filteredPrompts.map(prompt => (
<TableRow key={prompt.prompt_id}>
<TableCell className="max-w-48">
<Button
variant="link"
className="p-0 h-auto font-mono text-blue-600 hover:text-blue-800 dark:text-blue-400 dark:hover:text-blue-300 max-w-full justify-start"
onClick={() => onEdit(prompt)}
title={prompt.prompt_id}
>
<div className="truncate">{prompt.prompt_id}</div>
</Button>
</TableCell>
<TableCell className="max-w-64">
<div
className="font-mono text-xs text-muted-foreground truncate"
title={prompt.prompt || "No content"}
>
{prompt.prompt || "No content"}
</div>
</TableCell>
<TableCell>
{prompt.variables.length > 0 ? (
<div className="flex flex-wrap gap-1">
{prompt.variables.map(variable => (
<Badge
key={variable}
variant="outline"
className="text-xs"
>
{variable}
</Badge>
))}
</div>
) : (
<span className="text-muted-foreground text-sm">None</span>
)}
</TableCell>
<TableCell className="text-sm">
{prompt.version}
{prompt.is_default && (
<Badge variant="secondary" className="text-xs ml-2">
Default
</Badge>
)}
</TableCell>
<TableCell>
<div className="flex gap-1">
<Button
size="sm"
variant="outline"
onClick={() => onEdit(prompt)}
className="h-8 w-8 p-0"
>
<Edit className="h-3 w-3" />
</Button>
<Button
size="sm"
variant="outline"
onClick={() => {
if (
confirm(
`Are you sure you want to delete this prompt? This action cannot be undone.`
)
) {
onDelete(prompt.prompt_id);
}
}}
className="h-8 w-8 p-0 text-destructive hover:text-destructive"
>
<Trash2 className="h-3 w-3" />
</Button>
</div>
</TableCell>
</TableRow>
))}
</TableBody>
</Table>
</div>
{filteredPrompts.length === 0 && (
<div className="text-center py-12">
<div className="text-muted-foreground">
{prompts.length === 0
? "No prompts yet"
: "No prompts match your filters"}
</div>
</div>
)}
</div>
);
}

Some files were not shown because too many files have changed in this diff Show more