mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 18:00:36 +00:00
Merge branch 'main' into add-mongodb-vector_io
This commit is contained in:
commit
d0064fc915
426 changed files with 99110 additions and 62778 deletions
60
.github/actions/install-llama-stack-client/action.yml
vendored
Normal file
60
.github/actions/install-llama-stack-client/action.yml
vendored
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
name: Install llama-stack-client
|
||||
description: Install llama-stack-client based on branch context and client-version input
|
||||
|
||||
inputs:
|
||||
client-version:
|
||||
description: 'Client version to install on non-release branches (latest or published). Ignored on release branches.'
|
||||
required: false
|
||||
default: ""
|
||||
|
||||
outputs:
|
||||
uv-extra-index-url:
|
||||
description: 'UV_EXTRA_INDEX_URL to use (set for release branches)'
|
||||
value: ${{ steps.configure.outputs.uv-extra-index-url }}
|
||||
install-after-sync:
|
||||
description: 'Whether to install client after uv sync'
|
||||
value: ${{ steps.configure.outputs.install-after-sync }}
|
||||
install-source:
|
||||
description: 'Where to install client from after sync'
|
||||
value: ${{ steps.configure.outputs.install-source }}
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Configure client installation
|
||||
id: configure
|
||||
shell: bash
|
||||
run: |
|
||||
# Determine the branch we're working with
|
||||
BRANCH="${{ github.base_ref || github.ref }}"
|
||||
BRANCH="${BRANCH#refs/heads/}"
|
||||
|
||||
echo "Working with branch: $BRANCH"
|
||||
|
||||
# On release branches: use test.pypi for uv sync, then install from git
|
||||
# On non-release branches: install based on client-version after sync
|
||||
if [[ "$BRANCH" =~ ^release-[0-9]+\.[0-9]+\.x$ ]]; then
|
||||
echo "Detected release branch: $BRANCH"
|
||||
|
||||
# Check if matching branch exists in client repo
|
||||
if ! git ls-remote --exit-code --heads https://github.com/llamastack/llama-stack-client-python.git "$BRANCH" > /dev/null 2>&1; then
|
||||
echo "::error::Branch $BRANCH not found in llama-stack-client-python repository"
|
||||
echo "::error::Please create the matching release branch in llama-stack-client-python before testing"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Configure to use test.pypi as extra index (PyPI is primary)
|
||||
echo "uv-extra-index-url=https://test.pypi.org/simple/" >> $GITHUB_OUTPUT
|
||||
echo "install-after-sync=true" >> $GITHUB_OUTPUT
|
||||
echo "install-source=git+https://github.com/llamastack/llama-stack-client-python.git@$BRANCH" >> $GITHUB_OUTPUT
|
||||
elif [ "${{ inputs.client-version }}" = "latest" ]; then
|
||||
# Install from main git after sync
|
||||
echo "install-after-sync=true" >> $GITHUB_OUTPUT
|
||||
echo "install-source=git+https://github.com/llamastack/llama-stack-client-python.git@main" >> $GITHUB_OUTPUT
|
||||
elif [ "${{ inputs.client-version }}" = "published" ]; then
|
||||
# Use published version from PyPI (installed by sync)
|
||||
echo "install-after-sync=false" >> $GITHUB_OUTPUT
|
||||
elif [ -n "${{ inputs.client-version }}" ]; then
|
||||
echo "::error::Invalid client-version: ${{ inputs.client-version }}"
|
||||
exit 1
|
||||
fi
|
||||
|
|
@ -94,7 +94,7 @@ runs:
|
|||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
|
||||
with:
|
||||
name: logs-${{ github.run_id }}-${{ github.run_attempt || '' }}-${{ strategy.job-index }}
|
||||
name: logs-${{ github.run_id }}-${{ github.run_attempt || '1' }}-${{ strategy.job-index || github.job }}-${{ github.action }}
|
||||
path: |
|
||||
*.log
|
||||
retention-days: 1
|
||||
|
|
|
|||
30
.github/actions/setup-runner/action.yml
vendored
30
.github/actions/setup-runner/action.yml
vendored
|
|
@ -18,25 +18,35 @@ runs:
|
|||
python-version: ${{ inputs.python-version }}
|
||||
version: 0.7.6
|
||||
|
||||
- name: Configure client installation
|
||||
id: client-config
|
||||
uses: ./.github/actions/install-llama-stack-client
|
||||
with:
|
||||
client-version: ${{ inputs.client-version }}
|
||||
|
||||
- name: Install dependencies
|
||||
shell: bash
|
||||
env:
|
||||
UV_EXTRA_INDEX_URL: ${{ steps.client-config.outputs.uv-extra-index-url }}
|
||||
run: |
|
||||
# Export UV env vars for current step and persist to GITHUB_ENV for subsequent steps
|
||||
if [ -n "$UV_EXTRA_INDEX_URL" ]; then
|
||||
export UV_INDEX_STRATEGY=unsafe-best-match
|
||||
echo "UV_EXTRA_INDEX_URL=$UV_EXTRA_INDEX_URL" >> $GITHUB_ENV
|
||||
echo "UV_INDEX_STRATEGY=$UV_INDEX_STRATEGY" >> $GITHUB_ENV
|
||||
echo "Exported UV environment variables for current and subsequent steps"
|
||||
fi
|
||||
|
||||
echo "Updating project dependencies via uv sync"
|
||||
uv sync --all-groups
|
||||
|
||||
echo "Installing ad-hoc dependencies"
|
||||
uv pip install faiss-cpu
|
||||
|
||||
# Install llama-stack-client-python based on the client-version input
|
||||
if [ "${{ inputs.client-version }}" = "latest" ]; then
|
||||
echo "Installing latest llama-stack-client-python from main branch"
|
||||
uv pip install git+https://github.com/llamastack/llama-stack-client-python.git@main
|
||||
elif [ "${{ inputs.client-version }}" = "published" ]; then
|
||||
echo "Installing published llama-stack-client-python from PyPI"
|
||||
uv pip install llama-stack-client
|
||||
else
|
||||
echo "Invalid client-version: ${{ inputs.client-version }}"
|
||||
exit 1
|
||||
# Install specific client version after sync if needed
|
||||
if [ "${{ steps.client-config.outputs.install-after-sync }}" = "true" ]; then
|
||||
echo "Installing llama-stack-client from: ${{ steps.client-config.outputs.install-source }}"
|
||||
uv pip install ${{ steps.client-config.outputs.install-source }}
|
||||
fi
|
||||
|
||||
echo "Installed llama packages"
|
||||
|
|
|
|||
|
|
@ -42,18 +42,7 @@ runs:
|
|||
- name: Build Llama Stack
|
||||
shell: bash
|
||||
run: |
|
||||
# Install llama-stack-client-python based on the client-version input
|
||||
if [ "${{ inputs.client-version }}" = "latest" ]; then
|
||||
echo "Installing latest llama-stack-client-python from main branch"
|
||||
export LLAMA_STACK_CLIENT_DIR=git+https://github.com/llamastack/llama-stack-client-python.git@main
|
||||
elif [ "${{ inputs.client-version }}" = "published" ]; then
|
||||
echo "Installing published llama-stack-client-python from PyPI"
|
||||
unset LLAMA_STACK_CLIENT_DIR
|
||||
else
|
||||
echo "Invalid client-version: ${{ inputs.client-version }}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Client is already installed by setup-runner (handles both main and release branches)
|
||||
echo "Building Llama Stack"
|
||||
|
||||
LLAMA_STACK_DIR=. \
|
||||
|
|
|
|||
23
.github/mergify.yml
vendored
Normal file
23
.github/mergify.yml
vendored
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
pull_request_rules:
|
||||
- name: ping author on conflicts and add 'needs-rebase' label
|
||||
conditions:
|
||||
- conflict
|
||||
- -closed
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
- needs-rebase
|
||||
comment:
|
||||
message: >
|
||||
This pull request has merge conflicts that must be resolved before it
|
||||
can be merged. @{{author}} please rebase it.
|
||||
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
|
||||
|
||||
- name: remove 'needs-rebase' label when conflict is resolved
|
||||
conditions:
|
||||
- -conflict
|
||||
- -closed
|
||||
actions:
|
||||
label:
|
||||
remove:
|
||||
- needs-rebase
|
||||
2
.github/workflows/README.md
vendored
2
.github/workflows/README.md
vendored
|
|
@ -4,6 +4,7 @@ Llama Stack uses GitHub Actions for Continuous Integration (CI). Below is a tabl
|
|||
|
||||
| Name | File | Purpose |
|
||||
| ---- | ---- | ------- |
|
||||
| Backward Compatibility Check | [backward-compat.yml](backward-compat.yml) | Check backward compatibility for run.yaml configs |
|
||||
| Update Changelog | [changelog.yml](changelog.yml) | Creates PR for updating the CHANGELOG.md |
|
||||
| API Conformance Tests | [conformance.yml](conformance.yml) | Run the API Conformance test suite on the changes. |
|
||||
| Installer CI | [install-script-ci.yml](install-script-ci.yml) | Test the installation script |
|
||||
|
|
@ -12,7 +13,6 @@ Llama Stack uses GitHub Actions for Continuous Integration (CI). Below is a tabl
|
|||
| Integration Tests (Replay) | [integration-tests.yml](integration-tests.yml) | Run the integration test suites from tests/integration in replay mode |
|
||||
| Vector IO Integration Tests | [integration-vector-io-tests.yml](integration-vector-io-tests.yml) | Run the integration test suite with various VectorIO providers |
|
||||
| Pre-commit | [pre-commit.yml](pre-commit.yml) | Run pre-commit checks |
|
||||
| Pre-commit Bot | [precommit-trigger.yml](precommit-trigger.yml) | Pre-commit bot for PR |
|
||||
| Test Llama Stack Build | [providers-build.yml](providers-build.yml) | Test llama stack build |
|
||||
| Test llama stack list-deps | [providers-list-deps.yml](providers-list-deps.yml) | Test llama stack list-deps |
|
||||
| Python Package Build Test | [python-build-test.yml](python-build-test.yml) | Test building the llama-stack PyPI project |
|
||||
|
|
|
|||
578
.github/workflows/backward-compat.yml
vendored
Normal file
578
.github/workflows/backward-compat.yml
vendored
Normal file
|
|
@ -0,0 +1,578 @@
|
|||
name: Backward Compatibility Check
|
||||
|
||||
run-name: Check backward compatibility for run.yaml configs
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- 'release-[0-9]+.[0-9]+.[0-9]+.[0-9]+'
|
||||
- 'release-[0-9]+.[0-9]+.[0-9]+'
|
||||
- 'release-[0-9]+.[0-9]+'
|
||||
paths:
|
||||
- 'src/llama_stack/core/datatypes.py'
|
||||
- 'src/llama_stack/providers/datatypes.py'
|
||||
- 'src/llama_stack/distributions/**/run.yaml'
|
||||
- 'tests/backward_compat/**'
|
||||
- '.github/workflows/backward-compat.yml'
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
check-main-compatibility:
|
||||
name: Check Compatibility with main
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout PR branch
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
with:
|
||||
fetch-depth: 0 # Need full history to access main branch
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@85856786d1ce8acfbcc2f13a5f3fbd6b938f9f41 # v7.1.2
|
||||
with:
|
||||
enable-cache: true
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv sync --group dev
|
||||
|
||||
- name: Extract run.yaml files from main branch
|
||||
id: extract_configs
|
||||
run: |
|
||||
# Get list of run.yaml paths from main
|
||||
git fetch origin main
|
||||
CONFIG_PATHS=$(git ls-tree -r --name-only origin/main | grep "src/llama_stack/distributions/.*/run.yaml$" || true)
|
||||
|
||||
if [ -z "$CONFIG_PATHS" ]; then
|
||||
echo "No run.yaml files found in main branch"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Extract all configs to a temp directory
|
||||
mkdir -p /tmp/main_configs
|
||||
echo "Extracting configs from main branch:"
|
||||
|
||||
while IFS= read -r config_path; do
|
||||
if [ -z "$config_path" ]; then
|
||||
continue
|
||||
fi
|
||||
|
||||
# Extract filename for storage
|
||||
filename=$(basename $(dirname "$config_path"))
|
||||
echo " - $filename (from $config_path)"
|
||||
|
||||
git show origin/main:"$config_path" > "/tmp/main_configs/${filename}.yaml"
|
||||
done <<< "$CONFIG_PATHS"
|
||||
|
||||
echo ""
|
||||
echo "Extracted $(ls /tmp/main_configs/*.yaml | wc -l) config files"
|
||||
|
||||
- name: Test all configs from main
|
||||
id: test_configs
|
||||
continue-on-error: true
|
||||
run: |
|
||||
# Run pytest once with all configs parameterized
|
||||
if COMPAT_TEST_CONFIGS_DIR=/tmp/main_configs uv run pytest tests/backward_compat/test_run_config.py -v; then
|
||||
echo "failed=false" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "failed=true" >> $GITHUB_OUTPUT
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Check for breaking change acknowledgment
|
||||
id: check_ack
|
||||
if: steps.test_configs.outputs.failed == 'true'
|
||||
run: |
|
||||
echo "Breaking changes detected. Checking for acknowledgment..."
|
||||
|
||||
# Check PR title for '!:' marker (conventional commits)
|
||||
PR_TITLE="${{ github.event.pull_request.title }}"
|
||||
if [[ "$PR_TITLE" =~ ^[a-z]+\!: ]]; then
|
||||
echo "✓ Breaking change acknowledged in PR title"
|
||||
echo "acknowledged=true" >> $GITHUB_OUTPUT
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Check commit messages for BREAKING CHANGE:
|
||||
if git log origin/main..HEAD --format=%B | grep -q "BREAKING CHANGE:"; then
|
||||
echo "✓ Breaking change acknowledged in commit message"
|
||||
echo "acknowledged=true" >> $GITHUB_OUTPUT
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "✗ Breaking change NOT acknowledged"
|
||||
echo "acknowledged=false" >> $GITHUB_OUTPUT
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
|
||||
- name: Evaluate results
|
||||
if: always()
|
||||
run: |
|
||||
FAILED="${{ steps.test_configs.outputs.failed }}"
|
||||
ACKNOWLEDGED="${{ steps.check_ack.outputs.acknowledged }}"
|
||||
|
||||
if [[ "$FAILED" == "true" ]]; then
|
||||
if [[ "$ACKNOWLEDGED" == "true" ]]; then
|
||||
echo ""
|
||||
echo "⚠️ WARNING: Breaking changes detected but acknowledged"
|
||||
echo ""
|
||||
echo "This PR introduces backward-incompatible changes to run.yaml."
|
||||
echo "The changes have been properly acknowledged."
|
||||
echo ""
|
||||
exit 0 # Pass the check
|
||||
else
|
||||
echo ""
|
||||
echo "❌ ERROR: Breaking changes detected without acknowledgment"
|
||||
echo ""
|
||||
echo "This PR introduces backward-incompatible changes to run.yaml"
|
||||
echo "that will break existing user configurations."
|
||||
echo ""
|
||||
echo "To acknowledge this breaking change, do ONE of:"
|
||||
echo " 1. Add '!:' to your PR title (e.g., 'feat!: change xyz')"
|
||||
echo " 2. Add the 'breaking-change' label to this PR"
|
||||
echo " 3. Include 'BREAKING CHANGE:' in a commit message"
|
||||
echo ""
|
||||
exit 1 # Fail the check
|
||||
fi
|
||||
fi
|
||||
|
||||
test-integration-main:
|
||||
name: Run Integration Tests with main Config
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout PR branch
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Extract ci-tests run.yaml from main
|
||||
run: |
|
||||
git fetch origin main
|
||||
git show origin/main:src/llama_stack/distributions/ci-tests/run.yaml > /tmp/main-ci-tests-run.yaml
|
||||
echo "Extracted ci-tests run.yaml from main branch"
|
||||
|
||||
- name: Setup test environment
|
||||
uses: ./.github/actions/setup-test-environment
|
||||
with:
|
||||
python-version: '3.12'
|
||||
client-version: 'latest'
|
||||
setup: 'ollama'
|
||||
suite: 'base'
|
||||
inference-mode: 'replay'
|
||||
|
||||
- name: Run integration tests with main config
|
||||
id: test_integration
|
||||
continue-on-error: true
|
||||
uses: ./.github/actions/run-and-record-tests
|
||||
with:
|
||||
stack-config: /tmp/main-ci-tests-run.yaml
|
||||
setup: 'ollama'
|
||||
inference-mode: 'replay'
|
||||
suite: 'base'
|
||||
|
||||
- name: Check for breaking change acknowledgment
|
||||
id: check_ack
|
||||
if: steps.test_integration.outcome == 'failure'
|
||||
run: |
|
||||
echo "Integration tests failed. Checking for acknowledgment..."
|
||||
|
||||
# Check PR title for '!:' marker (conventional commits)
|
||||
PR_TITLE="${{ github.event.pull_request.title }}"
|
||||
if [[ "$PR_TITLE" =~ ^[a-z]+\!: ]]; then
|
||||
echo "✓ Breaking change acknowledged in PR title"
|
||||
echo "acknowledged=true" >> $GITHUB_OUTPUT
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Check commit messages for BREAKING CHANGE:
|
||||
if git log origin/main..HEAD --format=%B | grep -q "BREAKING CHANGE:"; then
|
||||
echo "✓ Breaking change acknowledged in commit message"
|
||||
echo "acknowledged=true" >> $GITHUB_OUTPUT
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "✗ Breaking change NOT acknowledged"
|
||||
echo "acknowledged=false" >> $GITHUB_OUTPUT
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
|
||||
- name: Evaluate integration test results
|
||||
if: always()
|
||||
run: |
|
||||
TEST_FAILED="${{ steps.test_integration.outcome == 'failure' }}"
|
||||
ACKNOWLEDGED="${{ steps.check_ack.outputs.acknowledged }}"
|
||||
|
||||
if [[ "$TEST_FAILED" == "true" ]]; then
|
||||
if [[ "$ACKNOWLEDGED" == "true" ]]; then
|
||||
echo ""
|
||||
echo "⚠️ WARNING: Integration tests failed with main config but acknowledged"
|
||||
echo ""
|
||||
exit 0 # Pass the check
|
||||
else
|
||||
echo ""
|
||||
echo "❌ ERROR: Integration tests failed with main config without acknowledgment"
|
||||
echo ""
|
||||
echo "To acknowledge this breaking change, do ONE of:"
|
||||
echo " 1. Add '!:' to your PR title (e.g., 'feat!: change xyz')"
|
||||
echo " 2. Include 'BREAKING CHANGE:' in a commit message"
|
||||
echo ""
|
||||
exit 1 # Fail the check
|
||||
fi
|
||||
fi
|
||||
|
||||
test-integration-release:
|
||||
name: Run Integration Tests with Latest Release (Informational)
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout PR branch
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Get latest release
|
||||
id: get_release
|
||||
run: |
|
||||
# Get the latest release from GitHub
|
||||
LATEST_TAG=$(gh release list --limit 1 --json tagName --jq '.[0].tagName' 2>/dev/null || echo "")
|
||||
|
||||
if [ -z "$LATEST_TAG" ]; then
|
||||
echo "No releases found, skipping release compatibility check"
|
||||
echo "has_release=false" >> $GITHUB_OUTPUT
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "Latest release: $LATEST_TAG"
|
||||
echo "has_release=true" >> $GITHUB_OUTPUT
|
||||
echo "tag=$LATEST_TAG" >> $GITHUB_OUTPUT
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
|
||||
- name: Extract ci-tests run.yaml from release
|
||||
if: steps.get_release.outputs.has_release == 'true'
|
||||
id: extract_config
|
||||
run: |
|
||||
RELEASE_TAG="${{ steps.get_release.outputs.tag }}"
|
||||
|
||||
# Try with src/ prefix first (newer releases), then without (older releases)
|
||||
if git show "$RELEASE_TAG:src/llama_stack/distributions/ci-tests/run.yaml" > /tmp/release-ci-tests-run.yaml 2>/dev/null; then
|
||||
echo "Extracted ci-tests run.yaml from release $RELEASE_TAG (src/ path)"
|
||||
echo "has_config=true" >> $GITHUB_OUTPUT
|
||||
elif git show "$RELEASE_TAG:llama_stack/distributions/ci-tests/run.yaml" > /tmp/release-ci-tests-run.yaml 2>/dev/null; then
|
||||
echo "Extracted ci-tests run.yaml from release $RELEASE_TAG (old path)"
|
||||
echo "has_config=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "::warning::ci-tests/run.yaml not found in release $RELEASE_TAG"
|
||||
echo "has_config=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Setup test environment
|
||||
if: steps.get_release.outputs.has_release == 'true' && steps.extract_config.outputs.has_config == 'true'
|
||||
uses: ./.github/actions/setup-test-environment
|
||||
with:
|
||||
python-version: '3.12'
|
||||
client-version: 'latest'
|
||||
setup: 'ollama'
|
||||
suite: 'base'
|
||||
inference-mode: 'replay'
|
||||
|
||||
- name: Run integration tests with release config (PR branch)
|
||||
id: test_release_pr
|
||||
if: steps.get_release.outputs.has_release == 'true' && steps.extract_config.outputs.has_config == 'true'
|
||||
continue-on-error: true
|
||||
uses: ./.github/actions/run-and-record-tests
|
||||
with:
|
||||
stack-config: /tmp/release-ci-tests-run.yaml
|
||||
setup: 'ollama'
|
||||
inference-mode: 'replay'
|
||||
suite: 'base'
|
||||
|
||||
- name: Checkout main branch to test baseline
|
||||
if: steps.get_release.outputs.has_release == 'true' && steps.extract_config.outputs.has_config == 'true'
|
||||
run: |
|
||||
git checkout origin/main
|
||||
|
||||
- name: Setup test environment for main
|
||||
if: steps.get_release.outputs.has_release == 'true' && steps.extract_config.outputs.has_config == 'true'
|
||||
uses: ./.github/actions/setup-test-environment
|
||||
with:
|
||||
python-version: '3.12'
|
||||
client-version: 'latest'
|
||||
setup: 'ollama'
|
||||
suite: 'base'
|
||||
inference-mode: 'replay'
|
||||
|
||||
- name: Run integration tests with release config (main branch)
|
||||
id: test_release_main
|
||||
if: steps.get_release.outputs.has_release == 'true' && steps.extract_config.outputs.has_config == 'true'
|
||||
continue-on-error: true
|
||||
uses: ./.github/actions/run-and-record-tests
|
||||
with:
|
||||
stack-config: /tmp/release-ci-tests-run.yaml
|
||||
setup: 'ollama'
|
||||
inference-mode: 'replay'
|
||||
suite: 'base'
|
||||
|
||||
- name: Report results and post PR comment
|
||||
if: always() && steps.get_release.outputs.has_release == 'true' && steps.extract_config.outputs.has_config == 'true'
|
||||
run: |
|
||||
RELEASE_TAG="${{ steps.get_release.outputs.tag }}"
|
||||
PR_OUTCOME="${{ steps.test_release_pr.outcome }}"
|
||||
MAIN_OUTCOME="${{ steps.test_release_main.outcome }}"
|
||||
|
||||
if [[ "$PR_OUTCOME" == "failure" && "$MAIN_OUTCOME" == "success" ]]; then
|
||||
# NEW breaking change - PR fails but main passes
|
||||
echo "::error::🚨 This PR introduces a NEW breaking change!"
|
||||
|
||||
# Check if we already posted a comment (to avoid spam on every push)
|
||||
EXISTING_COMMENT=$(gh pr view ${{ github.event.pull_request.number }} --json comments --jq '.comments[] | select(.body | contains("🚨 New Breaking Change Detected") and contains("Integration tests")) | .id' | head -1)
|
||||
|
||||
if [[ -z "$EXISTING_COMMENT" ]]; then
|
||||
gh pr comment ${{ github.event.pull_request.number }} --body "## 🚨 New Breaking Change Detected
|
||||
|
||||
**Integration tests against release \`$RELEASE_TAG\` are now failing**
|
||||
|
||||
⚠️ This PR introduces a breaking change that affects compatibility with the latest release.
|
||||
|
||||
- Users on release \`$RELEASE_TAG\` may not be able to upgrade
|
||||
- Existing configurations may break
|
||||
|
||||
The tests pass on \`main\` but fail with this PR's changes.
|
||||
|
||||
> **Note:** This is informational only and does not block merge.
|
||||
> Consider whether this breaking change is acceptable for users."
|
||||
else
|
||||
echo "Comment already exists, skipping to avoid spam"
|
||||
fi
|
||||
|
||||
cat >> $GITHUB_STEP_SUMMARY <<EOF
|
||||
## 🚨 NEW Breaking Change Detected
|
||||
|
||||
**Integration tests against release \`$RELEASE_TAG\` FAILED**
|
||||
|
||||
⚠️ **This PR introduces a NEW breaking change**
|
||||
|
||||
- Tests **PASS** on main branch ✅
|
||||
- Tests **FAIL** on PR branch ❌
|
||||
- Users on release \`$RELEASE_TAG\` may not be able to upgrade
|
||||
- Existing configurations may break
|
||||
|
||||
> **Note:** This is informational only and does not block merge.
|
||||
> Consider whether this breaking change is acceptable for users.
|
||||
EOF
|
||||
|
||||
elif [[ "$PR_OUTCOME" == "failure" ]]; then
|
||||
# Existing breaking change - both PR and main fail
|
||||
echo "::warning::Breaking change already exists in main branch"
|
||||
|
||||
cat >> $GITHUB_STEP_SUMMARY <<EOF
|
||||
## ⚠️ Release Compatibility Test Failed (Existing Issue)
|
||||
|
||||
**Integration tests against release \`$RELEASE_TAG\` FAILED**
|
||||
|
||||
- Tests **FAIL** on main branch ❌
|
||||
- Tests **FAIL** on PR branch ❌
|
||||
- This breaking change already exists in main (not introduced by this PR)
|
||||
|
||||
> **Note:** This is informational only.
|
||||
EOF
|
||||
|
||||
else
|
||||
# Success - tests pass
|
||||
cat >> $GITHUB_STEP_SUMMARY <<EOF
|
||||
## ✅ Release Compatibility Test Passed
|
||||
|
||||
Integration tests against release \`$RELEASE_TAG\` passed successfully.
|
||||
This PR maintains compatibility with the latest release.
|
||||
EOF
|
||||
fi
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
|
||||
check-schema-release-compatibility:
|
||||
name: Check Schema Compatibility with Latest Release (Informational)
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout PR branch
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@85856786d1ce8acfbcc2f13a5f3fbd6b938f9f41 # v7.1.2
|
||||
with:
|
||||
enable-cache: true
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv sync --group dev
|
||||
|
||||
- name: Get latest release
|
||||
id: get_release
|
||||
run: |
|
||||
# Get the latest release from GitHub
|
||||
LATEST_TAG=$(gh release list --limit 1 --json tagName --jq '.[0].tagName' 2>/dev/null || echo "")
|
||||
|
||||
if [ -z "$LATEST_TAG" ]; then
|
||||
echo "No releases found, skipping release compatibility check"
|
||||
echo "has_release=false" >> $GITHUB_OUTPUT
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "Latest release: $LATEST_TAG"
|
||||
echo "has_release=true" >> $GITHUB_OUTPUT
|
||||
echo "tag=$LATEST_TAG" >> $GITHUB_OUTPUT
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
|
||||
- name: Extract configs from release
|
||||
if: steps.get_release.outputs.has_release == 'true'
|
||||
id: extract_release_configs
|
||||
run: |
|
||||
RELEASE_TAG="${{ steps.get_release.outputs.tag }}"
|
||||
|
||||
# Get run.yaml files from the release (try both src/ and old path)
|
||||
CONFIG_PATHS=$(git ls-tree -r --name-only "$RELEASE_TAG" | grep "llama_stack/distributions/.*/run.yaml$" || true)
|
||||
|
||||
if [ -z "$CONFIG_PATHS" ]; then
|
||||
echo "::warning::No run.yaml files found in release $RELEASE_TAG"
|
||||
echo "has_configs=false" >> $GITHUB_OUTPUT
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Extract all configs to a temp directory
|
||||
mkdir -p /tmp/release_configs
|
||||
echo "Extracting configs from release $RELEASE_TAG:"
|
||||
|
||||
while IFS= read -r config_path; do
|
||||
if [ -z "$config_path" ]; then
|
||||
continue
|
||||
fi
|
||||
|
||||
filename=$(basename $(dirname "$config_path"))
|
||||
echo " - $filename (from $config_path)"
|
||||
|
||||
git show "$RELEASE_TAG:$config_path" > "/tmp/release_configs/${filename}.yaml" 2>/dev/null || true
|
||||
done <<< "$CONFIG_PATHS"
|
||||
|
||||
echo ""
|
||||
echo "Extracted $(ls /tmp/release_configs/*.yaml 2>/dev/null | wc -l) config files"
|
||||
echo "has_configs=true" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Test against release configs (PR branch)
|
||||
id: test_schema_pr
|
||||
if: steps.get_release.outputs.has_release == 'true' && steps.extract_release_configs.outputs.has_configs == 'true'
|
||||
continue-on-error: true
|
||||
run: |
|
||||
RELEASE_TAG="${{ steps.get_release.outputs.tag }}"
|
||||
COMPAT_TEST_CONFIGS_DIR=/tmp/release_configs uv run pytest tests/backward_compat/test_run_config.py -v --tb=short
|
||||
|
||||
- name: Checkout main branch to test baseline
|
||||
if: steps.get_release.outputs.has_release == 'true' && steps.extract_release_configs.outputs.has_configs == 'true'
|
||||
run: |
|
||||
git checkout origin/main
|
||||
|
||||
- name: Install dependencies for main
|
||||
if: steps.get_release.outputs.has_release == 'true' && steps.extract_release_configs.outputs.has_configs == 'true'
|
||||
run: |
|
||||
uv sync --group dev
|
||||
|
||||
- name: Test against release configs (main branch)
|
||||
id: test_schema_main
|
||||
if: steps.get_release.outputs.has_release == 'true' && steps.extract_release_configs.outputs.has_configs == 'true'
|
||||
continue-on-error: true
|
||||
run: |
|
||||
RELEASE_TAG="${{ steps.get_release.outputs.tag }}"
|
||||
COMPAT_TEST_CONFIGS_DIR=/tmp/release_configs uv run pytest tests/backward_compat/test_run_config.py -v --tb=short
|
||||
|
||||
- name: Report results and post PR comment
|
||||
if: always() && steps.get_release.outputs.has_release == 'true' && steps.extract_release_configs.outputs.has_configs == 'true'
|
||||
run: |
|
||||
RELEASE_TAG="${{ steps.get_release.outputs.tag }}"
|
||||
PR_OUTCOME="${{ steps.test_schema_pr.outcome }}"
|
||||
MAIN_OUTCOME="${{ steps.test_schema_main.outcome }}"
|
||||
|
||||
if [[ "$PR_OUTCOME" == "failure" && "$MAIN_OUTCOME" == "success" ]]; then
|
||||
# NEW breaking change - PR fails but main passes
|
||||
echo "::error::🚨 This PR introduces a NEW schema breaking change!"
|
||||
|
||||
# Check if we already posted a comment (to avoid spam on every push)
|
||||
EXISTING_COMMENT=$(gh pr view ${{ github.event.pull_request.number }} --json comments --jq '.comments[] | select(.body | contains("🚨 New Schema Breaking Change Detected")) | .id' | head -1)
|
||||
|
||||
if [[ -z "$EXISTING_COMMENT" ]]; then
|
||||
gh pr comment ${{ github.event.pull_request.number }} --body "## 🚨 New Schema Breaking Change Detected
|
||||
|
||||
**Schema validation against release \`$RELEASE_TAG\` is now failing**
|
||||
|
||||
⚠️ This PR introduces a schema breaking change that affects compatibility with the latest release.
|
||||
|
||||
- Users on release \`$RELEASE_TAG\` will not be able to upgrade
|
||||
- Existing run.yaml configurations will fail validation
|
||||
|
||||
The tests pass on \`main\` but fail with this PR's changes.
|
||||
|
||||
> **Note:** This is informational only and does not block merge.
|
||||
> Consider whether this breaking change is acceptable for users."
|
||||
else
|
||||
echo "Comment already exists, skipping to avoid spam"
|
||||
fi
|
||||
|
||||
cat >> $GITHUB_STEP_SUMMARY <<EOF
|
||||
## 🚨 NEW Schema Breaking Change Detected
|
||||
|
||||
**Schema validation against release \`$RELEASE_TAG\` FAILED**
|
||||
|
||||
⚠️ **This PR introduces a NEW schema breaking change**
|
||||
|
||||
- Tests **PASS** on main branch ✅
|
||||
- Tests **FAIL** on PR branch ❌
|
||||
- Users on release \`$RELEASE_TAG\` will not be able to upgrade
|
||||
- Existing run.yaml configurations will fail validation
|
||||
|
||||
> **Note:** This is informational only and does not block merge.
|
||||
> Consider whether this breaking change is acceptable for users.
|
||||
EOF
|
||||
|
||||
elif [[ "$PR_OUTCOME" == "failure" ]]; then
|
||||
# Existing breaking change - both PR and main fail
|
||||
echo "::warning::Schema breaking change already exists in main branch"
|
||||
|
||||
cat >> $GITHUB_STEP_SUMMARY <<EOF
|
||||
## ⚠️ Release Schema Compatibility Failed (Existing Issue)
|
||||
|
||||
**Schema validation against release \`$RELEASE_TAG\` FAILED**
|
||||
|
||||
- Tests **FAIL** on main branch ❌
|
||||
- Tests **FAIL** on PR branch ❌
|
||||
- This schema breaking change already exists in main (not introduced by this PR)
|
||||
|
||||
> **Note:** This is informational only.
|
||||
EOF
|
||||
|
||||
else
|
||||
# Success - tests pass
|
||||
cat >> $GITHUB_STEP_SUMMARY <<EOF
|
||||
## ✅ Release Schema Compatibility Passed
|
||||
|
||||
All run.yaml configs from release \`$RELEASE_TAG\` are compatible.
|
||||
This PR maintains backward compatibility with the latest release.
|
||||
EOF
|
||||
fi
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
1
.github/workflows/conformance.yml
vendored
1
.github/workflows/conformance.yml
vendored
|
|
@ -22,7 +22,6 @@ on:
|
|||
- 'docs/static/stable-llama-stack-spec.yaml' # Stable APIs spec
|
||||
- 'docs/static/experimental-llama-stack-spec.yaml' # Experimental APIs spec
|
||||
- 'docs/static/deprecated-llama-stack-spec.yaml' # Deprecated APIs spec
|
||||
- 'docs/static/llama-stack-spec.html' # Legacy HTML spec
|
||||
- '.github/workflows/conformance.yml' # This workflow itself
|
||||
|
||||
concurrency:
|
||||
|
|
|
|||
10
.github/workflows/install-script-ci.yml
vendored
10
.github/workflows/install-script-ci.yml
vendored
|
|
@ -30,10 +30,16 @@ jobs:
|
|||
|
||||
- name: Build a single provider
|
||||
run: |
|
||||
BUILD_ARGS="--build-arg INSTALL_MODE=editable --build-arg DISTRO_NAME=starter"
|
||||
if [ -n "${UV_EXTRA_INDEX_URL:-}" ]; then
|
||||
BUILD_ARGS="$BUILD_ARGS --build-arg UV_EXTRA_INDEX_URL=$UV_EXTRA_INDEX_URL"
|
||||
fi
|
||||
if [ -n "${UV_INDEX_STRATEGY:-}" ]; then
|
||||
BUILD_ARGS="$BUILD_ARGS --build-arg UV_INDEX_STRATEGY=$UV_INDEX_STRATEGY"
|
||||
fi
|
||||
docker build . \
|
||||
-f containers/Containerfile \
|
||||
--build-arg INSTALL_MODE=editable \
|
||||
--build-arg DISTRO_NAME=starter \
|
||||
$BUILD_ARGS \
|
||||
--tag llama-stack:starter-ci
|
||||
|
||||
- name: Run installer end-to-end
|
||||
|
|
|
|||
8
.github/workflows/integration-auth-tests.yml
vendored
8
.github/workflows/integration-auth-tests.yml
vendored
|
|
@ -4,9 +4,13 @@ run-name: Run the integration test suite with Kubernetes authentication
|
|||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
branches:
|
||||
- main
|
||||
- 'release-[0-9]+.[0-9]+.x'
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
branches:
|
||||
- main
|
||||
- 'release-[0-9]+.[0-9]+.x'
|
||||
paths:
|
||||
- 'distributions/**'
|
||||
- 'src/llama_stack/**'
|
||||
|
|
|
|||
|
|
@ -4,9 +4,13 @@ run-name: Run the integration test suite with SqlStore
|
|||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
branches:
|
||||
- main
|
||||
- 'release-[0-9]+.[0-9]+.x'
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
branches:
|
||||
- main
|
||||
- 'release-[0-9]+.[0-9]+.x'
|
||||
paths:
|
||||
- 'src/llama_stack/providers/utils/sqlstore/**'
|
||||
- 'tests/integration/sqlstore/**'
|
||||
|
|
|
|||
11
.github/workflows/integration-tests.yml
vendored
11
.github/workflows/integration-tests.yml
vendored
|
|
@ -4,9 +4,13 @@ run-name: Run the integration test suites from tests/integration in replay mode
|
|||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
branches:
|
||||
- main
|
||||
- 'release-[0-9]+.[0-9]+.x'
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
branches:
|
||||
- main
|
||||
- 'release-[0-9]+.[0-9]+.x'
|
||||
types: [opened, synchronize, reopened]
|
||||
paths:
|
||||
- 'src/llama_stack/**'
|
||||
|
|
@ -18,6 +22,7 @@ on:
|
|||
- '.github/actions/setup-ollama/action.yml'
|
||||
- '.github/actions/setup-test-environment/action.yml'
|
||||
- '.github/actions/run-and-record-tests/action.yml'
|
||||
- 'scripts/integration-tests.sh'
|
||||
schedule:
|
||||
# If changing the cron schedule, update the provider in the test-matrix job
|
||||
- cron: '0 0 * * *' # (test latest client) Daily at 12 AM UTC
|
||||
|
|
@ -47,7 +52,7 @@ jobs:
|
|||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
client-type: [library, docker]
|
||||
client-type: [library, docker, server]
|
||||
# Use Python 3.13 only on nightly schedule (daily latest client test), otherwise use 3.12
|
||||
python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }}
|
||||
client-version: ${{ (github.event.schedule == '0 0 * * *' || github.event.inputs.test-all-client-versions == 'true') && fromJSON('["published", "latest"]') || fromJSON('["latest"]') }}
|
||||
|
|
|
|||
|
|
@ -4,9 +4,13 @@ run-name: Run the integration test suite with various VectorIO providers
|
|||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
branches:
|
||||
- main
|
||||
- 'release-[0-9]+.[0-9]+.x'
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
branches:
|
||||
- main
|
||||
- 'release-[0-9]+.[0-9]+.x'
|
||||
paths:
|
||||
- 'src/llama_stack/**'
|
||||
- '!src/llama_stack/ui/**'
|
||||
|
|
|
|||
68
.github/workflows/pre-commit.yml
vendored
68
.github/workflows/pre-commit.yml
vendored
|
|
@ -5,7 +5,9 @@ run-name: Run pre-commit checks
|
|||
on:
|
||||
pull_request:
|
||||
push:
|
||||
branches: [main]
|
||||
branches:
|
||||
- main
|
||||
- 'release-[0-9]+.[0-9]+.x'
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
|
||||
|
|
@ -43,23 +45,41 @@ jobs:
|
|||
cache: 'npm'
|
||||
cache-dependency-path: 'src/llama_stack/ui/'
|
||||
|
||||
- name: Set up uv
|
||||
uses: astral-sh/setup-uv@85856786d1ce8acfbcc2f13a5f3fbd6b938f9f41 # v7.1.2
|
||||
|
||||
- name: Install npm dependencies
|
||||
run: npm ci
|
||||
working-directory: src/llama_stack/ui
|
||||
|
||||
- name: Install pre-commit
|
||||
run: python -m pip install pre-commit
|
||||
|
||||
- name: Cache pre-commit
|
||||
uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # v4
|
||||
with:
|
||||
path: ~/.cache/pre-commit
|
||||
key: pre-commit-3|${{ env.pythonLocation }}|${{ hashFiles('.pre-commit-config.yaml') }}
|
||||
|
||||
- name: Run pre-commit
|
||||
id: precommit
|
||||
uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
|
||||
continue-on-error: true
|
||||
run: |
|
||||
set +e
|
||||
pre-commit run --show-diff-on-failure --color=always --all-files 2>&1 | tee /tmp/precommit.log
|
||||
status=${PIPESTATUS[0]}
|
||||
echo "status=$status" >> $GITHUB_OUTPUT
|
||||
exit 0
|
||||
env:
|
||||
SKIP: no-commit-to-branch
|
||||
SKIP: no-commit-to-branch,mypy
|
||||
RUFF_OUTPUT_FORMAT: github
|
||||
|
||||
- name: Check pre-commit results
|
||||
if: steps.precommit.outcome == 'failure'
|
||||
if: steps.precommit.outputs.status != '0'
|
||||
run: |
|
||||
echo "::error::Pre-commit hooks failed. Please run 'pre-commit run --all-files' locally and commit the fixes."
|
||||
echo "::warning::Some pre-commit hooks failed. Check the output above for details."
|
||||
echo ""
|
||||
echo "Failed hooks output:"
|
||||
cat /tmp/precommit.log
|
||||
exit 1
|
||||
|
||||
- name: Debug
|
||||
|
|
@ -109,3 +129,39 @@ jobs:
|
|||
echo "$unstaged_files"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Configure client installation
|
||||
id: client-config
|
||||
uses: ./.github/actions/install-llama-stack-client
|
||||
|
||||
- name: Sync dev + type_checking dependencies
|
||||
env:
|
||||
UV_EXTRA_INDEX_URL: ${{ steps.client-config.outputs.uv-extra-index-url }}
|
||||
run: |
|
||||
if [ -n "$UV_EXTRA_INDEX_URL" ]; then
|
||||
export UV_INDEX_STRATEGY="unsafe-best-match"
|
||||
fi
|
||||
|
||||
uv sync --group dev --group type_checking
|
||||
|
||||
# Install specific client version after sync if needed
|
||||
if [ "${{ steps.client-config.outputs.install-after-sync }}" = "true" ]; then
|
||||
echo "Installing llama-stack-client from: ${{ steps.client-config.outputs.install-source }}"
|
||||
uv pip install ${{ steps.client-config.outputs.install-source }}
|
||||
fi
|
||||
|
||||
- name: Run mypy (full type_checking)
|
||||
env:
|
||||
UV_EXTRA_INDEX_URL: ${{ steps.client-config.outputs.uv-extra-index-url }}
|
||||
run: |
|
||||
if [ -n "$UV_EXTRA_INDEX_URL" ]; then
|
||||
export UV_INDEX_STRATEGY="unsafe-best-match"
|
||||
fi
|
||||
|
||||
set +e
|
||||
uv run --group dev --group type_checking mypy
|
||||
status=$?
|
||||
if [ $status -ne 0 ]; then
|
||||
echo "::error::Full mypy failed. Reproduce locally with 'uv run pre-commit run mypy-full --hook-stage manual --all-files'."
|
||||
fi
|
||||
exit $status
|
||||
|
|
|
|||
227
.github/workflows/precommit-trigger.yml
vendored
227
.github/workflows/precommit-trigger.yml
vendored
|
|
@ -1,227 +0,0 @@
|
|||
name: Pre-commit Bot
|
||||
|
||||
run-name: Pre-commit bot for PR #${{ github.event.issue.number }}
|
||||
|
||||
on:
|
||||
issue_comment:
|
||||
types: [created]
|
||||
|
||||
jobs:
|
||||
pre-commit:
|
||||
# Only run on pull request comments
|
||||
if: github.event.issue.pull_request && contains(github.event.comment.body, '@github-actions run precommit')
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
|
||||
steps:
|
||||
- name: Check comment author and get PR details
|
||||
id: check_author
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
// Get PR details
|
||||
const pr = await github.rest.pulls.get({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
pull_number: context.issue.number
|
||||
});
|
||||
|
||||
// Check if commenter has write access or is the PR author
|
||||
const commenter = context.payload.comment.user.login;
|
||||
const prAuthor = pr.data.user.login;
|
||||
|
||||
let hasPermission = false;
|
||||
|
||||
// Check if commenter is PR author
|
||||
if (commenter === prAuthor) {
|
||||
hasPermission = true;
|
||||
console.log(`Comment author ${commenter} is the PR author`);
|
||||
} else {
|
||||
// Check if commenter has write/admin access
|
||||
try {
|
||||
const permission = await github.rest.repos.getCollaboratorPermissionLevel({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
username: commenter
|
||||
});
|
||||
|
||||
const level = permission.data.permission;
|
||||
hasPermission = ['write', 'admin', 'maintain'].includes(level);
|
||||
console.log(`Comment author ${commenter} has permission: ${level}`);
|
||||
} catch (error) {
|
||||
console.log(`Could not check permissions for ${commenter}: ${error.message}`);
|
||||
}
|
||||
}
|
||||
|
||||
if (!hasPermission) {
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: context.issue.number,
|
||||
body: `❌ @${commenter} You don't have permission to trigger pre-commit. Only PR authors or repository collaborators can run this command.`
|
||||
});
|
||||
core.setFailed(`User ${commenter} does not have permission`);
|
||||
return;
|
||||
}
|
||||
|
||||
// Save PR info for later steps
|
||||
core.setOutput('pr_number', context.issue.number);
|
||||
core.setOutput('pr_head_ref', pr.data.head.ref);
|
||||
core.setOutput('pr_head_sha', pr.data.head.sha);
|
||||
core.setOutput('pr_head_repo', pr.data.head.repo.full_name);
|
||||
core.setOutput('pr_base_ref', pr.data.base.ref);
|
||||
core.setOutput('is_fork', pr.data.head.repo.full_name !== context.payload.repository.full_name);
|
||||
core.setOutput('authorized', 'true');
|
||||
|
||||
- name: React to comment
|
||||
if: steps.check_author.outputs.authorized == 'true'
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
await github.rest.reactions.createForIssueComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
comment_id: context.payload.comment.id,
|
||||
content: 'rocket'
|
||||
});
|
||||
|
||||
- name: Comment starting
|
||||
if: steps.check_author.outputs.authorized == 'true'
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: ${{ steps.check_author.outputs.pr_number }},
|
||||
body: `⏳ Running [pre-commit hooks](https://github.com/${context.repo.owner}/${context.repo.repo}/actions/runs/${context.runId}) on PR #${{ steps.check_author.outputs.pr_number }}...`
|
||||
});
|
||||
|
||||
- name: Checkout PR branch (same-repo)
|
||||
if: steps.check_author.outputs.authorized == 'true' && steps.check_author.outputs.is_fork == 'false'
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
with:
|
||||
ref: ${{ steps.check_author.outputs.pr_head_ref }}
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Checkout PR branch (fork)
|
||||
if: steps.check_author.outputs.authorized == 'true' && steps.check_author.outputs.is_fork == 'true'
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
with:
|
||||
repository: ${{ steps.check_author.outputs.pr_head_repo }}
|
||||
ref: ${{ steps.check_author.outputs.pr_head_ref }}
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Verify checkout
|
||||
if: steps.check_author.outputs.authorized == 'true'
|
||||
run: |
|
||||
echo "Current SHA: $(git rev-parse HEAD)"
|
||||
echo "Expected SHA: ${{ steps.check_author.outputs.pr_head_sha }}"
|
||||
if [[ "$(git rev-parse HEAD)" != "${{ steps.check_author.outputs.pr_head_sha }}" ]]; then
|
||||
echo "::error::Checked out SHA does not match expected SHA"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Set up Python
|
||||
if: steps.check_author.outputs.authorized == 'true'
|
||||
uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0
|
||||
with:
|
||||
python-version: '3.12'
|
||||
cache: pip
|
||||
cache-dependency-path: |
|
||||
**/requirements*.txt
|
||||
.pre-commit-config.yaml
|
||||
|
||||
- name: Set up Node.js
|
||||
if: steps.check_author.outputs.authorized == 'true'
|
||||
uses: actions/setup-node@2028fbc5c25fe9cf00d9f06a71cc4710d4507903 # v6.0.0
|
||||
with:
|
||||
node-version: '20'
|
||||
cache: 'npm'
|
||||
cache-dependency-path: 'src/llama_stack/ui/'
|
||||
|
||||
- name: Install npm dependencies
|
||||
if: steps.check_author.outputs.authorized == 'true'
|
||||
run: npm ci
|
||||
working-directory: src/llama_stack/ui
|
||||
|
||||
- name: Run pre-commit
|
||||
if: steps.check_author.outputs.authorized == 'true'
|
||||
id: precommit
|
||||
uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
|
||||
continue-on-error: true
|
||||
env:
|
||||
SKIP: no-commit-to-branch
|
||||
RUFF_OUTPUT_FORMAT: github
|
||||
|
||||
- name: Check for changes
|
||||
if: steps.check_author.outputs.authorized == 'true'
|
||||
id: changes
|
||||
run: |
|
||||
if ! git diff --exit-code || [ -n "$(git ls-files --others --exclude-standard)" ]; then
|
||||
echo "has_changes=true" >> $GITHUB_OUTPUT
|
||||
echo "Changes detected after pre-commit"
|
||||
else
|
||||
echo "has_changes=false" >> $GITHUB_OUTPUT
|
||||
echo "No changes after pre-commit"
|
||||
fi
|
||||
|
||||
- name: Commit and push changes
|
||||
if: steps.check_author.outputs.authorized == 'true' && steps.changes.outputs.has_changes == 'true'
|
||||
run: |
|
||||
git config --local user.email "github-actions[bot]@users.noreply.github.com"
|
||||
git config --local user.name "github-actions[bot]"
|
||||
|
||||
git add -A
|
||||
git commit -m "style: apply pre-commit fixes
|
||||
|
||||
🤖 Applied by @github-actions bot via pre-commit workflow"
|
||||
|
||||
# Push changes
|
||||
git push origin HEAD:${{ steps.check_author.outputs.pr_head_ref }}
|
||||
|
||||
- name: Comment success with changes
|
||||
if: steps.check_author.outputs.authorized == 'true' && steps.changes.outputs.has_changes == 'true'
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: ${{ steps.check_author.outputs.pr_number }},
|
||||
body: `✅ Pre-commit hooks completed successfully!\n\n🔧 Changes have been committed and pushed to the PR branch.`
|
||||
});
|
||||
|
||||
- name: Comment success without changes
|
||||
if: steps.check_author.outputs.authorized == 'true' && steps.changes.outputs.has_changes == 'false' && steps.precommit.outcome == 'success'
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: ${{ steps.check_author.outputs.pr_number }},
|
||||
body: `✅ Pre-commit hooks passed!\n\n✨ No changes needed - your code is already formatted correctly.`
|
||||
});
|
||||
|
||||
- name: Comment failure
|
||||
if: failure()
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: ${{ steps.check_author.outputs.pr_number }},
|
||||
body: `❌ Pre-commit workflow failed!\n\nPlease check the [workflow logs](https://github.com/${context.repo.owner}/${context.repo.repo}/actions/runs/${context.runId}) for details.`
|
||||
});
|
||||
38
.github/workflows/providers-build.yml
vendored
38
.github/workflows/providers-build.yml
vendored
|
|
@ -72,10 +72,16 @@ jobs:
|
|||
- name: Build container image
|
||||
if: matrix.image-type == 'container'
|
||||
run: |
|
||||
BUILD_ARGS="--build-arg INSTALL_MODE=editable --build-arg DISTRO_NAME=${{ matrix.distro }}"
|
||||
if [ -n "${UV_EXTRA_INDEX_URL:-}" ]; then
|
||||
BUILD_ARGS="$BUILD_ARGS --build-arg UV_EXTRA_INDEX_URL=$UV_EXTRA_INDEX_URL"
|
||||
fi
|
||||
if [ -n "${UV_INDEX_STRATEGY:-}" ]; then
|
||||
BUILD_ARGS="$BUILD_ARGS --build-arg UV_INDEX_STRATEGY=$UV_INDEX_STRATEGY"
|
||||
fi
|
||||
docker build . \
|
||||
-f containers/Containerfile \
|
||||
--build-arg INSTALL_MODE=editable \
|
||||
--build-arg DISTRO_NAME=${{ matrix.distro }} \
|
||||
$BUILD_ARGS \
|
||||
--tag llama-stack:${{ matrix.distro }}-ci
|
||||
|
||||
- name: Print dependencies in the image
|
||||
|
|
@ -108,12 +114,18 @@ jobs:
|
|||
- name: Build container image
|
||||
run: |
|
||||
BASE_IMAGE=$(yq -r '.distribution_spec.container_image // "python:3.12-slim"' src/llama_stack/distributions/ci-tests/build.yaml)
|
||||
BUILD_ARGS="--build-arg INSTALL_MODE=editable --build-arg DISTRO_NAME=ci-tests"
|
||||
BUILD_ARGS="$BUILD_ARGS --build-arg BASE_IMAGE=$BASE_IMAGE"
|
||||
BUILD_ARGS="$BUILD_ARGS --build-arg RUN_CONFIG_PATH=/workspace/src/llama_stack/distributions/ci-tests/run.yaml"
|
||||
if [ -n "${UV_EXTRA_INDEX_URL:-}" ]; then
|
||||
BUILD_ARGS="$BUILD_ARGS --build-arg UV_EXTRA_INDEX_URL=$UV_EXTRA_INDEX_URL"
|
||||
fi
|
||||
if [ -n "${UV_INDEX_STRATEGY:-}" ]; then
|
||||
BUILD_ARGS="$BUILD_ARGS --build-arg UV_INDEX_STRATEGY=$UV_INDEX_STRATEGY"
|
||||
fi
|
||||
docker build . \
|
||||
-f containers/Containerfile \
|
||||
--build-arg INSTALL_MODE=editable \
|
||||
--build-arg DISTRO_NAME=ci-tests \
|
||||
--build-arg BASE_IMAGE="$BASE_IMAGE" \
|
||||
--build-arg RUN_CONFIG_PATH=/workspace/src/llama_stack/distributions/ci-tests/run.yaml \
|
||||
$BUILD_ARGS \
|
||||
-t llama-stack:ci-tests
|
||||
|
||||
- name: Inspect the container image entrypoint
|
||||
|
|
@ -148,12 +160,18 @@ jobs:
|
|||
- name: Build UBI9 container image
|
||||
run: |
|
||||
BASE_IMAGE=$(yq -r '.distribution_spec.container_image // "registry.access.redhat.com/ubi9:latest"' src/llama_stack/distributions/ci-tests/build.yaml)
|
||||
BUILD_ARGS="--build-arg INSTALL_MODE=editable --build-arg DISTRO_NAME=ci-tests"
|
||||
BUILD_ARGS="$BUILD_ARGS --build-arg BASE_IMAGE=$BASE_IMAGE"
|
||||
BUILD_ARGS="$BUILD_ARGS --build-arg RUN_CONFIG_PATH=/workspace/src/llama_stack/distributions/ci-tests/run.yaml"
|
||||
if [ -n "${UV_EXTRA_INDEX_URL:-}" ]; then
|
||||
BUILD_ARGS="$BUILD_ARGS --build-arg UV_EXTRA_INDEX_URL=$UV_EXTRA_INDEX_URL"
|
||||
fi
|
||||
if [ -n "${UV_INDEX_STRATEGY:-}" ]; then
|
||||
BUILD_ARGS="$BUILD_ARGS --build-arg UV_INDEX_STRATEGY=$UV_INDEX_STRATEGY"
|
||||
fi
|
||||
docker build . \
|
||||
-f containers/Containerfile \
|
||||
--build-arg INSTALL_MODE=editable \
|
||||
--build-arg DISTRO_NAME=ci-tests \
|
||||
--build-arg BASE_IMAGE="$BASE_IMAGE" \
|
||||
--build-arg RUN_CONFIG_PATH=/workspace/src/llama_stack/distributions/ci-tests/run.yaml \
|
||||
$BUILD_ARGS \
|
||||
-t llama-stack:ci-tests-ubi9
|
||||
|
||||
- name: Inspect UBI9 image
|
||||
|
|
|
|||
2
.github/workflows/python-build-test.yml
vendored
2
.github/workflows/python-build-test.yml
vendored
|
|
@ -24,7 +24,7 @@ jobs:
|
|||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@2ddd2b9cb38ad8efd50337e8ab201519a34c9f24 # v7.1.1
|
||||
uses: astral-sh/setup-uv@85856786d1ce8acfbcc2f13a5f3fbd6b938f9f41 # v7.1.2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
activate-environment: true
|
||||
|
|
|
|||
8
.github/workflows/unit-tests.yml
vendored
8
.github/workflows/unit-tests.yml
vendored
|
|
@ -4,9 +4,13 @@ run-name: Run the unit test suite
|
|||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
branches:
|
||||
- main
|
||||
- 'release-[0-9]+.[0-9]+.x'
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
branches:
|
||||
- main
|
||||
- 'release-[0-9]+.[0-9]+.x'
|
||||
paths:
|
||||
- 'src/llama_stack/**'
|
||||
- '!src/llama_stack/ui/**'
|
||||
|
|
|
|||
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -32,3 +32,6 @@ CLAUDE.md
|
|||
docs/.docusaurus/
|
||||
docs/node_modules/
|
||||
docs/static/imported-files/
|
||||
docs/docs/api-deprecated/
|
||||
docs/docs/api-experimental/
|
||||
docs/docs/api/
|
||||
|
|
|
|||
|
|
@ -52,22 +52,19 @@ repos:
|
|||
additional_dependencies:
|
||||
- black==24.3.0
|
||||
|
||||
- repo: https://github.com/astral-sh/uv-pre-commit
|
||||
rev: 0.7.20
|
||||
hooks:
|
||||
- id: uv-lock
|
||||
|
||||
- repo: local
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.18.2
|
||||
hooks:
|
||||
- id: mypy
|
||||
name: mypy
|
||||
additional_dependencies:
|
||||
- uv==0.7.8
|
||||
entry: uv run --group dev --group type_checking mypy
|
||||
language: python
|
||||
types: [python]
|
||||
- uv==0.6.2
|
||||
- mypy
|
||||
- pytest
|
||||
- rich
|
||||
- types-requests
|
||||
- pydantic
|
||||
pass_filenames: false
|
||||
require_serial: true
|
||||
|
||||
# - repo: https://github.com/tcort/markdown-link-check
|
||||
# rev: v3.11.2
|
||||
|
|
@ -77,11 +74,26 @@ repos:
|
|||
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: uv-lock
|
||||
name: uv-lock
|
||||
additional_dependencies:
|
||||
- uv==0.7.20
|
||||
entry: ./scripts/uv-run-with-index.sh lock
|
||||
language: python
|
||||
pass_filenames: false
|
||||
require_serial: true
|
||||
files: ^(pyproject\.toml|uv\.lock)$
|
||||
- id: mypy-full
|
||||
name: mypy (full type_checking)
|
||||
entry: ./scripts/uv-run-with-index.sh run --group dev --group type_checking mypy
|
||||
language: system
|
||||
pass_filenames: false
|
||||
stages: [manual]
|
||||
- id: distro-codegen
|
||||
name: Distribution Template Codegen
|
||||
additional_dependencies:
|
||||
- uv==0.7.8
|
||||
entry: uv run --group codegen ./scripts/distro_codegen.py
|
||||
entry: ./scripts/uv-run-with-index.sh run --group codegen ./scripts/distro_codegen.py
|
||||
language: python
|
||||
pass_filenames: false
|
||||
require_serial: true
|
||||
|
|
@ -90,7 +102,7 @@ repos:
|
|||
name: Provider Codegen
|
||||
additional_dependencies:
|
||||
- uv==0.7.8
|
||||
entry: uv run --group codegen ./scripts/provider_codegen.py
|
||||
entry: ./scripts/uv-run-with-index.sh run --group codegen ./scripts/provider_codegen.py
|
||||
language: python
|
||||
pass_filenames: false
|
||||
require_serial: true
|
||||
|
|
@ -99,7 +111,7 @@ repos:
|
|||
name: API Spec Codegen
|
||||
additional_dependencies:
|
||||
- uv==0.7.8
|
||||
entry: sh -c 'uv run ./docs/openapi_generator/run_openapi_generator.sh > /dev/null'
|
||||
entry: sh -c './scripts/uv-run-with-index.sh run ./docs/openapi_generator/run_openapi_generator.sh > /dev/null'
|
||||
language: python
|
||||
pass_filenames: false
|
||||
require_serial: true
|
||||
|
|
@ -140,7 +152,7 @@ repos:
|
|||
name: Generate CI documentation
|
||||
additional_dependencies:
|
||||
- uv==0.7.8
|
||||
entry: uv run ./scripts/gen-ci-docs.py
|
||||
entry: ./scripts/uv-run-with-index.sh run ./scripts/gen-ci-docs.py
|
||||
language: python
|
||||
pass_filenames: false
|
||||
require_serial: true
|
||||
|
|
@ -171,6 +183,23 @@ repos:
|
|||
exit 1
|
||||
fi
|
||||
exit 0
|
||||
- id: fips-compliance
|
||||
name: Ensure llama-stack remains FIPS compliant
|
||||
entry: bash
|
||||
language: system
|
||||
types: [python]
|
||||
pass_filenames: true
|
||||
exclude: '^tests/.*$' # Exclude test dir as some safety tests used MD5
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
grep -EnH '^[^#]*\b(md5|sha1|uuid3|uuid5)\b' "$@" && {
|
||||
echo;
|
||||
echo "❌ Do not use any of the following functions: hashlib.md5, hashlib.sha1, uuid.uuid3, uuid.uuid5"
|
||||
echo " These functions are not FIPS-compliant"
|
||||
echo;
|
||||
exit 1;
|
||||
} || true
|
||||
|
||||
ci:
|
||||
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
|
||||
|
|
|
|||
|
|
@ -61,6 +61,18 @@ uv run pre-commit run --all-files -v
|
|||
|
||||
The `-v` (verbose) parameter is optional but often helpful for getting more information about any issues with that the pre-commit checks identify.
|
||||
|
||||
To run the expanded mypy configuration that CI enforces, use:
|
||||
|
||||
```bash
|
||||
uv run pre-commit run mypy-full --hook-stage manual --all-files
|
||||
```
|
||||
|
||||
or invoke mypy directly with all optional dependencies:
|
||||
|
||||
```bash
|
||||
uv run --group dev --group type_checking mypy
|
||||
```
|
||||
|
||||
```{caution}
|
||||
Before pushing your changes, make sure that the pre-commit hooks have passed successfully.
|
||||
```
|
||||
|
|
|
|||
|
|
@ -1,610 +0,0 @@
|
|||
# yaml-language-server: $schema=https://app.stainlessapi.com/config-internal.schema.json
|
||||
|
||||
organization:
|
||||
# Name of your organization or company, used to determine the name of the client
|
||||
# and headings.
|
||||
name: llama-stack-client
|
||||
docs: https://llama-stack.readthedocs.io/en/latest/
|
||||
contact: llamastack@meta.com
|
||||
security:
|
||||
- {}
|
||||
- BearerAuth: []
|
||||
security_schemes:
|
||||
BearerAuth:
|
||||
type: http
|
||||
scheme: bearer
|
||||
# `targets` define the output targets and their customization options, such as
|
||||
# whether to emit the Node SDK and what it's package name should be.
|
||||
targets:
|
||||
node:
|
||||
package_name: llama-stack-client
|
||||
production_repo: llamastack/llama-stack-client-typescript
|
||||
publish:
|
||||
npm: false
|
||||
python:
|
||||
package_name: llama_stack_client
|
||||
production_repo: llamastack/llama-stack-client-python
|
||||
options:
|
||||
use_uv: true
|
||||
publish:
|
||||
pypi: true
|
||||
project_name: llama_stack_client
|
||||
kotlin:
|
||||
reverse_domain: com.llama_stack_client.api
|
||||
production_repo: null
|
||||
publish:
|
||||
maven: false
|
||||
go:
|
||||
package_name: llama-stack-client
|
||||
production_repo: llamastack/llama-stack-client-go
|
||||
options:
|
||||
enable_v2: true
|
||||
back_compat_use_shared_package: false
|
||||
|
||||
# `client_settings` define settings for the API client, such as extra constructor
|
||||
# arguments (used for authentication), retry behavior, idempotency, etc.
|
||||
client_settings:
|
||||
default_env_prefix: LLAMA_STACK_CLIENT
|
||||
opts:
|
||||
api_key:
|
||||
type: string
|
||||
read_env: LLAMA_STACK_CLIENT_API_KEY
|
||||
auth: { security_scheme: BearerAuth }
|
||||
nullable: true
|
||||
|
||||
# `environments` are a map of the name of the environment (e.g. "sandbox",
|
||||
# "production") to the corresponding url to use.
|
||||
environments:
|
||||
production: http://any-hosted-llama-stack.com
|
||||
|
||||
# `pagination` defines [pagination schemes] which provides a template to match
|
||||
# endpoints and generate next-page and auto-pagination helpers in the SDKs.
|
||||
pagination:
|
||||
- name: datasets_iterrows
|
||||
type: offset
|
||||
request:
|
||||
dataset_id:
|
||||
type: string
|
||||
start_index:
|
||||
type: integer
|
||||
x-stainless-pagination-property:
|
||||
purpose: offset_count_param
|
||||
limit:
|
||||
type: integer
|
||||
response:
|
||||
data:
|
||||
type: array
|
||||
items:
|
||||
type: object
|
||||
next_index:
|
||||
type: integer
|
||||
x-stainless-pagination-property:
|
||||
purpose: offset_count_start_field
|
||||
- name: openai_cursor_page
|
||||
type: cursor
|
||||
request:
|
||||
limit:
|
||||
type: integer
|
||||
after:
|
||||
type: string
|
||||
x-stainless-pagination-property:
|
||||
purpose: next_cursor_param
|
||||
response:
|
||||
data:
|
||||
type: array
|
||||
items: {}
|
||||
has_more:
|
||||
type: boolean
|
||||
last_id:
|
||||
type: string
|
||||
x-stainless-pagination-property:
|
||||
purpose: next_cursor_field
|
||||
# `resources` define the structure and organziation for your API, such as how
|
||||
# methods and models are grouped together and accessed. See the [configuration
|
||||
# guide] for more information.
|
||||
#
|
||||
# [configuration guide]:
|
||||
# https://app.stainlessapi.com/docs/guides/configure#resources
|
||||
resources:
|
||||
$shared:
|
||||
models:
|
||||
agent_config: AgentConfig
|
||||
interleaved_content_item: InterleavedContentItem
|
||||
interleaved_content: InterleavedContent
|
||||
param_type: ParamType
|
||||
safety_violation: SafetyViolation
|
||||
sampling_params: SamplingParams
|
||||
scoring_result: ScoringResult
|
||||
message: Message
|
||||
user_message: UserMessage
|
||||
completion_message: CompletionMessage
|
||||
tool_response_message: ToolResponseMessage
|
||||
system_message: SystemMessage
|
||||
tool_call: ToolCall
|
||||
query_result: RAGQueryResult
|
||||
document: RAGDocument
|
||||
query_config: RAGQueryConfig
|
||||
response_format: ResponseFormat
|
||||
toolgroups:
|
||||
models:
|
||||
tool_group: ToolGroup
|
||||
list_tool_groups_response: ListToolGroupsResponse
|
||||
methods:
|
||||
register: post /v1/toolgroups
|
||||
get: get /v1/toolgroups/{toolgroup_id}
|
||||
list: get /v1/toolgroups
|
||||
unregister: delete /v1/toolgroups/{toolgroup_id}
|
||||
tools:
|
||||
methods:
|
||||
get: get /v1/tools/{tool_name}
|
||||
list:
|
||||
endpoint: get /v1/tools
|
||||
paginated: false
|
||||
|
||||
tool_runtime:
|
||||
models:
|
||||
tool_def: ToolDef
|
||||
tool_invocation_result: ToolInvocationResult
|
||||
methods:
|
||||
list_tools:
|
||||
endpoint: get /v1/tool-runtime/list-tools
|
||||
paginated: false
|
||||
invoke_tool: post /v1/tool-runtime/invoke
|
||||
subresources:
|
||||
rag_tool:
|
||||
methods:
|
||||
insert: post /v1/tool-runtime/rag-tool/insert
|
||||
query: post /v1/tool-runtime/rag-tool/query
|
||||
|
||||
responses:
|
||||
models:
|
||||
response_object_stream: OpenAIResponseObjectStream
|
||||
response_object: OpenAIResponseObject
|
||||
methods:
|
||||
create:
|
||||
type: http
|
||||
endpoint: post /v1/responses
|
||||
streaming:
|
||||
stream_event_model: responses.response_object_stream
|
||||
param_discriminator: stream
|
||||
retrieve: get /v1/responses/{response_id}
|
||||
list:
|
||||
type: http
|
||||
endpoint: get /v1/responses
|
||||
delete:
|
||||
type: http
|
||||
endpoint: delete /v1/responses/{response_id}
|
||||
subresources:
|
||||
input_items:
|
||||
methods:
|
||||
list:
|
||||
type: http
|
||||
endpoint: get /v1/responses/{response_id}/input_items
|
||||
|
||||
conversations:
|
||||
models:
|
||||
conversation_object: Conversation
|
||||
methods:
|
||||
create:
|
||||
type: http
|
||||
endpoint: post /v1/conversations
|
||||
retrieve: get /v1/conversations/{conversation_id}
|
||||
update:
|
||||
type: http
|
||||
endpoint: post /v1/conversations/{conversation_id}
|
||||
delete:
|
||||
type: http
|
||||
endpoint: delete /v1/conversations/{conversation_id}
|
||||
subresources:
|
||||
items:
|
||||
methods:
|
||||
get:
|
||||
type: http
|
||||
endpoint: get /v1/conversations/{conversation_id}/items/{item_id}
|
||||
list:
|
||||
type: http
|
||||
endpoint: get /v1/conversations/{conversation_id}/items
|
||||
create:
|
||||
type: http
|
||||
endpoint: post /v1/conversations/{conversation_id}/items
|
||||
|
||||
inspect:
|
||||
models:
|
||||
healthInfo: HealthInfo
|
||||
providerInfo: ProviderInfo
|
||||
routeInfo: RouteInfo
|
||||
versionInfo: VersionInfo
|
||||
methods:
|
||||
health: get /v1/health
|
||||
version: get /v1/version
|
||||
|
||||
embeddings:
|
||||
models:
|
||||
create_embeddings_response: OpenAIEmbeddingsResponse
|
||||
methods:
|
||||
create: post /v1/embeddings
|
||||
|
||||
chat:
|
||||
models:
|
||||
chat_completion_chunk: OpenAIChatCompletionChunk
|
||||
subresources:
|
||||
completions:
|
||||
methods:
|
||||
create:
|
||||
type: http
|
||||
endpoint: post /v1/chat/completions
|
||||
streaming:
|
||||
stream_event_model: chat.chat_completion_chunk
|
||||
param_discriminator: stream
|
||||
list:
|
||||
type: http
|
||||
endpoint: get /v1/chat/completions
|
||||
retrieve:
|
||||
type: http
|
||||
endpoint: get /v1/chat/completions/{completion_id}
|
||||
completions:
|
||||
methods:
|
||||
create:
|
||||
type: http
|
||||
endpoint: post /v1/completions
|
||||
streaming:
|
||||
param_discriminator: stream
|
||||
|
||||
vector_io:
|
||||
models:
|
||||
queryChunksResponse: QueryChunksResponse
|
||||
methods:
|
||||
insert: post /v1/vector-io/insert
|
||||
query: post /v1/vector-io/query
|
||||
|
||||
vector_stores:
|
||||
models:
|
||||
vector_store: VectorStoreObject
|
||||
list_vector_stores_response: VectorStoreListResponse
|
||||
vector_store_delete_response: VectorStoreDeleteResponse
|
||||
vector_store_search_response: VectorStoreSearchResponsePage
|
||||
methods:
|
||||
create: post /v1/vector_stores
|
||||
list:
|
||||
endpoint: get /v1/vector_stores
|
||||
retrieve: get /v1/vector_stores/{vector_store_id}
|
||||
update: post /v1/vector_stores/{vector_store_id}
|
||||
delete: delete /v1/vector_stores/{vector_store_id}
|
||||
search: post /v1/vector_stores/{vector_store_id}/search
|
||||
subresources:
|
||||
files:
|
||||
models:
|
||||
vector_store_file: VectorStoreFileObject
|
||||
methods:
|
||||
list: get /v1/vector_stores/{vector_store_id}/files
|
||||
retrieve: get /v1/vector_stores/{vector_store_id}/files/{file_id}
|
||||
update: post /v1/vector_stores/{vector_store_id}/files/{file_id}
|
||||
delete: delete /v1/vector_stores/{vector_store_id}/files/{file_id}
|
||||
create: post /v1/vector_stores/{vector_store_id}/files
|
||||
content: get /v1/vector_stores/{vector_store_id}/files/{file_id}/content
|
||||
file_batches:
|
||||
models:
|
||||
vector_store_file_batches: VectorStoreFileBatchObject
|
||||
list_vector_store_files_in_batch_response: VectorStoreFilesListInBatchResponse
|
||||
methods:
|
||||
create: post /v1/vector_stores/{vector_store_id}/file_batches
|
||||
retrieve: get /v1/vector_stores/{vector_store_id}/file_batches/{batch_id}
|
||||
list_files: get /v1/vector_stores/{vector_store_id}/file_batches/{batch_id}/files
|
||||
cancel: post /v1/vector_stores/{vector_store_id}/file_batches/{batch_id}/cancel
|
||||
|
||||
models:
|
||||
models:
|
||||
model: Model
|
||||
list_models_response: ListModelsResponse
|
||||
methods:
|
||||
retrieve: get /v1/models/{model_id}
|
||||
list:
|
||||
endpoint: get /v1/models
|
||||
paginated: false
|
||||
register: post /v1/models
|
||||
unregister: delete /v1/models/{model_id}
|
||||
subresources:
|
||||
openai:
|
||||
methods:
|
||||
list:
|
||||
endpoint: get /v1/models
|
||||
paginated: false
|
||||
|
||||
providers:
|
||||
models:
|
||||
list_providers_response: ListProvidersResponse
|
||||
methods:
|
||||
list:
|
||||
endpoint: get /v1/providers
|
||||
paginated: false
|
||||
retrieve: get /v1/providers/{provider_id}
|
||||
|
||||
routes:
|
||||
models:
|
||||
list_routes_response: ListRoutesResponse
|
||||
methods:
|
||||
list:
|
||||
endpoint: get /v1/inspect/routes
|
||||
paginated: false
|
||||
|
||||
|
||||
moderations:
|
||||
models:
|
||||
create_response: ModerationObject
|
||||
methods:
|
||||
create: post /v1/moderations
|
||||
|
||||
|
||||
safety:
|
||||
models:
|
||||
run_shield_response: RunShieldResponse
|
||||
methods:
|
||||
run_shield: post /v1/safety/run-shield
|
||||
|
||||
|
||||
shields:
|
||||
models:
|
||||
shield: Shield
|
||||
list_shields_response: ListShieldsResponse
|
||||
methods:
|
||||
retrieve: get /v1/shields/{identifier}
|
||||
list:
|
||||
endpoint: get /v1/shields
|
||||
paginated: false
|
||||
register: post /v1/shields
|
||||
delete: delete /v1/shields/{identifier}
|
||||
|
||||
synthetic_data_generation:
|
||||
models:
|
||||
syntheticDataGenerationResponse: SyntheticDataGenerationResponse
|
||||
methods:
|
||||
generate: post /v1/synthetic-data-generation/generate
|
||||
|
||||
telemetry:
|
||||
models:
|
||||
span_with_status: SpanWithStatus
|
||||
trace: Trace
|
||||
query_spans_response: QuerySpansResponse
|
||||
event: Event
|
||||
query_condition: QueryCondition
|
||||
methods:
|
||||
query_traces:
|
||||
endpoint: post /v1alpha/telemetry/traces
|
||||
skip_test_reason: 'unsupported query params in java / kotlin'
|
||||
get_span_tree: post /v1alpha/telemetry/spans/{span_id}/tree
|
||||
query_spans:
|
||||
endpoint: post /v1alpha/telemetry/spans
|
||||
skip_test_reason: 'unsupported query params in java / kotlin'
|
||||
query_metrics:
|
||||
endpoint: post /v1alpha/telemetry/metrics/{metric_name}
|
||||
skip_test_reason: 'unsupported query params in java / kotlin'
|
||||
# log_event: post /v1alpha/telemetry/events
|
||||
save_spans_to_dataset: post /v1alpha/telemetry/spans/export
|
||||
get_span: get /v1alpha/telemetry/traces/{trace_id}/spans/{span_id}
|
||||
get_trace: get /v1alpha/telemetry/traces/{trace_id}
|
||||
|
||||
scoring:
|
||||
methods:
|
||||
score: post /v1/scoring/score
|
||||
score_batch: post /v1/scoring/score-batch
|
||||
scoring_functions:
|
||||
methods:
|
||||
retrieve: get /v1/scoring-functions/{scoring_fn_id}
|
||||
list:
|
||||
endpoint: get /v1/scoring-functions
|
||||
paginated: false
|
||||
register: post /v1/scoring-functions
|
||||
models:
|
||||
scoring_fn: ScoringFn
|
||||
scoring_fn_params: ScoringFnParams
|
||||
list_scoring_functions_response: ListScoringFunctionsResponse
|
||||
|
||||
benchmarks:
|
||||
methods:
|
||||
retrieve: get /v1alpha/eval/benchmarks/{benchmark_id}
|
||||
list:
|
||||
endpoint: get /v1alpha/eval/benchmarks
|
||||
paginated: false
|
||||
register: post /v1alpha/eval/benchmarks
|
||||
models:
|
||||
benchmark: Benchmark
|
||||
list_benchmarks_response: ListBenchmarksResponse
|
||||
|
||||
files:
|
||||
methods:
|
||||
create: post /v1/files
|
||||
list: get /v1/files
|
||||
retrieve: get /v1/files/{file_id}
|
||||
delete: delete /v1/files/{file_id}
|
||||
content: get /v1/files/{file_id}/content
|
||||
models:
|
||||
file: OpenAIFileObject
|
||||
list_files_response: ListOpenAIFileResponse
|
||||
delete_file_response: OpenAIFileDeleteResponse
|
||||
|
||||
alpha:
|
||||
subresources:
|
||||
inference:
|
||||
methods:
|
||||
rerank: post /v1alpha/inference/rerank
|
||||
|
||||
post_training:
|
||||
models:
|
||||
algorithm_config: AlgorithmConfig
|
||||
post_training_job: PostTrainingJob
|
||||
list_post_training_jobs_response: ListPostTrainingJobsResponse
|
||||
methods:
|
||||
preference_optimize: post /v1alpha/post-training/preference-optimize
|
||||
supervised_fine_tune: post /v1alpha/post-training/supervised-fine-tune
|
||||
subresources:
|
||||
job:
|
||||
methods:
|
||||
artifacts: get /v1alpha/post-training/job/artifacts
|
||||
cancel: post /v1alpha/post-training/job/cancel
|
||||
status: get /v1alpha/post-training/job/status
|
||||
list:
|
||||
endpoint: get /v1alpha/post-training/jobs
|
||||
paginated: false
|
||||
|
||||
eval:
|
||||
methods:
|
||||
evaluate_rows: post /v1alpha/eval/benchmarks/{benchmark_id}/evaluations
|
||||
run_eval: post /v1alpha/eval/benchmarks/{benchmark_id}/jobs
|
||||
evaluate_rows_alpha: post /v1alpha/eval/benchmarks/{benchmark_id}/evaluations
|
||||
run_eval_alpha: post /v1alpha/eval/benchmarks/{benchmark_id}/jobs
|
||||
|
||||
subresources:
|
||||
jobs:
|
||||
methods:
|
||||
cancel: delete /v1alpha/eval/benchmarks/{benchmark_id}/jobs/{job_id}
|
||||
status: get /v1alpha/eval/benchmarks/{benchmark_id}/jobs/{job_id}
|
||||
retrieve: get /v1alpha/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result
|
||||
models:
|
||||
evaluate_response: EvaluateResponse
|
||||
benchmark_config: BenchmarkConfig
|
||||
job: Job
|
||||
|
||||
agents:
|
||||
methods:
|
||||
create: post /v1alpha/agents
|
||||
list: get /v1alpha/agents
|
||||
retrieve: get /v1alpha/agents/{agent_id}
|
||||
delete: delete /v1alpha/agents/{agent_id}
|
||||
models:
|
||||
inference_step: InferenceStep
|
||||
tool_execution_step: ToolExecutionStep
|
||||
tool_response: ToolResponse
|
||||
shield_call_step: ShieldCallStep
|
||||
memory_retrieval_step: MemoryRetrievalStep
|
||||
subresources:
|
||||
session:
|
||||
models:
|
||||
session: Session
|
||||
methods:
|
||||
list: get /v1alpha/agents/{agent_id}/sessions
|
||||
create: post /v1alpha/agents/{agent_id}/session
|
||||
delete: delete /v1alpha/agents/{agent_id}/session/{session_id}
|
||||
retrieve: get /v1alpha/agents/{agent_id}/session/{session_id}
|
||||
steps:
|
||||
methods:
|
||||
retrieve: get /v1alpha/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id}
|
||||
turn:
|
||||
models:
|
||||
turn: Turn
|
||||
turn_response_event: AgentTurnResponseEvent
|
||||
agent_turn_response_stream_chunk: AgentTurnResponseStreamChunk
|
||||
methods:
|
||||
create:
|
||||
type: http
|
||||
endpoint: post /v1alpha/agents/{agent_id}/session/{session_id}/turn
|
||||
streaming:
|
||||
stream_event_model: alpha.agents.turn.agent_turn_response_stream_chunk
|
||||
param_discriminator: stream
|
||||
retrieve: get /v1alpha/agents/{agent_id}/session/{session_id}/turn/{turn_id}
|
||||
resume:
|
||||
type: http
|
||||
endpoint: post /v1alpha/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume
|
||||
streaming:
|
||||
stream_event_model: alpha.agents.turn.agent_turn_response_stream_chunk
|
||||
param_discriminator: stream
|
||||
|
||||
beta:
|
||||
subresources:
|
||||
datasets:
|
||||
models:
|
||||
list_datasets_response: ListDatasetsResponse
|
||||
methods:
|
||||
register: post /v1beta/datasets
|
||||
retrieve: get /v1beta/datasets/{dataset_id}
|
||||
list:
|
||||
endpoint: get /v1beta/datasets
|
||||
paginated: false
|
||||
unregister: delete /v1beta/datasets/{dataset_id}
|
||||
iterrows: get /v1beta/datasetio/iterrows/{dataset_id}
|
||||
appendrows: post /v1beta/datasetio/append-rows/{dataset_id}
|
||||
|
||||
|
||||
settings:
|
||||
license: MIT
|
||||
unwrap_response_fields: [ data ]
|
||||
|
||||
openapi:
|
||||
transformations:
|
||||
- command: renameValue
|
||||
reason: pydantic reserved name
|
||||
args:
|
||||
filter:
|
||||
only:
|
||||
- '$.components.schemas.InferenceStep.properties.model_response'
|
||||
rename:
|
||||
python:
|
||||
property_name: 'inference_model_response'
|
||||
|
||||
# - command: renameValue
|
||||
# reason: pydantic reserved name
|
||||
# args:
|
||||
# filter:
|
||||
# only:
|
||||
# - '$.components.schemas.Model.properties.model_type'
|
||||
# rename:
|
||||
# python:
|
||||
# property_name: 'type'
|
||||
- command: mergeObject
|
||||
reason: Better return_type using enum
|
||||
args:
|
||||
target:
|
||||
- '$.components.schemas'
|
||||
object:
|
||||
ReturnType:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
type:
|
||||
enum:
|
||||
- string
|
||||
- number
|
||||
- boolean
|
||||
- array
|
||||
- object
|
||||
- json
|
||||
- union
|
||||
- chat_completion_input
|
||||
- completion_input
|
||||
- agent_turn_input
|
||||
required:
|
||||
- type
|
||||
type: object
|
||||
- command: replaceProperties
|
||||
reason: Replace return type properties with better model (see above)
|
||||
args:
|
||||
filter:
|
||||
only:
|
||||
- '$.components.schemas.ScoringFn.properties.return_type'
|
||||
- '$.components.schemas.RegisterScoringFunctionRequest.properties.return_type'
|
||||
value:
|
||||
$ref: '#/components/schemas/ReturnType'
|
||||
- command: oneOfToAnyOf
|
||||
reason: Prism (mock server) doesn't like one of our requests as it technically matches multiple variants
|
||||
- reason: For better names
|
||||
command: extractToRefs
|
||||
args:
|
||||
ref:
|
||||
target: '$.components.schemas.ToolCallDelta.properties.tool_call'
|
||||
name: '#/components/schemas/ToolCallOrString'
|
||||
|
||||
# `readme` is used to configure the code snippets that will be rendered in the
|
||||
# README.md of various SDKs. In particular, you can change the `headline`
|
||||
# snippet's endpoint and the arguments to call it with.
|
||||
readme:
|
||||
example_requests:
|
||||
default:
|
||||
type: request
|
||||
endpoint: post /v1/chat/completions
|
||||
params: &ref_0 {}
|
||||
headline:
|
||||
type: request
|
||||
endpoint: post /v1/models
|
||||
params: *ref_0
|
||||
pagination:
|
||||
type: request
|
||||
endpoint: post /v1/chat/completions
|
||||
params: {}
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -19,6 +19,8 @@ ARG KEEP_WORKSPACE=""
|
|||
ARG DISTRO_NAME="starter"
|
||||
ARG RUN_CONFIG_PATH=""
|
||||
ARG UV_HTTP_TIMEOUT=500
|
||||
ARG UV_EXTRA_INDEX_URL=""
|
||||
ARG UV_INDEX_STRATEGY=""
|
||||
ENV UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT}
|
||||
ENV PYTHONDONTWRITEBYTECODE=1
|
||||
ENV PIP_DISABLE_PIP_VERSION_CHECK=1
|
||||
|
|
@ -45,7 +47,7 @@ RUN set -eux; \
|
|||
exit 1; \
|
||||
fi
|
||||
|
||||
RUN pip install --no-cache uv
|
||||
RUN pip install --no-cache-dir uv
|
||||
ENV UV_SYSTEM_PYTHON=1
|
||||
|
||||
ENV INSTALL_MODE=${INSTALL_MODE}
|
||||
|
|
@ -62,47 +64,60 @@ COPY . /workspace
|
|||
|
||||
# Install the client package if it is provided
|
||||
# NOTE: this is installed before llama-stack since llama-stack depends on llama-stack-client-python
|
||||
# Unset UV index env vars to ensure we only use PyPI for the client
|
||||
RUN set -eux; \
|
||||
unset UV_EXTRA_INDEX_URL UV_INDEX_STRATEGY; \
|
||||
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then \
|
||||
if [ ! -d "$LLAMA_STACK_CLIENT_DIR" ]; then \
|
||||
echo "LLAMA_STACK_CLIENT_DIR is set but $LLAMA_STACK_CLIENT_DIR does not exist" >&2; \
|
||||
exit 1; \
|
||||
fi; \
|
||||
uv pip install --no-cache -e "$LLAMA_STACK_CLIENT_DIR"; \
|
||||
uv pip install --no-cache-dir -e "$LLAMA_STACK_CLIENT_DIR"; \
|
||||
fi;
|
||||
|
||||
# Install llama-stack
|
||||
# Use UV_EXTRA_INDEX_URL inline only for editable install with RC dependencies
|
||||
RUN set -eux; \
|
||||
SAVED_UV_EXTRA_INDEX_URL="${UV_EXTRA_INDEX_URL:-}"; \
|
||||
SAVED_UV_INDEX_STRATEGY="${UV_INDEX_STRATEGY:-}"; \
|
||||
unset UV_EXTRA_INDEX_URL UV_INDEX_STRATEGY; \
|
||||
if [ "$INSTALL_MODE" = "editable" ]; then \
|
||||
if [ ! -d "$LLAMA_STACK_DIR" ]; then \
|
||||
echo "INSTALL_MODE=editable requires LLAMA_STACK_DIR to point to a directory inside the build context" >&2; \
|
||||
exit 1; \
|
||||
fi; \
|
||||
uv pip install --no-cache -e "$LLAMA_STACK_DIR"; \
|
||||
elif [ "$INSTALL_MODE" = "test-pypi" ]; then \
|
||||
uv pip install --no-cache fastapi libcst; \
|
||||
if [ -n "$TEST_PYPI_VERSION" ]; then \
|
||||
uv pip install --no-cache --extra-index-url https://test.pypi.org/simple/ --index-strategy unsafe-best-match "llama-stack==$TEST_PYPI_VERSION"; \
|
||||
if [ -n "$SAVED_UV_EXTRA_INDEX_URL" ] && [ -n "$SAVED_UV_INDEX_STRATEGY" ]; then \
|
||||
UV_EXTRA_INDEX_URL="$SAVED_UV_EXTRA_INDEX_URL" UV_INDEX_STRATEGY="$SAVED_UV_INDEX_STRATEGY" \
|
||||
uv pip install --no-cache-dir -e "$LLAMA_STACK_DIR"; \
|
||||
else \
|
||||
uv pip install --no-cache --extra-index-url https://test.pypi.org/simple/ --index-strategy unsafe-best-match llama-stack; \
|
||||
uv pip install --no-cache-dir -e "$LLAMA_STACK_DIR"; \
|
||||
fi; \
|
||||
elif [ "$INSTALL_MODE" = "test-pypi" ]; then \
|
||||
uv pip install --no-cache-dir fastapi libcst; \
|
||||
if [ -n "$TEST_PYPI_VERSION" ]; then \
|
||||
uv pip install --no-cache-dir --extra-index-url https://test.pypi.org/simple/ --index-strategy unsafe-best-match "llama-stack==$TEST_PYPI_VERSION"; \
|
||||
else \
|
||||
uv pip install --no-cache-dir --extra-index-url https://test.pypi.org/simple/ --index-strategy unsafe-best-match llama-stack; \
|
||||
fi; \
|
||||
else \
|
||||
if [ -n "$PYPI_VERSION" ]; then \
|
||||
uv pip install --no-cache "llama-stack==$PYPI_VERSION"; \
|
||||
uv pip install --no-cache-dir "llama-stack==$PYPI_VERSION"; \
|
||||
else \
|
||||
uv pip install --no-cache llama-stack; \
|
||||
uv pip install --no-cache-dir llama-stack; \
|
||||
fi; \
|
||||
fi;
|
||||
|
||||
# Install the dependencies for the distribution
|
||||
# Explicitly unset UV index env vars to ensure we only use PyPI for distribution deps
|
||||
RUN set -eux; \
|
||||
unset UV_EXTRA_INDEX_URL UV_INDEX_STRATEGY; \
|
||||
if [ -z "$DISTRO_NAME" ]; then \
|
||||
echo "DISTRO_NAME must be provided" >&2; \
|
||||
exit 1; \
|
||||
fi; \
|
||||
deps="$(llama stack list-deps "$DISTRO_NAME")"; \
|
||||
if [ -n "$deps" ]; then \
|
||||
printf '%s\n' "$deps" | xargs -L1 uv pip install --no-cache; \
|
||||
printf '%s\n' "$deps" | xargs -L1 uv pip install --no-cache-dir; \
|
||||
fi
|
||||
|
||||
# Cleanup
|
||||
|
|
|
|||
|
|
@ -23,5 +23,4 @@ A Llama Stack API is described as a collection of REST endpoints. We currently s
|
|||
We are working on adding a few more APIs to complete the application lifecycle. These will include:
|
||||
- **Batch Inference**: run inference on a dataset of inputs
|
||||
- **Batch Agents**: run agents on a dataset of inputs
|
||||
- **Synthetic Data Generation**: generate synthetic data for model development
|
||||
- **Batches**: OpenAI-compatible batch management for inference
|
||||
|
|
|
|||
|
|
@ -79,6 +79,33 @@ docker run \
|
|||
--port $LLAMA_STACK_PORT
|
||||
```
|
||||
|
||||
### Via Docker with Custom Run Configuration
|
||||
|
||||
You can also run the Docker container with a custom run configuration file by mounting it into the container:
|
||||
|
||||
```bash
|
||||
# Set the path to your custom run.yaml file
|
||||
CUSTOM_RUN_CONFIG=/path/to/your/custom-run.yaml
|
||||
LLAMA_STACK_PORT=8321
|
||||
|
||||
docker run \
|
||||
-it \
|
||||
--pull always \
|
||||
--gpu all \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ~/.llama:/root/.llama \
|
||||
-v $CUSTOM_RUN_CONFIG:/app/custom-run.yaml \
|
||||
-e RUN_CONFIG_PATH=/app/custom-run.yaml \
|
||||
llamastack/distribution-meta-reference-gpu \
|
||||
--port $LLAMA_STACK_PORT
|
||||
```
|
||||
|
||||
**Note**: The run configuration must be mounted into the container before it can be used. The `-v` flag mounts your local file into the container, and the `RUN_CONFIG_PATH` environment variable tells the entrypoint script which configuration to use.
|
||||
|
||||
Available run configurations for this distribution:
|
||||
- `run.yaml`
|
||||
- `run-with-safety.yaml`
|
||||
|
||||
### Via venv
|
||||
|
||||
Make sure you have the Llama Stack CLI available.
|
||||
|
|
|
|||
|
|
@ -127,13 +127,39 @@ docker run \
|
|||
-it \
|
||||
--pull always \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ./run.yaml:/root/my-run.yaml \
|
||||
-v ~/.llama:/root/.llama \
|
||||
-e NVIDIA_API_KEY=$NVIDIA_API_KEY \
|
||||
llamastack/distribution-nvidia \
|
||||
--config /root/my-run.yaml \
|
||||
--port $LLAMA_STACK_PORT
|
||||
```
|
||||
|
||||
### Via Docker with Custom Run Configuration
|
||||
|
||||
You can also run the Docker container with a custom run configuration file by mounting it into the container:
|
||||
|
||||
```bash
|
||||
# Set the path to your custom run.yaml file
|
||||
CUSTOM_RUN_CONFIG=/path/to/your/custom-run.yaml
|
||||
LLAMA_STACK_PORT=8321
|
||||
|
||||
docker run \
|
||||
-it \
|
||||
--pull always \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ~/.llama:/root/.llama \
|
||||
-v $CUSTOM_RUN_CONFIG:/app/custom-run.yaml \
|
||||
-e RUN_CONFIG_PATH=/app/custom-run.yaml \
|
||||
-e NVIDIA_API_KEY=$NVIDIA_API_KEY \
|
||||
llamastack/distribution-nvidia \
|
||||
--port $LLAMA_STACK_PORT
|
||||
```
|
||||
|
||||
**Note**: The run configuration must be mounted into the container before it can be used. The `-v` flag mounts your local file into the container, and the `RUN_CONFIG_PATH` environment variable tells the entrypoint script which configuration to use.
|
||||
|
||||
Available run configurations for this distribution:
|
||||
- `run.yaml`
|
||||
- `run-with-safety.yaml`
|
||||
|
||||
### Via venv
|
||||
|
||||
If you've set up your local development environment, you can also install the distribution dependencies using your local virtual environment.
|
||||
|
|
|
|||
27
docs/docs/providers/files/remote_openai.mdx
Normal file
27
docs/docs/providers/files/remote_openai.mdx
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
---
|
||||
description: "OpenAI Files API provider for managing files through OpenAI's native file storage service."
|
||||
sidebar_label: Remote - Openai
|
||||
title: remote::openai
|
||||
---
|
||||
|
||||
# remote::openai
|
||||
|
||||
## Description
|
||||
|
||||
OpenAI Files API provider for managing files through OpenAI's native file storage service.
|
||||
|
||||
## Configuration
|
||||
|
||||
| Field | Type | Required | Default | Description |
|
||||
|-------|------|----------|---------|-------------|
|
||||
| `api_key` | `<class 'str'>` | No | | OpenAI API key for authentication |
|
||||
| `metadata_store` | `<class 'llama_stack.core.storage.datatypes.SqlStoreReference'>` | No | | SQL store configuration for file metadata |
|
||||
|
||||
## Sample Configuration
|
||||
|
||||
```yaml
|
||||
api_key: ${env.OPENAI_API_KEY}
|
||||
metadata_store:
|
||||
table_name: openai_files_metadata
|
||||
backend: sql_default
|
||||
```
|
||||
|
|
@ -20,6 +20,7 @@ NVIDIA inference provider for accessing NVIDIA NIM models and AI services.
|
|||
| `url` | `<class 'str'>` | No | https://integrate.api.nvidia.com | A base url for accessing the NVIDIA NIM |
|
||||
| `timeout` | `<class 'int'>` | No | 60 | Timeout for the HTTP requests |
|
||||
| `append_api_version` | `<class 'bool'>` | No | True | When set to false, the API version will not be appended to the base_url. By default, it is true. |
|
||||
| `rerank_model_to_url` | `dict[str, str` | No | `{'nv-rerank-qa-mistral-4b:1': 'https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking', 'nvidia/nv-rerankqa-mistral-4b-v3': 'https://ai.api.nvidia.com/v1/retrieval/nvidia/nv-rerankqa-mistral-4b-v3/reranking', 'nvidia/llama-3.2-nv-rerankqa-1b-v2': 'https://ai.api.nvidia.com/v1/retrieval/nvidia/llama-3_2-nv-rerankqa-1b-v2/reranking'}` | Mapping of rerank model identifiers to their API endpoints. |
|
||||
|
||||
## Sample Configuration
|
||||
|
||||
|
|
|
|||
1036
docs/notebooks/llamastack_agents_getting_started_examples.ipynb
Normal file
1036
docs/notebooks/llamastack_agents_getting_started_examples.ipynb
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -84,7 +84,6 @@ def generate_spec(output_dir: Path, stability_filter: str = None, main_spec: boo
|
|||
)
|
||||
|
||||
yaml_filename = f"{filename_prefix}llama-stack-spec.yaml"
|
||||
html_filename = f"{filename_prefix}llama-stack-spec.html"
|
||||
|
||||
with open(output_dir / yaml_filename, "w", encoding="utf-8") as fp:
|
||||
y = yaml.YAML()
|
||||
|
|
@ -102,11 +101,6 @@ def generate_spec(output_dir: Path, stability_filter: str = None, main_spec: boo
|
|||
fp,
|
||||
)
|
||||
|
||||
with open(output_dir / html_filename, "w") as fp:
|
||||
spec.write_html(fp, pretty_print=True)
|
||||
|
||||
print(f"Generated {yaml_filename} and {html_filename}")
|
||||
|
||||
def main(output_dir: str):
|
||||
output_dir = Path(output_dir)
|
||||
if not output_dir.exists():
|
||||
|
|
|
|||
|
|
@ -242,15 +242,6 @@ const sidebars: SidebarsConfig = {
|
|||
'providers/eval/remote_nvidia'
|
||||
],
|
||||
},
|
||||
{
|
||||
type: 'category',
|
||||
label: 'Telemetry',
|
||||
collapsed: true,
|
||||
items: [
|
||||
'providers/telemetry/index',
|
||||
'providers/telemetry/inline_meta-reference'
|
||||
],
|
||||
},
|
||||
{
|
||||
type: 'category',
|
||||
label: 'Batches',
|
||||
|
|
|
|||
13582
docs/static/deprecated-llama-stack-spec.html
vendored
13582
docs/static/deprecated-llama-stack-spec.html
vendored
File diff suppressed because it is too large
Load diff
534
docs/static/deprecated-llama-stack-spec.yaml
vendored
534
docs/static/deprecated-llama-stack-spec.yaml
vendored
|
|
@ -1012,6 +1012,141 @@ paths:
|
|||
schema:
|
||||
type: string
|
||||
deprecated: true
|
||||
/v1/openai/v1/batches:
|
||||
get:
|
||||
responses:
|
||||
'200':
|
||||
description: A list of batch objects.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ListBatchesResponse'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Batches
|
||||
summary: List all batches for the current user.
|
||||
description: List all batches for the current user.
|
||||
parameters:
|
||||
- name: after
|
||||
in: query
|
||||
description: >-
|
||||
A cursor for pagination; returns batches after this batch ID.
|
||||
required: false
|
||||
schema:
|
||||
type: string
|
||||
- name: limit
|
||||
in: query
|
||||
description: >-
|
||||
Number of batches to return (default 20, max 100).
|
||||
required: true
|
||||
schema:
|
||||
type: integer
|
||||
deprecated: true
|
||||
post:
|
||||
responses:
|
||||
'200':
|
||||
description: The created batch object.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Batch'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Batches
|
||||
summary: >-
|
||||
Create a new batch for processing multiple API requests.
|
||||
description: >-
|
||||
Create a new batch for processing multiple API requests.
|
||||
parameters: []
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/CreateBatchRequest'
|
||||
required: true
|
||||
deprecated: true
|
||||
/v1/openai/v1/batches/{batch_id}:
|
||||
get:
|
||||
responses:
|
||||
'200':
|
||||
description: The batch object.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Batch'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Batches
|
||||
summary: >-
|
||||
Retrieve information about a specific batch.
|
||||
description: >-
|
||||
Retrieve information about a specific batch.
|
||||
parameters:
|
||||
- name: batch_id
|
||||
in: path
|
||||
description: The ID of the batch to retrieve.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
deprecated: true
|
||||
/v1/openai/v1/batches/{batch_id}/cancel:
|
||||
post:
|
||||
responses:
|
||||
'200':
|
||||
description: The updated batch object.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Batch'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Batches
|
||||
summary: Cancel a batch that is in progress.
|
||||
description: Cancel a batch that is in progress.
|
||||
parameters:
|
||||
- name: batch_id
|
||||
in: path
|
||||
description: The ID of the batch to cancel.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
deprecated: true
|
||||
/v1/openai/v1/chat/completions:
|
||||
get:
|
||||
responses:
|
||||
|
|
@ -1426,31 +1561,6 @@ paths:
|
|||
schema:
|
||||
type: string
|
||||
deprecated: true
|
||||
/v1/openai/v1/models:
|
||||
get:
|
||||
responses:
|
||||
'200':
|
||||
description: A OpenAIListModelsResponse.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/OpenAIListModelsResponse'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Models
|
||||
summary: List models using the OpenAI API.
|
||||
description: List models using the OpenAI API.
|
||||
parameters: []
|
||||
deprecated: true
|
||||
/v1/openai/v1/moderations:
|
||||
post:
|
||||
responses:
|
||||
|
|
@ -4736,6 +4846,331 @@ components:
|
|||
title: Job
|
||||
description: >-
|
||||
A job execution instance with status tracking.
|
||||
ListBatchesResponse:
|
||||
type: object
|
||||
properties:
|
||||
object:
|
||||
type: string
|
||||
const: list
|
||||
default: list
|
||||
data:
|
||||
type: array
|
||||
items:
|
||||
type: object
|
||||
properties:
|
||||
id:
|
||||
type: string
|
||||
completion_window:
|
||||
type: string
|
||||
created_at:
|
||||
type: integer
|
||||
endpoint:
|
||||
type: string
|
||||
input_file_id:
|
||||
type: string
|
||||
object:
|
||||
type: string
|
||||
const: batch
|
||||
status:
|
||||
type: string
|
||||
enum:
|
||||
- validating
|
||||
- failed
|
||||
- in_progress
|
||||
- finalizing
|
||||
- completed
|
||||
- expired
|
||||
- cancelling
|
||||
- cancelled
|
||||
cancelled_at:
|
||||
type: integer
|
||||
cancelling_at:
|
||||
type: integer
|
||||
completed_at:
|
||||
type: integer
|
||||
error_file_id:
|
||||
type: string
|
||||
errors:
|
||||
type: object
|
||||
properties:
|
||||
data:
|
||||
type: array
|
||||
items:
|
||||
type: object
|
||||
properties:
|
||||
code:
|
||||
type: string
|
||||
line:
|
||||
type: integer
|
||||
message:
|
||||
type: string
|
||||
param:
|
||||
type: string
|
||||
additionalProperties: false
|
||||
title: BatchError
|
||||
object:
|
||||
type: string
|
||||
additionalProperties: false
|
||||
title: Errors
|
||||
expired_at:
|
||||
type: integer
|
||||
expires_at:
|
||||
type: integer
|
||||
failed_at:
|
||||
type: integer
|
||||
finalizing_at:
|
||||
type: integer
|
||||
in_progress_at:
|
||||
type: integer
|
||||
metadata:
|
||||
type: object
|
||||
additionalProperties:
|
||||
type: string
|
||||
model:
|
||||
type: string
|
||||
output_file_id:
|
||||
type: string
|
||||
request_counts:
|
||||
type: object
|
||||
properties:
|
||||
completed:
|
||||
type: integer
|
||||
failed:
|
||||
type: integer
|
||||
total:
|
||||
type: integer
|
||||
additionalProperties: false
|
||||
required:
|
||||
- completed
|
||||
- failed
|
||||
- total
|
||||
title: BatchRequestCounts
|
||||
usage:
|
||||
type: object
|
||||
properties:
|
||||
input_tokens:
|
||||
type: integer
|
||||
input_tokens_details:
|
||||
type: object
|
||||
properties:
|
||||
cached_tokens:
|
||||
type: integer
|
||||
additionalProperties: false
|
||||
required:
|
||||
- cached_tokens
|
||||
title: InputTokensDetails
|
||||
output_tokens:
|
||||
type: integer
|
||||
output_tokens_details:
|
||||
type: object
|
||||
properties:
|
||||
reasoning_tokens:
|
||||
type: integer
|
||||
additionalProperties: false
|
||||
required:
|
||||
- reasoning_tokens
|
||||
title: OutputTokensDetails
|
||||
total_tokens:
|
||||
type: integer
|
||||
additionalProperties: false
|
||||
required:
|
||||
- input_tokens
|
||||
- input_tokens_details
|
||||
- output_tokens
|
||||
- output_tokens_details
|
||||
- total_tokens
|
||||
title: BatchUsage
|
||||
additionalProperties: false
|
||||
required:
|
||||
- id
|
||||
- completion_window
|
||||
- created_at
|
||||
- endpoint
|
||||
- input_file_id
|
||||
- object
|
||||
- status
|
||||
title: Batch
|
||||
first_id:
|
||||
type: string
|
||||
last_id:
|
||||
type: string
|
||||
has_more:
|
||||
type: boolean
|
||||
default: false
|
||||
additionalProperties: false
|
||||
required:
|
||||
- object
|
||||
- data
|
||||
- has_more
|
||||
title: ListBatchesResponse
|
||||
description: >-
|
||||
Response containing a list of batch objects.
|
||||
CreateBatchRequest:
|
||||
type: object
|
||||
properties:
|
||||
input_file_id:
|
||||
type: string
|
||||
description: >-
|
||||
The ID of an uploaded file containing requests for the batch.
|
||||
endpoint:
|
||||
type: string
|
||||
description: >-
|
||||
The endpoint to be used for all requests in the batch.
|
||||
completion_window:
|
||||
type: string
|
||||
const: 24h
|
||||
description: >-
|
||||
The time window within which the batch should be processed.
|
||||
metadata:
|
||||
type: object
|
||||
additionalProperties:
|
||||
type: string
|
||||
description: Optional metadata for the batch.
|
||||
idempotency_key:
|
||||
type: string
|
||||
description: >-
|
||||
Optional idempotency key. When provided, enables idempotent behavior.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- input_file_id
|
||||
- endpoint
|
||||
- completion_window
|
||||
title: CreateBatchRequest
|
||||
Batch:
|
||||
type: object
|
||||
properties:
|
||||
id:
|
||||
type: string
|
||||
completion_window:
|
||||
type: string
|
||||
created_at:
|
||||
type: integer
|
||||
endpoint:
|
||||
type: string
|
||||
input_file_id:
|
||||
type: string
|
||||
object:
|
||||
type: string
|
||||
const: batch
|
||||
status:
|
||||
type: string
|
||||
enum:
|
||||
- validating
|
||||
- failed
|
||||
- in_progress
|
||||
- finalizing
|
||||
- completed
|
||||
- expired
|
||||
- cancelling
|
||||
- cancelled
|
||||
cancelled_at:
|
||||
type: integer
|
||||
cancelling_at:
|
||||
type: integer
|
||||
completed_at:
|
||||
type: integer
|
||||
error_file_id:
|
||||
type: string
|
||||
errors:
|
||||
type: object
|
||||
properties:
|
||||
data:
|
||||
type: array
|
||||
items:
|
||||
type: object
|
||||
properties:
|
||||
code:
|
||||
type: string
|
||||
line:
|
||||
type: integer
|
||||
message:
|
||||
type: string
|
||||
param:
|
||||
type: string
|
||||
additionalProperties: false
|
||||
title: BatchError
|
||||
object:
|
||||
type: string
|
||||
additionalProperties: false
|
||||
title: Errors
|
||||
expired_at:
|
||||
type: integer
|
||||
expires_at:
|
||||
type: integer
|
||||
failed_at:
|
||||
type: integer
|
||||
finalizing_at:
|
||||
type: integer
|
||||
in_progress_at:
|
||||
type: integer
|
||||
metadata:
|
||||
type: object
|
||||
additionalProperties:
|
||||
type: string
|
||||
model:
|
||||
type: string
|
||||
output_file_id:
|
||||
type: string
|
||||
request_counts:
|
||||
type: object
|
||||
properties:
|
||||
completed:
|
||||
type: integer
|
||||
failed:
|
||||
type: integer
|
||||
total:
|
||||
type: integer
|
||||
additionalProperties: false
|
||||
required:
|
||||
- completed
|
||||
- failed
|
||||
- total
|
||||
title: BatchRequestCounts
|
||||
usage:
|
||||
type: object
|
||||
properties:
|
||||
input_tokens:
|
||||
type: integer
|
||||
input_tokens_details:
|
||||
type: object
|
||||
properties:
|
||||
cached_tokens:
|
||||
type: integer
|
||||
additionalProperties: false
|
||||
required:
|
||||
- cached_tokens
|
||||
title: InputTokensDetails
|
||||
output_tokens:
|
||||
type: integer
|
||||
output_tokens_details:
|
||||
type: object
|
||||
properties:
|
||||
reasoning_tokens:
|
||||
type: integer
|
||||
additionalProperties: false
|
||||
required:
|
||||
- reasoning_tokens
|
||||
title: OutputTokensDetails
|
||||
total_tokens:
|
||||
type: integer
|
||||
additionalProperties: false
|
||||
required:
|
||||
- input_tokens
|
||||
- input_tokens_details
|
||||
- output_tokens
|
||||
- output_tokens_details
|
||||
- total_tokens
|
||||
title: BatchUsage
|
||||
additionalProperties: false
|
||||
required:
|
||||
- id
|
||||
- completion_window
|
||||
- created_at
|
||||
- endpoint
|
||||
- input_file_id
|
||||
- object
|
||||
- status
|
||||
title: Batch
|
||||
Order:
|
||||
type: string
|
||||
enum:
|
||||
|
|
@ -6056,38 +6491,6 @@ components:
|
|||
Response:
|
||||
type: object
|
||||
title: Response
|
||||
OpenAIModel:
|
||||
type: object
|
||||
properties:
|
||||
id:
|
||||
type: string
|
||||
object:
|
||||
type: string
|
||||
const: model
|
||||
default: model
|
||||
created:
|
||||
type: integer
|
||||
owned_by:
|
||||
type: string
|
||||
additionalProperties: false
|
||||
required:
|
||||
- id
|
||||
- object
|
||||
- created
|
||||
- owned_by
|
||||
title: OpenAIModel
|
||||
description: A model from OpenAI.
|
||||
OpenAIListModelsResponse:
|
||||
type: object
|
||||
properties:
|
||||
data:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/OpenAIModel'
|
||||
additionalProperties: false
|
||||
required:
|
||||
- data
|
||||
title: OpenAIListModelsResponse
|
||||
RunModerationRequest:
|
||||
type: object
|
||||
properties:
|
||||
|
|
@ -10263,6 +10666,19 @@ tags:
|
|||
|
||||
- **Responses API**: Use the stable v1 Responses API endpoints
|
||||
x-displayName: Agents
|
||||
- name: Batches
|
||||
description: >-
|
||||
The API is designed to allow use of openai client libraries for seamless integration.
|
||||
|
||||
|
||||
This API provides the following extensions:
|
||||
- idempotent batch creation
|
||||
|
||||
Note: This API is currently under active development and may undergo changes.
|
||||
x-displayName: >-
|
||||
The Batches API enables efficient processing of multiple requests in a single
|
||||
operation, particularly useful for processing large datasets, batch evaluation
|
||||
workflows, and cost-effective inference at scale.
|
||||
- name: Benchmarks
|
||||
description: ''
|
||||
- name: DatasetIO
|
||||
|
|
@ -10295,8 +10711,6 @@ tags:
|
|||
- Rerank models: these models reorder the documents based on their relevance
|
||||
to a query.
|
||||
x-displayName: Inference
|
||||
- name: Models
|
||||
description: ''
|
||||
- name: PostTraining (Coming Soon)
|
||||
description: ''
|
||||
- name: Safety
|
||||
|
|
@ -10308,13 +10722,13 @@ x-tagGroups:
|
|||
- name: Operations
|
||||
tags:
|
||||
- Agents
|
||||
- Batches
|
||||
- Benchmarks
|
||||
- DatasetIO
|
||||
- Datasets
|
||||
- Eval
|
||||
- Files
|
||||
- Inference
|
||||
- Models
|
||||
- PostTraining (Coming Soon)
|
||||
- Safety
|
||||
- VectorIO
|
||||
|
|
|
|||
5552
docs/static/experimental-llama-stack-spec.html
vendored
5552
docs/static/experimental-llama-stack-spec.html
vendored
File diff suppressed because it is too large
Load diff
1075
docs/static/llama-stack-spec.html
vendored
1075
docs/static/llama-stack-spec.html
vendored
File diff suppressed because it is too large
Load diff
873
docs/static/llama-stack-spec.yaml
vendored
873
docs/static/llama-stack-spec.yaml
vendored
|
|
@ -12,6 +12,141 @@ info:
|
|||
servers:
|
||||
- url: http://any-hosted-llama-stack.com
|
||||
paths:
|
||||
/v1/batches:
|
||||
get:
|
||||
responses:
|
||||
'200':
|
||||
description: A list of batch objects.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ListBatchesResponse'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Batches
|
||||
summary: List all batches for the current user.
|
||||
description: List all batches for the current user.
|
||||
parameters:
|
||||
- name: after
|
||||
in: query
|
||||
description: >-
|
||||
A cursor for pagination; returns batches after this batch ID.
|
||||
required: false
|
||||
schema:
|
||||
type: string
|
||||
- name: limit
|
||||
in: query
|
||||
description: >-
|
||||
Number of batches to return (default 20, max 100).
|
||||
required: true
|
||||
schema:
|
||||
type: integer
|
||||
deprecated: false
|
||||
post:
|
||||
responses:
|
||||
'200':
|
||||
description: The created batch object.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Batch'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Batches
|
||||
summary: >-
|
||||
Create a new batch for processing multiple API requests.
|
||||
description: >-
|
||||
Create a new batch for processing multiple API requests.
|
||||
parameters: []
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/CreateBatchRequest'
|
||||
required: true
|
||||
deprecated: false
|
||||
/v1/batches/{batch_id}:
|
||||
get:
|
||||
responses:
|
||||
'200':
|
||||
description: The batch object.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Batch'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Batches
|
||||
summary: >-
|
||||
Retrieve information about a specific batch.
|
||||
description: >-
|
||||
Retrieve information about a specific batch.
|
||||
parameters:
|
||||
- name: batch_id
|
||||
in: path
|
||||
description: The ID of the batch to retrieve.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
deprecated: false
|
||||
/v1/batches/{batch_id}/cancel:
|
||||
post:
|
||||
responses:
|
||||
'200':
|
||||
description: The updated batch object.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Batch'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Batches
|
||||
summary: Cancel a batch that is in progress.
|
||||
description: Cancel a batch that is in progress.
|
||||
parameters:
|
||||
- name: batch_id
|
||||
in: path
|
||||
description: The ID of the batch to cancel.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
deprecated: false
|
||||
/v1/chat/completions:
|
||||
get:
|
||||
responses:
|
||||
|
|
@ -818,7 +953,22 @@ paths:
|
|||
List routes.
|
||||
|
||||
List all available API routes with their methods and implementing providers.
|
||||
parameters: []
|
||||
parameters:
|
||||
- name: api_filter
|
||||
in: query
|
||||
description: >-
|
||||
Optional filter to control which routes are returned. Can be an API level
|
||||
('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level,
|
||||
or 'deprecated' to show deprecated routes across all levels. If not specified,
|
||||
returns only non-deprecated v1 routes.
|
||||
required: false
|
||||
schema:
|
||||
type: string
|
||||
enum:
|
||||
- v1
|
||||
- v1alpha
|
||||
- v1beta
|
||||
- deprecated
|
||||
deprecated: false
|
||||
/v1/models:
|
||||
get:
|
||||
|
|
@ -976,6 +1126,31 @@ paths:
|
|||
$ref: '#/components/schemas/RunModerationRequest'
|
||||
required: true
|
||||
deprecated: false
|
||||
/v1/openai/v1/models:
|
||||
get:
|
||||
responses:
|
||||
'200':
|
||||
description: A OpenAIListModelsResponse.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/OpenAIListModelsResponse'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- Models
|
||||
summary: List models using the OpenAI API.
|
||||
description: List models using the OpenAI API.
|
||||
parameters: []
|
||||
deprecated: false
|
||||
/v1/prompts:
|
||||
get:
|
||||
responses:
|
||||
|
|
@ -1832,40 +2007,6 @@ paths:
|
|||
schema:
|
||||
type: string
|
||||
deprecated: false
|
||||
/v1/synthetic-data-generation/generate:
|
||||
post:
|
||||
responses:
|
||||
'200':
|
||||
description: >-
|
||||
Response containing filtered synthetic data samples and optional statistics
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/SyntheticDataGenerationResponse'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
$ref: >-
|
||||
#/components/responses/TooManyRequests429
|
||||
'500':
|
||||
$ref: >-
|
||||
#/components/responses/InternalServerError500
|
||||
default:
|
||||
$ref: '#/components/responses/DefaultError'
|
||||
tags:
|
||||
- SyntheticDataGeneration (Coming Soon)
|
||||
summary: >-
|
||||
Generate synthetic data based on input dialogs and apply filtering.
|
||||
description: >-
|
||||
Generate synthetic data based on input dialogs and apply filtering.
|
||||
parameters: []
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/SyntheticDataGenerateRequest'
|
||||
required: true
|
||||
deprecated: false
|
||||
/v1/tool-runtime/invoke:
|
||||
post:
|
||||
responses:
|
||||
|
|
@ -2999,6 +3140,331 @@ components:
|
|||
title: Error
|
||||
description: >-
|
||||
Error response from the API. Roughly follows RFC 7807.
|
||||
ListBatchesResponse:
|
||||
type: object
|
||||
properties:
|
||||
object:
|
||||
type: string
|
||||
const: list
|
||||
default: list
|
||||
data:
|
||||
type: array
|
||||
items:
|
||||
type: object
|
||||
properties:
|
||||
id:
|
||||
type: string
|
||||
completion_window:
|
||||
type: string
|
||||
created_at:
|
||||
type: integer
|
||||
endpoint:
|
||||
type: string
|
||||
input_file_id:
|
||||
type: string
|
||||
object:
|
||||
type: string
|
||||
const: batch
|
||||
status:
|
||||
type: string
|
||||
enum:
|
||||
- validating
|
||||
- failed
|
||||
- in_progress
|
||||
- finalizing
|
||||
- completed
|
||||
- expired
|
||||
- cancelling
|
||||
- cancelled
|
||||
cancelled_at:
|
||||
type: integer
|
||||
cancelling_at:
|
||||
type: integer
|
||||
completed_at:
|
||||
type: integer
|
||||
error_file_id:
|
||||
type: string
|
||||
errors:
|
||||
type: object
|
||||
properties:
|
||||
data:
|
||||
type: array
|
||||
items:
|
||||
type: object
|
||||
properties:
|
||||
code:
|
||||
type: string
|
||||
line:
|
||||
type: integer
|
||||
message:
|
||||
type: string
|
||||
param:
|
||||
type: string
|
||||
additionalProperties: false
|
||||
title: BatchError
|
||||
object:
|
||||
type: string
|
||||
additionalProperties: false
|
||||
title: Errors
|
||||
expired_at:
|
||||
type: integer
|
||||
expires_at:
|
||||
type: integer
|
||||
failed_at:
|
||||
type: integer
|
||||
finalizing_at:
|
||||
type: integer
|
||||
in_progress_at:
|
||||
type: integer
|
||||
metadata:
|
||||
type: object
|
||||
additionalProperties:
|
||||
type: string
|
||||
model:
|
||||
type: string
|
||||
output_file_id:
|
||||
type: string
|
||||
request_counts:
|
||||
type: object
|
||||
properties:
|
||||
completed:
|
||||
type: integer
|
||||
failed:
|
||||
type: integer
|
||||
total:
|
||||
type: integer
|
||||
additionalProperties: false
|
||||
required:
|
||||
- completed
|
||||
- failed
|
||||
- total
|
||||
title: BatchRequestCounts
|
||||
usage:
|
||||
type: object
|
||||
properties:
|
||||
input_tokens:
|
||||
type: integer
|
||||
input_tokens_details:
|
||||
type: object
|
||||
properties:
|
||||
cached_tokens:
|
||||
type: integer
|
||||
additionalProperties: false
|
||||
required:
|
||||
- cached_tokens
|
||||
title: InputTokensDetails
|
||||
output_tokens:
|
||||
type: integer
|
||||
output_tokens_details:
|
||||
type: object
|
||||
properties:
|
||||
reasoning_tokens:
|
||||
type: integer
|
||||
additionalProperties: false
|
||||
required:
|
||||
- reasoning_tokens
|
||||
title: OutputTokensDetails
|
||||
total_tokens:
|
||||
type: integer
|
||||
additionalProperties: false
|
||||
required:
|
||||
- input_tokens
|
||||
- input_tokens_details
|
||||
- output_tokens
|
||||
- output_tokens_details
|
||||
- total_tokens
|
||||
title: BatchUsage
|
||||
additionalProperties: false
|
||||
required:
|
||||
- id
|
||||
- completion_window
|
||||
- created_at
|
||||
- endpoint
|
||||
- input_file_id
|
||||
- object
|
||||
- status
|
||||
title: Batch
|
||||
first_id:
|
||||
type: string
|
||||
last_id:
|
||||
type: string
|
||||
has_more:
|
||||
type: boolean
|
||||
default: false
|
||||
additionalProperties: false
|
||||
required:
|
||||
- object
|
||||
- data
|
||||
- has_more
|
||||
title: ListBatchesResponse
|
||||
description: >-
|
||||
Response containing a list of batch objects.
|
||||
CreateBatchRequest:
|
||||
type: object
|
||||
properties:
|
||||
input_file_id:
|
||||
type: string
|
||||
description: >-
|
||||
The ID of an uploaded file containing requests for the batch.
|
||||
endpoint:
|
||||
type: string
|
||||
description: >-
|
||||
The endpoint to be used for all requests in the batch.
|
||||
completion_window:
|
||||
type: string
|
||||
const: 24h
|
||||
description: >-
|
||||
The time window within which the batch should be processed.
|
||||
metadata:
|
||||
type: object
|
||||
additionalProperties:
|
||||
type: string
|
||||
description: Optional metadata for the batch.
|
||||
idempotency_key:
|
||||
type: string
|
||||
description: >-
|
||||
Optional idempotency key. When provided, enables idempotent behavior.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- input_file_id
|
||||
- endpoint
|
||||
- completion_window
|
||||
title: CreateBatchRequest
|
||||
Batch:
|
||||
type: object
|
||||
properties:
|
||||
id:
|
||||
type: string
|
||||
completion_window:
|
||||
type: string
|
||||
created_at:
|
||||
type: integer
|
||||
endpoint:
|
||||
type: string
|
||||
input_file_id:
|
||||
type: string
|
||||
object:
|
||||
type: string
|
||||
const: batch
|
||||
status:
|
||||
type: string
|
||||
enum:
|
||||
- validating
|
||||
- failed
|
||||
- in_progress
|
||||
- finalizing
|
||||
- completed
|
||||
- expired
|
||||
- cancelling
|
||||
- cancelled
|
||||
cancelled_at:
|
||||
type: integer
|
||||
cancelling_at:
|
||||
type: integer
|
||||
completed_at:
|
||||
type: integer
|
||||
error_file_id:
|
||||
type: string
|
||||
errors:
|
||||
type: object
|
||||
properties:
|
||||
data:
|
||||
type: array
|
||||
items:
|
||||
type: object
|
||||
properties:
|
||||
code:
|
||||
type: string
|
||||
line:
|
||||
type: integer
|
||||
message:
|
||||
type: string
|
||||
param:
|
||||
type: string
|
||||
additionalProperties: false
|
||||
title: BatchError
|
||||
object:
|
||||
type: string
|
||||
additionalProperties: false
|
||||
title: Errors
|
||||
expired_at:
|
||||
type: integer
|
||||
expires_at:
|
||||
type: integer
|
||||
failed_at:
|
||||
type: integer
|
||||
finalizing_at:
|
||||
type: integer
|
||||
in_progress_at:
|
||||
type: integer
|
||||
metadata:
|
||||
type: object
|
||||
additionalProperties:
|
||||
type: string
|
||||
model:
|
||||
type: string
|
||||
output_file_id:
|
||||
type: string
|
||||
request_counts:
|
||||
type: object
|
||||
properties:
|
||||
completed:
|
||||
type: integer
|
||||
failed:
|
||||
type: integer
|
||||
total:
|
||||
type: integer
|
||||
additionalProperties: false
|
||||
required:
|
||||
- completed
|
||||
- failed
|
||||
- total
|
||||
title: BatchRequestCounts
|
||||
usage:
|
||||
type: object
|
||||
properties:
|
||||
input_tokens:
|
||||
type: integer
|
||||
input_tokens_details:
|
||||
type: object
|
||||
properties:
|
||||
cached_tokens:
|
||||
type: integer
|
||||
additionalProperties: false
|
||||
required:
|
||||
- cached_tokens
|
||||
title: InputTokensDetails
|
||||
output_tokens:
|
||||
type: integer
|
||||
output_tokens_details:
|
||||
type: object
|
||||
properties:
|
||||
reasoning_tokens:
|
||||
type: integer
|
||||
additionalProperties: false
|
||||
required:
|
||||
- reasoning_tokens
|
||||
title: OutputTokensDetails
|
||||
total_tokens:
|
||||
type: integer
|
||||
additionalProperties: false
|
||||
required:
|
||||
- input_tokens
|
||||
- input_tokens_details
|
||||
- output_tokens
|
||||
- output_tokens_details
|
||||
- total_tokens
|
||||
title: BatchUsage
|
||||
additionalProperties: false
|
||||
required:
|
||||
- id
|
||||
- completion_window
|
||||
- created_at
|
||||
- endpoint
|
||||
- input_file_id
|
||||
- object
|
||||
- status
|
||||
title: Batch
|
||||
Order:
|
||||
type: string
|
||||
enum:
|
||||
|
|
@ -5341,6 +5807,48 @@ components:
|
|||
- metadata
|
||||
title: ModerationObjectResults
|
||||
description: A moderation object.
|
||||
OpenAIModel:
|
||||
type: object
|
||||
properties:
|
||||
id:
|
||||
type: string
|
||||
object:
|
||||
type: string
|
||||
const: model
|
||||
default: model
|
||||
created:
|
||||
type: integer
|
||||
owned_by:
|
||||
type: string
|
||||
custom_metadata:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
additionalProperties: false
|
||||
required:
|
||||
- id
|
||||
- object
|
||||
- created
|
||||
- owned_by
|
||||
title: OpenAIModel
|
||||
description: A model from OpenAI.
|
||||
OpenAIListModelsResponse:
|
||||
type: object
|
||||
properties:
|
||||
data:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/OpenAIModel'
|
||||
additionalProperties: false
|
||||
required:
|
||||
- data
|
||||
title: OpenAIListModelsResponse
|
||||
Prompt:
|
||||
type: object
|
||||
properties:
|
||||
|
|
@ -8265,45 +8773,29 @@ components:
|
|||
required:
|
||||
- shield_id
|
||||
title: RegisterShieldRequest
|
||||
CompletionMessage:
|
||||
InvokeToolRequest:
|
||||
type: object
|
||||
properties:
|
||||
role:
|
||||
tool_name:
|
||||
type: string
|
||||
const: assistant
|
||||
default: assistant
|
||||
description: The name of the tool to invoke.
|
||||
kwargs:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
description: >-
|
||||
Must be "assistant" to identify this as the model's response
|
||||
content:
|
||||
$ref: '#/components/schemas/InterleavedContent'
|
||||
description: The content of the model's response
|
||||
stop_reason:
|
||||
type: string
|
||||
enum:
|
||||
- end_of_turn
|
||||
- end_of_message
|
||||
- out_of_tokens
|
||||
description: >-
|
||||
Reason why the model stopped generating. Options are: - `StopReason.end_of_turn`:
|
||||
The model finished generating the entire response. - `StopReason.end_of_message`:
|
||||
The model finished generating but generated a partial response -- usually,
|
||||
a tool call. The user may call the tool and continue the conversation
|
||||
with the tool's response. - `StopReason.out_of_tokens`: The model ran
|
||||
out of token budget.
|
||||
tool_calls:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/ToolCall'
|
||||
description: >-
|
||||
List of tool calls. Each tool call is a ToolCall object.
|
||||
A dictionary of arguments to pass to the tool.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- role
|
||||
- content
|
||||
- stop_reason
|
||||
title: CompletionMessage
|
||||
description: >-
|
||||
A message containing the model's (assistant) response in a chat conversation.
|
||||
- tool_name
|
||||
- kwargs
|
||||
title: InvokeToolRequest
|
||||
ImageContentItem:
|
||||
type: object
|
||||
properties:
|
||||
|
|
@ -8350,41 +8842,6 @@ components:
|
|||
mapping:
|
||||
image: '#/components/schemas/ImageContentItem'
|
||||
text: '#/components/schemas/TextContentItem'
|
||||
Message:
|
||||
oneOf:
|
||||
- $ref: '#/components/schemas/UserMessage'
|
||||
- $ref: '#/components/schemas/SystemMessage'
|
||||
- $ref: '#/components/schemas/ToolResponseMessage'
|
||||
- $ref: '#/components/schemas/CompletionMessage'
|
||||
discriminator:
|
||||
propertyName: role
|
||||
mapping:
|
||||
user: '#/components/schemas/UserMessage'
|
||||
system: '#/components/schemas/SystemMessage'
|
||||
tool: '#/components/schemas/ToolResponseMessage'
|
||||
assistant: '#/components/schemas/CompletionMessage'
|
||||
SystemMessage:
|
||||
type: object
|
||||
properties:
|
||||
role:
|
||||
type: string
|
||||
const: system
|
||||
default: system
|
||||
description: >-
|
||||
Must be "system" to identify this as a system message
|
||||
content:
|
||||
$ref: '#/components/schemas/InterleavedContent'
|
||||
description: >-
|
||||
The content of the "system prompt". If multiple system messages are provided,
|
||||
they are concatenated. The underlying Llama Stack code may also add other
|
||||
system messages (for example, for formatting tool definitions).
|
||||
additionalProperties: false
|
||||
required:
|
||||
- role
|
||||
- content
|
||||
title: SystemMessage
|
||||
description: >-
|
||||
A system message providing instructions or context to the model.
|
||||
TextContentItem:
|
||||
type: object
|
||||
properties:
|
||||
|
|
@ -8403,179 +8860,6 @@ components:
|
|||
- text
|
||||
title: TextContentItem
|
||||
description: A text content item
|
||||
ToolCall:
|
||||
type: object
|
||||
properties:
|
||||
call_id:
|
||||
type: string
|
||||
tool_name:
|
||||
oneOf:
|
||||
- type: string
|
||||
enum:
|
||||
- brave_search
|
||||
- wolfram_alpha
|
||||
- photogen
|
||||
- code_interpreter
|
||||
title: BuiltinTool
|
||||
- type: string
|
||||
arguments:
|
||||
type: string
|
||||
additionalProperties: false
|
||||
required:
|
||||
- call_id
|
||||
- tool_name
|
||||
- arguments
|
||||
title: ToolCall
|
||||
ToolResponseMessage:
|
||||
type: object
|
||||
properties:
|
||||
role:
|
||||
type: string
|
||||
const: tool
|
||||
default: tool
|
||||
description: >-
|
||||
Must be "tool" to identify this as a tool response
|
||||
call_id:
|
||||
type: string
|
||||
description: >-
|
||||
Unique identifier for the tool call this response is for
|
||||
content:
|
||||
$ref: '#/components/schemas/InterleavedContent'
|
||||
description: The response content from the tool
|
||||
additionalProperties: false
|
||||
required:
|
||||
- role
|
||||
- call_id
|
||||
- content
|
||||
title: ToolResponseMessage
|
||||
description: >-
|
||||
A message representing the result of a tool invocation.
|
||||
URL:
|
||||
type: object
|
||||
properties:
|
||||
uri:
|
||||
type: string
|
||||
description: The URL string pointing to the resource
|
||||
additionalProperties: false
|
||||
required:
|
||||
- uri
|
||||
title: URL
|
||||
description: A URL reference to external content.
|
||||
UserMessage:
|
||||
type: object
|
||||
properties:
|
||||
role:
|
||||
type: string
|
||||
const: user
|
||||
default: user
|
||||
description: >-
|
||||
Must be "user" to identify this as a user message
|
||||
content:
|
||||
$ref: '#/components/schemas/InterleavedContent'
|
||||
description: >-
|
||||
The content of the message, which can include text and other media
|
||||
context:
|
||||
$ref: '#/components/schemas/InterleavedContent'
|
||||
description: >-
|
||||
(Optional) This field is used internally by Llama Stack to pass RAG context.
|
||||
This field may be removed in the API in the future.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- role
|
||||
- content
|
||||
title: UserMessage
|
||||
description: >-
|
||||
A message from the user in a chat conversation.
|
||||
SyntheticDataGenerateRequest:
|
||||
type: object
|
||||
properties:
|
||||
dialogs:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/Message'
|
||||
description: >-
|
||||
List of conversation messages to use as input for synthetic data generation
|
||||
filtering_function:
|
||||
type: string
|
||||
enum:
|
||||
- none
|
||||
- random
|
||||
- top_k
|
||||
- top_p
|
||||
- top_k_top_p
|
||||
- sigmoid
|
||||
description: >-
|
||||
Type of filtering to apply to generated synthetic data samples
|
||||
model:
|
||||
type: string
|
||||
description: >-
|
||||
(Optional) The identifier of the model to use. The model must be registered
|
||||
with Llama Stack and available via the /models endpoint
|
||||
additionalProperties: false
|
||||
required:
|
||||
- dialogs
|
||||
- filtering_function
|
||||
title: SyntheticDataGenerateRequest
|
||||
SyntheticDataGenerationResponse:
|
||||
type: object
|
||||
properties:
|
||||
synthetic_data:
|
||||
type: array
|
||||
items:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
description: >-
|
||||
List of generated synthetic data samples that passed the filtering criteria
|
||||
statistics:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
description: >-
|
||||
(Optional) Statistical information about the generation process and filtering
|
||||
results
|
||||
additionalProperties: false
|
||||
required:
|
||||
- synthetic_data
|
||||
title: SyntheticDataGenerationResponse
|
||||
description: >-
|
||||
Response from the synthetic data generation. Batch of (prompt, response, score)
|
||||
tuples that pass the threshold.
|
||||
InvokeToolRequest:
|
||||
type: object
|
||||
properties:
|
||||
tool_name:
|
||||
type: string
|
||||
description: The name of the tool to invoke.
|
||||
kwargs:
|
||||
type: object
|
||||
additionalProperties:
|
||||
oneOf:
|
||||
- type: 'null'
|
||||
- type: boolean
|
||||
- type: number
|
||||
- type: string
|
||||
- type: array
|
||||
- type: object
|
||||
description: >-
|
||||
A dictionary of arguments to pass to the tool.
|
||||
additionalProperties: false
|
||||
required:
|
||||
- tool_name
|
||||
- kwargs
|
||||
title: InvokeToolRequest
|
||||
ToolInvocationResult:
|
||||
type: object
|
||||
properties:
|
||||
|
|
@ -8606,6 +8890,17 @@ components:
|
|||
additionalProperties: false
|
||||
title: ToolInvocationResult
|
||||
description: Result of a tool invocation.
|
||||
URL:
|
||||
type: object
|
||||
properties:
|
||||
uri:
|
||||
type: string
|
||||
description: The URL string pointing to the resource
|
||||
additionalProperties: false
|
||||
required:
|
||||
- uri
|
||||
title: URL
|
||||
description: A URL reference to external content.
|
||||
ToolDef:
|
||||
type: object
|
||||
properties:
|
||||
|
|
@ -9045,6 +9340,10 @@ components:
|
|||
description: >-
|
||||
The content of the chunk, which can be interleaved text, images, or other
|
||||
types.
|
||||
chunk_id:
|
||||
type: string
|
||||
description: >-
|
||||
Unique identifier for the chunk. Must be provided explicitly.
|
||||
metadata:
|
||||
type: object
|
||||
additionalProperties:
|
||||
|
|
@ -9065,10 +9364,6 @@ components:
|
|||
description: >-
|
||||
Optional embedding for the chunk. If not provided, it will be computed
|
||||
later.
|
||||
stored_chunk_id:
|
||||
type: string
|
||||
description: >-
|
||||
The chunk ID that is stored in the vector database. Used for backend functionality.
|
||||
chunk_metadata:
|
||||
$ref: '#/components/schemas/ChunkMetadata'
|
||||
description: >-
|
||||
|
|
@ -9077,6 +9372,7 @@ components:
|
|||
additionalProperties: false
|
||||
required:
|
||||
- content
|
||||
- chunk_id
|
||||
- metadata
|
||||
title: Chunk
|
||||
description: >-
|
||||
|
|
@ -10143,6 +10439,19 @@ tags:
|
|||
|
||||
- `background`
|
||||
x-displayName: Agents
|
||||
- name: Batches
|
||||
description: >-
|
||||
The API is designed to allow use of openai client libraries for seamless integration.
|
||||
|
||||
|
||||
This API provides the following extensions:
|
||||
- idempotent batch creation
|
||||
|
||||
Note: This API is currently under active development and may undergo changes.
|
||||
x-displayName: >-
|
||||
The Batches API enables efficient processing of multiple requests in a single
|
||||
operation, particularly useful for processing large datasets, batch evaluation
|
||||
workflows, and cost-effective inference at scale.
|
||||
- name: Conversations
|
||||
description: >-
|
||||
Protocol for conversation management operations.
|
||||
|
|
@ -10193,8 +10502,6 @@ tags:
|
|||
description: ''
|
||||
- name: Shields
|
||||
description: ''
|
||||
- name: SyntheticDataGeneration (Coming Soon)
|
||||
description: ''
|
||||
- name: ToolGroups
|
||||
description: ''
|
||||
- name: ToolRuntime
|
||||
|
|
@ -10205,6 +10512,7 @@ x-tagGroups:
|
|||
- name: Operations
|
||||
tags:
|
||||
- Agents
|
||||
- Batches
|
||||
- Conversations
|
||||
- Files
|
||||
- Inference
|
||||
|
|
@ -10216,7 +10524,6 @@ x-tagGroups:
|
|||
- Scoring
|
||||
- ScoringFunctions
|
||||
- Shields
|
||||
- SyntheticDataGeneration (Coming Soon)
|
||||
- ToolGroups
|
||||
- ToolRuntime
|
||||
- VectorIO
|
||||
|
|
|
|||
18091
docs/static/stainless-llama-stack-spec.html
vendored
18091
docs/static/stainless-llama-stack-spec.html
vendored
File diff suppressed because it is too large
Load diff
1006
docs/static/stainless-llama-stack-spec.yaml
vendored
1006
docs/static/stainless-llama-stack-spec.yaml
vendored
File diff suppressed because it is too large
Load diff
|
|
@ -7,7 +7,7 @@ required-version = ">=0.7.0"
|
|||
|
||||
[project]
|
||||
name = "llama_stack"
|
||||
version = "0.3.0"
|
||||
version = "0.4.0.dev0"
|
||||
authors = [{ name = "Meta Llama", email = "llama-oss@meta.com" }]
|
||||
description = "Llama Stack"
|
||||
readme = "README.md"
|
||||
|
|
@ -284,7 +284,6 @@ exclude = [
|
|||
"^src/llama_stack/models/llama/llama3/interface\\.py$",
|
||||
"^src/llama_stack/models/llama/llama3/tokenizer\\.py$",
|
||||
"^src/llama_stack/models/llama/llama3/tool_utils\\.py$",
|
||||
"^src/llama_stack/providers/inline/agents/meta_reference/",
|
||||
"^src/llama_stack/providers/inline/datasetio/localfs/",
|
||||
"^src/llama_stack/providers/inline/eval/meta_reference/eval\\.py$",
|
||||
"^src/llama_stack/providers/inline/inference/meta_reference/inference\\.py$",
|
||||
|
|
|
|||
|
|
@ -215,6 +215,16 @@ build_image() {
|
|||
--build-arg "LLAMA_STACK_DIR=/workspace"
|
||||
)
|
||||
|
||||
# Pass UV index configuration for release branches
|
||||
if [[ -n "${UV_EXTRA_INDEX_URL:-}" ]]; then
|
||||
echo "Adding UV_EXTRA_INDEX_URL to docker build: $UV_EXTRA_INDEX_URL"
|
||||
build_cmd+=(--build-arg "UV_EXTRA_INDEX_URL=$UV_EXTRA_INDEX_URL")
|
||||
fi
|
||||
if [[ -n "${UV_INDEX_STRATEGY:-}" ]]; then
|
||||
echo "Adding UV_INDEX_STRATEGY to docker build: $UV_INDEX_STRATEGY"
|
||||
build_cmd+=(--build-arg "UV_INDEX_STRATEGY=$UV_INDEX_STRATEGY")
|
||||
fi
|
||||
|
||||
if ! "${build_cmd[@]}"; then
|
||||
echo "❌ Failed to build Docker image"
|
||||
exit 1
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ COLLECT_ONLY=false
|
|||
|
||||
# Function to display usage
|
||||
usage() {
|
||||
cat << EOF
|
||||
cat <<EOF
|
||||
Usage: $0 [OPTIONS]
|
||||
|
||||
Options:
|
||||
|
|
@ -102,7 +102,6 @@ while [[ $# -gt 0 ]]; do
|
|||
esac
|
||||
done
|
||||
|
||||
|
||||
# Validate required parameters
|
||||
if [[ -z "$STACK_CONFIG" && "$COLLECT_ONLY" == false ]]; then
|
||||
echo "Error: --stack-config is required"
|
||||
|
|
@ -177,21 +176,45 @@ cd $ROOT_DIR
|
|||
# check if "llama" and "pytest" are available. this script does not use `uv run` given
|
||||
# it can be used in a pre-release environment where we have not been able to tell
|
||||
# uv about pre-release dependencies properly (yet).
|
||||
if [[ "$COLLECT_ONLY" == false ]] && ! command -v llama &> /dev/null; then
|
||||
if [[ "$COLLECT_ONLY" == false ]] && ! command -v llama &>/dev/null; then
|
||||
echo "llama could not be found, ensure llama-stack is installed"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! command -v pytest &> /dev/null; then
|
||||
if ! command -v pytest &>/dev/null; then
|
||||
echo "pytest could not be found, ensure pytest is installed"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Helper function to find next available port
|
||||
find_available_port() {
|
||||
local start_port=$1
|
||||
local port=$start_port
|
||||
for ((i=0; i<100; i++)); do
|
||||
if ! lsof -Pi :$port -sTCP:LISTEN -t >/dev/null 2>&1; then
|
||||
echo $port
|
||||
return 0
|
||||
fi
|
||||
((port++))
|
||||
done
|
||||
echo "Failed to find available port starting from $start_port" >&2
|
||||
return 1
|
||||
}
|
||||
|
||||
# Start Llama Stack Server if needed
|
||||
if [[ "$STACK_CONFIG" == *"server:"* && "$COLLECT_ONLY" == false ]]; then
|
||||
# Find an available port for the server
|
||||
LLAMA_STACK_PORT=$(find_available_port 8321)
|
||||
if [[ $? -ne 0 ]]; then
|
||||
echo "Error: $LLAMA_STACK_PORT"
|
||||
exit 1
|
||||
fi
|
||||
export LLAMA_STACK_PORT
|
||||
echo "Will use port: $LLAMA_STACK_PORT"
|
||||
|
||||
stop_server() {
|
||||
echo "Stopping Llama Stack Server..."
|
||||
pids=$(lsof -i :8321 | awk 'NR>1 {print $2}')
|
||||
pids=$(lsof -i :$LLAMA_STACK_PORT | awk 'NR>1 {print $2}')
|
||||
if [[ -n "$pids" ]]; then
|
||||
echo "Killing Llama Stack Server processes: $pids"
|
||||
kill -9 $pids
|
||||
|
|
@ -201,20 +224,25 @@ if [[ "$STACK_CONFIG" == *"server:"* && "$COLLECT_ONLY" == false ]]; then
|
|||
echo "Llama Stack Server stopped"
|
||||
}
|
||||
|
||||
# check if server is already running
|
||||
if curl -s http://localhost:8321/v1/health 2>/dev/null | grep -q "OK"; then
|
||||
echo "Llama Stack Server is already running, skipping start"
|
||||
else
|
||||
echo "=== Starting Llama Stack Server ==="
|
||||
export LLAMA_STACK_LOG_WIDTH=120
|
||||
|
||||
# Configure telemetry collector for server mode
|
||||
# Use a fixed port for the OTEL collector so the server can connect to it
|
||||
COLLECTOR_PORT=4317
|
||||
export LLAMA_STACK_TEST_COLLECTOR_PORT="${COLLECTOR_PORT}"
|
||||
export OTEL_EXPORTER_OTLP_ENDPOINT="http://127.0.0.1:${COLLECTOR_PORT}"
|
||||
export OTEL_EXPORTER_OTLP_PROTOCOL="http/protobuf"
|
||||
export OTEL_BSP_SCHEDULE_DELAY="200"
|
||||
export OTEL_BSP_EXPORT_TIMEOUT="2000"
|
||||
|
||||
# remove "server:" from STACK_CONFIG
|
||||
stack_config=$(echo "$STACK_CONFIG" | sed 's/^server://')
|
||||
nohup llama stack run $stack_config > server.log 2>&1 &
|
||||
nohup llama stack run $stack_config >server.log 2>&1 &
|
||||
|
||||
echo "Waiting for Llama Stack Server to start..."
|
||||
echo "Waiting for Llama Stack Server to start on port $LLAMA_STACK_PORT..."
|
||||
for i in {1..30}; do
|
||||
if curl -s http://localhost:8321/v1/health 2>/dev/null | grep -q "OK"; then
|
||||
if curl -s http://localhost:$LLAMA_STACK_PORT/v1/health 2>/dev/null | grep -q "OK"; then
|
||||
echo "✅ Llama Stack Server started successfully"
|
||||
break
|
||||
fi
|
||||
|
|
@ -227,7 +255,6 @@ if [[ "$STACK_CONFIG" == *"server:"* && "$COLLECT_ONLY" == false ]]; then
|
|||
sleep 1
|
||||
done
|
||||
echo ""
|
||||
fi
|
||||
|
||||
trap stop_server EXIT ERR INT TERM
|
||||
fi
|
||||
|
|
@ -239,7 +266,7 @@ if [[ "$STACK_CONFIG" == *"docker:"* && "$COLLECT_ONLY" == false ]]; then
|
|||
container_name="llama-stack-test-$DISTRO"
|
||||
if docker ps -a --format '{{.Names}}' | grep -q "^${container_name}$"; then
|
||||
echo "Dumping container logs before stopping..."
|
||||
docker logs "$container_name" > "docker-${DISTRO}-${INFERENCE_MODE}.log" 2>&1 || true
|
||||
docker logs "$container_name" >"docker-${DISTRO}-${INFERENCE_MODE}.log" 2>&1 || true
|
||||
echo "Stopping and removing container: $container_name"
|
||||
docker stop "$container_name" 2>/dev/null || true
|
||||
docker rm "$container_name" 2>/dev/null || true
|
||||
|
|
@ -251,7 +278,14 @@ if [[ "$STACK_CONFIG" == *"docker:"* && "$COLLECT_ONLY" == false ]]; then
|
|||
|
||||
# Extract distribution name from docker:distro format
|
||||
DISTRO=$(echo "$STACK_CONFIG" | sed 's/^docker://')
|
||||
export LLAMA_STACK_PORT=8321
|
||||
# Find an available port for the docker container
|
||||
LLAMA_STACK_PORT=$(find_available_port 8321)
|
||||
if [[ $? -ne 0 ]]; then
|
||||
echo "Error: $LLAMA_STACK_PORT"
|
||||
exit 1
|
||||
fi
|
||||
export LLAMA_STACK_PORT
|
||||
echo "Will use port: $LLAMA_STACK_PORT"
|
||||
|
||||
echo "=== Building Docker Image for distribution: $DISTRO ==="
|
||||
containerfile="$ROOT_DIR/containers/Containerfile"
|
||||
|
|
@ -271,6 +305,16 @@ if [[ "$STACK_CONFIG" == *"docker:"* && "$COLLECT_ONLY" == false ]]; then
|
|||
--build-arg "LLAMA_STACK_DIR=/workspace"
|
||||
)
|
||||
|
||||
# Pass UV index configuration for release branches
|
||||
if [[ -n "${UV_EXTRA_INDEX_URL:-}" ]]; then
|
||||
echo "Adding UV_EXTRA_INDEX_URL to docker build: $UV_EXTRA_INDEX_URL"
|
||||
build_cmd+=(--build-arg "UV_EXTRA_INDEX_URL=$UV_EXTRA_INDEX_URL")
|
||||
fi
|
||||
if [[ -n "${UV_INDEX_STRATEGY:-}" ]]; then
|
||||
echo "Adding UV_INDEX_STRATEGY to docker build: $UV_INDEX_STRATEGY"
|
||||
build_cmd+=(--build-arg "UV_INDEX_STRATEGY=$UV_INDEX_STRATEGY")
|
||||
fi
|
||||
|
||||
if ! "${build_cmd[@]}"; then
|
||||
echo "❌ Failed to build Docker image"
|
||||
exit 1
|
||||
|
|
@ -284,10 +328,15 @@ if [[ "$STACK_CONFIG" == *"docker:"* && "$COLLECT_ONLY" == false ]]; then
|
|||
docker stop "$container_name" 2>/dev/null || true
|
||||
docker rm "$container_name" 2>/dev/null || true
|
||||
|
||||
# Configure telemetry collector port shared between host and container
|
||||
COLLECTOR_PORT=4317
|
||||
export LLAMA_STACK_TEST_COLLECTOR_PORT="${COLLECTOR_PORT}"
|
||||
|
||||
# Build environment variables for docker run
|
||||
DOCKER_ENV_VARS=""
|
||||
DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e LLAMA_STACK_TEST_INFERENCE_MODE=$INFERENCE_MODE"
|
||||
DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e LLAMA_STACK_TEST_STACK_CONFIG_TYPE=server"
|
||||
DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e OTEL_EXPORTER_OTLP_ENDPOINT=http://localhost:${COLLECTOR_PORT}"
|
||||
|
||||
# Pass through API keys if they exist
|
||||
[ -n "${TOGETHER_API_KEY:-}" ] && DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e TOGETHER_API_KEY=$TOGETHER_API_KEY"
|
||||
|
|
@ -308,8 +357,20 @@ if [[ "$STACK_CONFIG" == *"docker:"* && "$COLLECT_ONLY" == false ]]; then
|
|||
fi
|
||||
echo "Using image: $IMAGE_NAME"
|
||||
|
||||
docker run -d --network host --name "$container_name" \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
# On macOS/Darwin, --network host doesn't work as expected due to Docker running in a VM
|
||||
# Use regular port mapping instead
|
||||
NETWORK_MODE=""
|
||||
PORT_MAPPINGS=""
|
||||
if [[ "$(uname)" != "Darwin" ]] && [[ "$(uname)" != *"MINGW"* ]]; then
|
||||
NETWORK_MODE="--network host"
|
||||
else
|
||||
# On non-Linux (macOS, Windows), need explicit port mappings for both app and telemetry
|
||||
PORT_MAPPINGS="-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT -p $COLLECTOR_PORT:$COLLECTOR_PORT"
|
||||
echo "Using bridge networking with port mapping (non-Linux)"
|
||||
fi
|
||||
|
||||
docker run -d $NETWORK_MODE --name "$container_name" \
|
||||
$PORT_MAPPINGS \
|
||||
$DOCKER_ENV_VARS \
|
||||
"$IMAGE_NAME" \
|
||||
--port $LLAMA_STACK_PORT
|
||||
|
|
@ -411,17 +472,13 @@ elif [ $exit_code -eq 5 ]; then
|
|||
else
|
||||
echo "❌ Tests failed"
|
||||
echo ""
|
||||
echo "=== Dumping last 100 lines of logs for debugging ==="
|
||||
|
||||
# Output server or container logs based on stack config
|
||||
if [[ "$STACK_CONFIG" == *"server:"* && -f "server.log" ]]; then
|
||||
echo "--- Last 100 lines of server.log ---"
|
||||
tail -100 server.log
|
||||
echo "--- Server side failures can be located inside server.log (available from artifacts on CI) ---"
|
||||
elif [[ "$STACK_CONFIG" == *"docker:"* ]]; then
|
||||
docker_log_file="docker-${DISTRO}-${INFERENCE_MODE}.log"
|
||||
if [[ -f "$docker_log_file" ]]; then
|
||||
echo "--- Last 100 lines of $docker_log_file ---"
|
||||
tail -100 "$docker_log_file"
|
||||
echo "--- Server side failures can be located inside $docker_log_file (available from artifacts on CI) ---"
|
||||
fi
|
||||
fi
|
||||
|
||||
|
|
|
|||
42
scripts/uv-run-with-index.sh
Executable file
42
scripts/uv-run-with-index.sh
Executable file
|
|
@ -0,0 +1,42 @@
|
|||
#!/bin/bash
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# Detect current branch and target branch
|
||||
# In GitHub Actions, use GITHUB_REF/GITHUB_BASE_REF
|
||||
if [[ -n "${GITHUB_REF:-}" ]]; then
|
||||
BRANCH="${GITHUB_REF#refs/heads/}"
|
||||
else
|
||||
BRANCH=$(git rev-parse --abbrev-ref HEAD 2>/dev/null || echo "")
|
||||
fi
|
||||
|
||||
# For PRs, check the target branch
|
||||
if [[ -n "${GITHUB_BASE_REF:-}" ]]; then
|
||||
TARGET_BRANCH="${GITHUB_BASE_REF}"
|
||||
else
|
||||
TARGET_BRANCH=$(git rev-parse --abbrev-ref HEAD@{upstream} 2>/dev/null | sed 's|origin/||' || echo "")
|
||||
fi
|
||||
|
||||
# Check if on a release branch or targeting one, or LLAMA_STACK_RELEASE_MODE is set
|
||||
IS_RELEASE=false
|
||||
if [[ "$BRANCH" =~ ^release-[0-9]+\.[0-9]+\.x$ ]]; then
|
||||
IS_RELEASE=true
|
||||
elif [[ "$TARGET_BRANCH" =~ ^release-[0-9]+\.[0-9]+\.x$ ]]; then
|
||||
IS_RELEASE=true
|
||||
elif [[ "${LLAMA_STACK_RELEASE_MODE:-}" == "true" ]]; then
|
||||
IS_RELEASE=true
|
||||
fi
|
||||
|
||||
# On release branches, use test.pypi as extra index for RC versions
|
||||
if [[ "$IS_RELEASE" == "true" ]]; then
|
||||
export UV_EXTRA_INDEX_URL="https://test.pypi.org/simple/"
|
||||
export UV_INDEX_STRATEGY="unsafe-best-match"
|
||||
fi
|
||||
|
||||
# Run uv with all arguments passed through
|
||||
exec uv "$@"
|
||||
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
|
@ -202,7 +203,7 @@ class OpenAIResponseMessage(BaseModel):
|
|||
scenarios.
|
||||
"""
|
||||
|
||||
content: str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]
|
||||
content: str | Sequence[OpenAIResponseInputMessageContent] | Sequence[OpenAIResponseOutputMessageContent]
|
||||
role: Literal["system"] | Literal["developer"] | Literal["user"] | Literal["assistant"]
|
||||
type: Literal["message"] = "message"
|
||||
|
||||
|
|
@ -254,10 +255,10 @@ class OpenAIResponseOutputMessageFileSearchToolCall(BaseModel):
|
|||
"""
|
||||
|
||||
id: str
|
||||
queries: list[str]
|
||||
queries: Sequence[str]
|
||||
status: str
|
||||
type: Literal["file_search_call"] = "file_search_call"
|
||||
results: list[OpenAIResponseOutputMessageFileSearchToolCallResults] | None = None
|
||||
results: Sequence[OpenAIResponseOutputMessageFileSearchToolCallResults] | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -597,7 +598,7 @@ class OpenAIResponseObject(BaseModel):
|
|||
id: str
|
||||
model: str
|
||||
object: Literal["response"] = "response"
|
||||
output: list[OpenAIResponseOutput]
|
||||
output: Sequence[OpenAIResponseOutput]
|
||||
parallel_tool_calls: bool = False
|
||||
previous_response_id: str | None = None
|
||||
prompt: OpenAIResponsePrompt | None = None
|
||||
|
|
@ -607,7 +608,7 @@ class OpenAIResponseObject(BaseModel):
|
|||
# before the field was added. New responses will have this set always.
|
||||
text: OpenAIResponseText = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text"))
|
||||
top_p: float | None = None
|
||||
tools: list[OpenAIResponseTool] | None = None
|
||||
tools: Sequence[OpenAIResponseTool] | None = None
|
||||
truncation: str | None = None
|
||||
usage: OpenAIResponseUsage | None = None
|
||||
instructions: str | None = None
|
||||
|
|
@ -1315,7 +1316,7 @@ class ListOpenAIResponseInputItem(BaseModel):
|
|||
:param object: Object type identifier, always "list"
|
||||
"""
|
||||
|
||||
data: list[OpenAIResponseInput]
|
||||
data: Sequence[OpenAIResponseInput]
|
||||
object: Literal["list"] = "list"
|
||||
|
||||
|
||||
|
|
@ -1326,7 +1327,7 @@ class OpenAIResponseObjectWithInput(OpenAIResponseObject):
|
|||
:param input: List of input items that led to this response
|
||||
"""
|
||||
|
||||
input: list[OpenAIResponseInput]
|
||||
input: Sequence[OpenAIResponseInput]
|
||||
|
||||
def to_response_object(self) -> OpenAIResponseObject:
|
||||
"""Convert to OpenAIResponseObject by excluding input field."""
|
||||
|
|
@ -1344,7 +1345,7 @@ class ListOpenAIResponseObject(BaseModel):
|
|||
:param object: Object type identifier, always "list"
|
||||
"""
|
||||
|
||||
data: list[OpenAIResponseObjectWithInput]
|
||||
data: Sequence[OpenAIResponseObjectWithInput]
|
||||
has_more: bool
|
||||
first_id: str
|
||||
last_id: str
|
||||
|
|
|
|||
|
|
@ -4,14 +4,21 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Protocol, runtime_checkable
|
||||
from typing import Literal, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.apis.version import (
|
||||
LLAMA_STACK_API_V1,
|
||||
)
|
||||
from llama_stack.providers.datatypes import HealthStatus
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
# Valid values for the route filter parameter.
|
||||
# Actual API levels: v1, v1alpha, v1beta (filters by level, excludes deprecated)
|
||||
# Special filter value: "deprecated" (shows deprecated routes regardless of level)
|
||||
ApiFilter = Literal["v1", "v1alpha", "v1beta", "deprecated"]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RouteInfo(BaseModel):
|
||||
|
|
@ -64,11 +71,12 @@ class Inspect(Protocol):
|
|||
"""
|
||||
|
||||
@webmethod(route="/inspect/routes", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_routes(self) -> ListRoutesResponse:
|
||||
async def list_routes(self, api_filter: ApiFilter | None = None) -> ListRoutesResponse:
|
||||
"""List routes.
|
||||
|
||||
List all available API routes with their methods and implementing providers.
|
||||
|
||||
:param api_filter: Optional filter to control which routes are returned. Can be an API level ('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level, or 'deprecated' to show deprecated routes across all levels. If not specified, returns only non-deprecated v1 routes.
|
||||
:returns: Response containing information about all available routes.
|
||||
"""
|
||||
...
|
||||
|
|
|
|||
|
|
@ -90,12 +90,14 @@ class OpenAIModel(BaseModel):
|
|||
:object: The object type, which will be "model"
|
||||
:created: The Unix timestamp in seconds when the model was created
|
||||
:owned_by: The owner of the model
|
||||
:custom_metadata: Llama Stack-specific metadata including model_type, provider info, and additional metadata
|
||||
"""
|
||||
|
||||
id: str
|
||||
object: Literal["model"] = "model"
|
||||
created: int
|
||||
owned_by: str
|
||||
custom_metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class OpenAIListModelsResponse(BaseModel):
|
||||
|
|
@ -113,7 +115,7 @@ class Models(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/models", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
||||
@webmethod(route="/openai/v1/models", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def openai_list_models(self) -> OpenAIListModelsResponse:
|
||||
"""List models using the OpenAI API.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .synthetic_data_generation import *
|
||||
|
|
@ -1,77 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Protocol
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
class FilteringFunction(Enum):
|
||||
"""The type of filtering function.
|
||||
|
||||
:cvar none: No filtering applied, accept all generated synthetic data
|
||||
:cvar random: Random sampling of generated data points
|
||||
:cvar top_k: Keep only the top-k highest scoring synthetic data samples
|
||||
:cvar top_p: Nucleus-style filtering, keep samples exceeding cumulative score threshold
|
||||
:cvar top_k_top_p: Combined top-k and top-p filtering strategy
|
||||
:cvar sigmoid: Apply sigmoid function for probability-based filtering
|
||||
"""
|
||||
|
||||
none = "none"
|
||||
random = "random"
|
||||
top_k = "top_k"
|
||||
top_p = "top_p"
|
||||
top_k_top_p = "top_k_top_p"
|
||||
sigmoid = "sigmoid"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SyntheticDataGenerationRequest(BaseModel):
|
||||
"""Request to generate synthetic data. A small batch of prompts and a filtering function
|
||||
|
||||
:param dialogs: List of conversation messages to use as input for synthetic data generation
|
||||
:param filtering_function: Type of filtering to apply to generated synthetic data samples
|
||||
:param model: (Optional) The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint
|
||||
"""
|
||||
|
||||
dialogs: list[Message]
|
||||
filtering_function: FilteringFunction = FilteringFunction.none
|
||||
model: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SyntheticDataGenerationResponse(BaseModel):
|
||||
"""Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold.
|
||||
|
||||
:param synthetic_data: List of generated synthetic data samples that passed the filtering criteria
|
||||
:param statistics: (Optional) Statistical information about the generation process and filtering results
|
||||
"""
|
||||
|
||||
synthetic_data: list[dict[str, Any]]
|
||||
statistics: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class SyntheticDataGeneration(Protocol):
|
||||
@webmethod(route="/synthetic-data-generation/generate", level=LLAMA_STACK_API_V1)
|
||||
def synthetic_data_generate(
|
||||
self,
|
||||
dialogs: list[Message],
|
||||
filtering_function: FilteringFunction = FilteringFunction.none,
|
||||
model: str | None = None,
|
||||
) -> SyntheticDataGenerationResponse:
|
||||
"""Generate synthetic data based on input dialogs and apply filtering.
|
||||
|
||||
:param dialogs: List of conversation messages to use as input for synthetic data generation
|
||||
:param filtering_function: Type of filtering to apply to generated synthetic data samples
|
||||
:param model: (Optional) The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint
|
||||
:returns: Response containing filtered synthetic data samples and optional statistics
|
||||
"""
|
||||
...
|
||||
|
|
@ -8,7 +8,6 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import uuid
|
||||
from typing import Annotated, Any, Literal, Protocol, runtime_checkable
|
||||
|
||||
from fastapi import Body
|
||||
|
|
@ -18,7 +17,6 @@ from llama_stack.apis.inference import InterleavedContent
|
|||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
from llama_stack.strong_typing.schema import register_schema
|
||||
|
||||
|
|
@ -61,38 +59,19 @@ class Chunk(BaseModel):
|
|||
"""
|
||||
A chunk of content that can be inserted into a vector database.
|
||||
:param content: The content of the chunk, which can be interleaved text, images, or other types.
|
||||
:param embedding: Optional embedding for the chunk. If not provided, it will be computed later.
|
||||
:param chunk_id: Unique identifier for the chunk. Must be provided explicitly.
|
||||
:param metadata: Metadata associated with the chunk that will be used in the model context during inference.
|
||||
:param stored_chunk_id: The chunk ID that is stored in the vector database. Used for backend functionality.
|
||||
:param embedding: Optional embedding for the chunk. If not provided, it will be computed later.
|
||||
:param chunk_metadata: Metadata for the chunk that will NOT be used in the context during inference.
|
||||
The `chunk_metadata` is required backend functionality.
|
||||
"""
|
||||
|
||||
content: InterleavedContent
|
||||
chunk_id: str
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
embedding: list[float] | None = None
|
||||
# The alias parameter serializes the field as "chunk_id" in JSON but keeps the internal name as "stored_chunk_id"
|
||||
stored_chunk_id: str | None = Field(default=None, alias="chunk_id")
|
||||
chunk_metadata: ChunkMetadata | None = None
|
||||
|
||||
model_config = {"populate_by_name": True}
|
||||
|
||||
def model_post_init(self, __context):
|
||||
# Extract chunk_id from metadata if present
|
||||
if self.metadata and "chunk_id" in self.metadata:
|
||||
self.stored_chunk_id = self.metadata.pop("chunk_id")
|
||||
|
||||
@property
|
||||
def chunk_id(self) -> str:
|
||||
"""Returns the chunk ID, which is either an input `chunk_id` or a generated one if not set."""
|
||||
if self.stored_chunk_id:
|
||||
return self.stored_chunk_id
|
||||
|
||||
if "document_id" in self.metadata:
|
||||
return generate_chunk_id(self.metadata["document_id"], str(self.content))
|
||||
|
||||
return generate_chunk_id(str(uuid.uuid4()), str(self.content))
|
||||
|
||||
@property
|
||||
def document_id(self) -> str | None:
|
||||
"""Returns the document_id from either metadata or chunk_metadata, with metadata taking precedence."""
|
||||
|
|
|
|||
|
|
@ -8,16 +8,30 @@ import argparse
|
|||
import os
|
||||
import ssl
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import uvicorn
|
||||
import yaml
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.cli.stack.utils import ImageType
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.core.datatypes import Api, Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_provider_registry
|
||||
from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
|
||||
from llama_stack.core.storage.datatypes import (
|
||||
InferenceStoreReference,
|
||||
KVStoreReference,
|
||||
ServerStoresConfig,
|
||||
SqliteKVStoreConfig,
|
||||
SqliteSqlStoreConfig,
|
||||
SqlStoreReference,
|
||||
StorageConfig,
|
||||
)
|
||||
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
|
||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.log import LoggingConfig, get_logger
|
||||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
|
@ -68,6 +82,12 @@ class StackRun(Subcommand):
|
|||
action="store_true",
|
||||
help="Start the UI server",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--providers",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Run a stack with only a list of providers. This list is formatted like: api1=provider1,api1=provider2,api2=provider3. Where there can be multiple providers per API.",
|
||||
)
|
||||
|
||||
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
|
||||
import yaml
|
||||
|
|
@ -93,6 +113,55 @@ class StackRun(Subcommand):
|
|||
config_file = resolve_config_or_distro(args.config, Mode.RUN)
|
||||
except ValueError as e:
|
||||
self.parser.error(str(e))
|
||||
elif args.providers:
|
||||
provider_list: dict[str, list[Provider]] = dict()
|
||||
for api_provider in args.providers.split(","):
|
||||
if "=" not in api_provider:
|
||||
cprint(
|
||||
"Could not parse `--providers`. Please ensure the list is in the format api1=provider1,api2=provider2",
|
||||
color="red",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
api, provider_type = api_provider.split("=")
|
||||
providers_for_api = get_provider_registry().get(Api(api), None)
|
||||
if providers_for_api is None:
|
||||
cprint(
|
||||
f"{api} is not a valid API.",
|
||||
color="red",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
if provider_type in providers_for_api:
|
||||
config_type = instantiate_class_type(providers_for_api[provider_type].config_class)
|
||||
if config_type is not None and hasattr(config_type, "sample_run_config"):
|
||||
config = config_type.sample_run_config(__distro_dir__="~/.llama/distributions/providers-run")
|
||||
else:
|
||||
config = {}
|
||||
provider = Provider(
|
||||
provider_type=provider_type,
|
||||
config=config,
|
||||
provider_id=provider_type.split("::")[1],
|
||||
)
|
||||
provider_list.setdefault(api, []).append(provider)
|
||||
else:
|
||||
cprint(
|
||||
f"{provider} is not a valid provider for the {api} API.",
|
||||
color="red",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
run_config = self._generate_run_config_from_providers(providers=provider_list)
|
||||
config_dict = run_config.model_dump(mode="json")
|
||||
|
||||
# Write config to disk in providers-run directory
|
||||
distro_dir = DISTRIBS_BASE_DIR / "providers-run"
|
||||
config_file = distro_dir / "run.yaml"
|
||||
|
||||
logger.info(f"Writing generated config to: {config_file}")
|
||||
with open(config_file, "w") as f:
|
||||
yaml.dump(config_dict, f, default_flow_style=False, sort_keys=False)
|
||||
|
||||
else:
|
||||
config_file = None
|
||||
|
||||
|
|
@ -106,7 +175,8 @@ class StackRun(Subcommand):
|
|||
|
||||
try:
|
||||
config = parse_and_maybe_upgrade_config(config_dict)
|
||||
if not os.path.exists(str(config.external_providers_dir)):
|
||||
# Create external_providers_dir if it's specified and doesn't exist
|
||||
if config.external_providers_dir and not os.path.exists(str(config.external_providers_dir)):
|
||||
os.makedirs(str(config.external_providers_dir), exist_ok=True)
|
||||
except AttributeError as e:
|
||||
self.parser.error(f"failed to parse config file '{config_file}':\n {e}")
|
||||
|
|
@ -127,7 +197,7 @@ class StackRun(Subcommand):
|
|||
config = StackRunConfig(**cast_image_name_to_string(replace_env_vars(config_contents)))
|
||||
|
||||
port = args.port or config.server.port
|
||||
host = config.server.host or ["::", "0.0.0.0"]
|
||||
host = config.server.host or "0.0.0.0"
|
||||
|
||||
# Set the config file in environment so create_app can find it
|
||||
os.environ["LLAMA_STACK_CONFIG"] = str(config_file)
|
||||
|
|
@ -139,6 +209,7 @@ class StackRun(Subcommand):
|
|||
"lifespan": "on",
|
||||
"log_level": logger.getEffectiveLevel(),
|
||||
"log_config": logger_config,
|
||||
"workers": config.server.workers,
|
||||
}
|
||||
|
||||
keyfile = config.server.tls_keyfile
|
||||
|
|
@ -212,3 +283,44 @@ class StackRun(Subcommand):
|
|||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start UI development server in {ui_dir}: {e}")
|
||||
|
||||
def _generate_run_config_from_providers(self, providers: dict[str, list[Provider]]):
|
||||
apis = list(providers.keys())
|
||||
distro_dir = DISTRIBS_BASE_DIR / "providers-run"
|
||||
# need somewhere to put the storage.
|
||||
os.makedirs(distro_dir, exist_ok=True)
|
||||
storage = StorageConfig(
|
||||
backends={
|
||||
"kv_default": SqliteKVStoreConfig(
|
||||
db_path=f"${{env.SQLITE_STORE_DIR:={distro_dir}}}/kvstore.db",
|
||||
),
|
||||
"sql_default": SqliteSqlStoreConfig(
|
||||
db_path=f"${{env.SQLITE_STORE_DIR:={distro_dir}}}/sql_store.db",
|
||||
),
|
||||
},
|
||||
stores=ServerStoresConfig(
|
||||
metadata=KVStoreReference(
|
||||
backend="kv_default",
|
||||
namespace="registry",
|
||||
),
|
||||
inference=InferenceStoreReference(
|
||||
backend="sql_default",
|
||||
table_name="inference_store",
|
||||
),
|
||||
conversations=SqlStoreReference(
|
||||
backend="sql_default",
|
||||
table_name="openai_conversations",
|
||||
),
|
||||
prompts=KVStoreReference(
|
||||
backend="kv_default",
|
||||
namespace="prompts",
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
return StackRunConfig(
|
||||
image_name="providers-run",
|
||||
apis=apis,
|
||||
providers=providers,
|
||||
storage=storage,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ from llama_stack.core.distribution import (
|
|||
get_provider_registry,
|
||||
)
|
||||
from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
|
||||
from llama_stack.core.utils.config_dirs import EXTERNAL_PROVIDERS_DIR
|
||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.core.utils.prompt_for_config import prompt_for_config
|
||||
from llama_stack.log import get_logger
|
||||
|
|
@ -194,19 +193,11 @@ def upgrade_from_routing_table(
|
|||
|
||||
|
||||
def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfig:
|
||||
version = config_dict.get("version", None)
|
||||
if version == LLAMA_STACK_RUN_CONFIG_VERSION:
|
||||
processed_config_dict = replace_env_vars(config_dict)
|
||||
return StackRunConfig(**cast_image_name_to_string(processed_config_dict))
|
||||
|
||||
if "routing_table" in config_dict:
|
||||
logger.info("Upgrading config...")
|
||||
config_dict = upgrade_from_routing_table(config_dict)
|
||||
|
||||
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION
|
||||
|
||||
if not config_dict.get("external_providers_dir", None):
|
||||
config_dict["external_providers_dir"] = EXTERNAL_PROVIDERS_DIR
|
||||
|
||||
processed_config_dict = replace_env_vars(config_dict)
|
||||
return StackRunConfig(**cast_image_name_to_string(processed_config_dict))
|
||||
|
|
|
|||
|
|
@ -473,6 +473,10 @@ class ServerConfig(BaseModel):
|
|||
"- true: Enable localhost CORS for development\n"
|
||||
"- {allow_origins: [...], allow_methods: [...], ...}: Full configuration",
|
||||
)
|
||||
workers: int = Field(
|
||||
default=1,
|
||||
description="Number of workers to use for the server",
|
||||
)
|
||||
|
||||
|
||||
class StackRunConfig(BaseModel):
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from llama_stack.apis.inspect import (
|
|||
RouteInfo,
|
||||
VersionInfo,
|
||||
)
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.core.external import load_external_apis
|
||||
from llama_stack.core.server.routes import get_all_api_routes
|
||||
|
|
@ -39,9 +40,21 @@ class DistributionInspectImpl(Inspect):
|
|||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_routes(self) -> ListRoutesResponse:
|
||||
async def list_routes(self, api_filter: str | None = None) -> ListRoutesResponse:
|
||||
run_config: StackRunConfig = self.config.run_config
|
||||
|
||||
# Helper function to determine if a route should be included based on api_filter
|
||||
def should_include_route(webmethod) -> bool:
|
||||
if api_filter is None:
|
||||
# Default: only non-deprecated v1 APIs
|
||||
return not webmethod.deprecated and webmethod.level == LLAMA_STACK_API_V1
|
||||
elif api_filter == "deprecated":
|
||||
# Special filter: show deprecated routes regardless of their actual level
|
||||
return bool(webmethod.deprecated)
|
||||
else:
|
||||
# Filter by API level (non-deprecated routes only)
|
||||
return not webmethod.deprecated and webmethod.level == api_filter
|
||||
|
||||
ret = []
|
||||
external_apis = load_external_apis(run_config)
|
||||
all_endpoints = get_all_api_routes(external_apis)
|
||||
|
|
@ -55,8 +68,8 @@ class DistributionInspectImpl(Inspect):
|
|||
method=next(iter([m for m in e.methods if m != "HEAD"])),
|
||||
provider_types=[], # These APIs don't have "real" providers - they're internal to the stack
|
||||
)
|
||||
for e, _ in endpoints
|
||||
if e.methods is not None
|
||||
for e, webmethod in endpoints
|
||||
if e.methods is not None and should_include_route(webmethod)
|
||||
]
|
||||
)
|
||||
else:
|
||||
|
|
@ -69,8 +82,8 @@ class DistributionInspectImpl(Inspect):
|
|||
method=next(iter([m for m in e.methods if m != "HEAD"])),
|
||||
provider_types=[p.provider_type for p in providers],
|
||||
)
|
||||
for e, _ in endpoints
|
||||
if e.methods is not None
|
||||
for e, webmethod in endpoints
|
||||
if e.methods is not None and should_include_route(webmethod)
|
||||
]
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
import asyncio
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from collections.abc import AsyncIterator
|
||||
from datetime import UTC, datetime
|
||||
from typing import Annotated, Any
|
||||
|
||||
|
|
@ -15,20 +15,10 @@ from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatC
|
|||
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
InterleavedContent,
|
||||
)
|
||||
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
CompletionResponse,
|
||||
CompletionResponseStreamChunk,
|
||||
Inference,
|
||||
ListOpenAIChatCompletionResponse,
|
||||
Message,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionChunk,
|
||||
|
|
@ -45,15 +35,13 @@ from llama_stack.apis.inference import (
|
|||
OpenAIMessageParam,
|
||||
Order,
|
||||
RerankResponse,
|
||||
StopReason,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.core.telemetry.telemetry import MetricEvent, MetricInResponse
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.core.telemetry.telemetry import MetricEvent
|
||||
from llama_stack.core.telemetry.tracing import enqueue_event, get_current_span
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||
|
|
@ -153,35 +141,6 @@ class InferenceRouter(Inference):
|
|||
)
|
||||
return metric_events
|
||||
|
||||
async def _compute_and_log_token_usage(
|
||||
self,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
total_tokens: int,
|
||||
model: Model,
|
||||
) -> list[MetricInResponse]:
|
||||
metrics = self._construct_metrics(
|
||||
prompt_tokens, completion_tokens, total_tokens, model.model_id, model.provider_id
|
||||
)
|
||||
if self.telemetry_enabled:
|
||||
for metric in metrics:
|
||||
enqueue_event(metric)
|
||||
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
|
||||
|
||||
async def _count_tokens(
|
||||
self,
|
||||
messages: list[Message] | InterleavedContent,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
) -> int | None:
|
||||
if not hasattr(self, "formatter") or self.formatter is None:
|
||||
return None
|
||||
|
||||
if isinstance(messages, list):
|
||||
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
|
||||
else:
|
||||
encoded = self.formatter.encode_content(messages)
|
||||
return len(encoded.tokens) if encoded and encoded.tokens else 0
|
||||
|
||||
async def _get_model_provider(self, model_id: str, expected_model_type: str) -> tuple[Inference, str]:
|
||||
model = await self.routing_table.get_object_by_identifier("model", model_id)
|
||||
if model:
|
||||
|
|
@ -375,121 +334,6 @@ class InferenceRouter(Inference):
|
|||
)
|
||||
return health_statuses
|
||||
|
||||
async def stream_tokens_and_compute_metrics(
|
||||
self,
|
||||
response,
|
||||
prompt_tokens,
|
||||
fully_qualified_model_id: str,
|
||||
provider_id: str,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None] | AsyncGenerator[CompletionResponseStreamChunk, None]:
|
||||
completion_text = ""
|
||||
async for chunk in response:
|
||||
complete = False
|
||||
if hasattr(chunk, "event"): # only ChatCompletions have .event
|
||||
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
|
||||
if chunk.event.delta.type == "text":
|
||||
completion_text += chunk.event.delta.text
|
||||
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
|
||||
complete = True
|
||||
completion_tokens = await self._count_tokens(
|
||||
[
|
||||
CompletionMessage(
|
||||
content=completion_text,
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
],
|
||||
tool_prompt_format=tool_prompt_format,
|
||||
)
|
||||
else:
|
||||
if hasattr(chunk, "delta"):
|
||||
completion_text += chunk.delta
|
||||
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry_enabled:
|
||||
complete = True
|
||||
completion_tokens = await self._count_tokens(completion_text)
|
||||
# if we are done receiving tokens
|
||||
if complete:
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
|
||||
# Create a separate span for streaming completion metrics
|
||||
if self.telemetry_enabled:
|
||||
# Log metrics in the new span context
|
||||
completion_metrics = self._construct_metrics(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
fully_qualified_model_id=fully_qualified_model_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
for metric in completion_metrics:
|
||||
if metric.metric in [
|
||||
"completion_tokens",
|
||||
"total_tokens",
|
||||
]: # Only log completion and total tokens
|
||||
enqueue_event(metric)
|
||||
|
||||
# Return metrics in response
|
||||
async_metrics = [
|
||||
MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics
|
||||
]
|
||||
chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics
|
||||
else:
|
||||
# Fallback if no telemetry
|
||||
completion_metrics = self._construct_metrics(
|
||||
prompt_tokens or 0,
|
||||
completion_tokens or 0,
|
||||
total_tokens,
|
||||
fully_qualified_model_id=fully_qualified_model_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
async_metrics = [
|
||||
MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics
|
||||
]
|
||||
chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics
|
||||
yield chunk
|
||||
|
||||
async def count_tokens_and_compute_metrics(
|
||||
self,
|
||||
response: ChatCompletionResponse | CompletionResponse,
|
||||
prompt_tokens,
|
||||
fully_qualified_model_id: str,
|
||||
provider_id: str,
|
||||
tool_prompt_format: ToolPromptFormat | None = None,
|
||||
):
|
||||
if isinstance(response, ChatCompletionResponse):
|
||||
content = [response.completion_message]
|
||||
else:
|
||||
content = response.content
|
||||
completion_tokens = await self._count_tokens(messages=content, tool_prompt_format=tool_prompt_format)
|
||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
||||
|
||||
# Create a separate span for completion metrics
|
||||
if self.telemetry_enabled:
|
||||
# Log metrics in the new span context
|
||||
completion_metrics = self._construct_metrics(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
fully_qualified_model_id=fully_qualified_model_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
for metric in completion_metrics:
|
||||
if metric.metric in ["completion_tokens", "total_tokens"]: # Only log completion and total tokens
|
||||
enqueue_event(metric)
|
||||
|
||||
# Return metrics in response
|
||||
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics]
|
||||
|
||||
# Fallback if no telemetry
|
||||
metrics = self._construct_metrics(
|
||||
prompt_tokens or 0,
|
||||
completion_tokens or 0,
|
||||
total_tokens,
|
||||
fully_qualified_model_id=fully_qualified_model_id,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
|
||||
|
||||
async def stream_tokens_and_compute_metrics_openai_chat(
|
||||
self,
|
||||
response: AsyncIterator[OpenAIChatCompletionChunk],
|
||||
|
|
|
|||
|
|
@ -13,6 +13,8 @@ from llama_stack.core.datatypes import (
|
|||
ModelWithOwner,
|
||||
RegistryEntrySource,
|
||||
)
|
||||
from llama_stack.core.request_headers import PROVIDER_DATA_VAR, NeedsRequestProviderData
|
||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
from .common import CommonRoutingTableImpl, lookup_model
|
||||
|
|
@ -42,19 +44,104 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
|
||||
await self.update_registered_models(provider_id, models)
|
||||
|
||||
async def _get_dynamic_models_from_provider_data(self) -> list[Model]:
|
||||
"""
|
||||
Fetch models from providers that have credentials in the current request's provider_data.
|
||||
|
||||
This allows users to see models available to them from providers that require
|
||||
per-request API keys (via X-LlamaStack-Provider-Data header).
|
||||
|
||||
Returns models with fully qualified identifiers (provider_id/model_id) but does NOT
|
||||
cache them in the registry since they are user-specific.
|
||||
"""
|
||||
provider_data = PROVIDER_DATA_VAR.get()
|
||||
if not provider_data:
|
||||
return []
|
||||
|
||||
dynamic_models = []
|
||||
|
||||
for provider_id, provider in self.impls_by_provider_id.items():
|
||||
# Check if this provider supports provider_data
|
||||
if not isinstance(provider, NeedsRequestProviderData):
|
||||
continue
|
||||
|
||||
# Check if provider has a validator (some providers like ollama don't need per-request credentials)
|
||||
spec = getattr(provider, "__provider_spec__", None)
|
||||
if not spec or not getattr(spec, "provider_data_validator", None):
|
||||
continue
|
||||
|
||||
# Validate provider_data silently - we're speculatively checking all providers
|
||||
# so validation failures are expected when user didn't provide keys for this provider
|
||||
try:
|
||||
validator = instantiate_class_type(spec.provider_data_validator)
|
||||
validator(**provider_data)
|
||||
except Exception:
|
||||
# User didn't provide credentials for this provider - skip silently
|
||||
continue
|
||||
|
||||
# Validation succeeded! User has credentials for this provider
|
||||
# Now try to list models
|
||||
try:
|
||||
models = await provider.list_models()
|
||||
if not models:
|
||||
continue
|
||||
|
||||
# Ensure models have fully qualified identifiers with provider_id prefix
|
||||
for model in models:
|
||||
# Only add prefix if model identifier doesn't already have it
|
||||
if not model.identifier.startswith(f"{provider_id}/"):
|
||||
model.identifier = f"{provider_id}/{model.provider_resource_id}"
|
||||
|
||||
dynamic_models.append(model)
|
||||
|
||||
logger.debug(f"Fetched {len(models)} models from provider {provider_id} using provider_data")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to list models from provider {provider_id} with provider_data: {e}")
|
||||
continue
|
||||
|
||||
return dynamic_models
|
||||
|
||||
async def list_models(self) -> ListModelsResponse:
|
||||
return ListModelsResponse(data=await self.get_all_with_type("model"))
|
||||
# Get models from registry
|
||||
registry_models = await self.get_all_with_type("model")
|
||||
|
||||
# Get additional models available via provider_data (user-specific, not cached)
|
||||
dynamic_models = await self._get_dynamic_models_from_provider_data()
|
||||
|
||||
# Combine, avoiding duplicates (registry takes precedence)
|
||||
registry_identifiers = {m.identifier for m in registry_models}
|
||||
unique_dynamic_models = [m for m in dynamic_models if m.identifier not in registry_identifiers]
|
||||
|
||||
return ListModelsResponse(data=registry_models + unique_dynamic_models)
|
||||
|
||||
async def openai_list_models(self) -> OpenAIListModelsResponse:
|
||||
models = await self.get_all_with_type("model")
|
||||
# Get models from registry
|
||||
registry_models = await self.get_all_with_type("model")
|
||||
|
||||
# Get additional models available via provider_data (user-specific, not cached)
|
||||
dynamic_models = await self._get_dynamic_models_from_provider_data()
|
||||
|
||||
# Combine, avoiding duplicates (registry takes precedence)
|
||||
registry_identifiers = {m.identifier for m in registry_models}
|
||||
unique_dynamic_models = [m for m in dynamic_models if m.identifier not in registry_identifiers]
|
||||
|
||||
all_models = registry_models + unique_dynamic_models
|
||||
|
||||
openai_models = [
|
||||
OpenAIModel(
|
||||
id=model.identifier,
|
||||
object="model",
|
||||
created=int(time.time()),
|
||||
owned_by="llama_stack",
|
||||
custom_metadata={
|
||||
"model_type": model.model_type,
|
||||
"provider_id": model.provider_id,
|
||||
"provider_resource_id": model.provider_resource_id,
|
||||
**model.metadata,
|
||||
},
|
||||
)
|
||||
for model in models
|
||||
for model in all_models
|
||||
]
|
||||
return OpenAIListModelsResponse(data=openai_models)
|
||||
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from typing import Any
|
|||
import yaml
|
||||
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.batches import Batches
|
||||
from llama_stack.apis.benchmarks import Benchmarks
|
||||
from llama_stack.apis.conversations import Conversations
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
|
|
@ -30,7 +31,6 @@ from llama_stack.apis.safety import Safety
|
|||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.apis.scoring_functions import ScoringFunctions
|
||||
from llama_stack.apis.shields import Shields
|
||||
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
|
||||
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
|
||||
|
|
@ -63,8 +63,8 @@ class LlamaStack(
|
|||
Providers,
|
||||
Inference,
|
||||
Agents,
|
||||
Batches,
|
||||
Safety,
|
||||
SyntheticDataGeneration,
|
||||
Datasets,
|
||||
PostTraining,
|
||||
VectorIO,
|
||||
|
|
|
|||
|
|
@ -152,6 +152,37 @@ docker run \
|
|||
--port $LLAMA_STACK_PORT
|
||||
```
|
||||
|
||||
### Via Docker with Custom Run Configuration
|
||||
|
||||
You can also run the Docker container with a custom run configuration file by mounting it into the container:
|
||||
|
||||
```bash
|
||||
# Set the path to your custom run.yaml file
|
||||
CUSTOM_RUN_CONFIG=/path/to/your/custom-run.yaml
|
||||
|
||||
docker run -it \
|
||||
--pull always \
|
||||
--network host \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v $HOME/.llama:/root/.llama \
|
||||
-v $CUSTOM_RUN_CONFIG:/app/custom-run.yaml \
|
||||
-e RUN_CONFIG_PATH=/app/custom-run.yaml \
|
||||
-e INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
-e DEH_URL=$DEH_URL \
|
||||
-e CHROMA_URL=$CHROMA_URL \
|
||||
llamastack/distribution-{{ name }} \
|
||||
--port $LLAMA_STACK_PORT
|
||||
```
|
||||
|
||||
**Note**: The run configuration must be mounted into the container before it can be used. The `-v` flag mounts your local file into the container, and the `RUN_CONFIG_PATH` environment variable tells the entrypoint script which configuration to use.
|
||||
|
||||
{% if run_configs %}
|
||||
Available run configurations for this distribution:
|
||||
{% for config in run_configs %}
|
||||
- `{{ config }}`
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
### Via Conda
|
||||
|
||||
Make sure you have done `pip install llama-stack` and have the Llama Stack CLI available.
|
||||
|
|
|
|||
|
|
@ -68,6 +68,36 @@ docker run \
|
|||
--port $LLAMA_STACK_PORT
|
||||
```
|
||||
|
||||
### Via Docker with Custom Run Configuration
|
||||
|
||||
You can also run the Docker container with a custom run configuration file by mounting it into the container:
|
||||
|
||||
```bash
|
||||
# Set the path to your custom run.yaml file
|
||||
CUSTOM_RUN_CONFIG=/path/to/your/custom-run.yaml
|
||||
LLAMA_STACK_PORT=8321
|
||||
|
||||
docker run \
|
||||
-it \
|
||||
--pull always \
|
||||
--gpu all \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ~/.llama:/root/.llama \
|
||||
-v $CUSTOM_RUN_CONFIG:/app/custom-run.yaml \
|
||||
-e RUN_CONFIG_PATH=/app/custom-run.yaml \
|
||||
llamastack/distribution-{{ name }} \
|
||||
--port $LLAMA_STACK_PORT
|
||||
```
|
||||
|
||||
**Note**: The run configuration must be mounted into the container before it can be used. The `-v` flag mounts your local file into the container, and the `RUN_CONFIG_PATH` environment variable tells the entrypoint script which configuration to use.
|
||||
|
||||
{% if run_configs %}
|
||||
Available run configurations for this distribution:
|
||||
{% for config in run_configs %}
|
||||
- `{{ config }}`
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
### Via venv
|
||||
|
||||
Make sure you have the Llama Stack CLI available.
|
||||
|
|
|
|||
|
|
@ -117,13 +117,42 @@ docker run \
|
|||
-it \
|
||||
--pull always \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ./run.yaml:/root/my-run.yaml \
|
||||
-v ~/.llama:/root/.llama \
|
||||
-e NVIDIA_API_KEY=$NVIDIA_API_KEY \
|
||||
llamastack/distribution-{{ name }} \
|
||||
--config /root/my-run.yaml \
|
||||
--port $LLAMA_STACK_PORT
|
||||
```
|
||||
|
||||
### Via Docker with Custom Run Configuration
|
||||
|
||||
You can also run the Docker container with a custom run configuration file by mounting it into the container:
|
||||
|
||||
```bash
|
||||
# Set the path to your custom run.yaml file
|
||||
CUSTOM_RUN_CONFIG=/path/to/your/custom-run.yaml
|
||||
LLAMA_STACK_PORT=8321
|
||||
|
||||
docker run \
|
||||
-it \
|
||||
--pull always \
|
||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||
-v ~/.llama:/root/.llama \
|
||||
-v $CUSTOM_RUN_CONFIG:/app/custom-run.yaml \
|
||||
-e RUN_CONFIG_PATH=/app/custom-run.yaml \
|
||||
-e NVIDIA_API_KEY=$NVIDIA_API_KEY \
|
||||
llamastack/distribution-{{ name }} \
|
||||
--port $LLAMA_STACK_PORT
|
||||
```
|
||||
|
||||
**Note**: The run configuration must be mounted into the container before it can be used. The `-v` flag mounts your local file into the container, and the `RUN_CONFIG_PATH` environment variable tells the entrypoint script which configuration to use.
|
||||
|
||||
{% if run_configs %}
|
||||
Available run configurations for this distribution:
|
||||
{% for config in run_configs %}
|
||||
- `{{ config }}`
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
### Via venv
|
||||
|
||||
If you've set up your local development environment, you can also install the distribution dependencies using your local virtual environment.
|
||||
|
|
|
|||
|
|
@ -424,6 +424,7 @@ class DistributionTemplate(BaseModel):
|
|||
providers_table=providers_table,
|
||||
run_config_env_vars=self.run_config_env_vars,
|
||||
default_models=default_models,
|
||||
run_configs=list(self.run_configs.keys()),
|
||||
)
|
||||
return ""
|
||||
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import uuid
|
|||
import warnings
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, cast
|
||||
|
||||
import httpx
|
||||
|
||||
|
|
@ -125,12 +126,12 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
|
||||
def turn_to_messages(self, turn: Turn) -> list[Message]:
|
||||
messages = []
|
||||
messages: list[Message] = []
|
||||
|
||||
# NOTE: if a toolcall response is in a step, we do not add it when processing the input messages
|
||||
tool_call_ids = set()
|
||||
for step in turn.steps:
|
||||
if step.step_type == StepType.tool_execution.value:
|
||||
if step.step_type == StepType.tool_execution.value and isinstance(step, ToolExecutionStep):
|
||||
for response in step.tool_responses:
|
||||
tool_call_ids.add(response.call_id)
|
||||
|
||||
|
|
@ -149,9 +150,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
messages.append(msg)
|
||||
|
||||
for step in turn.steps:
|
||||
if step.step_type == StepType.inference.value:
|
||||
if step.step_type == StepType.inference.value and isinstance(step, InferenceStep):
|
||||
messages.append(step.model_response)
|
||||
elif step.step_type == StepType.tool_execution.value:
|
||||
elif step.step_type == StepType.tool_execution.value and isinstance(step, ToolExecutionStep):
|
||||
for response in step.tool_responses:
|
||||
messages.append(
|
||||
ToolResponseMessage(
|
||||
|
|
@ -159,8 +160,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
content=response.content,
|
||||
)
|
||||
)
|
||||
elif step.step_type == StepType.shield_call.value:
|
||||
if step.violation:
|
||||
elif step.step_type == StepType.shield_call.value and isinstance(step, ShieldCallStep):
|
||||
if step.violation and step.violation.user_message:
|
||||
# CompletionMessage itself in the ShieldResponse
|
||||
messages.append(
|
||||
CompletionMessage(
|
||||
|
|
@ -174,7 +175,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
return await self.storage.create_session(name)
|
||||
|
||||
async def get_messages_from_turns(self, turns: list[Turn]) -> list[Message]:
|
||||
messages = []
|
||||
messages: list[Message] = []
|
||||
if self.agent_config.instructions != "":
|
||||
messages.append(SystemMessage(content=self.agent_config.instructions))
|
||||
|
||||
|
|
@ -231,7 +232,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
steps = []
|
||||
messages = await self.get_messages_from_turns(turns)
|
||||
|
||||
if is_resume:
|
||||
assert isinstance(request, AgentTurnResumeRequest)
|
||||
tool_response_messages = [
|
||||
ToolResponseMessage(call_id=x.call_id, content=x.content) for x in request.tool_responses
|
||||
]
|
||||
|
|
@ -252,42 +255,52 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
|
||||
request.session_id, request.turn_id
|
||||
)
|
||||
now = datetime.now(UTC).isoformat()
|
||||
now_dt = datetime.now(UTC)
|
||||
tool_execution_step = ToolExecutionStep(
|
||||
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
|
||||
turn_id=request.turn_id,
|
||||
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
|
||||
tool_responses=request.tool_responses,
|
||||
completed_at=now,
|
||||
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now),
|
||||
completed_at=now_dt,
|
||||
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now_dt),
|
||||
)
|
||||
steps.append(tool_execution_step)
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_type=StepType.tool_execution,
|
||||
step_id=tool_execution_step.step_id,
|
||||
step_details=tool_execution_step,
|
||||
)
|
||||
)
|
||||
)
|
||||
input_messages = last_turn.input_messages
|
||||
# Cast needed due to list invariance - last_turn.input_messages is the right type
|
||||
input_messages = last_turn.input_messages # type: ignore[assignment]
|
||||
|
||||
turn_id = request.turn_id
|
||||
actual_turn_id = request.turn_id
|
||||
start_time = last_turn.started_at
|
||||
else:
|
||||
assert isinstance(request, AgentTurnCreateRequest)
|
||||
messages.extend(request.messages)
|
||||
start_time = datetime.now(UTC).isoformat()
|
||||
input_messages = request.messages
|
||||
start_time = datetime.now(UTC)
|
||||
# Cast needed due to list invariance - request.messages is the right type
|
||||
input_messages = request.messages # type: ignore[assignment]
|
||||
# Use the generated turn_id from beginning of function
|
||||
actual_turn_id = turn_id if turn_id else str(uuid.uuid4())
|
||||
|
||||
output_message = None
|
||||
req_documents = request.documents if isinstance(request, AgentTurnCreateRequest) and not is_resume else None
|
||||
req_sampling = (
|
||||
self.agent_config.sampling_params if self.agent_config.sampling_params is not None else SamplingParams()
|
||||
)
|
||||
|
||||
async for chunk in self.run(
|
||||
session_id=request.session_id,
|
||||
turn_id=turn_id,
|
||||
turn_id=actual_turn_id,
|
||||
input_messages=messages,
|
||||
sampling_params=self.agent_config.sampling_params,
|
||||
sampling_params=req_sampling,
|
||||
stream=request.stream,
|
||||
documents=request.documents if not is_resume else None,
|
||||
documents=req_documents,
|
||||
):
|
||||
if isinstance(chunk, CompletionMessage):
|
||||
output_message = chunk
|
||||
|
|
@ -295,20 +308,23 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
|
||||
event = chunk.event
|
||||
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
|
||||
steps.append(event.payload.step_details)
|
||||
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value and hasattr(
|
||||
event.payload, "step_details"
|
||||
):
|
||||
step_details = event.payload.step_details
|
||||
steps.append(step_details)
|
||||
|
||||
yield chunk
|
||||
|
||||
assert output_message is not None
|
||||
|
||||
turn = Turn(
|
||||
turn_id=turn_id,
|
||||
turn_id=actual_turn_id,
|
||||
session_id=request.session_id,
|
||||
input_messages=input_messages,
|
||||
input_messages=input_messages, # type: ignore[arg-type]
|
||||
output_message=output_message,
|
||||
started_at=start_time,
|
||||
completed_at=datetime.now(UTC).isoformat(),
|
||||
completed_at=datetime.now(UTC),
|
||||
steps=steps,
|
||||
)
|
||||
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||
|
|
@ -345,9 +361,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
# return a "final value" for the `yield from` statement. we simulate that by yielding a
|
||||
# final boolean (to see whether an exception happened) and then explicitly testing for it.
|
||||
|
||||
if len(self.input_shields) > 0:
|
||||
if self.input_shields:
|
||||
async for res in self.run_multiple_shields_wrapper(
|
||||
turn_id, input_messages, self.input_shields, "user-input"
|
||||
turn_id, cast(list[OpenAIMessageParam], input_messages), self.input_shields, "user-input"
|
||||
):
|
||||
if isinstance(res, bool):
|
||||
return
|
||||
|
|
@ -374,9 +390,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
# for output shields run on the full input and output combination
|
||||
messages = input_messages + [final_response]
|
||||
|
||||
if len(self.output_shields) > 0:
|
||||
if self.output_shields:
|
||||
async for res in self.run_multiple_shields_wrapper(
|
||||
turn_id, messages, self.output_shields, "assistant-output"
|
||||
turn_id, cast(list[OpenAIMessageParam], messages), self.output_shields, "assistant-output"
|
||||
):
|
||||
if isinstance(res, bool):
|
||||
return
|
||||
|
|
@ -388,7 +404,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
async def run_multiple_shields_wrapper(
|
||||
self,
|
||||
turn_id: str,
|
||||
messages: list[Message],
|
||||
messages: list[OpenAIMessageParam],
|
||||
shields: list[str],
|
||||
touchpoint: str,
|
||||
) -> AsyncGenerator:
|
||||
|
|
@ -402,12 +418,12 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
return
|
||||
|
||||
step_id = str(uuid.uuid4())
|
||||
shield_call_start_time = datetime.now(UTC).isoformat()
|
||||
shield_call_start_time = datetime.now(UTC)
|
||||
try:
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepStartPayload(
|
||||
step_type=StepType.shield_call.value,
|
||||
step_type=StepType.shield_call,
|
||||
step_id=step_id,
|
||||
metadata=dict(touchpoint=touchpoint),
|
||||
)
|
||||
|
|
@ -419,14 +435,14 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.shield_call.value,
|
||||
step_type=StepType.shield_call,
|
||||
step_id=step_id,
|
||||
step_details=ShieldCallStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
violation=e.violation,
|
||||
started_at=shield_call_start_time,
|
||||
completed_at=datetime.now(UTC).isoformat(),
|
||||
completed_at=datetime.now(UTC),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
|
@ -443,14 +459,14 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.shield_call.value,
|
||||
step_type=StepType.shield_call,
|
||||
step_id=step_id,
|
||||
step_details=ShieldCallStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
violation=None,
|
||||
started_at=shield_call_start_time,
|
||||
completed_at=datetime.now(UTC).isoformat(),
|
||||
completed_at=datetime.now(UTC),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
|
@ -496,21 +512,22 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
else:
|
||||
self.tool_name_to_args[tool_name]["vector_store_ids"].append(session_info.vector_store_id)
|
||||
|
||||
output_attachments = []
|
||||
output_attachments: list[Attachment] = []
|
||||
|
||||
n_iter = await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0
|
||||
|
||||
# Build a map of custom tools to their definitions for faster lookup
|
||||
client_tools = {}
|
||||
if self.agent_config.client_tools:
|
||||
for tool in self.agent_config.client_tools:
|
||||
client_tools[tool.name] = tool
|
||||
while True:
|
||||
step_id = str(uuid.uuid4())
|
||||
inference_start_time = datetime.now(UTC).isoformat()
|
||||
inference_start_time = datetime.now(UTC)
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepStartPayload(
|
||||
step_type=StepType.inference.value,
|
||||
step_type=StepType.inference,
|
||||
step_id=step_id,
|
||||
)
|
||||
)
|
||||
|
|
@ -538,7 +555,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
else:
|
||||
return value
|
||||
|
||||
def _add_type(openai_msg: dict) -> OpenAIMessageParam:
|
||||
def _add_type(openai_msg: Any) -> OpenAIMessageParam:
|
||||
# Serialize any nested Pydantic models to plain dicts
|
||||
openai_msg = _serialize_nested(openai_msg)
|
||||
|
||||
|
|
@ -588,7 +605,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
messages=openai_messages,
|
||||
tools=openai_tools if openai_tools else None,
|
||||
tool_choice=tool_choice,
|
||||
response_format=self.agent_config.response_format,
|
||||
response_format=self.agent_config.response_format, # type: ignore[arg-type]
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_tokens=max_tokens,
|
||||
|
|
@ -598,7 +615,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
# Convert OpenAI stream back to Llama Stack format
|
||||
response_stream = convert_openai_chat_completion_stream(
|
||||
openai_stream, enable_incremental_tool_calls=True
|
||||
openai_stream, # type: ignore[arg-type]
|
||||
enable_incremental_tool_calls=True,
|
||||
)
|
||||
|
||||
async for chunk in response_stream:
|
||||
|
|
@ -620,7 +638,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepProgressPayload(
|
||||
step_type=StepType.inference.value,
|
||||
step_type=StepType.inference,
|
||||
step_id=step_id,
|
||||
delta=delta,
|
||||
)
|
||||
|
|
@ -633,7 +651,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepProgressPayload(
|
||||
step_type=StepType.inference.value,
|
||||
step_type=StepType.inference,
|
||||
step_id=step_id,
|
||||
delta=delta,
|
||||
)
|
||||
|
|
@ -651,7 +669,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
output_attr = json.dumps(
|
||||
{
|
||||
"content": content,
|
||||
"tool_calls": [json.loads(t.model_dump_json()) for t in tool_calls],
|
||||
"tool_calls": [
|
||||
json.loads(t.model_dump_json()) for t in tool_calls if isinstance(t, ToolCall)
|
||||
],
|
||||
}
|
||||
)
|
||||
span.set_attribute("output", output_attr)
|
||||
|
|
@ -667,16 +687,18 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
if tool_calls:
|
||||
content = ""
|
||||
|
||||
# Filter out string tool calls for CompletionMessage (only keep ToolCall objects)
|
||||
valid_tool_calls = [t for t in tool_calls if isinstance(t, ToolCall)]
|
||||
message = CompletionMessage(
|
||||
content=content,
|
||||
stop_reason=stop_reason,
|
||||
tool_calls=tool_calls,
|
||||
tool_calls=valid_tool_calls if valid_tool_calls else None,
|
||||
)
|
||||
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.inference.value,
|
||||
step_type=StepType.inference,
|
||||
step_id=step_id,
|
||||
step_details=InferenceStep(
|
||||
# somewhere deep, we are re-assigning message or closing over some
|
||||
|
|
@ -686,13 +708,14 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
turn_id=turn_id,
|
||||
model_response=copy.deepcopy(message),
|
||||
started_at=inference_start_time,
|
||||
completed_at=datetime.now(UTC).isoformat(),
|
||||
completed_at=datetime.now(UTC),
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if n_iter >= self.agent_config.max_infer_iters:
|
||||
max_iters = self.agent_config.max_infer_iters if self.agent_config.max_infer_iters is not None else 10
|
||||
if n_iter >= max_iters:
|
||||
logger.info(f"done with MAX iterations ({n_iter}), exiting.")
|
||||
# NOTE: mark end_of_turn to indicate to client that we are done with the turn
|
||||
# Do not continue the tool call loop after this point
|
||||
|
|
@ -705,14 +728,16 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield message
|
||||
break
|
||||
|
||||
if len(message.tool_calls) == 0:
|
||||
if not message.tool_calls or len(message.tool_calls) == 0:
|
||||
if stop_reason == StopReason.end_of_turn:
|
||||
# TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS)
|
||||
if len(output_attachments) > 0:
|
||||
if isinstance(message.content, list):
|
||||
message.content += output_attachments
|
||||
# List invariance - attachments are compatible at runtime
|
||||
message.content += output_attachments # type: ignore[arg-type]
|
||||
else:
|
||||
message.content = [message.content] + output_attachments
|
||||
# List invariance - attachments are compatible at runtime
|
||||
message.content = [message.content] + output_attachments # type: ignore[assignment]
|
||||
yield message
|
||||
else:
|
||||
logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}")
|
||||
|
|
@ -725,6 +750,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
non_client_tool_calls = []
|
||||
|
||||
# Separate client and non-client tool calls
|
||||
if message.tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
if tool_call.tool_name in client_tools:
|
||||
client_tool_calls.append(tool_call)
|
||||
|
|
@ -737,7 +763,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepStartPayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_type=StepType.tool_execution,
|
||||
step_id=step_id,
|
||||
)
|
||||
)
|
||||
|
|
@ -746,7 +772,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepProgressPayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_type=StepType.tool_execution,
|
||||
step_id=step_id,
|
||||
delta=ToolCallDelta(
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
|
|
@ -766,7 +792,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
if self.telemetry_enabled
|
||||
else {},
|
||||
) as span:
|
||||
tool_execution_start_time = datetime.now(UTC).isoformat()
|
||||
tool_execution_start_time = datetime.now(UTC)
|
||||
tool_result = await self.execute_tool_call_maybe(
|
||||
session_id,
|
||||
tool_call,
|
||||
|
|
@ -796,14 +822,14 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
],
|
||||
started_at=tool_execution_start_time,
|
||||
completed_at=datetime.now(UTC).isoformat(),
|
||||
completed_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
# Yield the step completion event
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.tool_execution.value,
|
||||
step_type=StepType.tool_execution,
|
||||
step_id=step_id,
|
||||
step_details=tool_execution_step,
|
||||
)
|
||||
|
|
@ -833,7 +859,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
turn_id=turn_id,
|
||||
tool_calls=client_tool_calls,
|
||||
tool_responses=[],
|
||||
started_at=datetime.now(UTC).isoformat(),
|
||||
started_at=datetime.now(UTC),
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -868,9 +894,10 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
toolgroup_to_args = toolgroup_to_args or {}
|
||||
|
||||
tool_name_to_def = {}
|
||||
tool_name_to_def: dict[str, ToolDefinition] = {}
|
||||
tool_name_to_args = {}
|
||||
|
||||
if self.agent_config.client_tools:
|
||||
for tool_def in self.agent_config.client_tools:
|
||||
if tool_name_to_def.get(tool_def.name, None):
|
||||
raise ValueError(f"Tool {tool_def.name} already exists")
|
||||
|
|
@ -908,15 +935,17 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
else:
|
||||
identifier = None
|
||||
|
||||
if tool_name_to_def.get(identifier, None):
|
||||
raise ValueError(f"Tool {identifier} already exists")
|
||||
if identifier:
|
||||
tool_name_to_def[identifier] = ToolDefinition(
|
||||
tool_name=identifier,
|
||||
# Convert BuiltinTool to string for dictionary key
|
||||
identifier_str = identifier.value if isinstance(identifier, BuiltinTool) else identifier
|
||||
if tool_name_to_def.get(identifier_str, None):
|
||||
raise ValueError(f"Tool {identifier_str} already exists")
|
||||
tool_name_to_def[identifier_str] = ToolDefinition(
|
||||
tool_name=identifier_str,
|
||||
description=tool_def.description,
|
||||
input_schema=tool_def.input_schema,
|
||||
)
|
||||
tool_name_to_args[identifier] = toolgroup_to_args.get(toolgroup_name, {})
|
||||
tool_name_to_args[identifier_str] = toolgroup_to_args.get(toolgroup_name, {})
|
||||
|
||||
self.tool_defs, self.tool_name_to_args = (
|
||||
list(tool_name_to_def.values()),
|
||||
|
|
@ -966,7 +995,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Failed to parse arguments for tool call: {tool_call.arguments}") from e
|
||||
|
||||
result = await self.tool_runtime_api.invoke_tool(
|
||||
result = cast(
|
||||
ToolInvocationResult,
|
||||
await self.tool_runtime_api.invoke_tool(
|
||||
tool_name=tool_name_str,
|
||||
kwargs={
|
||||
"session_id": session_id,
|
||||
|
|
@ -974,6 +1005,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
**args,
|
||||
**self.tool_name_to_args.get(tool_name_str, {}),
|
||||
},
|
||||
),
|
||||
)
|
||||
logger.debug(f"tool call {tool_name_str} completed with result: {result}")
|
||||
return result
|
||||
|
|
@ -983,7 +1015,7 @@ async def load_data_from_url(url: str) -> str:
|
|||
if url.startswith("http"):
|
||||
async with httpx.AsyncClient() as client:
|
||||
r = await client.get(url)
|
||||
resp = r.text
|
||||
resp: str = r.text
|
||||
return resp
|
||||
raise ValueError(f"Unexpected URL: {type(url)}")
|
||||
|
||||
|
|
@ -1017,7 +1049,7 @@ def _interpret_content_as_attachment(
|
|||
snippet = match.group(1)
|
||||
data = json.loads(snippet)
|
||||
return Attachment(
|
||||
url=URL(uri="file://" + data["filepath"]),
|
||||
content=URL(uri="file://" + data["filepath"]),
|
||||
mime_type=data["mimetype"],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ from llama_stack.apis.agents import (
|
|||
Document,
|
||||
ListOpenAIResponseInputItem,
|
||||
ListOpenAIResponseObject,
|
||||
OpenAIDeleteResponseObject,
|
||||
OpenAIResponseInput,
|
||||
OpenAIResponseInputTool,
|
||||
OpenAIResponseObject,
|
||||
|
|
@ -141,7 +142,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
persistence_store=(
|
||||
self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store
|
||||
),
|
||||
created_at=agent_info.created_at,
|
||||
created_at=agent_info.created_at.isoformat(),
|
||||
policy=self.policy,
|
||||
telemetry_enabled=self.telemetry_enabled,
|
||||
)
|
||||
|
|
@ -163,9 +164,9 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
agent_id: str,
|
||||
session_id: str,
|
||||
messages: list[UserMessage | ToolResponseMessage],
|
||||
toolgroups: list[AgentToolGroup] | None = None,
|
||||
documents: list[Document] | None = None,
|
||||
stream: bool | None = False,
|
||||
documents: list[Document] | None = None,
|
||||
toolgroups: list[AgentToolGroup] | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> AsyncGenerator:
|
||||
request = AgentTurnCreateRequest(
|
||||
|
|
@ -221,6 +222,8 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
turn = await agent.storage.get_session_turn(session_id, turn_id)
|
||||
if turn is None:
|
||||
raise ValueError(f"Turn {turn_id} not found in session {session_id}")
|
||||
return turn
|
||||
|
||||
async def get_agents_step(self, agent_id: str, session_id: str, turn_id: str, step_id: str) -> AgentStepResponse:
|
||||
|
|
@ -232,13 +235,15 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
|
||||
async def get_agents_session(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
agent_id: str,
|
||||
turn_ids: list[str] | None = None,
|
||||
) -> Session:
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
|
||||
session_info = await agent.storage.get_session_info(session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
turns = await agent.storage.get_session_turns(session_id)
|
||||
if turn_ids:
|
||||
turns = [turn for turn in turns if turn.turn_id in turn_ids]
|
||||
|
|
@ -249,7 +254,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
started_at=session_info.started_at,
|
||||
)
|
||||
|
||||
async def delete_agents_session(self, agent_id: str, session_id: str) -> None:
|
||||
async def delete_agents_session(self, session_id: str, agent_id: str) -> None:
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
|
||||
# Delete turns first, then the session
|
||||
|
|
@ -302,7 +307,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
agent = Agent(
|
||||
agent_id=agent_id,
|
||||
agent_config=chat_agent.agent_config,
|
||||
created_at=chat_agent.created_at,
|
||||
created_at=datetime.fromisoformat(chat_agent.created_at),
|
||||
)
|
||||
return agent
|
||||
|
||||
|
|
@ -323,6 +328,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self,
|
||||
response_id: str,
|
||||
) -> OpenAIResponseObject:
|
||||
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
|
||||
return await self.openai_responses_impl.get_openai_response(response_id)
|
||||
|
||||
async def create_openai_response(
|
||||
|
|
@ -342,7 +348,8 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
max_infer_iters: int | None = 10,
|
||||
guardrails: list[ResponseGuardrail] | None = None,
|
||||
) -> OpenAIResponseObject:
|
||||
return await self.openai_responses_impl.create_openai_response(
|
||||
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
|
||||
result = await self.openai_responses_impl.create_openai_response(
|
||||
input,
|
||||
model,
|
||||
prompt,
|
||||
|
|
@ -358,6 +365,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
max_infer_iters,
|
||||
guardrails,
|
||||
)
|
||||
return result # type: ignore[no-any-return]
|
||||
|
||||
async def list_openai_responses(
|
||||
self,
|
||||
|
|
@ -366,6 +374,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
model: str | None = None,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIResponseObject:
|
||||
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
|
||||
return await self.openai_responses_impl.list_openai_responses(after, limit, model, order)
|
||||
|
||||
async def list_openai_response_input_items(
|
||||
|
|
@ -377,9 +386,11 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
limit: int | None = 20,
|
||||
order: Order | None = Order.desc,
|
||||
) -> ListOpenAIResponseInputItem:
|
||||
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
|
||||
return await self.openai_responses_impl.list_openai_response_input_items(
|
||||
response_id, after, before, include, limit, order
|
||||
)
|
||||
|
||||
async def delete_openai_response(self, response_id: str) -> None:
|
||||
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
|
||||
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
|
||||
return await self.openai_responses_impl.delete_openai_response(response_id)
|
||||
|
|
|
|||
|
|
@ -6,12 +6,14 @@
|
|||
|
||||
import json
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
|
||||
from llama_stack.apis.common.errors import SessionNotFoundError
|
||||
from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed
|
||||
from llama_stack.core.access_control.datatypes import AccessRule
|
||||
from llama_stack.core.access_control.conditions import User as ProtocolUser
|
||||
from llama_stack.core.access_control.datatypes import AccessRule, Action
|
||||
from llama_stack.core.datatypes import User
|
||||
from llama_stack.core.request_headers import get_authenticated_user
|
||||
from llama_stack.log import get_logger
|
||||
|
|
@ -33,6 +35,15 @@ class AgentInfo(AgentConfig):
|
|||
created_at: datetime
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionResource:
|
||||
"""Concrete implementation of ProtectedResource for session access control."""
|
||||
|
||||
type: str
|
||||
identifier: str
|
||||
owner: ProtocolUser # Use the protocol type for structural compatibility
|
||||
|
||||
|
||||
class AgentPersistence:
|
||||
def __init__(self, agent_id: str, kvstore: KVStore, policy: list[AccessRule]):
|
||||
self.agent_id = agent_id
|
||||
|
|
@ -53,8 +64,15 @@ class AgentPersistence:
|
|||
turns=[],
|
||||
identifier=name, # should this be qualified in any way?
|
||||
)
|
||||
if not is_action_allowed(self.policy, "create", session_info, user):
|
||||
raise AccessDeniedError("create", session_info, user)
|
||||
# Only perform access control if we have an authenticated user
|
||||
if user is not None and session_info.identifier is not None:
|
||||
resource = SessionResource(
|
||||
type=session_info.type,
|
||||
identifier=session_info.identifier,
|
||||
owner=user,
|
||||
)
|
||||
if not is_action_allowed(self.policy, Action.CREATE, resource, user):
|
||||
raise AccessDeniedError(Action.CREATE, resource, user)
|
||||
|
||||
await self.kvstore.set(
|
||||
key=f"session:{self.agent_id}:{session_id}",
|
||||
|
|
@ -62,7 +80,7 @@ class AgentPersistence:
|
|||
)
|
||||
return session_id
|
||||
|
||||
async def get_session_info(self, session_id: str) -> AgentSessionInfo:
|
||||
async def get_session_info(self, session_id: str) -> AgentSessionInfo | None:
|
||||
value = await self.kvstore.get(
|
||||
key=f"session:{self.agent_id}:{session_id}",
|
||||
)
|
||||
|
|
@ -83,7 +101,22 @@ class AgentPersistence:
|
|||
if not hasattr(session_info, "access_attributes") and not hasattr(session_info, "owner"):
|
||||
return True
|
||||
|
||||
return is_action_allowed(self.policy, "read", session_info, get_authenticated_user())
|
||||
# Get current user - if None, skip access control (e.g., in tests)
|
||||
user = get_authenticated_user()
|
||||
if user is None:
|
||||
return True
|
||||
|
||||
# Access control requires identifier and owner to be set
|
||||
if session_info.identifier is None or session_info.owner is None:
|
||||
return True
|
||||
|
||||
# At this point, both identifier and owner are guaranteed to be non-None
|
||||
resource = SessionResource(
|
||||
type=session_info.type,
|
||||
identifier=session_info.identifier,
|
||||
owner=session_info.owner,
|
||||
)
|
||||
return is_action_allowed(self.policy, Action.READ, resource, user)
|
||||
|
||||
async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None:
|
||||
"""Get session info if the user has access to it. For internal use by sub-session methods."""
|
||||
|
|
|
|||
|
|
@ -91,7 +91,8 @@ class OpenAIResponsesImpl:
|
|||
input: str | list[OpenAIResponseInput],
|
||||
previous_response: _OpenAIResponseObjectWithInputAndMessages,
|
||||
):
|
||||
new_input_items = previous_response.input.copy()
|
||||
# Convert Sequence to list for mutation
|
||||
new_input_items = list(previous_response.input)
|
||||
new_input_items.extend(previous_response.output)
|
||||
|
||||
if isinstance(input, str):
|
||||
|
|
@ -107,7 +108,7 @@ class OpenAIResponsesImpl:
|
|||
tools: list[OpenAIResponseInputTool] | None,
|
||||
previous_response_id: str | None,
|
||||
conversation: str | None,
|
||||
) -> tuple[str | list[OpenAIResponseInput], list[OpenAIMessageParam]]:
|
||||
) -> tuple[str | list[OpenAIResponseInput], list[OpenAIMessageParam], ToolContext]:
|
||||
"""Process input with optional previous response context.
|
||||
|
||||
Returns:
|
||||
|
|
@ -208,6 +209,9 @@ class OpenAIResponsesImpl:
|
|||
messages: list[OpenAIMessageParam],
|
||||
) -> None:
|
||||
new_input_id = f"msg_{uuid.uuid4()}"
|
||||
# Type input_items_data as the full OpenAIResponseInput union to avoid list invariance issues
|
||||
input_items_data: list[OpenAIResponseInput] = []
|
||||
|
||||
if isinstance(input, str):
|
||||
# synthesize a message from the input string
|
||||
input_content = OpenAIResponseInputMessageContentText(text=input)
|
||||
|
|
@ -219,7 +223,6 @@ class OpenAIResponsesImpl:
|
|||
input_items_data = [input_content_item]
|
||||
else:
|
||||
# we already have a list of messages
|
||||
input_items_data = []
|
||||
for input_item in input:
|
||||
if isinstance(input_item, OpenAIResponseMessage):
|
||||
# These may or may not already have an id, so dump to dict, check for id, and add if missing
|
||||
|
|
@ -251,7 +254,7 @@ class OpenAIResponsesImpl:
|
|||
tools: list[OpenAIResponseInputTool] | None = None,
|
||||
include: list[str] | None = None,
|
||||
max_infer_iters: int | None = 10,
|
||||
guardrails: list[ResponseGuardrailSpec] | None = None,
|
||||
guardrails: list[str | ResponseGuardrailSpec] | None = None,
|
||||
):
|
||||
stream = bool(stream)
|
||||
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
|
||||
|
|
@ -289,7 +292,8 @@ class OpenAIResponsesImpl:
|
|||
failed_response = None
|
||||
|
||||
async for stream_chunk in stream_gen:
|
||||
if stream_chunk.type in {"response.completed", "response.incomplete"}:
|
||||
match stream_chunk.type:
|
||||
case "response.completed" | "response.incomplete":
|
||||
if final_response is not None:
|
||||
raise ValueError(
|
||||
"The response stream produced multiple terminal responses! "
|
||||
|
|
@ -297,8 +301,10 @@ class OpenAIResponsesImpl:
|
|||
)
|
||||
final_response = stream_chunk.response
|
||||
final_event_type = stream_chunk.type
|
||||
elif stream_chunk.type == "response.failed":
|
||||
case "response.failed":
|
||||
failed_response = stream_chunk.response
|
||||
case _:
|
||||
pass # Other event types don't have .response
|
||||
|
||||
if failed_response is not None:
|
||||
error_message = (
|
||||
|
|
@ -326,6 +332,11 @@ class OpenAIResponsesImpl:
|
|||
max_infer_iters: int | None = 10,
|
||||
guardrail_ids: list[str] | None = None,
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# These should never be None when called from create_openai_response (which sets defaults)
|
||||
# but we assert here to help mypy understand the types
|
||||
assert text is not None, "text must not be None"
|
||||
assert max_infer_iters is not None, "max_infer_iters must not be None"
|
||||
|
||||
# Input preprocessing
|
||||
all_input, messages, tool_context = await self._process_input_with_previous_response(
|
||||
input, tools, previous_response_id, conversation
|
||||
|
|
@ -368,16 +379,19 @@ class OpenAIResponsesImpl:
|
|||
final_response = None
|
||||
failed_response = None
|
||||
|
||||
output_items = []
|
||||
# Type as ConversationItem to avoid list invariance issues
|
||||
output_items: list[ConversationItem] = []
|
||||
async for stream_chunk in orchestrator.create_response():
|
||||
if stream_chunk.type in {"response.completed", "response.incomplete"}:
|
||||
match stream_chunk.type:
|
||||
case "response.completed" | "response.incomplete":
|
||||
final_response = stream_chunk.response
|
||||
elif stream_chunk.type == "response.failed":
|
||||
case "response.failed":
|
||||
failed_response = stream_chunk.response
|
||||
|
||||
if stream_chunk.type == "response.output_item.done":
|
||||
case "response.output_item.done":
|
||||
item = stream_chunk.item
|
||||
output_items.append(item)
|
||||
case _:
|
||||
pass # Other event types
|
||||
|
||||
# Store and sync before yielding terminal events
|
||||
# This ensures the storage/syncing happens even if the consumer breaks after receiving the event
|
||||
|
|
@ -410,7 +424,8 @@ class OpenAIResponsesImpl:
|
|||
self, conversation_id: str, input: str | list[OpenAIResponseInput] | None, output_items: list[ConversationItem]
|
||||
) -> None:
|
||||
"""Sync content and response messages to the conversation."""
|
||||
conversation_items = []
|
||||
# Type as ConversationItem union to avoid list invariance issues
|
||||
conversation_items: list[ConversationItem] = []
|
||||
|
||||
if isinstance(input, str):
|
||||
conversation_items.append(
|
||||
|
|
|
|||
|
|
@ -111,7 +111,7 @@ class StreamingResponseOrchestrator:
|
|||
text: OpenAIResponseText,
|
||||
max_infer_iters: int,
|
||||
tool_executor, # Will be the tool execution logic from the main class
|
||||
instructions: str,
|
||||
instructions: str | None,
|
||||
safety_api,
|
||||
guardrail_ids: list[str] | None = None,
|
||||
prompt: OpenAIResponsePrompt | None = None,
|
||||
|
|
@ -128,7 +128,9 @@ class StreamingResponseOrchestrator:
|
|||
self.prompt = prompt
|
||||
self.sequence_number = 0
|
||||
# Store MCP tool mapping that gets built during tool processing
|
||||
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = ctx.tool_context.previous_tools or {}
|
||||
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = (
|
||||
ctx.tool_context.previous_tools if ctx.tool_context else {}
|
||||
)
|
||||
# Track final messages after all tool executions
|
||||
self.final_messages: list[OpenAIMessageParam] = []
|
||||
# mapping for annotations
|
||||
|
|
@ -229,7 +231,8 @@ class StreamingResponseOrchestrator:
|
|||
params = OpenAIChatCompletionRequestWithExtraBody(
|
||||
model=self.ctx.model,
|
||||
messages=messages,
|
||||
tools=self.ctx.chat_tools,
|
||||
# Pydantic models are dict-compatible but mypy treats them as distinct types
|
||||
tools=self.ctx.chat_tools, # type: ignore[arg-type]
|
||||
stream=True,
|
||||
temperature=self.ctx.temperature,
|
||||
response_format=response_format,
|
||||
|
|
@ -272,7 +275,12 @@ class StreamingResponseOrchestrator:
|
|||
|
||||
# Handle choices with no tool calls
|
||||
for choice in current_response.choices:
|
||||
if not (choice.message.tool_calls and self.ctx.response_tools):
|
||||
has_tool_calls = (
|
||||
isinstance(choice.message, OpenAIAssistantMessageParam)
|
||||
and choice.message.tool_calls
|
||||
and self.ctx.response_tools
|
||||
)
|
||||
if not has_tool_calls:
|
||||
output_messages.append(
|
||||
await convert_chat_choice_to_response_message(
|
||||
choice,
|
||||
|
|
@ -722,7 +730,10 @@ class StreamingResponseOrchestrator:
|
|||
)
|
||||
|
||||
# Accumulate arguments for final response (only for subsequent chunks)
|
||||
if not is_new_tool_call:
|
||||
if not is_new_tool_call and response_tool_call is not None:
|
||||
# Both should have functions since we're inside the tool_call.function check above
|
||||
assert response_tool_call.function is not None
|
||||
assert tool_call.function is not None
|
||||
response_tool_call.function.arguments = (
|
||||
response_tool_call.function.arguments or ""
|
||||
) + tool_call.function.arguments
|
||||
|
|
@ -747,10 +758,13 @@ class StreamingResponseOrchestrator:
|
|||
for tool_call_index in sorted(chat_response_tool_calls.keys()):
|
||||
tool_call = chat_response_tool_calls[tool_call_index]
|
||||
# Ensure that arguments, if sent back to the inference provider, are not None
|
||||
if tool_call.function:
|
||||
tool_call.function.arguments = tool_call.function.arguments or "{}"
|
||||
tool_call_item_id = tool_call_item_ids[tool_call_index]
|
||||
final_arguments = tool_call.function.arguments
|
||||
tool_call_name = chat_response_tool_calls[tool_call_index].function.name
|
||||
final_arguments: str = tool_call.function.arguments or "{}" if tool_call.function else "{}"
|
||||
func = chat_response_tool_calls[tool_call_index].function
|
||||
|
||||
tool_call_name = func.name if func else ""
|
||||
|
||||
# Check if this is an MCP tool call
|
||||
is_mcp_tool = tool_call_name and tool_call_name in self.mcp_tool_to_server
|
||||
|
|
@ -894,12 +908,11 @@ class StreamingResponseOrchestrator:
|
|||
|
||||
self.sequence_number += 1
|
||||
if tool_call.function.name and tool_call.function.name in self.mcp_tool_to_server:
|
||||
item = OpenAIResponseOutputMessageMCPCall(
|
||||
item: OpenAIResponseOutput = OpenAIResponseOutputMessageMCPCall(
|
||||
arguments="",
|
||||
name=tool_call.function.name,
|
||||
id=matching_item_id,
|
||||
server_label=self.mcp_tool_to_server[tool_call.function.name].server_label,
|
||||
status="in_progress",
|
||||
)
|
||||
elif tool_call.function.name == "web_search":
|
||||
item = OpenAIResponseOutputMessageWebSearchToolCall(
|
||||
|
|
@ -1008,7 +1021,7 @@ class StreamingResponseOrchestrator:
|
|||
description=tool.description,
|
||||
input_schema=tool.input_schema,
|
||||
)
|
||||
return convert_tooldef_to_openai_tool(tool_def)
|
||||
return convert_tooldef_to_openai_tool(tool_def) # type: ignore[return-value] # Returns dict but ChatCompletionToolParam expects TypedDict
|
||||
|
||||
# Initialize chat_tools if not already set
|
||||
if self.ctx.chat_tools is None:
|
||||
|
|
@ -1016,7 +1029,7 @@ class StreamingResponseOrchestrator:
|
|||
|
||||
for input_tool in tools:
|
||||
if input_tool.type == "function":
|
||||
self.ctx.chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump()))
|
||||
self.ctx.chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump())) # type: ignore[typeddict-item,arg-type] # Dict compatible with FunctionDefinition
|
||||
elif input_tool.type in WebSearchToolTypes:
|
||||
tool_name = "web_search"
|
||||
# Need to access tool_groups_api from tool_executor
|
||||
|
|
@ -1055,8 +1068,8 @@ class StreamingResponseOrchestrator:
|
|||
if isinstance(mcp_tool.allowed_tools, list):
|
||||
always_allowed = mcp_tool.allowed_tools
|
||||
elif isinstance(mcp_tool.allowed_tools, AllowedToolsFilter):
|
||||
always_allowed = mcp_tool.allowed_tools.always
|
||||
never_allowed = mcp_tool.allowed_tools.never
|
||||
# AllowedToolsFilter only has tool_names field (not allowed/disallowed)
|
||||
always_allowed = mcp_tool.allowed_tools.tool_names
|
||||
|
||||
# Call list_mcp_tools
|
||||
tool_defs = None
|
||||
|
|
@ -1088,7 +1101,7 @@ class StreamingResponseOrchestrator:
|
|||
openai_tool = convert_tooldef_to_chat_tool(t)
|
||||
if self.ctx.chat_tools is None:
|
||||
self.ctx.chat_tools = []
|
||||
self.ctx.chat_tools.append(openai_tool)
|
||||
self.ctx.chat_tools.append(openai_tool) # type: ignore[arg-type] # Returns dict but ChatCompletionToolParam expects TypedDict
|
||||
|
||||
# Add to MCP tool mapping
|
||||
if t.name in self.mcp_tool_to_server:
|
||||
|
|
@ -1120,12 +1133,16 @@ class StreamingResponseOrchestrator:
|
|||
self, output_messages: list[OpenAIResponseOutput]
|
||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||
# Handle all mcp tool lists from previous response that are still valid:
|
||||
# tool_context can be None when no tools are provided in the response request
|
||||
if self.ctx.tool_context:
|
||||
for tool in self.ctx.tool_context.previous_tool_listings:
|
||||
async for evt in self._reuse_mcp_list_tools(tool, output_messages):
|
||||
yield evt
|
||||
# Process all remaining tools (including MCP tools) and emit streaming events
|
||||
if self.ctx.tool_context.tools_to_process:
|
||||
async for stream_event in self._process_new_tools(self.ctx.tool_context.tools_to_process, output_messages):
|
||||
async for stream_event in self._process_new_tools(
|
||||
self.ctx.tool_context.tools_to_process, output_messages
|
||||
):
|
||||
yield stream_event
|
||||
|
||||
def _approval_required(self, tool_name: str) -> bool:
|
||||
|
|
@ -1220,7 +1237,7 @@ class StreamingResponseOrchestrator:
|
|||
openai_tool = convert_tooldef_to_openai_tool(tool_def)
|
||||
if self.ctx.chat_tools is None:
|
||||
self.ctx.chat_tools = []
|
||||
self.ctx.chat_tools.append(openai_tool)
|
||||
self.ctx.chat_tools.append(openai_tool) # type: ignore[arg-type] # Returns dict but ChatCompletionToolParam expects TypedDict
|
||||
|
||||
mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
|
||||
id=f"mcp_list_{uuid.uuid4()}",
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
import asyncio
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseInputToolFileSearch,
|
||||
|
|
@ -22,6 +23,7 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
OpenAIResponseObjectStreamResponseWebSearchCallSearching,
|
||||
OpenAIResponseOutputMessageFileSearchToolCall,
|
||||
OpenAIResponseOutputMessageFileSearchToolCallResults,
|
||||
OpenAIResponseOutputMessageMCPCall,
|
||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||
)
|
||||
from llama_stack.apis.common.content_types import (
|
||||
|
|
@ -67,7 +69,7 @@ class ToolExecutor:
|
|||
) -> AsyncIterator[ToolExecutionResult]:
|
||||
tool_call_id = tool_call.id
|
||||
function = tool_call.function
|
||||
tool_kwargs = json.loads(function.arguments) if function.arguments else {}
|
||||
tool_kwargs = json.loads(function.arguments) if function and function.arguments else {}
|
||||
|
||||
if not function or not tool_call_id or not function.name:
|
||||
yield ToolExecutionResult(sequence_number=sequence_number)
|
||||
|
|
@ -84,7 +86,16 @@ class ToolExecutor:
|
|||
error_exc, result = await self._execute_tool(function.name, tool_kwargs, ctx, mcp_tool_to_server)
|
||||
|
||||
# Emit completion events for tool execution
|
||||
has_error = error_exc or (result and ((result.error_code and result.error_code > 0) or result.error_message))
|
||||
has_error = bool(
|
||||
error_exc
|
||||
or (
|
||||
result
|
||||
and (
|
||||
((error_code := getattr(result, "error_code", None)) and error_code > 0)
|
||||
or getattr(result, "error_message", None)
|
||||
)
|
||||
)
|
||||
)
|
||||
async for event_result in self._emit_completion_events(
|
||||
function.name, ctx, sequence_number, output_index, item_id, has_error, mcp_tool_to_server
|
||||
):
|
||||
|
|
@ -101,7 +112,9 @@ class ToolExecutor:
|
|||
sequence_number=sequence_number,
|
||||
final_output_message=output_message,
|
||||
final_input_message=input_message,
|
||||
citation_files=result.metadata.get("citation_files") if result and result.metadata else None,
|
||||
citation_files=(
|
||||
metadata.get("citation_files") if result and (metadata := getattr(result, "metadata", None)) else None
|
||||
),
|
||||
)
|
||||
|
||||
async def _execute_knowledge_search_via_vector_store(
|
||||
|
|
@ -188,8 +201,9 @@ class ToolExecutor:
|
|||
|
||||
citation_files[file_id] = filename
|
||||
|
||||
# Cast to proper InterleavedContent type (list invariance)
|
||||
return ToolInvocationResult(
|
||||
content=content_items,
|
||||
content=content_items, # type: ignore[arg-type]
|
||||
metadata={
|
||||
"document_ids": [r.file_id for r in search_results],
|
||||
"chunks": [r.content[0].text if r.content else "" for r in search_results],
|
||||
|
|
@ -209,51 +223,60 @@ class ToolExecutor:
|
|||
) -> AsyncIterator[ToolExecutionResult]:
|
||||
"""Emit progress events for tool execution start."""
|
||||
# Emit in_progress event based on tool type (only for tools with specific streaming events)
|
||||
progress_event = None
|
||||
if mcp_tool_to_server and function_name in mcp_tool_to_server:
|
||||
sequence_number += 1
|
||||
progress_event = OpenAIResponseObjectStreamResponseMcpCallInProgress(
|
||||
yield ToolExecutionResult(
|
||||
stream_event=OpenAIResponseObjectStreamResponseMcpCallInProgress(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
),
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
elif function_name == "web_search":
|
||||
sequence_number += 1
|
||||
progress_event = OpenAIResponseObjectStreamResponseWebSearchCallInProgress(
|
||||
yield ToolExecutionResult(
|
||||
stream_event=OpenAIResponseObjectStreamResponseWebSearchCallInProgress(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
),
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
elif function_name == "knowledge_search":
|
||||
sequence_number += 1
|
||||
progress_event = OpenAIResponseObjectStreamResponseFileSearchCallInProgress(
|
||||
yield ToolExecutionResult(
|
||||
stream_event=OpenAIResponseObjectStreamResponseFileSearchCallInProgress(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
),
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
|
||||
if progress_event:
|
||||
yield ToolExecutionResult(stream_event=progress_event, sequence_number=sequence_number)
|
||||
|
||||
# For web search, emit searching event
|
||||
if function_name == "web_search":
|
||||
sequence_number += 1
|
||||
searching_event = OpenAIResponseObjectStreamResponseWebSearchCallSearching(
|
||||
yield ToolExecutionResult(
|
||||
stream_event=OpenAIResponseObjectStreamResponseWebSearchCallSearching(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
),
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number)
|
||||
|
||||
# For file search, emit searching event
|
||||
if function_name == "knowledge_search":
|
||||
sequence_number += 1
|
||||
searching_event = OpenAIResponseObjectStreamResponseFileSearchCallSearching(
|
||||
yield ToolExecutionResult(
|
||||
stream_event=OpenAIResponseObjectStreamResponseFileSearchCallSearching(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
),
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number)
|
||||
|
||||
async def _execute_tool(
|
||||
self,
|
||||
|
|
@ -261,7 +284,7 @@ class ToolExecutor:
|
|||
tool_kwargs: dict,
|
||||
ctx: ChatCompletionContext,
|
||||
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
||||
) -> tuple[Exception | None, any]:
|
||||
) -> tuple[Exception | None, Any]:
|
||||
"""Execute the tool and return error exception and result."""
|
||||
error_exc = None
|
||||
result = None
|
||||
|
|
@ -284,10 +307,14 @@ class ToolExecutor:
|
|||
kwargs=tool_kwargs,
|
||||
)
|
||||
elif function_name == "knowledge_search":
|
||||
response_file_search_tool = next(
|
||||
response_file_search_tool = (
|
||||
next(
|
||||
(t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)),
|
||||
None,
|
||||
)
|
||||
if ctx.response_tools
|
||||
else None
|
||||
)
|
||||
if response_file_search_tool:
|
||||
# Use vector_stores.search API instead of knowledge_search tool
|
||||
# to support filters and ranking_options
|
||||
|
|
@ -322,35 +349,34 @@ class ToolExecutor:
|
|||
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
||||
) -> AsyncIterator[ToolExecutionResult]:
|
||||
"""Emit completion or failure events for tool execution."""
|
||||
completion_event = None
|
||||
|
||||
if mcp_tool_to_server and function_name in mcp_tool_to_server:
|
||||
sequence_number += 1
|
||||
if has_error:
|
||||
completion_event = OpenAIResponseObjectStreamResponseMcpCallFailed(
|
||||
mcp_failed_event = OpenAIResponseObjectStreamResponseMcpCallFailed(
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
yield ToolExecutionResult(stream_event=mcp_failed_event, sequence_number=sequence_number)
|
||||
else:
|
||||
completion_event = OpenAIResponseObjectStreamResponseMcpCallCompleted(
|
||||
mcp_completed_event = OpenAIResponseObjectStreamResponseMcpCallCompleted(
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
yield ToolExecutionResult(stream_event=mcp_completed_event, sequence_number=sequence_number)
|
||||
elif function_name == "web_search":
|
||||
sequence_number += 1
|
||||
completion_event = OpenAIResponseObjectStreamResponseWebSearchCallCompleted(
|
||||
web_completion_event = OpenAIResponseObjectStreamResponseWebSearchCallCompleted(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
yield ToolExecutionResult(stream_event=web_completion_event, sequence_number=sequence_number)
|
||||
elif function_name == "knowledge_search":
|
||||
sequence_number += 1
|
||||
completion_event = OpenAIResponseObjectStreamResponseFileSearchCallCompleted(
|
||||
file_completion_event = OpenAIResponseObjectStreamResponseFileSearchCallCompleted(
|
||||
item_id=item_id,
|
||||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
|
||||
if completion_event:
|
||||
yield ToolExecutionResult(stream_event=completion_event, sequence_number=sequence_number)
|
||||
yield ToolExecutionResult(stream_event=file_completion_event, sequence_number=sequence_number)
|
||||
|
||||
async def _build_result_messages(
|
||||
self,
|
||||
|
|
@ -360,21 +386,18 @@ class ToolExecutor:
|
|||
tool_kwargs: dict,
|
||||
ctx: ChatCompletionContext,
|
||||
error_exc: Exception | None,
|
||||
result: any,
|
||||
result: Any,
|
||||
has_error: bool,
|
||||
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
||||
) -> tuple[any, any]:
|
||||
) -> tuple[Any, Any]:
|
||||
"""Build output and input messages from tool execution results."""
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
# Build output message
|
||||
message: Any
|
||||
if mcp_tool_to_server and function.name in mcp_tool_to_server:
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseOutputMessageMCPCall,
|
||||
)
|
||||
|
||||
message = OpenAIResponseOutputMessageMCPCall(
|
||||
id=item_id,
|
||||
arguments=function.arguments,
|
||||
|
|
@ -383,10 +406,14 @@ class ToolExecutor:
|
|||
)
|
||||
if error_exc:
|
||||
message.error = str(error_exc)
|
||||
elif (result and result.error_code and result.error_code > 0) or (result and result.error_message):
|
||||
message.error = f"Error (code {result.error_code}): {result.error_message}"
|
||||
elif result and result.content:
|
||||
message.output = interleaved_content_as_str(result.content)
|
||||
elif (result and (error_code := getattr(result, "error_code", None)) and error_code > 0) or (
|
||||
result and getattr(result, "error_message", None)
|
||||
):
|
||||
ec = getattr(result, "error_code", "unknown")
|
||||
em = getattr(result, "error_message", "")
|
||||
message.error = f"Error (code {ec}): {em}"
|
||||
elif result and (content := getattr(result, "content", None)):
|
||||
message.output = interleaved_content_as_str(content)
|
||||
else:
|
||||
if function.name == "web_search":
|
||||
message = OpenAIResponseOutputMessageWebSearchToolCall(
|
||||
|
|
@ -401,17 +428,17 @@ class ToolExecutor:
|
|||
queries=[tool_kwargs.get("query", "")],
|
||||
status="completed",
|
||||
)
|
||||
if result and "document_ids" in result.metadata:
|
||||
if result and (metadata := getattr(result, "metadata", None)) and "document_ids" in metadata:
|
||||
message.results = []
|
||||
for i, doc_id in enumerate(result.metadata["document_ids"]):
|
||||
text = result.metadata["chunks"][i] if "chunks" in result.metadata else None
|
||||
score = result.metadata["scores"][i] if "scores" in result.metadata else None
|
||||
for i, doc_id in enumerate(metadata["document_ids"]):
|
||||
text = metadata["chunks"][i] if "chunks" in metadata else None
|
||||
score = metadata["scores"][i] if "scores" in metadata else None
|
||||
message.results.append(
|
||||
OpenAIResponseOutputMessageFileSearchToolCallResults(
|
||||
file_id=doc_id,
|
||||
filename=doc_id,
|
||||
text=text,
|
||||
score=score,
|
||||
text=text if text is not None else "",
|
||||
score=score if score is not None else 0.0,
|
||||
attributes={},
|
||||
)
|
||||
)
|
||||
|
|
@ -421,27 +448,32 @@ class ToolExecutor:
|
|||
raise ValueError(f"Unknown tool {function.name} called")
|
||||
|
||||
# Build input message
|
||||
input_message = None
|
||||
if result and result.content:
|
||||
if isinstance(result.content, str):
|
||||
content = result.content
|
||||
elif isinstance(result.content, list):
|
||||
content = []
|
||||
for item in result.content:
|
||||
input_message: OpenAIToolMessageParam | None = None
|
||||
if result and (result_content := getattr(result, "content", None)):
|
||||
# all the mypy contortions here are still unsatisfactory with random Any typing
|
||||
if isinstance(result_content, str):
|
||||
msg_content: str | list[Any] = result_content
|
||||
elif isinstance(result_content, list):
|
||||
content_list: list[Any] = []
|
||||
for item in result_content:
|
||||
part: Any
|
||||
if isinstance(item, TextContentItem):
|
||||
part = OpenAIChatCompletionContentPartTextParam(text=item.text)
|
||||
elif isinstance(item, ImageContentItem):
|
||||
if item.image.data:
|
||||
url = f"data:image;base64,{item.image.data}"
|
||||
url_value = f"data:image;base64,{item.image.data}"
|
||||
else:
|
||||
url = item.image.url
|
||||
part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url))
|
||||
url_value = str(item.image.url) if item.image.url else ""
|
||||
part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url_value))
|
||||
else:
|
||||
raise ValueError(f"Unknown result content type: {type(item)}")
|
||||
content.append(part)
|
||||
content_list.append(part)
|
||||
msg_content = content_list
|
||||
else:
|
||||
raise ValueError(f"Unknown result content type: {type(result.content)}")
|
||||
input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id)
|
||||
raise ValueError(f"Unknown result content type: {type(result_content)}")
|
||||
# OpenAIToolMessageParam accepts str | list[TextParam] but we may have images
|
||||
# This is runtime-safe as the API accepts it, but mypy complains
|
||||
input_message = OpenAIToolMessageParam(content=msg_content, tool_call_id=tool_call_id) # type: ignore[arg-type]
|
||||
else:
|
||||
text = str(error_exc) if error_exc else "Tool execution failed"
|
||||
input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id)
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import cast
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -100,17 +101,19 @@ class ToolContext(BaseModel):
|
|||
if isinstance(tool, OpenAIResponseToolMCP):
|
||||
previous_tools_by_label[tool.server_label] = tool
|
||||
# collect tool definitions which are the same in current and previous requests:
|
||||
tools_to_process = []
|
||||
tools_to_process: list[OpenAIResponseInputTool] = []
|
||||
matched: dict[str, OpenAIResponseInputToolMCP] = {}
|
||||
for tool in self.current_tools:
|
||||
# Mypy confuses OpenAIResponseInputTool (Input union) with OpenAIResponseTool (output union)
|
||||
# which differ only in MCP type (InputToolMCP vs ToolMCP). Code is correct.
|
||||
for tool in cast(list[OpenAIResponseInputTool], self.current_tools): # type: ignore[assignment]
|
||||
if isinstance(tool, OpenAIResponseInputToolMCP) and tool.server_label in previous_tools_by_label:
|
||||
previous_tool = previous_tools_by_label[tool.server_label]
|
||||
if previous_tool.allowed_tools == tool.allowed_tools:
|
||||
matched[tool.server_label] = tool
|
||||
else:
|
||||
tools_to_process.append(tool)
|
||||
tools_to_process.append(tool) # type: ignore[arg-type]
|
||||
else:
|
||||
tools_to_process.append(tool)
|
||||
tools_to_process.append(tool) # type: ignore[arg-type]
|
||||
# tools that are not the same or were not previously defined need to be processed:
|
||||
self.tools_to_process = tools_to_process
|
||||
# for all matched definitions, get the mcp_list_tools objects from the previous output:
|
||||
|
|
@ -119,9 +122,11 @@ class ToolContext(BaseModel):
|
|||
]
|
||||
# reconstruct the tool to server mappings that can be reused:
|
||||
for listing in self.previous_tool_listings:
|
||||
# listing is OpenAIResponseOutputMessageMCPListTools which has tools: list[MCPListToolsTool]
|
||||
definition = matched[listing.server_label]
|
||||
for tool in listing.tools:
|
||||
self.previous_tools[tool.name] = definition
|
||||
for mcp_tool in listing.tools:
|
||||
# mcp_tool is MCPListToolsTool which has a name: str field
|
||||
self.previous_tools[mcp_tool.name] = definition
|
||||
|
||||
def available_tools(self) -> list[OpenAIResponseTool]:
|
||||
if not self.current_tools:
|
||||
|
|
@ -139,6 +144,8 @@ class ToolContext(BaseModel):
|
|||
server_label=tool.server_label,
|
||||
allowed_tools=tool.allowed_tools,
|
||||
)
|
||||
# Exhaustive check - all tool types should be handled above
|
||||
raise AssertionError(f"Unexpected tool type: {type(tool)}")
|
||||
|
||||
return [convert_tool(tool) for tool in self.current_tools]
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
import asyncio
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import Sequence
|
||||
|
||||
from llama_stack.apis.agents.agents import ResponseGuardrailSpec
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
|
|
@ -71,14 +72,14 @@ async def convert_chat_choice_to_response_message(
|
|||
|
||||
return OpenAIResponseMessage(
|
||||
id=message_id or f"msg_{uuid.uuid4()}",
|
||||
content=[OpenAIResponseOutputMessageContentOutputText(text=clean_text, annotations=annotations)],
|
||||
content=[OpenAIResponseOutputMessageContentOutputText(text=clean_text, annotations=list(annotations))],
|
||||
status="completed",
|
||||
role="assistant",
|
||||
)
|
||||
|
||||
|
||||
async def convert_response_content_to_chat_content(
|
||||
content: (str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]),
|
||||
content: str | Sequence[OpenAIResponseInputMessageContent | OpenAIResponseOutputMessageContent],
|
||||
) -> str | list[OpenAIChatCompletionContentPartParam]:
|
||||
"""
|
||||
Convert the content parts from an OpenAI Response API request into OpenAI Chat Completion content parts.
|
||||
|
|
@ -88,7 +89,8 @@ async def convert_response_content_to_chat_content(
|
|||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
converted_parts = []
|
||||
# Type with union to avoid list invariance issues
|
||||
converted_parts: list[OpenAIChatCompletionContentPartParam] = []
|
||||
for content_part in content:
|
||||
if isinstance(content_part, OpenAIResponseInputMessageContentText):
|
||||
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
|
||||
|
|
@ -158,9 +160,11 @@ async def convert_response_input_to_chat_messages(
|
|||
),
|
||||
)
|
||||
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
|
||||
# Output can be None, use empty string as fallback
|
||||
output_content = input_item.output if input_item.output is not None else ""
|
||||
messages.append(
|
||||
OpenAIToolMessageParam(
|
||||
content=input_item.output,
|
||||
content=output_content,
|
||||
tool_call_id=input_item.id,
|
||||
)
|
||||
)
|
||||
|
|
@ -172,7 +176,8 @@ async def convert_response_input_to_chat_messages(
|
|||
):
|
||||
# these are handled by the responses impl itself and not pass through to chat completions
|
||||
pass
|
||||
else:
|
||||
elif isinstance(input_item, OpenAIResponseMessage):
|
||||
# Narrow type to OpenAIResponseMessage which has content and role attributes
|
||||
content = await convert_response_content_to_chat_content(input_item.content)
|
||||
message_type = await get_message_type_by_role(input_item.role)
|
||||
if message_type is None:
|
||||
|
|
@ -191,7 +196,8 @@ async def convert_response_input_to_chat_messages(
|
|||
last_user_content = getattr(last_user_msg, "content", None)
|
||||
if last_user_content == content:
|
||||
continue # Skip duplicate user message
|
||||
messages.append(message_type(content=content))
|
||||
# Dynamic message type call - different message types have different content expectations
|
||||
messages.append(message_type(content=content)) # type: ignore[call-arg,arg-type]
|
||||
if len(tool_call_results):
|
||||
# Check if unpaired function_call_outputs reference function_calls from previous messages
|
||||
if previous_messages:
|
||||
|
|
@ -237,8 +243,11 @@ async def convert_response_text_to_chat_response_format(
|
|||
if text.format["type"] == "json_object":
|
||||
return OpenAIResponseFormatJSONObject()
|
||||
if text.format["type"] == "json_schema":
|
||||
# Assert name exists for json_schema format
|
||||
assert text.format.get("name"), "json_schema format requires a name"
|
||||
schema_name: str = text.format["name"] # type: ignore[assignment]
|
||||
return OpenAIResponseFormatJSONSchema(
|
||||
json_schema=OpenAIJSONSchema(name=text.format["name"], schema=text.format["schema"])
|
||||
json_schema=OpenAIJSONSchema(name=schema_name, schema=text.format["schema"])
|
||||
)
|
||||
raise ValueError(f"Unsupported text format: {text.format}")
|
||||
|
||||
|
|
@ -251,7 +260,7 @@ async def get_message_type_by_role(role: str) -> type[OpenAIMessageParam] | None
|
|||
"assistant": OpenAIAssistantMessageParam,
|
||||
"developer": OpenAIDeveloperMessageParam,
|
||||
}
|
||||
return role_to_type.get(role)
|
||||
return role_to_type.get(role) # type: ignore[return-value] # Pydantic models use ModelMetaclass
|
||||
|
||||
|
||||
def _extract_citations_from_text(
|
||||
|
|
@ -320,7 +329,8 @@ async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[
|
|||
|
||||
# Look up shields to get their provider_resource_id (actual model ID)
|
||||
model_ids = []
|
||||
shields_list = await safety_api.routing_table.list_shields()
|
||||
# TODO: list_shields not in Safety interface but available at runtime via API routing
|
||||
shields_list = await safety_api.routing_table.list_shields() # type: ignore[attr-defined]
|
||||
|
||||
for guardrail_id in guardrail_ids:
|
||||
matching_shields = [shield for shield in shields_list.data if shield.identifier == guardrail_id]
|
||||
|
|
@ -337,7 +347,9 @@ async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[
|
|||
for result in response.results:
|
||||
if result.flagged:
|
||||
message = result.user_message or "Content blocked by safety guardrails"
|
||||
flagged_categories = [cat for cat, flagged in result.categories.items() if flagged]
|
||||
flagged_categories = (
|
||||
[cat for cat, flagged in result.categories.items() if flagged] if result.categories else []
|
||||
)
|
||||
violation_type = result.metadata.get("violation_type", []) if result.metadata else []
|
||||
|
||||
if flagged_categories:
|
||||
|
|
@ -347,6 +359,9 @@ async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[
|
|||
|
||||
return message
|
||||
|
||||
# No violations found
|
||||
return None
|
||||
|
||||
|
||||
def extract_guardrail_ids(guardrails: list | None) -> list[str]:
|
||||
"""Extract guardrail IDs from guardrails parameter, handling both string IDs and ResponseGuardrailSpec objects."""
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
import asyncio
|
||||
|
||||
from llama_stack.apis.inference import Message
|
||||
from llama_stack.apis.inference import OpenAIMessageParam
|
||||
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
|
||||
from llama_stack.core.telemetry import tracing
|
||||
from llama_stack.log import get_logger
|
||||
|
|
@ -31,7 +31,7 @@ class ShieldRunnerMixin:
|
|||
self.input_shields = input_shields
|
||||
self.output_shields = output_shields
|
||||
|
||||
async def run_multiple_shields(self, messages: list[Message], identifiers: list[str]) -> None:
|
||||
async def run_multiple_shields(self, messages: list[OpenAIMessageParam], identifiers: list[str]) -> None:
|
||||
async def run_shield_with_span(identifier: str):
|
||||
async with tracing.span(f"run_shield_{identifier}"):
|
||||
return await self.safety_api.run_shield(
|
||||
|
|
|
|||
|
|
@ -28,4 +28,13 @@ def available_providers() -> list[ProviderSpec]:
|
|||
config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig",
|
||||
description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.",
|
||||
),
|
||||
RemoteProviderSpec(
|
||||
api=Api.files,
|
||||
provider_type="remote::openai",
|
||||
adapter_type="openai",
|
||||
pip_packages=["openai"] + sql_store_pip_packages,
|
||||
module="llama_stack.providers.remote.files.openai",
|
||||
config_class="llama_stack.providers.remote.files.openai.config.OpenAIFilesImplConfig",
|
||||
description="OpenAI Files API provider for managing files through OpenAI's native file storage service.",
|
||||
),
|
||||
]
|
||||
|
|
|
|||
19
src/llama_stack/providers/remote/files/openai/__init__.py
Normal file
19
src/llama_stack/providers/remote/files/openai/__init__.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.core.datatypes import AccessRule, Api
|
||||
|
||||
from .config import OpenAIFilesImplConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: OpenAIFilesImplConfig, deps: dict[Api, Any], policy: list[AccessRule] | None = None):
|
||||
from .files import OpenAIFilesImpl
|
||||
|
||||
impl = OpenAIFilesImpl(config, policy or [])
|
||||
await impl.initialize()
|
||||
return impl
|
||||
28
src/llama_stack/providers/remote/files/openai/config.py
Normal file
28
src/llama_stack/providers/remote/files/openai/config.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.core.storage.datatypes import SqlStoreReference
|
||||
|
||||
|
||||
class OpenAIFilesImplConfig(BaseModel):
|
||||
"""Configuration for OpenAI Files API provider."""
|
||||
|
||||
api_key: str = Field(description="OpenAI API key for authentication")
|
||||
metadata_store: SqlStoreReference = Field(description="SQL store configuration for file metadata")
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
|
||||
return {
|
||||
"api_key": "${env.OPENAI_API_KEY}",
|
||||
"metadata_store": SqlStoreReference(
|
||||
backend="sql_default",
|
||||
table_name="openai_files_metadata",
|
||||
).model_dump(exclude_none=True),
|
||||
}
|
||||
239
src/llama_stack/providers/remote/files/openai/files.py
Normal file
239
src/llama_stack/providers/remote/files/openai/files.py
Normal file
|
|
@ -0,0 +1,239 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from typing import Annotated, Any
|
||||
|
||||
from fastapi import Depends, File, Form, Response, UploadFile
|
||||
|
||||
from llama_stack.apis.common.errors import ResourceNotFoundError
|
||||
from llama_stack.apis.common.responses import Order
|
||||
from llama_stack.apis.files import (
|
||||
ExpiresAfter,
|
||||
Files,
|
||||
ListOpenAIFileResponse,
|
||||
OpenAIFileDeleteResponse,
|
||||
OpenAIFileObject,
|
||||
OpenAIFilePurpose,
|
||||
)
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.providers.utils.files.form_data import parse_expires_after
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
|
||||
from openai import OpenAI
|
||||
|
||||
from .config import OpenAIFilesImplConfig
|
||||
|
||||
|
||||
def _make_file_object(
|
||||
*,
|
||||
id: str,
|
||||
filename: str,
|
||||
purpose: str,
|
||||
bytes: int,
|
||||
created_at: int,
|
||||
expires_at: int,
|
||||
**kwargs: Any,
|
||||
) -> OpenAIFileObject:
|
||||
"""
|
||||
Construct an OpenAIFileObject and normalize expires_at.
|
||||
|
||||
If expires_at is greater than the max we treat it as no-expiration and
|
||||
return None for expires_at.
|
||||
"""
|
||||
obj = OpenAIFileObject(
|
||||
id=id,
|
||||
filename=filename,
|
||||
purpose=OpenAIFilePurpose(purpose),
|
||||
bytes=bytes,
|
||||
created_at=created_at,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
if obj.expires_at is not None and obj.expires_at > (obj.created_at + ExpiresAfter.MAX):
|
||||
obj.expires_at = None # type: ignore
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
class OpenAIFilesImpl(Files):
|
||||
"""OpenAI Files API implementation."""
|
||||
|
||||
def __init__(self, config: OpenAIFilesImplConfig, policy: list[AccessRule]) -> None:
|
||||
self._config = config
|
||||
self.policy = policy
|
||||
self._client: OpenAI | None = None
|
||||
self._sql_store: AuthorizedSqlStore | None = None
|
||||
|
||||
def _now(self) -> int:
|
||||
"""Return current UTC timestamp as int seconds."""
|
||||
return int(datetime.now(UTC).timestamp())
|
||||
|
||||
async def _get_file(self, file_id: str, return_expired: bool = False) -> dict[str, Any]:
|
||||
where: dict[str, str | dict] = {"id": file_id}
|
||||
if not return_expired:
|
||||
where["expires_at"] = {">": self._now()}
|
||||
if not (row := await self.sql_store.fetch_one("openai_files", where=where)):
|
||||
raise ResourceNotFoundError(file_id, "File", "files.list()")
|
||||
return row
|
||||
|
||||
async def _delete_file(self, file_id: str) -> None:
|
||||
"""Delete a file from OpenAI and the database."""
|
||||
try:
|
||||
self.client.files.delete(file_id)
|
||||
except Exception as e:
|
||||
# If file doesn't exist on OpenAI side, just remove from metadata store
|
||||
if "not found" not in str(e).lower():
|
||||
raise RuntimeError(f"Failed to delete file from OpenAI: {e}") from e
|
||||
|
||||
await self.sql_store.delete("openai_files", where={"id": file_id})
|
||||
|
||||
async def _delete_if_expired(self, file_id: str) -> None:
|
||||
"""If the file exists and is expired, delete it."""
|
||||
if row := await self._get_file(file_id, return_expired=True):
|
||||
if (expires_at := row.get("expires_at")) and expires_at <= self._now():
|
||||
await self._delete_file(file_id)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self._client = OpenAI(api_key=self._config.api_key)
|
||||
|
||||
self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store), self.policy)
|
||||
await self._sql_store.create_table(
|
||||
"openai_files",
|
||||
{
|
||||
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
|
||||
"filename": ColumnType.STRING,
|
||||
"purpose": ColumnType.STRING,
|
||||
"bytes": ColumnType.INTEGER,
|
||||
"created_at": ColumnType.INTEGER,
|
||||
"expires_at": ColumnType.INTEGER,
|
||||
},
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
@property
|
||||
def client(self) -> OpenAI:
|
||||
assert self._client is not None, "Provider not initialized"
|
||||
return self._client
|
||||
|
||||
@property
|
||||
def sql_store(self) -> AuthorizedSqlStore:
|
||||
assert self._sql_store is not None, "Provider not initialized"
|
||||
return self._sql_store
|
||||
|
||||
async def openai_upload_file(
|
||||
self,
|
||||
file: Annotated[UploadFile, File()],
|
||||
purpose: Annotated[OpenAIFilePurpose, Form()],
|
||||
expires_after: Annotated[ExpiresAfter | None, Depends(parse_expires_after)] = None,
|
||||
) -> OpenAIFileObject:
|
||||
filename = getattr(file, "filename", None) or "uploaded_file"
|
||||
content = await file.read()
|
||||
file_size = len(content)
|
||||
|
||||
created_at = self._now()
|
||||
|
||||
expires_at = created_at + ExpiresAfter.MAX * 42
|
||||
if purpose == OpenAIFilePurpose.BATCH:
|
||||
expires_at = created_at + ExpiresAfter.MAX
|
||||
|
||||
if expires_after is not None:
|
||||
expires_at = created_at + expires_after.seconds
|
||||
|
||||
try:
|
||||
from io import BytesIO
|
||||
|
||||
file_obj = BytesIO(content)
|
||||
file_obj.name = filename
|
||||
|
||||
response = self.client.files.create(
|
||||
file=file_obj,
|
||||
purpose=purpose.value,
|
||||
)
|
||||
|
||||
file_id = response.id
|
||||
|
||||
entry: dict[str, Any] = {
|
||||
"id": file_id,
|
||||
"filename": filename,
|
||||
"purpose": purpose.value,
|
||||
"bytes": file_size,
|
||||
"created_at": created_at,
|
||||
"expires_at": expires_at,
|
||||
}
|
||||
|
||||
await self.sql_store.insert("openai_files", entry)
|
||||
|
||||
return _make_file_object(**entry)
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to upload file to OpenAI: {e}") from e
|
||||
|
||||
async def openai_list_files(
|
||||
self,
|
||||
after: str | None = None,
|
||||
limit: int | None = 10000,
|
||||
order: Order | None = Order.desc,
|
||||
purpose: OpenAIFilePurpose | None = None,
|
||||
) -> ListOpenAIFileResponse:
|
||||
if not order:
|
||||
order = Order.desc
|
||||
|
||||
where_conditions: dict[str, Any] = {"expires_at": {">": self._now()}}
|
||||
if purpose:
|
||||
where_conditions["purpose"] = purpose.value
|
||||
|
||||
paginated_result = await self.sql_store.fetch_all(
|
||||
table="openai_files",
|
||||
where=where_conditions,
|
||||
order_by=[("created_at", order.value)],
|
||||
cursor=("id", after) if after else None,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
files = [_make_file_object(**row) for row in paginated_result.data]
|
||||
|
||||
return ListOpenAIFileResponse(
|
||||
data=files,
|
||||
has_more=paginated_result.has_more,
|
||||
first_id=files[0].id if files else "",
|
||||
last_id=files[-1].id if files else "",
|
||||
)
|
||||
|
||||
async def openai_retrieve_file(self, file_id: str) -> OpenAIFileObject:
|
||||
await self._delete_if_expired(file_id)
|
||||
row = await self._get_file(file_id)
|
||||
return _make_file_object(**row)
|
||||
|
||||
async def openai_delete_file(self, file_id: str) -> OpenAIFileDeleteResponse:
|
||||
await self._delete_if_expired(file_id)
|
||||
_ = await self._get_file(file_id)
|
||||
await self._delete_file(file_id)
|
||||
return OpenAIFileDeleteResponse(id=file_id, deleted=True)
|
||||
|
||||
async def openai_retrieve_file_content(self, file_id: str) -> Response:
|
||||
await self._delete_if_expired(file_id)
|
||||
|
||||
row = await self._get_file(file_id)
|
||||
|
||||
try:
|
||||
response = self.client.files.content(file_id)
|
||||
file_content = response.content
|
||||
|
||||
except Exception as e:
|
||||
if "not found" in str(e).lower():
|
||||
await self._delete_file(file_id)
|
||||
raise ResourceNotFoundError(file_id, "File", "files.list()") from e
|
||||
raise RuntimeError(f"Failed to download file from OpenAI: {e}") from e
|
||||
|
||||
return Response(
|
||||
content=file_content,
|
||||
media_type="application/octet-stream",
|
||||
headers={"Content-Disposition": f'attachment; filename="{row["filename"]}"'},
|
||||
)
|
||||
|
|
@ -33,4 +33,5 @@ class AnthropicInferenceAdapter(OpenAIMixin):
|
|||
return "https://api.anthropic.com/v1"
|
||||
|
||||
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||
return [m.id async for m in AsyncAnthropic(api_key=self.get_api_key()).models.list()]
|
||||
api_key = self._get_api_key_from_config_or_provider_data()
|
||||
return [m.id async for m in AsyncAnthropic(api_key=api_key).models.list()]
|
||||
|
|
|
|||
|
|
@ -33,10 +33,11 @@ class DatabricksInferenceAdapter(OpenAIMixin):
|
|||
|
||||
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||
# Filter out None values from endpoint names
|
||||
api_token = self._get_api_key_from_config_or_provider_data()
|
||||
return [
|
||||
endpoint.name # type: ignore[misc]
|
||||
for endpoint in WorkspaceClient(
|
||||
host=self.config.url, token=self.get_api_key()
|
||||
host=self.config.url, token=api_token
|
||||
).serving_endpoints.list() # TODO: this is not async
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -181,3 +181,22 @@ vlm_response = client.chat.completions.create(
|
|||
|
||||
print(f"VLM Response: {vlm_response.choices[0].message.content}")
|
||||
```
|
||||
|
||||
### Rerank Example
|
||||
|
||||
The following example shows how to rerank documents using an NVIDIA NIM.
|
||||
|
||||
```python
|
||||
rerank_response = client.alpha.inference.rerank(
|
||||
model="nvidia/nvidia/llama-3.2-nv-rerankqa-1b-v2",
|
||||
query="query",
|
||||
items=[
|
||||
"item_1",
|
||||
"item_2",
|
||||
"item_3",
|
||||
],
|
||||
)
|
||||
|
||||
for i, result in enumerate(rerank_response):
|
||||
print(f"{i+1}. [Index: {result.index}, " f"Score: {(result.relevance_score):.3f}]")
|
||||
```
|
||||
|
|
@ -28,6 +28,7 @@ class NVIDIAConfig(RemoteInferenceProviderConfig):
|
|||
Attributes:
|
||||
url (str): A base url for accessing the NVIDIA NIM, e.g. http://localhost:8000
|
||||
api_key (str): The access key for the hosted NIM endpoints
|
||||
rerank_model_to_url (dict[str, str]): Mapping of rerank model identifiers to their API endpoints
|
||||
|
||||
There are two ways to access NVIDIA NIMs -
|
||||
0. Hosted: Preview APIs hosted at https://integrate.api.nvidia.com
|
||||
|
|
@ -55,6 +56,14 @@ class NVIDIAConfig(RemoteInferenceProviderConfig):
|
|||
default_factory=lambda: os.getenv("NVIDIA_APPEND_API_VERSION", "True").lower() != "false",
|
||||
description="When set to false, the API version will not be appended to the base_url. By default, it is true.",
|
||||
)
|
||||
rerank_model_to_url: dict[str, str] = Field(
|
||||
default_factory=lambda: {
|
||||
"nv-rerank-qa-mistral-4b:1": "https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking",
|
||||
"nvidia/nv-rerankqa-mistral-4b-v3": "https://ai.api.nvidia.com/v1/retrieval/nvidia/nv-rerankqa-mistral-4b-v3/reranking",
|
||||
"nvidia/llama-3.2-nv-rerankqa-1b-v2": "https://ai.api.nvidia.com/v1/retrieval/nvidia/llama-3_2-nv-rerankqa-1b-v2/reranking",
|
||||
},
|
||||
description="Mapping of rerank model identifiers to their API endpoints. ",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
|
|
|
|||
|
|
@ -5,6 +5,19 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from collections.abc import Iterable
|
||||
|
||||
import aiohttp
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
RerankData,
|
||||
RerankResponse,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
|
|
@ -61,3 +74,101 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
|
|||
:return: The NVIDIA API base URL
|
||||
"""
|
||||
return f"{self.config.url}/v1" if self.config.append_api_version else self.config.url
|
||||
|
||||
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||
"""
|
||||
Return both dynamic model IDs and statically configured rerank model IDs.
|
||||
"""
|
||||
dynamic_ids: Iterable[str] = []
|
||||
try:
|
||||
dynamic_ids = await super().list_provider_model_ids()
|
||||
except Exception:
|
||||
# If the dynamic listing fails, proceed with just configured rerank IDs
|
||||
dynamic_ids = []
|
||||
|
||||
configured_rerank_ids = list(self.config.rerank_model_to_url.keys())
|
||||
return list(dict.fromkeys(list(dynamic_ids) + configured_rerank_ids)) # remove duplicates
|
||||
|
||||
def construct_model_from_identifier(self, identifier: str) -> Model:
|
||||
"""
|
||||
Classify rerank models from config; otherwise use the base behavior.
|
||||
"""
|
||||
if identifier in self.config.rerank_model_to_url:
|
||||
return Model(
|
||||
provider_id=self.__provider_id__, # type: ignore[attr-defined]
|
||||
provider_resource_id=identifier,
|
||||
identifier=identifier,
|
||||
model_type=ModelType.rerank,
|
||||
)
|
||||
return super().construct_model_from_identifier(identifier)
|
||||
|
||||
async def rerank(
|
||||
self,
|
||||
model: str,
|
||||
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
|
||||
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
|
||||
max_num_results: int | None = None,
|
||||
) -> RerankResponse:
|
||||
provider_model_id = await self._get_provider_model_id(model)
|
||||
|
||||
ranking_url = self.get_base_url()
|
||||
|
||||
if _is_nvidia_hosted(self.config) and provider_model_id in self.config.rerank_model_to_url:
|
||||
ranking_url = self.config.rerank_model_to_url[provider_model_id]
|
||||
|
||||
logger.debug(f"Using rerank endpoint: {ranking_url} for model: {provider_model_id}")
|
||||
|
||||
# Convert query to text format
|
||||
if isinstance(query, str):
|
||||
query_text = query
|
||||
elif isinstance(query, OpenAIChatCompletionContentPartTextParam):
|
||||
query_text = query.text
|
||||
else:
|
||||
raise ValueError("Query must be a string or text content part")
|
||||
|
||||
# Convert items to text format
|
||||
passages = []
|
||||
for item in items:
|
||||
if isinstance(item, str):
|
||||
passages.append({"text": item})
|
||||
elif isinstance(item, OpenAIChatCompletionContentPartTextParam):
|
||||
passages.append({"text": item.text})
|
||||
else:
|
||||
raise ValueError("Items must be strings or text content parts")
|
||||
|
||||
payload = {
|
||||
"model": provider_model_id,
|
||||
"query": {"text": query_text},
|
||||
"passages": passages,
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.get_api_key()}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(ranking_url, headers=headers, json=payload) as response:
|
||||
if response.status != 200:
|
||||
response_text = await response.text()
|
||||
raise ConnectionError(
|
||||
f"NVIDIA rerank API request failed with status {response.status}: {response_text}"
|
||||
)
|
||||
|
||||
result = await response.json()
|
||||
rankings = result.get("rankings", [])
|
||||
|
||||
# Convert to RerankData format
|
||||
rerank_data = []
|
||||
for ranking in rankings:
|
||||
rerank_data.append(RerankData(index=ranking["index"], relevance_score=ranking["logit"]))
|
||||
|
||||
# Apply max_num_results limit
|
||||
if max_num_results is not None:
|
||||
rerank_data = rerank_data[:max_num_results]
|
||||
|
||||
return RerankResponse(data=rerank_data)
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
raise ConnectionError(f"Failed to connect to NVIDIA rerank API at {ranking_url}: {e}") from e
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ class InferenceStore:
|
|||
self.reference = reference
|
||||
self.sql_store = None
|
||||
self.policy = policy
|
||||
self.enable_write_queue = True
|
||||
|
||||
# Async write queue and worker control
|
||||
self._queue: asyncio.Queue[tuple[OpenAIChatCompletion, list[OpenAIMessageParam]]] | None = None
|
||||
|
|
@ -47,14 +48,13 @@ class InferenceStore:
|
|||
base_store = sqlstore_impl(self.reference)
|
||||
self.sql_store = AuthorizedSqlStore(base_store, self.policy)
|
||||
|
||||
# Disable write queue for SQLite to avoid concurrency issues
|
||||
backend_name = self.reference.backend
|
||||
backend_config = _SQLSTORE_BACKENDS.get(backend_name)
|
||||
if backend_config is None:
|
||||
raise ValueError(
|
||||
f"Unregistered SQL backend '{backend_name}'. Registered backends: {sorted(_SQLSTORE_BACKENDS)}"
|
||||
)
|
||||
self.enable_write_queue = backend_config.type != StorageBackendType.SQL_SQLITE
|
||||
# Disable write queue for SQLite since WAL mode handles concurrency
|
||||
# Keep it enabled for other backends (like Postgres) for performance
|
||||
backend_config = _SQLSTORE_BACKENDS.get(self.reference.backend)
|
||||
if backend_config and backend_config.type == StorageBackendType.SQL_SQLITE:
|
||||
self.enable_write_queue = False
|
||||
logger.debug("Write queue disabled for SQLite (WAL mode handles concurrency)")
|
||||
|
||||
await self.sql_store.create_table(
|
||||
"chat_completions",
|
||||
{
|
||||
|
|
@ -70,8 +70,9 @@ class InferenceStore:
|
|||
self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
|
||||
for _ in range(self._num_writers):
|
||||
self._worker_tasks.append(asyncio.create_task(self._worker_loop()))
|
||||
else:
|
||||
logger.info("Write queue disabled for SQLite to avoid concurrency issues")
|
||||
logger.debug(
|
||||
f"Inference store write queue enabled with {self._num_writers} writers, max queue size {self._max_write_queue_size}"
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if not self._worker_tasks:
|
||||
|
|
|
|||
|
|
@ -128,7 +128,9 @@ class LiteLLMOpenAIMixin(
|
|||
return schema
|
||||
|
||||
async def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
input_dict = {}
|
||||
from typing import Any
|
||||
|
||||
input_dict: dict[str, Any] = {}
|
||||
|
||||
input_dict["messages"] = [
|
||||
await convert_message_to_openai_dict_new(m, download_images=self.download_images) for m in request.messages
|
||||
|
|
@ -139,30 +141,27 @@ class LiteLLMOpenAIMixin(
|
|||
f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported."
|
||||
)
|
||||
|
||||
fmt = fmt.json_schema
|
||||
name = fmt["title"]
|
||||
del fmt["title"]
|
||||
fmt["additionalProperties"] = False
|
||||
# Convert to dict for manipulation
|
||||
fmt_dict = dict(fmt.json_schema)
|
||||
name = fmt_dict["title"]
|
||||
del fmt_dict["title"]
|
||||
fmt_dict["additionalProperties"] = False
|
||||
|
||||
# Apply additionalProperties: False recursively to all objects
|
||||
fmt = self._add_additional_properties_recursive(fmt)
|
||||
fmt_dict = self._add_additional_properties_recursive(fmt_dict)
|
||||
|
||||
input_dict["response_format"] = {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": name,
|
||||
"schema": fmt,
|
||||
"schema": fmt_dict,
|
||||
"strict": self.json_schema_strict,
|
||||
},
|
||||
}
|
||||
if request.tools:
|
||||
input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools]
|
||||
if request.tool_config.tool_choice:
|
||||
input_dict["tool_choice"] = (
|
||||
request.tool_config.tool_choice.value
|
||||
if isinstance(request.tool_config.tool_choice, ToolChoice)
|
||||
else request.tool_config.tool_choice
|
||||
)
|
||||
if request.tool_config and (tool_choice := request.tool_config.tool_choice):
|
||||
input_dict["tool_choice"] = tool_choice.value if isinstance(tool_choice, ToolChoice) else tool_choice
|
||||
|
||||
return {
|
||||
"model": request.model,
|
||||
|
|
@ -176,9 +175,9 @@ class LiteLLMOpenAIMixin(
|
|||
def get_api_key(self) -> str:
|
||||
provider_data = self.get_request_provider_data()
|
||||
key_field = self.provider_data_api_key_field
|
||||
if provider_data and getattr(provider_data, key_field, None):
|
||||
api_key = getattr(provider_data, key_field)
|
||||
else:
|
||||
if provider_data and key_field and (api_key := getattr(provider_data, key_field, None)):
|
||||
return str(api_key) # type: ignore[no-any-return] # getattr returns Any, can't narrow without runtime type inspection
|
||||
|
||||
api_key = self.api_key_from_config
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
|
|
@ -192,7 +191,13 @@ class LiteLLMOpenAIMixin(
|
|||
self,
|
||||
params: OpenAIEmbeddingsRequestWithExtraBody,
|
||||
) -> OpenAIEmbeddingsResponse:
|
||||
if not self.model_store:
|
||||
raise ValueError("Model store is not initialized")
|
||||
|
||||
model_obj = await self.model_store.get_model(params.model)
|
||||
if model_obj.provider_resource_id is None:
|
||||
raise ValueError(f"Model {params.model} has no provider_resource_id")
|
||||
provider_resource_id = model_obj.provider_resource_id
|
||||
|
||||
# Convert input to list if it's a string
|
||||
input_list = [params.input] if isinstance(params.input, str) else params.input
|
||||
|
|
@ -200,7 +205,7 @@ class LiteLLMOpenAIMixin(
|
|||
# Call litellm embedding function
|
||||
# litellm.drop_params = True
|
||||
response = litellm.embedding(
|
||||
model=self.get_litellm_model_name(model_obj.provider_resource_id),
|
||||
model=self.get_litellm_model_name(provider_resource_id),
|
||||
input=input_list,
|
||||
api_key=self.get_api_key(),
|
||||
api_base=self.api_base,
|
||||
|
|
@ -217,7 +222,7 @@ class LiteLLMOpenAIMixin(
|
|||
|
||||
return OpenAIEmbeddingsResponse(
|
||||
data=data,
|
||||
model=model_obj.provider_resource_id,
|
||||
model=provider_resource_id,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
|
|
@ -225,10 +230,16 @@ class LiteLLMOpenAIMixin(
|
|||
self,
|
||||
params: OpenAICompletionRequestWithExtraBody,
|
||||
) -> OpenAICompletion:
|
||||
if not self.model_store:
|
||||
raise ValueError("Model store is not initialized")
|
||||
|
||||
model_obj = await self.model_store.get_model(params.model)
|
||||
if model_obj.provider_resource_id is None:
|
||||
raise ValueError(f"Model {params.model} has no provider_resource_id")
|
||||
provider_resource_id = model_obj.provider_resource_id
|
||||
|
||||
request_params = await prepare_openai_completion_params(
|
||||
model=self.get_litellm_model_name(model_obj.provider_resource_id),
|
||||
model=self.get_litellm_model_name(provider_resource_id),
|
||||
prompt=params.prompt,
|
||||
best_of=params.best_of,
|
||||
echo=params.echo,
|
||||
|
|
@ -249,7 +260,8 @@ class LiteLLMOpenAIMixin(
|
|||
api_key=self.get_api_key(),
|
||||
api_base=self.api_base,
|
||||
)
|
||||
return await litellm.atext_completion(**request_params)
|
||||
# LiteLLM returns compatible type but mypy can't verify external library
|
||||
return await litellm.atext_completion(**request_params) # type: ignore[no-any-return] # external lib lacks type stubs
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
|
|
@ -265,10 +277,16 @@ class LiteLLMOpenAIMixin(
|
|||
elif "include_usage" not in stream_options:
|
||||
stream_options = {**stream_options, "include_usage": True}
|
||||
|
||||
if not self.model_store:
|
||||
raise ValueError("Model store is not initialized")
|
||||
|
||||
model_obj = await self.model_store.get_model(params.model)
|
||||
if model_obj.provider_resource_id is None:
|
||||
raise ValueError(f"Model {params.model} has no provider_resource_id")
|
||||
provider_resource_id = model_obj.provider_resource_id
|
||||
|
||||
request_params = await prepare_openai_completion_params(
|
||||
model=self.get_litellm_model_name(model_obj.provider_resource_id),
|
||||
model=self.get_litellm_model_name(provider_resource_id),
|
||||
messages=params.messages,
|
||||
frequency_penalty=params.frequency_penalty,
|
||||
function_call=params.function_call,
|
||||
|
|
@ -294,7 +312,8 @@ class LiteLLMOpenAIMixin(
|
|||
api_key=self.get_api_key(),
|
||||
api_base=self.api_base,
|
||||
)
|
||||
return await litellm.acompletion(**request_params)
|
||||
# LiteLLM returns compatible type but mypy can't verify external library
|
||||
return await litellm.acompletion(**request_params) # type: ignore[no-any-return] # external lib lacks type stubs
|
||||
|
||||
async def check_model_availability(self, model: str) -> bool:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ logger = get_logger(name=__name__, category="providers::utils")
|
|||
|
||||
|
||||
class RemoteInferenceProviderConfig(BaseModel):
|
||||
allowed_models: list[str] | None = Field( # TODO: make this non-optional and give a list() default
|
||||
allowed_models: list[str] | None = Field(
|
||||
default=None,
|
||||
description="List of models that should be registered with the model registry. If None, all models are allowed.",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -161,7 +161,9 @@ def get_sampling_strategy_options(params: SamplingParams) -> dict:
|
|||
if isinstance(params.strategy, GreedySamplingStrategy):
|
||||
options["temperature"] = 0.0
|
||||
elif isinstance(params.strategy, TopPSamplingStrategy):
|
||||
if params.strategy.temperature is not None:
|
||||
options["temperature"] = params.strategy.temperature
|
||||
if params.strategy.top_p is not None:
|
||||
options["top_p"] = params.strategy.top_p
|
||||
elif isinstance(params.strategy, TopKSamplingStrategy):
|
||||
options["top_k"] = params.strategy.top_k
|
||||
|
|
@ -192,12 +194,12 @@ def get_sampling_options(params: SamplingParams | None) -> dict:
|
|||
|
||||
def text_from_choice(choice) -> str:
|
||||
if hasattr(choice, "delta") and choice.delta:
|
||||
return choice.delta.content
|
||||
return choice.delta.content # type: ignore[no-any-return] # external OpenAI types lack precise annotations
|
||||
|
||||
if hasattr(choice, "message"):
|
||||
return choice.message.content
|
||||
return choice.message.content # type: ignore[no-any-return] # external OpenAI types lack precise annotations
|
||||
|
||||
return choice.text
|
||||
return choice.text # type: ignore[no-any-return] # external OpenAI types lack precise annotations
|
||||
|
||||
|
||||
def get_stop_reason(finish_reason: str) -> StopReason:
|
||||
|
|
@ -216,7 +218,7 @@ def convert_openai_completion_logprobs(
|
|||
) -> list[TokenLogProbs] | None:
|
||||
if not logprobs:
|
||||
return None
|
||||
if hasattr(logprobs, "top_logprobs"):
|
||||
if hasattr(logprobs, "top_logprobs") and logprobs.top_logprobs:
|
||||
return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs]
|
||||
|
||||
# Together supports logprobs with top_k=1 only. This means for each token position,
|
||||
|
|
@ -236,7 +238,7 @@ def convert_openai_completion_logprobs_stream(text: str, logprobs: float | OpenA
|
|||
if isinstance(logprobs, float):
|
||||
# Adapt response from Together CompletionChoicesChunk
|
||||
return [TokenLogProbs(logprobs_by_token={text: logprobs})]
|
||||
if hasattr(logprobs, "top_logprobs"):
|
||||
if hasattr(logprobs, "top_logprobs") and logprobs.top_logprobs:
|
||||
return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs]
|
||||
return None
|
||||
|
||||
|
|
@ -245,23 +247,24 @@ def process_completion_response(
|
|||
response: OpenAICompatCompletionResponse,
|
||||
) -> CompletionResponse:
|
||||
choice = response.choices[0]
|
||||
text = choice.text or ""
|
||||
# drop suffix <eot_id> if present and return stop reason as end of turn
|
||||
if choice.text.endswith("<|eot_id|>"):
|
||||
if text.endswith("<|eot_id|>"):
|
||||
return CompletionResponse(
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
content=choice.text[: -len("<|eot_id|>")],
|
||||
content=text[: -len("<|eot_id|>")],
|
||||
logprobs=convert_openai_completion_logprobs(choice.logprobs),
|
||||
)
|
||||
# drop suffix <eom_id> if present and return stop reason as end of message
|
||||
if choice.text.endswith("<|eom_id|>"):
|
||||
if text.endswith("<|eom_id|>"):
|
||||
return CompletionResponse(
|
||||
stop_reason=StopReason.end_of_message,
|
||||
content=choice.text[: -len("<|eom_id|>")],
|
||||
content=text[: -len("<|eom_id|>")],
|
||||
logprobs=convert_openai_completion_logprobs(choice.logprobs),
|
||||
)
|
||||
return CompletionResponse(
|
||||
stop_reason=get_stop_reason(choice.finish_reason),
|
||||
content=choice.text,
|
||||
stop_reason=get_stop_reason(choice.finish_reason or "stop"),
|
||||
content=text,
|
||||
logprobs=convert_openai_completion_logprobs(choice.logprobs),
|
||||
)
|
||||
|
||||
|
|
@ -272,10 +275,10 @@ def process_chat_completion_response(
|
|||
) -> ChatCompletionResponse:
|
||||
choice = response.choices[0]
|
||||
if choice.finish_reason == "tool_calls":
|
||||
if not choice.message or not choice.message.tool_calls:
|
||||
if not hasattr(choice, "message") or not choice.message or not choice.message.tool_calls: # type: ignore[attr-defined] # OpenAICompatCompletionChoice is runtime duck-typed
|
||||
raise ValueError("Tool calls are not present in the response")
|
||||
|
||||
tool_calls = [convert_tool_call(tool_call) for tool_call in choice.message.tool_calls]
|
||||
tool_calls = [convert_tool_call(tool_call) for tool_call in choice.message.tool_calls] # type: ignore[attr-defined] # OpenAICompatCompletionChoice is runtime duck-typed
|
||||
if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls):
|
||||
# If we couldn't parse a tool call, jsonify the tool calls and return them
|
||||
return ChatCompletionResponse(
|
||||
|
|
@ -287,9 +290,11 @@ def process_chat_completion_response(
|
|||
)
|
||||
else:
|
||||
# Otherwise, return tool calls as normal
|
||||
# Filter to only valid ToolCall objects
|
||||
valid_tool_calls = [tc for tc in tool_calls if isinstance(tc, ToolCall)]
|
||||
return ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
tool_calls=tool_calls,
|
||||
tool_calls=valid_tool_calls,
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
# Content is not optional
|
||||
content="",
|
||||
|
|
@ -299,7 +304,7 @@ def process_chat_completion_response(
|
|||
|
||||
# TODO: This does not work well with tool calls for vLLM remote provider
|
||||
# Ref: https://github.com/meta-llama/llama-stack/issues/1058
|
||||
raw_message = decode_assistant_message(text_from_choice(choice), get_stop_reason(choice.finish_reason))
|
||||
raw_message = decode_assistant_message(text_from_choice(choice), get_stop_reason(choice.finish_reason or "stop"))
|
||||
|
||||
# NOTE: If we do not set tools in chat-completion request, we should not
|
||||
# expect the ToolCall in the response. Instead, we should return the raw
|
||||
|
|
@ -324,8 +329,8 @@ def process_chat_completion_response(
|
|||
|
||||
return ChatCompletionResponse(
|
||||
completion_message=CompletionMessage(
|
||||
content=raw_message.content,
|
||||
stop_reason=raw_message.stop_reason,
|
||||
content=raw_message.content, # type: ignore[arg-type] # decode_assistant_message returns Union[str, InterleavedContent]
|
||||
stop_reason=raw_message.stop_reason or StopReason.end_of_turn,
|
||||
tool_calls=raw_message.tool_calls,
|
||||
),
|
||||
logprobs=None,
|
||||
|
|
@ -448,7 +453,7 @@ async def process_chat_completion_stream_response(
|
|||
)
|
||||
|
||||
# parse tool calls and report errors
|
||||
message = decode_assistant_message(buffer, stop_reason)
|
||||
message = decode_assistant_message(buffer, stop_reason or StopReason.end_of_turn)
|
||||
|
||||
parsed_tool_calls = len(message.tool_calls) > 0
|
||||
if ipython and not parsed_tool_calls:
|
||||
|
|
@ -463,7 +468,7 @@ async def process_chat_completion_stream_response(
|
|||
)
|
||||
)
|
||||
|
||||
request_tools = {t.tool_name: t for t in request.tools}
|
||||
request_tools = {t.tool_name: t for t in (request.tools or [])}
|
||||
for tool_call in message.tool_calls:
|
||||
if tool_call.tool_name in request_tools:
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
|
|
@ -525,7 +530,7 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals
|
|||
}
|
||||
|
||||
if hasattr(message, "tool_calls") and message.tool_calls:
|
||||
result["tool_calls"] = []
|
||||
tool_calls_list = []
|
||||
for tc in message.tool_calls:
|
||||
# The tool.tool_name can be a str or a BuiltinTool enum. If
|
||||
# it's the latter, convert to a string.
|
||||
|
|
@ -533,7 +538,7 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals
|
|||
if isinstance(tool_name, BuiltinTool):
|
||||
tool_name = tool_name.value
|
||||
|
||||
result["tool_calls"].append(
|
||||
tool_calls_list.append(
|
||||
{
|
||||
"id": tc.call_id,
|
||||
"type": "function",
|
||||
|
|
@ -543,6 +548,7 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals
|
|||
},
|
||||
}
|
||||
)
|
||||
result["tool_calls"] = tool_calls_list # type: ignore[assignment] # dict allows Any value, stricter type expected
|
||||
return result
|
||||
|
||||
|
||||
|
|
@ -608,7 +614,7 @@ async def convert_message_to_openai_dict_new(
|
|||
),
|
||||
)
|
||||
elif isinstance(content_, list):
|
||||
return [await impl(item) for item in content_]
|
||||
return [await impl(item) for item in content_] # type: ignore[misc] # recursive list comprehension confuses mypy's type narrowing
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {type(content_)}")
|
||||
|
||||
|
|
@ -620,7 +626,7 @@ async def convert_message_to_openai_dict_new(
|
|||
else:
|
||||
return [ret]
|
||||
|
||||
out: OpenAIChatCompletionMessage = None
|
||||
out: OpenAIChatCompletionMessage
|
||||
if isinstance(message, UserMessage):
|
||||
out = OpenAIChatCompletionUserMessage(
|
||||
role="user",
|
||||
|
|
@ -636,7 +642,7 @@ async def convert_message_to_openai_dict_new(
|
|||
),
|
||||
type="function",
|
||||
)
|
||||
for tool in message.tool_calls
|
||||
for tool in (message.tool_calls or [])
|
||||
]
|
||||
params = {}
|
||||
if tool_calls:
|
||||
|
|
@ -644,18 +650,18 @@ async def convert_message_to_openai_dict_new(
|
|||
out = OpenAIChatCompletionAssistantMessage(
|
||||
role="assistant",
|
||||
content=await _convert_message_content(message.content),
|
||||
**params,
|
||||
**params, # type: ignore[typeddict-item] # tool_calls dict expansion conflicts with TypedDict optional field
|
||||
)
|
||||
elif isinstance(message, ToolResponseMessage):
|
||||
out = OpenAIChatCompletionToolMessage(
|
||||
role="tool",
|
||||
tool_call_id=message.call_id,
|
||||
content=await _convert_message_content(message.content),
|
||||
content=await _convert_message_content(message.content), # type: ignore[typeddict-item] # content union type incompatible with TypedDict str requirement
|
||||
)
|
||||
elif isinstance(message, SystemMessage):
|
||||
out = OpenAIChatCompletionSystemMessage(
|
||||
role="system",
|
||||
content=await _convert_message_content(message.content),
|
||||
content=await _convert_message_content(message.content), # type: ignore[typeddict-item] # content union type incompatible with TypedDict str requirement
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported message type: {type(message)}")
|
||||
|
|
@ -758,16 +764,16 @@ def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
|
|||
function = out["function"]
|
||||
|
||||
if isinstance(tool.tool_name, BuiltinTool):
|
||||
function["name"] = tool.tool_name.value
|
||||
function["name"] = tool.tool_name.value # type: ignore[index] # dict value inferred as Any but mypy sees Collection[str]
|
||||
else:
|
||||
function["name"] = tool.tool_name
|
||||
function["name"] = tool.tool_name # type: ignore[index] # dict value inferred as Any but mypy sees Collection[str]
|
||||
|
||||
if tool.description:
|
||||
function["description"] = tool.description
|
||||
function["description"] = tool.description # type: ignore[index] # dict value inferred as Any but mypy sees Collection[str]
|
||||
|
||||
if tool.input_schema:
|
||||
# Pass through the entire JSON Schema as-is
|
||||
function["parameters"] = tool.input_schema
|
||||
function["parameters"] = tool.input_schema # type: ignore[index] # dict value inferred as Any but mypy sees Collection[str]
|
||||
|
||||
# NOTE: OpenAI does not support output_schema, so we drop it here
|
||||
# It's stored in LlamaStack for validation and other provider usage
|
||||
|
|
@ -815,15 +821,15 @@ def _convert_openai_request_tool_config(tool_choice: str | dict[str, Any] | None
|
|||
tool_config = ToolConfig()
|
||||
if tool_choice:
|
||||
try:
|
||||
tool_choice = ToolChoice(tool_choice)
|
||||
tool_choice = ToolChoice(tool_choice) # type: ignore[assignment] # reassigning to enum narrows union but mypy can't track after exception
|
||||
except ValueError:
|
||||
pass
|
||||
tool_config.tool_choice = tool_choice
|
||||
tool_config.tool_choice = tool_choice # type: ignore[assignment] # ToolConfig.tool_choice accepts Union[ToolChoice, dict] but mypy tracks narrower type
|
||||
return tool_config
|
||||
|
||||
|
||||
def _convert_openai_request_tools(tools: list[dict[str, Any]] | None = None) -> list[ToolDefinition]:
|
||||
lls_tools = []
|
||||
lls_tools: list[ToolDefinition] = []
|
||||
if not tools:
|
||||
return lls_tools
|
||||
|
||||
|
|
@ -843,16 +849,16 @@ def _convert_openai_request_tools(tools: list[dict[str, Any]] | None = None) ->
|
|||
|
||||
|
||||
def _convert_openai_request_response_format(
|
||||
response_format: OpenAIResponseFormatParam = None,
|
||||
response_format: OpenAIResponseFormatParam | None = None,
|
||||
):
|
||||
if not response_format:
|
||||
return None
|
||||
# response_format can be a dict or a pydantic model
|
||||
response_format = dict(response_format)
|
||||
if response_format.get("type", "") == "json_schema":
|
||||
response_format_dict = dict(response_format) # type: ignore[arg-type] # OpenAIResponseFormatParam union needs dict conversion
|
||||
if response_format_dict.get("type", "") == "json_schema":
|
||||
return JsonSchemaResponseFormat(
|
||||
type="json_schema",
|
||||
json_schema=response_format.get("json_schema", {}).get("schema", ""),
|
||||
type="json_schema", # type: ignore[arg-type] # Literal["json_schema"] incompatible with expected type
|
||||
json_schema=response_format_dict.get("json_schema", {}).get("schema", ""),
|
||||
)
|
||||
return None
|
||||
|
||||
|
|
@ -938,16 +944,15 @@ def _convert_openai_sampling_params(
|
|||
|
||||
# Map an explicit temperature of 0 to greedy sampling
|
||||
if temperature == 0:
|
||||
strategy = GreedySamplingStrategy()
|
||||
sampling_params.strategy = GreedySamplingStrategy()
|
||||
else:
|
||||
# OpenAI defaults to 1.0 for temperature and top_p if unset
|
||||
if temperature is None:
|
||||
temperature = 1.0
|
||||
if top_p is None:
|
||||
top_p = 1.0
|
||||
strategy = TopPSamplingStrategy(temperature=temperature, top_p=top_p)
|
||||
sampling_params.strategy = TopPSamplingStrategy(temperature=temperature, top_p=top_p) # type: ignore[assignment] # SamplingParams.strategy union accepts this type
|
||||
|
||||
sampling_params.strategy = strategy
|
||||
return sampling_params
|
||||
|
||||
|
||||
|
|
@ -957,23 +962,24 @@ def openai_messages_to_messages(
|
|||
"""
|
||||
Convert a list of OpenAIChatCompletionMessage into a list of Message.
|
||||
"""
|
||||
converted_messages = []
|
||||
converted_messages: list[Message] = []
|
||||
for message in messages:
|
||||
converted_message: Message
|
||||
if message.role == "system":
|
||||
converted_message = SystemMessage(content=openai_content_to_content(message.content))
|
||||
converted_message = SystemMessage(content=openai_content_to_content(message.content)) # type: ignore[arg-type] # OpenAI SDK uses aliased types internally that mypy sees as incompatible with base types
|
||||
elif message.role == "user":
|
||||
converted_message = UserMessage(content=openai_content_to_content(message.content))
|
||||
converted_message = UserMessage(content=openai_content_to_content(message.content)) # type: ignore[arg-type] # OpenAI SDK uses aliased types internally that mypy sees as incompatible with base types
|
||||
elif message.role == "assistant":
|
||||
converted_message = CompletionMessage(
|
||||
content=openai_content_to_content(message.content),
|
||||
tool_calls=_convert_openai_tool_calls(message.tool_calls),
|
||||
content=openai_content_to_content(message.content), # type: ignore[arg-type] # OpenAI SDK uses aliased types internally that mypy sees as incompatible with base types
|
||||
tool_calls=_convert_openai_tool_calls(message.tool_calls) if message.tool_calls else [], # type: ignore[arg-type] # OpenAI tool_calls type incompatible with conversion function
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
elif message.role == "tool":
|
||||
converted_message = ToolResponseMessage(
|
||||
role="tool",
|
||||
call_id=message.tool_call_id,
|
||||
content=openai_content_to_content(message.content),
|
||||
content=openai_content_to_content(message.content), # type: ignore[arg-type] # OpenAI SDK uses aliased types internally that mypy sees as incompatible with base types
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown role {message.role}")
|
||||
|
|
@ -990,9 +996,9 @@ def openai_content_to_content(content: str | Iterable[OpenAIChatCompletionConten
|
|||
return [openai_content_to_content(c) for c in content]
|
||||
elif hasattr(content, "type"):
|
||||
if content.type == "text":
|
||||
return TextContentItem(type="text", text=content.text)
|
||||
return TextContentItem(type="text", text=content.text) # type: ignore[attr-defined] # Iterable narrowed by hasattr check but mypy doesn't track
|
||||
elif content.type == "image_url":
|
||||
return ImageContentItem(type="image", image=_URLOrData(url=URL(uri=content.image_url.url)))
|
||||
return ImageContentItem(type="image", image=_URLOrData(url=URL(uri=content.image_url.url))) # type: ignore[attr-defined] # Iterable narrowed by hasattr check but mypy doesn't track
|
||||
else:
|
||||
raise ValueError(f"Unknown content type: {content.type}")
|
||||
else:
|
||||
|
|
@ -1041,9 +1047,9 @@ def convert_openai_chat_completion_choice(
|
|||
completion_message=CompletionMessage(
|
||||
content=choice.message.content or "", # CompletionMessage content is not optional
|
||||
stop_reason=_convert_openai_finish_reason(choice.finish_reason),
|
||||
tool_calls=_convert_openai_tool_calls(choice.message.tool_calls),
|
||||
tool_calls=_convert_openai_tool_calls(choice.message.tool_calls) if choice.message.tool_calls else [], # type: ignore[arg-type] # OpenAI tool_calls Optional type broadens union
|
||||
),
|
||||
logprobs=_convert_openai_logprobs(getattr(choice, "logprobs", None)),
|
||||
logprobs=_convert_openai_logprobs(getattr(choice, "logprobs", None)), # type: ignore[arg-type] # getattr returns Any, can't narrow without inspection
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -1070,7 +1076,7 @@ async def convert_openai_chat_completion_stream(
|
|||
choice = chunk.choices[0] # assuming only one choice per chunk
|
||||
|
||||
# we assume there's only one finish_reason in the stream
|
||||
stop_reason = _convert_openai_finish_reason(choice.finish_reason) or stop_reason
|
||||
stop_reason = _convert_openai_finish_reason(choice.finish_reason) if choice.finish_reason else stop_reason
|
||||
logprobs = getattr(choice, "logprobs", None)
|
||||
|
||||
# if there's a tool call, emit an event for each tool in the list
|
||||
|
|
@ -1083,7 +1089,7 @@ async def convert_openai_chat_completion_stream(
|
|||
event=ChatCompletionResponseEvent(
|
||||
event_type=event_type,
|
||||
delta=TextDelta(text=choice.delta.content),
|
||||
logprobs=_convert_openai_logprobs(logprobs),
|
||||
logprobs=_convert_openai_logprobs(logprobs), # type: ignore[arg-type] # logprobs type broadened from getattr result
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -1101,10 +1107,10 @@ async def convert_openai_chat_completion_stream(
|
|||
event=ChatCompletionResponseEvent(
|
||||
event_type=event_type,
|
||||
delta=ToolCallDelta(
|
||||
tool_call=_convert_openai_tool_calls([tool_call])[0],
|
||||
tool_call=_convert_openai_tool_calls([tool_call])[0], # type: ignore[arg-type, list-item] # delta tool_call type differs from complete tool_call
|
||||
parse_status=ToolCallParseStatus.succeeded,
|
||||
),
|
||||
logprobs=_convert_openai_logprobs(logprobs),
|
||||
logprobs=_convert_openai_logprobs(logprobs), # type: ignore[arg-type] # logprobs type broadened from getattr result
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
|
@ -1125,11 +1131,14 @@ async def convert_openai_chat_completion_stream(
|
|||
if tool_call.function.name:
|
||||
buffer["name"] = tool_call.function.name
|
||||
delta = f"{buffer['name']}("
|
||||
if buffer["content"] is not None:
|
||||
buffer["content"] += delta
|
||||
|
||||
if tool_call.function.arguments:
|
||||
delta = tool_call.function.arguments
|
||||
if buffer["arguments"] is not None and delta:
|
||||
buffer["arguments"] += delta
|
||||
if buffer["content"] is not None and delta:
|
||||
buffer["content"] += delta
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
|
|
@ -1139,7 +1148,7 @@ async def convert_openai_chat_completion_stream(
|
|||
tool_call=delta,
|
||||
parse_status=ToolCallParseStatus.in_progress,
|
||||
),
|
||||
logprobs=_convert_openai_logprobs(logprobs),
|
||||
logprobs=_convert_openai_logprobs(logprobs), # type: ignore[arg-type] # logprobs type broadened from getattr result
|
||||
)
|
||||
)
|
||||
elif choice.delta.content:
|
||||
|
|
@ -1147,7 +1156,7 @@ async def convert_openai_chat_completion_stream(
|
|||
event=ChatCompletionResponseEvent(
|
||||
event_type=event_type,
|
||||
delta=TextDelta(text=choice.delta.content or ""),
|
||||
logprobs=_convert_openai_logprobs(logprobs),
|
||||
logprobs=_convert_openai_logprobs(logprobs), # type: ignore[arg-type] # logprobs type broadened from getattr result
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -1155,6 +1164,7 @@ async def convert_openai_chat_completion_stream(
|
|||
logger.debug(f"toolcall_buffer[{idx}]: {buffer}")
|
||||
if buffer["name"]:
|
||||
delta = ")"
|
||||
if buffer["content"] is not None:
|
||||
buffer["content"] += delta
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
|
|
@ -1168,16 +1178,16 @@ async def convert_openai_chat_completion_stream(
|
|||
)
|
||||
|
||||
try:
|
||||
tool_call = ToolCall(
|
||||
call_id=buffer["call_id"],
|
||||
tool_name=buffer["name"],
|
||||
arguments=buffer["arguments"],
|
||||
parsed_tool_call = ToolCall(
|
||||
call_id=buffer["call_id"] or "",
|
||||
tool_name=buffer["name"] or "",
|
||||
arguments=buffer["arguments"] or "",
|
||||
)
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
tool_call=tool_call,
|
||||
tool_call=parsed_tool_call, # type: ignore[arg-type] # ToolCallDelta.tool_call accepts Union[str, ToolCall]
|
||||
parse_status=ToolCallParseStatus.succeeded,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
|
|
@ -1189,7 +1199,7 @@ async def convert_openai_chat_completion_stream(
|
|||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=ToolCallDelta(
|
||||
tool_call=buffer["content"],
|
||||
tool_call=buffer["content"], # type: ignore[arg-type] # ToolCallDelta.tool_call accepts Union[str, ToolCall]
|
||||
parse_status=ToolCallParseStatus.failed,
|
||||
),
|
||||
stop_reason=stop_reason,
|
||||
|
|
@ -1250,7 +1260,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
|||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||
messages = openai_messages_to_messages(messages)
|
||||
messages = openai_messages_to_messages(messages) # type: ignore[assignment] # converted from OpenAI to LlamaStack message format
|
||||
response_format = _convert_openai_request_response_format(response_format)
|
||||
sampling_params = _convert_openai_sampling_params(
|
||||
max_tokens=max_tokens,
|
||||
|
|
@ -1259,15 +1269,15 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
|||
)
|
||||
tool_config = _convert_openai_request_tool_config(tool_choice)
|
||||
|
||||
tools = _convert_openai_request_tools(tools)
|
||||
tools = _convert_openai_request_tools(tools) # type: ignore[assignment] # converted from OpenAI to LlamaStack tool format
|
||||
if tool_config.tool_choice == ToolChoice.none:
|
||||
tools = []
|
||||
tools = [] # type: ignore[assignment] # empty list narrows return type but mypy tracks broader type
|
||||
|
||||
outstanding_responses = []
|
||||
# "n" is the number of completions to generate per prompt
|
||||
n = n or 1
|
||||
for _i in range(0, n):
|
||||
response = self.chat_completion(
|
||||
response = self.chat_completion( # type: ignore[attr-defined] # mixin expects class to implement chat_completion
|
||||
model_id=model,
|
||||
messages=messages,
|
||||
sampling_params=sampling_params,
|
||||
|
|
@ -1279,7 +1289,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
|||
outstanding_responses.append(response)
|
||||
|
||||
if stream:
|
||||
return OpenAIChatCompletionToLlamaStackMixin._process_stream_response(self, model, outstanding_responses)
|
||||
return OpenAIChatCompletionToLlamaStackMixin._process_stream_response(self, model, outstanding_responses) # type: ignore[no-any-return] # mixin async generator return type too complex for mypy
|
||||
|
||||
return await OpenAIChatCompletionToLlamaStackMixin._process_non_stream_response(
|
||||
self, model, outstanding_responses
|
||||
|
|
@ -1295,14 +1305,16 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
|||
response = await outstanding_response
|
||||
async for chunk in response:
|
||||
event = chunk.event
|
||||
finish_reason = _convert_stop_reason_to_openai_finish_reason(event.stop_reason)
|
||||
finish_reason = (
|
||||
_convert_stop_reason_to_openai_finish_reason(event.stop_reason) if event.stop_reason else None
|
||||
)
|
||||
|
||||
if isinstance(event.delta, TextDelta):
|
||||
text_delta = event.delta.text
|
||||
delta = OpenAIChoiceDelta(content=text_delta)
|
||||
yield OpenAIChatCompletionChunk(
|
||||
id=id,
|
||||
choices=[OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)],
|
||||
choices=[OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)], # type: ignore[arg-type] # finish_reason Optional[str] incompatible with Literal union
|
||||
created=int(time.time()),
|
||||
model=model,
|
||||
object="chat.completion.chunk",
|
||||
|
|
@ -1310,13 +1322,17 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
|||
elif isinstance(event.delta, ToolCallDelta):
|
||||
if event.delta.parse_status == ToolCallParseStatus.succeeded:
|
||||
tool_call = event.delta.tool_call
|
||||
if isinstance(tool_call, str):
|
||||
continue
|
||||
|
||||
# First chunk includes full structure
|
||||
openai_tool_call = OpenAIChoiceDeltaToolCall(
|
||||
index=0,
|
||||
id=tool_call.call_id,
|
||||
function=OpenAIChoiceDeltaToolCallFunction(
|
||||
name=tool_call.tool_name,
|
||||
name=tool_call.tool_name
|
||||
if isinstance(tool_call.tool_name, str)
|
||||
else tool_call.tool_name.value, # type: ignore[arg-type] # enum .value extraction on Union confuses mypy
|
||||
arguments="",
|
||||
),
|
||||
)
|
||||
|
|
@ -1324,7 +1340,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
|||
yield OpenAIChatCompletionChunk(
|
||||
id=id,
|
||||
choices=[
|
||||
OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)
|
||||
OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta) # type: ignore[arg-type] # finish_reason Optional[str] incompatible with Literal union
|
||||
],
|
||||
created=int(time.time()),
|
||||
model=model,
|
||||
|
|
@ -1341,7 +1357,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
|||
yield OpenAIChatCompletionChunk(
|
||||
id=id,
|
||||
choices=[
|
||||
OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)
|
||||
OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta) # type: ignore[arg-type] # finish_reason Optional[str] incompatible with Literal union
|
||||
],
|
||||
created=int(time.time()),
|
||||
model=model,
|
||||
|
|
@ -1351,7 +1367,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
|||
async def _process_non_stream_response(
|
||||
self, model: str, outstanding_responses: list[Awaitable[ChatCompletionResponse]]
|
||||
) -> OpenAIChatCompletion:
|
||||
choices = []
|
||||
choices: list[OpenAIChatCompletionChoice] = []
|
||||
for outstanding_response in outstanding_responses:
|
||||
response = await outstanding_response
|
||||
completion_message = response.completion_message
|
||||
|
|
@ -1360,14 +1376,14 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
|||
|
||||
choice = OpenAIChatCompletionChoice(
|
||||
index=len(choices),
|
||||
message=message,
|
||||
message=message, # type: ignore[arg-type] # OpenAIChatCompletionMessage union incompatible with narrower Message type
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
choices.append(choice)
|
||||
choices.append(choice) # type: ignore[arg-type] # OpenAIChatCompletionChoice type annotation mismatch
|
||||
|
||||
return OpenAIChatCompletion(
|
||||
id=f"chatcmpl-{uuid.uuid4()}",
|
||||
choices=choices,
|
||||
choices=choices, # type: ignore[arg-type] # list[OpenAIChatCompletionChoice] union incompatible
|
||||
created=int(time.time()),
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
|
|
|
|||
|
|
@ -83,9 +83,6 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
# This is set in list_models() and used in check_model_availability()
|
||||
_model_cache: dict[str, Model] = {}
|
||||
|
||||
# List of allowed models for this provider, if empty all models allowed
|
||||
allowed_models: list[str] = []
|
||||
|
||||
# Optional field name in provider data to look for API key, which takes precedence
|
||||
provider_data_api_key_field: str | None = None
|
||||
|
||||
|
|
@ -441,7 +438,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
|||
for provider_model_id in provider_models_ids:
|
||||
if not isinstance(provider_model_id, str):
|
||||
raise ValueError(f"Model ID {provider_model_id} from list_provider_model_ids() is not a string")
|
||||
if self.allowed_models and provider_model_id not in self.allowed_models:
|
||||
if self.config.allowed_models is not None and provider_model_id not in self.config.allowed_models:
|
||||
logger.info(f"Skipping model {provider_model_id} as it is not in the allowed models list")
|
||||
continue
|
||||
model = self.construct_model_from_identifier(provider_model_id)
|
||||
|
|
|
|||
|
|
@ -196,6 +196,7 @@ def make_overlapped_chunks(
|
|||
chunks.append(
|
||||
Chunk(
|
||||
content=chunk,
|
||||
chunk_id=chunk_id,
|
||||
metadata=chunk_metadata,
|
||||
chunk_metadata=backend_chunk_metadata,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -70,13 +70,13 @@ class ResponsesStore:
|
|||
base_store = sqlstore_impl(self.reference)
|
||||
self.sql_store = AuthorizedSqlStore(base_store, self.policy)
|
||||
|
||||
# Disable write queue for SQLite since WAL mode handles concurrency
|
||||
# Keep it enabled for other backends (like Postgres) for performance
|
||||
backend_config = _SQLSTORE_BACKENDS.get(self.reference.backend)
|
||||
if backend_config is None:
|
||||
raise ValueError(
|
||||
f"Unregistered SQL backend '{self.reference.backend}'. Registered backends: {sorted(_SQLSTORE_BACKENDS)}"
|
||||
)
|
||||
if backend_config.type == StorageBackendType.SQL_SQLITE:
|
||||
if backend_config and backend_config.type == StorageBackendType.SQL_SQLITE:
|
||||
self.enable_write_queue = False
|
||||
logger.debug("Write queue disabled for SQLite (WAL mode handles concurrency)")
|
||||
|
||||
await self.sql_store.create_table(
|
||||
"openai_responses",
|
||||
{
|
||||
|
|
@ -99,8 +99,9 @@ class ResponsesStore:
|
|||
self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
|
||||
for _ in range(self._num_writers):
|
||||
self._worker_tasks.append(asyncio.create_task(self._worker_loop()))
|
||||
else:
|
||||
logger.debug("Write queue disabled for SQLite to avoid concurrency issues")
|
||||
logger.debug(
|
||||
f"Responses store write queue enabled with {self._num_writers} writers, max queue size {self._max_write_queue_size}"
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if not self._worker_tasks:
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from sqlalchemy import (
|
|||
String,
|
||||
Table,
|
||||
Text,
|
||||
event,
|
||||
inspect,
|
||||
select,
|
||||
text,
|
||||
|
|
@ -75,7 +76,36 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
|||
self.metadata = MetaData()
|
||||
|
||||
def create_engine(self) -> AsyncEngine:
|
||||
return create_async_engine(self.config.engine_str, pool_pre_ping=True)
|
||||
# Configure connection args for better concurrency support
|
||||
connect_args = {}
|
||||
if "sqlite" in self.config.engine_str:
|
||||
# SQLite-specific optimizations for concurrent access
|
||||
# With WAL mode, most locks resolve in milliseconds, but allow up to 5s for edge cases
|
||||
connect_args["timeout"] = 5.0
|
||||
connect_args["check_same_thread"] = False # Allow usage across asyncio tasks
|
||||
|
||||
engine = create_async_engine(
|
||||
self.config.engine_str,
|
||||
pool_pre_ping=True,
|
||||
connect_args=connect_args,
|
||||
)
|
||||
|
||||
# Enable WAL mode for SQLite to support concurrent readers and writers
|
||||
if "sqlite" in self.config.engine_str:
|
||||
|
||||
@event.listens_for(engine.sync_engine, "connect")
|
||||
def set_sqlite_pragma(dbapi_conn, connection_record):
|
||||
cursor = dbapi_conn.cursor()
|
||||
# Enable Write-Ahead Logging for better concurrency
|
||||
cursor.execute("PRAGMA journal_mode=WAL")
|
||||
# Set busy timeout to 5 seconds (retry instead of immediate failure)
|
||||
# With WAL mode, locks should be brief; if we hit 5s there's a bigger issue
|
||||
cursor.execute("PRAGMA busy_timeout=5000")
|
||||
# Use NORMAL synchronous mode for better performance (still safe with WAL)
|
||||
cursor.execute("PRAGMA synchronous=NORMAL")
|
||||
cursor.close()
|
||||
|
||||
return engine
|
||||
|
||||
async def create_table(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -430,6 +430,32 @@ def _unwrap_generic_list(typ: type[list[T]]) -> type[T]:
|
|||
return list_type # type: ignore[no-any-return]
|
||||
|
||||
|
||||
def is_generic_sequence(typ: object) -> bool:
|
||||
"True if the specified type is a generic Sequence, i.e. `Sequence[T]`."
|
||||
import collections.abc
|
||||
|
||||
typ = unwrap_annotated_type(typ)
|
||||
return typing.get_origin(typ) is collections.abc.Sequence
|
||||
|
||||
|
||||
def unwrap_generic_sequence(typ: object) -> type:
|
||||
"""
|
||||
Extracts the item type of a Sequence type.
|
||||
|
||||
:param typ: The Sequence type `Sequence[T]`.
|
||||
:returns: The item type `T`.
|
||||
"""
|
||||
|
||||
return rewrap_annotated_type(_unwrap_generic_sequence, typ) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def _unwrap_generic_sequence(typ: object) -> type:
|
||||
"Extracts the item type of a Sequence type (e.g. returns `T` for `Sequence[T]`)."
|
||||
|
||||
(sequence_type,) = typing.get_args(typ) # unpack single tuple element
|
||||
return sequence_type # type: ignore[no-any-return]
|
||||
|
||||
|
||||
def is_generic_set(typ: object) -> TypeGuard[type[set]]:
|
||||
"True if the specified type is a generic set, i.e. `Set[T]`."
|
||||
|
||||
|
|
|
|||
|
|
@ -18,10 +18,12 @@ from .inspection import (
|
|||
TypeLike,
|
||||
is_generic_dict,
|
||||
is_generic_list,
|
||||
is_generic_sequence,
|
||||
is_type_optional,
|
||||
is_type_union,
|
||||
unwrap_generic_dict,
|
||||
unwrap_generic_list,
|
||||
unwrap_generic_sequence,
|
||||
unwrap_optional_type,
|
||||
unwrap_union_types,
|
||||
)
|
||||
|
|
@ -155,24 +157,28 @@ def python_type_to_name(data_type: TypeLike, force: bool = False) -> str:
|
|||
if metadata is not None:
|
||||
# type is Annotated[T, ...]
|
||||
arg = typing.get_args(data_type)[0]
|
||||
return python_type_to_name(arg)
|
||||
return python_type_to_name(arg, force=force)
|
||||
|
||||
if force:
|
||||
# generic types
|
||||
if is_type_optional(data_type, strict=True):
|
||||
inner_name = python_type_to_name(unwrap_optional_type(data_type))
|
||||
inner_name = python_type_to_name(unwrap_optional_type(data_type), force=True)
|
||||
return f"Optional__{inner_name}"
|
||||
elif is_generic_list(data_type):
|
||||
item_name = python_type_to_name(unwrap_generic_list(data_type))
|
||||
item_name = python_type_to_name(unwrap_generic_list(data_type), force=True)
|
||||
return f"List__{item_name}"
|
||||
elif is_generic_sequence(data_type):
|
||||
# Treat Sequence the same as List for schema generation purposes
|
||||
item_name = python_type_to_name(unwrap_generic_sequence(data_type), force=True)
|
||||
return f"List__{item_name}"
|
||||
elif is_generic_dict(data_type):
|
||||
key_type, value_type = unwrap_generic_dict(data_type)
|
||||
key_name = python_type_to_name(key_type)
|
||||
value_name = python_type_to_name(value_type)
|
||||
key_name = python_type_to_name(key_type, force=True)
|
||||
value_name = python_type_to_name(value_type, force=True)
|
||||
return f"Dict__{key_name}__{value_name}"
|
||||
elif is_type_union(data_type):
|
||||
member_types = unwrap_union_types(data_type)
|
||||
member_names = "__".join(python_type_to_name(member_type) for member_type in member_types)
|
||||
member_names = "__".join(python_type_to_name(member_type, force=True) for member_type in member_types)
|
||||
return f"Union__{member_names}"
|
||||
|
||||
# named system or user-defined type
|
||||
|
|
|
|||
|
|
@ -111,7 +111,7 @@ def get_class_property_docstrings(
|
|||
def docstring_to_schema(data_type: type) -> Schema:
|
||||
short_description, long_description = get_class_docstrings(data_type)
|
||||
schema: Schema = {
|
||||
"title": python_type_to_name(data_type),
|
||||
"title": python_type_to_name(data_type, force=True),
|
||||
}
|
||||
|
||||
description = "\n".join(filter(None, [short_description, long_description]))
|
||||
|
|
@ -417,6 +417,10 @@ class JsonSchemaGenerator:
|
|||
if origin_type is list:
|
||||
(list_type,) = typing.get_args(typ) # unpack single tuple element
|
||||
return {"type": "array", "items": self.type_to_schema(list_type)}
|
||||
elif origin_type is collections.abc.Sequence:
|
||||
# Treat Sequence the same as list for JSON schema (both are arrays)
|
||||
(sequence_type,) = typing.get_args(typ) # unpack single tuple element
|
||||
return {"type": "array", "items": self.type_to_schema(sequence_type)}
|
||||
elif origin_type is dict:
|
||||
key_type, value_type = typing.get_args(typ)
|
||||
if not (key_type is str or key_type is int or is_type_enum(key_type)):
|
||||
|
|
|
|||
|
|
@ -51,7 +51,11 @@ async function proxyRequest(request: NextRequest, method: string) {
|
|||
);
|
||||
|
||||
// Create response with same status and headers
|
||||
const proxyResponse = new NextResponse(responseText, {
|
||||
// Handle 204 No Content responses specially
|
||||
const proxyResponse =
|
||||
response.status === 204
|
||||
? new NextResponse(null, { status: 204 })
|
||||
: new NextResponse(responseText, {
|
||||
status: response.status,
|
||||
statusText: response.statusText,
|
||||
});
|
||||
|
|
|
|||
5
src/llama_stack/ui/app/prompts/page.tsx
Normal file
5
src/llama_stack/ui/app/prompts/page.tsx
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
import { PromptManagement } from "@/components/prompts";
|
||||
|
||||
export default function PromptsPage() {
|
||||
return <PromptManagement />;
|
||||
}
|
||||
|
|
@ -8,6 +8,7 @@ import {
|
|||
MessageCircle,
|
||||
Settings2,
|
||||
Compass,
|
||||
FileText,
|
||||
} from "lucide-react";
|
||||
import Link from "next/link";
|
||||
import { usePathname } from "next/navigation";
|
||||
|
|
@ -50,6 +51,11 @@ const manageItems = [
|
|||
url: "/logs/vector-stores",
|
||||
icon: Database,
|
||||
},
|
||||
{
|
||||
title: "Prompts",
|
||||
url: "/prompts",
|
||||
icon: FileText,
|
||||
},
|
||||
{
|
||||
title: "Documentation",
|
||||
url: "https://llama-stack.readthedocs.io/en/latest/references/api_reference/index.html",
|
||||
|
|
|
|||
4
src/llama_stack/ui/components/prompts/index.ts
Normal file
4
src/llama_stack/ui/components/prompts/index.ts
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
export { PromptManagement } from "./prompt-management";
|
||||
export { PromptList } from "./prompt-list";
|
||||
export { PromptEditor } from "./prompt-editor";
|
||||
export * from "./types";
|
||||
309
src/llama_stack/ui/components/prompts/prompt-editor.test.tsx
Normal file
309
src/llama_stack/ui/components/prompts/prompt-editor.test.tsx
Normal file
|
|
@ -0,0 +1,309 @@
|
|||
import React from "react";
|
||||
import { render, screen, fireEvent } from "@testing-library/react";
|
||||
import "@testing-library/jest-dom";
|
||||
import { PromptEditor } from "./prompt-editor";
|
||||
import type { Prompt, PromptFormData } from "./types";
|
||||
|
||||
describe("PromptEditor", () => {
|
||||
const mockOnSave = jest.fn();
|
||||
const mockOnCancel = jest.fn();
|
||||
const mockOnDelete = jest.fn();
|
||||
|
||||
const defaultProps = {
|
||||
onSave: mockOnSave,
|
||||
onCancel: mockOnCancel,
|
||||
onDelete: mockOnDelete,
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe("Create Mode", () => {
|
||||
test("renders create form correctly", () => {
|
||||
render(<PromptEditor {...defaultProps} />);
|
||||
|
||||
expect(screen.getByLabelText("Prompt Content *")).toBeInTheDocument();
|
||||
expect(screen.getByText("Variables")).toBeInTheDocument();
|
||||
expect(screen.getByText("Preview")).toBeInTheDocument();
|
||||
expect(screen.getByText("Create Prompt")).toBeInTheDocument();
|
||||
expect(screen.getByText("Cancel")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("shows preview placeholder when no content", () => {
|
||||
render(<PromptEditor {...defaultProps} />);
|
||||
|
||||
expect(
|
||||
screen.getByText("Enter content to preview the compiled prompt")
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("submits form with correct data", () => {
|
||||
render(<PromptEditor {...defaultProps} />);
|
||||
|
||||
const promptInput = screen.getByLabelText("Prompt Content *");
|
||||
fireEvent.change(promptInput, {
|
||||
target: { value: "Hello {{name}}, welcome!" },
|
||||
});
|
||||
|
||||
fireEvent.click(screen.getByText("Create Prompt"));
|
||||
|
||||
expect(mockOnSave).toHaveBeenCalledWith({
|
||||
prompt: "Hello {{name}}, welcome!",
|
||||
variables: [],
|
||||
});
|
||||
});
|
||||
|
||||
test("prevents submission with empty prompt", () => {
|
||||
render(<PromptEditor {...defaultProps} />);
|
||||
|
||||
fireEvent.click(screen.getByText("Create Prompt"));
|
||||
|
||||
expect(mockOnSave).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Edit Mode", () => {
|
||||
const mockPrompt: Prompt = {
|
||||
prompt_id: "prompt_123",
|
||||
prompt: "Hello {{name}}, how is {{weather}}?",
|
||||
version: 1,
|
||||
variables: ["name", "weather"],
|
||||
is_default: true,
|
||||
};
|
||||
|
||||
test("renders edit form with existing data", () => {
|
||||
render(<PromptEditor {...defaultProps} prompt={mockPrompt} />);
|
||||
|
||||
expect(
|
||||
screen.getByDisplayValue("Hello {{name}}, how is {{weather}}?")
|
||||
).toBeInTheDocument();
|
||||
expect(screen.getAllByText("name")).toHaveLength(2); // One in variables, one in preview
|
||||
expect(screen.getAllByText("weather")).toHaveLength(2); // One in variables, one in preview
|
||||
expect(screen.getByText("Update Prompt")).toBeInTheDocument();
|
||||
expect(screen.getByText("Delete Prompt")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("submits updated data correctly", () => {
|
||||
render(<PromptEditor {...defaultProps} prompt={mockPrompt} />);
|
||||
|
||||
const promptInput = screen.getByLabelText("Prompt Content *");
|
||||
fireEvent.change(promptInput, {
|
||||
target: { value: "Updated: Hello {{name}}!" },
|
||||
});
|
||||
|
||||
fireEvent.click(screen.getByText("Update Prompt"));
|
||||
|
||||
expect(mockOnSave).toHaveBeenCalledWith({
|
||||
prompt: "Updated: Hello {{name}}!",
|
||||
variables: ["name", "weather"],
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("Variables Management", () => {
|
||||
test("adds new variable", () => {
|
||||
render(<PromptEditor {...defaultProps} />);
|
||||
|
||||
const variableInput = screen.getByPlaceholderText(
|
||||
"Add variable name (e.g. user_name, topic)"
|
||||
);
|
||||
fireEvent.change(variableInput, { target: { value: "testVar" } });
|
||||
fireEvent.click(screen.getByText("Add"));
|
||||
|
||||
expect(screen.getByText("testVar")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("prevents adding duplicate variables", () => {
|
||||
render(<PromptEditor {...defaultProps} />);
|
||||
|
||||
const variableInput = screen.getByPlaceholderText(
|
||||
"Add variable name (e.g. user_name, topic)"
|
||||
);
|
||||
|
||||
// Add first variable
|
||||
fireEvent.change(variableInput, { target: { value: "test" } });
|
||||
fireEvent.click(screen.getByText("Add"));
|
||||
|
||||
// Try to add same variable again
|
||||
fireEvent.change(variableInput, { target: { value: "test" } });
|
||||
|
||||
// Button should be disabled
|
||||
expect(screen.getByText("Add")).toBeDisabled();
|
||||
});
|
||||
|
||||
test("removes variable", () => {
|
||||
const mockPrompt: Prompt = {
|
||||
prompt_id: "prompt_123",
|
||||
prompt: "Hello {{name}}",
|
||||
version: 1,
|
||||
variables: ["name", "location"],
|
||||
is_default: true,
|
||||
};
|
||||
|
||||
render(<PromptEditor {...defaultProps} prompt={mockPrompt} />);
|
||||
|
||||
// Check that both variables are present initially
|
||||
expect(screen.getAllByText("name").length).toBeGreaterThan(0);
|
||||
expect(screen.getAllByText("location").length).toBeGreaterThan(0);
|
||||
|
||||
// Remove the location variable by clicking the X button with the specific title
|
||||
const removeLocationButton = screen.getByTitle(
|
||||
"Remove location variable"
|
||||
);
|
||||
fireEvent.click(removeLocationButton);
|
||||
|
||||
// Name should still be there, location should be gone from the variables section
|
||||
expect(screen.getAllByText("name").length).toBeGreaterThan(0);
|
||||
expect(
|
||||
screen.queryByTitle("Remove location variable")
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("adds variable on Enter key", () => {
|
||||
render(<PromptEditor {...defaultProps} />);
|
||||
|
||||
const variableInput = screen.getByPlaceholderText(
|
||||
"Add variable name (e.g. user_name, topic)"
|
||||
);
|
||||
fireEvent.change(variableInput, { target: { value: "enterVar" } });
|
||||
|
||||
// Simulate Enter key press
|
||||
fireEvent.keyPress(variableInput, {
|
||||
key: "Enter",
|
||||
code: "Enter",
|
||||
charCode: 13,
|
||||
preventDefault: jest.fn(),
|
||||
});
|
||||
|
||||
// Check if the variable was added by looking for the badge
|
||||
expect(screen.getAllByText("enterVar").length).toBeGreaterThan(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe("Preview Functionality", () => {
|
||||
test("shows live preview with variables", () => {
|
||||
render(<PromptEditor {...defaultProps} />);
|
||||
|
||||
// Add prompt content
|
||||
const promptInput = screen.getByLabelText("Prompt Content *");
|
||||
fireEvent.change(promptInput, {
|
||||
target: { value: "Hello {{name}}, welcome to {{place}}!" },
|
||||
});
|
||||
|
||||
// Add variables
|
||||
const variableInput = screen.getByPlaceholderText(
|
||||
"Add variable name (e.g. user_name, topic)"
|
||||
);
|
||||
fireEvent.change(variableInput, { target: { value: "name" } });
|
||||
fireEvent.click(screen.getByText("Add"));
|
||||
|
||||
fireEvent.change(variableInput, { target: { value: "place" } });
|
||||
fireEvent.click(screen.getByText("Add"));
|
||||
|
||||
// Check that preview area shows the content
|
||||
expect(screen.getByText("Compiled Prompt")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("shows variable value inputs in preview", () => {
|
||||
const mockPrompt: Prompt = {
|
||||
prompt_id: "prompt_123",
|
||||
prompt: "Hello {{name}}",
|
||||
version: 1,
|
||||
variables: ["name"],
|
||||
is_default: true,
|
||||
};
|
||||
|
||||
render(<PromptEditor {...defaultProps} prompt={mockPrompt} />);
|
||||
|
||||
expect(screen.getByText("Variable Values")).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByPlaceholderText("Enter value for name")
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("shows color legend for variable states", () => {
|
||||
render(<PromptEditor {...defaultProps} />);
|
||||
|
||||
// Add content to show preview
|
||||
const promptInput = screen.getByLabelText("Prompt Content *");
|
||||
fireEvent.change(promptInput, {
|
||||
target: { value: "Hello {{name}}" },
|
||||
});
|
||||
|
||||
expect(screen.getByText("Used")).toBeInTheDocument();
|
||||
expect(screen.getByText("Unused")).toBeInTheDocument();
|
||||
expect(screen.getByText("Undefined")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Error Handling", () => {
|
||||
test("displays error message", () => {
|
||||
const errorMessage = "Prompt contains undeclared variables";
|
||||
render(<PromptEditor {...defaultProps} error={errorMessage} />);
|
||||
|
||||
expect(screen.getByText(errorMessage)).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Delete Functionality", () => {
|
||||
const mockPrompt: Prompt = {
|
||||
prompt_id: "prompt_123",
|
||||
prompt: "Hello {{name}}",
|
||||
version: 1,
|
||||
variables: ["name"],
|
||||
is_default: true,
|
||||
};
|
||||
|
||||
test("shows delete button in edit mode", () => {
|
||||
render(<PromptEditor {...defaultProps} prompt={mockPrompt} />);
|
||||
|
||||
expect(screen.getByText("Delete Prompt")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("hides delete button in create mode", () => {
|
||||
render(<PromptEditor {...defaultProps} />);
|
||||
|
||||
expect(screen.queryByText("Delete Prompt")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("calls onDelete with confirmation", () => {
|
||||
const originalConfirm = window.confirm;
|
||||
window.confirm = jest.fn(() => true);
|
||||
|
||||
render(<PromptEditor {...defaultProps} prompt={mockPrompt} />);
|
||||
|
||||
fireEvent.click(screen.getByText("Delete Prompt"));
|
||||
|
||||
expect(window.confirm).toHaveBeenCalledWith(
|
||||
"Are you sure you want to delete this prompt? This action cannot be undone."
|
||||
);
|
||||
expect(mockOnDelete).toHaveBeenCalledWith("prompt_123");
|
||||
|
||||
window.confirm = originalConfirm;
|
||||
});
|
||||
|
||||
test("does not delete when confirmation is cancelled", () => {
|
||||
const originalConfirm = window.confirm;
|
||||
window.confirm = jest.fn(() => false);
|
||||
|
||||
render(<PromptEditor {...defaultProps} prompt={mockPrompt} />);
|
||||
|
||||
fireEvent.click(screen.getByText("Delete Prompt"));
|
||||
|
||||
expect(mockOnDelete).not.toHaveBeenCalled();
|
||||
|
||||
window.confirm = originalConfirm;
|
||||
});
|
||||
});
|
||||
|
||||
describe("Cancel Functionality", () => {
|
||||
test("calls onCancel when cancel button is clicked", () => {
|
||||
render(<PromptEditor {...defaultProps} />);
|
||||
|
||||
fireEvent.click(screen.getByText("Cancel"));
|
||||
|
||||
expect(mockOnCancel).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
||||
346
src/llama_stack/ui/components/prompts/prompt-editor.tsx
Normal file
346
src/llama_stack/ui/components/prompts/prompt-editor.tsx
Normal file
|
|
@ -0,0 +1,346 @@
|
|||
"use client";
|
||||
|
||||
import { useState, useEffect } from "react";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import { Textarea } from "@/components/ui/textarea";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import {
|
||||
Card,
|
||||
CardContent,
|
||||
CardDescription,
|
||||
CardHeader,
|
||||
CardTitle,
|
||||
} from "@/components/ui/card";
|
||||
import { Separator } from "@/components/ui/separator";
|
||||
import { X, Plus, Save, Trash2 } from "lucide-react";
|
||||
import { Prompt, PromptFormData } from "./types";
|
||||
|
||||
interface PromptEditorProps {
|
||||
prompt?: Prompt;
|
||||
onSave: (prompt: PromptFormData) => void;
|
||||
onCancel: () => void;
|
||||
onDelete?: (promptId: string) => void;
|
||||
error?: string | null;
|
||||
}
|
||||
|
||||
export function PromptEditor({
|
||||
prompt,
|
||||
onSave,
|
||||
onCancel,
|
||||
onDelete,
|
||||
error,
|
||||
}: PromptEditorProps) {
|
||||
const [formData, setFormData] = useState<PromptFormData>({
|
||||
prompt: "",
|
||||
variables: [],
|
||||
});
|
||||
|
||||
const [newVariable, setNewVariable] = useState("");
|
||||
const [variableValues, setVariableValues] = useState<Record<string, string>>(
|
||||
{}
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (prompt) {
|
||||
setFormData({
|
||||
prompt: prompt.prompt || "",
|
||||
variables: prompt.variables || [],
|
||||
});
|
||||
}
|
||||
}, [prompt]);
|
||||
|
||||
const handleSubmit = (e: React.FormEvent) => {
|
||||
e.preventDefault();
|
||||
if (!formData.prompt.trim()) {
|
||||
return;
|
||||
}
|
||||
onSave(formData);
|
||||
};
|
||||
|
||||
const addVariable = () => {
|
||||
if (
|
||||
newVariable.trim() &&
|
||||
!formData.variables.includes(newVariable.trim())
|
||||
) {
|
||||
setFormData(prev => ({
|
||||
...prev,
|
||||
variables: [...prev.variables, newVariable.trim()],
|
||||
}));
|
||||
setNewVariable("");
|
||||
}
|
||||
};
|
||||
|
||||
const removeVariable = (variableToRemove: string) => {
|
||||
setFormData(prev => ({
|
||||
...prev,
|
||||
variables: prev.variables.filter(
|
||||
variable => variable !== variableToRemove
|
||||
),
|
||||
}));
|
||||
};
|
||||
|
||||
const renderPreview = () => {
|
||||
const text = formData.prompt;
|
||||
if (!text) return text;
|
||||
|
||||
// Split text by variable patterns and process each part
|
||||
const parts = text.split(/(\{\{\s*\w+\s*\}\})/g);
|
||||
|
||||
return parts.map((part, index) => {
|
||||
const variableMatch = part.match(/\{\{\s*(\w+)\s*\}\}/);
|
||||
if (variableMatch) {
|
||||
const variableName = variableMatch[1];
|
||||
const isDefined = formData.variables.includes(variableName);
|
||||
const value = variableValues[variableName];
|
||||
|
||||
if (!isDefined) {
|
||||
// Variable not in variables list - likely a typo/bug (RED)
|
||||
return (
|
||||
<span
|
||||
key={index}
|
||||
className="bg-red-100 text-red-800 dark:bg-red-900 dark:text-red-200 px-1 rounded font-medium"
|
||||
>
|
||||
{part}
|
||||
</span>
|
||||
);
|
||||
} else if (value && value.trim()) {
|
||||
// Variable defined and has value - show the value (GREEN)
|
||||
return (
|
||||
<span
|
||||
key={index}
|
||||
className="bg-green-100 text-green-800 dark:bg-green-900 dark:text-green-200 px-1 rounded font-medium"
|
||||
>
|
||||
{value}
|
||||
</span>
|
||||
);
|
||||
} else {
|
||||
// Variable defined but empty (YELLOW)
|
||||
return (
|
||||
<span
|
||||
key={index}
|
||||
className="bg-yellow-100 text-yellow-800 dark:bg-yellow-900 dark:text-yellow-200 px-1 rounded font-medium"
|
||||
>
|
||||
{part}
|
||||
</span>
|
||||
);
|
||||
}
|
||||
}
|
||||
return part;
|
||||
});
|
||||
};
|
||||
|
||||
const updateVariableValue = (variable: string, value: string) => {
|
||||
setVariableValues(prev => ({
|
||||
...prev,
|
||||
[variable]: value,
|
||||
}));
|
||||
};
|
||||
|
||||
return (
|
||||
<form onSubmit={handleSubmit} className="space-y-6">
|
||||
{error && (
|
||||
<div className="p-4 bg-destructive/10 border border-destructive/20 rounded-md">
|
||||
<p className="text-destructive text-sm">{error}</p>
|
||||
</div>
|
||||
)}
|
||||
<div className="grid grid-cols-1 lg:grid-cols-2 gap-6">
|
||||
{/* Form Section */}
|
||||
<div className="space-y-4">
|
||||
<div>
|
||||
<Label htmlFor="prompt">Prompt Content *</Label>
|
||||
<Textarea
|
||||
id="prompt"
|
||||
value={formData.prompt}
|
||||
onChange={e =>
|
||||
setFormData(prev => ({ ...prev, prompt: e.target.value }))
|
||||
}
|
||||
placeholder="Enter your prompt content here. Use {{variable_name}} for dynamic variables."
|
||||
className="min-h-32 font-mono mt-2"
|
||||
required
|
||||
/>
|
||||
<p className="text-xs text-muted-foreground mt-2">
|
||||
Use double curly braces around variable names, e.g.,{" "}
|
||||
{`{{user_name}}`} or {`{{topic}}`}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="space-y-3">
|
||||
<Label className="text-sm font-medium">Variables</Label>
|
||||
|
||||
<div className="flex gap-2 mt-2">
|
||||
<Input
|
||||
value={newVariable}
|
||||
onChange={e => setNewVariable(e.target.value)}
|
||||
placeholder="Add variable name (e.g. user_name, topic)"
|
||||
onKeyPress={e =>
|
||||
e.key === "Enter" && (e.preventDefault(), addVariable())
|
||||
}
|
||||
className="flex-1"
|
||||
/>
|
||||
<Button
|
||||
type="button"
|
||||
onClick={addVariable}
|
||||
size="sm"
|
||||
disabled={
|
||||
!newVariable.trim() ||
|
||||
formData.variables.includes(newVariable.trim())
|
||||
}
|
||||
>
|
||||
<Plus className="h-4 w-4" />
|
||||
Add
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
{formData.variables.length > 0 && (
|
||||
<div className="border rounded-lg p-3 bg-muted/20">
|
||||
<div className="flex flex-wrap gap-2">
|
||||
{formData.variables.map(variable => (
|
||||
<Badge
|
||||
key={variable}
|
||||
variant="secondary"
|
||||
className="text-sm px-2 py-1"
|
||||
>
|
||||
{variable}
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => removeVariable(variable)}
|
||||
className="ml-2 hover:text-destructive transition-colors"
|
||||
title={`Remove ${variable} variable`}
|
||||
>
|
||||
<X className="h-3 w-3" />
|
||||
</button>
|
||||
</Badge>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Variables that can be used in the prompt template. Each variable
|
||||
should match a {`{{variable}}`} placeholder in the content above.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Preview Section */}
|
||||
<div className="space-y-4">
|
||||
<Card>
|
||||
<CardHeader>
|
||||
<CardTitle className="text-lg">Preview</CardTitle>
|
||||
<CardDescription>
|
||||
Live preview of compiled prompt and variable substitution.
|
||||
</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent className="space-y-4">
|
||||
{formData.prompt ? (
|
||||
<>
|
||||
{/* Variable Values */}
|
||||
{formData.variables.length > 0 && (
|
||||
<div className="space-y-3">
|
||||
<Label className="text-sm font-medium">
|
||||
Variable Values
|
||||
</Label>
|
||||
<div className="space-y-2">
|
||||
{formData.variables.map(variable => (
|
||||
<div
|
||||
key={variable}
|
||||
className="grid grid-cols-2 gap-3 items-center"
|
||||
>
|
||||
<div className="text-sm font-mono text-muted-foreground">
|
||||
{variable}
|
||||
</div>
|
||||
<Input
|
||||
id={`var-${variable}`}
|
||||
value={variableValues[variable] || ""}
|
||||
onChange={e =>
|
||||
updateVariableValue(variable, e.target.value)
|
||||
}
|
||||
placeholder={`Enter value for ${variable}`}
|
||||
className="text-sm"
|
||||
/>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
<Separator />
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Live Preview */}
|
||||
<div>
|
||||
<Label className="text-sm font-medium mb-2 block">
|
||||
Compiled Prompt
|
||||
</Label>
|
||||
<div className="bg-muted/50 p-4 rounded-lg border">
|
||||
<div className="text-sm leading-relaxed whitespace-pre-wrap">
|
||||
{renderPreview()}
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex flex-wrap gap-4 mt-2 text-xs">
|
||||
<div className="flex items-center gap-1">
|
||||
<div className="w-3 h-3 bg-green-500 dark:bg-green-400 border rounded"></div>
|
||||
<span className="text-muted-foreground">Used</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-1">
|
||||
<div className="w-3 h-3 bg-yellow-500 dark:bg-yellow-400 border rounded"></div>
|
||||
<span className="text-muted-foreground">Unused</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-1">
|
||||
<div className="w-3 h-3 bg-red-500 dark:bg-red-400 border rounded"></div>
|
||||
<span className="text-muted-foreground">Undefined</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
<div className="text-center py-8">
|
||||
<div className="text-muted-foreground text-sm">
|
||||
Enter content to preview the compiled prompt
|
||||
</div>
|
||||
<div className="text-xs text-muted-foreground mt-2">
|
||||
Use {`{{variable_name}}`} to add dynamic variables
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<Separator />
|
||||
|
||||
<div className="flex justify-between">
|
||||
<div>
|
||||
{prompt && onDelete && (
|
||||
<Button
|
||||
type="button"
|
||||
variant="destructive"
|
||||
onClick={() => {
|
||||
if (
|
||||
confirm(
|
||||
`Are you sure you want to delete this prompt? This action cannot be undone.`
|
||||
)
|
||||
) {
|
||||
onDelete(prompt.prompt_id);
|
||||
}
|
||||
}}
|
||||
>
|
||||
<Trash2 className="h-4 w-4 mr-2" />
|
||||
Delete Prompt
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
<div className="flex gap-2">
|
||||
<Button type="button" variant="outline" onClick={onCancel}>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button type="submit">
|
||||
<Save className="h-4 w-4 mr-2" />
|
||||
{prompt ? "Update" : "Create"} Prompt
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</form>
|
||||
);
|
||||
}
|
||||
259
src/llama_stack/ui/components/prompts/prompt-list.test.tsx
Normal file
259
src/llama_stack/ui/components/prompts/prompt-list.test.tsx
Normal file
|
|
@ -0,0 +1,259 @@
|
|||
import React from "react";
|
||||
import { render, screen, fireEvent } from "@testing-library/react";
|
||||
import "@testing-library/jest-dom";
|
||||
import { PromptList } from "./prompt-list";
|
||||
import type { Prompt } from "./types";
|
||||
|
||||
describe("PromptList", () => {
|
||||
const mockOnEdit = jest.fn();
|
||||
const mockOnDelete = jest.fn();
|
||||
|
||||
const defaultProps = {
|
||||
prompts: [],
|
||||
onEdit: mockOnEdit,
|
||||
onDelete: mockOnDelete,
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe("Empty State", () => {
|
||||
test("renders empty message when no prompts", () => {
|
||||
render(<PromptList {...defaultProps} />);
|
||||
|
||||
expect(screen.getByText("No prompts yet")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("shows filtered empty message when search has no results", () => {
|
||||
const prompts: Prompt[] = [
|
||||
{
|
||||
prompt_id: "prompt_123",
|
||||
prompt: "Hello world",
|
||||
version: 1,
|
||||
variables: [],
|
||||
is_default: false,
|
||||
},
|
||||
];
|
||||
|
||||
render(<PromptList {...defaultProps} prompts={prompts} />);
|
||||
|
||||
// Search for something that doesn't exist
|
||||
const searchInput = screen.getByPlaceholderText("Search prompts...");
|
||||
fireEvent.change(searchInput, { target: { value: "nonexistent" } });
|
||||
|
||||
expect(
|
||||
screen.getByText("No prompts match your filters")
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Prompts Display", () => {
|
||||
const mockPrompts: Prompt[] = [
|
||||
{
|
||||
prompt_id: "prompt_123",
|
||||
prompt: "Hello {{name}}, how are you?",
|
||||
version: 1,
|
||||
variables: ["name"],
|
||||
is_default: true,
|
||||
},
|
||||
{
|
||||
prompt_id: "prompt_456",
|
||||
prompt: "Summarize this {{text}} in {{length}} words",
|
||||
version: 2,
|
||||
variables: ["text", "length"],
|
||||
is_default: false,
|
||||
},
|
||||
{
|
||||
prompt_id: "prompt_789",
|
||||
prompt: "Simple prompt with no variables",
|
||||
version: 1,
|
||||
variables: [],
|
||||
is_default: false,
|
||||
},
|
||||
];
|
||||
|
||||
test("renders prompts table with correct headers", () => {
|
||||
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
|
||||
|
||||
expect(screen.getByText("ID")).toBeInTheDocument();
|
||||
expect(screen.getByText("Content")).toBeInTheDocument();
|
||||
expect(screen.getByText("Variables")).toBeInTheDocument();
|
||||
expect(screen.getByText("Version")).toBeInTheDocument();
|
||||
expect(screen.getByText("Actions")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("renders prompt data correctly", () => {
|
||||
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
|
||||
|
||||
// Check prompt IDs
|
||||
expect(screen.getByText("prompt_123")).toBeInTheDocument();
|
||||
expect(screen.getByText("prompt_456")).toBeInTheDocument();
|
||||
expect(screen.getByText("prompt_789")).toBeInTheDocument();
|
||||
|
||||
// Check content
|
||||
expect(
|
||||
screen.getByText("Hello {{name}}, how are you?")
|
||||
).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText("Summarize this {{text}} in {{length}} words")
|
||||
).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText("Simple prompt with no variables")
|
||||
).toBeInTheDocument();
|
||||
|
||||
// Check versions
|
||||
expect(screen.getAllByText("1")).toHaveLength(2); // Two prompts with version 1
|
||||
expect(screen.getByText("2")).toBeInTheDocument();
|
||||
|
||||
// Check default badge
|
||||
expect(screen.getByText("Default")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("renders variables correctly", () => {
|
||||
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
|
||||
|
||||
// Check variables display
|
||||
expect(screen.getByText("name")).toBeInTheDocument();
|
||||
expect(screen.getByText("text")).toBeInTheDocument();
|
||||
expect(screen.getByText("length")).toBeInTheDocument();
|
||||
expect(screen.getByText("None")).toBeInTheDocument(); // For prompt with no variables
|
||||
});
|
||||
|
||||
test("prompt ID links are clickable and call onEdit", () => {
|
||||
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
|
||||
|
||||
// Click on the first prompt ID link
|
||||
const promptLink = screen.getByRole("button", { name: "prompt_123" });
|
||||
fireEvent.click(promptLink);
|
||||
|
||||
expect(mockOnEdit).toHaveBeenCalledWith(mockPrompts[0]);
|
||||
});
|
||||
|
||||
test("edit buttons call onEdit", () => {
|
||||
const { container } = render(
|
||||
<PromptList {...defaultProps} prompts={mockPrompts} />
|
||||
);
|
||||
|
||||
// Find the action buttons in the table - they should be in the last column
|
||||
const actionCells = container.querySelectorAll("td:last-child");
|
||||
const firstActionCell = actionCells[0];
|
||||
const editButton = firstActionCell?.querySelector("button");
|
||||
|
||||
expect(editButton).toBeInTheDocument();
|
||||
fireEvent.click(editButton!);
|
||||
|
||||
expect(mockOnEdit).toHaveBeenCalledWith(mockPrompts[0]);
|
||||
});
|
||||
|
||||
test("delete buttons call onDelete with confirmation", () => {
|
||||
const originalConfirm = window.confirm;
|
||||
window.confirm = jest.fn(() => true);
|
||||
|
||||
const { container } = render(
|
||||
<PromptList {...defaultProps} prompts={mockPrompts} />
|
||||
);
|
||||
|
||||
// Find the delete button (second button in the first action cell)
|
||||
const actionCells = container.querySelectorAll("td:last-child");
|
||||
const firstActionCell = actionCells[0];
|
||||
const buttons = firstActionCell?.querySelectorAll("button");
|
||||
const deleteButton = buttons?.[1]; // Second button should be delete
|
||||
|
||||
expect(deleteButton).toBeInTheDocument();
|
||||
fireEvent.click(deleteButton!);
|
||||
|
||||
expect(window.confirm).toHaveBeenCalledWith(
|
||||
"Are you sure you want to delete this prompt? This action cannot be undone."
|
||||
);
|
||||
expect(mockOnDelete).toHaveBeenCalledWith("prompt_123");
|
||||
|
||||
window.confirm = originalConfirm;
|
||||
});
|
||||
|
||||
test("delete does not execute when confirmation is cancelled", () => {
|
||||
const originalConfirm = window.confirm;
|
||||
window.confirm = jest.fn(() => false);
|
||||
|
||||
const { container } = render(
|
||||
<PromptList {...defaultProps} prompts={mockPrompts} />
|
||||
);
|
||||
|
||||
const actionCells = container.querySelectorAll("td:last-child");
|
||||
const firstActionCell = actionCells[0];
|
||||
const buttons = firstActionCell?.querySelectorAll("button");
|
||||
const deleteButton = buttons?.[1]; // Second button should be delete
|
||||
|
||||
expect(deleteButton).toBeInTheDocument();
|
||||
fireEvent.click(deleteButton!);
|
||||
|
||||
expect(mockOnDelete).not.toHaveBeenCalled();
|
||||
|
||||
window.confirm = originalConfirm;
|
||||
});
|
||||
});
|
||||
|
||||
describe("Search Functionality", () => {
|
||||
const mockPrompts: Prompt[] = [
|
||||
{
|
||||
prompt_id: "user_greeting",
|
||||
prompt: "Hello {{name}}, welcome!",
|
||||
version: 1,
|
||||
variables: ["name"],
|
||||
is_default: true,
|
||||
},
|
||||
{
|
||||
prompt_id: "system_summary",
|
||||
prompt: "Summarize the following text",
|
||||
version: 1,
|
||||
variables: [],
|
||||
is_default: false,
|
||||
},
|
||||
];
|
||||
|
||||
test("filters prompts by prompt ID", () => {
|
||||
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
|
||||
|
||||
const searchInput = screen.getByPlaceholderText("Search prompts...");
|
||||
fireEvent.change(searchInput, { target: { value: "user" } });
|
||||
|
||||
expect(screen.getByText("user_greeting")).toBeInTheDocument();
|
||||
expect(screen.queryByText("system_summary")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("filters prompts by content", () => {
|
||||
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
|
||||
|
||||
const searchInput = screen.getByPlaceholderText("Search prompts...");
|
||||
fireEvent.change(searchInput, { target: { value: "welcome" } });
|
||||
|
||||
expect(screen.getByText("user_greeting")).toBeInTheDocument();
|
||||
expect(screen.queryByText("system_summary")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("search is case insensitive", () => {
|
||||
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
|
||||
|
||||
const searchInput = screen.getByPlaceholderText("Search prompts...");
|
||||
fireEvent.change(searchInput, { target: { value: "HELLO" } });
|
||||
|
||||
expect(screen.getByText("user_greeting")).toBeInTheDocument();
|
||||
expect(screen.queryByText("system_summary")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("clearing search shows all prompts", () => {
|
||||
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
|
||||
|
||||
const searchInput = screen.getByPlaceholderText("Search prompts...");
|
||||
|
||||
// Filter first
|
||||
fireEvent.change(searchInput, { target: { value: "user" } });
|
||||
expect(screen.queryByText("system_summary")).not.toBeInTheDocument();
|
||||
|
||||
// Clear search
|
||||
fireEvent.change(searchInput, { target: { value: "" } });
|
||||
expect(screen.getByText("user_greeting")).toBeInTheDocument();
|
||||
expect(screen.getByText("system_summary")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
164
src/llama_stack/ui/components/prompts/prompt-list.tsx
Normal file
164
src/llama_stack/ui/components/prompts/prompt-list.tsx
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/ui/table";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Edit, Search, Trash2 } from "lucide-react";
|
||||
import { Prompt, PromptFilters } from "./types";
|
||||
|
||||
interface PromptListProps {
|
||||
prompts: Prompt[];
|
||||
onEdit: (prompt: Prompt) => void;
|
||||
onDelete: (promptId: string) => void;
|
||||
}
|
||||
|
||||
export function PromptList({ prompts, onEdit, onDelete }: PromptListProps) {
|
||||
const [filters, setFilters] = useState<PromptFilters>({});
|
||||
|
||||
const filteredPrompts = prompts.filter(prompt => {
|
||||
if (
|
||||
filters.searchTerm &&
|
||||
!(
|
||||
prompt.prompt
|
||||
?.toLowerCase()
|
||||
.includes(filters.searchTerm.toLowerCase()) ||
|
||||
prompt.prompt_id
|
||||
.toLowerCase()
|
||||
.includes(filters.searchTerm.toLowerCase())
|
||||
)
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
|
||||
return (
|
||||
<div className="space-y-4">
|
||||
{/* Filters */}
|
||||
<div className="flex flex-col sm:flex-row gap-4">
|
||||
<div className="relative flex-1">
|
||||
<Search className="absolute left-3 top-1/2 transform -translate-y-1/2 text-muted-foreground h-4 w-4" />
|
||||
<Input
|
||||
placeholder="Search prompts..."
|
||||
value={filters.searchTerm || ""}
|
||||
onChange={e =>
|
||||
setFilters(prev => ({ ...prev, searchTerm: e.target.value }))
|
||||
}
|
||||
className="pl-10"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Prompts Table */}
|
||||
<div className="overflow-auto">
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>ID</TableHead>
|
||||
<TableHead>Content</TableHead>
|
||||
<TableHead>Variables</TableHead>
|
||||
<TableHead>Version</TableHead>
|
||||
<TableHead>Actions</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{filteredPrompts.map(prompt => (
|
||||
<TableRow key={prompt.prompt_id}>
|
||||
<TableCell className="max-w-48">
|
||||
<Button
|
||||
variant="link"
|
||||
className="p-0 h-auto font-mono text-blue-600 hover:text-blue-800 dark:text-blue-400 dark:hover:text-blue-300 max-w-full justify-start"
|
||||
onClick={() => onEdit(prompt)}
|
||||
title={prompt.prompt_id}
|
||||
>
|
||||
<div className="truncate">{prompt.prompt_id}</div>
|
||||
</Button>
|
||||
</TableCell>
|
||||
<TableCell className="max-w-64">
|
||||
<div
|
||||
className="font-mono text-xs text-muted-foreground truncate"
|
||||
title={prompt.prompt || "No content"}
|
||||
>
|
||||
{prompt.prompt || "No content"}
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
{prompt.variables.length > 0 ? (
|
||||
<div className="flex flex-wrap gap-1">
|
||||
{prompt.variables.map(variable => (
|
||||
<Badge
|
||||
key={variable}
|
||||
variant="outline"
|
||||
className="text-xs"
|
||||
>
|
||||
{variable}
|
||||
</Badge>
|
||||
))}
|
||||
</div>
|
||||
) : (
|
||||
<span className="text-muted-foreground text-sm">None</span>
|
||||
)}
|
||||
</TableCell>
|
||||
<TableCell className="text-sm">
|
||||
{prompt.version}
|
||||
{prompt.is_default && (
|
||||
<Badge variant="secondary" className="text-xs ml-2">
|
||||
Default
|
||||
</Badge>
|
||||
)}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="flex gap-1">
|
||||
<Button
|
||||
size="sm"
|
||||
variant="outline"
|
||||
onClick={() => onEdit(prompt)}
|
||||
className="h-8 w-8 p-0"
|
||||
>
|
||||
<Edit className="h-3 w-3" />
|
||||
</Button>
|
||||
<Button
|
||||
size="sm"
|
||||
variant="outline"
|
||||
onClick={() => {
|
||||
if (
|
||||
confirm(
|
||||
`Are you sure you want to delete this prompt? This action cannot be undone.`
|
||||
)
|
||||
) {
|
||||
onDelete(prompt.prompt_id);
|
||||
}
|
||||
}}
|
||||
className="h-8 w-8 p-0 text-destructive hover:text-destructive"
|
||||
>
|
||||
<Trash2 className="h-3 w-3" />
|
||||
</Button>
|
||||
</div>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
|
||||
{filteredPrompts.length === 0 && (
|
||||
<div className="text-center py-12">
|
||||
<div className="text-muted-foreground">
|
||||
{prompts.length === 0
|
||||
? "No prompts yet"
|
||||
: "No prompts match your filters"}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
304
src/llama_stack/ui/components/prompts/prompt-management.test.tsx
Normal file
304
src/llama_stack/ui/components/prompts/prompt-management.test.tsx
Normal file
|
|
@ -0,0 +1,304 @@
|
|||
import React from "react";
|
||||
import { render, screen, fireEvent, waitFor } from "@testing-library/react";
|
||||
import "@testing-library/jest-dom";
|
||||
import { PromptManagement } from "./prompt-management";
|
||||
import type { Prompt } from "./types";
|
||||
|
||||
// Mock the auth client
|
||||
const mockPromptsClient = {
|
||||
list: jest.fn(),
|
||||
create: jest.fn(),
|
||||
update: jest.fn(),
|
||||
delete: jest.fn(),
|
||||
};
|
||||
|
||||
jest.mock("@/hooks/use-auth-client", () => ({
|
||||
useAuthClient: () => ({
|
||||
prompts: mockPromptsClient,
|
||||
}),
|
||||
}));
|
||||
|
||||
describe("PromptManagement", () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe("Loading State", () => {
|
||||
test("renders loading state initially", () => {
|
||||
mockPromptsClient.list.mockReturnValue(new Promise(() => {})); // Never resolves
|
||||
render(<PromptManagement />);
|
||||
|
||||
expect(screen.getByText("Loading prompts...")).toBeInTheDocument();
|
||||
expect(screen.getByText("Prompts")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Empty State", () => {
|
||||
test("renders empty state when no prompts", async () => {
|
||||
mockPromptsClient.list.mockResolvedValue([]);
|
||||
render(<PromptManagement />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("No prompts found.")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
expect(screen.getByText("Create Your First Prompt")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("opens modal when clicking 'Create Your First Prompt'", async () => {
|
||||
mockPromptsClient.list.mockResolvedValue([]);
|
||||
render(<PromptManagement />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText("Create Your First Prompt")
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
fireEvent.click(screen.getByText("Create Your First Prompt"));
|
||||
|
||||
expect(screen.getByText("Create New Prompt")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Error State", () => {
|
||||
test("renders error state when API fails", async () => {
|
||||
const error = new Error("API not found");
|
||||
mockPromptsClient.list.mockRejectedValue(error);
|
||||
render(<PromptManagement />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/Error:/)).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("renders specific error for 404", async () => {
|
||||
const error = new Error("404 Not found");
|
||||
mockPromptsClient.list.mockRejectedValue(error);
|
||||
render(<PromptManagement />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText(/Prompts API endpoint not found/)
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("Prompts List", () => {
|
||||
const mockPrompts: Prompt[] = [
|
||||
{
|
||||
prompt_id: "prompt_123",
|
||||
prompt: "Hello {{name}}, how are you?",
|
||||
version: 1,
|
||||
variables: ["name"],
|
||||
is_default: true,
|
||||
},
|
||||
{
|
||||
prompt_id: "prompt_456",
|
||||
prompt: "Summarize this {{text}}",
|
||||
version: 2,
|
||||
variables: ["text"],
|
||||
is_default: false,
|
||||
},
|
||||
];
|
||||
|
||||
test("renders prompts list correctly", async () => {
|
||||
mockPromptsClient.list.mockResolvedValue(mockPrompts);
|
||||
render(<PromptManagement />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("prompt_123")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
expect(screen.getByText("prompt_456")).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText("Hello {{name}}, how are you?")
|
||||
).toBeInTheDocument();
|
||||
expect(screen.getByText("Summarize this {{text}}")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("opens modal when clicking 'New Prompt' button", async () => {
|
||||
mockPromptsClient.list.mockResolvedValue(mockPrompts);
|
||||
render(<PromptManagement />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("prompt_123")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
fireEvent.click(screen.getByText("New Prompt"));
|
||||
|
||||
expect(screen.getByText("Create New Prompt")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Modal Operations", () => {
|
||||
const mockPrompts: Prompt[] = [
|
||||
{
|
||||
prompt_id: "prompt_123",
|
||||
prompt: "Hello {{name}}",
|
||||
version: 1,
|
||||
variables: ["name"],
|
||||
is_default: true,
|
||||
},
|
||||
];
|
||||
|
||||
test("closes modal when clicking cancel", async () => {
|
||||
mockPromptsClient.list.mockResolvedValue(mockPrompts);
|
||||
render(<PromptManagement />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("prompt_123")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
// Open modal
|
||||
fireEvent.click(screen.getByText("New Prompt"));
|
||||
expect(screen.getByText("Create New Prompt")).toBeInTheDocument();
|
||||
|
||||
// Close modal
|
||||
fireEvent.click(screen.getByText("Cancel"));
|
||||
expect(screen.queryByText("Create New Prompt")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("creates new prompt successfully", async () => {
|
||||
const newPrompt: Prompt = {
|
||||
prompt_id: "prompt_new",
|
||||
prompt: "New prompt content",
|
||||
version: 1,
|
||||
variables: [],
|
||||
is_default: false,
|
||||
};
|
||||
|
||||
mockPromptsClient.list.mockResolvedValue(mockPrompts);
|
||||
mockPromptsClient.create.mockResolvedValue(newPrompt);
|
||||
render(<PromptManagement />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("prompt_123")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
// Open modal
|
||||
fireEvent.click(screen.getByText("New Prompt"));
|
||||
|
||||
// Fill form
|
||||
const promptInput = screen.getByLabelText("Prompt Content *");
|
||||
fireEvent.change(promptInput, {
|
||||
target: { value: "New prompt content" },
|
||||
});
|
||||
|
||||
// Submit form
|
||||
fireEvent.click(screen.getByText("Create Prompt"));
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockPromptsClient.create).toHaveBeenCalledWith({
|
||||
prompt: "New prompt content",
|
||||
variables: [],
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
test("handles create error gracefully", async () => {
|
||||
const error = {
|
||||
detail: {
|
||||
errors: [{ msg: "Prompt contains undeclared variables: ['test']" }],
|
||||
},
|
||||
};
|
||||
|
||||
mockPromptsClient.list.mockResolvedValue(mockPrompts);
|
||||
mockPromptsClient.create.mockRejectedValue(error);
|
||||
render(<PromptManagement />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("prompt_123")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
// Open modal
|
||||
fireEvent.click(screen.getByText("New Prompt"));
|
||||
|
||||
// Fill form
|
||||
const promptInput = screen.getByLabelText("Prompt Content *");
|
||||
fireEvent.change(promptInput, { target: { value: "Hello {{test}}" } });
|
||||
|
||||
// Submit form
|
||||
fireEvent.click(screen.getByText("Create Prompt"));
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText("Prompt contains undeclared variables: ['test']")
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("updates existing prompt successfully", async () => {
|
||||
const updatedPrompt: Prompt = {
|
||||
...mockPrompts[0],
|
||||
prompt: "Updated content",
|
||||
};
|
||||
|
||||
mockPromptsClient.list.mockResolvedValue(mockPrompts);
|
||||
mockPromptsClient.update.mockResolvedValue(updatedPrompt);
|
||||
const { container } = render(<PromptManagement />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("prompt_123")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
// Click edit button (first button in the action cell of the first row)
|
||||
const actionCells = container.querySelectorAll("td:last-child");
|
||||
const firstActionCell = actionCells[0];
|
||||
const editButton = firstActionCell?.querySelector("button");
|
||||
|
||||
expect(editButton).toBeInTheDocument();
|
||||
fireEvent.click(editButton!);
|
||||
|
||||
expect(screen.getByText("Edit Prompt")).toBeInTheDocument();
|
||||
|
||||
// Update content
|
||||
const promptInput = screen.getByLabelText("Prompt Content *");
|
||||
fireEvent.change(promptInput, { target: { value: "Updated content" } });
|
||||
|
||||
// Submit form
|
||||
fireEvent.click(screen.getByText("Update Prompt"));
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockPromptsClient.update).toHaveBeenCalledWith("prompt_123", {
|
||||
prompt: "Updated content",
|
||||
variables: ["name"],
|
||||
version: 1,
|
||||
set_as_default: true,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
test("deletes prompt successfully", async () => {
|
||||
mockPromptsClient.list.mockResolvedValue(mockPrompts);
|
||||
mockPromptsClient.delete.mockResolvedValue(undefined);
|
||||
|
||||
// Mock window.confirm
|
||||
const originalConfirm = window.confirm;
|
||||
window.confirm = jest.fn(() => true);
|
||||
|
||||
const { container } = render(<PromptManagement />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("prompt_123")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
// Click delete button (second button in the action cell of the first row)
|
||||
const actionCells = container.querySelectorAll("td:last-child");
|
||||
const firstActionCell = actionCells[0];
|
||||
const buttons = firstActionCell?.querySelectorAll("button");
|
||||
const deleteButton = buttons?.[1]; // Second button should be delete
|
||||
|
||||
expect(deleteButton).toBeInTheDocument();
|
||||
fireEvent.click(deleteButton!);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockPromptsClient.delete).toHaveBeenCalledWith("prompt_123");
|
||||
});
|
||||
|
||||
// Restore window.confirm
|
||||
window.confirm = originalConfirm;
|
||||
});
|
||||
});
|
||||
});
|
||||
233
src/llama_stack/ui/components/prompts/prompt-management.tsx
Normal file
233
src/llama_stack/ui/components/prompts/prompt-management.tsx
Normal file
|
|
@ -0,0 +1,233 @@
|
|||
"use client";
|
||||
|
||||
import { useState, useEffect } from "react";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Plus } from "lucide-react";
|
||||
import { PromptList } from "./prompt-list";
|
||||
import { PromptEditor } from "./prompt-editor";
|
||||
import { Prompt, PromptFormData } from "./types";
|
||||
import { useAuthClient } from "@/hooks/use-auth-client";
|
||||
|
||||
export function PromptManagement() {
|
||||
const [prompts, setPrompts] = useState<Prompt[]>([]);
|
||||
const [showPromptModal, setShowPromptModal] = useState(false);
|
||||
const [editingPrompt, setEditingPrompt] = useState<Prompt | undefined>();
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [error, setError] = useState<string | null>(null); // For main page errors (loading, etc.)
|
||||
const [modalError, setModalError] = useState<string | null>(null); // For form submission errors
|
||||
const client = useAuthClient();
|
||||
|
||||
// Load prompts from API on component mount
|
||||
useEffect(() => {
|
||||
const fetchPrompts = async () => {
|
||||
try {
|
||||
setLoading(true);
|
||||
setError(null);
|
||||
|
||||
const response = await client.prompts.list();
|
||||
setPrompts(response || []);
|
||||
} catch (err: unknown) {
|
||||
console.error("Failed to load prompts:", err);
|
||||
|
||||
// Handle different types of errors
|
||||
const error = err as Error & { status?: number };
|
||||
if (error?.message?.includes("404") || error?.status === 404) {
|
||||
setError(
|
||||
"Prompts API endpoint not found. Please ensure your Llama Stack server supports the prompts API."
|
||||
);
|
||||
} else if (
|
||||
error?.message?.includes("not implemented") ||
|
||||
error?.message?.includes("not supported")
|
||||
) {
|
||||
setError(
|
||||
"Prompts API is not yet implemented on this Llama Stack server."
|
||||
);
|
||||
} else {
|
||||
setError(
|
||||
`Failed to load prompts: ${error?.message || "Unknown error"}`
|
||||
);
|
||||
}
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
fetchPrompts();
|
||||
}, [client]);
|
||||
|
||||
const handleSavePrompt = async (formData: PromptFormData) => {
|
||||
try {
|
||||
setModalError(null);
|
||||
|
||||
if (editingPrompt) {
|
||||
// Update existing prompt
|
||||
const response = await client.prompts.update(editingPrompt.prompt_id, {
|
||||
prompt: formData.prompt,
|
||||
variables: formData.variables,
|
||||
version: editingPrompt.version,
|
||||
set_as_default: true,
|
||||
});
|
||||
|
||||
// Update local state
|
||||
setPrompts(prev =>
|
||||
prev.map(p =>
|
||||
p.prompt_id === editingPrompt.prompt_id ? response : p
|
||||
)
|
||||
);
|
||||
} else {
|
||||
// Create new prompt
|
||||
const response = await client.prompts.create({
|
||||
prompt: formData.prompt,
|
||||
variables: formData.variables,
|
||||
});
|
||||
|
||||
// Add to local state
|
||||
setPrompts(prev => [response, ...prev]);
|
||||
}
|
||||
|
||||
setShowPromptModal(false);
|
||||
setEditingPrompt(undefined);
|
||||
} catch (err) {
|
||||
console.error("Failed to save prompt:", err);
|
||||
|
||||
// Extract specific error message from API response
|
||||
const error = err as Error & {
|
||||
message?: string;
|
||||
detail?: { errors?: Array<{ msg?: string }> };
|
||||
};
|
||||
|
||||
// Try to parse JSON from error message if it's a string
|
||||
let parsedError = error;
|
||||
if (typeof error?.message === "string" && error.message.includes("{")) {
|
||||
try {
|
||||
const jsonMatch = error.message.match(/\d+\s+(.+)/);
|
||||
if (jsonMatch) {
|
||||
parsedError = JSON.parse(jsonMatch[1]);
|
||||
}
|
||||
} catch {
|
||||
// If parsing fails, use original error
|
||||
}
|
||||
}
|
||||
|
||||
// Try to get the specific validation error message
|
||||
const validationError = parsedError?.detail?.errors?.[0]?.msg;
|
||||
if (validationError) {
|
||||
// Clean up validation error messages (remove "Value error, " prefix if present)
|
||||
const cleanMessage = validationError.replace(/^Value error,\s*/i, "");
|
||||
setModalError(cleanMessage);
|
||||
} else {
|
||||
// For other errors, format them nicely with line breaks
|
||||
const statusMatch = error?.message?.match(/(\d+)\s+(.+)/);
|
||||
if (statusMatch) {
|
||||
const statusCode = statusMatch[1];
|
||||
const response = statusMatch[2];
|
||||
setModalError(
|
||||
`Failed to save prompt: Status Code ${statusCode}\n\nResponse: ${response}`
|
||||
);
|
||||
} else {
|
||||
const message = error?.message || error?.detail || "Unknown error";
|
||||
setModalError(`Failed to save prompt: ${message}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const handleEditPrompt = (prompt: Prompt) => {
|
||||
setEditingPrompt(prompt);
|
||||
setShowPromptModal(true);
|
||||
setModalError(null); // Clear any previous modal errors
|
||||
};
|
||||
|
||||
const handleDeletePrompt = async (promptId: string) => {
|
||||
try {
|
||||
setError(null);
|
||||
await client.prompts.delete(promptId);
|
||||
setPrompts(prev => prev.filter(p => p.prompt_id !== promptId));
|
||||
|
||||
// If we're deleting the currently editing prompt, close the modal
|
||||
if (editingPrompt && editingPrompt.prompt_id === promptId) {
|
||||
setShowPromptModal(false);
|
||||
setEditingPrompt(undefined);
|
||||
}
|
||||
} catch (err) {
|
||||
console.error("Failed to delete prompt:", err);
|
||||
setError("Failed to delete prompt");
|
||||
}
|
||||
};
|
||||
|
||||
const handleCreateNew = () => {
|
||||
setEditingPrompt(undefined);
|
||||
setShowPromptModal(true);
|
||||
setModalError(null); // Clear any previous modal errors
|
||||
};
|
||||
|
||||
const handleCancel = () => {
|
||||
setShowPromptModal(false);
|
||||
setEditingPrompt(undefined);
|
||||
};
|
||||
|
||||
const renderContent = () => {
|
||||
if (loading) {
|
||||
return <div className="text-muted-foreground">Loading prompts...</div>;
|
||||
}
|
||||
|
||||
if (error) {
|
||||
return <div className="text-destructive">Error: {error}</div>;
|
||||
}
|
||||
|
||||
if (!prompts || prompts.length === 0) {
|
||||
return (
|
||||
<div className="text-center py-12">
|
||||
<p className="text-muted-foreground mb-4">No prompts found.</p>
|
||||
<Button onClick={handleCreateNew}>
|
||||
<Plus className="h-4 w-4 mr-2" />
|
||||
Create Your First Prompt
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<PromptList
|
||||
prompts={prompts}
|
||||
onEdit={handleEditPrompt}
|
||||
onDelete={handleDeletePrompt}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="space-y-4">
|
||||
<div className="flex items-center justify-between">
|
||||
<h1 className="text-2xl font-semibold">Prompts</h1>
|
||||
<Button onClick={handleCreateNew} disabled={loading}>
|
||||
<Plus className="h-4 w-4 mr-2" />
|
||||
New Prompt
|
||||
</Button>
|
||||
</div>
|
||||
{renderContent()}
|
||||
|
||||
{/* Create/Edit Prompt Modal */}
|
||||
{showPromptModal && (
|
||||
<div className="fixed inset-0 bg-black/50 flex items-center justify-center z-50">
|
||||
<div className="bg-background border rounded-lg shadow-lg max-w-4xl w-full mx-4 max-h-[90vh] overflow-hidden">
|
||||
<div className="p-6 border-b">
|
||||
<h2 className="text-2xl font-bold">
|
||||
{editingPrompt ? "Edit Prompt" : "Create New Prompt"}
|
||||
</h2>
|
||||
</div>
|
||||
<div className="p-6 overflow-y-auto max-h-[calc(90vh-120px)]">
|
||||
<PromptEditor
|
||||
prompt={editingPrompt}
|
||||
onSave={handleSavePrompt}
|
||||
onCancel={handleCancel}
|
||||
onDelete={handleDeletePrompt}
|
||||
error={modalError}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
16
src/llama_stack/ui/components/prompts/types.ts
Normal file
16
src/llama_stack/ui/components/prompts/types.ts
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
export interface Prompt {
|
||||
prompt_id: string;
|
||||
prompt: string | null;
|
||||
version: number;
|
||||
variables: string[];
|
||||
is_default: boolean;
|
||||
}
|
||||
|
||||
export interface PromptFormData {
|
||||
prompt: string;
|
||||
variables: string[];
|
||||
}
|
||||
|
||||
export interface PromptFilters {
|
||||
searchTerm?: string;
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue