mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 18:00:36 +00:00
Merge branch 'main' into add-mongodb-vector_io
This commit is contained in:
commit
d0064fc915
426 changed files with 99110 additions and 62778 deletions
60
.github/actions/install-llama-stack-client/action.yml
vendored
Normal file
60
.github/actions/install-llama-stack-client/action.yml
vendored
Normal file
|
|
@ -0,0 +1,60 @@
|
||||||
|
name: Install llama-stack-client
|
||||||
|
description: Install llama-stack-client based on branch context and client-version input
|
||||||
|
|
||||||
|
inputs:
|
||||||
|
client-version:
|
||||||
|
description: 'Client version to install on non-release branches (latest or published). Ignored on release branches.'
|
||||||
|
required: false
|
||||||
|
default: ""
|
||||||
|
|
||||||
|
outputs:
|
||||||
|
uv-extra-index-url:
|
||||||
|
description: 'UV_EXTRA_INDEX_URL to use (set for release branches)'
|
||||||
|
value: ${{ steps.configure.outputs.uv-extra-index-url }}
|
||||||
|
install-after-sync:
|
||||||
|
description: 'Whether to install client after uv sync'
|
||||||
|
value: ${{ steps.configure.outputs.install-after-sync }}
|
||||||
|
install-source:
|
||||||
|
description: 'Where to install client from after sync'
|
||||||
|
value: ${{ steps.configure.outputs.install-source }}
|
||||||
|
|
||||||
|
runs:
|
||||||
|
using: "composite"
|
||||||
|
steps:
|
||||||
|
- name: Configure client installation
|
||||||
|
id: configure
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
# Determine the branch we're working with
|
||||||
|
BRANCH="${{ github.base_ref || github.ref }}"
|
||||||
|
BRANCH="${BRANCH#refs/heads/}"
|
||||||
|
|
||||||
|
echo "Working with branch: $BRANCH"
|
||||||
|
|
||||||
|
# On release branches: use test.pypi for uv sync, then install from git
|
||||||
|
# On non-release branches: install based on client-version after sync
|
||||||
|
if [[ "$BRANCH" =~ ^release-[0-9]+\.[0-9]+\.x$ ]]; then
|
||||||
|
echo "Detected release branch: $BRANCH"
|
||||||
|
|
||||||
|
# Check if matching branch exists in client repo
|
||||||
|
if ! git ls-remote --exit-code --heads https://github.com/llamastack/llama-stack-client-python.git "$BRANCH" > /dev/null 2>&1; then
|
||||||
|
echo "::error::Branch $BRANCH not found in llama-stack-client-python repository"
|
||||||
|
echo "::error::Please create the matching release branch in llama-stack-client-python before testing"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Configure to use test.pypi as extra index (PyPI is primary)
|
||||||
|
echo "uv-extra-index-url=https://test.pypi.org/simple/" >> $GITHUB_OUTPUT
|
||||||
|
echo "install-after-sync=true" >> $GITHUB_OUTPUT
|
||||||
|
echo "install-source=git+https://github.com/llamastack/llama-stack-client-python.git@$BRANCH" >> $GITHUB_OUTPUT
|
||||||
|
elif [ "${{ inputs.client-version }}" = "latest" ]; then
|
||||||
|
# Install from main git after sync
|
||||||
|
echo "install-after-sync=true" >> $GITHUB_OUTPUT
|
||||||
|
echo "install-source=git+https://github.com/llamastack/llama-stack-client-python.git@main" >> $GITHUB_OUTPUT
|
||||||
|
elif [ "${{ inputs.client-version }}" = "published" ]; then
|
||||||
|
# Use published version from PyPI (installed by sync)
|
||||||
|
echo "install-after-sync=false" >> $GITHUB_OUTPUT
|
||||||
|
elif [ -n "${{ inputs.client-version }}" ]; then
|
||||||
|
echo "::error::Invalid client-version: ${{ inputs.client-version }}"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
@ -94,7 +94,7 @@ runs:
|
||||||
if: ${{ always() }}
|
if: ${{ always() }}
|
||||||
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
|
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
|
||||||
with:
|
with:
|
||||||
name: logs-${{ github.run_id }}-${{ github.run_attempt || '' }}-${{ strategy.job-index }}
|
name: logs-${{ github.run_id }}-${{ github.run_attempt || '1' }}-${{ strategy.job-index || github.job }}-${{ github.action }}
|
||||||
path: |
|
path: |
|
||||||
*.log
|
*.log
|
||||||
retention-days: 1
|
retention-days: 1
|
||||||
|
|
|
||||||
30
.github/actions/setup-runner/action.yml
vendored
30
.github/actions/setup-runner/action.yml
vendored
|
|
@ -18,25 +18,35 @@ runs:
|
||||||
python-version: ${{ inputs.python-version }}
|
python-version: ${{ inputs.python-version }}
|
||||||
version: 0.7.6
|
version: 0.7.6
|
||||||
|
|
||||||
|
- name: Configure client installation
|
||||||
|
id: client-config
|
||||||
|
uses: ./.github/actions/install-llama-stack-client
|
||||||
|
with:
|
||||||
|
client-version: ${{ inputs.client-version }}
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
shell: bash
|
shell: bash
|
||||||
|
env:
|
||||||
|
UV_EXTRA_INDEX_URL: ${{ steps.client-config.outputs.uv-extra-index-url }}
|
||||||
run: |
|
run: |
|
||||||
|
# Export UV env vars for current step and persist to GITHUB_ENV for subsequent steps
|
||||||
|
if [ -n "$UV_EXTRA_INDEX_URL" ]; then
|
||||||
|
export UV_INDEX_STRATEGY=unsafe-best-match
|
||||||
|
echo "UV_EXTRA_INDEX_URL=$UV_EXTRA_INDEX_URL" >> $GITHUB_ENV
|
||||||
|
echo "UV_INDEX_STRATEGY=$UV_INDEX_STRATEGY" >> $GITHUB_ENV
|
||||||
|
echo "Exported UV environment variables for current and subsequent steps"
|
||||||
|
fi
|
||||||
|
|
||||||
echo "Updating project dependencies via uv sync"
|
echo "Updating project dependencies via uv sync"
|
||||||
uv sync --all-groups
|
uv sync --all-groups
|
||||||
|
|
||||||
echo "Installing ad-hoc dependencies"
|
echo "Installing ad-hoc dependencies"
|
||||||
uv pip install faiss-cpu
|
uv pip install faiss-cpu
|
||||||
|
|
||||||
# Install llama-stack-client-python based on the client-version input
|
# Install specific client version after sync if needed
|
||||||
if [ "${{ inputs.client-version }}" = "latest" ]; then
|
if [ "${{ steps.client-config.outputs.install-after-sync }}" = "true" ]; then
|
||||||
echo "Installing latest llama-stack-client-python from main branch"
|
echo "Installing llama-stack-client from: ${{ steps.client-config.outputs.install-source }}"
|
||||||
uv pip install git+https://github.com/llamastack/llama-stack-client-python.git@main
|
uv pip install ${{ steps.client-config.outputs.install-source }}
|
||||||
elif [ "${{ inputs.client-version }}" = "published" ]; then
|
|
||||||
echo "Installing published llama-stack-client-python from PyPI"
|
|
||||||
uv pip install llama-stack-client
|
|
||||||
else
|
|
||||||
echo "Invalid client-version: ${{ inputs.client-version }}"
|
|
||||||
exit 1
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
echo "Installed llama packages"
|
echo "Installed llama packages"
|
||||||
|
|
|
||||||
|
|
@ -42,18 +42,7 @@ runs:
|
||||||
- name: Build Llama Stack
|
- name: Build Llama Stack
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
# Install llama-stack-client-python based on the client-version input
|
# Client is already installed by setup-runner (handles both main and release branches)
|
||||||
if [ "${{ inputs.client-version }}" = "latest" ]; then
|
|
||||||
echo "Installing latest llama-stack-client-python from main branch"
|
|
||||||
export LLAMA_STACK_CLIENT_DIR=git+https://github.com/llamastack/llama-stack-client-python.git@main
|
|
||||||
elif [ "${{ inputs.client-version }}" = "published" ]; then
|
|
||||||
echo "Installing published llama-stack-client-python from PyPI"
|
|
||||||
unset LLAMA_STACK_CLIENT_DIR
|
|
||||||
else
|
|
||||||
echo "Invalid client-version: ${{ inputs.client-version }}"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo "Building Llama Stack"
|
echo "Building Llama Stack"
|
||||||
|
|
||||||
LLAMA_STACK_DIR=. \
|
LLAMA_STACK_DIR=. \
|
||||||
|
|
|
||||||
23
.github/mergify.yml
vendored
Normal file
23
.github/mergify.yml
vendored
Normal file
|
|
@ -0,0 +1,23 @@
|
||||||
|
pull_request_rules:
|
||||||
|
- name: ping author on conflicts and add 'needs-rebase' label
|
||||||
|
conditions:
|
||||||
|
- conflict
|
||||||
|
- -closed
|
||||||
|
actions:
|
||||||
|
label:
|
||||||
|
add:
|
||||||
|
- needs-rebase
|
||||||
|
comment:
|
||||||
|
message: >
|
||||||
|
This pull request has merge conflicts that must be resolved before it
|
||||||
|
can be merged. @{{author}} please rebase it.
|
||||||
|
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
|
||||||
|
|
||||||
|
- name: remove 'needs-rebase' label when conflict is resolved
|
||||||
|
conditions:
|
||||||
|
- -conflict
|
||||||
|
- -closed
|
||||||
|
actions:
|
||||||
|
label:
|
||||||
|
remove:
|
||||||
|
- needs-rebase
|
||||||
2
.github/workflows/README.md
vendored
2
.github/workflows/README.md
vendored
|
|
@ -4,6 +4,7 @@ Llama Stack uses GitHub Actions for Continuous Integration (CI). Below is a tabl
|
||||||
|
|
||||||
| Name | File | Purpose |
|
| Name | File | Purpose |
|
||||||
| ---- | ---- | ------- |
|
| ---- | ---- | ------- |
|
||||||
|
| Backward Compatibility Check | [backward-compat.yml](backward-compat.yml) | Check backward compatibility for run.yaml configs |
|
||||||
| Update Changelog | [changelog.yml](changelog.yml) | Creates PR for updating the CHANGELOG.md |
|
| Update Changelog | [changelog.yml](changelog.yml) | Creates PR for updating the CHANGELOG.md |
|
||||||
| API Conformance Tests | [conformance.yml](conformance.yml) | Run the API Conformance test suite on the changes. |
|
| API Conformance Tests | [conformance.yml](conformance.yml) | Run the API Conformance test suite on the changes. |
|
||||||
| Installer CI | [install-script-ci.yml](install-script-ci.yml) | Test the installation script |
|
| Installer CI | [install-script-ci.yml](install-script-ci.yml) | Test the installation script |
|
||||||
|
|
@ -12,7 +13,6 @@ Llama Stack uses GitHub Actions for Continuous Integration (CI). Below is a tabl
|
||||||
| Integration Tests (Replay) | [integration-tests.yml](integration-tests.yml) | Run the integration test suites from tests/integration in replay mode |
|
| Integration Tests (Replay) | [integration-tests.yml](integration-tests.yml) | Run the integration test suites from tests/integration in replay mode |
|
||||||
| Vector IO Integration Tests | [integration-vector-io-tests.yml](integration-vector-io-tests.yml) | Run the integration test suite with various VectorIO providers |
|
| Vector IO Integration Tests | [integration-vector-io-tests.yml](integration-vector-io-tests.yml) | Run the integration test suite with various VectorIO providers |
|
||||||
| Pre-commit | [pre-commit.yml](pre-commit.yml) | Run pre-commit checks |
|
| Pre-commit | [pre-commit.yml](pre-commit.yml) | Run pre-commit checks |
|
||||||
| Pre-commit Bot | [precommit-trigger.yml](precommit-trigger.yml) | Pre-commit bot for PR |
|
|
||||||
| Test Llama Stack Build | [providers-build.yml](providers-build.yml) | Test llama stack build |
|
| Test Llama Stack Build | [providers-build.yml](providers-build.yml) | Test llama stack build |
|
||||||
| Test llama stack list-deps | [providers-list-deps.yml](providers-list-deps.yml) | Test llama stack list-deps |
|
| Test llama stack list-deps | [providers-list-deps.yml](providers-list-deps.yml) | Test llama stack list-deps |
|
||||||
| Python Package Build Test | [python-build-test.yml](python-build-test.yml) | Test building the llama-stack PyPI project |
|
| Python Package Build Test | [python-build-test.yml](python-build-test.yml) | Test building the llama-stack PyPI project |
|
||||||
|
|
|
||||||
578
.github/workflows/backward-compat.yml
vendored
Normal file
578
.github/workflows/backward-compat.yml
vendored
Normal file
|
|
@ -0,0 +1,578 @@
|
||||||
|
name: Backward Compatibility Check
|
||||||
|
|
||||||
|
run-name: Check backward compatibility for run.yaml configs
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
- 'release-[0-9]+.[0-9]+.[0-9]+.[0-9]+'
|
||||||
|
- 'release-[0-9]+.[0-9]+.[0-9]+'
|
||||||
|
- 'release-[0-9]+.[0-9]+'
|
||||||
|
paths:
|
||||||
|
- 'src/llama_stack/core/datatypes.py'
|
||||||
|
- 'src/llama_stack/providers/datatypes.py'
|
||||||
|
- 'src/llama_stack/distributions/**/run.yaml'
|
||||||
|
- 'tests/backward_compat/**'
|
||||||
|
- '.github/workflows/backward-compat.yml'
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-${{ github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
check-main-compatibility:
|
||||||
|
name: Check Compatibility with main
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout PR branch
|
||||||
|
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||||
|
with:
|
||||||
|
fetch-depth: 0 # Need full history to access main branch
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0
|
||||||
|
with:
|
||||||
|
python-version: '3.12'
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@85856786d1ce8acfbcc2f13a5f3fbd6b938f9f41 # v7.1.2
|
||||||
|
with:
|
||||||
|
enable-cache: true
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
uv sync --group dev
|
||||||
|
|
||||||
|
- name: Extract run.yaml files from main branch
|
||||||
|
id: extract_configs
|
||||||
|
run: |
|
||||||
|
# Get list of run.yaml paths from main
|
||||||
|
git fetch origin main
|
||||||
|
CONFIG_PATHS=$(git ls-tree -r --name-only origin/main | grep "src/llama_stack/distributions/.*/run.yaml$" || true)
|
||||||
|
|
||||||
|
if [ -z "$CONFIG_PATHS" ]; then
|
||||||
|
echo "No run.yaml files found in main branch"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Extract all configs to a temp directory
|
||||||
|
mkdir -p /tmp/main_configs
|
||||||
|
echo "Extracting configs from main branch:"
|
||||||
|
|
||||||
|
while IFS= read -r config_path; do
|
||||||
|
if [ -z "$config_path" ]; then
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Extract filename for storage
|
||||||
|
filename=$(basename $(dirname "$config_path"))
|
||||||
|
echo " - $filename (from $config_path)"
|
||||||
|
|
||||||
|
git show origin/main:"$config_path" > "/tmp/main_configs/${filename}.yaml"
|
||||||
|
done <<< "$CONFIG_PATHS"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "Extracted $(ls /tmp/main_configs/*.yaml | wc -l) config files"
|
||||||
|
|
||||||
|
- name: Test all configs from main
|
||||||
|
id: test_configs
|
||||||
|
continue-on-error: true
|
||||||
|
run: |
|
||||||
|
# Run pytest once with all configs parameterized
|
||||||
|
if COMPAT_TEST_CONFIGS_DIR=/tmp/main_configs uv run pytest tests/backward_compat/test_run_config.py -v; then
|
||||||
|
echo "failed=false" >> $GITHUB_OUTPUT
|
||||||
|
else
|
||||||
|
echo "failed=true" >> $GITHUB_OUTPUT
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Check for breaking change acknowledgment
|
||||||
|
id: check_ack
|
||||||
|
if: steps.test_configs.outputs.failed == 'true'
|
||||||
|
run: |
|
||||||
|
echo "Breaking changes detected. Checking for acknowledgment..."
|
||||||
|
|
||||||
|
# Check PR title for '!:' marker (conventional commits)
|
||||||
|
PR_TITLE="${{ github.event.pull_request.title }}"
|
||||||
|
if [[ "$PR_TITLE" =~ ^[a-z]+\!: ]]; then
|
||||||
|
echo "✓ Breaking change acknowledged in PR title"
|
||||||
|
echo "acknowledged=true" >> $GITHUB_OUTPUT
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check commit messages for BREAKING CHANGE:
|
||||||
|
if git log origin/main..HEAD --format=%B | grep -q "BREAKING CHANGE:"; then
|
||||||
|
echo "✓ Breaking change acknowledged in commit message"
|
||||||
|
echo "acknowledged=true" >> $GITHUB_OUTPUT
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "✗ Breaking change NOT acknowledged"
|
||||||
|
echo "acknowledged=false" >> $GITHUB_OUTPUT
|
||||||
|
env:
|
||||||
|
GH_TOKEN: ${{ github.token }}
|
||||||
|
|
||||||
|
- name: Evaluate results
|
||||||
|
if: always()
|
||||||
|
run: |
|
||||||
|
FAILED="${{ steps.test_configs.outputs.failed }}"
|
||||||
|
ACKNOWLEDGED="${{ steps.check_ack.outputs.acknowledged }}"
|
||||||
|
|
||||||
|
if [[ "$FAILED" == "true" ]]; then
|
||||||
|
if [[ "$ACKNOWLEDGED" == "true" ]]; then
|
||||||
|
echo ""
|
||||||
|
echo "⚠️ WARNING: Breaking changes detected but acknowledged"
|
||||||
|
echo ""
|
||||||
|
echo "This PR introduces backward-incompatible changes to run.yaml."
|
||||||
|
echo "The changes have been properly acknowledged."
|
||||||
|
echo ""
|
||||||
|
exit 0 # Pass the check
|
||||||
|
else
|
||||||
|
echo ""
|
||||||
|
echo "❌ ERROR: Breaking changes detected without acknowledgment"
|
||||||
|
echo ""
|
||||||
|
echo "This PR introduces backward-incompatible changes to run.yaml"
|
||||||
|
echo "that will break existing user configurations."
|
||||||
|
echo ""
|
||||||
|
echo "To acknowledge this breaking change, do ONE of:"
|
||||||
|
echo " 1. Add '!:' to your PR title (e.g., 'feat!: change xyz')"
|
||||||
|
echo " 2. Add the 'breaking-change' label to this PR"
|
||||||
|
echo " 3. Include 'BREAKING CHANGE:' in a commit message"
|
||||||
|
echo ""
|
||||||
|
exit 1 # Fail the check
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
test-integration-main:
|
||||||
|
name: Run Integration Tests with main Config
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout PR branch
|
||||||
|
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Extract ci-tests run.yaml from main
|
||||||
|
run: |
|
||||||
|
git fetch origin main
|
||||||
|
git show origin/main:src/llama_stack/distributions/ci-tests/run.yaml > /tmp/main-ci-tests-run.yaml
|
||||||
|
echo "Extracted ci-tests run.yaml from main branch"
|
||||||
|
|
||||||
|
- name: Setup test environment
|
||||||
|
uses: ./.github/actions/setup-test-environment
|
||||||
|
with:
|
||||||
|
python-version: '3.12'
|
||||||
|
client-version: 'latest'
|
||||||
|
setup: 'ollama'
|
||||||
|
suite: 'base'
|
||||||
|
inference-mode: 'replay'
|
||||||
|
|
||||||
|
- name: Run integration tests with main config
|
||||||
|
id: test_integration
|
||||||
|
continue-on-error: true
|
||||||
|
uses: ./.github/actions/run-and-record-tests
|
||||||
|
with:
|
||||||
|
stack-config: /tmp/main-ci-tests-run.yaml
|
||||||
|
setup: 'ollama'
|
||||||
|
inference-mode: 'replay'
|
||||||
|
suite: 'base'
|
||||||
|
|
||||||
|
- name: Check for breaking change acknowledgment
|
||||||
|
id: check_ack
|
||||||
|
if: steps.test_integration.outcome == 'failure'
|
||||||
|
run: |
|
||||||
|
echo "Integration tests failed. Checking for acknowledgment..."
|
||||||
|
|
||||||
|
# Check PR title for '!:' marker (conventional commits)
|
||||||
|
PR_TITLE="${{ github.event.pull_request.title }}"
|
||||||
|
if [[ "$PR_TITLE" =~ ^[a-z]+\!: ]]; then
|
||||||
|
echo "✓ Breaking change acknowledged in PR title"
|
||||||
|
echo "acknowledged=true" >> $GITHUB_OUTPUT
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check commit messages for BREAKING CHANGE:
|
||||||
|
if git log origin/main..HEAD --format=%B | grep -q "BREAKING CHANGE:"; then
|
||||||
|
echo "✓ Breaking change acknowledged in commit message"
|
||||||
|
echo "acknowledged=true" >> $GITHUB_OUTPUT
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "✗ Breaking change NOT acknowledged"
|
||||||
|
echo "acknowledged=false" >> $GITHUB_OUTPUT
|
||||||
|
env:
|
||||||
|
GH_TOKEN: ${{ github.token }}
|
||||||
|
|
||||||
|
- name: Evaluate integration test results
|
||||||
|
if: always()
|
||||||
|
run: |
|
||||||
|
TEST_FAILED="${{ steps.test_integration.outcome == 'failure' }}"
|
||||||
|
ACKNOWLEDGED="${{ steps.check_ack.outputs.acknowledged }}"
|
||||||
|
|
||||||
|
if [[ "$TEST_FAILED" == "true" ]]; then
|
||||||
|
if [[ "$ACKNOWLEDGED" == "true" ]]; then
|
||||||
|
echo ""
|
||||||
|
echo "⚠️ WARNING: Integration tests failed with main config but acknowledged"
|
||||||
|
echo ""
|
||||||
|
exit 0 # Pass the check
|
||||||
|
else
|
||||||
|
echo ""
|
||||||
|
echo "❌ ERROR: Integration tests failed with main config without acknowledgment"
|
||||||
|
echo ""
|
||||||
|
echo "To acknowledge this breaking change, do ONE of:"
|
||||||
|
echo " 1. Add '!:' to your PR title (e.g., 'feat!: change xyz')"
|
||||||
|
echo " 2. Include 'BREAKING CHANGE:' in a commit message"
|
||||||
|
echo ""
|
||||||
|
exit 1 # Fail the check
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
test-integration-release:
|
||||||
|
name: Run Integration Tests with Latest Release (Informational)
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout PR branch
|
||||||
|
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Get latest release
|
||||||
|
id: get_release
|
||||||
|
run: |
|
||||||
|
# Get the latest release from GitHub
|
||||||
|
LATEST_TAG=$(gh release list --limit 1 --json tagName --jq '.[0].tagName' 2>/dev/null || echo "")
|
||||||
|
|
||||||
|
if [ -z "$LATEST_TAG" ]; then
|
||||||
|
echo "No releases found, skipping release compatibility check"
|
||||||
|
echo "has_release=false" >> $GITHUB_OUTPUT
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Latest release: $LATEST_TAG"
|
||||||
|
echo "has_release=true" >> $GITHUB_OUTPUT
|
||||||
|
echo "tag=$LATEST_TAG" >> $GITHUB_OUTPUT
|
||||||
|
env:
|
||||||
|
GH_TOKEN: ${{ github.token }}
|
||||||
|
|
||||||
|
- name: Extract ci-tests run.yaml from release
|
||||||
|
if: steps.get_release.outputs.has_release == 'true'
|
||||||
|
id: extract_config
|
||||||
|
run: |
|
||||||
|
RELEASE_TAG="${{ steps.get_release.outputs.tag }}"
|
||||||
|
|
||||||
|
# Try with src/ prefix first (newer releases), then without (older releases)
|
||||||
|
if git show "$RELEASE_TAG:src/llama_stack/distributions/ci-tests/run.yaml" > /tmp/release-ci-tests-run.yaml 2>/dev/null; then
|
||||||
|
echo "Extracted ci-tests run.yaml from release $RELEASE_TAG (src/ path)"
|
||||||
|
echo "has_config=true" >> $GITHUB_OUTPUT
|
||||||
|
elif git show "$RELEASE_TAG:llama_stack/distributions/ci-tests/run.yaml" > /tmp/release-ci-tests-run.yaml 2>/dev/null; then
|
||||||
|
echo "Extracted ci-tests run.yaml from release $RELEASE_TAG (old path)"
|
||||||
|
echo "has_config=true" >> $GITHUB_OUTPUT
|
||||||
|
else
|
||||||
|
echo "::warning::ci-tests/run.yaml not found in release $RELEASE_TAG"
|
||||||
|
echo "has_config=false" >> $GITHUB_OUTPUT
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Setup test environment
|
||||||
|
if: steps.get_release.outputs.has_release == 'true' && steps.extract_config.outputs.has_config == 'true'
|
||||||
|
uses: ./.github/actions/setup-test-environment
|
||||||
|
with:
|
||||||
|
python-version: '3.12'
|
||||||
|
client-version: 'latest'
|
||||||
|
setup: 'ollama'
|
||||||
|
suite: 'base'
|
||||||
|
inference-mode: 'replay'
|
||||||
|
|
||||||
|
- name: Run integration tests with release config (PR branch)
|
||||||
|
id: test_release_pr
|
||||||
|
if: steps.get_release.outputs.has_release == 'true' && steps.extract_config.outputs.has_config == 'true'
|
||||||
|
continue-on-error: true
|
||||||
|
uses: ./.github/actions/run-and-record-tests
|
||||||
|
with:
|
||||||
|
stack-config: /tmp/release-ci-tests-run.yaml
|
||||||
|
setup: 'ollama'
|
||||||
|
inference-mode: 'replay'
|
||||||
|
suite: 'base'
|
||||||
|
|
||||||
|
- name: Checkout main branch to test baseline
|
||||||
|
if: steps.get_release.outputs.has_release == 'true' && steps.extract_config.outputs.has_config == 'true'
|
||||||
|
run: |
|
||||||
|
git checkout origin/main
|
||||||
|
|
||||||
|
- name: Setup test environment for main
|
||||||
|
if: steps.get_release.outputs.has_release == 'true' && steps.extract_config.outputs.has_config == 'true'
|
||||||
|
uses: ./.github/actions/setup-test-environment
|
||||||
|
with:
|
||||||
|
python-version: '3.12'
|
||||||
|
client-version: 'latest'
|
||||||
|
setup: 'ollama'
|
||||||
|
suite: 'base'
|
||||||
|
inference-mode: 'replay'
|
||||||
|
|
||||||
|
- name: Run integration tests with release config (main branch)
|
||||||
|
id: test_release_main
|
||||||
|
if: steps.get_release.outputs.has_release == 'true' && steps.extract_config.outputs.has_config == 'true'
|
||||||
|
continue-on-error: true
|
||||||
|
uses: ./.github/actions/run-and-record-tests
|
||||||
|
with:
|
||||||
|
stack-config: /tmp/release-ci-tests-run.yaml
|
||||||
|
setup: 'ollama'
|
||||||
|
inference-mode: 'replay'
|
||||||
|
suite: 'base'
|
||||||
|
|
||||||
|
- name: Report results and post PR comment
|
||||||
|
if: always() && steps.get_release.outputs.has_release == 'true' && steps.extract_config.outputs.has_config == 'true'
|
||||||
|
run: |
|
||||||
|
RELEASE_TAG="${{ steps.get_release.outputs.tag }}"
|
||||||
|
PR_OUTCOME="${{ steps.test_release_pr.outcome }}"
|
||||||
|
MAIN_OUTCOME="${{ steps.test_release_main.outcome }}"
|
||||||
|
|
||||||
|
if [[ "$PR_OUTCOME" == "failure" && "$MAIN_OUTCOME" == "success" ]]; then
|
||||||
|
# NEW breaking change - PR fails but main passes
|
||||||
|
echo "::error::🚨 This PR introduces a NEW breaking change!"
|
||||||
|
|
||||||
|
# Check if we already posted a comment (to avoid spam on every push)
|
||||||
|
EXISTING_COMMENT=$(gh pr view ${{ github.event.pull_request.number }} --json comments --jq '.comments[] | select(.body | contains("🚨 New Breaking Change Detected") and contains("Integration tests")) | .id' | head -1)
|
||||||
|
|
||||||
|
if [[ -z "$EXISTING_COMMENT" ]]; then
|
||||||
|
gh pr comment ${{ github.event.pull_request.number }} --body "## 🚨 New Breaking Change Detected
|
||||||
|
|
||||||
|
**Integration tests against release \`$RELEASE_TAG\` are now failing**
|
||||||
|
|
||||||
|
⚠️ This PR introduces a breaking change that affects compatibility with the latest release.
|
||||||
|
|
||||||
|
- Users on release \`$RELEASE_TAG\` may not be able to upgrade
|
||||||
|
- Existing configurations may break
|
||||||
|
|
||||||
|
The tests pass on \`main\` but fail with this PR's changes.
|
||||||
|
|
||||||
|
> **Note:** This is informational only and does not block merge.
|
||||||
|
> Consider whether this breaking change is acceptable for users."
|
||||||
|
else
|
||||||
|
echo "Comment already exists, skipping to avoid spam"
|
||||||
|
fi
|
||||||
|
|
||||||
|
cat >> $GITHUB_STEP_SUMMARY <<EOF
|
||||||
|
## 🚨 NEW Breaking Change Detected
|
||||||
|
|
||||||
|
**Integration tests against release \`$RELEASE_TAG\` FAILED**
|
||||||
|
|
||||||
|
⚠️ **This PR introduces a NEW breaking change**
|
||||||
|
|
||||||
|
- Tests **PASS** on main branch ✅
|
||||||
|
- Tests **FAIL** on PR branch ❌
|
||||||
|
- Users on release \`$RELEASE_TAG\` may not be able to upgrade
|
||||||
|
- Existing configurations may break
|
||||||
|
|
||||||
|
> **Note:** This is informational only and does not block merge.
|
||||||
|
> Consider whether this breaking change is acceptable for users.
|
||||||
|
EOF
|
||||||
|
|
||||||
|
elif [[ "$PR_OUTCOME" == "failure" ]]; then
|
||||||
|
# Existing breaking change - both PR and main fail
|
||||||
|
echo "::warning::Breaking change already exists in main branch"
|
||||||
|
|
||||||
|
cat >> $GITHUB_STEP_SUMMARY <<EOF
|
||||||
|
## ⚠️ Release Compatibility Test Failed (Existing Issue)
|
||||||
|
|
||||||
|
**Integration tests against release \`$RELEASE_TAG\` FAILED**
|
||||||
|
|
||||||
|
- Tests **FAIL** on main branch ❌
|
||||||
|
- Tests **FAIL** on PR branch ❌
|
||||||
|
- This breaking change already exists in main (not introduced by this PR)
|
||||||
|
|
||||||
|
> **Note:** This is informational only.
|
||||||
|
EOF
|
||||||
|
|
||||||
|
else
|
||||||
|
# Success - tests pass
|
||||||
|
cat >> $GITHUB_STEP_SUMMARY <<EOF
|
||||||
|
## ✅ Release Compatibility Test Passed
|
||||||
|
|
||||||
|
Integration tests against release \`$RELEASE_TAG\` passed successfully.
|
||||||
|
This PR maintains compatibility with the latest release.
|
||||||
|
EOF
|
||||||
|
fi
|
||||||
|
env:
|
||||||
|
GH_TOKEN: ${{ github.token }}
|
||||||
|
|
||||||
|
check-schema-release-compatibility:
|
||||||
|
name: Check Schema Compatibility with Latest Release (Informational)
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout PR branch
|
||||||
|
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0
|
||||||
|
with:
|
||||||
|
python-version: '3.12'
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@85856786d1ce8acfbcc2f13a5f3fbd6b938f9f41 # v7.1.2
|
||||||
|
with:
|
||||||
|
enable-cache: true
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
uv sync --group dev
|
||||||
|
|
||||||
|
- name: Get latest release
|
||||||
|
id: get_release
|
||||||
|
run: |
|
||||||
|
# Get the latest release from GitHub
|
||||||
|
LATEST_TAG=$(gh release list --limit 1 --json tagName --jq '.[0].tagName' 2>/dev/null || echo "")
|
||||||
|
|
||||||
|
if [ -z "$LATEST_TAG" ]; then
|
||||||
|
echo "No releases found, skipping release compatibility check"
|
||||||
|
echo "has_release=false" >> $GITHUB_OUTPUT
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Latest release: $LATEST_TAG"
|
||||||
|
echo "has_release=true" >> $GITHUB_OUTPUT
|
||||||
|
echo "tag=$LATEST_TAG" >> $GITHUB_OUTPUT
|
||||||
|
env:
|
||||||
|
GH_TOKEN: ${{ github.token }}
|
||||||
|
|
||||||
|
- name: Extract configs from release
|
||||||
|
if: steps.get_release.outputs.has_release == 'true'
|
||||||
|
id: extract_release_configs
|
||||||
|
run: |
|
||||||
|
RELEASE_TAG="${{ steps.get_release.outputs.tag }}"
|
||||||
|
|
||||||
|
# Get run.yaml files from the release (try both src/ and old path)
|
||||||
|
CONFIG_PATHS=$(git ls-tree -r --name-only "$RELEASE_TAG" | grep "llama_stack/distributions/.*/run.yaml$" || true)
|
||||||
|
|
||||||
|
if [ -z "$CONFIG_PATHS" ]; then
|
||||||
|
echo "::warning::No run.yaml files found in release $RELEASE_TAG"
|
||||||
|
echo "has_configs=false" >> $GITHUB_OUTPUT
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Extract all configs to a temp directory
|
||||||
|
mkdir -p /tmp/release_configs
|
||||||
|
echo "Extracting configs from release $RELEASE_TAG:"
|
||||||
|
|
||||||
|
while IFS= read -r config_path; do
|
||||||
|
if [ -z "$config_path" ]; then
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
|
||||||
|
filename=$(basename $(dirname "$config_path"))
|
||||||
|
echo " - $filename (from $config_path)"
|
||||||
|
|
||||||
|
git show "$RELEASE_TAG:$config_path" > "/tmp/release_configs/${filename}.yaml" 2>/dev/null || true
|
||||||
|
done <<< "$CONFIG_PATHS"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "Extracted $(ls /tmp/release_configs/*.yaml 2>/dev/null | wc -l) config files"
|
||||||
|
echo "has_configs=true" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
|
- name: Test against release configs (PR branch)
|
||||||
|
id: test_schema_pr
|
||||||
|
if: steps.get_release.outputs.has_release == 'true' && steps.extract_release_configs.outputs.has_configs == 'true'
|
||||||
|
continue-on-error: true
|
||||||
|
run: |
|
||||||
|
RELEASE_TAG="${{ steps.get_release.outputs.tag }}"
|
||||||
|
COMPAT_TEST_CONFIGS_DIR=/tmp/release_configs uv run pytest tests/backward_compat/test_run_config.py -v --tb=short
|
||||||
|
|
||||||
|
- name: Checkout main branch to test baseline
|
||||||
|
if: steps.get_release.outputs.has_release == 'true' && steps.extract_release_configs.outputs.has_configs == 'true'
|
||||||
|
run: |
|
||||||
|
git checkout origin/main
|
||||||
|
|
||||||
|
- name: Install dependencies for main
|
||||||
|
if: steps.get_release.outputs.has_release == 'true' && steps.extract_release_configs.outputs.has_configs == 'true'
|
||||||
|
run: |
|
||||||
|
uv sync --group dev
|
||||||
|
|
||||||
|
- name: Test against release configs (main branch)
|
||||||
|
id: test_schema_main
|
||||||
|
if: steps.get_release.outputs.has_release == 'true' && steps.extract_release_configs.outputs.has_configs == 'true'
|
||||||
|
continue-on-error: true
|
||||||
|
run: |
|
||||||
|
RELEASE_TAG="${{ steps.get_release.outputs.tag }}"
|
||||||
|
COMPAT_TEST_CONFIGS_DIR=/tmp/release_configs uv run pytest tests/backward_compat/test_run_config.py -v --tb=short
|
||||||
|
|
||||||
|
- name: Report results and post PR comment
|
||||||
|
if: always() && steps.get_release.outputs.has_release == 'true' && steps.extract_release_configs.outputs.has_configs == 'true'
|
||||||
|
run: |
|
||||||
|
RELEASE_TAG="${{ steps.get_release.outputs.tag }}"
|
||||||
|
PR_OUTCOME="${{ steps.test_schema_pr.outcome }}"
|
||||||
|
MAIN_OUTCOME="${{ steps.test_schema_main.outcome }}"
|
||||||
|
|
||||||
|
if [[ "$PR_OUTCOME" == "failure" && "$MAIN_OUTCOME" == "success" ]]; then
|
||||||
|
# NEW breaking change - PR fails but main passes
|
||||||
|
echo "::error::🚨 This PR introduces a NEW schema breaking change!"
|
||||||
|
|
||||||
|
# Check if we already posted a comment (to avoid spam on every push)
|
||||||
|
EXISTING_COMMENT=$(gh pr view ${{ github.event.pull_request.number }} --json comments --jq '.comments[] | select(.body | contains("🚨 New Schema Breaking Change Detected")) | .id' | head -1)
|
||||||
|
|
||||||
|
if [[ -z "$EXISTING_COMMENT" ]]; then
|
||||||
|
gh pr comment ${{ github.event.pull_request.number }} --body "## 🚨 New Schema Breaking Change Detected
|
||||||
|
|
||||||
|
**Schema validation against release \`$RELEASE_TAG\` is now failing**
|
||||||
|
|
||||||
|
⚠️ This PR introduces a schema breaking change that affects compatibility with the latest release.
|
||||||
|
|
||||||
|
- Users on release \`$RELEASE_TAG\` will not be able to upgrade
|
||||||
|
- Existing run.yaml configurations will fail validation
|
||||||
|
|
||||||
|
The tests pass on \`main\` but fail with this PR's changes.
|
||||||
|
|
||||||
|
> **Note:** This is informational only and does not block merge.
|
||||||
|
> Consider whether this breaking change is acceptable for users."
|
||||||
|
else
|
||||||
|
echo "Comment already exists, skipping to avoid spam"
|
||||||
|
fi
|
||||||
|
|
||||||
|
cat >> $GITHUB_STEP_SUMMARY <<EOF
|
||||||
|
## 🚨 NEW Schema Breaking Change Detected
|
||||||
|
|
||||||
|
**Schema validation against release \`$RELEASE_TAG\` FAILED**
|
||||||
|
|
||||||
|
⚠️ **This PR introduces a NEW schema breaking change**
|
||||||
|
|
||||||
|
- Tests **PASS** on main branch ✅
|
||||||
|
- Tests **FAIL** on PR branch ❌
|
||||||
|
- Users on release \`$RELEASE_TAG\` will not be able to upgrade
|
||||||
|
- Existing run.yaml configurations will fail validation
|
||||||
|
|
||||||
|
> **Note:** This is informational only and does not block merge.
|
||||||
|
> Consider whether this breaking change is acceptable for users.
|
||||||
|
EOF
|
||||||
|
|
||||||
|
elif [[ "$PR_OUTCOME" == "failure" ]]; then
|
||||||
|
# Existing breaking change - both PR and main fail
|
||||||
|
echo "::warning::Schema breaking change already exists in main branch"
|
||||||
|
|
||||||
|
cat >> $GITHUB_STEP_SUMMARY <<EOF
|
||||||
|
## ⚠️ Release Schema Compatibility Failed (Existing Issue)
|
||||||
|
|
||||||
|
**Schema validation against release \`$RELEASE_TAG\` FAILED**
|
||||||
|
|
||||||
|
- Tests **FAIL** on main branch ❌
|
||||||
|
- Tests **FAIL** on PR branch ❌
|
||||||
|
- This schema breaking change already exists in main (not introduced by this PR)
|
||||||
|
|
||||||
|
> **Note:** This is informational only.
|
||||||
|
EOF
|
||||||
|
|
||||||
|
else
|
||||||
|
# Success - tests pass
|
||||||
|
cat >> $GITHUB_STEP_SUMMARY <<EOF
|
||||||
|
## ✅ Release Schema Compatibility Passed
|
||||||
|
|
||||||
|
All run.yaml configs from release \`$RELEASE_TAG\` are compatible.
|
||||||
|
This PR maintains backward compatibility with the latest release.
|
||||||
|
EOF
|
||||||
|
fi
|
||||||
|
env:
|
||||||
|
GH_TOKEN: ${{ github.token }}
|
||||||
1
.github/workflows/conformance.yml
vendored
1
.github/workflows/conformance.yml
vendored
|
|
@ -22,7 +22,6 @@ on:
|
||||||
- 'docs/static/stable-llama-stack-spec.yaml' # Stable APIs spec
|
- 'docs/static/stable-llama-stack-spec.yaml' # Stable APIs spec
|
||||||
- 'docs/static/experimental-llama-stack-spec.yaml' # Experimental APIs spec
|
- 'docs/static/experimental-llama-stack-spec.yaml' # Experimental APIs spec
|
||||||
- 'docs/static/deprecated-llama-stack-spec.yaml' # Deprecated APIs spec
|
- 'docs/static/deprecated-llama-stack-spec.yaml' # Deprecated APIs spec
|
||||||
- 'docs/static/llama-stack-spec.html' # Legacy HTML spec
|
|
||||||
- '.github/workflows/conformance.yml' # This workflow itself
|
- '.github/workflows/conformance.yml' # This workflow itself
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
|
|
|
||||||
10
.github/workflows/install-script-ci.yml
vendored
10
.github/workflows/install-script-ci.yml
vendored
|
|
@ -30,10 +30,16 @@ jobs:
|
||||||
|
|
||||||
- name: Build a single provider
|
- name: Build a single provider
|
||||||
run: |
|
run: |
|
||||||
|
BUILD_ARGS="--build-arg INSTALL_MODE=editable --build-arg DISTRO_NAME=starter"
|
||||||
|
if [ -n "${UV_EXTRA_INDEX_URL:-}" ]; then
|
||||||
|
BUILD_ARGS="$BUILD_ARGS --build-arg UV_EXTRA_INDEX_URL=$UV_EXTRA_INDEX_URL"
|
||||||
|
fi
|
||||||
|
if [ -n "${UV_INDEX_STRATEGY:-}" ]; then
|
||||||
|
BUILD_ARGS="$BUILD_ARGS --build-arg UV_INDEX_STRATEGY=$UV_INDEX_STRATEGY"
|
||||||
|
fi
|
||||||
docker build . \
|
docker build . \
|
||||||
-f containers/Containerfile \
|
-f containers/Containerfile \
|
||||||
--build-arg INSTALL_MODE=editable \
|
$BUILD_ARGS \
|
||||||
--build-arg DISTRO_NAME=starter \
|
|
||||||
--tag llama-stack:starter-ci
|
--tag llama-stack:starter-ci
|
||||||
|
|
||||||
- name: Run installer end-to-end
|
- name: Run installer end-to-end
|
||||||
|
|
|
||||||
8
.github/workflows/integration-auth-tests.yml
vendored
8
.github/workflows/integration-auth-tests.yml
vendored
|
|
@ -4,9 +4,13 @@ run-name: Run the integration test suite with Kubernetes authentication
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches: [ main ]
|
branches:
|
||||||
|
- main
|
||||||
|
- 'release-[0-9]+.[0-9]+.x'
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ main ]
|
branches:
|
||||||
|
- main
|
||||||
|
- 'release-[0-9]+.[0-9]+.x'
|
||||||
paths:
|
paths:
|
||||||
- 'distributions/**'
|
- 'distributions/**'
|
||||||
- 'src/llama_stack/**'
|
- 'src/llama_stack/**'
|
||||||
|
|
|
||||||
|
|
@ -4,9 +4,13 @@ run-name: Run the integration test suite with SqlStore
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches: [ main ]
|
branches:
|
||||||
|
- main
|
||||||
|
- 'release-[0-9]+.[0-9]+.x'
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ main ]
|
branches:
|
||||||
|
- main
|
||||||
|
- 'release-[0-9]+.[0-9]+.x'
|
||||||
paths:
|
paths:
|
||||||
- 'src/llama_stack/providers/utils/sqlstore/**'
|
- 'src/llama_stack/providers/utils/sqlstore/**'
|
||||||
- 'tests/integration/sqlstore/**'
|
- 'tests/integration/sqlstore/**'
|
||||||
|
|
|
||||||
11
.github/workflows/integration-tests.yml
vendored
11
.github/workflows/integration-tests.yml
vendored
|
|
@ -4,9 +4,13 @@ run-name: Run the integration test suites from tests/integration in replay mode
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches: [ main ]
|
branches:
|
||||||
|
- main
|
||||||
|
- 'release-[0-9]+.[0-9]+.x'
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ main ]
|
branches:
|
||||||
|
- main
|
||||||
|
- 'release-[0-9]+.[0-9]+.x'
|
||||||
types: [opened, synchronize, reopened]
|
types: [opened, synchronize, reopened]
|
||||||
paths:
|
paths:
|
||||||
- 'src/llama_stack/**'
|
- 'src/llama_stack/**'
|
||||||
|
|
@ -18,6 +22,7 @@ on:
|
||||||
- '.github/actions/setup-ollama/action.yml'
|
- '.github/actions/setup-ollama/action.yml'
|
||||||
- '.github/actions/setup-test-environment/action.yml'
|
- '.github/actions/setup-test-environment/action.yml'
|
||||||
- '.github/actions/run-and-record-tests/action.yml'
|
- '.github/actions/run-and-record-tests/action.yml'
|
||||||
|
- 'scripts/integration-tests.sh'
|
||||||
schedule:
|
schedule:
|
||||||
# If changing the cron schedule, update the provider in the test-matrix job
|
# If changing the cron schedule, update the provider in the test-matrix job
|
||||||
- cron: '0 0 * * *' # (test latest client) Daily at 12 AM UTC
|
- cron: '0 0 * * *' # (test latest client) Daily at 12 AM UTC
|
||||||
|
|
@ -47,7 +52,7 @@ jobs:
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
client-type: [library, docker]
|
client-type: [library, docker, server]
|
||||||
# Use Python 3.13 only on nightly schedule (daily latest client test), otherwise use 3.12
|
# Use Python 3.13 only on nightly schedule (daily latest client test), otherwise use 3.12
|
||||||
python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }}
|
python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }}
|
||||||
client-version: ${{ (github.event.schedule == '0 0 * * *' || github.event.inputs.test-all-client-versions == 'true') && fromJSON('["published", "latest"]') || fromJSON('["latest"]') }}
|
client-version: ${{ (github.event.schedule == '0 0 * * *' || github.event.inputs.test-all-client-versions == 'true') && fromJSON('["published", "latest"]') || fromJSON('["latest"]') }}
|
||||||
|
|
|
||||||
|
|
@ -4,9 +4,13 @@ run-name: Run the integration test suite with various VectorIO providers
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches: [ main ]
|
branches:
|
||||||
|
- main
|
||||||
|
- 'release-[0-9]+.[0-9]+.x'
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ main ]
|
branches:
|
||||||
|
- main
|
||||||
|
- 'release-[0-9]+.[0-9]+.x'
|
||||||
paths:
|
paths:
|
||||||
- 'src/llama_stack/**'
|
- 'src/llama_stack/**'
|
||||||
- '!src/llama_stack/ui/**'
|
- '!src/llama_stack/ui/**'
|
||||||
|
|
|
||||||
68
.github/workflows/pre-commit.yml
vendored
68
.github/workflows/pre-commit.yml
vendored
|
|
@ -5,7 +5,9 @@ run-name: Run pre-commit checks
|
||||||
on:
|
on:
|
||||||
pull_request:
|
pull_request:
|
||||||
push:
|
push:
|
||||||
branches: [main]
|
branches:
|
||||||
|
- main
|
||||||
|
- 'release-[0-9]+.[0-9]+.x'
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
|
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
|
||||||
|
|
@ -43,23 +45,41 @@ jobs:
|
||||||
cache: 'npm'
|
cache: 'npm'
|
||||||
cache-dependency-path: 'src/llama_stack/ui/'
|
cache-dependency-path: 'src/llama_stack/ui/'
|
||||||
|
|
||||||
|
- name: Set up uv
|
||||||
|
uses: astral-sh/setup-uv@85856786d1ce8acfbcc2f13a5f3fbd6b938f9f41 # v7.1.2
|
||||||
|
|
||||||
- name: Install npm dependencies
|
- name: Install npm dependencies
|
||||||
run: npm ci
|
run: npm ci
|
||||||
working-directory: src/llama_stack/ui
|
working-directory: src/llama_stack/ui
|
||||||
|
|
||||||
|
- name: Install pre-commit
|
||||||
|
run: python -m pip install pre-commit
|
||||||
|
|
||||||
|
- name: Cache pre-commit
|
||||||
|
uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # v4
|
||||||
|
with:
|
||||||
|
path: ~/.cache/pre-commit
|
||||||
|
key: pre-commit-3|${{ env.pythonLocation }}|${{ hashFiles('.pre-commit-config.yaml') }}
|
||||||
|
|
||||||
- name: Run pre-commit
|
- name: Run pre-commit
|
||||||
id: precommit
|
id: precommit
|
||||||
uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
|
run: |
|
||||||
continue-on-error: true
|
set +e
|
||||||
|
pre-commit run --show-diff-on-failure --color=always --all-files 2>&1 | tee /tmp/precommit.log
|
||||||
|
status=${PIPESTATUS[0]}
|
||||||
|
echo "status=$status" >> $GITHUB_OUTPUT
|
||||||
|
exit 0
|
||||||
env:
|
env:
|
||||||
SKIP: no-commit-to-branch
|
SKIP: no-commit-to-branch,mypy
|
||||||
RUFF_OUTPUT_FORMAT: github
|
RUFF_OUTPUT_FORMAT: github
|
||||||
|
|
||||||
- name: Check pre-commit results
|
- name: Check pre-commit results
|
||||||
if: steps.precommit.outcome == 'failure'
|
if: steps.precommit.outputs.status != '0'
|
||||||
run: |
|
run: |
|
||||||
echo "::error::Pre-commit hooks failed. Please run 'pre-commit run --all-files' locally and commit the fixes."
|
echo "::error::Pre-commit hooks failed. Please run 'pre-commit run --all-files' locally and commit the fixes."
|
||||||
echo "::warning::Some pre-commit hooks failed. Check the output above for details."
|
echo ""
|
||||||
|
echo "Failed hooks output:"
|
||||||
|
cat /tmp/precommit.log
|
||||||
exit 1
|
exit 1
|
||||||
|
|
||||||
- name: Debug
|
- name: Debug
|
||||||
|
|
@ -109,3 +129,39 @@ jobs:
|
||||||
echo "$unstaged_files"
|
echo "$unstaged_files"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
- name: Configure client installation
|
||||||
|
id: client-config
|
||||||
|
uses: ./.github/actions/install-llama-stack-client
|
||||||
|
|
||||||
|
- name: Sync dev + type_checking dependencies
|
||||||
|
env:
|
||||||
|
UV_EXTRA_INDEX_URL: ${{ steps.client-config.outputs.uv-extra-index-url }}
|
||||||
|
run: |
|
||||||
|
if [ -n "$UV_EXTRA_INDEX_URL" ]; then
|
||||||
|
export UV_INDEX_STRATEGY="unsafe-best-match"
|
||||||
|
fi
|
||||||
|
|
||||||
|
uv sync --group dev --group type_checking
|
||||||
|
|
||||||
|
# Install specific client version after sync if needed
|
||||||
|
if [ "${{ steps.client-config.outputs.install-after-sync }}" = "true" ]; then
|
||||||
|
echo "Installing llama-stack-client from: ${{ steps.client-config.outputs.install-source }}"
|
||||||
|
uv pip install ${{ steps.client-config.outputs.install-source }}
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Run mypy (full type_checking)
|
||||||
|
env:
|
||||||
|
UV_EXTRA_INDEX_URL: ${{ steps.client-config.outputs.uv-extra-index-url }}
|
||||||
|
run: |
|
||||||
|
if [ -n "$UV_EXTRA_INDEX_URL" ]; then
|
||||||
|
export UV_INDEX_STRATEGY="unsafe-best-match"
|
||||||
|
fi
|
||||||
|
|
||||||
|
set +e
|
||||||
|
uv run --group dev --group type_checking mypy
|
||||||
|
status=$?
|
||||||
|
if [ $status -ne 0 ]; then
|
||||||
|
echo "::error::Full mypy failed. Reproduce locally with 'uv run pre-commit run mypy-full --hook-stage manual --all-files'."
|
||||||
|
fi
|
||||||
|
exit $status
|
||||||
|
|
|
||||||
227
.github/workflows/precommit-trigger.yml
vendored
227
.github/workflows/precommit-trigger.yml
vendored
|
|
@ -1,227 +0,0 @@
|
||||||
name: Pre-commit Bot
|
|
||||||
|
|
||||||
run-name: Pre-commit bot for PR #${{ github.event.issue.number }}
|
|
||||||
|
|
||||||
on:
|
|
||||||
issue_comment:
|
|
||||||
types: [created]
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
pre-commit:
|
|
||||||
# Only run on pull request comments
|
|
||||||
if: github.event.issue.pull_request && contains(github.event.comment.body, '@github-actions run precommit')
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
permissions:
|
|
||||||
contents: write
|
|
||||||
pull-requests: write
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Check comment author and get PR details
|
|
||||||
id: check_author
|
|
||||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
|
||||||
with:
|
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
script: |
|
|
||||||
// Get PR details
|
|
||||||
const pr = await github.rest.pulls.get({
|
|
||||||
owner: context.repo.owner,
|
|
||||||
repo: context.repo.repo,
|
|
||||||
pull_number: context.issue.number
|
|
||||||
});
|
|
||||||
|
|
||||||
// Check if commenter has write access or is the PR author
|
|
||||||
const commenter = context.payload.comment.user.login;
|
|
||||||
const prAuthor = pr.data.user.login;
|
|
||||||
|
|
||||||
let hasPermission = false;
|
|
||||||
|
|
||||||
// Check if commenter is PR author
|
|
||||||
if (commenter === prAuthor) {
|
|
||||||
hasPermission = true;
|
|
||||||
console.log(`Comment author ${commenter} is the PR author`);
|
|
||||||
} else {
|
|
||||||
// Check if commenter has write/admin access
|
|
||||||
try {
|
|
||||||
const permission = await github.rest.repos.getCollaboratorPermissionLevel({
|
|
||||||
owner: context.repo.owner,
|
|
||||||
repo: context.repo.repo,
|
|
||||||
username: commenter
|
|
||||||
});
|
|
||||||
|
|
||||||
const level = permission.data.permission;
|
|
||||||
hasPermission = ['write', 'admin', 'maintain'].includes(level);
|
|
||||||
console.log(`Comment author ${commenter} has permission: ${level}`);
|
|
||||||
} catch (error) {
|
|
||||||
console.log(`Could not check permissions for ${commenter}: ${error.message}`);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!hasPermission) {
|
|
||||||
await github.rest.issues.createComment({
|
|
||||||
owner: context.repo.owner,
|
|
||||||
repo: context.repo.repo,
|
|
||||||
issue_number: context.issue.number,
|
|
||||||
body: `❌ @${commenter} You don't have permission to trigger pre-commit. Only PR authors or repository collaborators can run this command.`
|
|
||||||
});
|
|
||||||
core.setFailed(`User ${commenter} does not have permission`);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Save PR info for later steps
|
|
||||||
core.setOutput('pr_number', context.issue.number);
|
|
||||||
core.setOutput('pr_head_ref', pr.data.head.ref);
|
|
||||||
core.setOutput('pr_head_sha', pr.data.head.sha);
|
|
||||||
core.setOutput('pr_head_repo', pr.data.head.repo.full_name);
|
|
||||||
core.setOutput('pr_base_ref', pr.data.base.ref);
|
|
||||||
core.setOutput('is_fork', pr.data.head.repo.full_name !== context.payload.repository.full_name);
|
|
||||||
core.setOutput('authorized', 'true');
|
|
||||||
|
|
||||||
- name: React to comment
|
|
||||||
if: steps.check_author.outputs.authorized == 'true'
|
|
||||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
|
||||||
with:
|
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
script: |
|
|
||||||
await github.rest.reactions.createForIssueComment({
|
|
||||||
owner: context.repo.owner,
|
|
||||||
repo: context.repo.repo,
|
|
||||||
comment_id: context.payload.comment.id,
|
|
||||||
content: 'rocket'
|
|
||||||
});
|
|
||||||
|
|
||||||
- name: Comment starting
|
|
||||||
if: steps.check_author.outputs.authorized == 'true'
|
|
||||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
|
||||||
with:
|
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
script: |
|
|
||||||
await github.rest.issues.createComment({
|
|
||||||
owner: context.repo.owner,
|
|
||||||
repo: context.repo.repo,
|
|
||||||
issue_number: ${{ steps.check_author.outputs.pr_number }},
|
|
||||||
body: `⏳ Running [pre-commit hooks](https://github.com/${context.repo.owner}/${context.repo.repo}/actions/runs/${context.runId}) on PR #${{ steps.check_author.outputs.pr_number }}...`
|
|
||||||
});
|
|
||||||
|
|
||||||
- name: Checkout PR branch (same-repo)
|
|
||||||
if: steps.check_author.outputs.authorized == 'true' && steps.check_author.outputs.is_fork == 'false'
|
|
||||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
|
||||||
with:
|
|
||||||
ref: ${{ steps.check_author.outputs.pr_head_ref }}
|
|
||||||
fetch-depth: 0
|
|
||||||
token: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
|
|
||||||
- name: Checkout PR branch (fork)
|
|
||||||
if: steps.check_author.outputs.authorized == 'true' && steps.check_author.outputs.is_fork == 'true'
|
|
||||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
|
||||||
with:
|
|
||||||
repository: ${{ steps.check_author.outputs.pr_head_repo }}
|
|
||||||
ref: ${{ steps.check_author.outputs.pr_head_ref }}
|
|
||||||
fetch-depth: 0
|
|
||||||
token: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
|
|
||||||
- name: Verify checkout
|
|
||||||
if: steps.check_author.outputs.authorized == 'true'
|
|
||||||
run: |
|
|
||||||
echo "Current SHA: $(git rev-parse HEAD)"
|
|
||||||
echo "Expected SHA: ${{ steps.check_author.outputs.pr_head_sha }}"
|
|
||||||
if [[ "$(git rev-parse HEAD)" != "${{ steps.check_author.outputs.pr_head_sha }}" ]]; then
|
|
||||||
echo "::error::Checked out SHA does not match expected SHA"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
- name: Set up Python
|
|
||||||
if: steps.check_author.outputs.authorized == 'true'
|
|
||||||
uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0
|
|
||||||
with:
|
|
||||||
python-version: '3.12'
|
|
||||||
cache: pip
|
|
||||||
cache-dependency-path: |
|
|
||||||
**/requirements*.txt
|
|
||||||
.pre-commit-config.yaml
|
|
||||||
|
|
||||||
- name: Set up Node.js
|
|
||||||
if: steps.check_author.outputs.authorized == 'true'
|
|
||||||
uses: actions/setup-node@2028fbc5c25fe9cf00d9f06a71cc4710d4507903 # v6.0.0
|
|
||||||
with:
|
|
||||||
node-version: '20'
|
|
||||||
cache: 'npm'
|
|
||||||
cache-dependency-path: 'src/llama_stack/ui/'
|
|
||||||
|
|
||||||
- name: Install npm dependencies
|
|
||||||
if: steps.check_author.outputs.authorized == 'true'
|
|
||||||
run: npm ci
|
|
||||||
working-directory: src/llama_stack/ui
|
|
||||||
|
|
||||||
- name: Run pre-commit
|
|
||||||
if: steps.check_author.outputs.authorized == 'true'
|
|
||||||
id: precommit
|
|
||||||
uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
|
|
||||||
continue-on-error: true
|
|
||||||
env:
|
|
||||||
SKIP: no-commit-to-branch
|
|
||||||
RUFF_OUTPUT_FORMAT: github
|
|
||||||
|
|
||||||
- name: Check for changes
|
|
||||||
if: steps.check_author.outputs.authorized == 'true'
|
|
||||||
id: changes
|
|
||||||
run: |
|
|
||||||
if ! git diff --exit-code || [ -n "$(git ls-files --others --exclude-standard)" ]; then
|
|
||||||
echo "has_changes=true" >> $GITHUB_OUTPUT
|
|
||||||
echo "Changes detected after pre-commit"
|
|
||||||
else
|
|
||||||
echo "has_changes=false" >> $GITHUB_OUTPUT
|
|
||||||
echo "No changes after pre-commit"
|
|
||||||
fi
|
|
||||||
|
|
||||||
- name: Commit and push changes
|
|
||||||
if: steps.check_author.outputs.authorized == 'true' && steps.changes.outputs.has_changes == 'true'
|
|
||||||
run: |
|
|
||||||
git config --local user.email "github-actions[bot]@users.noreply.github.com"
|
|
||||||
git config --local user.name "github-actions[bot]"
|
|
||||||
|
|
||||||
git add -A
|
|
||||||
git commit -m "style: apply pre-commit fixes
|
|
||||||
|
|
||||||
🤖 Applied by @github-actions bot via pre-commit workflow"
|
|
||||||
|
|
||||||
# Push changes
|
|
||||||
git push origin HEAD:${{ steps.check_author.outputs.pr_head_ref }}
|
|
||||||
|
|
||||||
- name: Comment success with changes
|
|
||||||
if: steps.check_author.outputs.authorized == 'true' && steps.changes.outputs.has_changes == 'true'
|
|
||||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
|
||||||
with:
|
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
script: |
|
|
||||||
await github.rest.issues.createComment({
|
|
||||||
owner: context.repo.owner,
|
|
||||||
repo: context.repo.repo,
|
|
||||||
issue_number: ${{ steps.check_author.outputs.pr_number }},
|
|
||||||
body: `✅ Pre-commit hooks completed successfully!\n\n🔧 Changes have been committed and pushed to the PR branch.`
|
|
||||||
});
|
|
||||||
|
|
||||||
- name: Comment success without changes
|
|
||||||
if: steps.check_author.outputs.authorized == 'true' && steps.changes.outputs.has_changes == 'false' && steps.precommit.outcome == 'success'
|
|
||||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
|
||||||
with:
|
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
script: |
|
|
||||||
await github.rest.issues.createComment({
|
|
||||||
owner: context.repo.owner,
|
|
||||||
repo: context.repo.repo,
|
|
||||||
issue_number: ${{ steps.check_author.outputs.pr_number }},
|
|
||||||
body: `✅ Pre-commit hooks passed!\n\n✨ No changes needed - your code is already formatted correctly.`
|
|
||||||
});
|
|
||||||
|
|
||||||
- name: Comment failure
|
|
||||||
if: failure()
|
|
||||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
|
||||||
with:
|
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
script: |
|
|
||||||
await github.rest.issues.createComment({
|
|
||||||
owner: context.repo.owner,
|
|
||||||
repo: context.repo.repo,
|
|
||||||
issue_number: ${{ steps.check_author.outputs.pr_number }},
|
|
||||||
body: `❌ Pre-commit workflow failed!\n\nPlease check the [workflow logs](https://github.com/${context.repo.owner}/${context.repo.repo}/actions/runs/${context.runId}) for details.`
|
|
||||||
});
|
|
||||||
38
.github/workflows/providers-build.yml
vendored
38
.github/workflows/providers-build.yml
vendored
|
|
@ -72,10 +72,16 @@ jobs:
|
||||||
- name: Build container image
|
- name: Build container image
|
||||||
if: matrix.image-type == 'container'
|
if: matrix.image-type == 'container'
|
||||||
run: |
|
run: |
|
||||||
|
BUILD_ARGS="--build-arg INSTALL_MODE=editable --build-arg DISTRO_NAME=${{ matrix.distro }}"
|
||||||
|
if [ -n "${UV_EXTRA_INDEX_URL:-}" ]; then
|
||||||
|
BUILD_ARGS="$BUILD_ARGS --build-arg UV_EXTRA_INDEX_URL=$UV_EXTRA_INDEX_URL"
|
||||||
|
fi
|
||||||
|
if [ -n "${UV_INDEX_STRATEGY:-}" ]; then
|
||||||
|
BUILD_ARGS="$BUILD_ARGS --build-arg UV_INDEX_STRATEGY=$UV_INDEX_STRATEGY"
|
||||||
|
fi
|
||||||
docker build . \
|
docker build . \
|
||||||
-f containers/Containerfile \
|
-f containers/Containerfile \
|
||||||
--build-arg INSTALL_MODE=editable \
|
$BUILD_ARGS \
|
||||||
--build-arg DISTRO_NAME=${{ matrix.distro }} \
|
|
||||||
--tag llama-stack:${{ matrix.distro }}-ci
|
--tag llama-stack:${{ matrix.distro }}-ci
|
||||||
|
|
||||||
- name: Print dependencies in the image
|
- name: Print dependencies in the image
|
||||||
|
|
@ -108,12 +114,18 @@ jobs:
|
||||||
- name: Build container image
|
- name: Build container image
|
||||||
run: |
|
run: |
|
||||||
BASE_IMAGE=$(yq -r '.distribution_spec.container_image // "python:3.12-slim"' src/llama_stack/distributions/ci-tests/build.yaml)
|
BASE_IMAGE=$(yq -r '.distribution_spec.container_image // "python:3.12-slim"' src/llama_stack/distributions/ci-tests/build.yaml)
|
||||||
|
BUILD_ARGS="--build-arg INSTALL_MODE=editable --build-arg DISTRO_NAME=ci-tests"
|
||||||
|
BUILD_ARGS="$BUILD_ARGS --build-arg BASE_IMAGE=$BASE_IMAGE"
|
||||||
|
BUILD_ARGS="$BUILD_ARGS --build-arg RUN_CONFIG_PATH=/workspace/src/llama_stack/distributions/ci-tests/run.yaml"
|
||||||
|
if [ -n "${UV_EXTRA_INDEX_URL:-}" ]; then
|
||||||
|
BUILD_ARGS="$BUILD_ARGS --build-arg UV_EXTRA_INDEX_URL=$UV_EXTRA_INDEX_URL"
|
||||||
|
fi
|
||||||
|
if [ -n "${UV_INDEX_STRATEGY:-}" ]; then
|
||||||
|
BUILD_ARGS="$BUILD_ARGS --build-arg UV_INDEX_STRATEGY=$UV_INDEX_STRATEGY"
|
||||||
|
fi
|
||||||
docker build . \
|
docker build . \
|
||||||
-f containers/Containerfile \
|
-f containers/Containerfile \
|
||||||
--build-arg INSTALL_MODE=editable \
|
$BUILD_ARGS \
|
||||||
--build-arg DISTRO_NAME=ci-tests \
|
|
||||||
--build-arg BASE_IMAGE="$BASE_IMAGE" \
|
|
||||||
--build-arg RUN_CONFIG_PATH=/workspace/src/llama_stack/distributions/ci-tests/run.yaml \
|
|
||||||
-t llama-stack:ci-tests
|
-t llama-stack:ci-tests
|
||||||
|
|
||||||
- name: Inspect the container image entrypoint
|
- name: Inspect the container image entrypoint
|
||||||
|
|
@ -148,12 +160,18 @@ jobs:
|
||||||
- name: Build UBI9 container image
|
- name: Build UBI9 container image
|
||||||
run: |
|
run: |
|
||||||
BASE_IMAGE=$(yq -r '.distribution_spec.container_image // "registry.access.redhat.com/ubi9:latest"' src/llama_stack/distributions/ci-tests/build.yaml)
|
BASE_IMAGE=$(yq -r '.distribution_spec.container_image // "registry.access.redhat.com/ubi9:latest"' src/llama_stack/distributions/ci-tests/build.yaml)
|
||||||
|
BUILD_ARGS="--build-arg INSTALL_MODE=editable --build-arg DISTRO_NAME=ci-tests"
|
||||||
|
BUILD_ARGS="$BUILD_ARGS --build-arg BASE_IMAGE=$BASE_IMAGE"
|
||||||
|
BUILD_ARGS="$BUILD_ARGS --build-arg RUN_CONFIG_PATH=/workspace/src/llama_stack/distributions/ci-tests/run.yaml"
|
||||||
|
if [ -n "${UV_EXTRA_INDEX_URL:-}" ]; then
|
||||||
|
BUILD_ARGS="$BUILD_ARGS --build-arg UV_EXTRA_INDEX_URL=$UV_EXTRA_INDEX_URL"
|
||||||
|
fi
|
||||||
|
if [ -n "${UV_INDEX_STRATEGY:-}" ]; then
|
||||||
|
BUILD_ARGS="$BUILD_ARGS --build-arg UV_INDEX_STRATEGY=$UV_INDEX_STRATEGY"
|
||||||
|
fi
|
||||||
docker build . \
|
docker build . \
|
||||||
-f containers/Containerfile \
|
-f containers/Containerfile \
|
||||||
--build-arg INSTALL_MODE=editable \
|
$BUILD_ARGS \
|
||||||
--build-arg DISTRO_NAME=ci-tests \
|
|
||||||
--build-arg BASE_IMAGE="$BASE_IMAGE" \
|
|
||||||
--build-arg RUN_CONFIG_PATH=/workspace/src/llama_stack/distributions/ci-tests/run.yaml \
|
|
||||||
-t llama-stack:ci-tests-ubi9
|
-t llama-stack:ci-tests-ubi9
|
||||||
|
|
||||||
- name: Inspect UBI9 image
|
- name: Inspect UBI9 image
|
||||||
|
|
|
||||||
2
.github/workflows/python-build-test.yml
vendored
2
.github/workflows/python-build-test.yml
vendored
|
|
@ -24,7 +24,7 @@ jobs:
|
||||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||||
|
|
||||||
- name: Install uv
|
- name: Install uv
|
||||||
uses: astral-sh/setup-uv@2ddd2b9cb38ad8efd50337e8ab201519a34c9f24 # v7.1.1
|
uses: astral-sh/setup-uv@85856786d1ce8acfbcc2f13a5f3fbd6b938f9f41 # v7.1.2
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
activate-environment: true
|
activate-environment: true
|
||||||
|
|
|
||||||
8
.github/workflows/unit-tests.yml
vendored
8
.github/workflows/unit-tests.yml
vendored
|
|
@ -4,9 +4,13 @@ run-name: Run the unit test suite
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches: [ main ]
|
branches:
|
||||||
|
- main
|
||||||
|
- 'release-[0-9]+.[0-9]+.x'
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ main ]
|
branches:
|
||||||
|
- main
|
||||||
|
- 'release-[0-9]+.[0-9]+.x'
|
||||||
paths:
|
paths:
|
||||||
- 'src/llama_stack/**'
|
- 'src/llama_stack/**'
|
||||||
- '!src/llama_stack/ui/**'
|
- '!src/llama_stack/ui/**'
|
||||||
|
|
|
||||||
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -32,3 +32,6 @@ CLAUDE.md
|
||||||
docs/.docusaurus/
|
docs/.docusaurus/
|
||||||
docs/node_modules/
|
docs/node_modules/
|
||||||
docs/static/imported-files/
|
docs/static/imported-files/
|
||||||
|
docs/docs/api-deprecated/
|
||||||
|
docs/docs/api-experimental/
|
||||||
|
docs/docs/api/
|
||||||
|
|
|
||||||
|
|
@ -52,22 +52,19 @@ repos:
|
||||||
additional_dependencies:
|
additional_dependencies:
|
||||||
- black==24.3.0
|
- black==24.3.0
|
||||||
|
|
||||||
- repo: https://github.com/astral-sh/uv-pre-commit
|
|
||||||
rev: 0.7.20
|
|
||||||
hooks:
|
|
||||||
- id: uv-lock
|
|
||||||
|
|
||||||
- repo: local
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||||
|
rev: v1.18.2
|
||||||
hooks:
|
hooks:
|
||||||
- id: mypy
|
- id: mypy
|
||||||
name: mypy
|
|
||||||
additional_dependencies:
|
additional_dependencies:
|
||||||
- uv==0.7.8
|
- uv==0.6.2
|
||||||
entry: uv run --group dev --group type_checking mypy
|
- mypy
|
||||||
language: python
|
- pytest
|
||||||
types: [python]
|
- rich
|
||||||
|
- types-requests
|
||||||
|
- pydantic
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
require_serial: true
|
|
||||||
|
|
||||||
# - repo: https://github.com/tcort/markdown-link-check
|
# - repo: https://github.com/tcort/markdown-link-check
|
||||||
# rev: v3.11.2
|
# rev: v3.11.2
|
||||||
|
|
@ -77,11 +74,26 @@ repos:
|
||||||
|
|
||||||
- repo: local
|
- repo: local
|
||||||
hooks:
|
hooks:
|
||||||
|
- id: uv-lock
|
||||||
|
name: uv-lock
|
||||||
|
additional_dependencies:
|
||||||
|
- uv==0.7.20
|
||||||
|
entry: ./scripts/uv-run-with-index.sh lock
|
||||||
|
language: python
|
||||||
|
pass_filenames: false
|
||||||
|
require_serial: true
|
||||||
|
files: ^(pyproject\.toml|uv\.lock)$
|
||||||
|
- id: mypy-full
|
||||||
|
name: mypy (full type_checking)
|
||||||
|
entry: ./scripts/uv-run-with-index.sh run --group dev --group type_checking mypy
|
||||||
|
language: system
|
||||||
|
pass_filenames: false
|
||||||
|
stages: [manual]
|
||||||
- id: distro-codegen
|
- id: distro-codegen
|
||||||
name: Distribution Template Codegen
|
name: Distribution Template Codegen
|
||||||
additional_dependencies:
|
additional_dependencies:
|
||||||
- uv==0.7.8
|
- uv==0.7.8
|
||||||
entry: uv run --group codegen ./scripts/distro_codegen.py
|
entry: ./scripts/uv-run-with-index.sh run --group codegen ./scripts/distro_codegen.py
|
||||||
language: python
|
language: python
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
require_serial: true
|
require_serial: true
|
||||||
|
|
@ -90,7 +102,7 @@ repos:
|
||||||
name: Provider Codegen
|
name: Provider Codegen
|
||||||
additional_dependencies:
|
additional_dependencies:
|
||||||
- uv==0.7.8
|
- uv==0.7.8
|
||||||
entry: uv run --group codegen ./scripts/provider_codegen.py
|
entry: ./scripts/uv-run-with-index.sh run --group codegen ./scripts/provider_codegen.py
|
||||||
language: python
|
language: python
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
require_serial: true
|
require_serial: true
|
||||||
|
|
@ -99,7 +111,7 @@ repos:
|
||||||
name: API Spec Codegen
|
name: API Spec Codegen
|
||||||
additional_dependencies:
|
additional_dependencies:
|
||||||
- uv==0.7.8
|
- uv==0.7.8
|
||||||
entry: sh -c 'uv run ./docs/openapi_generator/run_openapi_generator.sh > /dev/null'
|
entry: sh -c './scripts/uv-run-with-index.sh run ./docs/openapi_generator/run_openapi_generator.sh > /dev/null'
|
||||||
language: python
|
language: python
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
require_serial: true
|
require_serial: true
|
||||||
|
|
@ -140,7 +152,7 @@ repos:
|
||||||
name: Generate CI documentation
|
name: Generate CI documentation
|
||||||
additional_dependencies:
|
additional_dependencies:
|
||||||
- uv==0.7.8
|
- uv==0.7.8
|
||||||
entry: uv run ./scripts/gen-ci-docs.py
|
entry: ./scripts/uv-run-with-index.sh run ./scripts/gen-ci-docs.py
|
||||||
language: python
|
language: python
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
require_serial: true
|
require_serial: true
|
||||||
|
|
@ -171,6 +183,23 @@ repos:
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
exit 0
|
exit 0
|
||||||
|
- id: fips-compliance
|
||||||
|
name: Ensure llama-stack remains FIPS compliant
|
||||||
|
entry: bash
|
||||||
|
language: system
|
||||||
|
types: [python]
|
||||||
|
pass_filenames: true
|
||||||
|
exclude: '^tests/.*$' # Exclude test dir as some safety tests used MD5
|
||||||
|
args:
|
||||||
|
- -c
|
||||||
|
- |
|
||||||
|
grep -EnH '^[^#]*\b(md5|sha1|uuid3|uuid5)\b' "$@" && {
|
||||||
|
echo;
|
||||||
|
echo "❌ Do not use any of the following functions: hashlib.md5, hashlib.sha1, uuid.uuid3, uuid.uuid5"
|
||||||
|
echo " These functions are not FIPS-compliant"
|
||||||
|
echo;
|
||||||
|
exit 1;
|
||||||
|
} || true
|
||||||
|
|
||||||
ci:
|
ci:
|
||||||
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
|
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
|
||||||
|
|
|
||||||
|
|
@ -61,6 +61,18 @@ uv run pre-commit run --all-files -v
|
||||||
|
|
||||||
The `-v` (verbose) parameter is optional but often helpful for getting more information about any issues with that the pre-commit checks identify.
|
The `-v` (verbose) parameter is optional but often helpful for getting more information about any issues with that the pre-commit checks identify.
|
||||||
|
|
||||||
|
To run the expanded mypy configuration that CI enforces, use:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run pre-commit run mypy-full --hook-stage manual --all-files
|
||||||
|
```
|
||||||
|
|
||||||
|
or invoke mypy directly with all optional dependencies:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run --group dev --group type_checking mypy
|
||||||
|
```
|
||||||
|
|
||||||
```{caution}
|
```{caution}
|
||||||
Before pushing your changes, make sure that the pre-commit hooks have passed successfully.
|
Before pushing your changes, make sure that the pre-commit hooks have passed successfully.
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -1,610 +0,0 @@
|
||||||
# yaml-language-server: $schema=https://app.stainlessapi.com/config-internal.schema.json
|
|
||||||
|
|
||||||
organization:
|
|
||||||
# Name of your organization or company, used to determine the name of the client
|
|
||||||
# and headings.
|
|
||||||
name: llama-stack-client
|
|
||||||
docs: https://llama-stack.readthedocs.io/en/latest/
|
|
||||||
contact: llamastack@meta.com
|
|
||||||
security:
|
|
||||||
- {}
|
|
||||||
- BearerAuth: []
|
|
||||||
security_schemes:
|
|
||||||
BearerAuth:
|
|
||||||
type: http
|
|
||||||
scheme: bearer
|
|
||||||
# `targets` define the output targets and their customization options, such as
|
|
||||||
# whether to emit the Node SDK and what it's package name should be.
|
|
||||||
targets:
|
|
||||||
node:
|
|
||||||
package_name: llama-stack-client
|
|
||||||
production_repo: llamastack/llama-stack-client-typescript
|
|
||||||
publish:
|
|
||||||
npm: false
|
|
||||||
python:
|
|
||||||
package_name: llama_stack_client
|
|
||||||
production_repo: llamastack/llama-stack-client-python
|
|
||||||
options:
|
|
||||||
use_uv: true
|
|
||||||
publish:
|
|
||||||
pypi: true
|
|
||||||
project_name: llama_stack_client
|
|
||||||
kotlin:
|
|
||||||
reverse_domain: com.llama_stack_client.api
|
|
||||||
production_repo: null
|
|
||||||
publish:
|
|
||||||
maven: false
|
|
||||||
go:
|
|
||||||
package_name: llama-stack-client
|
|
||||||
production_repo: llamastack/llama-stack-client-go
|
|
||||||
options:
|
|
||||||
enable_v2: true
|
|
||||||
back_compat_use_shared_package: false
|
|
||||||
|
|
||||||
# `client_settings` define settings for the API client, such as extra constructor
|
|
||||||
# arguments (used for authentication), retry behavior, idempotency, etc.
|
|
||||||
client_settings:
|
|
||||||
default_env_prefix: LLAMA_STACK_CLIENT
|
|
||||||
opts:
|
|
||||||
api_key:
|
|
||||||
type: string
|
|
||||||
read_env: LLAMA_STACK_CLIENT_API_KEY
|
|
||||||
auth: { security_scheme: BearerAuth }
|
|
||||||
nullable: true
|
|
||||||
|
|
||||||
# `environments` are a map of the name of the environment (e.g. "sandbox",
|
|
||||||
# "production") to the corresponding url to use.
|
|
||||||
environments:
|
|
||||||
production: http://any-hosted-llama-stack.com
|
|
||||||
|
|
||||||
# `pagination` defines [pagination schemes] which provides a template to match
|
|
||||||
# endpoints and generate next-page and auto-pagination helpers in the SDKs.
|
|
||||||
pagination:
|
|
||||||
- name: datasets_iterrows
|
|
||||||
type: offset
|
|
||||||
request:
|
|
||||||
dataset_id:
|
|
||||||
type: string
|
|
||||||
start_index:
|
|
||||||
type: integer
|
|
||||||
x-stainless-pagination-property:
|
|
||||||
purpose: offset_count_param
|
|
||||||
limit:
|
|
||||||
type: integer
|
|
||||||
response:
|
|
||||||
data:
|
|
||||||
type: array
|
|
||||||
items:
|
|
||||||
type: object
|
|
||||||
next_index:
|
|
||||||
type: integer
|
|
||||||
x-stainless-pagination-property:
|
|
||||||
purpose: offset_count_start_field
|
|
||||||
- name: openai_cursor_page
|
|
||||||
type: cursor
|
|
||||||
request:
|
|
||||||
limit:
|
|
||||||
type: integer
|
|
||||||
after:
|
|
||||||
type: string
|
|
||||||
x-stainless-pagination-property:
|
|
||||||
purpose: next_cursor_param
|
|
||||||
response:
|
|
||||||
data:
|
|
||||||
type: array
|
|
||||||
items: {}
|
|
||||||
has_more:
|
|
||||||
type: boolean
|
|
||||||
last_id:
|
|
||||||
type: string
|
|
||||||
x-stainless-pagination-property:
|
|
||||||
purpose: next_cursor_field
|
|
||||||
# `resources` define the structure and organziation for your API, such as how
|
|
||||||
# methods and models are grouped together and accessed. See the [configuration
|
|
||||||
# guide] for more information.
|
|
||||||
#
|
|
||||||
# [configuration guide]:
|
|
||||||
# https://app.stainlessapi.com/docs/guides/configure#resources
|
|
||||||
resources:
|
|
||||||
$shared:
|
|
||||||
models:
|
|
||||||
agent_config: AgentConfig
|
|
||||||
interleaved_content_item: InterleavedContentItem
|
|
||||||
interleaved_content: InterleavedContent
|
|
||||||
param_type: ParamType
|
|
||||||
safety_violation: SafetyViolation
|
|
||||||
sampling_params: SamplingParams
|
|
||||||
scoring_result: ScoringResult
|
|
||||||
message: Message
|
|
||||||
user_message: UserMessage
|
|
||||||
completion_message: CompletionMessage
|
|
||||||
tool_response_message: ToolResponseMessage
|
|
||||||
system_message: SystemMessage
|
|
||||||
tool_call: ToolCall
|
|
||||||
query_result: RAGQueryResult
|
|
||||||
document: RAGDocument
|
|
||||||
query_config: RAGQueryConfig
|
|
||||||
response_format: ResponseFormat
|
|
||||||
toolgroups:
|
|
||||||
models:
|
|
||||||
tool_group: ToolGroup
|
|
||||||
list_tool_groups_response: ListToolGroupsResponse
|
|
||||||
methods:
|
|
||||||
register: post /v1/toolgroups
|
|
||||||
get: get /v1/toolgroups/{toolgroup_id}
|
|
||||||
list: get /v1/toolgroups
|
|
||||||
unregister: delete /v1/toolgroups/{toolgroup_id}
|
|
||||||
tools:
|
|
||||||
methods:
|
|
||||||
get: get /v1/tools/{tool_name}
|
|
||||||
list:
|
|
||||||
endpoint: get /v1/tools
|
|
||||||
paginated: false
|
|
||||||
|
|
||||||
tool_runtime:
|
|
||||||
models:
|
|
||||||
tool_def: ToolDef
|
|
||||||
tool_invocation_result: ToolInvocationResult
|
|
||||||
methods:
|
|
||||||
list_tools:
|
|
||||||
endpoint: get /v1/tool-runtime/list-tools
|
|
||||||
paginated: false
|
|
||||||
invoke_tool: post /v1/tool-runtime/invoke
|
|
||||||
subresources:
|
|
||||||
rag_tool:
|
|
||||||
methods:
|
|
||||||
insert: post /v1/tool-runtime/rag-tool/insert
|
|
||||||
query: post /v1/tool-runtime/rag-tool/query
|
|
||||||
|
|
||||||
responses:
|
|
||||||
models:
|
|
||||||
response_object_stream: OpenAIResponseObjectStream
|
|
||||||
response_object: OpenAIResponseObject
|
|
||||||
methods:
|
|
||||||
create:
|
|
||||||
type: http
|
|
||||||
endpoint: post /v1/responses
|
|
||||||
streaming:
|
|
||||||
stream_event_model: responses.response_object_stream
|
|
||||||
param_discriminator: stream
|
|
||||||
retrieve: get /v1/responses/{response_id}
|
|
||||||
list:
|
|
||||||
type: http
|
|
||||||
endpoint: get /v1/responses
|
|
||||||
delete:
|
|
||||||
type: http
|
|
||||||
endpoint: delete /v1/responses/{response_id}
|
|
||||||
subresources:
|
|
||||||
input_items:
|
|
||||||
methods:
|
|
||||||
list:
|
|
||||||
type: http
|
|
||||||
endpoint: get /v1/responses/{response_id}/input_items
|
|
||||||
|
|
||||||
conversations:
|
|
||||||
models:
|
|
||||||
conversation_object: Conversation
|
|
||||||
methods:
|
|
||||||
create:
|
|
||||||
type: http
|
|
||||||
endpoint: post /v1/conversations
|
|
||||||
retrieve: get /v1/conversations/{conversation_id}
|
|
||||||
update:
|
|
||||||
type: http
|
|
||||||
endpoint: post /v1/conversations/{conversation_id}
|
|
||||||
delete:
|
|
||||||
type: http
|
|
||||||
endpoint: delete /v1/conversations/{conversation_id}
|
|
||||||
subresources:
|
|
||||||
items:
|
|
||||||
methods:
|
|
||||||
get:
|
|
||||||
type: http
|
|
||||||
endpoint: get /v1/conversations/{conversation_id}/items/{item_id}
|
|
||||||
list:
|
|
||||||
type: http
|
|
||||||
endpoint: get /v1/conversations/{conversation_id}/items
|
|
||||||
create:
|
|
||||||
type: http
|
|
||||||
endpoint: post /v1/conversations/{conversation_id}/items
|
|
||||||
|
|
||||||
inspect:
|
|
||||||
models:
|
|
||||||
healthInfo: HealthInfo
|
|
||||||
providerInfo: ProviderInfo
|
|
||||||
routeInfo: RouteInfo
|
|
||||||
versionInfo: VersionInfo
|
|
||||||
methods:
|
|
||||||
health: get /v1/health
|
|
||||||
version: get /v1/version
|
|
||||||
|
|
||||||
embeddings:
|
|
||||||
models:
|
|
||||||
create_embeddings_response: OpenAIEmbeddingsResponse
|
|
||||||
methods:
|
|
||||||
create: post /v1/embeddings
|
|
||||||
|
|
||||||
chat:
|
|
||||||
models:
|
|
||||||
chat_completion_chunk: OpenAIChatCompletionChunk
|
|
||||||
subresources:
|
|
||||||
completions:
|
|
||||||
methods:
|
|
||||||
create:
|
|
||||||
type: http
|
|
||||||
endpoint: post /v1/chat/completions
|
|
||||||
streaming:
|
|
||||||
stream_event_model: chat.chat_completion_chunk
|
|
||||||
param_discriminator: stream
|
|
||||||
list:
|
|
||||||
type: http
|
|
||||||
endpoint: get /v1/chat/completions
|
|
||||||
retrieve:
|
|
||||||
type: http
|
|
||||||
endpoint: get /v1/chat/completions/{completion_id}
|
|
||||||
completions:
|
|
||||||
methods:
|
|
||||||
create:
|
|
||||||
type: http
|
|
||||||
endpoint: post /v1/completions
|
|
||||||
streaming:
|
|
||||||
param_discriminator: stream
|
|
||||||
|
|
||||||
vector_io:
|
|
||||||
models:
|
|
||||||
queryChunksResponse: QueryChunksResponse
|
|
||||||
methods:
|
|
||||||
insert: post /v1/vector-io/insert
|
|
||||||
query: post /v1/vector-io/query
|
|
||||||
|
|
||||||
vector_stores:
|
|
||||||
models:
|
|
||||||
vector_store: VectorStoreObject
|
|
||||||
list_vector_stores_response: VectorStoreListResponse
|
|
||||||
vector_store_delete_response: VectorStoreDeleteResponse
|
|
||||||
vector_store_search_response: VectorStoreSearchResponsePage
|
|
||||||
methods:
|
|
||||||
create: post /v1/vector_stores
|
|
||||||
list:
|
|
||||||
endpoint: get /v1/vector_stores
|
|
||||||
retrieve: get /v1/vector_stores/{vector_store_id}
|
|
||||||
update: post /v1/vector_stores/{vector_store_id}
|
|
||||||
delete: delete /v1/vector_stores/{vector_store_id}
|
|
||||||
search: post /v1/vector_stores/{vector_store_id}/search
|
|
||||||
subresources:
|
|
||||||
files:
|
|
||||||
models:
|
|
||||||
vector_store_file: VectorStoreFileObject
|
|
||||||
methods:
|
|
||||||
list: get /v1/vector_stores/{vector_store_id}/files
|
|
||||||
retrieve: get /v1/vector_stores/{vector_store_id}/files/{file_id}
|
|
||||||
update: post /v1/vector_stores/{vector_store_id}/files/{file_id}
|
|
||||||
delete: delete /v1/vector_stores/{vector_store_id}/files/{file_id}
|
|
||||||
create: post /v1/vector_stores/{vector_store_id}/files
|
|
||||||
content: get /v1/vector_stores/{vector_store_id}/files/{file_id}/content
|
|
||||||
file_batches:
|
|
||||||
models:
|
|
||||||
vector_store_file_batches: VectorStoreFileBatchObject
|
|
||||||
list_vector_store_files_in_batch_response: VectorStoreFilesListInBatchResponse
|
|
||||||
methods:
|
|
||||||
create: post /v1/vector_stores/{vector_store_id}/file_batches
|
|
||||||
retrieve: get /v1/vector_stores/{vector_store_id}/file_batches/{batch_id}
|
|
||||||
list_files: get /v1/vector_stores/{vector_store_id}/file_batches/{batch_id}/files
|
|
||||||
cancel: post /v1/vector_stores/{vector_store_id}/file_batches/{batch_id}/cancel
|
|
||||||
|
|
||||||
models:
|
|
||||||
models:
|
|
||||||
model: Model
|
|
||||||
list_models_response: ListModelsResponse
|
|
||||||
methods:
|
|
||||||
retrieve: get /v1/models/{model_id}
|
|
||||||
list:
|
|
||||||
endpoint: get /v1/models
|
|
||||||
paginated: false
|
|
||||||
register: post /v1/models
|
|
||||||
unregister: delete /v1/models/{model_id}
|
|
||||||
subresources:
|
|
||||||
openai:
|
|
||||||
methods:
|
|
||||||
list:
|
|
||||||
endpoint: get /v1/models
|
|
||||||
paginated: false
|
|
||||||
|
|
||||||
providers:
|
|
||||||
models:
|
|
||||||
list_providers_response: ListProvidersResponse
|
|
||||||
methods:
|
|
||||||
list:
|
|
||||||
endpoint: get /v1/providers
|
|
||||||
paginated: false
|
|
||||||
retrieve: get /v1/providers/{provider_id}
|
|
||||||
|
|
||||||
routes:
|
|
||||||
models:
|
|
||||||
list_routes_response: ListRoutesResponse
|
|
||||||
methods:
|
|
||||||
list:
|
|
||||||
endpoint: get /v1/inspect/routes
|
|
||||||
paginated: false
|
|
||||||
|
|
||||||
|
|
||||||
moderations:
|
|
||||||
models:
|
|
||||||
create_response: ModerationObject
|
|
||||||
methods:
|
|
||||||
create: post /v1/moderations
|
|
||||||
|
|
||||||
|
|
||||||
safety:
|
|
||||||
models:
|
|
||||||
run_shield_response: RunShieldResponse
|
|
||||||
methods:
|
|
||||||
run_shield: post /v1/safety/run-shield
|
|
||||||
|
|
||||||
|
|
||||||
shields:
|
|
||||||
models:
|
|
||||||
shield: Shield
|
|
||||||
list_shields_response: ListShieldsResponse
|
|
||||||
methods:
|
|
||||||
retrieve: get /v1/shields/{identifier}
|
|
||||||
list:
|
|
||||||
endpoint: get /v1/shields
|
|
||||||
paginated: false
|
|
||||||
register: post /v1/shields
|
|
||||||
delete: delete /v1/shields/{identifier}
|
|
||||||
|
|
||||||
synthetic_data_generation:
|
|
||||||
models:
|
|
||||||
syntheticDataGenerationResponse: SyntheticDataGenerationResponse
|
|
||||||
methods:
|
|
||||||
generate: post /v1/synthetic-data-generation/generate
|
|
||||||
|
|
||||||
telemetry:
|
|
||||||
models:
|
|
||||||
span_with_status: SpanWithStatus
|
|
||||||
trace: Trace
|
|
||||||
query_spans_response: QuerySpansResponse
|
|
||||||
event: Event
|
|
||||||
query_condition: QueryCondition
|
|
||||||
methods:
|
|
||||||
query_traces:
|
|
||||||
endpoint: post /v1alpha/telemetry/traces
|
|
||||||
skip_test_reason: 'unsupported query params in java / kotlin'
|
|
||||||
get_span_tree: post /v1alpha/telemetry/spans/{span_id}/tree
|
|
||||||
query_spans:
|
|
||||||
endpoint: post /v1alpha/telemetry/spans
|
|
||||||
skip_test_reason: 'unsupported query params in java / kotlin'
|
|
||||||
query_metrics:
|
|
||||||
endpoint: post /v1alpha/telemetry/metrics/{metric_name}
|
|
||||||
skip_test_reason: 'unsupported query params in java / kotlin'
|
|
||||||
# log_event: post /v1alpha/telemetry/events
|
|
||||||
save_spans_to_dataset: post /v1alpha/telemetry/spans/export
|
|
||||||
get_span: get /v1alpha/telemetry/traces/{trace_id}/spans/{span_id}
|
|
||||||
get_trace: get /v1alpha/telemetry/traces/{trace_id}
|
|
||||||
|
|
||||||
scoring:
|
|
||||||
methods:
|
|
||||||
score: post /v1/scoring/score
|
|
||||||
score_batch: post /v1/scoring/score-batch
|
|
||||||
scoring_functions:
|
|
||||||
methods:
|
|
||||||
retrieve: get /v1/scoring-functions/{scoring_fn_id}
|
|
||||||
list:
|
|
||||||
endpoint: get /v1/scoring-functions
|
|
||||||
paginated: false
|
|
||||||
register: post /v1/scoring-functions
|
|
||||||
models:
|
|
||||||
scoring_fn: ScoringFn
|
|
||||||
scoring_fn_params: ScoringFnParams
|
|
||||||
list_scoring_functions_response: ListScoringFunctionsResponse
|
|
||||||
|
|
||||||
benchmarks:
|
|
||||||
methods:
|
|
||||||
retrieve: get /v1alpha/eval/benchmarks/{benchmark_id}
|
|
||||||
list:
|
|
||||||
endpoint: get /v1alpha/eval/benchmarks
|
|
||||||
paginated: false
|
|
||||||
register: post /v1alpha/eval/benchmarks
|
|
||||||
models:
|
|
||||||
benchmark: Benchmark
|
|
||||||
list_benchmarks_response: ListBenchmarksResponse
|
|
||||||
|
|
||||||
files:
|
|
||||||
methods:
|
|
||||||
create: post /v1/files
|
|
||||||
list: get /v1/files
|
|
||||||
retrieve: get /v1/files/{file_id}
|
|
||||||
delete: delete /v1/files/{file_id}
|
|
||||||
content: get /v1/files/{file_id}/content
|
|
||||||
models:
|
|
||||||
file: OpenAIFileObject
|
|
||||||
list_files_response: ListOpenAIFileResponse
|
|
||||||
delete_file_response: OpenAIFileDeleteResponse
|
|
||||||
|
|
||||||
alpha:
|
|
||||||
subresources:
|
|
||||||
inference:
|
|
||||||
methods:
|
|
||||||
rerank: post /v1alpha/inference/rerank
|
|
||||||
|
|
||||||
post_training:
|
|
||||||
models:
|
|
||||||
algorithm_config: AlgorithmConfig
|
|
||||||
post_training_job: PostTrainingJob
|
|
||||||
list_post_training_jobs_response: ListPostTrainingJobsResponse
|
|
||||||
methods:
|
|
||||||
preference_optimize: post /v1alpha/post-training/preference-optimize
|
|
||||||
supervised_fine_tune: post /v1alpha/post-training/supervised-fine-tune
|
|
||||||
subresources:
|
|
||||||
job:
|
|
||||||
methods:
|
|
||||||
artifacts: get /v1alpha/post-training/job/artifacts
|
|
||||||
cancel: post /v1alpha/post-training/job/cancel
|
|
||||||
status: get /v1alpha/post-training/job/status
|
|
||||||
list:
|
|
||||||
endpoint: get /v1alpha/post-training/jobs
|
|
||||||
paginated: false
|
|
||||||
|
|
||||||
eval:
|
|
||||||
methods:
|
|
||||||
evaluate_rows: post /v1alpha/eval/benchmarks/{benchmark_id}/evaluations
|
|
||||||
run_eval: post /v1alpha/eval/benchmarks/{benchmark_id}/jobs
|
|
||||||
evaluate_rows_alpha: post /v1alpha/eval/benchmarks/{benchmark_id}/evaluations
|
|
||||||
run_eval_alpha: post /v1alpha/eval/benchmarks/{benchmark_id}/jobs
|
|
||||||
|
|
||||||
subresources:
|
|
||||||
jobs:
|
|
||||||
methods:
|
|
||||||
cancel: delete /v1alpha/eval/benchmarks/{benchmark_id}/jobs/{job_id}
|
|
||||||
status: get /v1alpha/eval/benchmarks/{benchmark_id}/jobs/{job_id}
|
|
||||||
retrieve: get /v1alpha/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result
|
|
||||||
models:
|
|
||||||
evaluate_response: EvaluateResponse
|
|
||||||
benchmark_config: BenchmarkConfig
|
|
||||||
job: Job
|
|
||||||
|
|
||||||
agents:
|
|
||||||
methods:
|
|
||||||
create: post /v1alpha/agents
|
|
||||||
list: get /v1alpha/agents
|
|
||||||
retrieve: get /v1alpha/agents/{agent_id}
|
|
||||||
delete: delete /v1alpha/agents/{agent_id}
|
|
||||||
models:
|
|
||||||
inference_step: InferenceStep
|
|
||||||
tool_execution_step: ToolExecutionStep
|
|
||||||
tool_response: ToolResponse
|
|
||||||
shield_call_step: ShieldCallStep
|
|
||||||
memory_retrieval_step: MemoryRetrievalStep
|
|
||||||
subresources:
|
|
||||||
session:
|
|
||||||
models:
|
|
||||||
session: Session
|
|
||||||
methods:
|
|
||||||
list: get /v1alpha/agents/{agent_id}/sessions
|
|
||||||
create: post /v1alpha/agents/{agent_id}/session
|
|
||||||
delete: delete /v1alpha/agents/{agent_id}/session/{session_id}
|
|
||||||
retrieve: get /v1alpha/agents/{agent_id}/session/{session_id}
|
|
||||||
steps:
|
|
||||||
methods:
|
|
||||||
retrieve: get /v1alpha/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id}
|
|
||||||
turn:
|
|
||||||
models:
|
|
||||||
turn: Turn
|
|
||||||
turn_response_event: AgentTurnResponseEvent
|
|
||||||
agent_turn_response_stream_chunk: AgentTurnResponseStreamChunk
|
|
||||||
methods:
|
|
||||||
create:
|
|
||||||
type: http
|
|
||||||
endpoint: post /v1alpha/agents/{agent_id}/session/{session_id}/turn
|
|
||||||
streaming:
|
|
||||||
stream_event_model: alpha.agents.turn.agent_turn_response_stream_chunk
|
|
||||||
param_discriminator: stream
|
|
||||||
retrieve: get /v1alpha/agents/{agent_id}/session/{session_id}/turn/{turn_id}
|
|
||||||
resume:
|
|
||||||
type: http
|
|
||||||
endpoint: post /v1alpha/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume
|
|
||||||
streaming:
|
|
||||||
stream_event_model: alpha.agents.turn.agent_turn_response_stream_chunk
|
|
||||||
param_discriminator: stream
|
|
||||||
|
|
||||||
beta:
|
|
||||||
subresources:
|
|
||||||
datasets:
|
|
||||||
models:
|
|
||||||
list_datasets_response: ListDatasetsResponse
|
|
||||||
methods:
|
|
||||||
register: post /v1beta/datasets
|
|
||||||
retrieve: get /v1beta/datasets/{dataset_id}
|
|
||||||
list:
|
|
||||||
endpoint: get /v1beta/datasets
|
|
||||||
paginated: false
|
|
||||||
unregister: delete /v1beta/datasets/{dataset_id}
|
|
||||||
iterrows: get /v1beta/datasetio/iterrows/{dataset_id}
|
|
||||||
appendrows: post /v1beta/datasetio/append-rows/{dataset_id}
|
|
||||||
|
|
||||||
|
|
||||||
settings:
|
|
||||||
license: MIT
|
|
||||||
unwrap_response_fields: [ data ]
|
|
||||||
|
|
||||||
openapi:
|
|
||||||
transformations:
|
|
||||||
- command: renameValue
|
|
||||||
reason: pydantic reserved name
|
|
||||||
args:
|
|
||||||
filter:
|
|
||||||
only:
|
|
||||||
- '$.components.schemas.InferenceStep.properties.model_response'
|
|
||||||
rename:
|
|
||||||
python:
|
|
||||||
property_name: 'inference_model_response'
|
|
||||||
|
|
||||||
# - command: renameValue
|
|
||||||
# reason: pydantic reserved name
|
|
||||||
# args:
|
|
||||||
# filter:
|
|
||||||
# only:
|
|
||||||
# - '$.components.schemas.Model.properties.model_type'
|
|
||||||
# rename:
|
|
||||||
# python:
|
|
||||||
# property_name: 'type'
|
|
||||||
- command: mergeObject
|
|
||||||
reason: Better return_type using enum
|
|
||||||
args:
|
|
||||||
target:
|
|
||||||
- '$.components.schemas'
|
|
||||||
object:
|
|
||||||
ReturnType:
|
|
||||||
additionalProperties: false
|
|
||||||
properties:
|
|
||||||
type:
|
|
||||||
enum:
|
|
||||||
- string
|
|
||||||
- number
|
|
||||||
- boolean
|
|
||||||
- array
|
|
||||||
- object
|
|
||||||
- json
|
|
||||||
- union
|
|
||||||
- chat_completion_input
|
|
||||||
- completion_input
|
|
||||||
- agent_turn_input
|
|
||||||
required:
|
|
||||||
- type
|
|
||||||
type: object
|
|
||||||
- command: replaceProperties
|
|
||||||
reason: Replace return type properties with better model (see above)
|
|
||||||
args:
|
|
||||||
filter:
|
|
||||||
only:
|
|
||||||
- '$.components.schemas.ScoringFn.properties.return_type'
|
|
||||||
- '$.components.schemas.RegisterScoringFunctionRequest.properties.return_type'
|
|
||||||
value:
|
|
||||||
$ref: '#/components/schemas/ReturnType'
|
|
||||||
- command: oneOfToAnyOf
|
|
||||||
reason: Prism (mock server) doesn't like one of our requests as it technically matches multiple variants
|
|
||||||
- reason: For better names
|
|
||||||
command: extractToRefs
|
|
||||||
args:
|
|
||||||
ref:
|
|
||||||
target: '$.components.schemas.ToolCallDelta.properties.tool_call'
|
|
||||||
name: '#/components/schemas/ToolCallOrString'
|
|
||||||
|
|
||||||
# `readme` is used to configure the code snippets that will be rendered in the
|
|
||||||
# README.md of various SDKs. In particular, you can change the `headline`
|
|
||||||
# snippet's endpoint and the arguments to call it with.
|
|
||||||
readme:
|
|
||||||
example_requests:
|
|
||||||
default:
|
|
||||||
type: request
|
|
||||||
endpoint: post /v1/chat/completions
|
|
||||||
params: &ref_0 {}
|
|
||||||
headline:
|
|
||||||
type: request
|
|
||||||
endpoint: post /v1/models
|
|
||||||
params: *ref_0
|
|
||||||
pagination:
|
|
||||||
type: request
|
|
||||||
endpoint: post /v1/chat/completions
|
|
||||||
params: {}
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -19,6 +19,8 @@ ARG KEEP_WORKSPACE=""
|
||||||
ARG DISTRO_NAME="starter"
|
ARG DISTRO_NAME="starter"
|
||||||
ARG RUN_CONFIG_PATH=""
|
ARG RUN_CONFIG_PATH=""
|
||||||
ARG UV_HTTP_TIMEOUT=500
|
ARG UV_HTTP_TIMEOUT=500
|
||||||
|
ARG UV_EXTRA_INDEX_URL=""
|
||||||
|
ARG UV_INDEX_STRATEGY=""
|
||||||
ENV UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT}
|
ENV UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT}
|
||||||
ENV PYTHONDONTWRITEBYTECODE=1
|
ENV PYTHONDONTWRITEBYTECODE=1
|
||||||
ENV PIP_DISABLE_PIP_VERSION_CHECK=1
|
ENV PIP_DISABLE_PIP_VERSION_CHECK=1
|
||||||
|
|
@ -45,7 +47,7 @@ RUN set -eux; \
|
||||||
exit 1; \
|
exit 1; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
RUN pip install --no-cache uv
|
RUN pip install --no-cache-dir uv
|
||||||
ENV UV_SYSTEM_PYTHON=1
|
ENV UV_SYSTEM_PYTHON=1
|
||||||
|
|
||||||
ENV INSTALL_MODE=${INSTALL_MODE}
|
ENV INSTALL_MODE=${INSTALL_MODE}
|
||||||
|
|
@ -62,47 +64,60 @@ COPY . /workspace
|
||||||
|
|
||||||
# Install the client package if it is provided
|
# Install the client package if it is provided
|
||||||
# NOTE: this is installed before llama-stack since llama-stack depends on llama-stack-client-python
|
# NOTE: this is installed before llama-stack since llama-stack depends on llama-stack-client-python
|
||||||
|
# Unset UV index env vars to ensure we only use PyPI for the client
|
||||||
RUN set -eux; \
|
RUN set -eux; \
|
||||||
|
unset UV_EXTRA_INDEX_URL UV_INDEX_STRATEGY; \
|
||||||
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then \
|
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then \
|
||||||
if [ ! -d "$LLAMA_STACK_CLIENT_DIR" ]; then \
|
if [ ! -d "$LLAMA_STACK_CLIENT_DIR" ]; then \
|
||||||
echo "LLAMA_STACK_CLIENT_DIR is set but $LLAMA_STACK_CLIENT_DIR does not exist" >&2; \
|
echo "LLAMA_STACK_CLIENT_DIR is set but $LLAMA_STACK_CLIENT_DIR does not exist" >&2; \
|
||||||
exit 1; \
|
exit 1; \
|
||||||
fi; \
|
fi; \
|
||||||
uv pip install --no-cache -e "$LLAMA_STACK_CLIENT_DIR"; \
|
uv pip install --no-cache-dir -e "$LLAMA_STACK_CLIENT_DIR"; \
|
||||||
fi;
|
fi;
|
||||||
|
|
||||||
# Install llama-stack
|
# Install llama-stack
|
||||||
|
# Use UV_EXTRA_INDEX_URL inline only for editable install with RC dependencies
|
||||||
RUN set -eux; \
|
RUN set -eux; \
|
||||||
|
SAVED_UV_EXTRA_INDEX_URL="${UV_EXTRA_INDEX_URL:-}"; \
|
||||||
|
SAVED_UV_INDEX_STRATEGY="${UV_INDEX_STRATEGY:-}"; \
|
||||||
|
unset UV_EXTRA_INDEX_URL UV_INDEX_STRATEGY; \
|
||||||
if [ "$INSTALL_MODE" = "editable" ]; then \
|
if [ "$INSTALL_MODE" = "editable" ]; then \
|
||||||
if [ ! -d "$LLAMA_STACK_DIR" ]; then \
|
if [ ! -d "$LLAMA_STACK_DIR" ]; then \
|
||||||
echo "INSTALL_MODE=editable requires LLAMA_STACK_DIR to point to a directory inside the build context" >&2; \
|
echo "INSTALL_MODE=editable requires LLAMA_STACK_DIR to point to a directory inside the build context" >&2; \
|
||||||
exit 1; \
|
exit 1; \
|
||||||
fi; \
|
fi; \
|
||||||
uv pip install --no-cache -e "$LLAMA_STACK_DIR"; \
|
if [ -n "$SAVED_UV_EXTRA_INDEX_URL" ] && [ -n "$SAVED_UV_INDEX_STRATEGY" ]; then \
|
||||||
elif [ "$INSTALL_MODE" = "test-pypi" ]; then \
|
UV_EXTRA_INDEX_URL="$SAVED_UV_EXTRA_INDEX_URL" UV_INDEX_STRATEGY="$SAVED_UV_INDEX_STRATEGY" \
|
||||||
uv pip install --no-cache fastapi libcst; \
|
uv pip install --no-cache-dir -e "$LLAMA_STACK_DIR"; \
|
||||||
if [ -n "$TEST_PYPI_VERSION" ]; then \
|
|
||||||
uv pip install --no-cache --extra-index-url https://test.pypi.org/simple/ --index-strategy unsafe-best-match "llama-stack==$TEST_PYPI_VERSION"; \
|
|
||||||
else \
|
else \
|
||||||
uv pip install --no-cache --extra-index-url https://test.pypi.org/simple/ --index-strategy unsafe-best-match llama-stack; \
|
uv pip install --no-cache-dir -e "$LLAMA_STACK_DIR"; \
|
||||||
|
fi; \
|
||||||
|
elif [ "$INSTALL_MODE" = "test-pypi" ]; then \
|
||||||
|
uv pip install --no-cache-dir fastapi libcst; \
|
||||||
|
if [ -n "$TEST_PYPI_VERSION" ]; then \
|
||||||
|
uv pip install --no-cache-dir --extra-index-url https://test.pypi.org/simple/ --index-strategy unsafe-best-match "llama-stack==$TEST_PYPI_VERSION"; \
|
||||||
|
else \
|
||||||
|
uv pip install --no-cache-dir --extra-index-url https://test.pypi.org/simple/ --index-strategy unsafe-best-match llama-stack; \
|
||||||
fi; \
|
fi; \
|
||||||
else \
|
else \
|
||||||
if [ -n "$PYPI_VERSION" ]; then \
|
if [ -n "$PYPI_VERSION" ]; then \
|
||||||
uv pip install --no-cache "llama-stack==$PYPI_VERSION"; \
|
uv pip install --no-cache-dir "llama-stack==$PYPI_VERSION"; \
|
||||||
else \
|
else \
|
||||||
uv pip install --no-cache llama-stack; \
|
uv pip install --no-cache-dir llama-stack; \
|
||||||
fi; \
|
fi; \
|
||||||
fi;
|
fi;
|
||||||
|
|
||||||
# Install the dependencies for the distribution
|
# Install the dependencies for the distribution
|
||||||
|
# Explicitly unset UV index env vars to ensure we only use PyPI for distribution deps
|
||||||
RUN set -eux; \
|
RUN set -eux; \
|
||||||
|
unset UV_EXTRA_INDEX_URL UV_INDEX_STRATEGY; \
|
||||||
if [ -z "$DISTRO_NAME" ]; then \
|
if [ -z "$DISTRO_NAME" ]; then \
|
||||||
echo "DISTRO_NAME must be provided" >&2; \
|
echo "DISTRO_NAME must be provided" >&2; \
|
||||||
exit 1; \
|
exit 1; \
|
||||||
fi; \
|
fi; \
|
||||||
deps="$(llama stack list-deps "$DISTRO_NAME")"; \
|
deps="$(llama stack list-deps "$DISTRO_NAME")"; \
|
||||||
if [ -n "$deps" ]; then \
|
if [ -n "$deps" ]; then \
|
||||||
printf '%s\n' "$deps" | xargs -L1 uv pip install --no-cache; \
|
printf '%s\n' "$deps" | xargs -L1 uv pip install --no-cache-dir; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Cleanup
|
# Cleanup
|
||||||
|
|
|
||||||
|
|
@ -23,5 +23,4 @@ A Llama Stack API is described as a collection of REST endpoints. We currently s
|
||||||
We are working on adding a few more APIs to complete the application lifecycle. These will include:
|
We are working on adding a few more APIs to complete the application lifecycle. These will include:
|
||||||
- **Batch Inference**: run inference on a dataset of inputs
|
- **Batch Inference**: run inference on a dataset of inputs
|
||||||
- **Batch Agents**: run agents on a dataset of inputs
|
- **Batch Agents**: run agents on a dataset of inputs
|
||||||
- **Synthetic Data Generation**: generate synthetic data for model development
|
|
||||||
- **Batches**: OpenAI-compatible batch management for inference
|
- **Batches**: OpenAI-compatible batch management for inference
|
||||||
|
|
|
||||||
|
|
@ -79,6 +79,33 @@ docker run \
|
||||||
--port $LLAMA_STACK_PORT
|
--port $LLAMA_STACK_PORT
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Via Docker with Custom Run Configuration
|
||||||
|
|
||||||
|
You can also run the Docker container with a custom run configuration file by mounting it into the container:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Set the path to your custom run.yaml file
|
||||||
|
CUSTOM_RUN_CONFIG=/path/to/your/custom-run.yaml
|
||||||
|
LLAMA_STACK_PORT=8321
|
||||||
|
|
||||||
|
docker run \
|
||||||
|
-it \
|
||||||
|
--pull always \
|
||||||
|
--gpu all \
|
||||||
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
|
-v ~/.llama:/root/.llama \
|
||||||
|
-v $CUSTOM_RUN_CONFIG:/app/custom-run.yaml \
|
||||||
|
-e RUN_CONFIG_PATH=/app/custom-run.yaml \
|
||||||
|
llamastack/distribution-meta-reference-gpu \
|
||||||
|
--port $LLAMA_STACK_PORT
|
||||||
|
```
|
||||||
|
|
||||||
|
**Note**: The run configuration must be mounted into the container before it can be used. The `-v` flag mounts your local file into the container, and the `RUN_CONFIG_PATH` environment variable tells the entrypoint script which configuration to use.
|
||||||
|
|
||||||
|
Available run configurations for this distribution:
|
||||||
|
- `run.yaml`
|
||||||
|
- `run-with-safety.yaml`
|
||||||
|
|
||||||
### Via venv
|
### Via venv
|
||||||
|
|
||||||
Make sure you have the Llama Stack CLI available.
|
Make sure you have the Llama Stack CLI available.
|
||||||
|
|
|
||||||
|
|
@ -127,13 +127,39 @@ docker run \
|
||||||
-it \
|
-it \
|
||||||
--pull always \
|
--pull always \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ./run.yaml:/root/my-run.yaml \
|
-v ~/.llama:/root/.llama \
|
||||||
-e NVIDIA_API_KEY=$NVIDIA_API_KEY \
|
-e NVIDIA_API_KEY=$NVIDIA_API_KEY \
|
||||||
llamastack/distribution-nvidia \
|
llamastack/distribution-nvidia \
|
||||||
--config /root/my-run.yaml \
|
|
||||||
--port $LLAMA_STACK_PORT
|
--port $LLAMA_STACK_PORT
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Via Docker with Custom Run Configuration
|
||||||
|
|
||||||
|
You can also run the Docker container with a custom run configuration file by mounting it into the container:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Set the path to your custom run.yaml file
|
||||||
|
CUSTOM_RUN_CONFIG=/path/to/your/custom-run.yaml
|
||||||
|
LLAMA_STACK_PORT=8321
|
||||||
|
|
||||||
|
docker run \
|
||||||
|
-it \
|
||||||
|
--pull always \
|
||||||
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
|
-v ~/.llama:/root/.llama \
|
||||||
|
-v $CUSTOM_RUN_CONFIG:/app/custom-run.yaml \
|
||||||
|
-e RUN_CONFIG_PATH=/app/custom-run.yaml \
|
||||||
|
-e NVIDIA_API_KEY=$NVIDIA_API_KEY \
|
||||||
|
llamastack/distribution-nvidia \
|
||||||
|
--port $LLAMA_STACK_PORT
|
||||||
|
```
|
||||||
|
|
||||||
|
**Note**: The run configuration must be mounted into the container before it can be used. The `-v` flag mounts your local file into the container, and the `RUN_CONFIG_PATH` environment variable tells the entrypoint script which configuration to use.
|
||||||
|
|
||||||
|
Available run configurations for this distribution:
|
||||||
|
- `run.yaml`
|
||||||
|
- `run-with-safety.yaml`
|
||||||
|
|
||||||
### Via venv
|
### Via venv
|
||||||
|
|
||||||
If you've set up your local development environment, you can also install the distribution dependencies using your local virtual environment.
|
If you've set up your local development environment, you can also install the distribution dependencies using your local virtual environment.
|
||||||
|
|
|
||||||
27
docs/docs/providers/files/remote_openai.mdx
Normal file
27
docs/docs/providers/files/remote_openai.mdx
Normal file
|
|
@ -0,0 +1,27 @@
|
||||||
|
---
|
||||||
|
description: "OpenAI Files API provider for managing files through OpenAI's native file storage service."
|
||||||
|
sidebar_label: Remote - Openai
|
||||||
|
title: remote::openai
|
||||||
|
---
|
||||||
|
|
||||||
|
# remote::openai
|
||||||
|
|
||||||
|
## Description
|
||||||
|
|
||||||
|
OpenAI Files API provider for managing files through OpenAI's native file storage service.
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
| Field | Type | Required | Default | Description |
|
||||||
|
|-------|------|----------|---------|-------------|
|
||||||
|
| `api_key` | `<class 'str'>` | No | | OpenAI API key for authentication |
|
||||||
|
| `metadata_store` | `<class 'llama_stack.core.storage.datatypes.SqlStoreReference'>` | No | | SQL store configuration for file metadata |
|
||||||
|
|
||||||
|
## Sample Configuration
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
api_key: ${env.OPENAI_API_KEY}
|
||||||
|
metadata_store:
|
||||||
|
table_name: openai_files_metadata
|
||||||
|
backend: sql_default
|
||||||
|
```
|
||||||
|
|
@ -20,6 +20,7 @@ NVIDIA inference provider for accessing NVIDIA NIM models and AI services.
|
||||||
| `url` | `<class 'str'>` | No | https://integrate.api.nvidia.com | A base url for accessing the NVIDIA NIM |
|
| `url` | `<class 'str'>` | No | https://integrate.api.nvidia.com | A base url for accessing the NVIDIA NIM |
|
||||||
| `timeout` | `<class 'int'>` | No | 60 | Timeout for the HTTP requests |
|
| `timeout` | `<class 'int'>` | No | 60 | Timeout for the HTTP requests |
|
||||||
| `append_api_version` | `<class 'bool'>` | No | True | When set to false, the API version will not be appended to the base_url. By default, it is true. |
|
| `append_api_version` | `<class 'bool'>` | No | True | When set to false, the API version will not be appended to the base_url. By default, it is true. |
|
||||||
|
| `rerank_model_to_url` | `dict[str, str` | No | `{'nv-rerank-qa-mistral-4b:1': 'https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking', 'nvidia/nv-rerankqa-mistral-4b-v3': 'https://ai.api.nvidia.com/v1/retrieval/nvidia/nv-rerankqa-mistral-4b-v3/reranking', 'nvidia/llama-3.2-nv-rerankqa-1b-v2': 'https://ai.api.nvidia.com/v1/retrieval/nvidia/llama-3_2-nv-rerankqa-1b-v2/reranking'}` | Mapping of rerank model identifiers to their API endpoints. |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
|
|
|
||||||
1036
docs/notebooks/llamastack_agents_getting_started_examples.ipynb
Normal file
1036
docs/notebooks/llamastack_agents_getting_started_examples.ipynb
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -84,7 +84,6 @@ def generate_spec(output_dir: Path, stability_filter: str = None, main_spec: boo
|
||||||
)
|
)
|
||||||
|
|
||||||
yaml_filename = f"{filename_prefix}llama-stack-spec.yaml"
|
yaml_filename = f"{filename_prefix}llama-stack-spec.yaml"
|
||||||
html_filename = f"{filename_prefix}llama-stack-spec.html"
|
|
||||||
|
|
||||||
with open(output_dir / yaml_filename, "w", encoding="utf-8") as fp:
|
with open(output_dir / yaml_filename, "w", encoding="utf-8") as fp:
|
||||||
y = yaml.YAML()
|
y = yaml.YAML()
|
||||||
|
|
@ -102,11 +101,6 @@ def generate_spec(output_dir: Path, stability_filter: str = None, main_spec: boo
|
||||||
fp,
|
fp,
|
||||||
)
|
)
|
||||||
|
|
||||||
with open(output_dir / html_filename, "w") as fp:
|
|
||||||
spec.write_html(fp, pretty_print=True)
|
|
||||||
|
|
||||||
print(f"Generated {yaml_filename} and {html_filename}")
|
|
||||||
|
|
||||||
def main(output_dir: str):
|
def main(output_dir: str):
|
||||||
output_dir = Path(output_dir)
|
output_dir = Path(output_dir)
|
||||||
if not output_dir.exists():
|
if not output_dir.exists():
|
||||||
|
|
|
||||||
|
|
@ -242,15 +242,6 @@ const sidebars: SidebarsConfig = {
|
||||||
'providers/eval/remote_nvidia'
|
'providers/eval/remote_nvidia'
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
{
|
|
||||||
type: 'category',
|
|
||||||
label: 'Telemetry',
|
|
||||||
collapsed: true,
|
|
||||||
items: [
|
|
||||||
'providers/telemetry/index',
|
|
||||||
'providers/telemetry/inline_meta-reference'
|
|
||||||
],
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
type: 'category',
|
type: 'category',
|
||||||
label: 'Batches',
|
label: 'Batches',
|
||||||
|
|
|
||||||
13582
docs/static/deprecated-llama-stack-spec.html
vendored
13582
docs/static/deprecated-llama-stack-spec.html
vendored
File diff suppressed because it is too large
Load diff
534
docs/static/deprecated-llama-stack-spec.yaml
vendored
534
docs/static/deprecated-llama-stack-spec.yaml
vendored
|
|
@ -1012,6 +1012,141 @@ paths:
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
deprecated: true
|
deprecated: true
|
||||||
|
/v1/openai/v1/batches:
|
||||||
|
get:
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
description: A list of batch objects.
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/ListBatchesResponse'
|
||||||
|
'400':
|
||||||
|
$ref: '#/components/responses/BadRequest400'
|
||||||
|
'429':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/TooManyRequests429
|
||||||
|
'500':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/InternalServerError500
|
||||||
|
default:
|
||||||
|
$ref: '#/components/responses/DefaultError'
|
||||||
|
tags:
|
||||||
|
- Batches
|
||||||
|
summary: List all batches for the current user.
|
||||||
|
description: List all batches for the current user.
|
||||||
|
parameters:
|
||||||
|
- name: after
|
||||||
|
in: query
|
||||||
|
description: >-
|
||||||
|
A cursor for pagination; returns batches after this batch ID.
|
||||||
|
required: false
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
- name: limit
|
||||||
|
in: query
|
||||||
|
description: >-
|
||||||
|
Number of batches to return (default 20, max 100).
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: integer
|
||||||
|
deprecated: true
|
||||||
|
post:
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
description: The created batch object.
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/Batch'
|
||||||
|
'400':
|
||||||
|
$ref: '#/components/responses/BadRequest400'
|
||||||
|
'429':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/TooManyRequests429
|
||||||
|
'500':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/InternalServerError500
|
||||||
|
default:
|
||||||
|
$ref: '#/components/responses/DefaultError'
|
||||||
|
tags:
|
||||||
|
- Batches
|
||||||
|
summary: >-
|
||||||
|
Create a new batch for processing multiple API requests.
|
||||||
|
description: >-
|
||||||
|
Create a new batch for processing multiple API requests.
|
||||||
|
parameters: []
|
||||||
|
requestBody:
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/CreateBatchRequest'
|
||||||
|
required: true
|
||||||
|
deprecated: true
|
||||||
|
/v1/openai/v1/batches/{batch_id}:
|
||||||
|
get:
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
description: The batch object.
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/Batch'
|
||||||
|
'400':
|
||||||
|
$ref: '#/components/responses/BadRequest400'
|
||||||
|
'429':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/TooManyRequests429
|
||||||
|
'500':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/InternalServerError500
|
||||||
|
default:
|
||||||
|
$ref: '#/components/responses/DefaultError'
|
||||||
|
tags:
|
||||||
|
- Batches
|
||||||
|
summary: >-
|
||||||
|
Retrieve information about a specific batch.
|
||||||
|
description: >-
|
||||||
|
Retrieve information about a specific batch.
|
||||||
|
parameters:
|
||||||
|
- name: batch_id
|
||||||
|
in: path
|
||||||
|
description: The ID of the batch to retrieve.
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
deprecated: true
|
||||||
|
/v1/openai/v1/batches/{batch_id}/cancel:
|
||||||
|
post:
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
description: The updated batch object.
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/Batch'
|
||||||
|
'400':
|
||||||
|
$ref: '#/components/responses/BadRequest400'
|
||||||
|
'429':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/TooManyRequests429
|
||||||
|
'500':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/InternalServerError500
|
||||||
|
default:
|
||||||
|
$ref: '#/components/responses/DefaultError'
|
||||||
|
tags:
|
||||||
|
- Batches
|
||||||
|
summary: Cancel a batch that is in progress.
|
||||||
|
description: Cancel a batch that is in progress.
|
||||||
|
parameters:
|
||||||
|
- name: batch_id
|
||||||
|
in: path
|
||||||
|
description: The ID of the batch to cancel.
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
deprecated: true
|
||||||
/v1/openai/v1/chat/completions:
|
/v1/openai/v1/chat/completions:
|
||||||
get:
|
get:
|
||||||
responses:
|
responses:
|
||||||
|
|
@ -1426,31 +1561,6 @@ paths:
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
deprecated: true
|
deprecated: true
|
||||||
/v1/openai/v1/models:
|
|
||||||
get:
|
|
||||||
responses:
|
|
||||||
'200':
|
|
||||||
description: A OpenAIListModelsResponse.
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
$ref: '#/components/schemas/OpenAIListModelsResponse'
|
|
||||||
'400':
|
|
||||||
$ref: '#/components/responses/BadRequest400'
|
|
||||||
'429':
|
|
||||||
$ref: >-
|
|
||||||
#/components/responses/TooManyRequests429
|
|
||||||
'500':
|
|
||||||
$ref: >-
|
|
||||||
#/components/responses/InternalServerError500
|
|
||||||
default:
|
|
||||||
$ref: '#/components/responses/DefaultError'
|
|
||||||
tags:
|
|
||||||
- Models
|
|
||||||
summary: List models using the OpenAI API.
|
|
||||||
description: List models using the OpenAI API.
|
|
||||||
parameters: []
|
|
||||||
deprecated: true
|
|
||||||
/v1/openai/v1/moderations:
|
/v1/openai/v1/moderations:
|
||||||
post:
|
post:
|
||||||
responses:
|
responses:
|
||||||
|
|
@ -4736,6 +4846,331 @@ components:
|
||||||
title: Job
|
title: Job
|
||||||
description: >-
|
description: >-
|
||||||
A job execution instance with status tracking.
|
A job execution instance with status tracking.
|
||||||
|
ListBatchesResponse:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
object:
|
||||||
|
type: string
|
||||||
|
const: list
|
||||||
|
default: list
|
||||||
|
data:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
id:
|
||||||
|
type: string
|
||||||
|
completion_window:
|
||||||
|
type: string
|
||||||
|
created_at:
|
||||||
|
type: integer
|
||||||
|
endpoint:
|
||||||
|
type: string
|
||||||
|
input_file_id:
|
||||||
|
type: string
|
||||||
|
object:
|
||||||
|
type: string
|
||||||
|
const: batch
|
||||||
|
status:
|
||||||
|
type: string
|
||||||
|
enum:
|
||||||
|
- validating
|
||||||
|
- failed
|
||||||
|
- in_progress
|
||||||
|
- finalizing
|
||||||
|
- completed
|
||||||
|
- expired
|
||||||
|
- cancelling
|
||||||
|
- cancelled
|
||||||
|
cancelled_at:
|
||||||
|
type: integer
|
||||||
|
cancelling_at:
|
||||||
|
type: integer
|
||||||
|
completed_at:
|
||||||
|
type: integer
|
||||||
|
error_file_id:
|
||||||
|
type: string
|
||||||
|
errors:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
data:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
code:
|
||||||
|
type: string
|
||||||
|
line:
|
||||||
|
type: integer
|
||||||
|
message:
|
||||||
|
type: string
|
||||||
|
param:
|
||||||
|
type: string
|
||||||
|
additionalProperties: false
|
||||||
|
title: BatchError
|
||||||
|
object:
|
||||||
|
type: string
|
||||||
|
additionalProperties: false
|
||||||
|
title: Errors
|
||||||
|
expired_at:
|
||||||
|
type: integer
|
||||||
|
expires_at:
|
||||||
|
type: integer
|
||||||
|
failed_at:
|
||||||
|
type: integer
|
||||||
|
finalizing_at:
|
||||||
|
type: integer
|
||||||
|
in_progress_at:
|
||||||
|
type: integer
|
||||||
|
metadata:
|
||||||
|
type: object
|
||||||
|
additionalProperties:
|
||||||
|
type: string
|
||||||
|
model:
|
||||||
|
type: string
|
||||||
|
output_file_id:
|
||||||
|
type: string
|
||||||
|
request_counts:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
completed:
|
||||||
|
type: integer
|
||||||
|
failed:
|
||||||
|
type: integer
|
||||||
|
total:
|
||||||
|
type: integer
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- completed
|
||||||
|
- failed
|
||||||
|
- total
|
||||||
|
title: BatchRequestCounts
|
||||||
|
usage:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
input_tokens:
|
||||||
|
type: integer
|
||||||
|
input_tokens_details:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
cached_tokens:
|
||||||
|
type: integer
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- cached_tokens
|
||||||
|
title: InputTokensDetails
|
||||||
|
output_tokens:
|
||||||
|
type: integer
|
||||||
|
output_tokens_details:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
reasoning_tokens:
|
||||||
|
type: integer
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- reasoning_tokens
|
||||||
|
title: OutputTokensDetails
|
||||||
|
total_tokens:
|
||||||
|
type: integer
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- input_tokens
|
||||||
|
- input_tokens_details
|
||||||
|
- output_tokens
|
||||||
|
- output_tokens_details
|
||||||
|
- total_tokens
|
||||||
|
title: BatchUsage
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- id
|
||||||
|
- completion_window
|
||||||
|
- created_at
|
||||||
|
- endpoint
|
||||||
|
- input_file_id
|
||||||
|
- object
|
||||||
|
- status
|
||||||
|
title: Batch
|
||||||
|
first_id:
|
||||||
|
type: string
|
||||||
|
last_id:
|
||||||
|
type: string
|
||||||
|
has_more:
|
||||||
|
type: boolean
|
||||||
|
default: false
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- object
|
||||||
|
- data
|
||||||
|
- has_more
|
||||||
|
title: ListBatchesResponse
|
||||||
|
description: >-
|
||||||
|
Response containing a list of batch objects.
|
||||||
|
CreateBatchRequest:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
input_file_id:
|
||||||
|
type: string
|
||||||
|
description: >-
|
||||||
|
The ID of an uploaded file containing requests for the batch.
|
||||||
|
endpoint:
|
||||||
|
type: string
|
||||||
|
description: >-
|
||||||
|
The endpoint to be used for all requests in the batch.
|
||||||
|
completion_window:
|
||||||
|
type: string
|
||||||
|
const: 24h
|
||||||
|
description: >-
|
||||||
|
The time window within which the batch should be processed.
|
||||||
|
metadata:
|
||||||
|
type: object
|
||||||
|
additionalProperties:
|
||||||
|
type: string
|
||||||
|
description: Optional metadata for the batch.
|
||||||
|
idempotency_key:
|
||||||
|
type: string
|
||||||
|
description: >-
|
||||||
|
Optional idempotency key. When provided, enables idempotent behavior.
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- input_file_id
|
||||||
|
- endpoint
|
||||||
|
- completion_window
|
||||||
|
title: CreateBatchRequest
|
||||||
|
Batch:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
id:
|
||||||
|
type: string
|
||||||
|
completion_window:
|
||||||
|
type: string
|
||||||
|
created_at:
|
||||||
|
type: integer
|
||||||
|
endpoint:
|
||||||
|
type: string
|
||||||
|
input_file_id:
|
||||||
|
type: string
|
||||||
|
object:
|
||||||
|
type: string
|
||||||
|
const: batch
|
||||||
|
status:
|
||||||
|
type: string
|
||||||
|
enum:
|
||||||
|
- validating
|
||||||
|
- failed
|
||||||
|
- in_progress
|
||||||
|
- finalizing
|
||||||
|
- completed
|
||||||
|
- expired
|
||||||
|
- cancelling
|
||||||
|
- cancelled
|
||||||
|
cancelled_at:
|
||||||
|
type: integer
|
||||||
|
cancelling_at:
|
||||||
|
type: integer
|
||||||
|
completed_at:
|
||||||
|
type: integer
|
||||||
|
error_file_id:
|
||||||
|
type: string
|
||||||
|
errors:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
data:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
code:
|
||||||
|
type: string
|
||||||
|
line:
|
||||||
|
type: integer
|
||||||
|
message:
|
||||||
|
type: string
|
||||||
|
param:
|
||||||
|
type: string
|
||||||
|
additionalProperties: false
|
||||||
|
title: BatchError
|
||||||
|
object:
|
||||||
|
type: string
|
||||||
|
additionalProperties: false
|
||||||
|
title: Errors
|
||||||
|
expired_at:
|
||||||
|
type: integer
|
||||||
|
expires_at:
|
||||||
|
type: integer
|
||||||
|
failed_at:
|
||||||
|
type: integer
|
||||||
|
finalizing_at:
|
||||||
|
type: integer
|
||||||
|
in_progress_at:
|
||||||
|
type: integer
|
||||||
|
metadata:
|
||||||
|
type: object
|
||||||
|
additionalProperties:
|
||||||
|
type: string
|
||||||
|
model:
|
||||||
|
type: string
|
||||||
|
output_file_id:
|
||||||
|
type: string
|
||||||
|
request_counts:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
completed:
|
||||||
|
type: integer
|
||||||
|
failed:
|
||||||
|
type: integer
|
||||||
|
total:
|
||||||
|
type: integer
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- completed
|
||||||
|
- failed
|
||||||
|
- total
|
||||||
|
title: BatchRequestCounts
|
||||||
|
usage:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
input_tokens:
|
||||||
|
type: integer
|
||||||
|
input_tokens_details:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
cached_tokens:
|
||||||
|
type: integer
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- cached_tokens
|
||||||
|
title: InputTokensDetails
|
||||||
|
output_tokens:
|
||||||
|
type: integer
|
||||||
|
output_tokens_details:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
reasoning_tokens:
|
||||||
|
type: integer
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- reasoning_tokens
|
||||||
|
title: OutputTokensDetails
|
||||||
|
total_tokens:
|
||||||
|
type: integer
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- input_tokens
|
||||||
|
- input_tokens_details
|
||||||
|
- output_tokens
|
||||||
|
- output_tokens_details
|
||||||
|
- total_tokens
|
||||||
|
title: BatchUsage
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- id
|
||||||
|
- completion_window
|
||||||
|
- created_at
|
||||||
|
- endpoint
|
||||||
|
- input_file_id
|
||||||
|
- object
|
||||||
|
- status
|
||||||
|
title: Batch
|
||||||
Order:
|
Order:
|
||||||
type: string
|
type: string
|
||||||
enum:
|
enum:
|
||||||
|
|
@ -6056,38 +6491,6 @@ components:
|
||||||
Response:
|
Response:
|
||||||
type: object
|
type: object
|
||||||
title: Response
|
title: Response
|
||||||
OpenAIModel:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
id:
|
|
||||||
type: string
|
|
||||||
object:
|
|
||||||
type: string
|
|
||||||
const: model
|
|
||||||
default: model
|
|
||||||
created:
|
|
||||||
type: integer
|
|
||||||
owned_by:
|
|
||||||
type: string
|
|
||||||
additionalProperties: false
|
|
||||||
required:
|
|
||||||
- id
|
|
||||||
- object
|
|
||||||
- created
|
|
||||||
- owned_by
|
|
||||||
title: OpenAIModel
|
|
||||||
description: A model from OpenAI.
|
|
||||||
OpenAIListModelsResponse:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
data:
|
|
||||||
type: array
|
|
||||||
items:
|
|
||||||
$ref: '#/components/schemas/OpenAIModel'
|
|
||||||
additionalProperties: false
|
|
||||||
required:
|
|
||||||
- data
|
|
||||||
title: OpenAIListModelsResponse
|
|
||||||
RunModerationRequest:
|
RunModerationRequest:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
|
@ -10263,6 +10666,19 @@ tags:
|
||||||
|
|
||||||
- **Responses API**: Use the stable v1 Responses API endpoints
|
- **Responses API**: Use the stable v1 Responses API endpoints
|
||||||
x-displayName: Agents
|
x-displayName: Agents
|
||||||
|
- name: Batches
|
||||||
|
description: >-
|
||||||
|
The API is designed to allow use of openai client libraries for seamless integration.
|
||||||
|
|
||||||
|
|
||||||
|
This API provides the following extensions:
|
||||||
|
- idempotent batch creation
|
||||||
|
|
||||||
|
Note: This API is currently under active development and may undergo changes.
|
||||||
|
x-displayName: >-
|
||||||
|
The Batches API enables efficient processing of multiple requests in a single
|
||||||
|
operation, particularly useful for processing large datasets, batch evaluation
|
||||||
|
workflows, and cost-effective inference at scale.
|
||||||
- name: Benchmarks
|
- name: Benchmarks
|
||||||
description: ''
|
description: ''
|
||||||
- name: DatasetIO
|
- name: DatasetIO
|
||||||
|
|
@ -10295,8 +10711,6 @@ tags:
|
||||||
- Rerank models: these models reorder the documents based on their relevance
|
- Rerank models: these models reorder the documents based on their relevance
|
||||||
to a query.
|
to a query.
|
||||||
x-displayName: Inference
|
x-displayName: Inference
|
||||||
- name: Models
|
|
||||||
description: ''
|
|
||||||
- name: PostTraining (Coming Soon)
|
- name: PostTraining (Coming Soon)
|
||||||
description: ''
|
description: ''
|
||||||
- name: Safety
|
- name: Safety
|
||||||
|
|
@ -10308,13 +10722,13 @@ x-tagGroups:
|
||||||
- name: Operations
|
- name: Operations
|
||||||
tags:
|
tags:
|
||||||
- Agents
|
- Agents
|
||||||
|
- Batches
|
||||||
- Benchmarks
|
- Benchmarks
|
||||||
- DatasetIO
|
- DatasetIO
|
||||||
- Datasets
|
- Datasets
|
||||||
- Eval
|
- Eval
|
||||||
- Files
|
- Files
|
||||||
- Inference
|
- Inference
|
||||||
- Models
|
|
||||||
- PostTraining (Coming Soon)
|
- PostTraining (Coming Soon)
|
||||||
- Safety
|
- Safety
|
||||||
- VectorIO
|
- VectorIO
|
||||||
|
|
|
||||||
5552
docs/static/experimental-llama-stack-spec.html
vendored
5552
docs/static/experimental-llama-stack-spec.html
vendored
File diff suppressed because it is too large
Load diff
1075
docs/static/llama-stack-spec.html
vendored
1075
docs/static/llama-stack-spec.html
vendored
File diff suppressed because it is too large
Load diff
873
docs/static/llama-stack-spec.yaml
vendored
873
docs/static/llama-stack-spec.yaml
vendored
|
|
@ -12,6 +12,141 @@ info:
|
||||||
servers:
|
servers:
|
||||||
- url: http://any-hosted-llama-stack.com
|
- url: http://any-hosted-llama-stack.com
|
||||||
paths:
|
paths:
|
||||||
|
/v1/batches:
|
||||||
|
get:
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
description: A list of batch objects.
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/ListBatchesResponse'
|
||||||
|
'400':
|
||||||
|
$ref: '#/components/responses/BadRequest400'
|
||||||
|
'429':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/TooManyRequests429
|
||||||
|
'500':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/InternalServerError500
|
||||||
|
default:
|
||||||
|
$ref: '#/components/responses/DefaultError'
|
||||||
|
tags:
|
||||||
|
- Batches
|
||||||
|
summary: List all batches for the current user.
|
||||||
|
description: List all batches for the current user.
|
||||||
|
parameters:
|
||||||
|
- name: after
|
||||||
|
in: query
|
||||||
|
description: >-
|
||||||
|
A cursor for pagination; returns batches after this batch ID.
|
||||||
|
required: false
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
- name: limit
|
||||||
|
in: query
|
||||||
|
description: >-
|
||||||
|
Number of batches to return (default 20, max 100).
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: integer
|
||||||
|
deprecated: false
|
||||||
|
post:
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
description: The created batch object.
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/Batch'
|
||||||
|
'400':
|
||||||
|
$ref: '#/components/responses/BadRequest400'
|
||||||
|
'429':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/TooManyRequests429
|
||||||
|
'500':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/InternalServerError500
|
||||||
|
default:
|
||||||
|
$ref: '#/components/responses/DefaultError'
|
||||||
|
tags:
|
||||||
|
- Batches
|
||||||
|
summary: >-
|
||||||
|
Create a new batch for processing multiple API requests.
|
||||||
|
description: >-
|
||||||
|
Create a new batch for processing multiple API requests.
|
||||||
|
parameters: []
|
||||||
|
requestBody:
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/CreateBatchRequest'
|
||||||
|
required: true
|
||||||
|
deprecated: false
|
||||||
|
/v1/batches/{batch_id}:
|
||||||
|
get:
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
description: The batch object.
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/Batch'
|
||||||
|
'400':
|
||||||
|
$ref: '#/components/responses/BadRequest400'
|
||||||
|
'429':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/TooManyRequests429
|
||||||
|
'500':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/InternalServerError500
|
||||||
|
default:
|
||||||
|
$ref: '#/components/responses/DefaultError'
|
||||||
|
tags:
|
||||||
|
- Batches
|
||||||
|
summary: >-
|
||||||
|
Retrieve information about a specific batch.
|
||||||
|
description: >-
|
||||||
|
Retrieve information about a specific batch.
|
||||||
|
parameters:
|
||||||
|
- name: batch_id
|
||||||
|
in: path
|
||||||
|
description: The ID of the batch to retrieve.
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
deprecated: false
|
||||||
|
/v1/batches/{batch_id}/cancel:
|
||||||
|
post:
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
description: The updated batch object.
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/Batch'
|
||||||
|
'400':
|
||||||
|
$ref: '#/components/responses/BadRequest400'
|
||||||
|
'429':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/TooManyRequests429
|
||||||
|
'500':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/InternalServerError500
|
||||||
|
default:
|
||||||
|
$ref: '#/components/responses/DefaultError'
|
||||||
|
tags:
|
||||||
|
- Batches
|
||||||
|
summary: Cancel a batch that is in progress.
|
||||||
|
description: Cancel a batch that is in progress.
|
||||||
|
parameters:
|
||||||
|
- name: batch_id
|
||||||
|
in: path
|
||||||
|
description: The ID of the batch to cancel.
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
deprecated: false
|
||||||
/v1/chat/completions:
|
/v1/chat/completions:
|
||||||
get:
|
get:
|
||||||
responses:
|
responses:
|
||||||
|
|
@ -818,7 +953,22 @@ paths:
|
||||||
List routes.
|
List routes.
|
||||||
|
|
||||||
List all available API routes with their methods and implementing providers.
|
List all available API routes with their methods and implementing providers.
|
||||||
parameters: []
|
parameters:
|
||||||
|
- name: api_filter
|
||||||
|
in: query
|
||||||
|
description: >-
|
||||||
|
Optional filter to control which routes are returned. Can be an API level
|
||||||
|
('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level,
|
||||||
|
or 'deprecated' to show deprecated routes across all levels. If not specified,
|
||||||
|
returns only non-deprecated v1 routes.
|
||||||
|
required: false
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
enum:
|
||||||
|
- v1
|
||||||
|
- v1alpha
|
||||||
|
- v1beta
|
||||||
|
- deprecated
|
||||||
deprecated: false
|
deprecated: false
|
||||||
/v1/models:
|
/v1/models:
|
||||||
get:
|
get:
|
||||||
|
|
@ -976,6 +1126,31 @@ paths:
|
||||||
$ref: '#/components/schemas/RunModerationRequest'
|
$ref: '#/components/schemas/RunModerationRequest'
|
||||||
required: true
|
required: true
|
||||||
deprecated: false
|
deprecated: false
|
||||||
|
/v1/openai/v1/models:
|
||||||
|
get:
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
description: A OpenAIListModelsResponse.
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/OpenAIListModelsResponse'
|
||||||
|
'400':
|
||||||
|
$ref: '#/components/responses/BadRequest400'
|
||||||
|
'429':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/TooManyRequests429
|
||||||
|
'500':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/InternalServerError500
|
||||||
|
default:
|
||||||
|
$ref: '#/components/responses/DefaultError'
|
||||||
|
tags:
|
||||||
|
- Models
|
||||||
|
summary: List models using the OpenAI API.
|
||||||
|
description: List models using the OpenAI API.
|
||||||
|
parameters: []
|
||||||
|
deprecated: false
|
||||||
/v1/prompts:
|
/v1/prompts:
|
||||||
get:
|
get:
|
||||||
responses:
|
responses:
|
||||||
|
|
@ -1832,40 +2007,6 @@ paths:
|
||||||
schema:
|
schema:
|
||||||
type: string
|
type: string
|
||||||
deprecated: false
|
deprecated: false
|
||||||
/v1/synthetic-data-generation/generate:
|
|
||||||
post:
|
|
||||||
responses:
|
|
||||||
'200':
|
|
||||||
description: >-
|
|
||||||
Response containing filtered synthetic data samples and optional statistics
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
$ref: '#/components/schemas/SyntheticDataGenerationResponse'
|
|
||||||
'400':
|
|
||||||
$ref: '#/components/responses/BadRequest400'
|
|
||||||
'429':
|
|
||||||
$ref: >-
|
|
||||||
#/components/responses/TooManyRequests429
|
|
||||||
'500':
|
|
||||||
$ref: >-
|
|
||||||
#/components/responses/InternalServerError500
|
|
||||||
default:
|
|
||||||
$ref: '#/components/responses/DefaultError'
|
|
||||||
tags:
|
|
||||||
- SyntheticDataGeneration (Coming Soon)
|
|
||||||
summary: >-
|
|
||||||
Generate synthetic data based on input dialogs and apply filtering.
|
|
||||||
description: >-
|
|
||||||
Generate synthetic data based on input dialogs and apply filtering.
|
|
||||||
parameters: []
|
|
||||||
requestBody:
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
$ref: '#/components/schemas/SyntheticDataGenerateRequest'
|
|
||||||
required: true
|
|
||||||
deprecated: false
|
|
||||||
/v1/tool-runtime/invoke:
|
/v1/tool-runtime/invoke:
|
||||||
post:
|
post:
|
||||||
responses:
|
responses:
|
||||||
|
|
@ -2999,6 +3140,331 @@ components:
|
||||||
title: Error
|
title: Error
|
||||||
description: >-
|
description: >-
|
||||||
Error response from the API. Roughly follows RFC 7807.
|
Error response from the API. Roughly follows RFC 7807.
|
||||||
|
ListBatchesResponse:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
object:
|
||||||
|
type: string
|
||||||
|
const: list
|
||||||
|
default: list
|
||||||
|
data:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
id:
|
||||||
|
type: string
|
||||||
|
completion_window:
|
||||||
|
type: string
|
||||||
|
created_at:
|
||||||
|
type: integer
|
||||||
|
endpoint:
|
||||||
|
type: string
|
||||||
|
input_file_id:
|
||||||
|
type: string
|
||||||
|
object:
|
||||||
|
type: string
|
||||||
|
const: batch
|
||||||
|
status:
|
||||||
|
type: string
|
||||||
|
enum:
|
||||||
|
- validating
|
||||||
|
- failed
|
||||||
|
- in_progress
|
||||||
|
- finalizing
|
||||||
|
- completed
|
||||||
|
- expired
|
||||||
|
- cancelling
|
||||||
|
- cancelled
|
||||||
|
cancelled_at:
|
||||||
|
type: integer
|
||||||
|
cancelling_at:
|
||||||
|
type: integer
|
||||||
|
completed_at:
|
||||||
|
type: integer
|
||||||
|
error_file_id:
|
||||||
|
type: string
|
||||||
|
errors:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
data:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
code:
|
||||||
|
type: string
|
||||||
|
line:
|
||||||
|
type: integer
|
||||||
|
message:
|
||||||
|
type: string
|
||||||
|
param:
|
||||||
|
type: string
|
||||||
|
additionalProperties: false
|
||||||
|
title: BatchError
|
||||||
|
object:
|
||||||
|
type: string
|
||||||
|
additionalProperties: false
|
||||||
|
title: Errors
|
||||||
|
expired_at:
|
||||||
|
type: integer
|
||||||
|
expires_at:
|
||||||
|
type: integer
|
||||||
|
failed_at:
|
||||||
|
type: integer
|
||||||
|
finalizing_at:
|
||||||
|
type: integer
|
||||||
|
in_progress_at:
|
||||||
|
type: integer
|
||||||
|
metadata:
|
||||||
|
type: object
|
||||||
|
additionalProperties:
|
||||||
|
type: string
|
||||||
|
model:
|
||||||
|
type: string
|
||||||
|
output_file_id:
|
||||||
|
type: string
|
||||||
|
request_counts:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
completed:
|
||||||
|
type: integer
|
||||||
|
failed:
|
||||||
|
type: integer
|
||||||
|
total:
|
||||||
|
type: integer
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- completed
|
||||||
|
- failed
|
||||||
|
- total
|
||||||
|
title: BatchRequestCounts
|
||||||
|
usage:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
input_tokens:
|
||||||
|
type: integer
|
||||||
|
input_tokens_details:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
cached_tokens:
|
||||||
|
type: integer
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- cached_tokens
|
||||||
|
title: InputTokensDetails
|
||||||
|
output_tokens:
|
||||||
|
type: integer
|
||||||
|
output_tokens_details:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
reasoning_tokens:
|
||||||
|
type: integer
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- reasoning_tokens
|
||||||
|
title: OutputTokensDetails
|
||||||
|
total_tokens:
|
||||||
|
type: integer
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- input_tokens
|
||||||
|
- input_tokens_details
|
||||||
|
- output_tokens
|
||||||
|
- output_tokens_details
|
||||||
|
- total_tokens
|
||||||
|
title: BatchUsage
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- id
|
||||||
|
- completion_window
|
||||||
|
- created_at
|
||||||
|
- endpoint
|
||||||
|
- input_file_id
|
||||||
|
- object
|
||||||
|
- status
|
||||||
|
title: Batch
|
||||||
|
first_id:
|
||||||
|
type: string
|
||||||
|
last_id:
|
||||||
|
type: string
|
||||||
|
has_more:
|
||||||
|
type: boolean
|
||||||
|
default: false
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- object
|
||||||
|
- data
|
||||||
|
- has_more
|
||||||
|
title: ListBatchesResponse
|
||||||
|
description: >-
|
||||||
|
Response containing a list of batch objects.
|
||||||
|
CreateBatchRequest:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
input_file_id:
|
||||||
|
type: string
|
||||||
|
description: >-
|
||||||
|
The ID of an uploaded file containing requests for the batch.
|
||||||
|
endpoint:
|
||||||
|
type: string
|
||||||
|
description: >-
|
||||||
|
The endpoint to be used for all requests in the batch.
|
||||||
|
completion_window:
|
||||||
|
type: string
|
||||||
|
const: 24h
|
||||||
|
description: >-
|
||||||
|
The time window within which the batch should be processed.
|
||||||
|
metadata:
|
||||||
|
type: object
|
||||||
|
additionalProperties:
|
||||||
|
type: string
|
||||||
|
description: Optional metadata for the batch.
|
||||||
|
idempotency_key:
|
||||||
|
type: string
|
||||||
|
description: >-
|
||||||
|
Optional idempotency key. When provided, enables idempotent behavior.
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- input_file_id
|
||||||
|
- endpoint
|
||||||
|
- completion_window
|
||||||
|
title: CreateBatchRequest
|
||||||
|
Batch:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
id:
|
||||||
|
type: string
|
||||||
|
completion_window:
|
||||||
|
type: string
|
||||||
|
created_at:
|
||||||
|
type: integer
|
||||||
|
endpoint:
|
||||||
|
type: string
|
||||||
|
input_file_id:
|
||||||
|
type: string
|
||||||
|
object:
|
||||||
|
type: string
|
||||||
|
const: batch
|
||||||
|
status:
|
||||||
|
type: string
|
||||||
|
enum:
|
||||||
|
- validating
|
||||||
|
- failed
|
||||||
|
- in_progress
|
||||||
|
- finalizing
|
||||||
|
- completed
|
||||||
|
- expired
|
||||||
|
- cancelling
|
||||||
|
- cancelled
|
||||||
|
cancelled_at:
|
||||||
|
type: integer
|
||||||
|
cancelling_at:
|
||||||
|
type: integer
|
||||||
|
completed_at:
|
||||||
|
type: integer
|
||||||
|
error_file_id:
|
||||||
|
type: string
|
||||||
|
errors:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
data:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
code:
|
||||||
|
type: string
|
||||||
|
line:
|
||||||
|
type: integer
|
||||||
|
message:
|
||||||
|
type: string
|
||||||
|
param:
|
||||||
|
type: string
|
||||||
|
additionalProperties: false
|
||||||
|
title: BatchError
|
||||||
|
object:
|
||||||
|
type: string
|
||||||
|
additionalProperties: false
|
||||||
|
title: Errors
|
||||||
|
expired_at:
|
||||||
|
type: integer
|
||||||
|
expires_at:
|
||||||
|
type: integer
|
||||||
|
failed_at:
|
||||||
|
type: integer
|
||||||
|
finalizing_at:
|
||||||
|
type: integer
|
||||||
|
in_progress_at:
|
||||||
|
type: integer
|
||||||
|
metadata:
|
||||||
|
type: object
|
||||||
|
additionalProperties:
|
||||||
|
type: string
|
||||||
|
model:
|
||||||
|
type: string
|
||||||
|
output_file_id:
|
||||||
|
type: string
|
||||||
|
request_counts:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
completed:
|
||||||
|
type: integer
|
||||||
|
failed:
|
||||||
|
type: integer
|
||||||
|
total:
|
||||||
|
type: integer
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- completed
|
||||||
|
- failed
|
||||||
|
- total
|
||||||
|
title: BatchRequestCounts
|
||||||
|
usage:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
input_tokens:
|
||||||
|
type: integer
|
||||||
|
input_tokens_details:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
cached_tokens:
|
||||||
|
type: integer
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- cached_tokens
|
||||||
|
title: InputTokensDetails
|
||||||
|
output_tokens:
|
||||||
|
type: integer
|
||||||
|
output_tokens_details:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
reasoning_tokens:
|
||||||
|
type: integer
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- reasoning_tokens
|
||||||
|
title: OutputTokensDetails
|
||||||
|
total_tokens:
|
||||||
|
type: integer
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- input_tokens
|
||||||
|
- input_tokens_details
|
||||||
|
- output_tokens
|
||||||
|
- output_tokens_details
|
||||||
|
- total_tokens
|
||||||
|
title: BatchUsage
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- id
|
||||||
|
- completion_window
|
||||||
|
- created_at
|
||||||
|
- endpoint
|
||||||
|
- input_file_id
|
||||||
|
- object
|
||||||
|
- status
|
||||||
|
title: Batch
|
||||||
Order:
|
Order:
|
||||||
type: string
|
type: string
|
||||||
enum:
|
enum:
|
||||||
|
|
@ -5341,6 +5807,48 @@ components:
|
||||||
- metadata
|
- metadata
|
||||||
title: ModerationObjectResults
|
title: ModerationObjectResults
|
||||||
description: A moderation object.
|
description: A moderation object.
|
||||||
|
OpenAIModel:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
id:
|
||||||
|
type: string
|
||||||
|
object:
|
||||||
|
type: string
|
||||||
|
const: model
|
||||||
|
default: model
|
||||||
|
created:
|
||||||
|
type: integer
|
||||||
|
owned_by:
|
||||||
|
type: string
|
||||||
|
custom_metadata:
|
||||||
|
type: object
|
||||||
|
additionalProperties:
|
||||||
|
oneOf:
|
||||||
|
- type: 'null'
|
||||||
|
- type: boolean
|
||||||
|
- type: number
|
||||||
|
- type: string
|
||||||
|
- type: array
|
||||||
|
- type: object
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- id
|
||||||
|
- object
|
||||||
|
- created
|
||||||
|
- owned_by
|
||||||
|
title: OpenAIModel
|
||||||
|
description: A model from OpenAI.
|
||||||
|
OpenAIListModelsResponse:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
data:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
$ref: '#/components/schemas/OpenAIModel'
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- data
|
||||||
|
title: OpenAIListModelsResponse
|
||||||
Prompt:
|
Prompt:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
|
@ -8265,45 +8773,29 @@ components:
|
||||||
required:
|
required:
|
||||||
- shield_id
|
- shield_id
|
||||||
title: RegisterShieldRequest
|
title: RegisterShieldRequest
|
||||||
CompletionMessage:
|
InvokeToolRequest:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
role:
|
tool_name:
|
||||||
type: string
|
type: string
|
||||||
const: assistant
|
description: The name of the tool to invoke.
|
||||||
default: assistant
|
kwargs:
|
||||||
|
type: object
|
||||||
|
additionalProperties:
|
||||||
|
oneOf:
|
||||||
|
- type: 'null'
|
||||||
|
- type: boolean
|
||||||
|
- type: number
|
||||||
|
- type: string
|
||||||
|
- type: array
|
||||||
|
- type: object
|
||||||
description: >-
|
description: >-
|
||||||
Must be "assistant" to identify this as the model's response
|
A dictionary of arguments to pass to the tool.
|
||||||
content:
|
|
||||||
$ref: '#/components/schemas/InterleavedContent'
|
|
||||||
description: The content of the model's response
|
|
||||||
stop_reason:
|
|
||||||
type: string
|
|
||||||
enum:
|
|
||||||
- end_of_turn
|
|
||||||
- end_of_message
|
|
||||||
- out_of_tokens
|
|
||||||
description: >-
|
|
||||||
Reason why the model stopped generating. Options are: - `StopReason.end_of_turn`:
|
|
||||||
The model finished generating the entire response. - `StopReason.end_of_message`:
|
|
||||||
The model finished generating but generated a partial response -- usually,
|
|
||||||
a tool call. The user may call the tool and continue the conversation
|
|
||||||
with the tool's response. - `StopReason.out_of_tokens`: The model ran
|
|
||||||
out of token budget.
|
|
||||||
tool_calls:
|
|
||||||
type: array
|
|
||||||
items:
|
|
||||||
$ref: '#/components/schemas/ToolCall'
|
|
||||||
description: >-
|
|
||||||
List of tool calls. Each tool call is a ToolCall object.
|
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- role
|
- tool_name
|
||||||
- content
|
- kwargs
|
||||||
- stop_reason
|
title: InvokeToolRequest
|
||||||
title: CompletionMessage
|
|
||||||
description: >-
|
|
||||||
A message containing the model's (assistant) response in a chat conversation.
|
|
||||||
ImageContentItem:
|
ImageContentItem:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
|
@ -8350,41 +8842,6 @@ components:
|
||||||
mapping:
|
mapping:
|
||||||
image: '#/components/schemas/ImageContentItem'
|
image: '#/components/schemas/ImageContentItem'
|
||||||
text: '#/components/schemas/TextContentItem'
|
text: '#/components/schemas/TextContentItem'
|
||||||
Message:
|
|
||||||
oneOf:
|
|
||||||
- $ref: '#/components/schemas/UserMessage'
|
|
||||||
- $ref: '#/components/schemas/SystemMessage'
|
|
||||||
- $ref: '#/components/schemas/ToolResponseMessage'
|
|
||||||
- $ref: '#/components/schemas/CompletionMessage'
|
|
||||||
discriminator:
|
|
||||||
propertyName: role
|
|
||||||
mapping:
|
|
||||||
user: '#/components/schemas/UserMessage'
|
|
||||||
system: '#/components/schemas/SystemMessage'
|
|
||||||
tool: '#/components/schemas/ToolResponseMessage'
|
|
||||||
assistant: '#/components/schemas/CompletionMessage'
|
|
||||||
SystemMessage:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
role:
|
|
||||||
type: string
|
|
||||||
const: system
|
|
||||||
default: system
|
|
||||||
description: >-
|
|
||||||
Must be "system" to identify this as a system message
|
|
||||||
content:
|
|
||||||
$ref: '#/components/schemas/InterleavedContent'
|
|
||||||
description: >-
|
|
||||||
The content of the "system prompt". If multiple system messages are provided,
|
|
||||||
they are concatenated. The underlying Llama Stack code may also add other
|
|
||||||
system messages (for example, for formatting tool definitions).
|
|
||||||
additionalProperties: false
|
|
||||||
required:
|
|
||||||
- role
|
|
||||||
- content
|
|
||||||
title: SystemMessage
|
|
||||||
description: >-
|
|
||||||
A system message providing instructions or context to the model.
|
|
||||||
TextContentItem:
|
TextContentItem:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
|
@ -8403,179 +8860,6 @@ components:
|
||||||
- text
|
- text
|
||||||
title: TextContentItem
|
title: TextContentItem
|
||||||
description: A text content item
|
description: A text content item
|
||||||
ToolCall:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
call_id:
|
|
||||||
type: string
|
|
||||||
tool_name:
|
|
||||||
oneOf:
|
|
||||||
- type: string
|
|
||||||
enum:
|
|
||||||
- brave_search
|
|
||||||
- wolfram_alpha
|
|
||||||
- photogen
|
|
||||||
- code_interpreter
|
|
||||||
title: BuiltinTool
|
|
||||||
- type: string
|
|
||||||
arguments:
|
|
||||||
type: string
|
|
||||||
additionalProperties: false
|
|
||||||
required:
|
|
||||||
- call_id
|
|
||||||
- tool_name
|
|
||||||
- arguments
|
|
||||||
title: ToolCall
|
|
||||||
ToolResponseMessage:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
role:
|
|
||||||
type: string
|
|
||||||
const: tool
|
|
||||||
default: tool
|
|
||||||
description: >-
|
|
||||||
Must be "tool" to identify this as a tool response
|
|
||||||
call_id:
|
|
||||||
type: string
|
|
||||||
description: >-
|
|
||||||
Unique identifier for the tool call this response is for
|
|
||||||
content:
|
|
||||||
$ref: '#/components/schemas/InterleavedContent'
|
|
||||||
description: The response content from the tool
|
|
||||||
additionalProperties: false
|
|
||||||
required:
|
|
||||||
- role
|
|
||||||
- call_id
|
|
||||||
- content
|
|
||||||
title: ToolResponseMessage
|
|
||||||
description: >-
|
|
||||||
A message representing the result of a tool invocation.
|
|
||||||
URL:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
uri:
|
|
||||||
type: string
|
|
||||||
description: The URL string pointing to the resource
|
|
||||||
additionalProperties: false
|
|
||||||
required:
|
|
||||||
- uri
|
|
||||||
title: URL
|
|
||||||
description: A URL reference to external content.
|
|
||||||
UserMessage:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
role:
|
|
||||||
type: string
|
|
||||||
const: user
|
|
||||||
default: user
|
|
||||||
description: >-
|
|
||||||
Must be "user" to identify this as a user message
|
|
||||||
content:
|
|
||||||
$ref: '#/components/schemas/InterleavedContent'
|
|
||||||
description: >-
|
|
||||||
The content of the message, which can include text and other media
|
|
||||||
context:
|
|
||||||
$ref: '#/components/schemas/InterleavedContent'
|
|
||||||
description: >-
|
|
||||||
(Optional) This field is used internally by Llama Stack to pass RAG context.
|
|
||||||
This field may be removed in the API in the future.
|
|
||||||
additionalProperties: false
|
|
||||||
required:
|
|
||||||
- role
|
|
||||||
- content
|
|
||||||
title: UserMessage
|
|
||||||
description: >-
|
|
||||||
A message from the user in a chat conversation.
|
|
||||||
SyntheticDataGenerateRequest:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
dialogs:
|
|
||||||
type: array
|
|
||||||
items:
|
|
||||||
$ref: '#/components/schemas/Message'
|
|
||||||
description: >-
|
|
||||||
List of conversation messages to use as input for synthetic data generation
|
|
||||||
filtering_function:
|
|
||||||
type: string
|
|
||||||
enum:
|
|
||||||
- none
|
|
||||||
- random
|
|
||||||
- top_k
|
|
||||||
- top_p
|
|
||||||
- top_k_top_p
|
|
||||||
- sigmoid
|
|
||||||
description: >-
|
|
||||||
Type of filtering to apply to generated synthetic data samples
|
|
||||||
model:
|
|
||||||
type: string
|
|
||||||
description: >-
|
|
||||||
(Optional) The identifier of the model to use. The model must be registered
|
|
||||||
with Llama Stack and available via the /models endpoint
|
|
||||||
additionalProperties: false
|
|
||||||
required:
|
|
||||||
- dialogs
|
|
||||||
- filtering_function
|
|
||||||
title: SyntheticDataGenerateRequest
|
|
||||||
SyntheticDataGenerationResponse:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
synthetic_data:
|
|
||||||
type: array
|
|
||||||
items:
|
|
||||||
type: object
|
|
||||||
additionalProperties:
|
|
||||||
oneOf:
|
|
||||||
- type: 'null'
|
|
||||||
- type: boolean
|
|
||||||
- type: number
|
|
||||||
- type: string
|
|
||||||
- type: array
|
|
||||||
- type: object
|
|
||||||
description: >-
|
|
||||||
List of generated synthetic data samples that passed the filtering criteria
|
|
||||||
statistics:
|
|
||||||
type: object
|
|
||||||
additionalProperties:
|
|
||||||
oneOf:
|
|
||||||
- type: 'null'
|
|
||||||
- type: boolean
|
|
||||||
- type: number
|
|
||||||
- type: string
|
|
||||||
- type: array
|
|
||||||
- type: object
|
|
||||||
description: >-
|
|
||||||
(Optional) Statistical information about the generation process and filtering
|
|
||||||
results
|
|
||||||
additionalProperties: false
|
|
||||||
required:
|
|
||||||
- synthetic_data
|
|
||||||
title: SyntheticDataGenerationResponse
|
|
||||||
description: >-
|
|
||||||
Response from the synthetic data generation. Batch of (prompt, response, score)
|
|
||||||
tuples that pass the threshold.
|
|
||||||
InvokeToolRequest:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
tool_name:
|
|
||||||
type: string
|
|
||||||
description: The name of the tool to invoke.
|
|
||||||
kwargs:
|
|
||||||
type: object
|
|
||||||
additionalProperties:
|
|
||||||
oneOf:
|
|
||||||
- type: 'null'
|
|
||||||
- type: boolean
|
|
||||||
- type: number
|
|
||||||
- type: string
|
|
||||||
- type: array
|
|
||||||
- type: object
|
|
||||||
description: >-
|
|
||||||
A dictionary of arguments to pass to the tool.
|
|
||||||
additionalProperties: false
|
|
||||||
required:
|
|
||||||
- tool_name
|
|
||||||
- kwargs
|
|
||||||
title: InvokeToolRequest
|
|
||||||
ToolInvocationResult:
|
ToolInvocationResult:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
|
@ -8606,6 +8890,17 @@ components:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
title: ToolInvocationResult
|
title: ToolInvocationResult
|
||||||
description: Result of a tool invocation.
|
description: Result of a tool invocation.
|
||||||
|
URL:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
uri:
|
||||||
|
type: string
|
||||||
|
description: The URL string pointing to the resource
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- uri
|
||||||
|
title: URL
|
||||||
|
description: A URL reference to external content.
|
||||||
ToolDef:
|
ToolDef:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
|
@ -9045,6 +9340,10 @@ components:
|
||||||
description: >-
|
description: >-
|
||||||
The content of the chunk, which can be interleaved text, images, or other
|
The content of the chunk, which can be interleaved text, images, or other
|
||||||
types.
|
types.
|
||||||
|
chunk_id:
|
||||||
|
type: string
|
||||||
|
description: >-
|
||||||
|
Unique identifier for the chunk. Must be provided explicitly.
|
||||||
metadata:
|
metadata:
|
||||||
type: object
|
type: object
|
||||||
additionalProperties:
|
additionalProperties:
|
||||||
|
|
@ -9065,10 +9364,6 @@ components:
|
||||||
description: >-
|
description: >-
|
||||||
Optional embedding for the chunk. If not provided, it will be computed
|
Optional embedding for the chunk. If not provided, it will be computed
|
||||||
later.
|
later.
|
||||||
stored_chunk_id:
|
|
||||||
type: string
|
|
||||||
description: >-
|
|
||||||
The chunk ID that is stored in the vector database. Used for backend functionality.
|
|
||||||
chunk_metadata:
|
chunk_metadata:
|
||||||
$ref: '#/components/schemas/ChunkMetadata'
|
$ref: '#/components/schemas/ChunkMetadata'
|
||||||
description: >-
|
description: >-
|
||||||
|
|
@ -9077,6 +9372,7 @@ components:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- content
|
- content
|
||||||
|
- chunk_id
|
||||||
- metadata
|
- metadata
|
||||||
title: Chunk
|
title: Chunk
|
||||||
description: >-
|
description: >-
|
||||||
|
|
@ -10143,6 +10439,19 @@ tags:
|
||||||
|
|
||||||
- `background`
|
- `background`
|
||||||
x-displayName: Agents
|
x-displayName: Agents
|
||||||
|
- name: Batches
|
||||||
|
description: >-
|
||||||
|
The API is designed to allow use of openai client libraries for seamless integration.
|
||||||
|
|
||||||
|
|
||||||
|
This API provides the following extensions:
|
||||||
|
- idempotent batch creation
|
||||||
|
|
||||||
|
Note: This API is currently under active development and may undergo changes.
|
||||||
|
x-displayName: >-
|
||||||
|
The Batches API enables efficient processing of multiple requests in a single
|
||||||
|
operation, particularly useful for processing large datasets, batch evaluation
|
||||||
|
workflows, and cost-effective inference at scale.
|
||||||
- name: Conversations
|
- name: Conversations
|
||||||
description: >-
|
description: >-
|
||||||
Protocol for conversation management operations.
|
Protocol for conversation management operations.
|
||||||
|
|
@ -10193,8 +10502,6 @@ tags:
|
||||||
description: ''
|
description: ''
|
||||||
- name: Shields
|
- name: Shields
|
||||||
description: ''
|
description: ''
|
||||||
- name: SyntheticDataGeneration (Coming Soon)
|
|
||||||
description: ''
|
|
||||||
- name: ToolGroups
|
- name: ToolGroups
|
||||||
description: ''
|
description: ''
|
||||||
- name: ToolRuntime
|
- name: ToolRuntime
|
||||||
|
|
@ -10205,6 +10512,7 @@ x-tagGroups:
|
||||||
- name: Operations
|
- name: Operations
|
||||||
tags:
|
tags:
|
||||||
- Agents
|
- Agents
|
||||||
|
- Batches
|
||||||
- Conversations
|
- Conversations
|
||||||
- Files
|
- Files
|
||||||
- Inference
|
- Inference
|
||||||
|
|
@ -10216,7 +10524,6 @@ x-tagGroups:
|
||||||
- Scoring
|
- Scoring
|
||||||
- ScoringFunctions
|
- ScoringFunctions
|
||||||
- Shields
|
- Shields
|
||||||
- SyntheticDataGeneration (Coming Soon)
|
|
||||||
- ToolGroups
|
- ToolGroups
|
||||||
- ToolRuntime
|
- ToolRuntime
|
||||||
- VectorIO
|
- VectorIO
|
||||||
|
|
|
||||||
18091
docs/static/stainless-llama-stack-spec.html
vendored
18091
docs/static/stainless-llama-stack-spec.html
vendored
File diff suppressed because it is too large
Load diff
1006
docs/static/stainless-llama-stack-spec.yaml
vendored
1006
docs/static/stainless-llama-stack-spec.yaml
vendored
File diff suppressed because it is too large
Load diff
|
|
@ -7,7 +7,7 @@ required-version = ">=0.7.0"
|
||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "llama_stack"
|
name = "llama_stack"
|
||||||
version = "0.3.0"
|
version = "0.4.0.dev0"
|
||||||
authors = [{ name = "Meta Llama", email = "llama-oss@meta.com" }]
|
authors = [{ name = "Meta Llama", email = "llama-oss@meta.com" }]
|
||||||
description = "Llama Stack"
|
description = "Llama Stack"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
|
@ -284,7 +284,6 @@ exclude = [
|
||||||
"^src/llama_stack/models/llama/llama3/interface\\.py$",
|
"^src/llama_stack/models/llama/llama3/interface\\.py$",
|
||||||
"^src/llama_stack/models/llama/llama3/tokenizer\\.py$",
|
"^src/llama_stack/models/llama/llama3/tokenizer\\.py$",
|
||||||
"^src/llama_stack/models/llama/llama3/tool_utils\\.py$",
|
"^src/llama_stack/models/llama/llama3/tool_utils\\.py$",
|
||||||
"^src/llama_stack/providers/inline/agents/meta_reference/",
|
|
||||||
"^src/llama_stack/providers/inline/datasetio/localfs/",
|
"^src/llama_stack/providers/inline/datasetio/localfs/",
|
||||||
"^src/llama_stack/providers/inline/eval/meta_reference/eval\\.py$",
|
"^src/llama_stack/providers/inline/eval/meta_reference/eval\\.py$",
|
||||||
"^src/llama_stack/providers/inline/inference/meta_reference/inference\\.py$",
|
"^src/llama_stack/providers/inline/inference/meta_reference/inference\\.py$",
|
||||||
|
|
|
||||||
|
|
@ -215,6 +215,16 @@ build_image() {
|
||||||
--build-arg "LLAMA_STACK_DIR=/workspace"
|
--build-arg "LLAMA_STACK_DIR=/workspace"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Pass UV index configuration for release branches
|
||||||
|
if [[ -n "${UV_EXTRA_INDEX_URL:-}" ]]; then
|
||||||
|
echo "Adding UV_EXTRA_INDEX_URL to docker build: $UV_EXTRA_INDEX_URL"
|
||||||
|
build_cmd+=(--build-arg "UV_EXTRA_INDEX_URL=$UV_EXTRA_INDEX_URL")
|
||||||
|
fi
|
||||||
|
if [[ -n "${UV_INDEX_STRATEGY:-}" ]]; then
|
||||||
|
echo "Adding UV_INDEX_STRATEGY to docker build: $UV_INDEX_STRATEGY"
|
||||||
|
build_cmd+=(--build-arg "UV_INDEX_STRATEGY=$UV_INDEX_STRATEGY")
|
||||||
|
fi
|
||||||
|
|
||||||
if ! "${build_cmd[@]}"; then
|
if ! "${build_cmd[@]}"; then
|
||||||
echo "❌ Failed to build Docker image"
|
echo "❌ Failed to build Docker image"
|
||||||
exit 1
|
exit 1
|
||||||
|
|
|
||||||
|
|
@ -102,7 +102,6 @@ while [[ $# -gt 0 ]]; do
|
||||||
esac
|
esac
|
||||||
done
|
done
|
||||||
|
|
||||||
|
|
||||||
# Validate required parameters
|
# Validate required parameters
|
||||||
if [[ -z "$STACK_CONFIG" && "$COLLECT_ONLY" == false ]]; then
|
if [[ -z "$STACK_CONFIG" && "$COLLECT_ONLY" == false ]]; then
|
||||||
echo "Error: --stack-config is required"
|
echo "Error: --stack-config is required"
|
||||||
|
|
@ -187,11 +186,35 @@ if ! command -v pytest &> /dev/null; then
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# Helper function to find next available port
|
||||||
|
find_available_port() {
|
||||||
|
local start_port=$1
|
||||||
|
local port=$start_port
|
||||||
|
for ((i=0; i<100; i++)); do
|
||||||
|
if ! lsof -Pi :$port -sTCP:LISTEN -t >/dev/null 2>&1; then
|
||||||
|
echo $port
|
||||||
|
return 0
|
||||||
|
fi
|
||||||
|
((port++))
|
||||||
|
done
|
||||||
|
echo "Failed to find available port starting from $start_port" >&2
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
# Start Llama Stack Server if needed
|
# Start Llama Stack Server if needed
|
||||||
if [[ "$STACK_CONFIG" == *"server:"* && "$COLLECT_ONLY" == false ]]; then
|
if [[ "$STACK_CONFIG" == *"server:"* && "$COLLECT_ONLY" == false ]]; then
|
||||||
|
# Find an available port for the server
|
||||||
|
LLAMA_STACK_PORT=$(find_available_port 8321)
|
||||||
|
if [[ $? -ne 0 ]]; then
|
||||||
|
echo "Error: $LLAMA_STACK_PORT"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
export LLAMA_STACK_PORT
|
||||||
|
echo "Will use port: $LLAMA_STACK_PORT"
|
||||||
|
|
||||||
stop_server() {
|
stop_server() {
|
||||||
echo "Stopping Llama Stack Server..."
|
echo "Stopping Llama Stack Server..."
|
||||||
pids=$(lsof -i :8321 | awk 'NR>1 {print $2}')
|
pids=$(lsof -i :$LLAMA_STACK_PORT | awk 'NR>1 {print $2}')
|
||||||
if [[ -n "$pids" ]]; then
|
if [[ -n "$pids" ]]; then
|
||||||
echo "Killing Llama Stack Server processes: $pids"
|
echo "Killing Llama Stack Server processes: $pids"
|
||||||
kill -9 $pids
|
kill -9 $pids
|
||||||
|
|
@ -201,20 +224,25 @@ if [[ "$STACK_CONFIG" == *"server:"* && "$COLLECT_ONLY" == false ]]; then
|
||||||
echo "Llama Stack Server stopped"
|
echo "Llama Stack Server stopped"
|
||||||
}
|
}
|
||||||
|
|
||||||
# check if server is already running
|
|
||||||
if curl -s http://localhost:8321/v1/health 2>/dev/null | grep -q "OK"; then
|
|
||||||
echo "Llama Stack Server is already running, skipping start"
|
|
||||||
else
|
|
||||||
echo "=== Starting Llama Stack Server ==="
|
echo "=== Starting Llama Stack Server ==="
|
||||||
export LLAMA_STACK_LOG_WIDTH=120
|
export LLAMA_STACK_LOG_WIDTH=120
|
||||||
|
|
||||||
|
# Configure telemetry collector for server mode
|
||||||
|
# Use a fixed port for the OTEL collector so the server can connect to it
|
||||||
|
COLLECTOR_PORT=4317
|
||||||
|
export LLAMA_STACK_TEST_COLLECTOR_PORT="${COLLECTOR_PORT}"
|
||||||
|
export OTEL_EXPORTER_OTLP_ENDPOINT="http://127.0.0.1:${COLLECTOR_PORT}"
|
||||||
|
export OTEL_EXPORTER_OTLP_PROTOCOL="http/protobuf"
|
||||||
|
export OTEL_BSP_SCHEDULE_DELAY="200"
|
||||||
|
export OTEL_BSP_EXPORT_TIMEOUT="2000"
|
||||||
|
|
||||||
# remove "server:" from STACK_CONFIG
|
# remove "server:" from STACK_CONFIG
|
||||||
stack_config=$(echo "$STACK_CONFIG" | sed 's/^server://')
|
stack_config=$(echo "$STACK_CONFIG" | sed 's/^server://')
|
||||||
nohup llama stack run $stack_config >server.log 2>&1 &
|
nohup llama stack run $stack_config >server.log 2>&1 &
|
||||||
|
|
||||||
echo "Waiting for Llama Stack Server to start..."
|
echo "Waiting for Llama Stack Server to start on port $LLAMA_STACK_PORT..."
|
||||||
for i in {1..30}; do
|
for i in {1..30}; do
|
||||||
if curl -s http://localhost:8321/v1/health 2>/dev/null | grep -q "OK"; then
|
if curl -s http://localhost:$LLAMA_STACK_PORT/v1/health 2>/dev/null | grep -q "OK"; then
|
||||||
echo "✅ Llama Stack Server started successfully"
|
echo "✅ Llama Stack Server started successfully"
|
||||||
break
|
break
|
||||||
fi
|
fi
|
||||||
|
|
@ -227,7 +255,6 @@ if [[ "$STACK_CONFIG" == *"server:"* && "$COLLECT_ONLY" == false ]]; then
|
||||||
sleep 1
|
sleep 1
|
||||||
done
|
done
|
||||||
echo ""
|
echo ""
|
||||||
fi
|
|
||||||
|
|
||||||
trap stop_server EXIT ERR INT TERM
|
trap stop_server EXIT ERR INT TERM
|
||||||
fi
|
fi
|
||||||
|
|
@ -251,7 +278,14 @@ if [[ "$STACK_CONFIG" == *"docker:"* && "$COLLECT_ONLY" == false ]]; then
|
||||||
|
|
||||||
# Extract distribution name from docker:distro format
|
# Extract distribution name from docker:distro format
|
||||||
DISTRO=$(echo "$STACK_CONFIG" | sed 's/^docker://')
|
DISTRO=$(echo "$STACK_CONFIG" | sed 's/^docker://')
|
||||||
export LLAMA_STACK_PORT=8321
|
# Find an available port for the docker container
|
||||||
|
LLAMA_STACK_PORT=$(find_available_port 8321)
|
||||||
|
if [[ $? -ne 0 ]]; then
|
||||||
|
echo "Error: $LLAMA_STACK_PORT"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
export LLAMA_STACK_PORT
|
||||||
|
echo "Will use port: $LLAMA_STACK_PORT"
|
||||||
|
|
||||||
echo "=== Building Docker Image for distribution: $DISTRO ==="
|
echo "=== Building Docker Image for distribution: $DISTRO ==="
|
||||||
containerfile="$ROOT_DIR/containers/Containerfile"
|
containerfile="$ROOT_DIR/containers/Containerfile"
|
||||||
|
|
@ -271,6 +305,16 @@ if [[ "$STACK_CONFIG" == *"docker:"* && "$COLLECT_ONLY" == false ]]; then
|
||||||
--build-arg "LLAMA_STACK_DIR=/workspace"
|
--build-arg "LLAMA_STACK_DIR=/workspace"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Pass UV index configuration for release branches
|
||||||
|
if [[ -n "${UV_EXTRA_INDEX_URL:-}" ]]; then
|
||||||
|
echo "Adding UV_EXTRA_INDEX_URL to docker build: $UV_EXTRA_INDEX_URL"
|
||||||
|
build_cmd+=(--build-arg "UV_EXTRA_INDEX_URL=$UV_EXTRA_INDEX_URL")
|
||||||
|
fi
|
||||||
|
if [[ -n "${UV_INDEX_STRATEGY:-}" ]]; then
|
||||||
|
echo "Adding UV_INDEX_STRATEGY to docker build: $UV_INDEX_STRATEGY"
|
||||||
|
build_cmd+=(--build-arg "UV_INDEX_STRATEGY=$UV_INDEX_STRATEGY")
|
||||||
|
fi
|
||||||
|
|
||||||
if ! "${build_cmd[@]}"; then
|
if ! "${build_cmd[@]}"; then
|
||||||
echo "❌ Failed to build Docker image"
|
echo "❌ Failed to build Docker image"
|
||||||
exit 1
|
exit 1
|
||||||
|
|
@ -284,10 +328,15 @@ if [[ "$STACK_CONFIG" == *"docker:"* && "$COLLECT_ONLY" == false ]]; then
|
||||||
docker stop "$container_name" 2>/dev/null || true
|
docker stop "$container_name" 2>/dev/null || true
|
||||||
docker rm "$container_name" 2>/dev/null || true
|
docker rm "$container_name" 2>/dev/null || true
|
||||||
|
|
||||||
|
# Configure telemetry collector port shared between host and container
|
||||||
|
COLLECTOR_PORT=4317
|
||||||
|
export LLAMA_STACK_TEST_COLLECTOR_PORT="${COLLECTOR_PORT}"
|
||||||
|
|
||||||
# Build environment variables for docker run
|
# Build environment variables for docker run
|
||||||
DOCKER_ENV_VARS=""
|
DOCKER_ENV_VARS=""
|
||||||
DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e LLAMA_STACK_TEST_INFERENCE_MODE=$INFERENCE_MODE"
|
DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e LLAMA_STACK_TEST_INFERENCE_MODE=$INFERENCE_MODE"
|
||||||
DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e LLAMA_STACK_TEST_STACK_CONFIG_TYPE=server"
|
DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e LLAMA_STACK_TEST_STACK_CONFIG_TYPE=server"
|
||||||
|
DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e OTEL_EXPORTER_OTLP_ENDPOINT=http://localhost:${COLLECTOR_PORT}"
|
||||||
|
|
||||||
# Pass through API keys if they exist
|
# Pass through API keys if they exist
|
||||||
[ -n "${TOGETHER_API_KEY:-}" ] && DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e TOGETHER_API_KEY=$TOGETHER_API_KEY"
|
[ -n "${TOGETHER_API_KEY:-}" ] && DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e TOGETHER_API_KEY=$TOGETHER_API_KEY"
|
||||||
|
|
@ -308,8 +357,20 @@ if [[ "$STACK_CONFIG" == *"docker:"* && "$COLLECT_ONLY" == false ]]; then
|
||||||
fi
|
fi
|
||||||
echo "Using image: $IMAGE_NAME"
|
echo "Using image: $IMAGE_NAME"
|
||||||
|
|
||||||
docker run -d --network host --name "$container_name" \
|
# On macOS/Darwin, --network host doesn't work as expected due to Docker running in a VM
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
# Use regular port mapping instead
|
||||||
|
NETWORK_MODE=""
|
||||||
|
PORT_MAPPINGS=""
|
||||||
|
if [[ "$(uname)" != "Darwin" ]] && [[ "$(uname)" != *"MINGW"* ]]; then
|
||||||
|
NETWORK_MODE="--network host"
|
||||||
|
else
|
||||||
|
# On non-Linux (macOS, Windows), need explicit port mappings for both app and telemetry
|
||||||
|
PORT_MAPPINGS="-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT -p $COLLECTOR_PORT:$COLLECTOR_PORT"
|
||||||
|
echo "Using bridge networking with port mapping (non-Linux)"
|
||||||
|
fi
|
||||||
|
|
||||||
|
docker run -d $NETWORK_MODE --name "$container_name" \
|
||||||
|
$PORT_MAPPINGS \
|
||||||
$DOCKER_ENV_VARS \
|
$DOCKER_ENV_VARS \
|
||||||
"$IMAGE_NAME" \
|
"$IMAGE_NAME" \
|
||||||
--port $LLAMA_STACK_PORT
|
--port $LLAMA_STACK_PORT
|
||||||
|
|
@ -411,17 +472,13 @@ elif [ $exit_code -eq 5 ]; then
|
||||||
else
|
else
|
||||||
echo "❌ Tests failed"
|
echo "❌ Tests failed"
|
||||||
echo ""
|
echo ""
|
||||||
echo "=== Dumping last 100 lines of logs for debugging ==="
|
|
||||||
|
|
||||||
# Output server or container logs based on stack config
|
# Output server or container logs based on stack config
|
||||||
if [[ "$STACK_CONFIG" == *"server:"* && -f "server.log" ]]; then
|
if [[ "$STACK_CONFIG" == *"server:"* && -f "server.log" ]]; then
|
||||||
echo "--- Last 100 lines of server.log ---"
|
echo "--- Server side failures can be located inside server.log (available from artifacts on CI) ---"
|
||||||
tail -100 server.log
|
|
||||||
elif [[ "$STACK_CONFIG" == *"docker:"* ]]; then
|
elif [[ "$STACK_CONFIG" == *"docker:"* ]]; then
|
||||||
docker_log_file="docker-${DISTRO}-${INFERENCE_MODE}.log"
|
docker_log_file="docker-${DISTRO}-${INFERENCE_MODE}.log"
|
||||||
if [[ -f "$docker_log_file" ]]; then
|
if [[ -f "$docker_log_file" ]]; then
|
||||||
echo "--- Last 100 lines of $docker_log_file ---"
|
echo "--- Server side failures can be located inside $docker_log_file (available from artifacts on CI) ---"
|
||||||
tail -100 "$docker_log_file"
|
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
|
||||||
42
scripts/uv-run-with-index.sh
Executable file
42
scripts/uv-run-with-index.sh
Executable file
|
|
@ -0,0 +1,42 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
# Detect current branch and target branch
|
||||||
|
# In GitHub Actions, use GITHUB_REF/GITHUB_BASE_REF
|
||||||
|
if [[ -n "${GITHUB_REF:-}" ]]; then
|
||||||
|
BRANCH="${GITHUB_REF#refs/heads/}"
|
||||||
|
else
|
||||||
|
BRANCH=$(git rev-parse --abbrev-ref HEAD 2>/dev/null || echo "")
|
||||||
|
fi
|
||||||
|
|
||||||
|
# For PRs, check the target branch
|
||||||
|
if [[ -n "${GITHUB_BASE_REF:-}" ]]; then
|
||||||
|
TARGET_BRANCH="${GITHUB_BASE_REF}"
|
||||||
|
else
|
||||||
|
TARGET_BRANCH=$(git rev-parse --abbrev-ref HEAD@{upstream} 2>/dev/null | sed 's|origin/||' || echo "")
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check if on a release branch or targeting one, or LLAMA_STACK_RELEASE_MODE is set
|
||||||
|
IS_RELEASE=false
|
||||||
|
if [[ "$BRANCH" =~ ^release-[0-9]+\.[0-9]+\.x$ ]]; then
|
||||||
|
IS_RELEASE=true
|
||||||
|
elif [[ "$TARGET_BRANCH" =~ ^release-[0-9]+\.[0-9]+\.x$ ]]; then
|
||||||
|
IS_RELEASE=true
|
||||||
|
elif [[ "${LLAMA_STACK_RELEASE_MODE:-}" == "true" ]]; then
|
||||||
|
IS_RELEASE=true
|
||||||
|
fi
|
||||||
|
|
||||||
|
# On release branches, use test.pypi as extra index for RC versions
|
||||||
|
if [[ "$IS_RELEASE" == "true" ]]; then
|
||||||
|
export UV_EXTRA_INDEX_URL="https://test.pypi.org/simple/"
|
||||||
|
export UV_INDEX_STRATEGY="unsafe-best-match"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Run uv with all arguments passed through
|
||||||
|
exec uv "$@"
|
||||||
|
|
@ -4,6 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
from typing import Annotated, Any, Literal
|
from typing import Annotated, Any, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
@ -202,7 +203,7 @@ class OpenAIResponseMessage(BaseModel):
|
||||||
scenarios.
|
scenarios.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
content: str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]
|
content: str | Sequence[OpenAIResponseInputMessageContent] | Sequence[OpenAIResponseOutputMessageContent]
|
||||||
role: Literal["system"] | Literal["developer"] | Literal["user"] | Literal["assistant"]
|
role: Literal["system"] | Literal["developer"] | Literal["user"] | Literal["assistant"]
|
||||||
type: Literal["message"] = "message"
|
type: Literal["message"] = "message"
|
||||||
|
|
||||||
|
|
@ -254,10 +255,10 @@ class OpenAIResponseOutputMessageFileSearchToolCall(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
queries: list[str]
|
queries: Sequence[str]
|
||||||
status: str
|
status: str
|
||||||
type: Literal["file_search_call"] = "file_search_call"
|
type: Literal["file_search_call"] = "file_search_call"
|
||||||
results: list[OpenAIResponseOutputMessageFileSearchToolCallResults] | None = None
|
results: Sequence[OpenAIResponseOutputMessageFileSearchToolCallResults] | None = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
@ -597,7 +598,7 @@ class OpenAIResponseObject(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
model: str
|
model: str
|
||||||
object: Literal["response"] = "response"
|
object: Literal["response"] = "response"
|
||||||
output: list[OpenAIResponseOutput]
|
output: Sequence[OpenAIResponseOutput]
|
||||||
parallel_tool_calls: bool = False
|
parallel_tool_calls: bool = False
|
||||||
previous_response_id: str | None = None
|
previous_response_id: str | None = None
|
||||||
prompt: OpenAIResponsePrompt | None = None
|
prompt: OpenAIResponsePrompt | None = None
|
||||||
|
|
@ -607,7 +608,7 @@ class OpenAIResponseObject(BaseModel):
|
||||||
# before the field was added. New responses will have this set always.
|
# before the field was added. New responses will have this set always.
|
||||||
text: OpenAIResponseText = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text"))
|
text: OpenAIResponseText = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text"))
|
||||||
top_p: float | None = None
|
top_p: float | None = None
|
||||||
tools: list[OpenAIResponseTool] | None = None
|
tools: Sequence[OpenAIResponseTool] | None = None
|
||||||
truncation: str | None = None
|
truncation: str | None = None
|
||||||
usage: OpenAIResponseUsage | None = None
|
usage: OpenAIResponseUsage | None = None
|
||||||
instructions: str | None = None
|
instructions: str | None = None
|
||||||
|
|
@ -1315,7 +1316,7 @@ class ListOpenAIResponseInputItem(BaseModel):
|
||||||
:param object: Object type identifier, always "list"
|
:param object: Object type identifier, always "list"
|
||||||
"""
|
"""
|
||||||
|
|
||||||
data: list[OpenAIResponseInput]
|
data: Sequence[OpenAIResponseInput]
|
||||||
object: Literal["list"] = "list"
|
object: Literal["list"] = "list"
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1326,7 +1327,7 @@ class OpenAIResponseObjectWithInput(OpenAIResponseObject):
|
||||||
:param input: List of input items that led to this response
|
:param input: List of input items that led to this response
|
||||||
"""
|
"""
|
||||||
|
|
||||||
input: list[OpenAIResponseInput]
|
input: Sequence[OpenAIResponseInput]
|
||||||
|
|
||||||
def to_response_object(self) -> OpenAIResponseObject:
|
def to_response_object(self) -> OpenAIResponseObject:
|
||||||
"""Convert to OpenAIResponseObject by excluding input field."""
|
"""Convert to OpenAIResponseObject by excluding input field."""
|
||||||
|
|
@ -1344,7 +1345,7 @@ class ListOpenAIResponseObject(BaseModel):
|
||||||
:param object: Object type identifier, always "list"
|
:param object: Object type identifier, always "list"
|
||||||
"""
|
"""
|
||||||
|
|
||||||
data: list[OpenAIResponseObjectWithInput]
|
data: Sequence[OpenAIResponseObjectWithInput]
|
||||||
has_more: bool
|
has_more: bool
|
||||||
first_id: str
|
first_id: str
|
||||||
last_id: str
|
last_id: str
|
||||||
|
|
|
||||||
|
|
@ -4,14 +4,21 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Protocol, runtime_checkable
|
from typing import Literal, Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
from llama_stack.apis.version import (
|
||||||
|
LLAMA_STACK_API_V1,
|
||||||
|
)
|
||||||
from llama_stack.providers.datatypes import HealthStatus
|
from llama_stack.providers.datatypes import HealthStatus
|
||||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
# Valid values for the route filter parameter.
|
||||||
|
# Actual API levels: v1, v1alpha, v1beta (filters by level, excludes deprecated)
|
||||||
|
# Special filter value: "deprecated" (shows deprecated routes regardless of level)
|
||||||
|
ApiFilter = Literal["v1", "v1alpha", "v1beta", "deprecated"]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RouteInfo(BaseModel):
|
class RouteInfo(BaseModel):
|
||||||
|
|
@ -64,11 +71,12 @@ class Inspect(Protocol):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@webmethod(route="/inspect/routes", method="GET", level=LLAMA_STACK_API_V1)
|
@webmethod(route="/inspect/routes", method="GET", level=LLAMA_STACK_API_V1)
|
||||||
async def list_routes(self) -> ListRoutesResponse:
|
async def list_routes(self, api_filter: ApiFilter | None = None) -> ListRoutesResponse:
|
||||||
"""List routes.
|
"""List routes.
|
||||||
|
|
||||||
List all available API routes with their methods and implementing providers.
|
List all available API routes with their methods and implementing providers.
|
||||||
|
|
||||||
|
:param api_filter: Optional filter to control which routes are returned. Can be an API level ('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level, or 'deprecated' to show deprecated routes across all levels. If not specified, returns only non-deprecated v1 routes.
|
||||||
:returns: Response containing information about all available routes.
|
:returns: Response containing information about all available routes.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
|
||||||
|
|
@ -90,12 +90,14 @@ class OpenAIModel(BaseModel):
|
||||||
:object: The object type, which will be "model"
|
:object: The object type, which will be "model"
|
||||||
:created: The Unix timestamp in seconds when the model was created
|
:created: The Unix timestamp in seconds when the model was created
|
||||||
:owned_by: The owner of the model
|
:owned_by: The owner of the model
|
||||||
|
:custom_metadata: Llama Stack-specific metadata including model_type, provider info, and additional metadata
|
||||||
"""
|
"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
object: Literal["model"] = "model"
|
object: Literal["model"] = "model"
|
||||||
created: int
|
created: int
|
||||||
owned_by: str
|
owned_by: str
|
||||||
|
custom_metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
class OpenAIListModelsResponse(BaseModel):
|
class OpenAIListModelsResponse(BaseModel):
|
||||||
|
|
@ -113,7 +115,7 @@ class Models(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
@webmethod(route="/openai/v1/models", method="GET", level=LLAMA_STACK_API_V1, deprecated=True)
|
@webmethod(route="/openai/v1/models", method="GET", level=LLAMA_STACK_API_V1)
|
||||||
async def openai_list_models(self) -> OpenAIListModelsResponse:
|
async def openai_list_models(self) -> OpenAIListModelsResponse:
|
||||||
"""List models using the OpenAI API.
|
"""List models using the OpenAI API.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from .synthetic_data_generation import *
|
|
||||||
|
|
@ -1,77 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Any, Protocol
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from llama_stack.apis.inference import Message
|
|
||||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
|
||||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
|
||||||
|
|
||||||
|
|
||||||
class FilteringFunction(Enum):
|
|
||||||
"""The type of filtering function.
|
|
||||||
|
|
||||||
:cvar none: No filtering applied, accept all generated synthetic data
|
|
||||||
:cvar random: Random sampling of generated data points
|
|
||||||
:cvar top_k: Keep only the top-k highest scoring synthetic data samples
|
|
||||||
:cvar top_p: Nucleus-style filtering, keep samples exceeding cumulative score threshold
|
|
||||||
:cvar top_k_top_p: Combined top-k and top-p filtering strategy
|
|
||||||
:cvar sigmoid: Apply sigmoid function for probability-based filtering
|
|
||||||
"""
|
|
||||||
|
|
||||||
none = "none"
|
|
||||||
random = "random"
|
|
||||||
top_k = "top_k"
|
|
||||||
top_p = "top_p"
|
|
||||||
top_k_top_p = "top_k_top_p"
|
|
||||||
sigmoid = "sigmoid"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class SyntheticDataGenerationRequest(BaseModel):
|
|
||||||
"""Request to generate synthetic data. A small batch of prompts and a filtering function
|
|
||||||
|
|
||||||
:param dialogs: List of conversation messages to use as input for synthetic data generation
|
|
||||||
:param filtering_function: Type of filtering to apply to generated synthetic data samples
|
|
||||||
:param model: (Optional) The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint
|
|
||||||
"""
|
|
||||||
|
|
||||||
dialogs: list[Message]
|
|
||||||
filtering_function: FilteringFunction = FilteringFunction.none
|
|
||||||
model: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class SyntheticDataGenerationResponse(BaseModel):
|
|
||||||
"""Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold.
|
|
||||||
|
|
||||||
:param synthetic_data: List of generated synthetic data samples that passed the filtering criteria
|
|
||||||
:param statistics: (Optional) Statistical information about the generation process and filtering results
|
|
||||||
"""
|
|
||||||
|
|
||||||
synthetic_data: list[dict[str, Any]]
|
|
||||||
statistics: dict[str, Any] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class SyntheticDataGeneration(Protocol):
|
|
||||||
@webmethod(route="/synthetic-data-generation/generate", level=LLAMA_STACK_API_V1)
|
|
||||||
def synthetic_data_generate(
|
|
||||||
self,
|
|
||||||
dialogs: list[Message],
|
|
||||||
filtering_function: FilteringFunction = FilteringFunction.none,
|
|
||||||
model: str | None = None,
|
|
||||||
) -> SyntheticDataGenerationResponse:
|
|
||||||
"""Generate synthetic data based on input dialogs and apply filtering.
|
|
||||||
|
|
||||||
:param dialogs: List of conversation messages to use as input for synthetic data generation
|
|
||||||
:param filtering_function: Type of filtering to apply to generated synthetic data samples
|
|
||||||
:param model: (Optional) The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint
|
|
||||||
:returns: Response containing filtered synthetic data samples and optional statistics
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
@ -8,7 +8,6 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import uuid
|
|
||||||
from typing import Annotated, Any, Literal, Protocol, runtime_checkable
|
from typing import Annotated, Any, Literal, Protocol, runtime_checkable
|
||||||
|
|
||||||
from fastapi import Body
|
from fastapi import Body
|
||||||
|
|
@ -18,7 +17,6 @@ from llama_stack.apis.inference import InterleavedContent
|
||||||
from llama_stack.apis.vector_stores import VectorStore
|
from llama_stack.apis.vector_stores import VectorStore
|
||||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||||
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
from llama_stack.core.telemetry.trace_protocol import trace_protocol
|
||||||
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
|
|
||||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
from llama_stack.strong_typing.schema import register_schema
|
from llama_stack.strong_typing.schema import register_schema
|
||||||
|
|
||||||
|
|
@ -61,38 +59,19 @@ class Chunk(BaseModel):
|
||||||
"""
|
"""
|
||||||
A chunk of content that can be inserted into a vector database.
|
A chunk of content that can be inserted into a vector database.
|
||||||
:param content: The content of the chunk, which can be interleaved text, images, or other types.
|
:param content: The content of the chunk, which can be interleaved text, images, or other types.
|
||||||
:param embedding: Optional embedding for the chunk. If not provided, it will be computed later.
|
:param chunk_id: Unique identifier for the chunk. Must be provided explicitly.
|
||||||
:param metadata: Metadata associated with the chunk that will be used in the model context during inference.
|
:param metadata: Metadata associated with the chunk that will be used in the model context during inference.
|
||||||
:param stored_chunk_id: The chunk ID that is stored in the vector database. Used for backend functionality.
|
:param embedding: Optional embedding for the chunk. If not provided, it will be computed later.
|
||||||
:param chunk_metadata: Metadata for the chunk that will NOT be used in the context during inference.
|
:param chunk_metadata: Metadata for the chunk that will NOT be used in the context during inference.
|
||||||
The `chunk_metadata` is required backend functionality.
|
The `chunk_metadata` is required backend functionality.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
content: InterleavedContent
|
content: InterleavedContent
|
||||||
|
chunk_id: str
|
||||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
embedding: list[float] | None = None
|
embedding: list[float] | None = None
|
||||||
# The alias parameter serializes the field as "chunk_id" in JSON but keeps the internal name as "stored_chunk_id"
|
|
||||||
stored_chunk_id: str | None = Field(default=None, alias="chunk_id")
|
|
||||||
chunk_metadata: ChunkMetadata | None = None
|
chunk_metadata: ChunkMetadata | None = None
|
||||||
|
|
||||||
model_config = {"populate_by_name": True}
|
|
||||||
|
|
||||||
def model_post_init(self, __context):
|
|
||||||
# Extract chunk_id from metadata if present
|
|
||||||
if self.metadata and "chunk_id" in self.metadata:
|
|
||||||
self.stored_chunk_id = self.metadata.pop("chunk_id")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def chunk_id(self) -> str:
|
|
||||||
"""Returns the chunk ID, which is either an input `chunk_id` or a generated one if not set."""
|
|
||||||
if self.stored_chunk_id:
|
|
||||||
return self.stored_chunk_id
|
|
||||||
|
|
||||||
if "document_id" in self.metadata:
|
|
||||||
return generate_chunk_id(self.metadata["document_id"], str(self.content))
|
|
||||||
|
|
||||||
return generate_chunk_id(str(uuid.uuid4()), str(self.content))
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def document_id(self) -> str | None:
|
def document_id(self) -> str | None:
|
||||||
"""Returns the document_id from either metadata or chunk_metadata, with metadata taking precedence."""
|
"""Returns the document_id from either metadata or chunk_metadata, with metadata taking precedence."""
|
||||||
|
|
|
||||||
|
|
@ -8,16 +8,30 @@ import argparse
|
||||||
import os
|
import os
|
||||||
import ssl
|
import ssl
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
import yaml
|
import yaml
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack.cli.stack.utils import ImageType
|
from llama_stack.cli.stack.utils import ImageType
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
from llama_stack.core.datatypes import StackRunConfig
|
from llama_stack.core.datatypes import Api, Provider, StackRunConfig
|
||||||
|
from llama_stack.core.distribution import get_provider_registry
|
||||||
from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
|
from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
|
||||||
|
from llama_stack.core.storage.datatypes import (
|
||||||
|
InferenceStoreReference,
|
||||||
|
KVStoreReference,
|
||||||
|
ServerStoresConfig,
|
||||||
|
SqliteKVStoreConfig,
|
||||||
|
SqliteSqlStoreConfig,
|
||||||
|
SqlStoreReference,
|
||||||
|
StorageConfig,
|
||||||
|
)
|
||||||
|
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||||
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
|
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
|
||||||
|
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||||
from llama_stack.log import LoggingConfig, get_logger
|
from llama_stack.log import LoggingConfig, get_logger
|
||||||
|
|
||||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
|
@ -68,6 +82,12 @@ class StackRun(Subcommand):
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Start the UI server",
|
help="Start the UI server",
|
||||||
)
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--providers",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Run a stack with only a list of providers. This list is formatted like: api1=provider1,api1=provider2,api2=provider3. Where there can be multiple providers per API.",
|
||||||
|
)
|
||||||
|
|
||||||
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
|
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
|
||||||
import yaml
|
import yaml
|
||||||
|
|
@ -93,6 +113,55 @@ class StackRun(Subcommand):
|
||||||
config_file = resolve_config_or_distro(args.config, Mode.RUN)
|
config_file = resolve_config_or_distro(args.config, Mode.RUN)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
self.parser.error(str(e))
|
self.parser.error(str(e))
|
||||||
|
elif args.providers:
|
||||||
|
provider_list: dict[str, list[Provider]] = dict()
|
||||||
|
for api_provider in args.providers.split(","):
|
||||||
|
if "=" not in api_provider:
|
||||||
|
cprint(
|
||||||
|
"Could not parse `--providers`. Please ensure the list is in the format api1=provider1,api2=provider2",
|
||||||
|
color="red",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
api, provider_type = api_provider.split("=")
|
||||||
|
providers_for_api = get_provider_registry().get(Api(api), None)
|
||||||
|
if providers_for_api is None:
|
||||||
|
cprint(
|
||||||
|
f"{api} is not a valid API.",
|
||||||
|
color="red",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
if provider_type in providers_for_api:
|
||||||
|
config_type = instantiate_class_type(providers_for_api[provider_type].config_class)
|
||||||
|
if config_type is not None and hasattr(config_type, "sample_run_config"):
|
||||||
|
config = config_type.sample_run_config(__distro_dir__="~/.llama/distributions/providers-run")
|
||||||
|
else:
|
||||||
|
config = {}
|
||||||
|
provider = Provider(
|
||||||
|
provider_type=provider_type,
|
||||||
|
config=config,
|
||||||
|
provider_id=provider_type.split("::")[1],
|
||||||
|
)
|
||||||
|
provider_list.setdefault(api, []).append(provider)
|
||||||
|
else:
|
||||||
|
cprint(
|
||||||
|
f"{provider} is not a valid provider for the {api} API.",
|
||||||
|
color="red",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
run_config = self._generate_run_config_from_providers(providers=provider_list)
|
||||||
|
config_dict = run_config.model_dump(mode="json")
|
||||||
|
|
||||||
|
# Write config to disk in providers-run directory
|
||||||
|
distro_dir = DISTRIBS_BASE_DIR / "providers-run"
|
||||||
|
config_file = distro_dir / "run.yaml"
|
||||||
|
|
||||||
|
logger.info(f"Writing generated config to: {config_file}")
|
||||||
|
with open(config_file, "w") as f:
|
||||||
|
yaml.dump(config_dict, f, default_flow_style=False, sort_keys=False)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
config_file = None
|
config_file = None
|
||||||
|
|
||||||
|
|
@ -106,7 +175,8 @@ class StackRun(Subcommand):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
config = parse_and_maybe_upgrade_config(config_dict)
|
config = parse_and_maybe_upgrade_config(config_dict)
|
||||||
if not os.path.exists(str(config.external_providers_dir)):
|
# Create external_providers_dir if it's specified and doesn't exist
|
||||||
|
if config.external_providers_dir and not os.path.exists(str(config.external_providers_dir)):
|
||||||
os.makedirs(str(config.external_providers_dir), exist_ok=True)
|
os.makedirs(str(config.external_providers_dir), exist_ok=True)
|
||||||
except AttributeError as e:
|
except AttributeError as e:
|
||||||
self.parser.error(f"failed to parse config file '{config_file}':\n {e}")
|
self.parser.error(f"failed to parse config file '{config_file}':\n {e}")
|
||||||
|
|
@ -127,7 +197,7 @@ class StackRun(Subcommand):
|
||||||
config = StackRunConfig(**cast_image_name_to_string(replace_env_vars(config_contents)))
|
config = StackRunConfig(**cast_image_name_to_string(replace_env_vars(config_contents)))
|
||||||
|
|
||||||
port = args.port or config.server.port
|
port = args.port or config.server.port
|
||||||
host = config.server.host or ["::", "0.0.0.0"]
|
host = config.server.host or "0.0.0.0"
|
||||||
|
|
||||||
# Set the config file in environment so create_app can find it
|
# Set the config file in environment so create_app can find it
|
||||||
os.environ["LLAMA_STACK_CONFIG"] = str(config_file)
|
os.environ["LLAMA_STACK_CONFIG"] = str(config_file)
|
||||||
|
|
@ -139,6 +209,7 @@ class StackRun(Subcommand):
|
||||||
"lifespan": "on",
|
"lifespan": "on",
|
||||||
"log_level": logger.getEffectiveLevel(),
|
"log_level": logger.getEffectiveLevel(),
|
||||||
"log_config": logger_config,
|
"log_config": logger_config,
|
||||||
|
"workers": config.server.workers,
|
||||||
}
|
}
|
||||||
|
|
||||||
keyfile = config.server.tls_keyfile
|
keyfile = config.server.tls_keyfile
|
||||||
|
|
@ -212,3 +283,44 @@ class StackRun(Subcommand):
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to start UI development server in {ui_dir}: {e}")
|
logger.error(f"Failed to start UI development server in {ui_dir}: {e}")
|
||||||
|
|
||||||
|
def _generate_run_config_from_providers(self, providers: dict[str, list[Provider]]):
|
||||||
|
apis = list(providers.keys())
|
||||||
|
distro_dir = DISTRIBS_BASE_DIR / "providers-run"
|
||||||
|
# need somewhere to put the storage.
|
||||||
|
os.makedirs(distro_dir, exist_ok=True)
|
||||||
|
storage = StorageConfig(
|
||||||
|
backends={
|
||||||
|
"kv_default": SqliteKVStoreConfig(
|
||||||
|
db_path=f"${{env.SQLITE_STORE_DIR:={distro_dir}}}/kvstore.db",
|
||||||
|
),
|
||||||
|
"sql_default": SqliteSqlStoreConfig(
|
||||||
|
db_path=f"${{env.SQLITE_STORE_DIR:={distro_dir}}}/sql_store.db",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
stores=ServerStoresConfig(
|
||||||
|
metadata=KVStoreReference(
|
||||||
|
backend="kv_default",
|
||||||
|
namespace="registry",
|
||||||
|
),
|
||||||
|
inference=InferenceStoreReference(
|
||||||
|
backend="sql_default",
|
||||||
|
table_name="inference_store",
|
||||||
|
),
|
||||||
|
conversations=SqlStoreReference(
|
||||||
|
backend="sql_default",
|
||||||
|
table_name="openai_conversations",
|
||||||
|
),
|
||||||
|
prompts=KVStoreReference(
|
||||||
|
backend="kv_default",
|
||||||
|
namespace="prompts",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return StackRunConfig(
|
||||||
|
image_name="providers-run",
|
||||||
|
apis=apis,
|
||||||
|
providers=providers,
|
||||||
|
storage=storage,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,6 @@ from llama_stack.core.distribution import (
|
||||||
get_provider_registry,
|
get_provider_registry,
|
||||||
)
|
)
|
||||||
from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
|
from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
|
||||||
from llama_stack.core.utils.config_dirs import EXTERNAL_PROVIDERS_DIR
|
|
||||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||||
from llama_stack.core.utils.prompt_for_config import prompt_for_config
|
from llama_stack.core.utils.prompt_for_config import prompt_for_config
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
@ -194,19 +193,11 @@ def upgrade_from_routing_table(
|
||||||
|
|
||||||
|
|
||||||
def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfig:
|
def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfig:
|
||||||
version = config_dict.get("version", None)
|
|
||||||
if version == LLAMA_STACK_RUN_CONFIG_VERSION:
|
|
||||||
processed_config_dict = replace_env_vars(config_dict)
|
|
||||||
return StackRunConfig(**cast_image_name_to_string(processed_config_dict))
|
|
||||||
|
|
||||||
if "routing_table" in config_dict:
|
if "routing_table" in config_dict:
|
||||||
logger.info("Upgrading config...")
|
logger.info("Upgrading config...")
|
||||||
config_dict = upgrade_from_routing_table(config_dict)
|
config_dict = upgrade_from_routing_table(config_dict)
|
||||||
|
|
||||||
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION
|
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION
|
||||||
|
|
||||||
if not config_dict.get("external_providers_dir", None):
|
|
||||||
config_dict["external_providers_dir"] = EXTERNAL_PROVIDERS_DIR
|
|
||||||
|
|
||||||
processed_config_dict = replace_env_vars(config_dict)
|
processed_config_dict = replace_env_vars(config_dict)
|
||||||
return StackRunConfig(**cast_image_name_to_string(processed_config_dict))
|
return StackRunConfig(**cast_image_name_to_string(processed_config_dict))
|
||||||
|
|
|
||||||
|
|
@ -473,6 +473,10 @@ class ServerConfig(BaseModel):
|
||||||
"- true: Enable localhost CORS for development\n"
|
"- true: Enable localhost CORS for development\n"
|
||||||
"- {allow_origins: [...], allow_methods: [...], ...}: Full configuration",
|
"- {allow_origins: [...], allow_methods: [...], ...}: Full configuration",
|
||||||
)
|
)
|
||||||
|
workers: int = Field(
|
||||||
|
default=1,
|
||||||
|
description="Number of workers to use for the server",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class StackRunConfig(BaseModel):
|
class StackRunConfig(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ from llama_stack.apis.inspect import (
|
||||||
RouteInfo,
|
RouteInfo,
|
||||||
VersionInfo,
|
VersionInfo,
|
||||||
)
|
)
|
||||||
|
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||||
from llama_stack.core.datatypes import StackRunConfig
|
from llama_stack.core.datatypes import StackRunConfig
|
||||||
from llama_stack.core.external import load_external_apis
|
from llama_stack.core.external import load_external_apis
|
||||||
from llama_stack.core.server.routes import get_all_api_routes
|
from llama_stack.core.server.routes import get_all_api_routes
|
||||||
|
|
@ -39,9 +40,21 @@ class DistributionInspectImpl(Inspect):
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def list_routes(self) -> ListRoutesResponse:
|
async def list_routes(self, api_filter: str | None = None) -> ListRoutesResponse:
|
||||||
run_config: StackRunConfig = self.config.run_config
|
run_config: StackRunConfig = self.config.run_config
|
||||||
|
|
||||||
|
# Helper function to determine if a route should be included based on api_filter
|
||||||
|
def should_include_route(webmethod) -> bool:
|
||||||
|
if api_filter is None:
|
||||||
|
# Default: only non-deprecated v1 APIs
|
||||||
|
return not webmethod.deprecated and webmethod.level == LLAMA_STACK_API_V1
|
||||||
|
elif api_filter == "deprecated":
|
||||||
|
# Special filter: show deprecated routes regardless of their actual level
|
||||||
|
return bool(webmethod.deprecated)
|
||||||
|
else:
|
||||||
|
# Filter by API level (non-deprecated routes only)
|
||||||
|
return not webmethod.deprecated and webmethod.level == api_filter
|
||||||
|
|
||||||
ret = []
|
ret = []
|
||||||
external_apis = load_external_apis(run_config)
|
external_apis = load_external_apis(run_config)
|
||||||
all_endpoints = get_all_api_routes(external_apis)
|
all_endpoints = get_all_api_routes(external_apis)
|
||||||
|
|
@ -55,8 +68,8 @@ class DistributionInspectImpl(Inspect):
|
||||||
method=next(iter([m for m in e.methods if m != "HEAD"])),
|
method=next(iter([m for m in e.methods if m != "HEAD"])),
|
||||||
provider_types=[], # These APIs don't have "real" providers - they're internal to the stack
|
provider_types=[], # These APIs don't have "real" providers - they're internal to the stack
|
||||||
)
|
)
|
||||||
for e, _ in endpoints
|
for e, webmethod in endpoints
|
||||||
if e.methods is not None
|
if e.methods is not None and should_include_route(webmethod)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
@ -69,8 +82,8 @@ class DistributionInspectImpl(Inspect):
|
||||||
method=next(iter([m for m in e.methods if m != "HEAD"])),
|
method=next(iter([m for m in e.methods if m != "HEAD"])),
|
||||||
provider_types=[p.provider_type for p in providers],
|
provider_types=[p.provider_type for p in providers],
|
||||||
)
|
)
|
||||||
for e, _ in endpoints
|
for e, webmethod in endpoints
|
||||||
if e.methods is not None
|
if e.methods is not None and should_include_route(webmethod)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from collections.abc import AsyncGenerator, AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from typing import Annotated, Any
|
from typing import Annotated, Any
|
||||||
|
|
||||||
|
|
@ -15,20 +15,10 @@ from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatC
|
||||||
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
|
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
|
||||||
from pydantic import TypeAdapter
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
|
||||||
InterleavedContent,
|
|
||||||
)
|
|
||||||
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
|
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionResponse,
|
|
||||||
ChatCompletionResponseEventType,
|
|
||||||
ChatCompletionResponseStreamChunk,
|
|
||||||
CompletionMessage,
|
|
||||||
CompletionResponse,
|
|
||||||
CompletionResponseStreamChunk,
|
|
||||||
Inference,
|
Inference,
|
||||||
ListOpenAIChatCompletionResponse,
|
ListOpenAIChatCompletionResponse,
|
||||||
Message,
|
|
||||||
OpenAIAssistantMessageParam,
|
OpenAIAssistantMessageParam,
|
||||||
OpenAIChatCompletion,
|
OpenAIChatCompletion,
|
||||||
OpenAIChatCompletionChunk,
|
OpenAIChatCompletionChunk,
|
||||||
|
|
@ -45,15 +35,13 @@ from llama_stack.apis.inference import (
|
||||||
OpenAIMessageParam,
|
OpenAIMessageParam,
|
||||||
Order,
|
Order,
|
||||||
RerankResponse,
|
RerankResponse,
|
||||||
StopReason,
|
|
||||||
ToolPromptFormat,
|
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference.inference import (
|
from llama_stack.apis.inference.inference import (
|
||||||
OpenAIChatCompletionContentPartImageParam,
|
OpenAIChatCompletionContentPartImageParam,
|
||||||
OpenAIChatCompletionContentPartTextParam,
|
OpenAIChatCompletionContentPartTextParam,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model, ModelType
|
from llama_stack.apis.models import ModelType
|
||||||
from llama_stack.core.telemetry.telemetry import MetricEvent, MetricInResponse
|
from llama_stack.core.telemetry.telemetry import MetricEvent
|
||||||
from llama_stack.core.telemetry.tracing import enqueue_event, get_current_span
|
from llama_stack.core.telemetry.tracing import enqueue_event, get_current_span
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||||
|
|
@ -153,35 +141,6 @@ class InferenceRouter(Inference):
|
||||||
)
|
)
|
||||||
return metric_events
|
return metric_events
|
||||||
|
|
||||||
async def _compute_and_log_token_usage(
|
|
||||||
self,
|
|
||||||
prompt_tokens: int,
|
|
||||||
completion_tokens: int,
|
|
||||||
total_tokens: int,
|
|
||||||
model: Model,
|
|
||||||
) -> list[MetricInResponse]:
|
|
||||||
metrics = self._construct_metrics(
|
|
||||||
prompt_tokens, completion_tokens, total_tokens, model.model_id, model.provider_id
|
|
||||||
)
|
|
||||||
if self.telemetry_enabled:
|
|
||||||
for metric in metrics:
|
|
||||||
enqueue_event(metric)
|
|
||||||
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
|
|
||||||
|
|
||||||
async def _count_tokens(
|
|
||||||
self,
|
|
||||||
messages: list[Message] | InterleavedContent,
|
|
||||||
tool_prompt_format: ToolPromptFormat | None = None,
|
|
||||||
) -> int | None:
|
|
||||||
if not hasattr(self, "formatter") or self.formatter is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if isinstance(messages, list):
|
|
||||||
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
|
|
||||||
else:
|
|
||||||
encoded = self.formatter.encode_content(messages)
|
|
||||||
return len(encoded.tokens) if encoded and encoded.tokens else 0
|
|
||||||
|
|
||||||
async def _get_model_provider(self, model_id: str, expected_model_type: str) -> tuple[Inference, str]:
|
async def _get_model_provider(self, model_id: str, expected_model_type: str) -> tuple[Inference, str]:
|
||||||
model = await self.routing_table.get_object_by_identifier("model", model_id)
|
model = await self.routing_table.get_object_by_identifier("model", model_id)
|
||||||
if model:
|
if model:
|
||||||
|
|
@ -375,121 +334,6 @@ class InferenceRouter(Inference):
|
||||||
)
|
)
|
||||||
return health_statuses
|
return health_statuses
|
||||||
|
|
||||||
async def stream_tokens_and_compute_metrics(
|
|
||||||
self,
|
|
||||||
response,
|
|
||||||
prompt_tokens,
|
|
||||||
fully_qualified_model_id: str,
|
|
||||||
provider_id: str,
|
|
||||||
tool_prompt_format: ToolPromptFormat | None = None,
|
|
||||||
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None] | AsyncGenerator[CompletionResponseStreamChunk, None]:
|
|
||||||
completion_text = ""
|
|
||||||
async for chunk in response:
|
|
||||||
complete = False
|
|
||||||
if hasattr(chunk, "event"): # only ChatCompletions have .event
|
|
||||||
if chunk.event.event_type == ChatCompletionResponseEventType.progress:
|
|
||||||
if chunk.event.delta.type == "text":
|
|
||||||
completion_text += chunk.event.delta.text
|
|
||||||
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
|
|
||||||
complete = True
|
|
||||||
completion_tokens = await self._count_tokens(
|
|
||||||
[
|
|
||||||
CompletionMessage(
|
|
||||||
content=completion_text,
|
|
||||||
stop_reason=StopReason.end_of_turn,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
tool_prompt_format=tool_prompt_format,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if hasattr(chunk, "delta"):
|
|
||||||
completion_text += chunk.delta
|
|
||||||
if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry_enabled:
|
|
||||||
complete = True
|
|
||||||
completion_tokens = await self._count_tokens(completion_text)
|
|
||||||
# if we are done receiving tokens
|
|
||||||
if complete:
|
|
||||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
|
||||||
|
|
||||||
# Create a separate span for streaming completion metrics
|
|
||||||
if self.telemetry_enabled:
|
|
||||||
# Log metrics in the new span context
|
|
||||||
completion_metrics = self._construct_metrics(
|
|
||||||
prompt_tokens=prompt_tokens,
|
|
||||||
completion_tokens=completion_tokens,
|
|
||||||
total_tokens=total_tokens,
|
|
||||||
fully_qualified_model_id=fully_qualified_model_id,
|
|
||||||
provider_id=provider_id,
|
|
||||||
)
|
|
||||||
for metric in completion_metrics:
|
|
||||||
if metric.metric in [
|
|
||||||
"completion_tokens",
|
|
||||||
"total_tokens",
|
|
||||||
]: # Only log completion and total tokens
|
|
||||||
enqueue_event(metric)
|
|
||||||
|
|
||||||
# Return metrics in response
|
|
||||||
async_metrics = [
|
|
||||||
MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics
|
|
||||||
]
|
|
||||||
chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics
|
|
||||||
else:
|
|
||||||
# Fallback if no telemetry
|
|
||||||
completion_metrics = self._construct_metrics(
|
|
||||||
prompt_tokens or 0,
|
|
||||||
completion_tokens or 0,
|
|
||||||
total_tokens,
|
|
||||||
fully_qualified_model_id=fully_qualified_model_id,
|
|
||||||
provider_id=provider_id,
|
|
||||||
)
|
|
||||||
async_metrics = [
|
|
||||||
MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics
|
|
||||||
]
|
|
||||||
chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
async def count_tokens_and_compute_metrics(
|
|
||||||
self,
|
|
||||||
response: ChatCompletionResponse | CompletionResponse,
|
|
||||||
prompt_tokens,
|
|
||||||
fully_qualified_model_id: str,
|
|
||||||
provider_id: str,
|
|
||||||
tool_prompt_format: ToolPromptFormat | None = None,
|
|
||||||
):
|
|
||||||
if isinstance(response, ChatCompletionResponse):
|
|
||||||
content = [response.completion_message]
|
|
||||||
else:
|
|
||||||
content = response.content
|
|
||||||
completion_tokens = await self._count_tokens(messages=content, tool_prompt_format=tool_prompt_format)
|
|
||||||
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
|
|
||||||
|
|
||||||
# Create a separate span for completion metrics
|
|
||||||
if self.telemetry_enabled:
|
|
||||||
# Log metrics in the new span context
|
|
||||||
completion_metrics = self._construct_metrics(
|
|
||||||
prompt_tokens=prompt_tokens,
|
|
||||||
completion_tokens=completion_tokens,
|
|
||||||
total_tokens=total_tokens,
|
|
||||||
fully_qualified_model_id=fully_qualified_model_id,
|
|
||||||
provider_id=provider_id,
|
|
||||||
)
|
|
||||||
for metric in completion_metrics:
|
|
||||||
if metric.metric in ["completion_tokens", "total_tokens"]: # Only log completion and total tokens
|
|
||||||
enqueue_event(metric)
|
|
||||||
|
|
||||||
# Return metrics in response
|
|
||||||
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics]
|
|
||||||
|
|
||||||
# Fallback if no telemetry
|
|
||||||
metrics = self._construct_metrics(
|
|
||||||
prompt_tokens or 0,
|
|
||||||
completion_tokens or 0,
|
|
||||||
total_tokens,
|
|
||||||
fully_qualified_model_id=fully_qualified_model_id,
|
|
||||||
provider_id=provider_id,
|
|
||||||
)
|
|
||||||
return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics]
|
|
||||||
|
|
||||||
async def stream_tokens_and_compute_metrics_openai_chat(
|
async def stream_tokens_and_compute_metrics_openai_chat(
|
||||||
self,
|
self,
|
||||||
response: AsyncIterator[OpenAIChatCompletionChunk],
|
response: AsyncIterator[OpenAIChatCompletionChunk],
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,8 @@ from llama_stack.core.datatypes import (
|
||||||
ModelWithOwner,
|
ModelWithOwner,
|
||||||
RegistryEntrySource,
|
RegistryEntrySource,
|
||||||
)
|
)
|
||||||
|
from llama_stack.core.request_headers import PROVIDER_DATA_VAR, NeedsRequestProviderData
|
||||||
|
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl, lookup_model
|
from .common import CommonRoutingTableImpl, lookup_model
|
||||||
|
|
@ -42,19 +44,104 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
|
|
||||||
await self.update_registered_models(provider_id, models)
|
await self.update_registered_models(provider_id, models)
|
||||||
|
|
||||||
|
async def _get_dynamic_models_from_provider_data(self) -> list[Model]:
|
||||||
|
"""
|
||||||
|
Fetch models from providers that have credentials in the current request's provider_data.
|
||||||
|
|
||||||
|
This allows users to see models available to them from providers that require
|
||||||
|
per-request API keys (via X-LlamaStack-Provider-Data header).
|
||||||
|
|
||||||
|
Returns models with fully qualified identifiers (provider_id/model_id) but does NOT
|
||||||
|
cache them in the registry since they are user-specific.
|
||||||
|
"""
|
||||||
|
provider_data = PROVIDER_DATA_VAR.get()
|
||||||
|
if not provider_data:
|
||||||
|
return []
|
||||||
|
|
||||||
|
dynamic_models = []
|
||||||
|
|
||||||
|
for provider_id, provider in self.impls_by_provider_id.items():
|
||||||
|
# Check if this provider supports provider_data
|
||||||
|
if not isinstance(provider, NeedsRequestProviderData):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if provider has a validator (some providers like ollama don't need per-request credentials)
|
||||||
|
spec = getattr(provider, "__provider_spec__", None)
|
||||||
|
if not spec or not getattr(spec, "provider_data_validator", None):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Validate provider_data silently - we're speculatively checking all providers
|
||||||
|
# so validation failures are expected when user didn't provide keys for this provider
|
||||||
|
try:
|
||||||
|
validator = instantiate_class_type(spec.provider_data_validator)
|
||||||
|
validator(**provider_data)
|
||||||
|
except Exception:
|
||||||
|
# User didn't provide credentials for this provider - skip silently
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Validation succeeded! User has credentials for this provider
|
||||||
|
# Now try to list models
|
||||||
|
try:
|
||||||
|
models = await provider.list_models()
|
||||||
|
if not models:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Ensure models have fully qualified identifiers with provider_id prefix
|
||||||
|
for model in models:
|
||||||
|
# Only add prefix if model identifier doesn't already have it
|
||||||
|
if not model.identifier.startswith(f"{provider_id}/"):
|
||||||
|
model.identifier = f"{provider_id}/{model.provider_resource_id}"
|
||||||
|
|
||||||
|
dynamic_models.append(model)
|
||||||
|
|
||||||
|
logger.debug(f"Fetched {len(models)} models from provider {provider_id} using provider_data")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Failed to list models from provider {provider_id} with provider_data: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
return dynamic_models
|
||||||
|
|
||||||
async def list_models(self) -> ListModelsResponse:
|
async def list_models(self) -> ListModelsResponse:
|
||||||
return ListModelsResponse(data=await self.get_all_with_type("model"))
|
# Get models from registry
|
||||||
|
registry_models = await self.get_all_with_type("model")
|
||||||
|
|
||||||
|
# Get additional models available via provider_data (user-specific, not cached)
|
||||||
|
dynamic_models = await self._get_dynamic_models_from_provider_data()
|
||||||
|
|
||||||
|
# Combine, avoiding duplicates (registry takes precedence)
|
||||||
|
registry_identifiers = {m.identifier for m in registry_models}
|
||||||
|
unique_dynamic_models = [m for m in dynamic_models if m.identifier not in registry_identifiers]
|
||||||
|
|
||||||
|
return ListModelsResponse(data=registry_models + unique_dynamic_models)
|
||||||
|
|
||||||
async def openai_list_models(self) -> OpenAIListModelsResponse:
|
async def openai_list_models(self) -> OpenAIListModelsResponse:
|
||||||
models = await self.get_all_with_type("model")
|
# Get models from registry
|
||||||
|
registry_models = await self.get_all_with_type("model")
|
||||||
|
|
||||||
|
# Get additional models available via provider_data (user-specific, not cached)
|
||||||
|
dynamic_models = await self._get_dynamic_models_from_provider_data()
|
||||||
|
|
||||||
|
# Combine, avoiding duplicates (registry takes precedence)
|
||||||
|
registry_identifiers = {m.identifier for m in registry_models}
|
||||||
|
unique_dynamic_models = [m for m in dynamic_models if m.identifier not in registry_identifiers]
|
||||||
|
|
||||||
|
all_models = registry_models + unique_dynamic_models
|
||||||
|
|
||||||
openai_models = [
|
openai_models = [
|
||||||
OpenAIModel(
|
OpenAIModel(
|
||||||
id=model.identifier,
|
id=model.identifier,
|
||||||
object="model",
|
object="model",
|
||||||
created=int(time.time()),
|
created=int(time.time()),
|
||||||
owned_by="llama_stack",
|
owned_by="llama_stack",
|
||||||
|
custom_metadata={
|
||||||
|
"model_type": model.model_type,
|
||||||
|
"provider_id": model.provider_id,
|
||||||
|
"provider_resource_id": model.provider_resource_id,
|
||||||
|
**model.metadata,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
for model in models
|
for model in all_models
|
||||||
]
|
]
|
||||||
return OpenAIListModelsResponse(data=openai_models)
|
return OpenAIListModelsResponse(data=openai_models)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@ from typing import Any
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from llama_stack.apis.agents import Agents
|
from llama_stack.apis.agents import Agents
|
||||||
|
from llama_stack.apis.batches import Batches
|
||||||
from llama_stack.apis.benchmarks import Benchmarks
|
from llama_stack.apis.benchmarks import Benchmarks
|
||||||
from llama_stack.apis.conversations import Conversations
|
from llama_stack.apis.conversations import Conversations
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
|
|
@ -30,7 +31,6 @@ from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.apis.scoring import Scoring
|
from llama_stack.apis.scoring import Scoring
|
||||||
from llama_stack.apis.scoring_functions import ScoringFunctions
|
from llama_stack.apis.scoring_functions import ScoringFunctions
|
||||||
from llama_stack.apis.shields import Shields
|
from llama_stack.apis.shields import Shields
|
||||||
from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration
|
|
||||||
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
|
from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
|
from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl
|
||||||
|
|
@ -63,8 +63,8 @@ class LlamaStack(
|
||||||
Providers,
|
Providers,
|
||||||
Inference,
|
Inference,
|
||||||
Agents,
|
Agents,
|
||||||
|
Batches,
|
||||||
Safety,
|
Safety,
|
||||||
SyntheticDataGeneration,
|
|
||||||
Datasets,
|
Datasets,
|
||||||
PostTraining,
|
PostTraining,
|
||||||
VectorIO,
|
VectorIO,
|
||||||
|
|
|
||||||
|
|
@ -152,6 +152,37 @@ docker run \
|
||||||
--port $LLAMA_STACK_PORT
|
--port $LLAMA_STACK_PORT
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Via Docker with Custom Run Configuration
|
||||||
|
|
||||||
|
You can also run the Docker container with a custom run configuration file by mounting it into the container:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Set the path to your custom run.yaml file
|
||||||
|
CUSTOM_RUN_CONFIG=/path/to/your/custom-run.yaml
|
||||||
|
|
||||||
|
docker run -it \
|
||||||
|
--pull always \
|
||||||
|
--network host \
|
||||||
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
|
-v $HOME/.llama:/root/.llama \
|
||||||
|
-v $CUSTOM_RUN_CONFIG:/app/custom-run.yaml \
|
||||||
|
-e RUN_CONFIG_PATH=/app/custom-run.yaml \
|
||||||
|
-e INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||||
|
-e DEH_URL=$DEH_URL \
|
||||||
|
-e CHROMA_URL=$CHROMA_URL \
|
||||||
|
llamastack/distribution-{{ name }} \
|
||||||
|
--port $LLAMA_STACK_PORT
|
||||||
|
```
|
||||||
|
|
||||||
|
**Note**: The run configuration must be mounted into the container before it can be used. The `-v` flag mounts your local file into the container, and the `RUN_CONFIG_PATH` environment variable tells the entrypoint script which configuration to use.
|
||||||
|
|
||||||
|
{% if run_configs %}
|
||||||
|
Available run configurations for this distribution:
|
||||||
|
{% for config in run_configs %}
|
||||||
|
- `{{ config }}`
|
||||||
|
{% endfor %}
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
### Via Conda
|
### Via Conda
|
||||||
|
|
||||||
Make sure you have done `pip install llama-stack` and have the Llama Stack CLI available.
|
Make sure you have done `pip install llama-stack` and have the Llama Stack CLI available.
|
||||||
|
|
|
||||||
|
|
@ -68,6 +68,36 @@ docker run \
|
||||||
--port $LLAMA_STACK_PORT
|
--port $LLAMA_STACK_PORT
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Via Docker with Custom Run Configuration
|
||||||
|
|
||||||
|
You can also run the Docker container with a custom run configuration file by mounting it into the container:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Set the path to your custom run.yaml file
|
||||||
|
CUSTOM_RUN_CONFIG=/path/to/your/custom-run.yaml
|
||||||
|
LLAMA_STACK_PORT=8321
|
||||||
|
|
||||||
|
docker run \
|
||||||
|
-it \
|
||||||
|
--pull always \
|
||||||
|
--gpu all \
|
||||||
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
|
-v ~/.llama:/root/.llama \
|
||||||
|
-v $CUSTOM_RUN_CONFIG:/app/custom-run.yaml \
|
||||||
|
-e RUN_CONFIG_PATH=/app/custom-run.yaml \
|
||||||
|
llamastack/distribution-{{ name }} \
|
||||||
|
--port $LLAMA_STACK_PORT
|
||||||
|
```
|
||||||
|
|
||||||
|
**Note**: The run configuration must be mounted into the container before it can be used. The `-v` flag mounts your local file into the container, and the `RUN_CONFIG_PATH` environment variable tells the entrypoint script which configuration to use.
|
||||||
|
|
||||||
|
{% if run_configs %}
|
||||||
|
Available run configurations for this distribution:
|
||||||
|
{% for config in run_configs %}
|
||||||
|
- `{{ config }}`
|
||||||
|
{% endfor %}
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
### Via venv
|
### Via venv
|
||||||
|
|
||||||
Make sure you have the Llama Stack CLI available.
|
Make sure you have the Llama Stack CLI available.
|
||||||
|
|
|
||||||
|
|
@ -117,13 +117,42 @@ docker run \
|
||||||
-it \
|
-it \
|
||||||
--pull always \
|
--pull always \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ./run.yaml:/root/my-run.yaml \
|
-v ~/.llama:/root/.llama \
|
||||||
-e NVIDIA_API_KEY=$NVIDIA_API_KEY \
|
-e NVIDIA_API_KEY=$NVIDIA_API_KEY \
|
||||||
llamastack/distribution-{{ name }} \
|
llamastack/distribution-{{ name }} \
|
||||||
--config /root/my-run.yaml \
|
|
||||||
--port $LLAMA_STACK_PORT
|
--port $LLAMA_STACK_PORT
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Via Docker with Custom Run Configuration
|
||||||
|
|
||||||
|
You can also run the Docker container with a custom run configuration file by mounting it into the container:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Set the path to your custom run.yaml file
|
||||||
|
CUSTOM_RUN_CONFIG=/path/to/your/custom-run.yaml
|
||||||
|
LLAMA_STACK_PORT=8321
|
||||||
|
|
||||||
|
docker run \
|
||||||
|
-it \
|
||||||
|
--pull always \
|
||||||
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
|
-v ~/.llama:/root/.llama \
|
||||||
|
-v $CUSTOM_RUN_CONFIG:/app/custom-run.yaml \
|
||||||
|
-e RUN_CONFIG_PATH=/app/custom-run.yaml \
|
||||||
|
-e NVIDIA_API_KEY=$NVIDIA_API_KEY \
|
||||||
|
llamastack/distribution-{{ name }} \
|
||||||
|
--port $LLAMA_STACK_PORT
|
||||||
|
```
|
||||||
|
|
||||||
|
**Note**: The run configuration must be mounted into the container before it can be used. The `-v` flag mounts your local file into the container, and the `RUN_CONFIG_PATH` environment variable tells the entrypoint script which configuration to use.
|
||||||
|
|
||||||
|
{% if run_configs %}
|
||||||
|
Available run configurations for this distribution:
|
||||||
|
{% for config in run_configs %}
|
||||||
|
- `{{ config }}`
|
||||||
|
{% endfor %}
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
### Via venv
|
### Via venv
|
||||||
|
|
||||||
If you've set up your local development environment, you can also install the distribution dependencies using your local virtual environment.
|
If you've set up your local development environment, you can also install the distribution dependencies using your local virtual environment.
|
||||||
|
|
|
||||||
|
|
@ -424,6 +424,7 @@ class DistributionTemplate(BaseModel):
|
||||||
providers_table=providers_table,
|
providers_table=providers_table,
|
||||||
run_config_env_vars=self.run_config_env_vars,
|
run_config_env_vars=self.run_config_env_vars,
|
||||||
default_models=default_models,
|
default_models=default_models,
|
||||||
|
run_configs=list(self.run_configs.keys()),
|
||||||
)
|
)
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ import uuid
|
||||||
import warnings
|
import warnings
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
|
@ -125,12 +126,12 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
|
|
||||||
def turn_to_messages(self, turn: Turn) -> list[Message]:
|
def turn_to_messages(self, turn: Turn) -> list[Message]:
|
||||||
messages = []
|
messages: list[Message] = []
|
||||||
|
|
||||||
# NOTE: if a toolcall response is in a step, we do not add it when processing the input messages
|
# NOTE: if a toolcall response is in a step, we do not add it when processing the input messages
|
||||||
tool_call_ids = set()
|
tool_call_ids = set()
|
||||||
for step in turn.steps:
|
for step in turn.steps:
|
||||||
if step.step_type == StepType.tool_execution.value:
|
if step.step_type == StepType.tool_execution.value and isinstance(step, ToolExecutionStep):
|
||||||
for response in step.tool_responses:
|
for response in step.tool_responses:
|
||||||
tool_call_ids.add(response.call_id)
|
tool_call_ids.add(response.call_id)
|
||||||
|
|
||||||
|
|
@ -149,9 +150,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
messages.append(msg)
|
messages.append(msg)
|
||||||
|
|
||||||
for step in turn.steps:
|
for step in turn.steps:
|
||||||
if step.step_type == StepType.inference.value:
|
if step.step_type == StepType.inference.value and isinstance(step, InferenceStep):
|
||||||
messages.append(step.model_response)
|
messages.append(step.model_response)
|
||||||
elif step.step_type == StepType.tool_execution.value:
|
elif step.step_type == StepType.tool_execution.value and isinstance(step, ToolExecutionStep):
|
||||||
for response in step.tool_responses:
|
for response in step.tool_responses:
|
||||||
messages.append(
|
messages.append(
|
||||||
ToolResponseMessage(
|
ToolResponseMessage(
|
||||||
|
|
@ -159,8 +160,8 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
content=response.content,
|
content=response.content,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif step.step_type == StepType.shield_call.value:
|
elif step.step_type == StepType.shield_call.value and isinstance(step, ShieldCallStep):
|
||||||
if step.violation:
|
if step.violation and step.violation.user_message:
|
||||||
# CompletionMessage itself in the ShieldResponse
|
# CompletionMessage itself in the ShieldResponse
|
||||||
messages.append(
|
messages.append(
|
||||||
CompletionMessage(
|
CompletionMessage(
|
||||||
|
|
@ -174,7 +175,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
return await self.storage.create_session(name)
|
return await self.storage.create_session(name)
|
||||||
|
|
||||||
async def get_messages_from_turns(self, turns: list[Turn]) -> list[Message]:
|
async def get_messages_from_turns(self, turns: list[Turn]) -> list[Message]:
|
||||||
messages = []
|
messages: list[Message] = []
|
||||||
if self.agent_config.instructions != "":
|
if self.agent_config.instructions != "":
|
||||||
messages.append(SystemMessage(content=self.agent_config.instructions))
|
messages.append(SystemMessage(content=self.agent_config.instructions))
|
||||||
|
|
||||||
|
|
@ -231,7 +232,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
steps = []
|
steps = []
|
||||||
messages = await self.get_messages_from_turns(turns)
|
messages = await self.get_messages_from_turns(turns)
|
||||||
|
|
||||||
if is_resume:
|
if is_resume:
|
||||||
|
assert isinstance(request, AgentTurnResumeRequest)
|
||||||
tool_response_messages = [
|
tool_response_messages = [
|
||||||
ToolResponseMessage(call_id=x.call_id, content=x.content) for x in request.tool_responses
|
ToolResponseMessage(call_id=x.call_id, content=x.content) for x in request.tool_responses
|
||||||
]
|
]
|
||||||
|
|
@ -252,42 +255,52 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
|
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
|
||||||
request.session_id, request.turn_id
|
request.session_id, request.turn_id
|
||||||
)
|
)
|
||||||
now = datetime.now(UTC).isoformat()
|
now_dt = datetime.now(UTC)
|
||||||
tool_execution_step = ToolExecutionStep(
|
tool_execution_step = ToolExecutionStep(
|
||||||
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
|
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
|
||||||
turn_id=request.turn_id,
|
turn_id=request.turn_id,
|
||||||
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
|
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
|
||||||
tool_responses=request.tool_responses,
|
tool_responses=request.tool_responses,
|
||||||
completed_at=now,
|
completed_at=now_dt,
|
||||||
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now),
|
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now_dt),
|
||||||
)
|
)
|
||||||
steps.append(tool_execution_step)
|
steps.append(tool_execution_step)
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepCompletePayload(
|
payload=AgentTurnResponseStepCompletePayload(
|
||||||
step_type=StepType.tool_execution.value,
|
step_type=StepType.tool_execution,
|
||||||
step_id=tool_execution_step.step_id,
|
step_id=tool_execution_step.step_id,
|
||||||
step_details=tool_execution_step,
|
step_details=tool_execution_step,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
input_messages = last_turn.input_messages
|
# Cast needed due to list invariance - last_turn.input_messages is the right type
|
||||||
|
input_messages = last_turn.input_messages # type: ignore[assignment]
|
||||||
|
|
||||||
turn_id = request.turn_id
|
actual_turn_id = request.turn_id
|
||||||
start_time = last_turn.started_at
|
start_time = last_turn.started_at
|
||||||
else:
|
else:
|
||||||
|
assert isinstance(request, AgentTurnCreateRequest)
|
||||||
messages.extend(request.messages)
|
messages.extend(request.messages)
|
||||||
start_time = datetime.now(UTC).isoformat()
|
start_time = datetime.now(UTC)
|
||||||
input_messages = request.messages
|
# Cast needed due to list invariance - request.messages is the right type
|
||||||
|
input_messages = request.messages # type: ignore[assignment]
|
||||||
|
# Use the generated turn_id from beginning of function
|
||||||
|
actual_turn_id = turn_id if turn_id else str(uuid.uuid4())
|
||||||
|
|
||||||
output_message = None
|
output_message = None
|
||||||
|
req_documents = request.documents if isinstance(request, AgentTurnCreateRequest) and not is_resume else None
|
||||||
|
req_sampling = (
|
||||||
|
self.agent_config.sampling_params if self.agent_config.sampling_params is not None else SamplingParams()
|
||||||
|
)
|
||||||
|
|
||||||
async for chunk in self.run(
|
async for chunk in self.run(
|
||||||
session_id=request.session_id,
|
session_id=request.session_id,
|
||||||
turn_id=turn_id,
|
turn_id=actual_turn_id,
|
||||||
input_messages=messages,
|
input_messages=messages,
|
||||||
sampling_params=self.agent_config.sampling_params,
|
sampling_params=req_sampling,
|
||||||
stream=request.stream,
|
stream=request.stream,
|
||||||
documents=request.documents if not is_resume else None,
|
documents=req_documents,
|
||||||
):
|
):
|
||||||
if isinstance(chunk, CompletionMessage):
|
if isinstance(chunk, CompletionMessage):
|
||||||
output_message = chunk
|
output_message = chunk
|
||||||
|
|
@ -295,20 +308,23 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
|
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
|
||||||
event = chunk.event
|
event = chunk.event
|
||||||
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value:
|
if event.payload.event_type == AgentTurnResponseEventType.step_complete.value and hasattr(
|
||||||
steps.append(event.payload.step_details)
|
event.payload, "step_details"
|
||||||
|
):
|
||||||
|
step_details = event.payload.step_details
|
||||||
|
steps.append(step_details)
|
||||||
|
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
assert output_message is not None
|
assert output_message is not None
|
||||||
|
|
||||||
turn = Turn(
|
turn = Turn(
|
||||||
turn_id=turn_id,
|
turn_id=actual_turn_id,
|
||||||
session_id=request.session_id,
|
session_id=request.session_id,
|
||||||
input_messages=input_messages,
|
input_messages=input_messages, # type: ignore[arg-type]
|
||||||
output_message=output_message,
|
output_message=output_message,
|
||||||
started_at=start_time,
|
started_at=start_time,
|
||||||
completed_at=datetime.now(UTC).isoformat(),
|
completed_at=datetime.now(UTC),
|
||||||
steps=steps,
|
steps=steps,
|
||||||
)
|
)
|
||||||
await self.storage.add_turn_to_session(request.session_id, turn)
|
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||||
|
|
@ -345,9 +361,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
# return a "final value" for the `yield from` statement. we simulate that by yielding a
|
# return a "final value" for the `yield from` statement. we simulate that by yielding a
|
||||||
# final boolean (to see whether an exception happened) and then explicitly testing for it.
|
# final boolean (to see whether an exception happened) and then explicitly testing for it.
|
||||||
|
|
||||||
if len(self.input_shields) > 0:
|
if self.input_shields:
|
||||||
async for res in self.run_multiple_shields_wrapper(
|
async for res in self.run_multiple_shields_wrapper(
|
||||||
turn_id, input_messages, self.input_shields, "user-input"
|
turn_id, cast(list[OpenAIMessageParam], input_messages), self.input_shields, "user-input"
|
||||||
):
|
):
|
||||||
if isinstance(res, bool):
|
if isinstance(res, bool):
|
||||||
return
|
return
|
||||||
|
|
@ -374,9 +390,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
# for output shields run on the full input and output combination
|
# for output shields run on the full input and output combination
|
||||||
messages = input_messages + [final_response]
|
messages = input_messages + [final_response]
|
||||||
|
|
||||||
if len(self.output_shields) > 0:
|
if self.output_shields:
|
||||||
async for res in self.run_multiple_shields_wrapper(
|
async for res in self.run_multiple_shields_wrapper(
|
||||||
turn_id, messages, self.output_shields, "assistant-output"
|
turn_id, cast(list[OpenAIMessageParam], messages), self.output_shields, "assistant-output"
|
||||||
):
|
):
|
||||||
if isinstance(res, bool):
|
if isinstance(res, bool):
|
||||||
return
|
return
|
||||||
|
|
@ -388,7 +404,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
async def run_multiple_shields_wrapper(
|
async def run_multiple_shields_wrapper(
|
||||||
self,
|
self,
|
||||||
turn_id: str,
|
turn_id: str,
|
||||||
messages: list[Message],
|
messages: list[OpenAIMessageParam],
|
||||||
shields: list[str],
|
shields: list[str],
|
||||||
touchpoint: str,
|
touchpoint: str,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
|
@ -402,12 +418,12 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
return
|
return
|
||||||
|
|
||||||
step_id = str(uuid.uuid4())
|
step_id = str(uuid.uuid4())
|
||||||
shield_call_start_time = datetime.now(UTC).isoformat()
|
shield_call_start_time = datetime.now(UTC)
|
||||||
try:
|
try:
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepStartPayload(
|
payload=AgentTurnResponseStepStartPayload(
|
||||||
step_type=StepType.shield_call.value,
|
step_type=StepType.shield_call,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
metadata=dict(touchpoint=touchpoint),
|
metadata=dict(touchpoint=touchpoint),
|
||||||
)
|
)
|
||||||
|
|
@ -419,14 +435,14 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepCompletePayload(
|
payload=AgentTurnResponseStepCompletePayload(
|
||||||
step_type=StepType.shield_call.value,
|
step_type=StepType.shield_call,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
step_details=ShieldCallStep(
|
step_details=ShieldCallStep(
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
turn_id=turn_id,
|
turn_id=turn_id,
|
||||||
violation=e.violation,
|
violation=e.violation,
|
||||||
started_at=shield_call_start_time,
|
started_at=shield_call_start_time,
|
||||||
completed_at=datetime.now(UTC).isoformat(),
|
completed_at=datetime.now(UTC),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
@ -443,14 +459,14 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepCompletePayload(
|
payload=AgentTurnResponseStepCompletePayload(
|
||||||
step_type=StepType.shield_call.value,
|
step_type=StepType.shield_call,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
step_details=ShieldCallStep(
|
step_details=ShieldCallStep(
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
turn_id=turn_id,
|
turn_id=turn_id,
|
||||||
violation=None,
|
violation=None,
|
||||||
started_at=shield_call_start_time,
|
started_at=shield_call_start_time,
|
||||||
completed_at=datetime.now(UTC).isoformat(),
|
completed_at=datetime.now(UTC),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
@ -496,21 +512,22 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
else:
|
else:
|
||||||
self.tool_name_to_args[tool_name]["vector_store_ids"].append(session_info.vector_store_id)
|
self.tool_name_to_args[tool_name]["vector_store_ids"].append(session_info.vector_store_id)
|
||||||
|
|
||||||
output_attachments = []
|
output_attachments: list[Attachment] = []
|
||||||
|
|
||||||
n_iter = await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0
|
n_iter = await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0
|
||||||
|
|
||||||
# Build a map of custom tools to their definitions for faster lookup
|
# Build a map of custom tools to their definitions for faster lookup
|
||||||
client_tools = {}
|
client_tools = {}
|
||||||
|
if self.agent_config.client_tools:
|
||||||
for tool in self.agent_config.client_tools:
|
for tool in self.agent_config.client_tools:
|
||||||
client_tools[tool.name] = tool
|
client_tools[tool.name] = tool
|
||||||
while True:
|
while True:
|
||||||
step_id = str(uuid.uuid4())
|
step_id = str(uuid.uuid4())
|
||||||
inference_start_time = datetime.now(UTC).isoformat()
|
inference_start_time = datetime.now(UTC)
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepStartPayload(
|
payload=AgentTurnResponseStepStartPayload(
|
||||||
step_type=StepType.inference.value,
|
step_type=StepType.inference,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
@ -538,7 +555,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
else:
|
else:
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def _add_type(openai_msg: dict) -> OpenAIMessageParam:
|
def _add_type(openai_msg: Any) -> OpenAIMessageParam:
|
||||||
# Serialize any nested Pydantic models to plain dicts
|
# Serialize any nested Pydantic models to plain dicts
|
||||||
openai_msg = _serialize_nested(openai_msg)
|
openai_msg = _serialize_nested(openai_msg)
|
||||||
|
|
||||||
|
|
@ -588,7 +605,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
messages=openai_messages,
|
messages=openai_messages,
|
||||||
tools=openai_tools if openai_tools else None,
|
tools=openai_tools if openai_tools else None,
|
||||||
tool_choice=tool_choice,
|
tool_choice=tool_choice,
|
||||||
response_format=self.agent_config.response_format,
|
response_format=self.agent_config.response_format, # type: ignore[arg-type]
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
|
|
@ -598,7 +615,8 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
# Convert OpenAI stream back to Llama Stack format
|
# Convert OpenAI stream back to Llama Stack format
|
||||||
response_stream = convert_openai_chat_completion_stream(
|
response_stream = convert_openai_chat_completion_stream(
|
||||||
openai_stream, enable_incremental_tool_calls=True
|
openai_stream, # type: ignore[arg-type]
|
||||||
|
enable_incremental_tool_calls=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
async for chunk in response_stream:
|
async for chunk in response_stream:
|
||||||
|
|
@ -620,7 +638,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepProgressPayload(
|
payload=AgentTurnResponseStepProgressPayload(
|
||||||
step_type=StepType.inference.value,
|
step_type=StepType.inference,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
delta=delta,
|
delta=delta,
|
||||||
)
|
)
|
||||||
|
|
@ -633,7 +651,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepProgressPayload(
|
payload=AgentTurnResponseStepProgressPayload(
|
||||||
step_type=StepType.inference.value,
|
step_type=StepType.inference,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
delta=delta,
|
delta=delta,
|
||||||
)
|
)
|
||||||
|
|
@ -651,7 +669,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
output_attr = json.dumps(
|
output_attr = json.dumps(
|
||||||
{
|
{
|
||||||
"content": content,
|
"content": content,
|
||||||
"tool_calls": [json.loads(t.model_dump_json()) for t in tool_calls],
|
"tool_calls": [
|
||||||
|
json.loads(t.model_dump_json()) for t in tool_calls if isinstance(t, ToolCall)
|
||||||
|
],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
span.set_attribute("output", output_attr)
|
span.set_attribute("output", output_attr)
|
||||||
|
|
@ -667,16 +687,18 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
if tool_calls:
|
if tool_calls:
|
||||||
content = ""
|
content = ""
|
||||||
|
|
||||||
|
# Filter out string tool calls for CompletionMessage (only keep ToolCall objects)
|
||||||
|
valid_tool_calls = [t for t in tool_calls if isinstance(t, ToolCall)]
|
||||||
message = CompletionMessage(
|
message = CompletionMessage(
|
||||||
content=content,
|
content=content,
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
tool_calls=tool_calls,
|
tool_calls=valid_tool_calls if valid_tool_calls else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepCompletePayload(
|
payload=AgentTurnResponseStepCompletePayload(
|
||||||
step_type=StepType.inference.value,
|
step_type=StepType.inference,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
step_details=InferenceStep(
|
step_details=InferenceStep(
|
||||||
# somewhere deep, we are re-assigning message or closing over some
|
# somewhere deep, we are re-assigning message or closing over some
|
||||||
|
|
@ -686,13 +708,14 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
turn_id=turn_id,
|
turn_id=turn_id,
|
||||||
model_response=copy.deepcopy(message),
|
model_response=copy.deepcopy(message),
|
||||||
started_at=inference_start_time,
|
started_at=inference_start_time,
|
||||||
completed_at=datetime.now(UTC).isoformat(),
|
completed_at=datetime.now(UTC),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if n_iter >= self.agent_config.max_infer_iters:
|
max_iters = self.agent_config.max_infer_iters if self.agent_config.max_infer_iters is not None else 10
|
||||||
|
if n_iter >= max_iters:
|
||||||
logger.info(f"done with MAX iterations ({n_iter}), exiting.")
|
logger.info(f"done with MAX iterations ({n_iter}), exiting.")
|
||||||
# NOTE: mark end_of_turn to indicate to client that we are done with the turn
|
# NOTE: mark end_of_turn to indicate to client that we are done with the turn
|
||||||
# Do not continue the tool call loop after this point
|
# Do not continue the tool call loop after this point
|
||||||
|
|
@ -705,14 +728,16 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
yield message
|
yield message
|
||||||
break
|
break
|
||||||
|
|
||||||
if len(message.tool_calls) == 0:
|
if not message.tool_calls or len(message.tool_calls) == 0:
|
||||||
if stop_reason == StopReason.end_of_turn:
|
if stop_reason == StopReason.end_of_turn:
|
||||||
# TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS)
|
# TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS)
|
||||||
if len(output_attachments) > 0:
|
if len(output_attachments) > 0:
|
||||||
if isinstance(message.content, list):
|
if isinstance(message.content, list):
|
||||||
message.content += output_attachments
|
# List invariance - attachments are compatible at runtime
|
||||||
|
message.content += output_attachments # type: ignore[arg-type]
|
||||||
else:
|
else:
|
||||||
message.content = [message.content] + output_attachments
|
# List invariance - attachments are compatible at runtime
|
||||||
|
message.content = [message.content] + output_attachments # type: ignore[assignment]
|
||||||
yield message
|
yield message
|
||||||
else:
|
else:
|
||||||
logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}")
|
logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}")
|
||||||
|
|
@ -725,6 +750,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
non_client_tool_calls = []
|
non_client_tool_calls = []
|
||||||
|
|
||||||
# Separate client and non-client tool calls
|
# Separate client and non-client tool calls
|
||||||
|
if message.tool_calls:
|
||||||
for tool_call in message.tool_calls:
|
for tool_call in message.tool_calls:
|
||||||
if tool_call.tool_name in client_tools:
|
if tool_call.tool_name in client_tools:
|
||||||
client_tool_calls.append(tool_call)
|
client_tool_calls.append(tool_call)
|
||||||
|
|
@ -737,7 +763,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepStartPayload(
|
payload=AgentTurnResponseStepStartPayload(
|
||||||
step_type=StepType.tool_execution.value,
|
step_type=StepType.tool_execution,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
@ -746,7 +772,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepProgressPayload(
|
payload=AgentTurnResponseStepProgressPayload(
|
||||||
step_type=StepType.tool_execution.value,
|
step_type=StepType.tool_execution,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
delta=ToolCallDelta(
|
delta=ToolCallDelta(
|
||||||
parse_status=ToolCallParseStatus.in_progress,
|
parse_status=ToolCallParseStatus.in_progress,
|
||||||
|
|
@ -766,7 +792,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
if self.telemetry_enabled
|
if self.telemetry_enabled
|
||||||
else {},
|
else {},
|
||||||
) as span:
|
) as span:
|
||||||
tool_execution_start_time = datetime.now(UTC).isoformat()
|
tool_execution_start_time = datetime.now(UTC)
|
||||||
tool_result = await self.execute_tool_call_maybe(
|
tool_result = await self.execute_tool_call_maybe(
|
||||||
session_id,
|
session_id,
|
||||||
tool_call,
|
tool_call,
|
||||||
|
|
@ -796,14 +822,14 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
started_at=tool_execution_start_time,
|
started_at=tool_execution_start_time,
|
||||||
completed_at=datetime.now(UTC).isoformat(),
|
completed_at=datetime.now(UTC),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Yield the step completion event
|
# Yield the step completion event
|
||||||
yield AgentTurnResponseStreamChunk(
|
yield AgentTurnResponseStreamChunk(
|
||||||
event=AgentTurnResponseEvent(
|
event=AgentTurnResponseEvent(
|
||||||
payload=AgentTurnResponseStepCompletePayload(
|
payload=AgentTurnResponseStepCompletePayload(
|
||||||
step_type=StepType.tool_execution.value,
|
step_type=StepType.tool_execution,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
step_details=tool_execution_step,
|
step_details=tool_execution_step,
|
||||||
)
|
)
|
||||||
|
|
@ -833,7 +859,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
turn_id=turn_id,
|
turn_id=turn_id,
|
||||||
tool_calls=client_tool_calls,
|
tool_calls=client_tool_calls,
|
||||||
tool_responses=[],
|
tool_responses=[],
|
||||||
started_at=datetime.now(UTC).isoformat(),
|
started_at=datetime.now(UTC),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -868,9 +894,10 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
||||||
toolgroup_to_args = toolgroup_to_args or {}
|
toolgroup_to_args = toolgroup_to_args or {}
|
||||||
|
|
||||||
tool_name_to_def = {}
|
tool_name_to_def: dict[str, ToolDefinition] = {}
|
||||||
tool_name_to_args = {}
|
tool_name_to_args = {}
|
||||||
|
|
||||||
|
if self.agent_config.client_tools:
|
||||||
for tool_def in self.agent_config.client_tools:
|
for tool_def in self.agent_config.client_tools:
|
||||||
if tool_name_to_def.get(tool_def.name, None):
|
if tool_name_to_def.get(tool_def.name, None):
|
||||||
raise ValueError(f"Tool {tool_def.name} already exists")
|
raise ValueError(f"Tool {tool_def.name} already exists")
|
||||||
|
|
@ -908,15 +935,17 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
else:
|
else:
|
||||||
identifier = None
|
identifier = None
|
||||||
|
|
||||||
if tool_name_to_def.get(identifier, None):
|
|
||||||
raise ValueError(f"Tool {identifier} already exists")
|
|
||||||
if identifier:
|
if identifier:
|
||||||
tool_name_to_def[identifier] = ToolDefinition(
|
# Convert BuiltinTool to string for dictionary key
|
||||||
tool_name=identifier,
|
identifier_str = identifier.value if isinstance(identifier, BuiltinTool) else identifier
|
||||||
|
if tool_name_to_def.get(identifier_str, None):
|
||||||
|
raise ValueError(f"Tool {identifier_str} already exists")
|
||||||
|
tool_name_to_def[identifier_str] = ToolDefinition(
|
||||||
|
tool_name=identifier_str,
|
||||||
description=tool_def.description,
|
description=tool_def.description,
|
||||||
input_schema=tool_def.input_schema,
|
input_schema=tool_def.input_schema,
|
||||||
)
|
)
|
||||||
tool_name_to_args[identifier] = toolgroup_to_args.get(toolgroup_name, {})
|
tool_name_to_args[identifier_str] = toolgroup_to_args.get(toolgroup_name, {})
|
||||||
|
|
||||||
self.tool_defs, self.tool_name_to_args = (
|
self.tool_defs, self.tool_name_to_args = (
|
||||||
list(tool_name_to_def.values()),
|
list(tool_name_to_def.values()),
|
||||||
|
|
@ -966,7 +995,9 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
raise ValueError(f"Failed to parse arguments for tool call: {tool_call.arguments}") from e
|
raise ValueError(f"Failed to parse arguments for tool call: {tool_call.arguments}") from e
|
||||||
|
|
||||||
result = await self.tool_runtime_api.invoke_tool(
|
result = cast(
|
||||||
|
ToolInvocationResult,
|
||||||
|
await self.tool_runtime_api.invoke_tool(
|
||||||
tool_name=tool_name_str,
|
tool_name=tool_name_str,
|
||||||
kwargs={
|
kwargs={
|
||||||
"session_id": session_id,
|
"session_id": session_id,
|
||||||
|
|
@ -974,6 +1005,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
**args,
|
**args,
|
||||||
**self.tool_name_to_args.get(tool_name_str, {}),
|
**self.tool_name_to_args.get(tool_name_str, {}),
|
||||||
},
|
},
|
||||||
|
),
|
||||||
)
|
)
|
||||||
logger.debug(f"tool call {tool_name_str} completed with result: {result}")
|
logger.debug(f"tool call {tool_name_str} completed with result: {result}")
|
||||||
return result
|
return result
|
||||||
|
|
@ -983,7 +1015,7 @@ async def load_data_from_url(url: str) -> str:
|
||||||
if url.startswith("http"):
|
if url.startswith("http"):
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
r = await client.get(url)
|
r = await client.get(url)
|
||||||
resp = r.text
|
resp: str = r.text
|
||||||
return resp
|
return resp
|
||||||
raise ValueError(f"Unexpected URL: {type(url)}")
|
raise ValueError(f"Unexpected URL: {type(url)}")
|
||||||
|
|
||||||
|
|
@ -1017,7 +1049,7 @@ def _interpret_content_as_attachment(
|
||||||
snippet = match.group(1)
|
snippet = match.group(1)
|
||||||
data = json.loads(snippet)
|
data = json.loads(snippet)
|
||||||
return Attachment(
|
return Attachment(
|
||||||
url=URL(uri="file://" + data["filepath"]),
|
content=URL(uri="file://" + data["filepath"]),
|
||||||
mime_type=data["mimetype"],
|
mime_type=data["mimetype"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,7 @@ from llama_stack.apis.agents import (
|
||||||
Document,
|
Document,
|
||||||
ListOpenAIResponseInputItem,
|
ListOpenAIResponseInputItem,
|
||||||
ListOpenAIResponseObject,
|
ListOpenAIResponseObject,
|
||||||
|
OpenAIDeleteResponseObject,
|
||||||
OpenAIResponseInput,
|
OpenAIResponseInput,
|
||||||
OpenAIResponseInputTool,
|
OpenAIResponseInputTool,
|
||||||
OpenAIResponseObject,
|
OpenAIResponseObject,
|
||||||
|
|
@ -141,7 +142,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
persistence_store=(
|
persistence_store=(
|
||||||
self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store
|
self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store
|
||||||
),
|
),
|
||||||
created_at=agent_info.created_at,
|
created_at=agent_info.created_at.isoformat(),
|
||||||
policy=self.policy,
|
policy=self.policy,
|
||||||
telemetry_enabled=self.telemetry_enabled,
|
telemetry_enabled=self.telemetry_enabled,
|
||||||
)
|
)
|
||||||
|
|
@ -163,9 +164,9 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
messages: list[UserMessage | ToolResponseMessage],
|
messages: list[UserMessage | ToolResponseMessage],
|
||||||
toolgroups: list[AgentToolGroup] | None = None,
|
|
||||||
documents: list[Document] | None = None,
|
|
||||||
stream: bool | None = False,
|
stream: bool | None = False,
|
||||||
|
documents: list[Document] | None = None,
|
||||||
|
toolgroups: list[AgentToolGroup] | None = None,
|
||||||
tool_config: ToolConfig | None = None,
|
tool_config: ToolConfig | None = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
request = AgentTurnCreateRequest(
|
request = AgentTurnCreateRequest(
|
||||||
|
|
@ -221,6 +222,8 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
|
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
|
||||||
agent = await self._get_agent_impl(agent_id)
|
agent = await self._get_agent_impl(agent_id)
|
||||||
turn = await agent.storage.get_session_turn(session_id, turn_id)
|
turn = await agent.storage.get_session_turn(session_id, turn_id)
|
||||||
|
if turn is None:
|
||||||
|
raise ValueError(f"Turn {turn_id} not found in session {session_id}")
|
||||||
return turn
|
return turn
|
||||||
|
|
||||||
async def get_agents_step(self, agent_id: str, session_id: str, turn_id: str, step_id: str) -> AgentStepResponse:
|
async def get_agents_step(self, agent_id: str, session_id: str, turn_id: str, step_id: str) -> AgentStepResponse:
|
||||||
|
|
@ -232,13 +235,15 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
|
|
||||||
async def get_agents_session(
|
async def get_agents_session(
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
|
||||||
session_id: str,
|
session_id: str,
|
||||||
|
agent_id: str,
|
||||||
turn_ids: list[str] | None = None,
|
turn_ids: list[str] | None = None,
|
||||||
) -> Session:
|
) -> Session:
|
||||||
agent = await self._get_agent_impl(agent_id)
|
agent = await self._get_agent_impl(agent_id)
|
||||||
|
|
||||||
session_info = await agent.storage.get_session_info(session_id)
|
session_info = await agent.storage.get_session_info(session_id)
|
||||||
|
if session_info is None:
|
||||||
|
raise ValueError(f"Session {session_id} not found")
|
||||||
turns = await agent.storage.get_session_turns(session_id)
|
turns = await agent.storage.get_session_turns(session_id)
|
||||||
if turn_ids:
|
if turn_ids:
|
||||||
turns = [turn for turn in turns if turn.turn_id in turn_ids]
|
turns = [turn for turn in turns if turn.turn_id in turn_ids]
|
||||||
|
|
@ -249,7 +254,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
started_at=session_info.started_at,
|
started_at=session_info.started_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def delete_agents_session(self, agent_id: str, session_id: str) -> None:
|
async def delete_agents_session(self, session_id: str, agent_id: str) -> None:
|
||||||
agent = await self._get_agent_impl(agent_id)
|
agent = await self._get_agent_impl(agent_id)
|
||||||
|
|
||||||
# Delete turns first, then the session
|
# Delete turns first, then the session
|
||||||
|
|
@ -302,7 +307,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
agent = Agent(
|
agent = Agent(
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
agent_config=chat_agent.agent_config,
|
agent_config=chat_agent.agent_config,
|
||||||
created_at=chat_agent.created_at,
|
created_at=datetime.fromisoformat(chat_agent.created_at),
|
||||||
)
|
)
|
||||||
return agent
|
return agent
|
||||||
|
|
||||||
|
|
@ -323,6 +328,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
self,
|
self,
|
||||||
response_id: str,
|
response_id: str,
|
||||||
) -> OpenAIResponseObject:
|
) -> OpenAIResponseObject:
|
||||||
|
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
|
||||||
return await self.openai_responses_impl.get_openai_response(response_id)
|
return await self.openai_responses_impl.get_openai_response(response_id)
|
||||||
|
|
||||||
async def create_openai_response(
|
async def create_openai_response(
|
||||||
|
|
@ -342,7 +348,8 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
max_infer_iters: int | None = 10,
|
max_infer_iters: int | None = 10,
|
||||||
guardrails: list[ResponseGuardrail] | None = None,
|
guardrails: list[ResponseGuardrail] | None = None,
|
||||||
) -> OpenAIResponseObject:
|
) -> OpenAIResponseObject:
|
||||||
return await self.openai_responses_impl.create_openai_response(
|
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
|
||||||
|
result = await self.openai_responses_impl.create_openai_response(
|
||||||
input,
|
input,
|
||||||
model,
|
model,
|
||||||
prompt,
|
prompt,
|
||||||
|
|
@ -358,6 +365,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
max_infer_iters,
|
max_infer_iters,
|
||||||
guardrails,
|
guardrails,
|
||||||
)
|
)
|
||||||
|
return result # type: ignore[no-any-return]
|
||||||
|
|
||||||
async def list_openai_responses(
|
async def list_openai_responses(
|
||||||
self,
|
self,
|
||||||
|
|
@ -366,6 +374,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
model: str | None = None,
|
model: str | None = None,
|
||||||
order: Order | None = Order.desc,
|
order: Order | None = Order.desc,
|
||||||
) -> ListOpenAIResponseObject:
|
) -> ListOpenAIResponseObject:
|
||||||
|
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
|
||||||
return await self.openai_responses_impl.list_openai_responses(after, limit, model, order)
|
return await self.openai_responses_impl.list_openai_responses(after, limit, model, order)
|
||||||
|
|
||||||
async def list_openai_response_input_items(
|
async def list_openai_response_input_items(
|
||||||
|
|
@ -377,9 +386,11 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
limit: int | None = 20,
|
limit: int | None = 20,
|
||||||
order: Order | None = Order.desc,
|
order: Order | None = Order.desc,
|
||||||
) -> ListOpenAIResponseInputItem:
|
) -> ListOpenAIResponseInputItem:
|
||||||
|
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
|
||||||
return await self.openai_responses_impl.list_openai_response_input_items(
|
return await self.openai_responses_impl.list_openai_response_input_items(
|
||||||
response_id, after, before, include, limit, order
|
response_id, after, before, include, limit, order
|
||||||
)
|
)
|
||||||
|
|
||||||
async def delete_openai_response(self, response_id: str) -> None:
|
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
|
||||||
|
assert self.openai_responses_impl is not None, "OpenAI responses not initialized"
|
||||||
return await self.openai_responses_impl.delete_openai_response(response_id)
|
return await self.openai_responses_impl.delete_openai_response(response_id)
|
||||||
|
|
|
||||||
|
|
@ -6,12 +6,14 @@
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
|
from dataclasses import dataclass
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
|
from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
|
||||||
from llama_stack.apis.common.errors import SessionNotFoundError
|
from llama_stack.apis.common.errors import SessionNotFoundError
|
||||||
from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed
|
from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed
|
||||||
from llama_stack.core.access_control.datatypes import AccessRule
|
from llama_stack.core.access_control.conditions import User as ProtocolUser
|
||||||
|
from llama_stack.core.access_control.datatypes import AccessRule, Action
|
||||||
from llama_stack.core.datatypes import User
|
from llama_stack.core.datatypes import User
|
||||||
from llama_stack.core.request_headers import get_authenticated_user
|
from llama_stack.core.request_headers import get_authenticated_user
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
@ -33,6 +35,15 @@ class AgentInfo(AgentConfig):
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SessionResource:
|
||||||
|
"""Concrete implementation of ProtectedResource for session access control."""
|
||||||
|
|
||||||
|
type: str
|
||||||
|
identifier: str
|
||||||
|
owner: ProtocolUser # Use the protocol type for structural compatibility
|
||||||
|
|
||||||
|
|
||||||
class AgentPersistence:
|
class AgentPersistence:
|
||||||
def __init__(self, agent_id: str, kvstore: KVStore, policy: list[AccessRule]):
|
def __init__(self, agent_id: str, kvstore: KVStore, policy: list[AccessRule]):
|
||||||
self.agent_id = agent_id
|
self.agent_id = agent_id
|
||||||
|
|
@ -53,8 +64,15 @@ class AgentPersistence:
|
||||||
turns=[],
|
turns=[],
|
||||||
identifier=name, # should this be qualified in any way?
|
identifier=name, # should this be qualified in any way?
|
||||||
)
|
)
|
||||||
if not is_action_allowed(self.policy, "create", session_info, user):
|
# Only perform access control if we have an authenticated user
|
||||||
raise AccessDeniedError("create", session_info, user)
|
if user is not None and session_info.identifier is not None:
|
||||||
|
resource = SessionResource(
|
||||||
|
type=session_info.type,
|
||||||
|
identifier=session_info.identifier,
|
||||||
|
owner=user,
|
||||||
|
)
|
||||||
|
if not is_action_allowed(self.policy, Action.CREATE, resource, user):
|
||||||
|
raise AccessDeniedError(Action.CREATE, resource, user)
|
||||||
|
|
||||||
await self.kvstore.set(
|
await self.kvstore.set(
|
||||||
key=f"session:{self.agent_id}:{session_id}",
|
key=f"session:{self.agent_id}:{session_id}",
|
||||||
|
|
@ -62,7 +80,7 @@ class AgentPersistence:
|
||||||
)
|
)
|
||||||
return session_id
|
return session_id
|
||||||
|
|
||||||
async def get_session_info(self, session_id: str) -> AgentSessionInfo:
|
async def get_session_info(self, session_id: str) -> AgentSessionInfo | None:
|
||||||
value = await self.kvstore.get(
|
value = await self.kvstore.get(
|
||||||
key=f"session:{self.agent_id}:{session_id}",
|
key=f"session:{self.agent_id}:{session_id}",
|
||||||
)
|
)
|
||||||
|
|
@ -83,7 +101,22 @@ class AgentPersistence:
|
||||||
if not hasattr(session_info, "access_attributes") and not hasattr(session_info, "owner"):
|
if not hasattr(session_info, "access_attributes") and not hasattr(session_info, "owner"):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return is_action_allowed(self.policy, "read", session_info, get_authenticated_user())
|
# Get current user - if None, skip access control (e.g., in tests)
|
||||||
|
user = get_authenticated_user()
|
||||||
|
if user is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Access control requires identifier and owner to be set
|
||||||
|
if session_info.identifier is None or session_info.owner is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# At this point, both identifier and owner are guaranteed to be non-None
|
||||||
|
resource = SessionResource(
|
||||||
|
type=session_info.type,
|
||||||
|
identifier=session_info.identifier,
|
||||||
|
owner=session_info.owner,
|
||||||
|
)
|
||||||
|
return is_action_allowed(self.policy, Action.READ, resource, user)
|
||||||
|
|
||||||
async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None:
|
async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None:
|
||||||
"""Get session info if the user has access to it. For internal use by sub-session methods."""
|
"""Get session info if the user has access to it. For internal use by sub-session methods."""
|
||||||
|
|
|
||||||
|
|
@ -91,7 +91,8 @@ class OpenAIResponsesImpl:
|
||||||
input: str | list[OpenAIResponseInput],
|
input: str | list[OpenAIResponseInput],
|
||||||
previous_response: _OpenAIResponseObjectWithInputAndMessages,
|
previous_response: _OpenAIResponseObjectWithInputAndMessages,
|
||||||
):
|
):
|
||||||
new_input_items = previous_response.input.copy()
|
# Convert Sequence to list for mutation
|
||||||
|
new_input_items = list(previous_response.input)
|
||||||
new_input_items.extend(previous_response.output)
|
new_input_items.extend(previous_response.output)
|
||||||
|
|
||||||
if isinstance(input, str):
|
if isinstance(input, str):
|
||||||
|
|
@ -107,7 +108,7 @@ class OpenAIResponsesImpl:
|
||||||
tools: list[OpenAIResponseInputTool] | None,
|
tools: list[OpenAIResponseInputTool] | None,
|
||||||
previous_response_id: str | None,
|
previous_response_id: str | None,
|
||||||
conversation: str | None,
|
conversation: str | None,
|
||||||
) -> tuple[str | list[OpenAIResponseInput], list[OpenAIMessageParam]]:
|
) -> tuple[str | list[OpenAIResponseInput], list[OpenAIMessageParam], ToolContext]:
|
||||||
"""Process input with optional previous response context.
|
"""Process input with optional previous response context.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
@ -208,6 +209,9 @@ class OpenAIResponsesImpl:
|
||||||
messages: list[OpenAIMessageParam],
|
messages: list[OpenAIMessageParam],
|
||||||
) -> None:
|
) -> None:
|
||||||
new_input_id = f"msg_{uuid.uuid4()}"
|
new_input_id = f"msg_{uuid.uuid4()}"
|
||||||
|
# Type input_items_data as the full OpenAIResponseInput union to avoid list invariance issues
|
||||||
|
input_items_data: list[OpenAIResponseInput] = []
|
||||||
|
|
||||||
if isinstance(input, str):
|
if isinstance(input, str):
|
||||||
# synthesize a message from the input string
|
# synthesize a message from the input string
|
||||||
input_content = OpenAIResponseInputMessageContentText(text=input)
|
input_content = OpenAIResponseInputMessageContentText(text=input)
|
||||||
|
|
@ -219,7 +223,6 @@ class OpenAIResponsesImpl:
|
||||||
input_items_data = [input_content_item]
|
input_items_data = [input_content_item]
|
||||||
else:
|
else:
|
||||||
# we already have a list of messages
|
# we already have a list of messages
|
||||||
input_items_data = []
|
|
||||||
for input_item in input:
|
for input_item in input:
|
||||||
if isinstance(input_item, OpenAIResponseMessage):
|
if isinstance(input_item, OpenAIResponseMessage):
|
||||||
# These may or may not already have an id, so dump to dict, check for id, and add if missing
|
# These may or may not already have an id, so dump to dict, check for id, and add if missing
|
||||||
|
|
@ -251,7 +254,7 @@ class OpenAIResponsesImpl:
|
||||||
tools: list[OpenAIResponseInputTool] | None = None,
|
tools: list[OpenAIResponseInputTool] | None = None,
|
||||||
include: list[str] | None = None,
|
include: list[str] | None = None,
|
||||||
max_infer_iters: int | None = 10,
|
max_infer_iters: int | None = 10,
|
||||||
guardrails: list[ResponseGuardrailSpec] | None = None,
|
guardrails: list[str | ResponseGuardrailSpec] | None = None,
|
||||||
):
|
):
|
||||||
stream = bool(stream)
|
stream = bool(stream)
|
||||||
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
|
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
|
||||||
|
|
@ -289,7 +292,8 @@ class OpenAIResponsesImpl:
|
||||||
failed_response = None
|
failed_response = None
|
||||||
|
|
||||||
async for stream_chunk in stream_gen:
|
async for stream_chunk in stream_gen:
|
||||||
if stream_chunk.type in {"response.completed", "response.incomplete"}:
|
match stream_chunk.type:
|
||||||
|
case "response.completed" | "response.incomplete":
|
||||||
if final_response is not None:
|
if final_response is not None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The response stream produced multiple terminal responses! "
|
"The response stream produced multiple terminal responses! "
|
||||||
|
|
@ -297,8 +301,10 @@ class OpenAIResponsesImpl:
|
||||||
)
|
)
|
||||||
final_response = stream_chunk.response
|
final_response = stream_chunk.response
|
||||||
final_event_type = stream_chunk.type
|
final_event_type = stream_chunk.type
|
||||||
elif stream_chunk.type == "response.failed":
|
case "response.failed":
|
||||||
failed_response = stream_chunk.response
|
failed_response = stream_chunk.response
|
||||||
|
case _:
|
||||||
|
pass # Other event types don't have .response
|
||||||
|
|
||||||
if failed_response is not None:
|
if failed_response is not None:
|
||||||
error_message = (
|
error_message = (
|
||||||
|
|
@ -326,6 +332,11 @@ class OpenAIResponsesImpl:
|
||||||
max_infer_iters: int | None = 10,
|
max_infer_iters: int | None = 10,
|
||||||
guardrail_ids: list[str] | None = None,
|
guardrail_ids: list[str] | None = None,
|
||||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||||
|
# These should never be None when called from create_openai_response (which sets defaults)
|
||||||
|
# but we assert here to help mypy understand the types
|
||||||
|
assert text is not None, "text must not be None"
|
||||||
|
assert max_infer_iters is not None, "max_infer_iters must not be None"
|
||||||
|
|
||||||
# Input preprocessing
|
# Input preprocessing
|
||||||
all_input, messages, tool_context = await self._process_input_with_previous_response(
|
all_input, messages, tool_context = await self._process_input_with_previous_response(
|
||||||
input, tools, previous_response_id, conversation
|
input, tools, previous_response_id, conversation
|
||||||
|
|
@ -368,16 +379,19 @@ class OpenAIResponsesImpl:
|
||||||
final_response = None
|
final_response = None
|
||||||
failed_response = None
|
failed_response = None
|
||||||
|
|
||||||
output_items = []
|
# Type as ConversationItem to avoid list invariance issues
|
||||||
|
output_items: list[ConversationItem] = []
|
||||||
async for stream_chunk in orchestrator.create_response():
|
async for stream_chunk in orchestrator.create_response():
|
||||||
if stream_chunk.type in {"response.completed", "response.incomplete"}:
|
match stream_chunk.type:
|
||||||
|
case "response.completed" | "response.incomplete":
|
||||||
final_response = stream_chunk.response
|
final_response = stream_chunk.response
|
||||||
elif stream_chunk.type == "response.failed":
|
case "response.failed":
|
||||||
failed_response = stream_chunk.response
|
failed_response = stream_chunk.response
|
||||||
|
case "response.output_item.done":
|
||||||
if stream_chunk.type == "response.output_item.done":
|
|
||||||
item = stream_chunk.item
|
item = stream_chunk.item
|
||||||
output_items.append(item)
|
output_items.append(item)
|
||||||
|
case _:
|
||||||
|
pass # Other event types
|
||||||
|
|
||||||
# Store and sync before yielding terminal events
|
# Store and sync before yielding terminal events
|
||||||
# This ensures the storage/syncing happens even if the consumer breaks after receiving the event
|
# This ensures the storage/syncing happens even if the consumer breaks after receiving the event
|
||||||
|
|
@ -410,7 +424,8 @@ class OpenAIResponsesImpl:
|
||||||
self, conversation_id: str, input: str | list[OpenAIResponseInput] | None, output_items: list[ConversationItem]
|
self, conversation_id: str, input: str | list[OpenAIResponseInput] | None, output_items: list[ConversationItem]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Sync content and response messages to the conversation."""
|
"""Sync content and response messages to the conversation."""
|
||||||
conversation_items = []
|
# Type as ConversationItem union to avoid list invariance issues
|
||||||
|
conversation_items: list[ConversationItem] = []
|
||||||
|
|
||||||
if isinstance(input, str):
|
if isinstance(input, str):
|
||||||
conversation_items.append(
|
conversation_items.append(
|
||||||
|
|
|
||||||
|
|
@ -111,7 +111,7 @@ class StreamingResponseOrchestrator:
|
||||||
text: OpenAIResponseText,
|
text: OpenAIResponseText,
|
||||||
max_infer_iters: int,
|
max_infer_iters: int,
|
||||||
tool_executor, # Will be the tool execution logic from the main class
|
tool_executor, # Will be the tool execution logic from the main class
|
||||||
instructions: str,
|
instructions: str | None,
|
||||||
safety_api,
|
safety_api,
|
||||||
guardrail_ids: list[str] | None = None,
|
guardrail_ids: list[str] | None = None,
|
||||||
prompt: OpenAIResponsePrompt | None = None,
|
prompt: OpenAIResponsePrompt | None = None,
|
||||||
|
|
@ -128,7 +128,9 @@ class StreamingResponseOrchestrator:
|
||||||
self.prompt = prompt
|
self.prompt = prompt
|
||||||
self.sequence_number = 0
|
self.sequence_number = 0
|
||||||
# Store MCP tool mapping that gets built during tool processing
|
# Store MCP tool mapping that gets built during tool processing
|
||||||
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = ctx.tool_context.previous_tools or {}
|
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = (
|
||||||
|
ctx.tool_context.previous_tools if ctx.tool_context else {}
|
||||||
|
)
|
||||||
# Track final messages after all tool executions
|
# Track final messages after all tool executions
|
||||||
self.final_messages: list[OpenAIMessageParam] = []
|
self.final_messages: list[OpenAIMessageParam] = []
|
||||||
# mapping for annotations
|
# mapping for annotations
|
||||||
|
|
@ -229,7 +231,8 @@ class StreamingResponseOrchestrator:
|
||||||
params = OpenAIChatCompletionRequestWithExtraBody(
|
params = OpenAIChatCompletionRequestWithExtraBody(
|
||||||
model=self.ctx.model,
|
model=self.ctx.model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=self.ctx.chat_tools,
|
# Pydantic models are dict-compatible but mypy treats them as distinct types
|
||||||
|
tools=self.ctx.chat_tools, # type: ignore[arg-type]
|
||||||
stream=True,
|
stream=True,
|
||||||
temperature=self.ctx.temperature,
|
temperature=self.ctx.temperature,
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
|
|
@ -272,7 +275,12 @@ class StreamingResponseOrchestrator:
|
||||||
|
|
||||||
# Handle choices with no tool calls
|
# Handle choices with no tool calls
|
||||||
for choice in current_response.choices:
|
for choice in current_response.choices:
|
||||||
if not (choice.message.tool_calls and self.ctx.response_tools):
|
has_tool_calls = (
|
||||||
|
isinstance(choice.message, OpenAIAssistantMessageParam)
|
||||||
|
and choice.message.tool_calls
|
||||||
|
and self.ctx.response_tools
|
||||||
|
)
|
||||||
|
if not has_tool_calls:
|
||||||
output_messages.append(
|
output_messages.append(
|
||||||
await convert_chat_choice_to_response_message(
|
await convert_chat_choice_to_response_message(
|
||||||
choice,
|
choice,
|
||||||
|
|
@ -722,7 +730,10 @@ class StreamingResponseOrchestrator:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Accumulate arguments for final response (only for subsequent chunks)
|
# Accumulate arguments for final response (only for subsequent chunks)
|
||||||
if not is_new_tool_call:
|
if not is_new_tool_call and response_tool_call is not None:
|
||||||
|
# Both should have functions since we're inside the tool_call.function check above
|
||||||
|
assert response_tool_call.function is not None
|
||||||
|
assert tool_call.function is not None
|
||||||
response_tool_call.function.arguments = (
|
response_tool_call.function.arguments = (
|
||||||
response_tool_call.function.arguments or ""
|
response_tool_call.function.arguments or ""
|
||||||
) + tool_call.function.arguments
|
) + tool_call.function.arguments
|
||||||
|
|
@ -747,10 +758,13 @@ class StreamingResponseOrchestrator:
|
||||||
for tool_call_index in sorted(chat_response_tool_calls.keys()):
|
for tool_call_index in sorted(chat_response_tool_calls.keys()):
|
||||||
tool_call = chat_response_tool_calls[tool_call_index]
|
tool_call = chat_response_tool_calls[tool_call_index]
|
||||||
# Ensure that arguments, if sent back to the inference provider, are not None
|
# Ensure that arguments, if sent back to the inference provider, are not None
|
||||||
|
if tool_call.function:
|
||||||
tool_call.function.arguments = tool_call.function.arguments or "{}"
|
tool_call.function.arguments = tool_call.function.arguments or "{}"
|
||||||
tool_call_item_id = tool_call_item_ids[tool_call_index]
|
tool_call_item_id = tool_call_item_ids[tool_call_index]
|
||||||
final_arguments = tool_call.function.arguments
|
final_arguments: str = tool_call.function.arguments or "{}" if tool_call.function else "{}"
|
||||||
tool_call_name = chat_response_tool_calls[tool_call_index].function.name
|
func = chat_response_tool_calls[tool_call_index].function
|
||||||
|
|
||||||
|
tool_call_name = func.name if func else ""
|
||||||
|
|
||||||
# Check if this is an MCP tool call
|
# Check if this is an MCP tool call
|
||||||
is_mcp_tool = tool_call_name and tool_call_name in self.mcp_tool_to_server
|
is_mcp_tool = tool_call_name and tool_call_name in self.mcp_tool_to_server
|
||||||
|
|
@ -894,12 +908,11 @@ class StreamingResponseOrchestrator:
|
||||||
|
|
||||||
self.sequence_number += 1
|
self.sequence_number += 1
|
||||||
if tool_call.function.name and tool_call.function.name in self.mcp_tool_to_server:
|
if tool_call.function.name and tool_call.function.name in self.mcp_tool_to_server:
|
||||||
item = OpenAIResponseOutputMessageMCPCall(
|
item: OpenAIResponseOutput = OpenAIResponseOutputMessageMCPCall(
|
||||||
arguments="",
|
arguments="",
|
||||||
name=tool_call.function.name,
|
name=tool_call.function.name,
|
||||||
id=matching_item_id,
|
id=matching_item_id,
|
||||||
server_label=self.mcp_tool_to_server[tool_call.function.name].server_label,
|
server_label=self.mcp_tool_to_server[tool_call.function.name].server_label,
|
||||||
status="in_progress",
|
|
||||||
)
|
)
|
||||||
elif tool_call.function.name == "web_search":
|
elif tool_call.function.name == "web_search":
|
||||||
item = OpenAIResponseOutputMessageWebSearchToolCall(
|
item = OpenAIResponseOutputMessageWebSearchToolCall(
|
||||||
|
|
@ -1008,7 +1021,7 @@ class StreamingResponseOrchestrator:
|
||||||
description=tool.description,
|
description=tool.description,
|
||||||
input_schema=tool.input_schema,
|
input_schema=tool.input_schema,
|
||||||
)
|
)
|
||||||
return convert_tooldef_to_openai_tool(tool_def)
|
return convert_tooldef_to_openai_tool(tool_def) # type: ignore[return-value] # Returns dict but ChatCompletionToolParam expects TypedDict
|
||||||
|
|
||||||
# Initialize chat_tools if not already set
|
# Initialize chat_tools if not already set
|
||||||
if self.ctx.chat_tools is None:
|
if self.ctx.chat_tools is None:
|
||||||
|
|
@ -1016,7 +1029,7 @@ class StreamingResponseOrchestrator:
|
||||||
|
|
||||||
for input_tool in tools:
|
for input_tool in tools:
|
||||||
if input_tool.type == "function":
|
if input_tool.type == "function":
|
||||||
self.ctx.chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump()))
|
self.ctx.chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump())) # type: ignore[typeddict-item,arg-type] # Dict compatible with FunctionDefinition
|
||||||
elif input_tool.type in WebSearchToolTypes:
|
elif input_tool.type in WebSearchToolTypes:
|
||||||
tool_name = "web_search"
|
tool_name = "web_search"
|
||||||
# Need to access tool_groups_api from tool_executor
|
# Need to access tool_groups_api from tool_executor
|
||||||
|
|
@ -1055,8 +1068,8 @@ class StreamingResponseOrchestrator:
|
||||||
if isinstance(mcp_tool.allowed_tools, list):
|
if isinstance(mcp_tool.allowed_tools, list):
|
||||||
always_allowed = mcp_tool.allowed_tools
|
always_allowed = mcp_tool.allowed_tools
|
||||||
elif isinstance(mcp_tool.allowed_tools, AllowedToolsFilter):
|
elif isinstance(mcp_tool.allowed_tools, AllowedToolsFilter):
|
||||||
always_allowed = mcp_tool.allowed_tools.always
|
# AllowedToolsFilter only has tool_names field (not allowed/disallowed)
|
||||||
never_allowed = mcp_tool.allowed_tools.never
|
always_allowed = mcp_tool.allowed_tools.tool_names
|
||||||
|
|
||||||
# Call list_mcp_tools
|
# Call list_mcp_tools
|
||||||
tool_defs = None
|
tool_defs = None
|
||||||
|
|
@ -1088,7 +1101,7 @@ class StreamingResponseOrchestrator:
|
||||||
openai_tool = convert_tooldef_to_chat_tool(t)
|
openai_tool = convert_tooldef_to_chat_tool(t)
|
||||||
if self.ctx.chat_tools is None:
|
if self.ctx.chat_tools is None:
|
||||||
self.ctx.chat_tools = []
|
self.ctx.chat_tools = []
|
||||||
self.ctx.chat_tools.append(openai_tool)
|
self.ctx.chat_tools.append(openai_tool) # type: ignore[arg-type] # Returns dict but ChatCompletionToolParam expects TypedDict
|
||||||
|
|
||||||
# Add to MCP tool mapping
|
# Add to MCP tool mapping
|
||||||
if t.name in self.mcp_tool_to_server:
|
if t.name in self.mcp_tool_to_server:
|
||||||
|
|
@ -1120,12 +1133,16 @@ class StreamingResponseOrchestrator:
|
||||||
self, output_messages: list[OpenAIResponseOutput]
|
self, output_messages: list[OpenAIResponseOutput]
|
||||||
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
) -> AsyncIterator[OpenAIResponseObjectStream]:
|
||||||
# Handle all mcp tool lists from previous response that are still valid:
|
# Handle all mcp tool lists from previous response that are still valid:
|
||||||
|
# tool_context can be None when no tools are provided in the response request
|
||||||
|
if self.ctx.tool_context:
|
||||||
for tool in self.ctx.tool_context.previous_tool_listings:
|
for tool in self.ctx.tool_context.previous_tool_listings:
|
||||||
async for evt in self._reuse_mcp_list_tools(tool, output_messages):
|
async for evt in self._reuse_mcp_list_tools(tool, output_messages):
|
||||||
yield evt
|
yield evt
|
||||||
# Process all remaining tools (including MCP tools) and emit streaming events
|
# Process all remaining tools (including MCP tools) and emit streaming events
|
||||||
if self.ctx.tool_context.tools_to_process:
|
if self.ctx.tool_context.tools_to_process:
|
||||||
async for stream_event in self._process_new_tools(self.ctx.tool_context.tools_to_process, output_messages):
|
async for stream_event in self._process_new_tools(
|
||||||
|
self.ctx.tool_context.tools_to_process, output_messages
|
||||||
|
):
|
||||||
yield stream_event
|
yield stream_event
|
||||||
|
|
||||||
def _approval_required(self, tool_name: str) -> bool:
|
def _approval_required(self, tool_name: str) -> bool:
|
||||||
|
|
@ -1220,7 +1237,7 @@ class StreamingResponseOrchestrator:
|
||||||
openai_tool = convert_tooldef_to_openai_tool(tool_def)
|
openai_tool = convert_tooldef_to_openai_tool(tool_def)
|
||||||
if self.ctx.chat_tools is None:
|
if self.ctx.chat_tools is None:
|
||||||
self.ctx.chat_tools = []
|
self.ctx.chat_tools = []
|
||||||
self.ctx.chat_tools.append(openai_tool)
|
self.ctx.chat_tools.append(openai_tool) # type: ignore[arg-type] # Returns dict but ChatCompletionToolParam expects TypedDict
|
||||||
|
|
||||||
mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
|
mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
|
||||||
id=f"mcp_list_{uuid.uuid4()}",
|
id=f"mcp_list_{uuid.uuid4()}",
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.apis.agents.openai_responses import (
|
from llama_stack.apis.agents.openai_responses import (
|
||||||
OpenAIResponseInputToolFileSearch,
|
OpenAIResponseInputToolFileSearch,
|
||||||
|
|
@ -22,6 +23,7 @@ from llama_stack.apis.agents.openai_responses import (
|
||||||
OpenAIResponseObjectStreamResponseWebSearchCallSearching,
|
OpenAIResponseObjectStreamResponseWebSearchCallSearching,
|
||||||
OpenAIResponseOutputMessageFileSearchToolCall,
|
OpenAIResponseOutputMessageFileSearchToolCall,
|
||||||
OpenAIResponseOutputMessageFileSearchToolCallResults,
|
OpenAIResponseOutputMessageFileSearchToolCallResults,
|
||||||
|
OpenAIResponseOutputMessageMCPCall,
|
||||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
|
|
@ -67,7 +69,7 @@ class ToolExecutor:
|
||||||
) -> AsyncIterator[ToolExecutionResult]:
|
) -> AsyncIterator[ToolExecutionResult]:
|
||||||
tool_call_id = tool_call.id
|
tool_call_id = tool_call.id
|
||||||
function = tool_call.function
|
function = tool_call.function
|
||||||
tool_kwargs = json.loads(function.arguments) if function.arguments else {}
|
tool_kwargs = json.loads(function.arguments) if function and function.arguments else {}
|
||||||
|
|
||||||
if not function or not tool_call_id or not function.name:
|
if not function or not tool_call_id or not function.name:
|
||||||
yield ToolExecutionResult(sequence_number=sequence_number)
|
yield ToolExecutionResult(sequence_number=sequence_number)
|
||||||
|
|
@ -84,7 +86,16 @@ class ToolExecutor:
|
||||||
error_exc, result = await self._execute_tool(function.name, tool_kwargs, ctx, mcp_tool_to_server)
|
error_exc, result = await self._execute_tool(function.name, tool_kwargs, ctx, mcp_tool_to_server)
|
||||||
|
|
||||||
# Emit completion events for tool execution
|
# Emit completion events for tool execution
|
||||||
has_error = error_exc or (result and ((result.error_code and result.error_code > 0) or result.error_message))
|
has_error = bool(
|
||||||
|
error_exc
|
||||||
|
or (
|
||||||
|
result
|
||||||
|
and (
|
||||||
|
((error_code := getattr(result, "error_code", None)) and error_code > 0)
|
||||||
|
or getattr(result, "error_message", None)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
async for event_result in self._emit_completion_events(
|
async for event_result in self._emit_completion_events(
|
||||||
function.name, ctx, sequence_number, output_index, item_id, has_error, mcp_tool_to_server
|
function.name, ctx, sequence_number, output_index, item_id, has_error, mcp_tool_to_server
|
||||||
):
|
):
|
||||||
|
|
@ -101,7 +112,9 @@ class ToolExecutor:
|
||||||
sequence_number=sequence_number,
|
sequence_number=sequence_number,
|
||||||
final_output_message=output_message,
|
final_output_message=output_message,
|
||||||
final_input_message=input_message,
|
final_input_message=input_message,
|
||||||
citation_files=result.metadata.get("citation_files") if result and result.metadata else None,
|
citation_files=(
|
||||||
|
metadata.get("citation_files") if result and (metadata := getattr(result, "metadata", None)) else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _execute_knowledge_search_via_vector_store(
|
async def _execute_knowledge_search_via_vector_store(
|
||||||
|
|
@ -188,8 +201,9 @@ class ToolExecutor:
|
||||||
|
|
||||||
citation_files[file_id] = filename
|
citation_files[file_id] = filename
|
||||||
|
|
||||||
|
# Cast to proper InterleavedContent type (list invariance)
|
||||||
return ToolInvocationResult(
|
return ToolInvocationResult(
|
||||||
content=content_items,
|
content=content_items, # type: ignore[arg-type]
|
||||||
metadata={
|
metadata={
|
||||||
"document_ids": [r.file_id for r in search_results],
|
"document_ids": [r.file_id for r in search_results],
|
||||||
"chunks": [r.content[0].text if r.content else "" for r in search_results],
|
"chunks": [r.content[0].text if r.content else "" for r in search_results],
|
||||||
|
|
@ -209,51 +223,60 @@ class ToolExecutor:
|
||||||
) -> AsyncIterator[ToolExecutionResult]:
|
) -> AsyncIterator[ToolExecutionResult]:
|
||||||
"""Emit progress events for tool execution start."""
|
"""Emit progress events for tool execution start."""
|
||||||
# Emit in_progress event based on tool type (only for tools with specific streaming events)
|
# Emit in_progress event based on tool type (only for tools with specific streaming events)
|
||||||
progress_event = None
|
|
||||||
if mcp_tool_to_server and function_name in mcp_tool_to_server:
|
if mcp_tool_to_server and function_name in mcp_tool_to_server:
|
||||||
sequence_number += 1
|
sequence_number += 1
|
||||||
progress_event = OpenAIResponseObjectStreamResponseMcpCallInProgress(
|
yield ToolExecutionResult(
|
||||||
|
stream_event=OpenAIResponseObjectStreamResponseMcpCallInProgress(
|
||||||
item_id=item_id,
|
item_id=item_id,
|
||||||
output_index=output_index,
|
output_index=output_index,
|
||||||
sequence_number=sequence_number,
|
sequence_number=sequence_number,
|
||||||
|
),
|
||||||
|
sequence_number=sequence_number,
|
||||||
)
|
)
|
||||||
elif function_name == "web_search":
|
elif function_name == "web_search":
|
||||||
sequence_number += 1
|
sequence_number += 1
|
||||||
progress_event = OpenAIResponseObjectStreamResponseWebSearchCallInProgress(
|
yield ToolExecutionResult(
|
||||||
|
stream_event=OpenAIResponseObjectStreamResponseWebSearchCallInProgress(
|
||||||
item_id=item_id,
|
item_id=item_id,
|
||||||
output_index=output_index,
|
output_index=output_index,
|
||||||
sequence_number=sequence_number,
|
sequence_number=sequence_number,
|
||||||
|
),
|
||||||
|
sequence_number=sequence_number,
|
||||||
)
|
)
|
||||||
elif function_name == "knowledge_search":
|
elif function_name == "knowledge_search":
|
||||||
sequence_number += 1
|
sequence_number += 1
|
||||||
progress_event = OpenAIResponseObjectStreamResponseFileSearchCallInProgress(
|
yield ToolExecutionResult(
|
||||||
|
stream_event=OpenAIResponseObjectStreamResponseFileSearchCallInProgress(
|
||||||
item_id=item_id,
|
item_id=item_id,
|
||||||
output_index=output_index,
|
output_index=output_index,
|
||||||
sequence_number=sequence_number,
|
sequence_number=sequence_number,
|
||||||
|
),
|
||||||
|
sequence_number=sequence_number,
|
||||||
)
|
)
|
||||||
|
|
||||||
if progress_event:
|
|
||||||
yield ToolExecutionResult(stream_event=progress_event, sequence_number=sequence_number)
|
|
||||||
|
|
||||||
# For web search, emit searching event
|
# For web search, emit searching event
|
||||||
if function_name == "web_search":
|
if function_name == "web_search":
|
||||||
sequence_number += 1
|
sequence_number += 1
|
||||||
searching_event = OpenAIResponseObjectStreamResponseWebSearchCallSearching(
|
yield ToolExecutionResult(
|
||||||
|
stream_event=OpenAIResponseObjectStreamResponseWebSearchCallSearching(
|
||||||
item_id=item_id,
|
item_id=item_id,
|
||||||
output_index=output_index,
|
output_index=output_index,
|
||||||
sequence_number=sequence_number,
|
sequence_number=sequence_number,
|
||||||
|
),
|
||||||
|
sequence_number=sequence_number,
|
||||||
)
|
)
|
||||||
yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number)
|
|
||||||
|
|
||||||
# For file search, emit searching event
|
# For file search, emit searching event
|
||||||
if function_name == "knowledge_search":
|
if function_name == "knowledge_search":
|
||||||
sequence_number += 1
|
sequence_number += 1
|
||||||
searching_event = OpenAIResponseObjectStreamResponseFileSearchCallSearching(
|
yield ToolExecutionResult(
|
||||||
|
stream_event=OpenAIResponseObjectStreamResponseFileSearchCallSearching(
|
||||||
item_id=item_id,
|
item_id=item_id,
|
||||||
output_index=output_index,
|
output_index=output_index,
|
||||||
sequence_number=sequence_number,
|
sequence_number=sequence_number,
|
||||||
|
),
|
||||||
|
sequence_number=sequence_number,
|
||||||
)
|
)
|
||||||
yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number)
|
|
||||||
|
|
||||||
async def _execute_tool(
|
async def _execute_tool(
|
||||||
self,
|
self,
|
||||||
|
|
@ -261,7 +284,7 @@ class ToolExecutor:
|
||||||
tool_kwargs: dict,
|
tool_kwargs: dict,
|
||||||
ctx: ChatCompletionContext,
|
ctx: ChatCompletionContext,
|
||||||
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
||||||
) -> tuple[Exception | None, any]:
|
) -> tuple[Exception | None, Any]:
|
||||||
"""Execute the tool and return error exception and result."""
|
"""Execute the tool and return error exception and result."""
|
||||||
error_exc = None
|
error_exc = None
|
||||||
result = None
|
result = None
|
||||||
|
|
@ -284,10 +307,14 @@ class ToolExecutor:
|
||||||
kwargs=tool_kwargs,
|
kwargs=tool_kwargs,
|
||||||
)
|
)
|
||||||
elif function_name == "knowledge_search":
|
elif function_name == "knowledge_search":
|
||||||
response_file_search_tool = next(
|
response_file_search_tool = (
|
||||||
|
next(
|
||||||
(t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)),
|
(t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
if ctx.response_tools
|
||||||
|
else None
|
||||||
|
)
|
||||||
if response_file_search_tool:
|
if response_file_search_tool:
|
||||||
# Use vector_stores.search API instead of knowledge_search tool
|
# Use vector_stores.search API instead of knowledge_search tool
|
||||||
# to support filters and ranking_options
|
# to support filters and ranking_options
|
||||||
|
|
@ -322,35 +349,34 @@ class ToolExecutor:
|
||||||
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
||||||
) -> AsyncIterator[ToolExecutionResult]:
|
) -> AsyncIterator[ToolExecutionResult]:
|
||||||
"""Emit completion or failure events for tool execution."""
|
"""Emit completion or failure events for tool execution."""
|
||||||
completion_event = None
|
|
||||||
|
|
||||||
if mcp_tool_to_server and function_name in mcp_tool_to_server:
|
if mcp_tool_to_server and function_name in mcp_tool_to_server:
|
||||||
sequence_number += 1
|
sequence_number += 1
|
||||||
if has_error:
|
if has_error:
|
||||||
completion_event = OpenAIResponseObjectStreamResponseMcpCallFailed(
|
mcp_failed_event = OpenAIResponseObjectStreamResponseMcpCallFailed(
|
||||||
sequence_number=sequence_number,
|
sequence_number=sequence_number,
|
||||||
)
|
)
|
||||||
|
yield ToolExecutionResult(stream_event=mcp_failed_event, sequence_number=sequence_number)
|
||||||
else:
|
else:
|
||||||
completion_event = OpenAIResponseObjectStreamResponseMcpCallCompleted(
|
mcp_completed_event = OpenAIResponseObjectStreamResponseMcpCallCompleted(
|
||||||
sequence_number=sequence_number,
|
sequence_number=sequence_number,
|
||||||
)
|
)
|
||||||
|
yield ToolExecutionResult(stream_event=mcp_completed_event, sequence_number=sequence_number)
|
||||||
elif function_name == "web_search":
|
elif function_name == "web_search":
|
||||||
sequence_number += 1
|
sequence_number += 1
|
||||||
completion_event = OpenAIResponseObjectStreamResponseWebSearchCallCompleted(
|
web_completion_event = OpenAIResponseObjectStreamResponseWebSearchCallCompleted(
|
||||||
item_id=item_id,
|
item_id=item_id,
|
||||||
output_index=output_index,
|
output_index=output_index,
|
||||||
sequence_number=sequence_number,
|
sequence_number=sequence_number,
|
||||||
)
|
)
|
||||||
|
yield ToolExecutionResult(stream_event=web_completion_event, sequence_number=sequence_number)
|
||||||
elif function_name == "knowledge_search":
|
elif function_name == "knowledge_search":
|
||||||
sequence_number += 1
|
sequence_number += 1
|
||||||
completion_event = OpenAIResponseObjectStreamResponseFileSearchCallCompleted(
|
file_completion_event = OpenAIResponseObjectStreamResponseFileSearchCallCompleted(
|
||||||
item_id=item_id,
|
item_id=item_id,
|
||||||
output_index=output_index,
|
output_index=output_index,
|
||||||
sequence_number=sequence_number,
|
sequence_number=sequence_number,
|
||||||
)
|
)
|
||||||
|
yield ToolExecutionResult(stream_event=file_completion_event, sequence_number=sequence_number)
|
||||||
if completion_event:
|
|
||||||
yield ToolExecutionResult(stream_event=completion_event, sequence_number=sequence_number)
|
|
||||||
|
|
||||||
async def _build_result_messages(
|
async def _build_result_messages(
|
||||||
self,
|
self,
|
||||||
|
|
@ -360,21 +386,18 @@ class ToolExecutor:
|
||||||
tool_kwargs: dict,
|
tool_kwargs: dict,
|
||||||
ctx: ChatCompletionContext,
|
ctx: ChatCompletionContext,
|
||||||
error_exc: Exception | None,
|
error_exc: Exception | None,
|
||||||
result: any,
|
result: Any,
|
||||||
has_error: bool,
|
has_error: bool,
|
||||||
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
|
||||||
) -> tuple[any, any]:
|
) -> tuple[Any, Any]:
|
||||||
"""Build output and input messages from tool execution results."""
|
"""Build output and input messages from tool execution results."""
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
interleaved_content_as_str,
|
interleaved_content_as_str,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build output message
|
# Build output message
|
||||||
|
message: Any
|
||||||
if mcp_tool_to_server and function.name in mcp_tool_to_server:
|
if mcp_tool_to_server and function.name in mcp_tool_to_server:
|
||||||
from llama_stack.apis.agents.openai_responses import (
|
|
||||||
OpenAIResponseOutputMessageMCPCall,
|
|
||||||
)
|
|
||||||
|
|
||||||
message = OpenAIResponseOutputMessageMCPCall(
|
message = OpenAIResponseOutputMessageMCPCall(
|
||||||
id=item_id,
|
id=item_id,
|
||||||
arguments=function.arguments,
|
arguments=function.arguments,
|
||||||
|
|
@ -383,10 +406,14 @@ class ToolExecutor:
|
||||||
)
|
)
|
||||||
if error_exc:
|
if error_exc:
|
||||||
message.error = str(error_exc)
|
message.error = str(error_exc)
|
||||||
elif (result and result.error_code and result.error_code > 0) or (result and result.error_message):
|
elif (result and (error_code := getattr(result, "error_code", None)) and error_code > 0) or (
|
||||||
message.error = f"Error (code {result.error_code}): {result.error_message}"
|
result and getattr(result, "error_message", None)
|
||||||
elif result and result.content:
|
):
|
||||||
message.output = interleaved_content_as_str(result.content)
|
ec = getattr(result, "error_code", "unknown")
|
||||||
|
em = getattr(result, "error_message", "")
|
||||||
|
message.error = f"Error (code {ec}): {em}"
|
||||||
|
elif result and (content := getattr(result, "content", None)):
|
||||||
|
message.output = interleaved_content_as_str(content)
|
||||||
else:
|
else:
|
||||||
if function.name == "web_search":
|
if function.name == "web_search":
|
||||||
message = OpenAIResponseOutputMessageWebSearchToolCall(
|
message = OpenAIResponseOutputMessageWebSearchToolCall(
|
||||||
|
|
@ -401,17 +428,17 @@ class ToolExecutor:
|
||||||
queries=[tool_kwargs.get("query", "")],
|
queries=[tool_kwargs.get("query", "")],
|
||||||
status="completed",
|
status="completed",
|
||||||
)
|
)
|
||||||
if result and "document_ids" in result.metadata:
|
if result and (metadata := getattr(result, "metadata", None)) and "document_ids" in metadata:
|
||||||
message.results = []
|
message.results = []
|
||||||
for i, doc_id in enumerate(result.metadata["document_ids"]):
|
for i, doc_id in enumerate(metadata["document_ids"]):
|
||||||
text = result.metadata["chunks"][i] if "chunks" in result.metadata else None
|
text = metadata["chunks"][i] if "chunks" in metadata else None
|
||||||
score = result.metadata["scores"][i] if "scores" in result.metadata else None
|
score = metadata["scores"][i] if "scores" in metadata else None
|
||||||
message.results.append(
|
message.results.append(
|
||||||
OpenAIResponseOutputMessageFileSearchToolCallResults(
|
OpenAIResponseOutputMessageFileSearchToolCallResults(
|
||||||
file_id=doc_id,
|
file_id=doc_id,
|
||||||
filename=doc_id,
|
filename=doc_id,
|
||||||
text=text,
|
text=text if text is not None else "",
|
||||||
score=score,
|
score=score if score is not None else 0.0,
|
||||||
attributes={},
|
attributes={},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
@ -421,27 +448,32 @@ class ToolExecutor:
|
||||||
raise ValueError(f"Unknown tool {function.name} called")
|
raise ValueError(f"Unknown tool {function.name} called")
|
||||||
|
|
||||||
# Build input message
|
# Build input message
|
||||||
input_message = None
|
input_message: OpenAIToolMessageParam | None = None
|
||||||
if result and result.content:
|
if result and (result_content := getattr(result, "content", None)):
|
||||||
if isinstance(result.content, str):
|
# all the mypy contortions here are still unsatisfactory with random Any typing
|
||||||
content = result.content
|
if isinstance(result_content, str):
|
||||||
elif isinstance(result.content, list):
|
msg_content: str | list[Any] = result_content
|
||||||
content = []
|
elif isinstance(result_content, list):
|
||||||
for item in result.content:
|
content_list: list[Any] = []
|
||||||
|
for item in result_content:
|
||||||
|
part: Any
|
||||||
if isinstance(item, TextContentItem):
|
if isinstance(item, TextContentItem):
|
||||||
part = OpenAIChatCompletionContentPartTextParam(text=item.text)
|
part = OpenAIChatCompletionContentPartTextParam(text=item.text)
|
||||||
elif isinstance(item, ImageContentItem):
|
elif isinstance(item, ImageContentItem):
|
||||||
if item.image.data:
|
if item.image.data:
|
||||||
url = f"data:image;base64,{item.image.data}"
|
url_value = f"data:image;base64,{item.image.data}"
|
||||||
else:
|
else:
|
||||||
url = item.image.url
|
url_value = str(item.image.url) if item.image.url else ""
|
||||||
part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url))
|
part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url_value))
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown result content type: {type(item)}")
|
raise ValueError(f"Unknown result content type: {type(item)}")
|
||||||
content.append(part)
|
content_list.append(part)
|
||||||
|
msg_content = content_list
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown result content type: {type(result.content)}")
|
raise ValueError(f"Unknown result content type: {type(result_content)}")
|
||||||
input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id)
|
# OpenAIToolMessageParam accepts str | list[TextParam] but we may have images
|
||||||
|
# This is runtime-safe as the API accepts it, but mypy complains
|
||||||
|
input_message = OpenAIToolMessageParam(content=msg_content, tool_call_id=tool_call_id) # type: ignore[arg-type]
|
||||||
else:
|
else:
|
||||||
text = str(error_exc) if error_exc else "Tool execution failed"
|
text = str(error_exc) if error_exc else "Tool execution failed"
|
||||||
input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id)
|
input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id)
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
from openai.types.chat import ChatCompletionToolParam
|
from openai.types.chat import ChatCompletionToolParam
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
@ -100,17 +101,19 @@ class ToolContext(BaseModel):
|
||||||
if isinstance(tool, OpenAIResponseToolMCP):
|
if isinstance(tool, OpenAIResponseToolMCP):
|
||||||
previous_tools_by_label[tool.server_label] = tool
|
previous_tools_by_label[tool.server_label] = tool
|
||||||
# collect tool definitions which are the same in current and previous requests:
|
# collect tool definitions which are the same in current and previous requests:
|
||||||
tools_to_process = []
|
tools_to_process: list[OpenAIResponseInputTool] = []
|
||||||
matched: dict[str, OpenAIResponseInputToolMCP] = {}
|
matched: dict[str, OpenAIResponseInputToolMCP] = {}
|
||||||
for tool in self.current_tools:
|
# Mypy confuses OpenAIResponseInputTool (Input union) with OpenAIResponseTool (output union)
|
||||||
|
# which differ only in MCP type (InputToolMCP vs ToolMCP). Code is correct.
|
||||||
|
for tool in cast(list[OpenAIResponseInputTool], self.current_tools): # type: ignore[assignment]
|
||||||
if isinstance(tool, OpenAIResponseInputToolMCP) and tool.server_label in previous_tools_by_label:
|
if isinstance(tool, OpenAIResponseInputToolMCP) and tool.server_label in previous_tools_by_label:
|
||||||
previous_tool = previous_tools_by_label[tool.server_label]
|
previous_tool = previous_tools_by_label[tool.server_label]
|
||||||
if previous_tool.allowed_tools == tool.allowed_tools:
|
if previous_tool.allowed_tools == tool.allowed_tools:
|
||||||
matched[tool.server_label] = tool
|
matched[tool.server_label] = tool
|
||||||
else:
|
else:
|
||||||
tools_to_process.append(tool)
|
tools_to_process.append(tool) # type: ignore[arg-type]
|
||||||
else:
|
else:
|
||||||
tools_to_process.append(tool)
|
tools_to_process.append(tool) # type: ignore[arg-type]
|
||||||
# tools that are not the same or were not previously defined need to be processed:
|
# tools that are not the same or were not previously defined need to be processed:
|
||||||
self.tools_to_process = tools_to_process
|
self.tools_to_process = tools_to_process
|
||||||
# for all matched definitions, get the mcp_list_tools objects from the previous output:
|
# for all matched definitions, get the mcp_list_tools objects from the previous output:
|
||||||
|
|
@ -119,9 +122,11 @@ class ToolContext(BaseModel):
|
||||||
]
|
]
|
||||||
# reconstruct the tool to server mappings that can be reused:
|
# reconstruct the tool to server mappings that can be reused:
|
||||||
for listing in self.previous_tool_listings:
|
for listing in self.previous_tool_listings:
|
||||||
|
# listing is OpenAIResponseOutputMessageMCPListTools which has tools: list[MCPListToolsTool]
|
||||||
definition = matched[listing.server_label]
|
definition = matched[listing.server_label]
|
||||||
for tool in listing.tools:
|
for mcp_tool in listing.tools:
|
||||||
self.previous_tools[tool.name] = definition
|
# mcp_tool is MCPListToolsTool which has a name: str field
|
||||||
|
self.previous_tools[mcp_tool.name] = definition
|
||||||
|
|
||||||
def available_tools(self) -> list[OpenAIResponseTool]:
|
def available_tools(self) -> list[OpenAIResponseTool]:
|
||||||
if not self.current_tools:
|
if not self.current_tools:
|
||||||
|
|
@ -139,6 +144,8 @@ class ToolContext(BaseModel):
|
||||||
server_label=tool.server_label,
|
server_label=tool.server_label,
|
||||||
allowed_tools=tool.allowed_tools,
|
allowed_tools=tool.allowed_tools,
|
||||||
)
|
)
|
||||||
|
# Exhaustive check - all tool types should be handled above
|
||||||
|
raise AssertionError(f"Unexpected tool type: {type(tool)}")
|
||||||
|
|
||||||
return [convert_tool(tool) for tool in self.current_tools]
|
return [convert_tool(tool) for tool in self.current_tools]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
from llama_stack.apis.agents.agents import ResponseGuardrailSpec
|
from llama_stack.apis.agents.agents import ResponseGuardrailSpec
|
||||||
from llama_stack.apis.agents.openai_responses import (
|
from llama_stack.apis.agents.openai_responses import (
|
||||||
|
|
@ -71,14 +72,14 @@ async def convert_chat_choice_to_response_message(
|
||||||
|
|
||||||
return OpenAIResponseMessage(
|
return OpenAIResponseMessage(
|
||||||
id=message_id or f"msg_{uuid.uuid4()}",
|
id=message_id or f"msg_{uuid.uuid4()}",
|
||||||
content=[OpenAIResponseOutputMessageContentOutputText(text=clean_text, annotations=annotations)],
|
content=[OpenAIResponseOutputMessageContentOutputText(text=clean_text, annotations=list(annotations))],
|
||||||
status="completed",
|
status="completed",
|
||||||
role="assistant",
|
role="assistant",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def convert_response_content_to_chat_content(
|
async def convert_response_content_to_chat_content(
|
||||||
content: (str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]),
|
content: str | Sequence[OpenAIResponseInputMessageContent | OpenAIResponseOutputMessageContent],
|
||||||
) -> str | list[OpenAIChatCompletionContentPartParam]:
|
) -> str | list[OpenAIChatCompletionContentPartParam]:
|
||||||
"""
|
"""
|
||||||
Convert the content parts from an OpenAI Response API request into OpenAI Chat Completion content parts.
|
Convert the content parts from an OpenAI Response API request into OpenAI Chat Completion content parts.
|
||||||
|
|
@ -88,7 +89,8 @@ async def convert_response_content_to_chat_content(
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
return content
|
return content
|
||||||
|
|
||||||
converted_parts = []
|
# Type with union to avoid list invariance issues
|
||||||
|
converted_parts: list[OpenAIChatCompletionContentPartParam] = []
|
||||||
for content_part in content:
|
for content_part in content:
|
||||||
if isinstance(content_part, OpenAIResponseInputMessageContentText):
|
if isinstance(content_part, OpenAIResponseInputMessageContentText):
|
||||||
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
|
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
|
||||||
|
|
@ -158,9 +160,11 @@ async def convert_response_input_to_chat_messages(
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
|
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
|
||||||
|
# Output can be None, use empty string as fallback
|
||||||
|
output_content = input_item.output if input_item.output is not None else ""
|
||||||
messages.append(
|
messages.append(
|
||||||
OpenAIToolMessageParam(
|
OpenAIToolMessageParam(
|
||||||
content=input_item.output,
|
content=output_content,
|
||||||
tool_call_id=input_item.id,
|
tool_call_id=input_item.id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
@ -172,7 +176,8 @@ async def convert_response_input_to_chat_messages(
|
||||||
):
|
):
|
||||||
# these are handled by the responses impl itself and not pass through to chat completions
|
# these are handled by the responses impl itself and not pass through to chat completions
|
||||||
pass
|
pass
|
||||||
else:
|
elif isinstance(input_item, OpenAIResponseMessage):
|
||||||
|
# Narrow type to OpenAIResponseMessage which has content and role attributes
|
||||||
content = await convert_response_content_to_chat_content(input_item.content)
|
content = await convert_response_content_to_chat_content(input_item.content)
|
||||||
message_type = await get_message_type_by_role(input_item.role)
|
message_type = await get_message_type_by_role(input_item.role)
|
||||||
if message_type is None:
|
if message_type is None:
|
||||||
|
|
@ -191,7 +196,8 @@ async def convert_response_input_to_chat_messages(
|
||||||
last_user_content = getattr(last_user_msg, "content", None)
|
last_user_content = getattr(last_user_msg, "content", None)
|
||||||
if last_user_content == content:
|
if last_user_content == content:
|
||||||
continue # Skip duplicate user message
|
continue # Skip duplicate user message
|
||||||
messages.append(message_type(content=content))
|
# Dynamic message type call - different message types have different content expectations
|
||||||
|
messages.append(message_type(content=content)) # type: ignore[call-arg,arg-type]
|
||||||
if len(tool_call_results):
|
if len(tool_call_results):
|
||||||
# Check if unpaired function_call_outputs reference function_calls from previous messages
|
# Check if unpaired function_call_outputs reference function_calls from previous messages
|
||||||
if previous_messages:
|
if previous_messages:
|
||||||
|
|
@ -237,8 +243,11 @@ async def convert_response_text_to_chat_response_format(
|
||||||
if text.format["type"] == "json_object":
|
if text.format["type"] == "json_object":
|
||||||
return OpenAIResponseFormatJSONObject()
|
return OpenAIResponseFormatJSONObject()
|
||||||
if text.format["type"] == "json_schema":
|
if text.format["type"] == "json_schema":
|
||||||
|
# Assert name exists for json_schema format
|
||||||
|
assert text.format.get("name"), "json_schema format requires a name"
|
||||||
|
schema_name: str = text.format["name"] # type: ignore[assignment]
|
||||||
return OpenAIResponseFormatJSONSchema(
|
return OpenAIResponseFormatJSONSchema(
|
||||||
json_schema=OpenAIJSONSchema(name=text.format["name"], schema=text.format["schema"])
|
json_schema=OpenAIJSONSchema(name=schema_name, schema=text.format["schema"])
|
||||||
)
|
)
|
||||||
raise ValueError(f"Unsupported text format: {text.format}")
|
raise ValueError(f"Unsupported text format: {text.format}")
|
||||||
|
|
||||||
|
|
@ -251,7 +260,7 @@ async def get_message_type_by_role(role: str) -> type[OpenAIMessageParam] | None
|
||||||
"assistant": OpenAIAssistantMessageParam,
|
"assistant": OpenAIAssistantMessageParam,
|
||||||
"developer": OpenAIDeveloperMessageParam,
|
"developer": OpenAIDeveloperMessageParam,
|
||||||
}
|
}
|
||||||
return role_to_type.get(role)
|
return role_to_type.get(role) # type: ignore[return-value] # Pydantic models use ModelMetaclass
|
||||||
|
|
||||||
|
|
||||||
def _extract_citations_from_text(
|
def _extract_citations_from_text(
|
||||||
|
|
@ -320,7 +329,8 @@ async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[
|
||||||
|
|
||||||
# Look up shields to get their provider_resource_id (actual model ID)
|
# Look up shields to get their provider_resource_id (actual model ID)
|
||||||
model_ids = []
|
model_ids = []
|
||||||
shields_list = await safety_api.routing_table.list_shields()
|
# TODO: list_shields not in Safety interface but available at runtime via API routing
|
||||||
|
shields_list = await safety_api.routing_table.list_shields() # type: ignore[attr-defined]
|
||||||
|
|
||||||
for guardrail_id in guardrail_ids:
|
for guardrail_id in guardrail_ids:
|
||||||
matching_shields = [shield for shield in shields_list.data if shield.identifier == guardrail_id]
|
matching_shields = [shield for shield in shields_list.data if shield.identifier == guardrail_id]
|
||||||
|
|
@ -337,7 +347,9 @@ async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[
|
||||||
for result in response.results:
|
for result in response.results:
|
||||||
if result.flagged:
|
if result.flagged:
|
||||||
message = result.user_message or "Content blocked by safety guardrails"
|
message = result.user_message or "Content blocked by safety guardrails"
|
||||||
flagged_categories = [cat for cat, flagged in result.categories.items() if flagged]
|
flagged_categories = (
|
||||||
|
[cat for cat, flagged in result.categories.items() if flagged] if result.categories else []
|
||||||
|
)
|
||||||
violation_type = result.metadata.get("violation_type", []) if result.metadata else []
|
violation_type = result.metadata.get("violation_type", []) if result.metadata else []
|
||||||
|
|
||||||
if flagged_categories:
|
if flagged_categories:
|
||||||
|
|
@ -347,6 +359,9 @@ async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[
|
||||||
|
|
||||||
return message
|
return message
|
||||||
|
|
||||||
|
# No violations found
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def extract_guardrail_ids(guardrails: list | None) -> list[str]:
|
def extract_guardrail_ids(guardrails: list | None) -> list[str]:
|
||||||
"""Extract guardrail IDs from guardrails parameter, handling both string IDs and ResponseGuardrailSpec objects."""
|
"""Extract guardrail IDs from guardrails parameter, handling both string IDs and ResponseGuardrailSpec objects."""
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from llama_stack.apis.inference import Message
|
from llama_stack.apis.inference import OpenAIMessageParam
|
||||||
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
|
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
|
||||||
from llama_stack.core.telemetry import tracing
|
from llama_stack.core.telemetry import tracing
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
@ -31,7 +31,7 @@ class ShieldRunnerMixin:
|
||||||
self.input_shields = input_shields
|
self.input_shields = input_shields
|
||||||
self.output_shields = output_shields
|
self.output_shields = output_shields
|
||||||
|
|
||||||
async def run_multiple_shields(self, messages: list[Message], identifiers: list[str]) -> None:
|
async def run_multiple_shields(self, messages: list[OpenAIMessageParam], identifiers: list[str]) -> None:
|
||||||
async def run_shield_with_span(identifier: str):
|
async def run_shield_with_span(identifier: str):
|
||||||
async with tracing.span(f"run_shield_{identifier}"):
|
async with tracing.span(f"run_shield_{identifier}"):
|
||||||
return await self.safety_api.run_shield(
|
return await self.safety_api.run_shield(
|
||||||
|
|
|
||||||
|
|
@ -28,4 +28,13 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig",
|
config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig",
|
||||||
description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.",
|
description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.",
|
||||||
),
|
),
|
||||||
|
RemoteProviderSpec(
|
||||||
|
api=Api.files,
|
||||||
|
provider_type="remote::openai",
|
||||||
|
adapter_type="openai",
|
||||||
|
pip_packages=["openai"] + sql_store_pip_packages,
|
||||||
|
module="llama_stack.providers.remote.files.openai",
|
||||||
|
config_class="llama_stack.providers.remote.files.openai.config.OpenAIFilesImplConfig",
|
||||||
|
description="OpenAI Files API provider for managing files through OpenAI's native file storage service.",
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
|
||||||
19
src/llama_stack/providers/remote/files/openai/__init__.py
Normal file
19
src/llama_stack/providers/remote/files/openai/__init__.py
Normal file
|
|
@ -0,0 +1,19 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.core.datatypes import AccessRule, Api
|
||||||
|
|
||||||
|
from .config import OpenAIFilesImplConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: OpenAIFilesImplConfig, deps: dict[Api, Any], policy: list[AccessRule] | None = None):
|
||||||
|
from .files import OpenAIFilesImpl
|
||||||
|
|
||||||
|
impl = OpenAIFilesImpl(config, policy or [])
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
||||||
28
src/llama_stack/providers/remote/files/openai/config.py
Normal file
28
src/llama_stack/providers/remote/files/openai/config.py
Normal file
|
|
@ -0,0 +1,28 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.core.storage.datatypes import SqlStoreReference
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIFilesImplConfig(BaseModel):
|
||||||
|
"""Configuration for OpenAI Files API provider."""
|
||||||
|
|
||||||
|
api_key: str = Field(description="OpenAI API key for authentication")
|
||||||
|
metadata_store: SqlStoreReference = Field(description="SQL store configuration for file metadata")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"api_key": "${env.OPENAI_API_KEY}",
|
||||||
|
"metadata_store": SqlStoreReference(
|
||||||
|
backend="sql_default",
|
||||||
|
table_name="openai_files_metadata",
|
||||||
|
).model_dump(exclude_none=True),
|
||||||
|
}
|
||||||
239
src/llama_stack/providers/remote/files/openai/files.py
Normal file
239
src/llama_stack/providers/remote/files/openai/files.py
Normal file
|
|
@ -0,0 +1,239 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from typing import Annotated, Any
|
||||||
|
|
||||||
|
from fastapi import Depends, File, Form, Response, UploadFile
|
||||||
|
|
||||||
|
from llama_stack.apis.common.errors import ResourceNotFoundError
|
||||||
|
from llama_stack.apis.common.responses import Order
|
||||||
|
from llama_stack.apis.files import (
|
||||||
|
ExpiresAfter,
|
||||||
|
Files,
|
||||||
|
ListOpenAIFileResponse,
|
||||||
|
OpenAIFileDeleteResponse,
|
||||||
|
OpenAIFileObject,
|
||||||
|
OpenAIFilePurpose,
|
||||||
|
)
|
||||||
|
from llama_stack.core.datatypes import AccessRule
|
||||||
|
from llama_stack.providers.utils.files.form_data import parse_expires_after
|
||||||
|
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||||
|
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||||
|
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
from .config import OpenAIFilesImplConfig
|
||||||
|
|
||||||
|
|
||||||
|
def _make_file_object(
|
||||||
|
*,
|
||||||
|
id: str,
|
||||||
|
filename: str,
|
||||||
|
purpose: str,
|
||||||
|
bytes: int,
|
||||||
|
created_at: int,
|
||||||
|
expires_at: int,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> OpenAIFileObject:
|
||||||
|
"""
|
||||||
|
Construct an OpenAIFileObject and normalize expires_at.
|
||||||
|
|
||||||
|
If expires_at is greater than the max we treat it as no-expiration and
|
||||||
|
return None for expires_at.
|
||||||
|
"""
|
||||||
|
obj = OpenAIFileObject(
|
||||||
|
id=id,
|
||||||
|
filename=filename,
|
||||||
|
purpose=OpenAIFilePurpose(purpose),
|
||||||
|
bytes=bytes,
|
||||||
|
created_at=created_at,
|
||||||
|
expires_at=expires_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
if obj.expires_at is not None and obj.expires_at > (obj.created_at + ExpiresAfter.MAX):
|
||||||
|
obj.expires_at = None # type: ignore
|
||||||
|
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIFilesImpl(Files):
|
||||||
|
"""OpenAI Files API implementation."""
|
||||||
|
|
||||||
|
def __init__(self, config: OpenAIFilesImplConfig, policy: list[AccessRule]) -> None:
|
||||||
|
self._config = config
|
||||||
|
self.policy = policy
|
||||||
|
self._client: OpenAI | None = None
|
||||||
|
self._sql_store: AuthorizedSqlStore | None = None
|
||||||
|
|
||||||
|
def _now(self) -> int:
|
||||||
|
"""Return current UTC timestamp as int seconds."""
|
||||||
|
return int(datetime.now(UTC).timestamp())
|
||||||
|
|
||||||
|
async def _get_file(self, file_id: str, return_expired: bool = False) -> dict[str, Any]:
|
||||||
|
where: dict[str, str | dict] = {"id": file_id}
|
||||||
|
if not return_expired:
|
||||||
|
where["expires_at"] = {">": self._now()}
|
||||||
|
if not (row := await self.sql_store.fetch_one("openai_files", where=where)):
|
||||||
|
raise ResourceNotFoundError(file_id, "File", "files.list()")
|
||||||
|
return row
|
||||||
|
|
||||||
|
async def _delete_file(self, file_id: str) -> None:
|
||||||
|
"""Delete a file from OpenAI and the database."""
|
||||||
|
try:
|
||||||
|
self.client.files.delete(file_id)
|
||||||
|
except Exception as e:
|
||||||
|
# If file doesn't exist on OpenAI side, just remove from metadata store
|
||||||
|
if "not found" not in str(e).lower():
|
||||||
|
raise RuntimeError(f"Failed to delete file from OpenAI: {e}") from e
|
||||||
|
|
||||||
|
await self.sql_store.delete("openai_files", where={"id": file_id})
|
||||||
|
|
||||||
|
async def _delete_if_expired(self, file_id: str) -> None:
|
||||||
|
"""If the file exists and is expired, delete it."""
|
||||||
|
if row := await self._get_file(file_id, return_expired=True):
|
||||||
|
if (expires_at := row.get("expires_at")) and expires_at <= self._now():
|
||||||
|
await self._delete_file(file_id)
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
self._client = OpenAI(api_key=self._config.api_key)
|
||||||
|
|
||||||
|
self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store), self.policy)
|
||||||
|
await self._sql_store.create_table(
|
||||||
|
"openai_files",
|
||||||
|
{
|
||||||
|
"id": ColumnDefinition(type=ColumnType.STRING, primary_key=True),
|
||||||
|
"filename": ColumnType.STRING,
|
||||||
|
"purpose": ColumnType.STRING,
|
||||||
|
"bytes": ColumnType.INTEGER,
|
||||||
|
"created_at": ColumnType.INTEGER,
|
||||||
|
"expires_at": ColumnType.INTEGER,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def client(self) -> OpenAI:
|
||||||
|
assert self._client is not None, "Provider not initialized"
|
||||||
|
return self._client
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sql_store(self) -> AuthorizedSqlStore:
|
||||||
|
assert self._sql_store is not None, "Provider not initialized"
|
||||||
|
return self._sql_store
|
||||||
|
|
||||||
|
async def openai_upload_file(
|
||||||
|
self,
|
||||||
|
file: Annotated[UploadFile, File()],
|
||||||
|
purpose: Annotated[OpenAIFilePurpose, Form()],
|
||||||
|
expires_after: Annotated[ExpiresAfter | None, Depends(parse_expires_after)] = None,
|
||||||
|
) -> OpenAIFileObject:
|
||||||
|
filename = getattr(file, "filename", None) or "uploaded_file"
|
||||||
|
content = await file.read()
|
||||||
|
file_size = len(content)
|
||||||
|
|
||||||
|
created_at = self._now()
|
||||||
|
|
||||||
|
expires_at = created_at + ExpiresAfter.MAX * 42
|
||||||
|
if purpose == OpenAIFilePurpose.BATCH:
|
||||||
|
expires_at = created_at + ExpiresAfter.MAX
|
||||||
|
|
||||||
|
if expires_after is not None:
|
||||||
|
expires_at = created_at + expires_after.seconds
|
||||||
|
|
||||||
|
try:
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
file_obj = BytesIO(content)
|
||||||
|
file_obj.name = filename
|
||||||
|
|
||||||
|
response = self.client.files.create(
|
||||||
|
file=file_obj,
|
||||||
|
purpose=purpose.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
file_id = response.id
|
||||||
|
|
||||||
|
entry: dict[str, Any] = {
|
||||||
|
"id": file_id,
|
||||||
|
"filename": filename,
|
||||||
|
"purpose": purpose.value,
|
||||||
|
"bytes": file_size,
|
||||||
|
"created_at": created_at,
|
||||||
|
"expires_at": expires_at,
|
||||||
|
}
|
||||||
|
|
||||||
|
await self.sql_store.insert("openai_files", entry)
|
||||||
|
|
||||||
|
return _make_file_object(**entry)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to upload file to OpenAI: {e}") from e
|
||||||
|
|
||||||
|
async def openai_list_files(
|
||||||
|
self,
|
||||||
|
after: str | None = None,
|
||||||
|
limit: int | None = 10000,
|
||||||
|
order: Order | None = Order.desc,
|
||||||
|
purpose: OpenAIFilePurpose | None = None,
|
||||||
|
) -> ListOpenAIFileResponse:
|
||||||
|
if not order:
|
||||||
|
order = Order.desc
|
||||||
|
|
||||||
|
where_conditions: dict[str, Any] = {"expires_at": {">": self._now()}}
|
||||||
|
if purpose:
|
||||||
|
where_conditions["purpose"] = purpose.value
|
||||||
|
|
||||||
|
paginated_result = await self.sql_store.fetch_all(
|
||||||
|
table="openai_files",
|
||||||
|
where=where_conditions,
|
||||||
|
order_by=[("created_at", order.value)],
|
||||||
|
cursor=("id", after) if after else None,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
files = [_make_file_object(**row) for row in paginated_result.data]
|
||||||
|
|
||||||
|
return ListOpenAIFileResponse(
|
||||||
|
data=files,
|
||||||
|
has_more=paginated_result.has_more,
|
||||||
|
first_id=files[0].id if files else "",
|
||||||
|
last_id=files[-1].id if files else "",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def openai_retrieve_file(self, file_id: str) -> OpenAIFileObject:
|
||||||
|
await self._delete_if_expired(file_id)
|
||||||
|
row = await self._get_file(file_id)
|
||||||
|
return _make_file_object(**row)
|
||||||
|
|
||||||
|
async def openai_delete_file(self, file_id: str) -> OpenAIFileDeleteResponse:
|
||||||
|
await self._delete_if_expired(file_id)
|
||||||
|
_ = await self._get_file(file_id)
|
||||||
|
await self._delete_file(file_id)
|
||||||
|
return OpenAIFileDeleteResponse(id=file_id, deleted=True)
|
||||||
|
|
||||||
|
async def openai_retrieve_file_content(self, file_id: str) -> Response:
|
||||||
|
await self._delete_if_expired(file_id)
|
||||||
|
|
||||||
|
row = await self._get_file(file_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = self.client.files.content(file_id)
|
||||||
|
file_content = response.content
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if "not found" in str(e).lower():
|
||||||
|
await self._delete_file(file_id)
|
||||||
|
raise ResourceNotFoundError(file_id, "File", "files.list()") from e
|
||||||
|
raise RuntimeError(f"Failed to download file from OpenAI: {e}") from e
|
||||||
|
|
||||||
|
return Response(
|
||||||
|
content=file_content,
|
||||||
|
media_type="application/octet-stream",
|
||||||
|
headers={"Content-Disposition": f'attachment; filename="{row["filename"]}"'},
|
||||||
|
)
|
||||||
|
|
@ -33,4 +33,5 @@ class AnthropicInferenceAdapter(OpenAIMixin):
|
||||||
return "https://api.anthropic.com/v1"
|
return "https://api.anthropic.com/v1"
|
||||||
|
|
||||||
async def list_provider_model_ids(self) -> Iterable[str]:
|
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||||
return [m.id async for m in AsyncAnthropic(api_key=self.get_api_key()).models.list()]
|
api_key = self._get_api_key_from_config_or_provider_data()
|
||||||
|
return [m.id async for m in AsyncAnthropic(api_key=api_key).models.list()]
|
||||||
|
|
|
||||||
|
|
@ -33,10 +33,11 @@ class DatabricksInferenceAdapter(OpenAIMixin):
|
||||||
|
|
||||||
async def list_provider_model_ids(self) -> Iterable[str]:
|
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||||
# Filter out None values from endpoint names
|
# Filter out None values from endpoint names
|
||||||
|
api_token = self._get_api_key_from_config_or_provider_data()
|
||||||
return [
|
return [
|
||||||
endpoint.name # type: ignore[misc]
|
endpoint.name # type: ignore[misc]
|
||||||
for endpoint in WorkspaceClient(
|
for endpoint in WorkspaceClient(
|
||||||
host=self.config.url, token=self.get_api_key()
|
host=self.config.url, token=api_token
|
||||||
).serving_endpoints.list() # TODO: this is not async
|
).serving_endpoints.list() # TODO: this is not async
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -181,3 +181,22 @@ vlm_response = client.chat.completions.create(
|
||||||
|
|
||||||
print(f"VLM Response: {vlm_response.choices[0].message.content}")
|
print(f"VLM Response: {vlm_response.choices[0].message.content}")
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Rerank Example
|
||||||
|
|
||||||
|
The following example shows how to rerank documents using an NVIDIA NIM.
|
||||||
|
|
||||||
|
```python
|
||||||
|
rerank_response = client.alpha.inference.rerank(
|
||||||
|
model="nvidia/nvidia/llama-3.2-nv-rerankqa-1b-v2",
|
||||||
|
query="query",
|
||||||
|
items=[
|
||||||
|
"item_1",
|
||||||
|
"item_2",
|
||||||
|
"item_3",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, result in enumerate(rerank_response):
|
||||||
|
print(f"{i+1}. [Index: {result.index}, " f"Score: {(result.relevance_score):.3f}]")
|
||||||
|
```
|
||||||
|
|
@ -28,6 +28,7 @@ class NVIDIAConfig(RemoteInferenceProviderConfig):
|
||||||
Attributes:
|
Attributes:
|
||||||
url (str): A base url for accessing the NVIDIA NIM, e.g. http://localhost:8000
|
url (str): A base url for accessing the NVIDIA NIM, e.g. http://localhost:8000
|
||||||
api_key (str): The access key for the hosted NIM endpoints
|
api_key (str): The access key for the hosted NIM endpoints
|
||||||
|
rerank_model_to_url (dict[str, str]): Mapping of rerank model identifiers to their API endpoints
|
||||||
|
|
||||||
There are two ways to access NVIDIA NIMs -
|
There are two ways to access NVIDIA NIMs -
|
||||||
0. Hosted: Preview APIs hosted at https://integrate.api.nvidia.com
|
0. Hosted: Preview APIs hosted at https://integrate.api.nvidia.com
|
||||||
|
|
@ -55,6 +56,14 @@ class NVIDIAConfig(RemoteInferenceProviderConfig):
|
||||||
default_factory=lambda: os.getenv("NVIDIA_APPEND_API_VERSION", "True").lower() != "false",
|
default_factory=lambda: os.getenv("NVIDIA_APPEND_API_VERSION", "True").lower() != "false",
|
||||||
description="When set to false, the API version will not be appended to the base_url. By default, it is true.",
|
description="When set to false, the API version will not be appended to the base_url. By default, it is true.",
|
||||||
)
|
)
|
||||||
|
rerank_model_to_url: dict[str, str] = Field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"nv-rerank-qa-mistral-4b:1": "https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking",
|
||||||
|
"nvidia/nv-rerankqa-mistral-4b-v3": "https://ai.api.nvidia.com/v1/retrieval/nvidia/nv-rerankqa-mistral-4b-v3/reranking",
|
||||||
|
"nvidia/llama-3.2-nv-rerankqa-1b-v2": "https://ai.api.nvidia.com/v1/retrieval/nvidia/llama-3_2-nv-rerankqa-1b-v2/reranking",
|
||||||
|
},
|
||||||
|
description="Mapping of rerank model identifiers to their API endpoints. ",
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(
|
def sample_run_config(
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,19 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
from collections.abc import Iterable
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import (
|
||||||
|
RerankData,
|
||||||
|
RerankResponse,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.inference.inference import (
|
||||||
|
OpenAIChatCompletionContentPartImageParam,
|
||||||
|
OpenAIChatCompletionContentPartTextParam,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.models import Model, ModelType
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
|
|
||||||
|
|
@ -61,3 +74,101 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
|
||||||
:return: The NVIDIA API base URL
|
:return: The NVIDIA API base URL
|
||||||
"""
|
"""
|
||||||
return f"{self.config.url}/v1" if self.config.append_api_version else self.config.url
|
return f"{self.config.url}/v1" if self.config.append_api_version else self.config.url
|
||||||
|
|
||||||
|
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||||
|
"""
|
||||||
|
Return both dynamic model IDs and statically configured rerank model IDs.
|
||||||
|
"""
|
||||||
|
dynamic_ids: Iterable[str] = []
|
||||||
|
try:
|
||||||
|
dynamic_ids = await super().list_provider_model_ids()
|
||||||
|
except Exception:
|
||||||
|
# If the dynamic listing fails, proceed with just configured rerank IDs
|
||||||
|
dynamic_ids = []
|
||||||
|
|
||||||
|
configured_rerank_ids = list(self.config.rerank_model_to_url.keys())
|
||||||
|
return list(dict.fromkeys(list(dynamic_ids) + configured_rerank_ids)) # remove duplicates
|
||||||
|
|
||||||
|
def construct_model_from_identifier(self, identifier: str) -> Model:
|
||||||
|
"""
|
||||||
|
Classify rerank models from config; otherwise use the base behavior.
|
||||||
|
"""
|
||||||
|
if identifier in self.config.rerank_model_to_url:
|
||||||
|
return Model(
|
||||||
|
provider_id=self.__provider_id__, # type: ignore[attr-defined]
|
||||||
|
provider_resource_id=identifier,
|
||||||
|
identifier=identifier,
|
||||||
|
model_type=ModelType.rerank,
|
||||||
|
)
|
||||||
|
return super().construct_model_from_identifier(identifier)
|
||||||
|
|
||||||
|
async def rerank(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
|
||||||
|
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
|
||||||
|
max_num_results: int | None = None,
|
||||||
|
) -> RerankResponse:
|
||||||
|
provider_model_id = await self._get_provider_model_id(model)
|
||||||
|
|
||||||
|
ranking_url = self.get_base_url()
|
||||||
|
|
||||||
|
if _is_nvidia_hosted(self.config) and provider_model_id in self.config.rerank_model_to_url:
|
||||||
|
ranking_url = self.config.rerank_model_to_url[provider_model_id]
|
||||||
|
|
||||||
|
logger.debug(f"Using rerank endpoint: {ranking_url} for model: {provider_model_id}")
|
||||||
|
|
||||||
|
# Convert query to text format
|
||||||
|
if isinstance(query, str):
|
||||||
|
query_text = query
|
||||||
|
elif isinstance(query, OpenAIChatCompletionContentPartTextParam):
|
||||||
|
query_text = query.text
|
||||||
|
else:
|
||||||
|
raise ValueError("Query must be a string or text content part")
|
||||||
|
|
||||||
|
# Convert items to text format
|
||||||
|
passages = []
|
||||||
|
for item in items:
|
||||||
|
if isinstance(item, str):
|
||||||
|
passages.append({"text": item})
|
||||||
|
elif isinstance(item, OpenAIChatCompletionContentPartTextParam):
|
||||||
|
passages.append({"text": item.text})
|
||||||
|
else:
|
||||||
|
raise ValueError("Items must be strings or text content parts")
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": provider_model_id,
|
||||||
|
"query": {"text": query_text},
|
||||||
|
"passages": passages,
|
||||||
|
}
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.get_api_key()}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(ranking_url, headers=headers, json=payload) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
response_text = await response.text()
|
||||||
|
raise ConnectionError(
|
||||||
|
f"NVIDIA rerank API request failed with status {response.status}: {response_text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await response.json()
|
||||||
|
rankings = result.get("rankings", [])
|
||||||
|
|
||||||
|
# Convert to RerankData format
|
||||||
|
rerank_data = []
|
||||||
|
for ranking in rankings:
|
||||||
|
rerank_data.append(RerankData(index=ranking["index"], relevance_score=ranking["logit"]))
|
||||||
|
|
||||||
|
# Apply max_num_results limit
|
||||||
|
if max_num_results is not None:
|
||||||
|
rerank_data = rerank_data[:max_num_results]
|
||||||
|
|
||||||
|
return RerankResponse(data=rerank_data)
|
||||||
|
|
||||||
|
except aiohttp.ClientError as e:
|
||||||
|
raise ConnectionError(f"Failed to connect to NVIDIA rerank API at {ranking_url}: {e}") from e
|
||||||
|
|
|
||||||
|
|
@ -35,6 +35,7 @@ class InferenceStore:
|
||||||
self.reference = reference
|
self.reference = reference
|
||||||
self.sql_store = None
|
self.sql_store = None
|
||||||
self.policy = policy
|
self.policy = policy
|
||||||
|
self.enable_write_queue = True
|
||||||
|
|
||||||
# Async write queue and worker control
|
# Async write queue and worker control
|
||||||
self._queue: asyncio.Queue[tuple[OpenAIChatCompletion, list[OpenAIMessageParam]]] | None = None
|
self._queue: asyncio.Queue[tuple[OpenAIChatCompletion, list[OpenAIMessageParam]]] | None = None
|
||||||
|
|
@ -47,14 +48,13 @@ class InferenceStore:
|
||||||
base_store = sqlstore_impl(self.reference)
|
base_store = sqlstore_impl(self.reference)
|
||||||
self.sql_store = AuthorizedSqlStore(base_store, self.policy)
|
self.sql_store = AuthorizedSqlStore(base_store, self.policy)
|
||||||
|
|
||||||
# Disable write queue for SQLite to avoid concurrency issues
|
# Disable write queue for SQLite since WAL mode handles concurrency
|
||||||
backend_name = self.reference.backend
|
# Keep it enabled for other backends (like Postgres) for performance
|
||||||
backend_config = _SQLSTORE_BACKENDS.get(backend_name)
|
backend_config = _SQLSTORE_BACKENDS.get(self.reference.backend)
|
||||||
if backend_config is None:
|
if backend_config and backend_config.type == StorageBackendType.SQL_SQLITE:
|
||||||
raise ValueError(
|
self.enable_write_queue = False
|
||||||
f"Unregistered SQL backend '{backend_name}'. Registered backends: {sorted(_SQLSTORE_BACKENDS)}"
|
logger.debug("Write queue disabled for SQLite (WAL mode handles concurrency)")
|
||||||
)
|
|
||||||
self.enable_write_queue = backend_config.type != StorageBackendType.SQL_SQLITE
|
|
||||||
await self.sql_store.create_table(
|
await self.sql_store.create_table(
|
||||||
"chat_completions",
|
"chat_completions",
|
||||||
{
|
{
|
||||||
|
|
@ -70,8 +70,9 @@ class InferenceStore:
|
||||||
self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
|
self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
|
||||||
for _ in range(self._num_writers):
|
for _ in range(self._num_writers):
|
||||||
self._worker_tasks.append(asyncio.create_task(self._worker_loop()))
|
self._worker_tasks.append(asyncio.create_task(self._worker_loop()))
|
||||||
else:
|
logger.debug(
|
||||||
logger.info("Write queue disabled for SQLite to avoid concurrency issues")
|
f"Inference store write queue enabled with {self._num_writers} writers, max queue size {self._max_write_queue_size}"
|
||||||
|
)
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
if not self._worker_tasks:
|
if not self._worker_tasks:
|
||||||
|
|
|
||||||
|
|
@ -128,7 +128,9 @@ class LiteLLMOpenAIMixin(
|
||||||
return schema
|
return schema
|
||||||
|
|
||||||
async def _get_params(self, request: ChatCompletionRequest) -> dict:
|
async def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||||
input_dict = {}
|
from typing import Any
|
||||||
|
|
||||||
|
input_dict: dict[str, Any] = {}
|
||||||
|
|
||||||
input_dict["messages"] = [
|
input_dict["messages"] = [
|
||||||
await convert_message_to_openai_dict_new(m, download_images=self.download_images) for m in request.messages
|
await convert_message_to_openai_dict_new(m, download_images=self.download_images) for m in request.messages
|
||||||
|
|
@ -139,30 +141,27 @@ class LiteLLMOpenAIMixin(
|
||||||
f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported."
|
f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported."
|
||||||
)
|
)
|
||||||
|
|
||||||
fmt = fmt.json_schema
|
# Convert to dict for manipulation
|
||||||
name = fmt["title"]
|
fmt_dict = dict(fmt.json_schema)
|
||||||
del fmt["title"]
|
name = fmt_dict["title"]
|
||||||
fmt["additionalProperties"] = False
|
del fmt_dict["title"]
|
||||||
|
fmt_dict["additionalProperties"] = False
|
||||||
|
|
||||||
# Apply additionalProperties: False recursively to all objects
|
# Apply additionalProperties: False recursively to all objects
|
||||||
fmt = self._add_additional_properties_recursive(fmt)
|
fmt_dict = self._add_additional_properties_recursive(fmt_dict)
|
||||||
|
|
||||||
input_dict["response_format"] = {
|
input_dict["response_format"] = {
|
||||||
"type": "json_schema",
|
"type": "json_schema",
|
||||||
"json_schema": {
|
"json_schema": {
|
||||||
"name": name,
|
"name": name,
|
||||||
"schema": fmt,
|
"schema": fmt_dict,
|
||||||
"strict": self.json_schema_strict,
|
"strict": self.json_schema_strict,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
if request.tools:
|
if request.tools:
|
||||||
input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools]
|
input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools]
|
||||||
if request.tool_config.tool_choice:
|
if request.tool_config and (tool_choice := request.tool_config.tool_choice):
|
||||||
input_dict["tool_choice"] = (
|
input_dict["tool_choice"] = tool_choice.value if isinstance(tool_choice, ToolChoice) else tool_choice
|
||||||
request.tool_config.tool_choice.value
|
|
||||||
if isinstance(request.tool_config.tool_choice, ToolChoice)
|
|
||||||
else request.tool_config.tool_choice
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"model": request.model,
|
"model": request.model,
|
||||||
|
|
@ -176,9 +175,9 @@ class LiteLLMOpenAIMixin(
|
||||||
def get_api_key(self) -> str:
|
def get_api_key(self) -> str:
|
||||||
provider_data = self.get_request_provider_data()
|
provider_data = self.get_request_provider_data()
|
||||||
key_field = self.provider_data_api_key_field
|
key_field = self.provider_data_api_key_field
|
||||||
if provider_data and getattr(provider_data, key_field, None):
|
if provider_data and key_field and (api_key := getattr(provider_data, key_field, None)):
|
||||||
api_key = getattr(provider_data, key_field)
|
return str(api_key) # type: ignore[no-any-return] # getattr returns Any, can't narrow without runtime type inspection
|
||||||
else:
|
|
||||||
api_key = self.api_key_from_config
|
api_key = self.api_key_from_config
|
||||||
if not api_key:
|
if not api_key:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
@ -192,7 +191,13 @@ class LiteLLMOpenAIMixin(
|
||||||
self,
|
self,
|
||||||
params: OpenAIEmbeddingsRequestWithExtraBody,
|
params: OpenAIEmbeddingsRequestWithExtraBody,
|
||||||
) -> OpenAIEmbeddingsResponse:
|
) -> OpenAIEmbeddingsResponse:
|
||||||
|
if not self.model_store:
|
||||||
|
raise ValueError("Model store is not initialized")
|
||||||
|
|
||||||
model_obj = await self.model_store.get_model(params.model)
|
model_obj = await self.model_store.get_model(params.model)
|
||||||
|
if model_obj.provider_resource_id is None:
|
||||||
|
raise ValueError(f"Model {params.model} has no provider_resource_id")
|
||||||
|
provider_resource_id = model_obj.provider_resource_id
|
||||||
|
|
||||||
# Convert input to list if it's a string
|
# Convert input to list if it's a string
|
||||||
input_list = [params.input] if isinstance(params.input, str) else params.input
|
input_list = [params.input] if isinstance(params.input, str) else params.input
|
||||||
|
|
@ -200,7 +205,7 @@ class LiteLLMOpenAIMixin(
|
||||||
# Call litellm embedding function
|
# Call litellm embedding function
|
||||||
# litellm.drop_params = True
|
# litellm.drop_params = True
|
||||||
response = litellm.embedding(
|
response = litellm.embedding(
|
||||||
model=self.get_litellm_model_name(model_obj.provider_resource_id),
|
model=self.get_litellm_model_name(provider_resource_id),
|
||||||
input=input_list,
|
input=input_list,
|
||||||
api_key=self.get_api_key(),
|
api_key=self.get_api_key(),
|
||||||
api_base=self.api_base,
|
api_base=self.api_base,
|
||||||
|
|
@ -217,7 +222,7 @@ class LiteLLMOpenAIMixin(
|
||||||
|
|
||||||
return OpenAIEmbeddingsResponse(
|
return OpenAIEmbeddingsResponse(
|
||||||
data=data,
|
data=data,
|
||||||
model=model_obj.provider_resource_id,
|
model=provider_resource_id,
|
||||||
usage=usage,
|
usage=usage,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -225,10 +230,16 @@ class LiteLLMOpenAIMixin(
|
||||||
self,
|
self,
|
||||||
params: OpenAICompletionRequestWithExtraBody,
|
params: OpenAICompletionRequestWithExtraBody,
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
|
if not self.model_store:
|
||||||
|
raise ValueError("Model store is not initialized")
|
||||||
|
|
||||||
model_obj = await self.model_store.get_model(params.model)
|
model_obj = await self.model_store.get_model(params.model)
|
||||||
|
if model_obj.provider_resource_id is None:
|
||||||
|
raise ValueError(f"Model {params.model} has no provider_resource_id")
|
||||||
|
provider_resource_id = model_obj.provider_resource_id
|
||||||
|
|
||||||
request_params = await prepare_openai_completion_params(
|
request_params = await prepare_openai_completion_params(
|
||||||
model=self.get_litellm_model_name(model_obj.provider_resource_id),
|
model=self.get_litellm_model_name(provider_resource_id),
|
||||||
prompt=params.prompt,
|
prompt=params.prompt,
|
||||||
best_of=params.best_of,
|
best_of=params.best_of,
|
||||||
echo=params.echo,
|
echo=params.echo,
|
||||||
|
|
@ -249,7 +260,8 @@ class LiteLLMOpenAIMixin(
|
||||||
api_key=self.get_api_key(),
|
api_key=self.get_api_key(),
|
||||||
api_base=self.api_base,
|
api_base=self.api_base,
|
||||||
)
|
)
|
||||||
return await litellm.atext_completion(**request_params)
|
# LiteLLM returns compatible type but mypy can't verify external library
|
||||||
|
return await litellm.atext_completion(**request_params) # type: ignore[no-any-return] # external lib lacks type stubs
|
||||||
|
|
||||||
async def openai_chat_completion(
|
async def openai_chat_completion(
|
||||||
self,
|
self,
|
||||||
|
|
@ -265,10 +277,16 @@ class LiteLLMOpenAIMixin(
|
||||||
elif "include_usage" not in stream_options:
|
elif "include_usage" not in stream_options:
|
||||||
stream_options = {**stream_options, "include_usage": True}
|
stream_options = {**stream_options, "include_usage": True}
|
||||||
|
|
||||||
|
if not self.model_store:
|
||||||
|
raise ValueError("Model store is not initialized")
|
||||||
|
|
||||||
model_obj = await self.model_store.get_model(params.model)
|
model_obj = await self.model_store.get_model(params.model)
|
||||||
|
if model_obj.provider_resource_id is None:
|
||||||
|
raise ValueError(f"Model {params.model} has no provider_resource_id")
|
||||||
|
provider_resource_id = model_obj.provider_resource_id
|
||||||
|
|
||||||
request_params = await prepare_openai_completion_params(
|
request_params = await prepare_openai_completion_params(
|
||||||
model=self.get_litellm_model_name(model_obj.provider_resource_id),
|
model=self.get_litellm_model_name(provider_resource_id),
|
||||||
messages=params.messages,
|
messages=params.messages,
|
||||||
frequency_penalty=params.frequency_penalty,
|
frequency_penalty=params.frequency_penalty,
|
||||||
function_call=params.function_call,
|
function_call=params.function_call,
|
||||||
|
|
@ -294,7 +312,8 @@ class LiteLLMOpenAIMixin(
|
||||||
api_key=self.get_api_key(),
|
api_key=self.get_api_key(),
|
||||||
api_base=self.api_base,
|
api_base=self.api_base,
|
||||||
)
|
)
|
||||||
return await litellm.acompletion(**request_params)
|
# LiteLLM returns compatible type but mypy can't verify external library
|
||||||
|
return await litellm.acompletion(**request_params) # type: ignore[no-any-return] # external lib lacks type stubs
|
||||||
|
|
||||||
async def check_model_availability(self, model: str) -> bool:
|
async def check_model_availability(self, model: str) -> bool:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,7 @@ logger = get_logger(name=__name__, category="providers::utils")
|
||||||
|
|
||||||
|
|
||||||
class RemoteInferenceProviderConfig(BaseModel):
|
class RemoteInferenceProviderConfig(BaseModel):
|
||||||
allowed_models: list[str] | None = Field( # TODO: make this non-optional and give a list() default
|
allowed_models: list[str] | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="List of models that should be registered with the model registry. If None, all models are allowed.",
|
description="List of models that should be registered with the model registry. If None, all models are allowed.",
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -161,7 +161,9 @@ def get_sampling_strategy_options(params: SamplingParams) -> dict:
|
||||||
if isinstance(params.strategy, GreedySamplingStrategy):
|
if isinstance(params.strategy, GreedySamplingStrategy):
|
||||||
options["temperature"] = 0.0
|
options["temperature"] = 0.0
|
||||||
elif isinstance(params.strategy, TopPSamplingStrategy):
|
elif isinstance(params.strategy, TopPSamplingStrategy):
|
||||||
|
if params.strategy.temperature is not None:
|
||||||
options["temperature"] = params.strategy.temperature
|
options["temperature"] = params.strategy.temperature
|
||||||
|
if params.strategy.top_p is not None:
|
||||||
options["top_p"] = params.strategy.top_p
|
options["top_p"] = params.strategy.top_p
|
||||||
elif isinstance(params.strategy, TopKSamplingStrategy):
|
elif isinstance(params.strategy, TopKSamplingStrategy):
|
||||||
options["top_k"] = params.strategy.top_k
|
options["top_k"] = params.strategy.top_k
|
||||||
|
|
@ -192,12 +194,12 @@ def get_sampling_options(params: SamplingParams | None) -> dict:
|
||||||
|
|
||||||
def text_from_choice(choice) -> str:
|
def text_from_choice(choice) -> str:
|
||||||
if hasattr(choice, "delta") and choice.delta:
|
if hasattr(choice, "delta") and choice.delta:
|
||||||
return choice.delta.content
|
return choice.delta.content # type: ignore[no-any-return] # external OpenAI types lack precise annotations
|
||||||
|
|
||||||
if hasattr(choice, "message"):
|
if hasattr(choice, "message"):
|
||||||
return choice.message.content
|
return choice.message.content # type: ignore[no-any-return] # external OpenAI types lack precise annotations
|
||||||
|
|
||||||
return choice.text
|
return choice.text # type: ignore[no-any-return] # external OpenAI types lack precise annotations
|
||||||
|
|
||||||
|
|
||||||
def get_stop_reason(finish_reason: str) -> StopReason:
|
def get_stop_reason(finish_reason: str) -> StopReason:
|
||||||
|
|
@ -216,7 +218,7 @@ def convert_openai_completion_logprobs(
|
||||||
) -> list[TokenLogProbs] | None:
|
) -> list[TokenLogProbs] | None:
|
||||||
if not logprobs:
|
if not logprobs:
|
||||||
return None
|
return None
|
||||||
if hasattr(logprobs, "top_logprobs"):
|
if hasattr(logprobs, "top_logprobs") and logprobs.top_logprobs:
|
||||||
return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs]
|
return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs]
|
||||||
|
|
||||||
# Together supports logprobs with top_k=1 only. This means for each token position,
|
# Together supports logprobs with top_k=1 only. This means for each token position,
|
||||||
|
|
@ -236,7 +238,7 @@ def convert_openai_completion_logprobs_stream(text: str, logprobs: float | OpenA
|
||||||
if isinstance(logprobs, float):
|
if isinstance(logprobs, float):
|
||||||
# Adapt response from Together CompletionChoicesChunk
|
# Adapt response from Together CompletionChoicesChunk
|
||||||
return [TokenLogProbs(logprobs_by_token={text: logprobs})]
|
return [TokenLogProbs(logprobs_by_token={text: logprobs})]
|
||||||
if hasattr(logprobs, "top_logprobs"):
|
if hasattr(logprobs, "top_logprobs") and logprobs.top_logprobs:
|
||||||
return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs]
|
return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
@ -245,23 +247,24 @@ def process_completion_response(
|
||||||
response: OpenAICompatCompletionResponse,
|
response: OpenAICompatCompletionResponse,
|
||||||
) -> CompletionResponse:
|
) -> CompletionResponse:
|
||||||
choice = response.choices[0]
|
choice = response.choices[0]
|
||||||
|
text = choice.text or ""
|
||||||
# drop suffix <eot_id> if present and return stop reason as end of turn
|
# drop suffix <eot_id> if present and return stop reason as end of turn
|
||||||
if choice.text.endswith("<|eot_id|>"):
|
if text.endswith("<|eot_id|>"):
|
||||||
return CompletionResponse(
|
return CompletionResponse(
|
||||||
stop_reason=StopReason.end_of_turn,
|
stop_reason=StopReason.end_of_turn,
|
||||||
content=choice.text[: -len("<|eot_id|>")],
|
content=text[: -len("<|eot_id|>")],
|
||||||
logprobs=convert_openai_completion_logprobs(choice.logprobs),
|
logprobs=convert_openai_completion_logprobs(choice.logprobs),
|
||||||
)
|
)
|
||||||
# drop suffix <eom_id> if present and return stop reason as end of message
|
# drop suffix <eom_id> if present and return stop reason as end of message
|
||||||
if choice.text.endswith("<|eom_id|>"):
|
if text.endswith("<|eom_id|>"):
|
||||||
return CompletionResponse(
|
return CompletionResponse(
|
||||||
stop_reason=StopReason.end_of_message,
|
stop_reason=StopReason.end_of_message,
|
||||||
content=choice.text[: -len("<|eom_id|>")],
|
content=text[: -len("<|eom_id|>")],
|
||||||
logprobs=convert_openai_completion_logprobs(choice.logprobs),
|
logprobs=convert_openai_completion_logprobs(choice.logprobs),
|
||||||
)
|
)
|
||||||
return CompletionResponse(
|
return CompletionResponse(
|
||||||
stop_reason=get_stop_reason(choice.finish_reason),
|
stop_reason=get_stop_reason(choice.finish_reason or "stop"),
|
||||||
content=choice.text,
|
content=text,
|
||||||
logprobs=convert_openai_completion_logprobs(choice.logprobs),
|
logprobs=convert_openai_completion_logprobs(choice.logprobs),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -272,10 +275,10 @@ def process_chat_completion_response(
|
||||||
) -> ChatCompletionResponse:
|
) -> ChatCompletionResponse:
|
||||||
choice = response.choices[0]
|
choice = response.choices[0]
|
||||||
if choice.finish_reason == "tool_calls":
|
if choice.finish_reason == "tool_calls":
|
||||||
if not choice.message or not choice.message.tool_calls:
|
if not hasattr(choice, "message") or not choice.message or not choice.message.tool_calls: # type: ignore[attr-defined] # OpenAICompatCompletionChoice is runtime duck-typed
|
||||||
raise ValueError("Tool calls are not present in the response")
|
raise ValueError("Tool calls are not present in the response")
|
||||||
|
|
||||||
tool_calls = [convert_tool_call(tool_call) for tool_call in choice.message.tool_calls]
|
tool_calls = [convert_tool_call(tool_call) for tool_call in choice.message.tool_calls] # type: ignore[attr-defined] # OpenAICompatCompletionChoice is runtime duck-typed
|
||||||
if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls):
|
if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls):
|
||||||
# If we couldn't parse a tool call, jsonify the tool calls and return them
|
# If we couldn't parse a tool call, jsonify the tool calls and return them
|
||||||
return ChatCompletionResponse(
|
return ChatCompletionResponse(
|
||||||
|
|
@ -287,9 +290,11 @@ def process_chat_completion_response(
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Otherwise, return tool calls as normal
|
# Otherwise, return tool calls as normal
|
||||||
|
# Filter to only valid ToolCall objects
|
||||||
|
valid_tool_calls = [tc for tc in tool_calls if isinstance(tc, ToolCall)]
|
||||||
return ChatCompletionResponse(
|
return ChatCompletionResponse(
|
||||||
completion_message=CompletionMessage(
|
completion_message=CompletionMessage(
|
||||||
tool_calls=tool_calls,
|
tool_calls=valid_tool_calls,
|
||||||
stop_reason=StopReason.end_of_turn,
|
stop_reason=StopReason.end_of_turn,
|
||||||
# Content is not optional
|
# Content is not optional
|
||||||
content="",
|
content="",
|
||||||
|
|
@ -299,7 +304,7 @@ def process_chat_completion_response(
|
||||||
|
|
||||||
# TODO: This does not work well with tool calls for vLLM remote provider
|
# TODO: This does not work well with tool calls for vLLM remote provider
|
||||||
# Ref: https://github.com/meta-llama/llama-stack/issues/1058
|
# Ref: https://github.com/meta-llama/llama-stack/issues/1058
|
||||||
raw_message = decode_assistant_message(text_from_choice(choice), get_stop_reason(choice.finish_reason))
|
raw_message = decode_assistant_message(text_from_choice(choice), get_stop_reason(choice.finish_reason or "stop"))
|
||||||
|
|
||||||
# NOTE: If we do not set tools in chat-completion request, we should not
|
# NOTE: If we do not set tools in chat-completion request, we should not
|
||||||
# expect the ToolCall in the response. Instead, we should return the raw
|
# expect the ToolCall in the response. Instead, we should return the raw
|
||||||
|
|
@ -324,8 +329,8 @@ def process_chat_completion_response(
|
||||||
|
|
||||||
return ChatCompletionResponse(
|
return ChatCompletionResponse(
|
||||||
completion_message=CompletionMessage(
|
completion_message=CompletionMessage(
|
||||||
content=raw_message.content,
|
content=raw_message.content, # type: ignore[arg-type] # decode_assistant_message returns Union[str, InterleavedContent]
|
||||||
stop_reason=raw_message.stop_reason,
|
stop_reason=raw_message.stop_reason or StopReason.end_of_turn,
|
||||||
tool_calls=raw_message.tool_calls,
|
tool_calls=raw_message.tool_calls,
|
||||||
),
|
),
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
|
|
@ -448,7 +453,7 @@ async def process_chat_completion_stream_response(
|
||||||
)
|
)
|
||||||
|
|
||||||
# parse tool calls and report errors
|
# parse tool calls and report errors
|
||||||
message = decode_assistant_message(buffer, stop_reason)
|
message = decode_assistant_message(buffer, stop_reason or StopReason.end_of_turn)
|
||||||
|
|
||||||
parsed_tool_calls = len(message.tool_calls) > 0
|
parsed_tool_calls = len(message.tool_calls) > 0
|
||||||
if ipython and not parsed_tool_calls:
|
if ipython and not parsed_tool_calls:
|
||||||
|
|
@ -463,7 +468,7 @@ async def process_chat_completion_stream_response(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
request_tools = {t.tool_name: t for t in request.tools}
|
request_tools = {t.tool_name: t for t in (request.tools or [])}
|
||||||
for tool_call in message.tool_calls:
|
for tool_call in message.tool_calls:
|
||||||
if tool_call.tool_name in request_tools:
|
if tool_call.tool_name in request_tools:
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
|
@ -525,7 +530,7 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals
|
||||||
}
|
}
|
||||||
|
|
||||||
if hasattr(message, "tool_calls") and message.tool_calls:
|
if hasattr(message, "tool_calls") and message.tool_calls:
|
||||||
result["tool_calls"] = []
|
tool_calls_list = []
|
||||||
for tc in message.tool_calls:
|
for tc in message.tool_calls:
|
||||||
# The tool.tool_name can be a str or a BuiltinTool enum. If
|
# The tool.tool_name can be a str or a BuiltinTool enum. If
|
||||||
# it's the latter, convert to a string.
|
# it's the latter, convert to a string.
|
||||||
|
|
@ -533,7 +538,7 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals
|
||||||
if isinstance(tool_name, BuiltinTool):
|
if isinstance(tool_name, BuiltinTool):
|
||||||
tool_name = tool_name.value
|
tool_name = tool_name.value
|
||||||
|
|
||||||
result["tool_calls"].append(
|
tool_calls_list.append(
|
||||||
{
|
{
|
||||||
"id": tc.call_id,
|
"id": tc.call_id,
|
||||||
"type": "function",
|
"type": "function",
|
||||||
|
|
@ -543,6 +548,7 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
result["tool_calls"] = tool_calls_list # type: ignore[assignment] # dict allows Any value, stricter type expected
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -608,7 +614,7 @@ async def convert_message_to_openai_dict_new(
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
elif isinstance(content_, list):
|
elif isinstance(content_, list):
|
||||||
return [await impl(item) for item in content_]
|
return [await impl(item) for item in content_] # type: ignore[misc] # recursive list comprehension confuses mypy's type narrowing
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported content type: {type(content_)}")
|
raise ValueError(f"Unsupported content type: {type(content_)}")
|
||||||
|
|
||||||
|
|
@ -620,7 +626,7 @@ async def convert_message_to_openai_dict_new(
|
||||||
else:
|
else:
|
||||||
return [ret]
|
return [ret]
|
||||||
|
|
||||||
out: OpenAIChatCompletionMessage = None
|
out: OpenAIChatCompletionMessage
|
||||||
if isinstance(message, UserMessage):
|
if isinstance(message, UserMessage):
|
||||||
out = OpenAIChatCompletionUserMessage(
|
out = OpenAIChatCompletionUserMessage(
|
||||||
role="user",
|
role="user",
|
||||||
|
|
@ -636,7 +642,7 @@ async def convert_message_to_openai_dict_new(
|
||||||
),
|
),
|
||||||
type="function",
|
type="function",
|
||||||
)
|
)
|
||||||
for tool in message.tool_calls
|
for tool in (message.tool_calls or [])
|
||||||
]
|
]
|
||||||
params = {}
|
params = {}
|
||||||
if tool_calls:
|
if tool_calls:
|
||||||
|
|
@ -644,18 +650,18 @@ async def convert_message_to_openai_dict_new(
|
||||||
out = OpenAIChatCompletionAssistantMessage(
|
out = OpenAIChatCompletionAssistantMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content=await _convert_message_content(message.content),
|
content=await _convert_message_content(message.content),
|
||||||
**params,
|
**params, # type: ignore[typeddict-item] # tool_calls dict expansion conflicts with TypedDict optional field
|
||||||
)
|
)
|
||||||
elif isinstance(message, ToolResponseMessage):
|
elif isinstance(message, ToolResponseMessage):
|
||||||
out = OpenAIChatCompletionToolMessage(
|
out = OpenAIChatCompletionToolMessage(
|
||||||
role="tool",
|
role="tool",
|
||||||
tool_call_id=message.call_id,
|
tool_call_id=message.call_id,
|
||||||
content=await _convert_message_content(message.content),
|
content=await _convert_message_content(message.content), # type: ignore[typeddict-item] # content union type incompatible with TypedDict str requirement
|
||||||
)
|
)
|
||||||
elif isinstance(message, SystemMessage):
|
elif isinstance(message, SystemMessage):
|
||||||
out = OpenAIChatCompletionSystemMessage(
|
out = OpenAIChatCompletionSystemMessage(
|
||||||
role="system",
|
role="system",
|
||||||
content=await _convert_message_content(message.content),
|
content=await _convert_message_content(message.content), # type: ignore[typeddict-item] # content union type incompatible with TypedDict str requirement
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported message type: {type(message)}")
|
raise ValueError(f"Unsupported message type: {type(message)}")
|
||||||
|
|
@ -758,16 +764,16 @@ def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict:
|
||||||
function = out["function"]
|
function = out["function"]
|
||||||
|
|
||||||
if isinstance(tool.tool_name, BuiltinTool):
|
if isinstance(tool.tool_name, BuiltinTool):
|
||||||
function["name"] = tool.tool_name.value
|
function["name"] = tool.tool_name.value # type: ignore[index] # dict value inferred as Any but mypy sees Collection[str]
|
||||||
else:
|
else:
|
||||||
function["name"] = tool.tool_name
|
function["name"] = tool.tool_name # type: ignore[index] # dict value inferred as Any but mypy sees Collection[str]
|
||||||
|
|
||||||
if tool.description:
|
if tool.description:
|
||||||
function["description"] = tool.description
|
function["description"] = tool.description # type: ignore[index] # dict value inferred as Any but mypy sees Collection[str]
|
||||||
|
|
||||||
if tool.input_schema:
|
if tool.input_schema:
|
||||||
# Pass through the entire JSON Schema as-is
|
# Pass through the entire JSON Schema as-is
|
||||||
function["parameters"] = tool.input_schema
|
function["parameters"] = tool.input_schema # type: ignore[index] # dict value inferred as Any but mypy sees Collection[str]
|
||||||
|
|
||||||
# NOTE: OpenAI does not support output_schema, so we drop it here
|
# NOTE: OpenAI does not support output_schema, so we drop it here
|
||||||
# It's stored in LlamaStack for validation and other provider usage
|
# It's stored in LlamaStack for validation and other provider usage
|
||||||
|
|
@ -815,15 +821,15 @@ def _convert_openai_request_tool_config(tool_choice: str | dict[str, Any] | None
|
||||||
tool_config = ToolConfig()
|
tool_config = ToolConfig()
|
||||||
if tool_choice:
|
if tool_choice:
|
||||||
try:
|
try:
|
||||||
tool_choice = ToolChoice(tool_choice)
|
tool_choice = ToolChoice(tool_choice) # type: ignore[assignment] # reassigning to enum narrows union but mypy can't track after exception
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
tool_config.tool_choice = tool_choice
|
tool_config.tool_choice = tool_choice # type: ignore[assignment] # ToolConfig.tool_choice accepts Union[ToolChoice, dict] but mypy tracks narrower type
|
||||||
return tool_config
|
return tool_config
|
||||||
|
|
||||||
|
|
||||||
def _convert_openai_request_tools(tools: list[dict[str, Any]] | None = None) -> list[ToolDefinition]:
|
def _convert_openai_request_tools(tools: list[dict[str, Any]] | None = None) -> list[ToolDefinition]:
|
||||||
lls_tools = []
|
lls_tools: list[ToolDefinition] = []
|
||||||
if not tools:
|
if not tools:
|
||||||
return lls_tools
|
return lls_tools
|
||||||
|
|
||||||
|
|
@ -843,16 +849,16 @@ def _convert_openai_request_tools(tools: list[dict[str, Any]] | None = None) ->
|
||||||
|
|
||||||
|
|
||||||
def _convert_openai_request_response_format(
|
def _convert_openai_request_response_format(
|
||||||
response_format: OpenAIResponseFormatParam = None,
|
response_format: OpenAIResponseFormatParam | None = None,
|
||||||
):
|
):
|
||||||
if not response_format:
|
if not response_format:
|
||||||
return None
|
return None
|
||||||
# response_format can be a dict or a pydantic model
|
# response_format can be a dict or a pydantic model
|
||||||
response_format = dict(response_format)
|
response_format_dict = dict(response_format) # type: ignore[arg-type] # OpenAIResponseFormatParam union needs dict conversion
|
||||||
if response_format.get("type", "") == "json_schema":
|
if response_format_dict.get("type", "") == "json_schema":
|
||||||
return JsonSchemaResponseFormat(
|
return JsonSchemaResponseFormat(
|
||||||
type="json_schema",
|
type="json_schema", # type: ignore[arg-type] # Literal["json_schema"] incompatible with expected type
|
||||||
json_schema=response_format.get("json_schema", {}).get("schema", ""),
|
json_schema=response_format_dict.get("json_schema", {}).get("schema", ""),
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
@ -938,16 +944,15 @@ def _convert_openai_sampling_params(
|
||||||
|
|
||||||
# Map an explicit temperature of 0 to greedy sampling
|
# Map an explicit temperature of 0 to greedy sampling
|
||||||
if temperature == 0:
|
if temperature == 0:
|
||||||
strategy = GreedySamplingStrategy()
|
sampling_params.strategy = GreedySamplingStrategy()
|
||||||
else:
|
else:
|
||||||
# OpenAI defaults to 1.0 for temperature and top_p if unset
|
# OpenAI defaults to 1.0 for temperature and top_p if unset
|
||||||
if temperature is None:
|
if temperature is None:
|
||||||
temperature = 1.0
|
temperature = 1.0
|
||||||
if top_p is None:
|
if top_p is None:
|
||||||
top_p = 1.0
|
top_p = 1.0
|
||||||
strategy = TopPSamplingStrategy(temperature=temperature, top_p=top_p)
|
sampling_params.strategy = TopPSamplingStrategy(temperature=temperature, top_p=top_p) # type: ignore[assignment] # SamplingParams.strategy union accepts this type
|
||||||
|
|
||||||
sampling_params.strategy = strategy
|
|
||||||
return sampling_params
|
return sampling_params
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -957,23 +962,24 @@ def openai_messages_to_messages(
|
||||||
"""
|
"""
|
||||||
Convert a list of OpenAIChatCompletionMessage into a list of Message.
|
Convert a list of OpenAIChatCompletionMessage into a list of Message.
|
||||||
"""
|
"""
|
||||||
converted_messages = []
|
converted_messages: list[Message] = []
|
||||||
for message in messages:
|
for message in messages:
|
||||||
|
converted_message: Message
|
||||||
if message.role == "system":
|
if message.role == "system":
|
||||||
converted_message = SystemMessage(content=openai_content_to_content(message.content))
|
converted_message = SystemMessage(content=openai_content_to_content(message.content)) # type: ignore[arg-type] # OpenAI SDK uses aliased types internally that mypy sees as incompatible with base types
|
||||||
elif message.role == "user":
|
elif message.role == "user":
|
||||||
converted_message = UserMessage(content=openai_content_to_content(message.content))
|
converted_message = UserMessage(content=openai_content_to_content(message.content)) # type: ignore[arg-type] # OpenAI SDK uses aliased types internally that mypy sees as incompatible with base types
|
||||||
elif message.role == "assistant":
|
elif message.role == "assistant":
|
||||||
converted_message = CompletionMessage(
|
converted_message = CompletionMessage(
|
||||||
content=openai_content_to_content(message.content),
|
content=openai_content_to_content(message.content), # type: ignore[arg-type] # OpenAI SDK uses aliased types internally that mypy sees as incompatible with base types
|
||||||
tool_calls=_convert_openai_tool_calls(message.tool_calls),
|
tool_calls=_convert_openai_tool_calls(message.tool_calls) if message.tool_calls else [], # type: ignore[arg-type] # OpenAI tool_calls type incompatible with conversion function
|
||||||
stop_reason=StopReason.end_of_turn,
|
stop_reason=StopReason.end_of_turn,
|
||||||
)
|
)
|
||||||
elif message.role == "tool":
|
elif message.role == "tool":
|
||||||
converted_message = ToolResponseMessage(
|
converted_message = ToolResponseMessage(
|
||||||
role="tool",
|
role="tool",
|
||||||
call_id=message.tool_call_id,
|
call_id=message.tool_call_id,
|
||||||
content=openai_content_to_content(message.content),
|
content=openai_content_to_content(message.content), # type: ignore[arg-type] # OpenAI SDK uses aliased types internally that mypy sees as incompatible with base types
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown role {message.role}")
|
raise ValueError(f"Unknown role {message.role}")
|
||||||
|
|
@ -990,9 +996,9 @@ def openai_content_to_content(content: str | Iterable[OpenAIChatCompletionConten
|
||||||
return [openai_content_to_content(c) for c in content]
|
return [openai_content_to_content(c) for c in content]
|
||||||
elif hasattr(content, "type"):
|
elif hasattr(content, "type"):
|
||||||
if content.type == "text":
|
if content.type == "text":
|
||||||
return TextContentItem(type="text", text=content.text)
|
return TextContentItem(type="text", text=content.text) # type: ignore[attr-defined] # Iterable narrowed by hasattr check but mypy doesn't track
|
||||||
elif content.type == "image_url":
|
elif content.type == "image_url":
|
||||||
return ImageContentItem(type="image", image=_URLOrData(url=URL(uri=content.image_url.url)))
|
return ImageContentItem(type="image", image=_URLOrData(url=URL(uri=content.image_url.url))) # type: ignore[attr-defined] # Iterable narrowed by hasattr check but mypy doesn't track
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown content type: {content.type}")
|
raise ValueError(f"Unknown content type: {content.type}")
|
||||||
else:
|
else:
|
||||||
|
|
@ -1041,9 +1047,9 @@ def convert_openai_chat_completion_choice(
|
||||||
completion_message=CompletionMessage(
|
completion_message=CompletionMessage(
|
||||||
content=choice.message.content or "", # CompletionMessage content is not optional
|
content=choice.message.content or "", # CompletionMessage content is not optional
|
||||||
stop_reason=_convert_openai_finish_reason(choice.finish_reason),
|
stop_reason=_convert_openai_finish_reason(choice.finish_reason),
|
||||||
tool_calls=_convert_openai_tool_calls(choice.message.tool_calls),
|
tool_calls=_convert_openai_tool_calls(choice.message.tool_calls) if choice.message.tool_calls else [], # type: ignore[arg-type] # OpenAI tool_calls Optional type broadens union
|
||||||
),
|
),
|
||||||
logprobs=_convert_openai_logprobs(getattr(choice, "logprobs", None)),
|
logprobs=_convert_openai_logprobs(getattr(choice, "logprobs", None)), # type: ignore[arg-type] # getattr returns Any, can't narrow without inspection
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1070,7 +1076,7 @@ async def convert_openai_chat_completion_stream(
|
||||||
choice = chunk.choices[0] # assuming only one choice per chunk
|
choice = chunk.choices[0] # assuming only one choice per chunk
|
||||||
|
|
||||||
# we assume there's only one finish_reason in the stream
|
# we assume there's only one finish_reason in the stream
|
||||||
stop_reason = _convert_openai_finish_reason(choice.finish_reason) or stop_reason
|
stop_reason = _convert_openai_finish_reason(choice.finish_reason) if choice.finish_reason else stop_reason
|
||||||
logprobs = getattr(choice, "logprobs", None)
|
logprobs = getattr(choice, "logprobs", None)
|
||||||
|
|
||||||
# if there's a tool call, emit an event for each tool in the list
|
# if there's a tool call, emit an event for each tool in the list
|
||||||
|
|
@ -1083,7 +1089,7 @@ async def convert_openai_chat_completion_stream(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=event_type,
|
event_type=event_type,
|
||||||
delta=TextDelta(text=choice.delta.content),
|
delta=TextDelta(text=choice.delta.content),
|
||||||
logprobs=_convert_openai_logprobs(logprobs),
|
logprobs=_convert_openai_logprobs(logprobs), # type: ignore[arg-type] # logprobs type broadened from getattr result
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1101,10 +1107,10 @@ async def convert_openai_chat_completion_stream(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=event_type,
|
event_type=event_type,
|
||||||
delta=ToolCallDelta(
|
delta=ToolCallDelta(
|
||||||
tool_call=_convert_openai_tool_calls([tool_call])[0],
|
tool_call=_convert_openai_tool_calls([tool_call])[0], # type: ignore[arg-type, list-item] # delta tool_call type differs from complete tool_call
|
||||||
parse_status=ToolCallParseStatus.succeeded,
|
parse_status=ToolCallParseStatus.succeeded,
|
||||||
),
|
),
|
||||||
logprobs=_convert_openai_logprobs(logprobs),
|
logprobs=_convert_openai_logprobs(logprobs), # type: ignore[arg-type] # logprobs type broadened from getattr result
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
@ -1125,11 +1131,14 @@ async def convert_openai_chat_completion_stream(
|
||||||
if tool_call.function.name:
|
if tool_call.function.name:
|
||||||
buffer["name"] = tool_call.function.name
|
buffer["name"] = tool_call.function.name
|
||||||
delta = f"{buffer['name']}("
|
delta = f"{buffer['name']}("
|
||||||
|
if buffer["content"] is not None:
|
||||||
buffer["content"] += delta
|
buffer["content"] += delta
|
||||||
|
|
||||||
if tool_call.function.arguments:
|
if tool_call.function.arguments:
|
||||||
delta = tool_call.function.arguments
|
delta = tool_call.function.arguments
|
||||||
|
if buffer["arguments"] is not None and delta:
|
||||||
buffer["arguments"] += delta
|
buffer["arguments"] += delta
|
||||||
|
if buffer["content"] is not None and delta:
|
||||||
buffer["content"] += delta
|
buffer["content"] += delta
|
||||||
|
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
|
|
@ -1139,7 +1148,7 @@ async def convert_openai_chat_completion_stream(
|
||||||
tool_call=delta,
|
tool_call=delta,
|
||||||
parse_status=ToolCallParseStatus.in_progress,
|
parse_status=ToolCallParseStatus.in_progress,
|
||||||
),
|
),
|
||||||
logprobs=_convert_openai_logprobs(logprobs),
|
logprobs=_convert_openai_logprobs(logprobs), # type: ignore[arg-type] # logprobs type broadened from getattr result
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif choice.delta.content:
|
elif choice.delta.content:
|
||||||
|
|
@ -1147,7 +1156,7 @@ async def convert_openai_chat_completion_stream(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=event_type,
|
event_type=event_type,
|
||||||
delta=TextDelta(text=choice.delta.content or ""),
|
delta=TextDelta(text=choice.delta.content or ""),
|
||||||
logprobs=_convert_openai_logprobs(logprobs),
|
logprobs=_convert_openai_logprobs(logprobs), # type: ignore[arg-type] # logprobs type broadened from getattr result
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1155,6 +1164,7 @@ async def convert_openai_chat_completion_stream(
|
||||||
logger.debug(f"toolcall_buffer[{idx}]: {buffer}")
|
logger.debug(f"toolcall_buffer[{idx}]: {buffer}")
|
||||||
if buffer["name"]:
|
if buffer["name"]:
|
||||||
delta = ")"
|
delta = ")"
|
||||||
|
if buffer["content"] is not None:
|
||||||
buffer["content"] += delta
|
buffer["content"] += delta
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
|
|
@ -1168,16 +1178,16 @@ async def convert_openai_chat_completion_stream(
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tool_call = ToolCall(
|
parsed_tool_call = ToolCall(
|
||||||
call_id=buffer["call_id"],
|
call_id=buffer["call_id"] or "",
|
||||||
tool_name=buffer["name"],
|
tool_name=buffer["name"] or "",
|
||||||
arguments=buffer["arguments"],
|
arguments=buffer["arguments"] or "",
|
||||||
)
|
)
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
delta=ToolCallDelta(
|
delta=ToolCallDelta(
|
||||||
tool_call=tool_call,
|
tool_call=parsed_tool_call, # type: ignore[arg-type] # ToolCallDelta.tool_call accepts Union[str, ToolCall]
|
||||||
parse_status=ToolCallParseStatus.succeeded,
|
parse_status=ToolCallParseStatus.succeeded,
|
||||||
),
|
),
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
|
|
@ -1189,7 +1199,7 @@ async def convert_openai_chat_completion_stream(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.progress,
|
event_type=ChatCompletionResponseEventType.progress,
|
||||||
delta=ToolCallDelta(
|
delta=ToolCallDelta(
|
||||||
tool_call=buffer["content"],
|
tool_call=buffer["content"], # type: ignore[arg-type] # ToolCallDelta.tool_call accepts Union[str, ToolCall]
|
||||||
parse_status=ToolCallParseStatus.failed,
|
parse_status=ToolCallParseStatus.failed,
|
||||||
),
|
),
|
||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
|
|
@ -1250,7 +1260,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
||||||
top_p: float | None = None,
|
top_p: float | None = None,
|
||||||
user: str | None = None,
|
user: str | None = None,
|
||||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||||
messages = openai_messages_to_messages(messages)
|
messages = openai_messages_to_messages(messages) # type: ignore[assignment] # converted from OpenAI to LlamaStack message format
|
||||||
response_format = _convert_openai_request_response_format(response_format)
|
response_format = _convert_openai_request_response_format(response_format)
|
||||||
sampling_params = _convert_openai_sampling_params(
|
sampling_params = _convert_openai_sampling_params(
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
|
|
@ -1259,15 +1269,15 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
||||||
)
|
)
|
||||||
tool_config = _convert_openai_request_tool_config(tool_choice)
|
tool_config = _convert_openai_request_tool_config(tool_choice)
|
||||||
|
|
||||||
tools = _convert_openai_request_tools(tools)
|
tools = _convert_openai_request_tools(tools) # type: ignore[assignment] # converted from OpenAI to LlamaStack tool format
|
||||||
if tool_config.tool_choice == ToolChoice.none:
|
if tool_config.tool_choice == ToolChoice.none:
|
||||||
tools = []
|
tools = [] # type: ignore[assignment] # empty list narrows return type but mypy tracks broader type
|
||||||
|
|
||||||
outstanding_responses = []
|
outstanding_responses = []
|
||||||
# "n" is the number of completions to generate per prompt
|
# "n" is the number of completions to generate per prompt
|
||||||
n = n or 1
|
n = n or 1
|
||||||
for _i in range(0, n):
|
for _i in range(0, n):
|
||||||
response = self.chat_completion(
|
response = self.chat_completion( # type: ignore[attr-defined] # mixin expects class to implement chat_completion
|
||||||
model_id=model,
|
model_id=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
|
|
@ -1279,7 +1289,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
||||||
outstanding_responses.append(response)
|
outstanding_responses.append(response)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return OpenAIChatCompletionToLlamaStackMixin._process_stream_response(self, model, outstanding_responses)
|
return OpenAIChatCompletionToLlamaStackMixin._process_stream_response(self, model, outstanding_responses) # type: ignore[no-any-return] # mixin async generator return type too complex for mypy
|
||||||
|
|
||||||
return await OpenAIChatCompletionToLlamaStackMixin._process_non_stream_response(
|
return await OpenAIChatCompletionToLlamaStackMixin._process_non_stream_response(
|
||||||
self, model, outstanding_responses
|
self, model, outstanding_responses
|
||||||
|
|
@ -1295,14 +1305,16 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
||||||
response = await outstanding_response
|
response = await outstanding_response
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
event = chunk.event
|
event = chunk.event
|
||||||
finish_reason = _convert_stop_reason_to_openai_finish_reason(event.stop_reason)
|
finish_reason = (
|
||||||
|
_convert_stop_reason_to_openai_finish_reason(event.stop_reason) if event.stop_reason else None
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(event.delta, TextDelta):
|
if isinstance(event.delta, TextDelta):
|
||||||
text_delta = event.delta.text
|
text_delta = event.delta.text
|
||||||
delta = OpenAIChoiceDelta(content=text_delta)
|
delta = OpenAIChoiceDelta(content=text_delta)
|
||||||
yield OpenAIChatCompletionChunk(
|
yield OpenAIChatCompletionChunk(
|
||||||
id=id,
|
id=id,
|
||||||
choices=[OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)],
|
choices=[OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)], # type: ignore[arg-type] # finish_reason Optional[str] incompatible with Literal union
|
||||||
created=int(time.time()),
|
created=int(time.time()),
|
||||||
model=model,
|
model=model,
|
||||||
object="chat.completion.chunk",
|
object="chat.completion.chunk",
|
||||||
|
|
@ -1310,13 +1322,17 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
||||||
elif isinstance(event.delta, ToolCallDelta):
|
elif isinstance(event.delta, ToolCallDelta):
|
||||||
if event.delta.parse_status == ToolCallParseStatus.succeeded:
|
if event.delta.parse_status == ToolCallParseStatus.succeeded:
|
||||||
tool_call = event.delta.tool_call
|
tool_call = event.delta.tool_call
|
||||||
|
if isinstance(tool_call, str):
|
||||||
|
continue
|
||||||
|
|
||||||
# First chunk includes full structure
|
# First chunk includes full structure
|
||||||
openai_tool_call = OpenAIChoiceDeltaToolCall(
|
openai_tool_call = OpenAIChoiceDeltaToolCall(
|
||||||
index=0,
|
index=0,
|
||||||
id=tool_call.call_id,
|
id=tool_call.call_id,
|
||||||
function=OpenAIChoiceDeltaToolCallFunction(
|
function=OpenAIChoiceDeltaToolCallFunction(
|
||||||
name=tool_call.tool_name,
|
name=tool_call.tool_name
|
||||||
|
if isinstance(tool_call.tool_name, str)
|
||||||
|
else tool_call.tool_name.value, # type: ignore[arg-type] # enum .value extraction on Union confuses mypy
|
||||||
arguments="",
|
arguments="",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
@ -1324,7 +1340,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
||||||
yield OpenAIChatCompletionChunk(
|
yield OpenAIChatCompletionChunk(
|
||||||
id=id,
|
id=id,
|
||||||
choices=[
|
choices=[
|
||||||
OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)
|
OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta) # type: ignore[arg-type] # finish_reason Optional[str] incompatible with Literal union
|
||||||
],
|
],
|
||||||
created=int(time.time()),
|
created=int(time.time()),
|
||||||
model=model,
|
model=model,
|
||||||
|
|
@ -1341,7 +1357,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
||||||
yield OpenAIChatCompletionChunk(
|
yield OpenAIChatCompletionChunk(
|
||||||
id=id,
|
id=id,
|
||||||
choices=[
|
choices=[
|
||||||
OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)
|
OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta) # type: ignore[arg-type] # finish_reason Optional[str] incompatible with Literal union
|
||||||
],
|
],
|
||||||
created=int(time.time()),
|
created=int(time.time()),
|
||||||
model=model,
|
model=model,
|
||||||
|
|
@ -1351,7 +1367,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
||||||
async def _process_non_stream_response(
|
async def _process_non_stream_response(
|
||||||
self, model: str, outstanding_responses: list[Awaitable[ChatCompletionResponse]]
|
self, model: str, outstanding_responses: list[Awaitable[ChatCompletionResponse]]
|
||||||
) -> OpenAIChatCompletion:
|
) -> OpenAIChatCompletion:
|
||||||
choices = []
|
choices: list[OpenAIChatCompletionChoice] = []
|
||||||
for outstanding_response in outstanding_responses:
|
for outstanding_response in outstanding_responses:
|
||||||
response = await outstanding_response
|
response = await outstanding_response
|
||||||
completion_message = response.completion_message
|
completion_message = response.completion_message
|
||||||
|
|
@ -1360,14 +1376,14 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
||||||
|
|
||||||
choice = OpenAIChatCompletionChoice(
|
choice = OpenAIChatCompletionChoice(
|
||||||
index=len(choices),
|
index=len(choices),
|
||||||
message=message,
|
message=message, # type: ignore[arg-type] # OpenAIChatCompletionMessage union incompatible with narrower Message type
|
||||||
finish_reason=finish_reason,
|
finish_reason=finish_reason,
|
||||||
)
|
)
|
||||||
choices.append(choice)
|
choices.append(choice) # type: ignore[arg-type] # OpenAIChatCompletionChoice type annotation mismatch
|
||||||
|
|
||||||
return OpenAIChatCompletion(
|
return OpenAIChatCompletion(
|
||||||
id=f"chatcmpl-{uuid.uuid4()}",
|
id=f"chatcmpl-{uuid.uuid4()}",
|
||||||
choices=choices,
|
choices=choices, # type: ignore[arg-type] # list[OpenAIChatCompletionChoice] union incompatible
|
||||||
created=int(time.time()),
|
created=int(time.time()),
|
||||||
model=model,
|
model=model,
|
||||||
object="chat.completion",
|
object="chat.completion",
|
||||||
|
|
|
||||||
|
|
@ -83,9 +83,6 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
||||||
# This is set in list_models() and used in check_model_availability()
|
# This is set in list_models() and used in check_model_availability()
|
||||||
_model_cache: dict[str, Model] = {}
|
_model_cache: dict[str, Model] = {}
|
||||||
|
|
||||||
# List of allowed models for this provider, if empty all models allowed
|
|
||||||
allowed_models: list[str] = []
|
|
||||||
|
|
||||||
# Optional field name in provider data to look for API key, which takes precedence
|
# Optional field name in provider data to look for API key, which takes precedence
|
||||||
provider_data_api_key_field: str | None = None
|
provider_data_api_key_field: str | None = None
|
||||||
|
|
||||||
|
|
@ -441,7 +438,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel):
|
||||||
for provider_model_id in provider_models_ids:
|
for provider_model_id in provider_models_ids:
|
||||||
if not isinstance(provider_model_id, str):
|
if not isinstance(provider_model_id, str):
|
||||||
raise ValueError(f"Model ID {provider_model_id} from list_provider_model_ids() is not a string")
|
raise ValueError(f"Model ID {provider_model_id} from list_provider_model_ids() is not a string")
|
||||||
if self.allowed_models and provider_model_id not in self.allowed_models:
|
if self.config.allowed_models is not None and provider_model_id not in self.config.allowed_models:
|
||||||
logger.info(f"Skipping model {provider_model_id} as it is not in the allowed models list")
|
logger.info(f"Skipping model {provider_model_id} as it is not in the allowed models list")
|
||||||
continue
|
continue
|
||||||
model = self.construct_model_from_identifier(provider_model_id)
|
model = self.construct_model_from_identifier(provider_model_id)
|
||||||
|
|
|
||||||
|
|
@ -196,6 +196,7 @@ def make_overlapped_chunks(
|
||||||
chunks.append(
|
chunks.append(
|
||||||
Chunk(
|
Chunk(
|
||||||
content=chunk,
|
content=chunk,
|
||||||
|
chunk_id=chunk_id,
|
||||||
metadata=chunk_metadata,
|
metadata=chunk_metadata,
|
||||||
chunk_metadata=backend_chunk_metadata,
|
chunk_metadata=backend_chunk_metadata,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -70,13 +70,13 @@ class ResponsesStore:
|
||||||
base_store = sqlstore_impl(self.reference)
|
base_store = sqlstore_impl(self.reference)
|
||||||
self.sql_store = AuthorizedSqlStore(base_store, self.policy)
|
self.sql_store = AuthorizedSqlStore(base_store, self.policy)
|
||||||
|
|
||||||
|
# Disable write queue for SQLite since WAL mode handles concurrency
|
||||||
|
# Keep it enabled for other backends (like Postgres) for performance
|
||||||
backend_config = _SQLSTORE_BACKENDS.get(self.reference.backend)
|
backend_config = _SQLSTORE_BACKENDS.get(self.reference.backend)
|
||||||
if backend_config is None:
|
if backend_config and backend_config.type == StorageBackendType.SQL_SQLITE:
|
||||||
raise ValueError(
|
|
||||||
f"Unregistered SQL backend '{self.reference.backend}'. Registered backends: {sorted(_SQLSTORE_BACKENDS)}"
|
|
||||||
)
|
|
||||||
if backend_config.type == StorageBackendType.SQL_SQLITE:
|
|
||||||
self.enable_write_queue = False
|
self.enable_write_queue = False
|
||||||
|
logger.debug("Write queue disabled for SQLite (WAL mode handles concurrency)")
|
||||||
|
|
||||||
await self.sql_store.create_table(
|
await self.sql_store.create_table(
|
||||||
"openai_responses",
|
"openai_responses",
|
||||||
{
|
{
|
||||||
|
|
@ -99,8 +99,9 @@ class ResponsesStore:
|
||||||
self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
|
self._queue = asyncio.Queue(maxsize=self._max_write_queue_size)
|
||||||
for _ in range(self._num_writers):
|
for _ in range(self._num_writers):
|
||||||
self._worker_tasks.append(asyncio.create_task(self._worker_loop()))
|
self._worker_tasks.append(asyncio.create_task(self._worker_loop()))
|
||||||
else:
|
logger.debug(
|
||||||
logger.debug("Write queue disabled for SQLite to avoid concurrency issues")
|
f"Responses store write queue enabled with {self._num_writers} writers, max queue size {self._max_write_queue_size}"
|
||||||
|
)
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
if not self._worker_tasks:
|
if not self._worker_tasks:
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@ from sqlalchemy import (
|
||||||
String,
|
String,
|
||||||
Table,
|
Table,
|
||||||
Text,
|
Text,
|
||||||
|
event,
|
||||||
inspect,
|
inspect,
|
||||||
select,
|
select,
|
||||||
text,
|
text,
|
||||||
|
|
@ -75,7 +76,36 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
||||||
self.metadata = MetaData()
|
self.metadata = MetaData()
|
||||||
|
|
||||||
def create_engine(self) -> AsyncEngine:
|
def create_engine(self) -> AsyncEngine:
|
||||||
return create_async_engine(self.config.engine_str, pool_pre_ping=True)
|
# Configure connection args for better concurrency support
|
||||||
|
connect_args = {}
|
||||||
|
if "sqlite" in self.config.engine_str:
|
||||||
|
# SQLite-specific optimizations for concurrent access
|
||||||
|
# With WAL mode, most locks resolve in milliseconds, but allow up to 5s for edge cases
|
||||||
|
connect_args["timeout"] = 5.0
|
||||||
|
connect_args["check_same_thread"] = False # Allow usage across asyncio tasks
|
||||||
|
|
||||||
|
engine = create_async_engine(
|
||||||
|
self.config.engine_str,
|
||||||
|
pool_pre_ping=True,
|
||||||
|
connect_args=connect_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Enable WAL mode for SQLite to support concurrent readers and writers
|
||||||
|
if "sqlite" in self.config.engine_str:
|
||||||
|
|
||||||
|
@event.listens_for(engine.sync_engine, "connect")
|
||||||
|
def set_sqlite_pragma(dbapi_conn, connection_record):
|
||||||
|
cursor = dbapi_conn.cursor()
|
||||||
|
# Enable Write-Ahead Logging for better concurrency
|
||||||
|
cursor.execute("PRAGMA journal_mode=WAL")
|
||||||
|
# Set busy timeout to 5 seconds (retry instead of immediate failure)
|
||||||
|
# With WAL mode, locks should be brief; if we hit 5s there's a bigger issue
|
||||||
|
cursor.execute("PRAGMA busy_timeout=5000")
|
||||||
|
# Use NORMAL synchronous mode for better performance (still safe with WAL)
|
||||||
|
cursor.execute("PRAGMA synchronous=NORMAL")
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
return engine
|
||||||
|
|
||||||
async def create_table(
|
async def create_table(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -430,6 +430,32 @@ def _unwrap_generic_list(typ: type[list[T]]) -> type[T]:
|
||||||
return list_type # type: ignore[no-any-return]
|
return list_type # type: ignore[no-any-return]
|
||||||
|
|
||||||
|
|
||||||
|
def is_generic_sequence(typ: object) -> bool:
|
||||||
|
"True if the specified type is a generic Sequence, i.e. `Sequence[T]`."
|
||||||
|
import collections.abc
|
||||||
|
|
||||||
|
typ = unwrap_annotated_type(typ)
|
||||||
|
return typing.get_origin(typ) is collections.abc.Sequence
|
||||||
|
|
||||||
|
|
||||||
|
def unwrap_generic_sequence(typ: object) -> type:
|
||||||
|
"""
|
||||||
|
Extracts the item type of a Sequence type.
|
||||||
|
|
||||||
|
:param typ: The Sequence type `Sequence[T]`.
|
||||||
|
:returns: The item type `T`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return rewrap_annotated_type(_unwrap_generic_sequence, typ) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|
||||||
|
def _unwrap_generic_sequence(typ: object) -> type:
|
||||||
|
"Extracts the item type of a Sequence type (e.g. returns `T` for `Sequence[T]`)."
|
||||||
|
|
||||||
|
(sequence_type,) = typing.get_args(typ) # unpack single tuple element
|
||||||
|
return sequence_type # type: ignore[no-any-return]
|
||||||
|
|
||||||
|
|
||||||
def is_generic_set(typ: object) -> TypeGuard[type[set]]:
|
def is_generic_set(typ: object) -> TypeGuard[type[set]]:
|
||||||
"True if the specified type is a generic set, i.e. `Set[T]`."
|
"True if the specified type is a generic set, i.e. `Set[T]`."
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -18,10 +18,12 @@ from .inspection import (
|
||||||
TypeLike,
|
TypeLike,
|
||||||
is_generic_dict,
|
is_generic_dict,
|
||||||
is_generic_list,
|
is_generic_list,
|
||||||
|
is_generic_sequence,
|
||||||
is_type_optional,
|
is_type_optional,
|
||||||
is_type_union,
|
is_type_union,
|
||||||
unwrap_generic_dict,
|
unwrap_generic_dict,
|
||||||
unwrap_generic_list,
|
unwrap_generic_list,
|
||||||
|
unwrap_generic_sequence,
|
||||||
unwrap_optional_type,
|
unwrap_optional_type,
|
||||||
unwrap_union_types,
|
unwrap_union_types,
|
||||||
)
|
)
|
||||||
|
|
@ -155,24 +157,28 @@ def python_type_to_name(data_type: TypeLike, force: bool = False) -> str:
|
||||||
if metadata is not None:
|
if metadata is not None:
|
||||||
# type is Annotated[T, ...]
|
# type is Annotated[T, ...]
|
||||||
arg = typing.get_args(data_type)[0]
|
arg = typing.get_args(data_type)[0]
|
||||||
return python_type_to_name(arg)
|
return python_type_to_name(arg, force=force)
|
||||||
|
|
||||||
if force:
|
if force:
|
||||||
# generic types
|
# generic types
|
||||||
if is_type_optional(data_type, strict=True):
|
if is_type_optional(data_type, strict=True):
|
||||||
inner_name = python_type_to_name(unwrap_optional_type(data_type))
|
inner_name = python_type_to_name(unwrap_optional_type(data_type), force=True)
|
||||||
return f"Optional__{inner_name}"
|
return f"Optional__{inner_name}"
|
||||||
elif is_generic_list(data_type):
|
elif is_generic_list(data_type):
|
||||||
item_name = python_type_to_name(unwrap_generic_list(data_type))
|
item_name = python_type_to_name(unwrap_generic_list(data_type), force=True)
|
||||||
|
return f"List__{item_name}"
|
||||||
|
elif is_generic_sequence(data_type):
|
||||||
|
# Treat Sequence the same as List for schema generation purposes
|
||||||
|
item_name = python_type_to_name(unwrap_generic_sequence(data_type), force=True)
|
||||||
return f"List__{item_name}"
|
return f"List__{item_name}"
|
||||||
elif is_generic_dict(data_type):
|
elif is_generic_dict(data_type):
|
||||||
key_type, value_type = unwrap_generic_dict(data_type)
|
key_type, value_type = unwrap_generic_dict(data_type)
|
||||||
key_name = python_type_to_name(key_type)
|
key_name = python_type_to_name(key_type, force=True)
|
||||||
value_name = python_type_to_name(value_type)
|
value_name = python_type_to_name(value_type, force=True)
|
||||||
return f"Dict__{key_name}__{value_name}"
|
return f"Dict__{key_name}__{value_name}"
|
||||||
elif is_type_union(data_type):
|
elif is_type_union(data_type):
|
||||||
member_types = unwrap_union_types(data_type)
|
member_types = unwrap_union_types(data_type)
|
||||||
member_names = "__".join(python_type_to_name(member_type) for member_type in member_types)
|
member_names = "__".join(python_type_to_name(member_type, force=True) for member_type in member_types)
|
||||||
return f"Union__{member_names}"
|
return f"Union__{member_names}"
|
||||||
|
|
||||||
# named system or user-defined type
|
# named system or user-defined type
|
||||||
|
|
|
||||||
|
|
@ -111,7 +111,7 @@ def get_class_property_docstrings(
|
||||||
def docstring_to_schema(data_type: type) -> Schema:
|
def docstring_to_schema(data_type: type) -> Schema:
|
||||||
short_description, long_description = get_class_docstrings(data_type)
|
short_description, long_description = get_class_docstrings(data_type)
|
||||||
schema: Schema = {
|
schema: Schema = {
|
||||||
"title": python_type_to_name(data_type),
|
"title": python_type_to_name(data_type, force=True),
|
||||||
}
|
}
|
||||||
|
|
||||||
description = "\n".join(filter(None, [short_description, long_description]))
|
description = "\n".join(filter(None, [short_description, long_description]))
|
||||||
|
|
@ -417,6 +417,10 @@ class JsonSchemaGenerator:
|
||||||
if origin_type is list:
|
if origin_type is list:
|
||||||
(list_type,) = typing.get_args(typ) # unpack single tuple element
|
(list_type,) = typing.get_args(typ) # unpack single tuple element
|
||||||
return {"type": "array", "items": self.type_to_schema(list_type)}
|
return {"type": "array", "items": self.type_to_schema(list_type)}
|
||||||
|
elif origin_type is collections.abc.Sequence:
|
||||||
|
# Treat Sequence the same as list for JSON schema (both are arrays)
|
||||||
|
(sequence_type,) = typing.get_args(typ) # unpack single tuple element
|
||||||
|
return {"type": "array", "items": self.type_to_schema(sequence_type)}
|
||||||
elif origin_type is dict:
|
elif origin_type is dict:
|
||||||
key_type, value_type = typing.get_args(typ)
|
key_type, value_type = typing.get_args(typ)
|
||||||
if not (key_type is str or key_type is int or is_type_enum(key_type)):
|
if not (key_type is str or key_type is int or is_type_enum(key_type)):
|
||||||
|
|
|
||||||
|
|
@ -51,7 +51,11 @@ async function proxyRequest(request: NextRequest, method: string) {
|
||||||
);
|
);
|
||||||
|
|
||||||
// Create response with same status and headers
|
// Create response with same status and headers
|
||||||
const proxyResponse = new NextResponse(responseText, {
|
// Handle 204 No Content responses specially
|
||||||
|
const proxyResponse =
|
||||||
|
response.status === 204
|
||||||
|
? new NextResponse(null, { status: 204 })
|
||||||
|
: new NextResponse(responseText, {
|
||||||
status: response.status,
|
status: response.status,
|
||||||
statusText: response.statusText,
|
statusText: response.statusText,
|
||||||
});
|
});
|
||||||
|
|
|
||||||
5
src/llama_stack/ui/app/prompts/page.tsx
Normal file
5
src/llama_stack/ui/app/prompts/page.tsx
Normal file
|
|
@ -0,0 +1,5 @@
|
||||||
|
import { PromptManagement } from "@/components/prompts";
|
||||||
|
|
||||||
|
export default function PromptsPage() {
|
||||||
|
return <PromptManagement />;
|
||||||
|
}
|
||||||
|
|
@ -8,6 +8,7 @@ import {
|
||||||
MessageCircle,
|
MessageCircle,
|
||||||
Settings2,
|
Settings2,
|
||||||
Compass,
|
Compass,
|
||||||
|
FileText,
|
||||||
} from "lucide-react";
|
} from "lucide-react";
|
||||||
import Link from "next/link";
|
import Link from "next/link";
|
||||||
import { usePathname } from "next/navigation";
|
import { usePathname } from "next/navigation";
|
||||||
|
|
@ -50,6 +51,11 @@ const manageItems = [
|
||||||
url: "/logs/vector-stores",
|
url: "/logs/vector-stores",
|
||||||
icon: Database,
|
icon: Database,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
title: "Prompts",
|
||||||
|
url: "/prompts",
|
||||||
|
icon: FileText,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
title: "Documentation",
|
title: "Documentation",
|
||||||
url: "https://llama-stack.readthedocs.io/en/latest/references/api_reference/index.html",
|
url: "https://llama-stack.readthedocs.io/en/latest/references/api_reference/index.html",
|
||||||
|
|
|
||||||
4
src/llama_stack/ui/components/prompts/index.ts
Normal file
4
src/llama_stack/ui/components/prompts/index.ts
Normal file
|
|
@ -0,0 +1,4 @@
|
||||||
|
export { PromptManagement } from "./prompt-management";
|
||||||
|
export { PromptList } from "./prompt-list";
|
||||||
|
export { PromptEditor } from "./prompt-editor";
|
||||||
|
export * from "./types";
|
||||||
309
src/llama_stack/ui/components/prompts/prompt-editor.test.tsx
Normal file
309
src/llama_stack/ui/components/prompts/prompt-editor.test.tsx
Normal file
|
|
@ -0,0 +1,309 @@
|
||||||
|
import React from "react";
|
||||||
|
import { render, screen, fireEvent } from "@testing-library/react";
|
||||||
|
import "@testing-library/jest-dom";
|
||||||
|
import { PromptEditor } from "./prompt-editor";
|
||||||
|
import type { Prompt, PromptFormData } from "./types";
|
||||||
|
|
||||||
|
describe("PromptEditor", () => {
|
||||||
|
const mockOnSave = jest.fn();
|
||||||
|
const mockOnCancel = jest.fn();
|
||||||
|
const mockOnDelete = jest.fn();
|
||||||
|
|
||||||
|
const defaultProps = {
|
||||||
|
onSave: mockOnSave,
|
||||||
|
onCancel: mockOnCancel,
|
||||||
|
onDelete: mockOnDelete,
|
||||||
|
};
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
jest.clearAllMocks();
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Create Mode", () => {
|
||||||
|
test("renders create form correctly", () => {
|
||||||
|
render(<PromptEditor {...defaultProps} />);
|
||||||
|
|
||||||
|
expect(screen.getByLabelText("Prompt Content *")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Variables")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Preview")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Create Prompt")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Cancel")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("shows preview placeholder when no content", () => {
|
||||||
|
render(<PromptEditor {...defaultProps} />);
|
||||||
|
|
||||||
|
expect(
|
||||||
|
screen.getByText("Enter content to preview the compiled prompt")
|
||||||
|
).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("submits form with correct data", () => {
|
||||||
|
render(<PromptEditor {...defaultProps} />);
|
||||||
|
|
||||||
|
const promptInput = screen.getByLabelText("Prompt Content *");
|
||||||
|
fireEvent.change(promptInput, {
|
||||||
|
target: { value: "Hello {{name}}, welcome!" },
|
||||||
|
});
|
||||||
|
|
||||||
|
fireEvent.click(screen.getByText("Create Prompt"));
|
||||||
|
|
||||||
|
expect(mockOnSave).toHaveBeenCalledWith({
|
||||||
|
prompt: "Hello {{name}}, welcome!",
|
||||||
|
variables: [],
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
test("prevents submission with empty prompt", () => {
|
||||||
|
render(<PromptEditor {...defaultProps} />);
|
||||||
|
|
||||||
|
fireEvent.click(screen.getByText("Create Prompt"));
|
||||||
|
|
||||||
|
expect(mockOnSave).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Edit Mode", () => {
|
||||||
|
const mockPrompt: Prompt = {
|
||||||
|
prompt_id: "prompt_123",
|
||||||
|
prompt: "Hello {{name}}, how is {{weather}}?",
|
||||||
|
version: 1,
|
||||||
|
variables: ["name", "weather"],
|
||||||
|
is_default: true,
|
||||||
|
};
|
||||||
|
|
||||||
|
test("renders edit form with existing data", () => {
|
||||||
|
render(<PromptEditor {...defaultProps} prompt={mockPrompt} />);
|
||||||
|
|
||||||
|
expect(
|
||||||
|
screen.getByDisplayValue("Hello {{name}}, how is {{weather}}?")
|
||||||
|
).toBeInTheDocument();
|
||||||
|
expect(screen.getAllByText("name")).toHaveLength(2); // One in variables, one in preview
|
||||||
|
expect(screen.getAllByText("weather")).toHaveLength(2); // One in variables, one in preview
|
||||||
|
expect(screen.getByText("Update Prompt")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Delete Prompt")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("submits updated data correctly", () => {
|
||||||
|
render(<PromptEditor {...defaultProps} prompt={mockPrompt} />);
|
||||||
|
|
||||||
|
const promptInput = screen.getByLabelText("Prompt Content *");
|
||||||
|
fireEvent.change(promptInput, {
|
||||||
|
target: { value: "Updated: Hello {{name}}!" },
|
||||||
|
});
|
||||||
|
|
||||||
|
fireEvent.click(screen.getByText("Update Prompt"));
|
||||||
|
|
||||||
|
expect(mockOnSave).toHaveBeenCalledWith({
|
||||||
|
prompt: "Updated: Hello {{name}}!",
|
||||||
|
variables: ["name", "weather"],
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Variables Management", () => {
|
||||||
|
test("adds new variable", () => {
|
||||||
|
render(<PromptEditor {...defaultProps} />);
|
||||||
|
|
||||||
|
const variableInput = screen.getByPlaceholderText(
|
||||||
|
"Add variable name (e.g. user_name, topic)"
|
||||||
|
);
|
||||||
|
fireEvent.change(variableInput, { target: { value: "testVar" } });
|
||||||
|
fireEvent.click(screen.getByText("Add"));
|
||||||
|
|
||||||
|
expect(screen.getByText("testVar")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("prevents adding duplicate variables", () => {
|
||||||
|
render(<PromptEditor {...defaultProps} />);
|
||||||
|
|
||||||
|
const variableInput = screen.getByPlaceholderText(
|
||||||
|
"Add variable name (e.g. user_name, topic)"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Add first variable
|
||||||
|
fireEvent.change(variableInput, { target: { value: "test" } });
|
||||||
|
fireEvent.click(screen.getByText("Add"));
|
||||||
|
|
||||||
|
// Try to add same variable again
|
||||||
|
fireEvent.change(variableInput, { target: { value: "test" } });
|
||||||
|
|
||||||
|
// Button should be disabled
|
||||||
|
expect(screen.getByText("Add")).toBeDisabled();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("removes variable", () => {
|
||||||
|
const mockPrompt: Prompt = {
|
||||||
|
prompt_id: "prompt_123",
|
||||||
|
prompt: "Hello {{name}}",
|
||||||
|
version: 1,
|
||||||
|
variables: ["name", "location"],
|
||||||
|
is_default: true,
|
||||||
|
};
|
||||||
|
|
||||||
|
render(<PromptEditor {...defaultProps} prompt={mockPrompt} />);
|
||||||
|
|
||||||
|
// Check that both variables are present initially
|
||||||
|
expect(screen.getAllByText("name").length).toBeGreaterThan(0);
|
||||||
|
expect(screen.getAllByText("location").length).toBeGreaterThan(0);
|
||||||
|
|
||||||
|
// Remove the location variable by clicking the X button with the specific title
|
||||||
|
const removeLocationButton = screen.getByTitle(
|
||||||
|
"Remove location variable"
|
||||||
|
);
|
||||||
|
fireEvent.click(removeLocationButton);
|
||||||
|
|
||||||
|
// Name should still be there, location should be gone from the variables section
|
||||||
|
expect(screen.getAllByText("name").length).toBeGreaterThan(0);
|
||||||
|
expect(
|
||||||
|
screen.queryByTitle("Remove location variable")
|
||||||
|
).not.toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("adds variable on Enter key", () => {
|
||||||
|
render(<PromptEditor {...defaultProps} />);
|
||||||
|
|
||||||
|
const variableInput = screen.getByPlaceholderText(
|
||||||
|
"Add variable name (e.g. user_name, topic)"
|
||||||
|
);
|
||||||
|
fireEvent.change(variableInput, { target: { value: "enterVar" } });
|
||||||
|
|
||||||
|
// Simulate Enter key press
|
||||||
|
fireEvent.keyPress(variableInput, {
|
||||||
|
key: "Enter",
|
||||||
|
code: "Enter",
|
||||||
|
charCode: 13,
|
||||||
|
preventDefault: jest.fn(),
|
||||||
|
});
|
||||||
|
|
||||||
|
// Check if the variable was added by looking for the badge
|
||||||
|
expect(screen.getAllByText("enterVar").length).toBeGreaterThan(0);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Preview Functionality", () => {
|
||||||
|
test("shows live preview with variables", () => {
|
||||||
|
render(<PromptEditor {...defaultProps} />);
|
||||||
|
|
||||||
|
// Add prompt content
|
||||||
|
const promptInput = screen.getByLabelText("Prompt Content *");
|
||||||
|
fireEvent.change(promptInput, {
|
||||||
|
target: { value: "Hello {{name}}, welcome to {{place}}!" },
|
||||||
|
});
|
||||||
|
|
||||||
|
// Add variables
|
||||||
|
const variableInput = screen.getByPlaceholderText(
|
||||||
|
"Add variable name (e.g. user_name, topic)"
|
||||||
|
);
|
||||||
|
fireEvent.change(variableInput, { target: { value: "name" } });
|
||||||
|
fireEvent.click(screen.getByText("Add"));
|
||||||
|
|
||||||
|
fireEvent.change(variableInput, { target: { value: "place" } });
|
||||||
|
fireEvent.click(screen.getByText("Add"));
|
||||||
|
|
||||||
|
// Check that preview area shows the content
|
||||||
|
expect(screen.getByText("Compiled Prompt")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("shows variable value inputs in preview", () => {
|
||||||
|
const mockPrompt: Prompt = {
|
||||||
|
prompt_id: "prompt_123",
|
||||||
|
prompt: "Hello {{name}}",
|
||||||
|
version: 1,
|
||||||
|
variables: ["name"],
|
||||||
|
is_default: true,
|
||||||
|
};
|
||||||
|
|
||||||
|
render(<PromptEditor {...defaultProps} prompt={mockPrompt} />);
|
||||||
|
|
||||||
|
expect(screen.getByText("Variable Values")).toBeInTheDocument();
|
||||||
|
expect(
|
||||||
|
screen.getByPlaceholderText("Enter value for name")
|
||||||
|
).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("shows color legend for variable states", () => {
|
||||||
|
render(<PromptEditor {...defaultProps} />);
|
||||||
|
|
||||||
|
// Add content to show preview
|
||||||
|
const promptInput = screen.getByLabelText("Prompt Content *");
|
||||||
|
fireEvent.change(promptInput, {
|
||||||
|
target: { value: "Hello {{name}}" },
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(screen.getByText("Used")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Unused")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Undefined")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Error Handling", () => {
|
||||||
|
test("displays error message", () => {
|
||||||
|
const errorMessage = "Prompt contains undeclared variables";
|
||||||
|
render(<PromptEditor {...defaultProps} error={errorMessage} />);
|
||||||
|
|
||||||
|
expect(screen.getByText(errorMessage)).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Delete Functionality", () => {
|
||||||
|
const mockPrompt: Prompt = {
|
||||||
|
prompt_id: "prompt_123",
|
||||||
|
prompt: "Hello {{name}}",
|
||||||
|
version: 1,
|
||||||
|
variables: ["name"],
|
||||||
|
is_default: true,
|
||||||
|
};
|
||||||
|
|
||||||
|
test("shows delete button in edit mode", () => {
|
||||||
|
render(<PromptEditor {...defaultProps} prompt={mockPrompt} />);
|
||||||
|
|
||||||
|
expect(screen.getByText("Delete Prompt")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("hides delete button in create mode", () => {
|
||||||
|
render(<PromptEditor {...defaultProps} />);
|
||||||
|
|
||||||
|
expect(screen.queryByText("Delete Prompt")).not.toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("calls onDelete with confirmation", () => {
|
||||||
|
const originalConfirm = window.confirm;
|
||||||
|
window.confirm = jest.fn(() => true);
|
||||||
|
|
||||||
|
render(<PromptEditor {...defaultProps} prompt={mockPrompt} />);
|
||||||
|
|
||||||
|
fireEvent.click(screen.getByText("Delete Prompt"));
|
||||||
|
|
||||||
|
expect(window.confirm).toHaveBeenCalledWith(
|
||||||
|
"Are you sure you want to delete this prompt? This action cannot be undone."
|
||||||
|
);
|
||||||
|
expect(mockOnDelete).toHaveBeenCalledWith("prompt_123");
|
||||||
|
|
||||||
|
window.confirm = originalConfirm;
|
||||||
|
});
|
||||||
|
|
||||||
|
test("does not delete when confirmation is cancelled", () => {
|
||||||
|
const originalConfirm = window.confirm;
|
||||||
|
window.confirm = jest.fn(() => false);
|
||||||
|
|
||||||
|
render(<PromptEditor {...defaultProps} prompt={mockPrompt} />);
|
||||||
|
|
||||||
|
fireEvent.click(screen.getByText("Delete Prompt"));
|
||||||
|
|
||||||
|
expect(mockOnDelete).not.toHaveBeenCalled();
|
||||||
|
|
||||||
|
window.confirm = originalConfirm;
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Cancel Functionality", () => {
|
||||||
|
test("calls onCancel when cancel button is clicked", () => {
|
||||||
|
render(<PromptEditor {...defaultProps} />);
|
||||||
|
|
||||||
|
fireEvent.click(screen.getByText("Cancel"));
|
||||||
|
|
||||||
|
expect(mockOnCancel).toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
346
src/llama_stack/ui/components/prompts/prompt-editor.tsx
Normal file
346
src/llama_stack/ui/components/prompts/prompt-editor.tsx
Normal file
|
|
@ -0,0 +1,346 @@
|
||||||
|
"use client";
|
||||||
|
|
||||||
|
import { useState, useEffect } from "react";
|
||||||
|
import { Button } from "@/components/ui/button";
|
||||||
|
import { Input } from "@/components/ui/input";
|
||||||
|
import { Label } from "@/components/ui/label";
|
||||||
|
import { Textarea } from "@/components/ui/textarea";
|
||||||
|
import { Badge } from "@/components/ui/badge";
|
||||||
|
import {
|
||||||
|
Card,
|
||||||
|
CardContent,
|
||||||
|
CardDescription,
|
||||||
|
CardHeader,
|
||||||
|
CardTitle,
|
||||||
|
} from "@/components/ui/card";
|
||||||
|
import { Separator } from "@/components/ui/separator";
|
||||||
|
import { X, Plus, Save, Trash2 } from "lucide-react";
|
||||||
|
import { Prompt, PromptFormData } from "./types";
|
||||||
|
|
||||||
|
interface PromptEditorProps {
|
||||||
|
prompt?: Prompt;
|
||||||
|
onSave: (prompt: PromptFormData) => void;
|
||||||
|
onCancel: () => void;
|
||||||
|
onDelete?: (promptId: string) => void;
|
||||||
|
error?: string | null;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function PromptEditor({
|
||||||
|
prompt,
|
||||||
|
onSave,
|
||||||
|
onCancel,
|
||||||
|
onDelete,
|
||||||
|
error,
|
||||||
|
}: PromptEditorProps) {
|
||||||
|
const [formData, setFormData] = useState<PromptFormData>({
|
||||||
|
prompt: "",
|
||||||
|
variables: [],
|
||||||
|
});
|
||||||
|
|
||||||
|
const [newVariable, setNewVariable] = useState("");
|
||||||
|
const [variableValues, setVariableValues] = useState<Record<string, string>>(
|
||||||
|
{}
|
||||||
|
);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (prompt) {
|
||||||
|
setFormData({
|
||||||
|
prompt: prompt.prompt || "",
|
||||||
|
variables: prompt.variables || [],
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}, [prompt]);
|
||||||
|
|
||||||
|
const handleSubmit = (e: React.FormEvent) => {
|
||||||
|
e.preventDefault();
|
||||||
|
if (!formData.prompt.trim()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
onSave(formData);
|
||||||
|
};
|
||||||
|
|
||||||
|
const addVariable = () => {
|
||||||
|
if (
|
||||||
|
newVariable.trim() &&
|
||||||
|
!formData.variables.includes(newVariable.trim())
|
||||||
|
) {
|
||||||
|
setFormData(prev => ({
|
||||||
|
...prev,
|
||||||
|
variables: [...prev.variables, newVariable.trim()],
|
||||||
|
}));
|
||||||
|
setNewVariable("");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const removeVariable = (variableToRemove: string) => {
|
||||||
|
setFormData(prev => ({
|
||||||
|
...prev,
|
||||||
|
variables: prev.variables.filter(
|
||||||
|
variable => variable !== variableToRemove
|
||||||
|
),
|
||||||
|
}));
|
||||||
|
};
|
||||||
|
|
||||||
|
const renderPreview = () => {
|
||||||
|
const text = formData.prompt;
|
||||||
|
if (!text) return text;
|
||||||
|
|
||||||
|
// Split text by variable patterns and process each part
|
||||||
|
const parts = text.split(/(\{\{\s*\w+\s*\}\})/g);
|
||||||
|
|
||||||
|
return parts.map((part, index) => {
|
||||||
|
const variableMatch = part.match(/\{\{\s*(\w+)\s*\}\}/);
|
||||||
|
if (variableMatch) {
|
||||||
|
const variableName = variableMatch[1];
|
||||||
|
const isDefined = formData.variables.includes(variableName);
|
||||||
|
const value = variableValues[variableName];
|
||||||
|
|
||||||
|
if (!isDefined) {
|
||||||
|
// Variable not in variables list - likely a typo/bug (RED)
|
||||||
|
return (
|
||||||
|
<span
|
||||||
|
key={index}
|
||||||
|
className="bg-red-100 text-red-800 dark:bg-red-900 dark:text-red-200 px-1 rounded font-medium"
|
||||||
|
>
|
||||||
|
{part}
|
||||||
|
</span>
|
||||||
|
);
|
||||||
|
} else if (value && value.trim()) {
|
||||||
|
// Variable defined and has value - show the value (GREEN)
|
||||||
|
return (
|
||||||
|
<span
|
||||||
|
key={index}
|
||||||
|
className="bg-green-100 text-green-800 dark:bg-green-900 dark:text-green-200 px-1 rounded font-medium"
|
||||||
|
>
|
||||||
|
{value}
|
||||||
|
</span>
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
// Variable defined but empty (YELLOW)
|
||||||
|
return (
|
||||||
|
<span
|
||||||
|
key={index}
|
||||||
|
className="bg-yellow-100 text-yellow-800 dark:bg-yellow-900 dark:text-yellow-200 px-1 rounded font-medium"
|
||||||
|
>
|
||||||
|
{part}
|
||||||
|
</span>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return part;
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
const updateVariableValue = (variable: string, value: string) => {
|
||||||
|
setVariableValues(prev => ({
|
||||||
|
...prev,
|
||||||
|
[variable]: value,
|
||||||
|
}));
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<form onSubmit={handleSubmit} className="space-y-6">
|
||||||
|
{error && (
|
||||||
|
<div className="p-4 bg-destructive/10 border border-destructive/20 rounded-md">
|
||||||
|
<p className="text-destructive text-sm">{error}</p>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
<div className="grid grid-cols-1 lg:grid-cols-2 gap-6">
|
||||||
|
{/* Form Section */}
|
||||||
|
<div className="space-y-4">
|
||||||
|
<div>
|
||||||
|
<Label htmlFor="prompt">Prompt Content *</Label>
|
||||||
|
<Textarea
|
||||||
|
id="prompt"
|
||||||
|
value={formData.prompt}
|
||||||
|
onChange={e =>
|
||||||
|
setFormData(prev => ({ ...prev, prompt: e.target.value }))
|
||||||
|
}
|
||||||
|
placeholder="Enter your prompt content here. Use {{variable_name}} for dynamic variables."
|
||||||
|
className="min-h-32 font-mono mt-2"
|
||||||
|
required
|
||||||
|
/>
|
||||||
|
<p className="text-xs text-muted-foreground mt-2">
|
||||||
|
Use double curly braces around variable names, e.g.,{" "}
|
||||||
|
{`{{user_name}}`} or {`{{topic}}`}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="space-y-3">
|
||||||
|
<Label className="text-sm font-medium">Variables</Label>
|
||||||
|
|
||||||
|
<div className="flex gap-2 mt-2">
|
||||||
|
<Input
|
||||||
|
value={newVariable}
|
||||||
|
onChange={e => setNewVariable(e.target.value)}
|
||||||
|
placeholder="Add variable name (e.g. user_name, topic)"
|
||||||
|
onKeyPress={e =>
|
||||||
|
e.key === "Enter" && (e.preventDefault(), addVariable())
|
||||||
|
}
|
||||||
|
className="flex-1"
|
||||||
|
/>
|
||||||
|
<Button
|
||||||
|
type="button"
|
||||||
|
onClick={addVariable}
|
||||||
|
size="sm"
|
||||||
|
disabled={
|
||||||
|
!newVariable.trim() ||
|
||||||
|
formData.variables.includes(newVariable.trim())
|
||||||
|
}
|
||||||
|
>
|
||||||
|
<Plus className="h-4 w-4" />
|
||||||
|
Add
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{formData.variables.length > 0 && (
|
||||||
|
<div className="border rounded-lg p-3 bg-muted/20">
|
||||||
|
<div className="flex flex-wrap gap-2">
|
||||||
|
{formData.variables.map(variable => (
|
||||||
|
<Badge
|
||||||
|
key={variable}
|
||||||
|
variant="secondary"
|
||||||
|
className="text-sm px-2 py-1"
|
||||||
|
>
|
||||||
|
{variable}
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
onClick={() => removeVariable(variable)}
|
||||||
|
className="ml-2 hover:text-destructive transition-colors"
|
||||||
|
title={`Remove ${variable} variable`}
|
||||||
|
>
|
||||||
|
<X className="h-3 w-3" />
|
||||||
|
</button>
|
||||||
|
</Badge>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<p className="text-xs text-muted-foreground">
|
||||||
|
Variables that can be used in the prompt template. Each variable
|
||||||
|
should match a {`{{variable}}`} placeholder in the content above.
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Preview Section */}
|
||||||
|
<div className="space-y-4">
|
||||||
|
<Card>
|
||||||
|
<CardHeader>
|
||||||
|
<CardTitle className="text-lg">Preview</CardTitle>
|
||||||
|
<CardDescription>
|
||||||
|
Live preview of compiled prompt and variable substitution.
|
||||||
|
</CardDescription>
|
||||||
|
</CardHeader>
|
||||||
|
<CardContent className="space-y-4">
|
||||||
|
{formData.prompt ? (
|
||||||
|
<>
|
||||||
|
{/* Variable Values */}
|
||||||
|
{formData.variables.length > 0 && (
|
||||||
|
<div className="space-y-3">
|
||||||
|
<Label className="text-sm font-medium">
|
||||||
|
Variable Values
|
||||||
|
</Label>
|
||||||
|
<div className="space-y-2">
|
||||||
|
{formData.variables.map(variable => (
|
||||||
|
<div
|
||||||
|
key={variable}
|
||||||
|
className="grid grid-cols-2 gap-3 items-center"
|
||||||
|
>
|
||||||
|
<div className="text-sm font-mono text-muted-foreground">
|
||||||
|
{variable}
|
||||||
|
</div>
|
||||||
|
<Input
|
||||||
|
id={`var-${variable}`}
|
||||||
|
value={variableValues[variable] || ""}
|
||||||
|
onChange={e =>
|
||||||
|
updateVariableValue(variable, e.target.value)
|
||||||
|
}
|
||||||
|
placeholder={`Enter value for ${variable}`}
|
||||||
|
className="text-sm"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
<Separator />
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Live Preview */}
|
||||||
|
<div>
|
||||||
|
<Label className="text-sm font-medium mb-2 block">
|
||||||
|
Compiled Prompt
|
||||||
|
</Label>
|
||||||
|
<div className="bg-muted/50 p-4 rounded-lg border">
|
||||||
|
<div className="text-sm leading-relaxed whitespace-pre-wrap">
|
||||||
|
{renderPreview()}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div className="flex flex-wrap gap-4 mt-2 text-xs">
|
||||||
|
<div className="flex items-center gap-1">
|
||||||
|
<div className="w-3 h-3 bg-green-500 dark:bg-green-400 border rounded"></div>
|
||||||
|
<span className="text-muted-foreground">Used</span>
|
||||||
|
</div>
|
||||||
|
<div className="flex items-center gap-1">
|
||||||
|
<div className="w-3 h-3 bg-yellow-500 dark:bg-yellow-400 border rounded"></div>
|
||||||
|
<span className="text-muted-foreground">Unused</span>
|
||||||
|
</div>
|
||||||
|
<div className="flex items-center gap-1">
|
||||||
|
<div className="w-3 h-3 bg-red-500 dark:bg-red-400 border rounded"></div>
|
||||||
|
<span className="text-muted-foreground">Undefined</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</>
|
||||||
|
) : (
|
||||||
|
<div className="text-center py-8">
|
||||||
|
<div className="text-muted-foreground text-sm">
|
||||||
|
Enter content to preview the compiled prompt
|
||||||
|
</div>
|
||||||
|
<div className="text-xs text-muted-foreground mt-2">
|
||||||
|
Use {`{{variable_name}}`} to add dynamic variables
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</CardContent>
|
||||||
|
</Card>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<Separator />
|
||||||
|
|
||||||
|
<div className="flex justify-between">
|
||||||
|
<div>
|
||||||
|
{prompt && onDelete && (
|
||||||
|
<Button
|
||||||
|
type="button"
|
||||||
|
variant="destructive"
|
||||||
|
onClick={() => {
|
||||||
|
if (
|
||||||
|
confirm(
|
||||||
|
`Are you sure you want to delete this prompt? This action cannot be undone.`
|
||||||
|
)
|
||||||
|
) {
|
||||||
|
onDelete(prompt.prompt_id);
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Trash2 className="h-4 w-4 mr-2" />
|
||||||
|
Delete Prompt
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
<div className="flex gap-2">
|
||||||
|
<Button type="button" variant="outline" onClick={onCancel}>
|
||||||
|
Cancel
|
||||||
|
</Button>
|
||||||
|
<Button type="submit">
|
||||||
|
<Save className="h-4 w-4 mr-2" />
|
||||||
|
{prompt ? "Update" : "Create"} Prompt
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</form>
|
||||||
|
);
|
||||||
|
}
|
||||||
259
src/llama_stack/ui/components/prompts/prompt-list.test.tsx
Normal file
259
src/llama_stack/ui/components/prompts/prompt-list.test.tsx
Normal file
|
|
@ -0,0 +1,259 @@
|
||||||
|
import React from "react";
|
||||||
|
import { render, screen, fireEvent } from "@testing-library/react";
|
||||||
|
import "@testing-library/jest-dom";
|
||||||
|
import { PromptList } from "./prompt-list";
|
||||||
|
import type { Prompt } from "./types";
|
||||||
|
|
||||||
|
describe("PromptList", () => {
|
||||||
|
const mockOnEdit = jest.fn();
|
||||||
|
const mockOnDelete = jest.fn();
|
||||||
|
|
||||||
|
const defaultProps = {
|
||||||
|
prompts: [],
|
||||||
|
onEdit: mockOnEdit,
|
||||||
|
onDelete: mockOnDelete,
|
||||||
|
};
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
jest.clearAllMocks();
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Empty State", () => {
|
||||||
|
test("renders empty message when no prompts", () => {
|
||||||
|
render(<PromptList {...defaultProps} />);
|
||||||
|
|
||||||
|
expect(screen.getByText("No prompts yet")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("shows filtered empty message when search has no results", () => {
|
||||||
|
const prompts: Prompt[] = [
|
||||||
|
{
|
||||||
|
prompt_id: "prompt_123",
|
||||||
|
prompt: "Hello world",
|
||||||
|
version: 1,
|
||||||
|
variables: [],
|
||||||
|
is_default: false,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
render(<PromptList {...defaultProps} prompts={prompts} />);
|
||||||
|
|
||||||
|
// Search for something that doesn't exist
|
||||||
|
const searchInput = screen.getByPlaceholderText("Search prompts...");
|
||||||
|
fireEvent.change(searchInput, { target: { value: "nonexistent" } });
|
||||||
|
|
||||||
|
expect(
|
||||||
|
screen.getByText("No prompts match your filters")
|
||||||
|
).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Prompts Display", () => {
|
||||||
|
const mockPrompts: Prompt[] = [
|
||||||
|
{
|
||||||
|
prompt_id: "prompt_123",
|
||||||
|
prompt: "Hello {{name}}, how are you?",
|
||||||
|
version: 1,
|
||||||
|
variables: ["name"],
|
||||||
|
is_default: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
prompt_id: "prompt_456",
|
||||||
|
prompt: "Summarize this {{text}} in {{length}} words",
|
||||||
|
version: 2,
|
||||||
|
variables: ["text", "length"],
|
||||||
|
is_default: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
prompt_id: "prompt_789",
|
||||||
|
prompt: "Simple prompt with no variables",
|
||||||
|
version: 1,
|
||||||
|
variables: [],
|
||||||
|
is_default: false,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
test("renders prompts table with correct headers", () => {
|
||||||
|
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
|
||||||
|
|
||||||
|
expect(screen.getByText("ID")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Content")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Variables")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Version")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Actions")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("renders prompt data correctly", () => {
|
||||||
|
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
|
||||||
|
|
||||||
|
// Check prompt IDs
|
||||||
|
expect(screen.getByText("prompt_123")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("prompt_456")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("prompt_789")).toBeInTheDocument();
|
||||||
|
|
||||||
|
// Check content
|
||||||
|
expect(
|
||||||
|
screen.getByText("Hello {{name}}, how are you?")
|
||||||
|
).toBeInTheDocument();
|
||||||
|
expect(
|
||||||
|
screen.getByText("Summarize this {{text}} in {{length}} words")
|
||||||
|
).toBeInTheDocument();
|
||||||
|
expect(
|
||||||
|
screen.getByText("Simple prompt with no variables")
|
||||||
|
).toBeInTheDocument();
|
||||||
|
|
||||||
|
// Check versions
|
||||||
|
expect(screen.getAllByText("1")).toHaveLength(2); // Two prompts with version 1
|
||||||
|
expect(screen.getByText("2")).toBeInTheDocument();
|
||||||
|
|
||||||
|
// Check default badge
|
||||||
|
expect(screen.getByText("Default")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("renders variables correctly", () => {
|
||||||
|
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
|
||||||
|
|
||||||
|
// Check variables display
|
||||||
|
expect(screen.getByText("name")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("text")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("length")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("None")).toBeInTheDocument(); // For prompt with no variables
|
||||||
|
});
|
||||||
|
|
||||||
|
test("prompt ID links are clickable and call onEdit", () => {
|
||||||
|
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
|
||||||
|
|
||||||
|
// Click on the first prompt ID link
|
||||||
|
const promptLink = screen.getByRole("button", { name: "prompt_123" });
|
||||||
|
fireEvent.click(promptLink);
|
||||||
|
|
||||||
|
expect(mockOnEdit).toHaveBeenCalledWith(mockPrompts[0]);
|
||||||
|
});
|
||||||
|
|
||||||
|
test("edit buttons call onEdit", () => {
|
||||||
|
const { container } = render(
|
||||||
|
<PromptList {...defaultProps} prompts={mockPrompts} />
|
||||||
|
);
|
||||||
|
|
||||||
|
// Find the action buttons in the table - they should be in the last column
|
||||||
|
const actionCells = container.querySelectorAll("td:last-child");
|
||||||
|
const firstActionCell = actionCells[0];
|
||||||
|
const editButton = firstActionCell?.querySelector("button");
|
||||||
|
|
||||||
|
expect(editButton).toBeInTheDocument();
|
||||||
|
fireEvent.click(editButton!);
|
||||||
|
|
||||||
|
expect(mockOnEdit).toHaveBeenCalledWith(mockPrompts[0]);
|
||||||
|
});
|
||||||
|
|
||||||
|
test("delete buttons call onDelete with confirmation", () => {
|
||||||
|
const originalConfirm = window.confirm;
|
||||||
|
window.confirm = jest.fn(() => true);
|
||||||
|
|
||||||
|
const { container } = render(
|
||||||
|
<PromptList {...defaultProps} prompts={mockPrompts} />
|
||||||
|
);
|
||||||
|
|
||||||
|
// Find the delete button (second button in the first action cell)
|
||||||
|
const actionCells = container.querySelectorAll("td:last-child");
|
||||||
|
const firstActionCell = actionCells[0];
|
||||||
|
const buttons = firstActionCell?.querySelectorAll("button");
|
||||||
|
const deleteButton = buttons?.[1]; // Second button should be delete
|
||||||
|
|
||||||
|
expect(deleteButton).toBeInTheDocument();
|
||||||
|
fireEvent.click(deleteButton!);
|
||||||
|
|
||||||
|
expect(window.confirm).toHaveBeenCalledWith(
|
||||||
|
"Are you sure you want to delete this prompt? This action cannot be undone."
|
||||||
|
);
|
||||||
|
expect(mockOnDelete).toHaveBeenCalledWith("prompt_123");
|
||||||
|
|
||||||
|
window.confirm = originalConfirm;
|
||||||
|
});
|
||||||
|
|
||||||
|
test("delete does not execute when confirmation is cancelled", () => {
|
||||||
|
const originalConfirm = window.confirm;
|
||||||
|
window.confirm = jest.fn(() => false);
|
||||||
|
|
||||||
|
const { container } = render(
|
||||||
|
<PromptList {...defaultProps} prompts={mockPrompts} />
|
||||||
|
);
|
||||||
|
|
||||||
|
const actionCells = container.querySelectorAll("td:last-child");
|
||||||
|
const firstActionCell = actionCells[0];
|
||||||
|
const buttons = firstActionCell?.querySelectorAll("button");
|
||||||
|
const deleteButton = buttons?.[1]; // Second button should be delete
|
||||||
|
|
||||||
|
expect(deleteButton).toBeInTheDocument();
|
||||||
|
fireEvent.click(deleteButton!);
|
||||||
|
|
||||||
|
expect(mockOnDelete).not.toHaveBeenCalled();
|
||||||
|
|
||||||
|
window.confirm = originalConfirm;
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Search Functionality", () => {
|
||||||
|
const mockPrompts: Prompt[] = [
|
||||||
|
{
|
||||||
|
prompt_id: "user_greeting",
|
||||||
|
prompt: "Hello {{name}}, welcome!",
|
||||||
|
version: 1,
|
||||||
|
variables: ["name"],
|
||||||
|
is_default: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
prompt_id: "system_summary",
|
||||||
|
prompt: "Summarize the following text",
|
||||||
|
version: 1,
|
||||||
|
variables: [],
|
||||||
|
is_default: false,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
test("filters prompts by prompt ID", () => {
|
||||||
|
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
|
||||||
|
|
||||||
|
const searchInput = screen.getByPlaceholderText("Search prompts...");
|
||||||
|
fireEvent.change(searchInput, { target: { value: "user" } });
|
||||||
|
|
||||||
|
expect(screen.getByText("user_greeting")).toBeInTheDocument();
|
||||||
|
expect(screen.queryByText("system_summary")).not.toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("filters prompts by content", () => {
|
||||||
|
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
|
||||||
|
|
||||||
|
const searchInput = screen.getByPlaceholderText("Search prompts...");
|
||||||
|
fireEvent.change(searchInput, { target: { value: "welcome" } });
|
||||||
|
|
||||||
|
expect(screen.getByText("user_greeting")).toBeInTheDocument();
|
||||||
|
expect(screen.queryByText("system_summary")).not.toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("search is case insensitive", () => {
|
||||||
|
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
|
||||||
|
|
||||||
|
const searchInput = screen.getByPlaceholderText("Search prompts...");
|
||||||
|
fireEvent.change(searchInput, { target: { value: "HELLO" } });
|
||||||
|
|
||||||
|
expect(screen.getByText("user_greeting")).toBeInTheDocument();
|
||||||
|
expect(screen.queryByText("system_summary")).not.toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("clearing search shows all prompts", () => {
|
||||||
|
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
|
||||||
|
|
||||||
|
const searchInput = screen.getByPlaceholderText("Search prompts...");
|
||||||
|
|
||||||
|
// Filter first
|
||||||
|
fireEvent.change(searchInput, { target: { value: "user" } });
|
||||||
|
expect(screen.queryByText("system_summary")).not.toBeInTheDocument();
|
||||||
|
|
||||||
|
// Clear search
|
||||||
|
fireEvent.change(searchInput, { target: { value: "" } });
|
||||||
|
expect(screen.getByText("user_greeting")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("system_summary")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
164
src/llama_stack/ui/components/prompts/prompt-list.tsx
Normal file
164
src/llama_stack/ui/components/prompts/prompt-list.tsx
Normal file
|
|
@ -0,0 +1,164 @@
|
||||||
|
"use client";
|
||||||
|
|
||||||
|
import { useState } from "react";
|
||||||
|
import { Badge } from "@/components/ui/badge";
|
||||||
|
import { Button } from "@/components/ui/button";
|
||||||
|
import {
|
||||||
|
Table,
|
||||||
|
TableBody,
|
||||||
|
TableCell,
|
||||||
|
TableHead,
|
||||||
|
TableHeader,
|
||||||
|
TableRow,
|
||||||
|
} from "@/components/ui/table";
|
||||||
|
import { Input } from "@/components/ui/input";
|
||||||
|
import { Edit, Search, Trash2 } from "lucide-react";
|
||||||
|
import { Prompt, PromptFilters } from "./types";
|
||||||
|
|
||||||
|
interface PromptListProps {
|
||||||
|
prompts: Prompt[];
|
||||||
|
onEdit: (prompt: Prompt) => void;
|
||||||
|
onDelete: (promptId: string) => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function PromptList({ prompts, onEdit, onDelete }: PromptListProps) {
|
||||||
|
const [filters, setFilters] = useState<PromptFilters>({});
|
||||||
|
|
||||||
|
const filteredPrompts = prompts.filter(prompt => {
|
||||||
|
if (
|
||||||
|
filters.searchTerm &&
|
||||||
|
!(
|
||||||
|
prompt.prompt
|
||||||
|
?.toLowerCase()
|
||||||
|
.includes(filters.searchTerm.toLowerCase()) ||
|
||||||
|
prompt.prompt_id
|
||||||
|
.toLowerCase()
|
||||||
|
.includes(filters.searchTerm.toLowerCase())
|
||||||
|
)
|
||||||
|
) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
});
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="space-y-4">
|
||||||
|
{/* Filters */}
|
||||||
|
<div className="flex flex-col sm:flex-row gap-4">
|
||||||
|
<div className="relative flex-1">
|
||||||
|
<Search className="absolute left-3 top-1/2 transform -translate-y-1/2 text-muted-foreground h-4 w-4" />
|
||||||
|
<Input
|
||||||
|
placeholder="Search prompts..."
|
||||||
|
value={filters.searchTerm || ""}
|
||||||
|
onChange={e =>
|
||||||
|
setFilters(prev => ({ ...prev, searchTerm: e.target.value }))
|
||||||
|
}
|
||||||
|
className="pl-10"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Prompts Table */}
|
||||||
|
<div className="overflow-auto">
|
||||||
|
<Table>
|
||||||
|
<TableHeader>
|
||||||
|
<TableRow>
|
||||||
|
<TableHead>ID</TableHead>
|
||||||
|
<TableHead>Content</TableHead>
|
||||||
|
<TableHead>Variables</TableHead>
|
||||||
|
<TableHead>Version</TableHead>
|
||||||
|
<TableHead>Actions</TableHead>
|
||||||
|
</TableRow>
|
||||||
|
</TableHeader>
|
||||||
|
<TableBody>
|
||||||
|
{filteredPrompts.map(prompt => (
|
||||||
|
<TableRow key={prompt.prompt_id}>
|
||||||
|
<TableCell className="max-w-48">
|
||||||
|
<Button
|
||||||
|
variant="link"
|
||||||
|
className="p-0 h-auto font-mono text-blue-600 hover:text-blue-800 dark:text-blue-400 dark:hover:text-blue-300 max-w-full justify-start"
|
||||||
|
onClick={() => onEdit(prompt)}
|
||||||
|
title={prompt.prompt_id}
|
||||||
|
>
|
||||||
|
<div className="truncate">{prompt.prompt_id}</div>
|
||||||
|
</Button>
|
||||||
|
</TableCell>
|
||||||
|
<TableCell className="max-w-64">
|
||||||
|
<div
|
||||||
|
className="font-mono text-xs text-muted-foreground truncate"
|
||||||
|
title={prompt.prompt || "No content"}
|
||||||
|
>
|
||||||
|
{prompt.prompt || "No content"}
|
||||||
|
</div>
|
||||||
|
</TableCell>
|
||||||
|
<TableCell>
|
||||||
|
{prompt.variables.length > 0 ? (
|
||||||
|
<div className="flex flex-wrap gap-1">
|
||||||
|
{prompt.variables.map(variable => (
|
||||||
|
<Badge
|
||||||
|
key={variable}
|
||||||
|
variant="outline"
|
||||||
|
className="text-xs"
|
||||||
|
>
|
||||||
|
{variable}
|
||||||
|
</Badge>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
<span className="text-muted-foreground text-sm">None</span>
|
||||||
|
)}
|
||||||
|
</TableCell>
|
||||||
|
<TableCell className="text-sm">
|
||||||
|
{prompt.version}
|
||||||
|
{prompt.is_default && (
|
||||||
|
<Badge variant="secondary" className="text-xs ml-2">
|
||||||
|
Default
|
||||||
|
</Badge>
|
||||||
|
)}
|
||||||
|
</TableCell>
|
||||||
|
<TableCell>
|
||||||
|
<div className="flex gap-1">
|
||||||
|
<Button
|
||||||
|
size="sm"
|
||||||
|
variant="outline"
|
||||||
|
onClick={() => onEdit(prompt)}
|
||||||
|
className="h-8 w-8 p-0"
|
||||||
|
>
|
||||||
|
<Edit className="h-3 w-3" />
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
size="sm"
|
||||||
|
variant="outline"
|
||||||
|
onClick={() => {
|
||||||
|
if (
|
||||||
|
confirm(
|
||||||
|
`Are you sure you want to delete this prompt? This action cannot be undone.`
|
||||||
|
)
|
||||||
|
) {
|
||||||
|
onDelete(prompt.prompt_id);
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
className="h-8 w-8 p-0 text-destructive hover:text-destructive"
|
||||||
|
>
|
||||||
|
<Trash2 className="h-3 w-3" />
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</TableCell>
|
||||||
|
</TableRow>
|
||||||
|
))}
|
||||||
|
</TableBody>
|
||||||
|
</Table>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{filteredPrompts.length === 0 && (
|
||||||
|
<div className="text-center py-12">
|
||||||
|
<div className="text-muted-foreground">
|
||||||
|
{prompts.length === 0
|
||||||
|
? "No prompts yet"
|
||||||
|
: "No prompts match your filters"}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
304
src/llama_stack/ui/components/prompts/prompt-management.test.tsx
Normal file
304
src/llama_stack/ui/components/prompts/prompt-management.test.tsx
Normal file
|
|
@ -0,0 +1,304 @@
|
||||||
|
import React from "react";
|
||||||
|
import { render, screen, fireEvent, waitFor } from "@testing-library/react";
|
||||||
|
import "@testing-library/jest-dom";
|
||||||
|
import { PromptManagement } from "./prompt-management";
|
||||||
|
import type { Prompt } from "./types";
|
||||||
|
|
||||||
|
// Mock the auth client
|
||||||
|
const mockPromptsClient = {
|
||||||
|
list: jest.fn(),
|
||||||
|
create: jest.fn(),
|
||||||
|
update: jest.fn(),
|
||||||
|
delete: jest.fn(),
|
||||||
|
};
|
||||||
|
|
||||||
|
jest.mock("@/hooks/use-auth-client", () => ({
|
||||||
|
useAuthClient: () => ({
|
||||||
|
prompts: mockPromptsClient,
|
||||||
|
}),
|
||||||
|
}));
|
||||||
|
|
||||||
|
describe("PromptManagement", () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
jest.clearAllMocks();
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Loading State", () => {
|
||||||
|
test("renders loading state initially", () => {
|
||||||
|
mockPromptsClient.list.mockReturnValue(new Promise(() => {})); // Never resolves
|
||||||
|
render(<PromptManagement />);
|
||||||
|
|
||||||
|
expect(screen.getByText("Loading prompts...")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Prompts")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Empty State", () => {
|
||||||
|
test("renders empty state when no prompts", async () => {
|
||||||
|
mockPromptsClient.list.mockResolvedValue([]);
|
||||||
|
render(<PromptManagement />);
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText("No prompts found.")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(screen.getByText("Create Your First Prompt")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("opens modal when clicking 'Create Your First Prompt'", async () => {
|
||||||
|
mockPromptsClient.list.mockResolvedValue([]);
|
||||||
|
render(<PromptManagement />);
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(
|
||||||
|
screen.getByText("Create Your First Prompt")
|
||||||
|
).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
fireEvent.click(screen.getByText("Create Your First Prompt"));
|
||||||
|
|
||||||
|
expect(screen.getByText("Create New Prompt")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Error State", () => {
|
||||||
|
test("renders error state when API fails", async () => {
|
||||||
|
const error = new Error("API not found");
|
||||||
|
mockPromptsClient.list.mockRejectedValue(error);
|
||||||
|
render(<PromptManagement />);
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText(/Error:/)).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
test("renders specific error for 404", async () => {
|
||||||
|
const error = new Error("404 Not found");
|
||||||
|
mockPromptsClient.list.mockRejectedValue(error);
|
||||||
|
render(<PromptManagement />);
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(
|
||||||
|
screen.getByText(/Prompts API endpoint not found/)
|
||||||
|
).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Prompts List", () => {
|
||||||
|
const mockPrompts: Prompt[] = [
|
||||||
|
{
|
||||||
|
prompt_id: "prompt_123",
|
||||||
|
prompt: "Hello {{name}}, how are you?",
|
||||||
|
version: 1,
|
||||||
|
variables: ["name"],
|
||||||
|
is_default: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
prompt_id: "prompt_456",
|
||||||
|
prompt: "Summarize this {{text}}",
|
||||||
|
version: 2,
|
||||||
|
variables: ["text"],
|
||||||
|
is_default: false,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
test("renders prompts list correctly", async () => {
|
||||||
|
mockPromptsClient.list.mockResolvedValue(mockPrompts);
|
||||||
|
render(<PromptManagement />);
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText("prompt_123")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(screen.getByText("prompt_456")).toBeInTheDocument();
|
||||||
|
expect(
|
||||||
|
screen.getByText("Hello {{name}}, how are you?")
|
||||||
|
).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Summarize this {{text}}")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("opens modal when clicking 'New Prompt' button", async () => {
|
||||||
|
mockPromptsClient.list.mockResolvedValue(mockPrompts);
|
||||||
|
render(<PromptManagement />);
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText("prompt_123")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
fireEvent.click(screen.getByText("New Prompt"));
|
||||||
|
|
||||||
|
expect(screen.getByText("Create New Prompt")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Modal Operations", () => {
|
||||||
|
const mockPrompts: Prompt[] = [
|
||||||
|
{
|
||||||
|
prompt_id: "prompt_123",
|
||||||
|
prompt: "Hello {{name}}",
|
||||||
|
version: 1,
|
||||||
|
variables: ["name"],
|
||||||
|
is_default: true,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
test("closes modal when clicking cancel", async () => {
|
||||||
|
mockPromptsClient.list.mockResolvedValue(mockPrompts);
|
||||||
|
render(<PromptManagement />);
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText("prompt_123")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
// Open modal
|
||||||
|
fireEvent.click(screen.getByText("New Prompt"));
|
||||||
|
expect(screen.getByText("Create New Prompt")).toBeInTheDocument();
|
||||||
|
|
||||||
|
// Close modal
|
||||||
|
fireEvent.click(screen.getByText("Cancel"));
|
||||||
|
expect(screen.queryByText("Create New Prompt")).not.toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("creates new prompt successfully", async () => {
|
||||||
|
const newPrompt: Prompt = {
|
||||||
|
prompt_id: "prompt_new",
|
||||||
|
prompt: "New prompt content",
|
||||||
|
version: 1,
|
||||||
|
variables: [],
|
||||||
|
is_default: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
mockPromptsClient.list.mockResolvedValue(mockPrompts);
|
||||||
|
mockPromptsClient.create.mockResolvedValue(newPrompt);
|
||||||
|
render(<PromptManagement />);
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText("prompt_123")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
// Open modal
|
||||||
|
fireEvent.click(screen.getByText("New Prompt"));
|
||||||
|
|
||||||
|
// Fill form
|
||||||
|
const promptInput = screen.getByLabelText("Prompt Content *");
|
||||||
|
fireEvent.change(promptInput, {
|
||||||
|
target: { value: "New prompt content" },
|
||||||
|
});
|
||||||
|
|
||||||
|
// Submit form
|
||||||
|
fireEvent.click(screen.getByText("Create Prompt"));
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(mockPromptsClient.create).toHaveBeenCalledWith({
|
||||||
|
prompt: "New prompt content",
|
||||||
|
variables: [],
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
test("handles create error gracefully", async () => {
|
||||||
|
const error = {
|
||||||
|
detail: {
|
||||||
|
errors: [{ msg: "Prompt contains undeclared variables: ['test']" }],
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
mockPromptsClient.list.mockResolvedValue(mockPrompts);
|
||||||
|
mockPromptsClient.create.mockRejectedValue(error);
|
||||||
|
render(<PromptManagement />);
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText("prompt_123")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
// Open modal
|
||||||
|
fireEvent.click(screen.getByText("New Prompt"));
|
||||||
|
|
||||||
|
// Fill form
|
||||||
|
const promptInput = screen.getByLabelText("Prompt Content *");
|
||||||
|
fireEvent.change(promptInput, { target: { value: "Hello {{test}}" } });
|
||||||
|
|
||||||
|
// Submit form
|
||||||
|
fireEvent.click(screen.getByText("Create Prompt"));
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(
|
||||||
|
screen.getByText("Prompt contains undeclared variables: ['test']")
|
||||||
|
).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
test("updates existing prompt successfully", async () => {
|
||||||
|
const updatedPrompt: Prompt = {
|
||||||
|
...mockPrompts[0],
|
||||||
|
prompt: "Updated content",
|
||||||
|
};
|
||||||
|
|
||||||
|
mockPromptsClient.list.mockResolvedValue(mockPrompts);
|
||||||
|
mockPromptsClient.update.mockResolvedValue(updatedPrompt);
|
||||||
|
const { container } = render(<PromptManagement />);
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText("prompt_123")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
// Click edit button (first button in the action cell of the first row)
|
||||||
|
const actionCells = container.querySelectorAll("td:last-child");
|
||||||
|
const firstActionCell = actionCells[0];
|
||||||
|
const editButton = firstActionCell?.querySelector("button");
|
||||||
|
|
||||||
|
expect(editButton).toBeInTheDocument();
|
||||||
|
fireEvent.click(editButton!);
|
||||||
|
|
||||||
|
expect(screen.getByText("Edit Prompt")).toBeInTheDocument();
|
||||||
|
|
||||||
|
// Update content
|
||||||
|
const promptInput = screen.getByLabelText("Prompt Content *");
|
||||||
|
fireEvent.change(promptInput, { target: { value: "Updated content" } });
|
||||||
|
|
||||||
|
// Submit form
|
||||||
|
fireEvent.click(screen.getByText("Update Prompt"));
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(mockPromptsClient.update).toHaveBeenCalledWith("prompt_123", {
|
||||||
|
prompt: "Updated content",
|
||||||
|
variables: ["name"],
|
||||||
|
version: 1,
|
||||||
|
set_as_default: true,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
test("deletes prompt successfully", async () => {
|
||||||
|
mockPromptsClient.list.mockResolvedValue(mockPrompts);
|
||||||
|
mockPromptsClient.delete.mockResolvedValue(undefined);
|
||||||
|
|
||||||
|
// Mock window.confirm
|
||||||
|
const originalConfirm = window.confirm;
|
||||||
|
window.confirm = jest.fn(() => true);
|
||||||
|
|
||||||
|
const { container } = render(<PromptManagement />);
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText("prompt_123")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
// Click delete button (second button in the action cell of the first row)
|
||||||
|
const actionCells = container.querySelectorAll("td:last-child");
|
||||||
|
const firstActionCell = actionCells[0];
|
||||||
|
const buttons = firstActionCell?.querySelectorAll("button");
|
||||||
|
const deleteButton = buttons?.[1]; // Second button should be delete
|
||||||
|
|
||||||
|
expect(deleteButton).toBeInTheDocument();
|
||||||
|
fireEvent.click(deleteButton!);
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(mockPromptsClient.delete).toHaveBeenCalledWith("prompt_123");
|
||||||
|
});
|
||||||
|
|
||||||
|
// Restore window.confirm
|
||||||
|
window.confirm = originalConfirm;
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
233
src/llama_stack/ui/components/prompts/prompt-management.tsx
Normal file
233
src/llama_stack/ui/components/prompts/prompt-management.tsx
Normal file
|
|
@ -0,0 +1,233 @@
|
||||||
|
"use client";
|
||||||
|
|
||||||
|
import { useState, useEffect } from "react";
|
||||||
|
import { Button } from "@/components/ui/button";
|
||||||
|
import { Plus } from "lucide-react";
|
||||||
|
import { PromptList } from "./prompt-list";
|
||||||
|
import { PromptEditor } from "./prompt-editor";
|
||||||
|
import { Prompt, PromptFormData } from "./types";
|
||||||
|
import { useAuthClient } from "@/hooks/use-auth-client";
|
||||||
|
|
||||||
|
export function PromptManagement() {
|
||||||
|
const [prompts, setPrompts] = useState<Prompt[]>([]);
|
||||||
|
const [showPromptModal, setShowPromptModal] = useState(false);
|
||||||
|
const [editingPrompt, setEditingPrompt] = useState<Prompt | undefined>();
|
||||||
|
const [loading, setLoading] = useState(true);
|
||||||
|
const [error, setError] = useState<string | null>(null); // For main page errors (loading, etc.)
|
||||||
|
const [modalError, setModalError] = useState<string | null>(null); // For form submission errors
|
||||||
|
const client = useAuthClient();
|
||||||
|
|
||||||
|
// Load prompts from API on component mount
|
||||||
|
useEffect(() => {
|
||||||
|
const fetchPrompts = async () => {
|
||||||
|
try {
|
||||||
|
setLoading(true);
|
||||||
|
setError(null);
|
||||||
|
|
||||||
|
const response = await client.prompts.list();
|
||||||
|
setPrompts(response || []);
|
||||||
|
} catch (err: unknown) {
|
||||||
|
console.error("Failed to load prompts:", err);
|
||||||
|
|
||||||
|
// Handle different types of errors
|
||||||
|
const error = err as Error & { status?: number };
|
||||||
|
if (error?.message?.includes("404") || error?.status === 404) {
|
||||||
|
setError(
|
||||||
|
"Prompts API endpoint not found. Please ensure your Llama Stack server supports the prompts API."
|
||||||
|
);
|
||||||
|
} else if (
|
||||||
|
error?.message?.includes("not implemented") ||
|
||||||
|
error?.message?.includes("not supported")
|
||||||
|
) {
|
||||||
|
setError(
|
||||||
|
"Prompts API is not yet implemented on this Llama Stack server."
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
setError(
|
||||||
|
`Failed to load prompts: ${error?.message || "Unknown error"}`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
setLoading(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
fetchPrompts();
|
||||||
|
}, [client]);
|
||||||
|
|
||||||
|
const handleSavePrompt = async (formData: PromptFormData) => {
|
||||||
|
try {
|
||||||
|
setModalError(null);
|
||||||
|
|
||||||
|
if (editingPrompt) {
|
||||||
|
// Update existing prompt
|
||||||
|
const response = await client.prompts.update(editingPrompt.prompt_id, {
|
||||||
|
prompt: formData.prompt,
|
||||||
|
variables: formData.variables,
|
||||||
|
version: editingPrompt.version,
|
||||||
|
set_as_default: true,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Update local state
|
||||||
|
setPrompts(prev =>
|
||||||
|
prev.map(p =>
|
||||||
|
p.prompt_id === editingPrompt.prompt_id ? response : p
|
||||||
|
)
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
// Create new prompt
|
||||||
|
const response = await client.prompts.create({
|
||||||
|
prompt: formData.prompt,
|
||||||
|
variables: formData.variables,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Add to local state
|
||||||
|
setPrompts(prev => [response, ...prev]);
|
||||||
|
}
|
||||||
|
|
||||||
|
setShowPromptModal(false);
|
||||||
|
setEditingPrompt(undefined);
|
||||||
|
} catch (err) {
|
||||||
|
console.error("Failed to save prompt:", err);
|
||||||
|
|
||||||
|
// Extract specific error message from API response
|
||||||
|
const error = err as Error & {
|
||||||
|
message?: string;
|
||||||
|
detail?: { errors?: Array<{ msg?: string }> };
|
||||||
|
};
|
||||||
|
|
||||||
|
// Try to parse JSON from error message if it's a string
|
||||||
|
let parsedError = error;
|
||||||
|
if (typeof error?.message === "string" && error.message.includes("{")) {
|
||||||
|
try {
|
||||||
|
const jsonMatch = error.message.match(/\d+\s+(.+)/);
|
||||||
|
if (jsonMatch) {
|
||||||
|
parsedError = JSON.parse(jsonMatch[1]);
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
// If parsing fails, use original error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to get the specific validation error message
|
||||||
|
const validationError = parsedError?.detail?.errors?.[0]?.msg;
|
||||||
|
if (validationError) {
|
||||||
|
// Clean up validation error messages (remove "Value error, " prefix if present)
|
||||||
|
const cleanMessage = validationError.replace(/^Value error,\s*/i, "");
|
||||||
|
setModalError(cleanMessage);
|
||||||
|
} else {
|
||||||
|
// For other errors, format them nicely with line breaks
|
||||||
|
const statusMatch = error?.message?.match(/(\d+)\s+(.+)/);
|
||||||
|
if (statusMatch) {
|
||||||
|
const statusCode = statusMatch[1];
|
||||||
|
const response = statusMatch[2];
|
||||||
|
setModalError(
|
||||||
|
`Failed to save prompt: Status Code ${statusCode}\n\nResponse: ${response}`
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
const message = error?.message || error?.detail || "Unknown error";
|
||||||
|
setModalError(`Failed to save prompt: ${message}`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleEditPrompt = (prompt: Prompt) => {
|
||||||
|
setEditingPrompt(prompt);
|
||||||
|
setShowPromptModal(true);
|
||||||
|
setModalError(null); // Clear any previous modal errors
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleDeletePrompt = async (promptId: string) => {
|
||||||
|
try {
|
||||||
|
setError(null);
|
||||||
|
await client.prompts.delete(promptId);
|
||||||
|
setPrompts(prev => prev.filter(p => p.prompt_id !== promptId));
|
||||||
|
|
||||||
|
// If we're deleting the currently editing prompt, close the modal
|
||||||
|
if (editingPrompt && editingPrompt.prompt_id === promptId) {
|
||||||
|
setShowPromptModal(false);
|
||||||
|
setEditingPrompt(undefined);
|
||||||
|
}
|
||||||
|
} catch (err) {
|
||||||
|
console.error("Failed to delete prompt:", err);
|
||||||
|
setError("Failed to delete prompt");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleCreateNew = () => {
|
||||||
|
setEditingPrompt(undefined);
|
||||||
|
setShowPromptModal(true);
|
||||||
|
setModalError(null); // Clear any previous modal errors
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleCancel = () => {
|
||||||
|
setShowPromptModal(false);
|
||||||
|
setEditingPrompt(undefined);
|
||||||
|
};
|
||||||
|
|
||||||
|
const renderContent = () => {
|
||||||
|
if (loading) {
|
||||||
|
return <div className="text-muted-foreground">Loading prompts...</div>;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (error) {
|
||||||
|
return <div className="text-destructive">Error: {error}</div>;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!prompts || prompts.length === 0) {
|
||||||
|
return (
|
||||||
|
<div className="text-center py-12">
|
||||||
|
<p className="text-muted-foreground mb-4">No prompts found.</p>
|
||||||
|
<Button onClick={handleCreateNew}>
|
||||||
|
<Plus className="h-4 w-4 mr-2" />
|
||||||
|
Create Your First Prompt
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<PromptList
|
||||||
|
prompts={prompts}
|
||||||
|
onEdit={handleEditPrompt}
|
||||||
|
onDelete={handleDeletePrompt}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="space-y-4">
|
||||||
|
<div className="flex items-center justify-between">
|
||||||
|
<h1 className="text-2xl font-semibold">Prompts</h1>
|
||||||
|
<Button onClick={handleCreateNew} disabled={loading}>
|
||||||
|
<Plus className="h-4 w-4 mr-2" />
|
||||||
|
New Prompt
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
{renderContent()}
|
||||||
|
|
||||||
|
{/* Create/Edit Prompt Modal */}
|
||||||
|
{showPromptModal && (
|
||||||
|
<div className="fixed inset-0 bg-black/50 flex items-center justify-center z-50">
|
||||||
|
<div className="bg-background border rounded-lg shadow-lg max-w-4xl w-full mx-4 max-h-[90vh] overflow-hidden">
|
||||||
|
<div className="p-6 border-b">
|
||||||
|
<h2 className="text-2xl font-bold">
|
||||||
|
{editingPrompt ? "Edit Prompt" : "Create New Prompt"}
|
||||||
|
</h2>
|
||||||
|
</div>
|
||||||
|
<div className="p-6 overflow-y-auto max-h-[calc(90vh-120px)]">
|
||||||
|
<PromptEditor
|
||||||
|
prompt={editingPrompt}
|
||||||
|
onSave={handleSavePrompt}
|
||||||
|
onCancel={handleCancel}
|
||||||
|
onDelete={handleDeletePrompt}
|
||||||
|
error={modalError}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
16
src/llama_stack/ui/components/prompts/types.ts
Normal file
16
src/llama_stack/ui/components/prompts/types.ts
Normal file
|
|
@ -0,0 +1,16 @@
|
||||||
|
export interface Prompt {
|
||||||
|
prompt_id: string;
|
||||||
|
prompt: string | null;
|
||||||
|
version: number;
|
||||||
|
variables: string[];
|
||||||
|
is_default: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface PromptFormData {
|
||||||
|
prompt: string;
|
||||||
|
variables: string[];
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface PromptFilters {
|
||||||
|
searchTerm?: string;
|
||||||
|
}
|
||||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue