diff --git a/.github/actions/install-llama-stack-client/action.yml b/.github/actions/install-llama-stack-client/action.yml new file mode 100644 index 000000000..3c1c77d9c --- /dev/null +++ b/.github/actions/install-llama-stack-client/action.yml @@ -0,0 +1,60 @@ +name: Install llama-stack-client +description: Install llama-stack-client based on branch context and client-version input + +inputs: + client-version: + description: 'Client version to install on non-release branches (latest or published). Ignored on release branches.' + required: false + default: "" + +outputs: + uv-extra-index-url: + description: 'UV_EXTRA_INDEX_URL to use (set for release branches)' + value: ${{ steps.configure.outputs.uv-extra-index-url }} + install-after-sync: + description: 'Whether to install client after uv sync' + value: ${{ steps.configure.outputs.install-after-sync }} + install-source: + description: 'Where to install client from after sync' + value: ${{ steps.configure.outputs.install-source }} + +runs: + using: "composite" + steps: + - name: Configure client installation + id: configure + shell: bash + run: | + # Determine the branch we're working with + BRANCH="${{ github.base_ref || github.ref }}" + BRANCH="${BRANCH#refs/heads/}" + + echo "Working with branch: $BRANCH" + + # On release branches: use test.pypi for uv sync, then install from git + # On non-release branches: install based on client-version after sync + if [[ "$BRANCH" =~ ^release-[0-9]+\.[0-9]+\.x$ ]]; then + echo "Detected release branch: $BRANCH" + + # Check if matching branch exists in client repo + if ! git ls-remote --exit-code --heads https://github.com/llamastack/llama-stack-client-python.git "$BRANCH" > /dev/null 2>&1; then + echo "::error::Branch $BRANCH not found in llama-stack-client-python repository" + echo "::error::Please create the matching release branch in llama-stack-client-python before testing" + exit 1 + fi + + # Configure to use test.pypi as extra index (PyPI is primary) + echo "uv-extra-index-url=https://test.pypi.org/simple/" >> $GITHUB_OUTPUT + echo "install-after-sync=true" >> $GITHUB_OUTPUT + echo "install-source=git+https://github.com/llamastack/llama-stack-client-python.git@$BRANCH" >> $GITHUB_OUTPUT + elif [ "${{ inputs.client-version }}" = "latest" ]; then + # Install from main git after sync + echo "install-after-sync=true" >> $GITHUB_OUTPUT + echo "install-source=git+https://github.com/llamastack/llama-stack-client-python.git@main" >> $GITHUB_OUTPUT + elif [ "${{ inputs.client-version }}" = "published" ]; then + # Use published version from PyPI (installed by sync) + echo "install-after-sync=false" >> $GITHUB_OUTPUT + elif [ -n "${{ inputs.client-version }}" ]; then + echo "::error::Invalid client-version: ${{ inputs.client-version }}" + exit 1 + fi diff --git a/.github/actions/run-and-record-tests/action.yml b/.github/actions/run-and-record-tests/action.yml index ac600d570..ec4d7f977 100644 --- a/.github/actions/run-and-record-tests/action.yml +++ b/.github/actions/run-and-record-tests/action.yml @@ -94,7 +94,7 @@ runs: if: ${{ always() }} uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 with: - name: logs-${{ github.run_id }}-${{ github.run_attempt || '' }}-${{ strategy.job-index }} + name: logs-${{ github.run_id }}-${{ github.run_attempt || '1' }}-${{ strategy.job-index || github.job }}-${{ github.action }} path: | *.log retention-days: 1 diff --git a/.github/actions/setup-runner/action.yml b/.github/actions/setup-runner/action.yml index 905d6b73a..3237abb67 100644 --- a/.github/actions/setup-runner/action.yml +++ b/.github/actions/setup-runner/action.yml @@ -18,25 +18,35 @@ runs: python-version: ${{ inputs.python-version }} version: 0.7.6 + - name: Configure client installation + id: client-config + uses: ./.github/actions/install-llama-stack-client + with: + client-version: ${{ inputs.client-version }} + - name: Install dependencies shell: bash + env: + UV_EXTRA_INDEX_URL: ${{ steps.client-config.outputs.uv-extra-index-url }} run: | + # Export UV env vars for current step and persist to GITHUB_ENV for subsequent steps + if [ -n "$UV_EXTRA_INDEX_URL" ]; then + export UV_INDEX_STRATEGY=unsafe-best-match + echo "UV_EXTRA_INDEX_URL=$UV_EXTRA_INDEX_URL" >> $GITHUB_ENV + echo "UV_INDEX_STRATEGY=$UV_INDEX_STRATEGY" >> $GITHUB_ENV + echo "Exported UV environment variables for current and subsequent steps" + fi + echo "Updating project dependencies via uv sync" uv sync --all-groups echo "Installing ad-hoc dependencies" uv pip install faiss-cpu - # Install llama-stack-client-python based on the client-version input - if [ "${{ inputs.client-version }}" = "latest" ]; then - echo "Installing latest llama-stack-client-python from main branch" - uv pip install git+https://github.com/llamastack/llama-stack-client-python.git@main - elif [ "${{ inputs.client-version }}" = "published" ]; then - echo "Installing published llama-stack-client-python from PyPI" - uv pip install llama-stack-client - else - echo "Invalid client-version: ${{ inputs.client-version }}" - exit 1 + # Install specific client version after sync if needed + if [ "${{ steps.client-config.outputs.install-after-sync }}" = "true" ]; then + echo "Installing llama-stack-client from: ${{ steps.client-config.outputs.install-source }}" + uv pip install ${{ steps.client-config.outputs.install-source }} fi echo "Installed llama packages" diff --git a/.github/actions/setup-test-environment/action.yml b/.github/actions/setup-test-environment/action.yml index ee9011ed8..7b306fef5 100644 --- a/.github/actions/setup-test-environment/action.yml +++ b/.github/actions/setup-test-environment/action.yml @@ -42,18 +42,7 @@ runs: - name: Build Llama Stack shell: bash run: | - # Install llama-stack-client-python based on the client-version input - if [ "${{ inputs.client-version }}" = "latest" ]; then - echo "Installing latest llama-stack-client-python from main branch" - export LLAMA_STACK_CLIENT_DIR=git+https://github.com/llamastack/llama-stack-client-python.git@main - elif [ "${{ inputs.client-version }}" = "published" ]; then - echo "Installing published llama-stack-client-python from PyPI" - unset LLAMA_STACK_CLIENT_DIR - else - echo "Invalid client-version: ${{ inputs.client-version }}" - exit 1 - fi - + # Client is already installed by setup-runner (handles both main and release branches) echo "Building Llama Stack" LLAMA_STACK_DIR=. \ diff --git a/.github/mergify.yml b/.github/mergify.yml new file mode 100644 index 000000000..a96191958 --- /dev/null +++ b/.github/mergify.yml @@ -0,0 +1,23 @@ +pull_request_rules: +- name: ping author on conflicts and add 'needs-rebase' label + conditions: + - conflict + - -closed + actions: + label: + add: + - needs-rebase + comment: + message: > + This pull request has merge conflicts that must be resolved before it + can be merged. @{{author}} please rebase it. + https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork + +- name: remove 'needs-rebase' label when conflict is resolved + conditions: + - -conflict + - -closed + actions: + label: + remove: + - needs-rebase diff --git a/.github/workflows/README.md b/.github/workflows/README.md index 00a8f54ac..88b2d5106 100644 --- a/.github/workflows/README.md +++ b/.github/workflows/README.md @@ -4,6 +4,7 @@ Llama Stack uses GitHub Actions for Continuous Integration (CI). Below is a tabl | Name | File | Purpose | | ---- | ---- | ------- | +| Backward Compatibility Check | [backward-compat.yml](backward-compat.yml) | Check backward compatibility for run.yaml configs | | Update Changelog | [changelog.yml](changelog.yml) | Creates PR for updating the CHANGELOG.md | | API Conformance Tests | [conformance.yml](conformance.yml) | Run the API Conformance test suite on the changes. | | Installer CI | [install-script-ci.yml](install-script-ci.yml) | Test the installation script | @@ -12,7 +13,6 @@ Llama Stack uses GitHub Actions for Continuous Integration (CI). Below is a tabl | Integration Tests (Replay) | [integration-tests.yml](integration-tests.yml) | Run the integration test suites from tests/integration in replay mode | | Vector IO Integration Tests | [integration-vector-io-tests.yml](integration-vector-io-tests.yml) | Run the integration test suite with various VectorIO providers | | Pre-commit | [pre-commit.yml](pre-commit.yml) | Run pre-commit checks | -| Pre-commit Bot | [precommit-trigger.yml](precommit-trigger.yml) | Pre-commit bot for PR | | Test Llama Stack Build | [providers-build.yml](providers-build.yml) | Test llama stack build | | Test llama stack list-deps | [providers-list-deps.yml](providers-list-deps.yml) | Test llama stack list-deps | | Python Package Build Test | [python-build-test.yml](python-build-test.yml) | Test building the llama-stack PyPI project | diff --git a/.github/workflows/backward-compat.yml b/.github/workflows/backward-compat.yml new file mode 100644 index 000000000..9f950a8b9 --- /dev/null +++ b/.github/workflows/backward-compat.yml @@ -0,0 +1,578 @@ +name: Backward Compatibility Check + +run-name: Check backward compatibility for run.yaml configs + +on: + pull_request: + branches: + - main + - 'release-[0-9]+.[0-9]+.[0-9]+.[0-9]+' + - 'release-[0-9]+.[0-9]+.[0-9]+' + - 'release-[0-9]+.[0-9]+' + paths: + - 'src/llama_stack/core/datatypes.py' + - 'src/llama_stack/providers/datatypes.py' + - 'src/llama_stack/distributions/**/run.yaml' + - 'tests/backward_compat/**' + - '.github/workflows/backward-compat.yml' + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + check-main-compatibility: + name: Check Compatibility with main + runs-on: ubuntu-latest + + steps: + - name: Checkout PR branch + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + with: + fetch-depth: 0 # Need full history to access main branch + + - name: Set up Python + uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 + with: + python-version: '3.12' + + - name: Install uv + uses: astral-sh/setup-uv@85856786d1ce8acfbcc2f13a5f3fbd6b938f9f41 # v7.1.2 + with: + enable-cache: true + + - name: Install dependencies + run: | + uv sync --group dev + + - name: Extract run.yaml files from main branch + id: extract_configs + run: | + # Get list of run.yaml paths from main + git fetch origin main + CONFIG_PATHS=$(git ls-tree -r --name-only origin/main | grep "src/llama_stack/distributions/.*/run.yaml$" || true) + + if [ -z "$CONFIG_PATHS" ]; then + echo "No run.yaml files found in main branch" + exit 1 + fi + + # Extract all configs to a temp directory + mkdir -p /tmp/main_configs + echo "Extracting configs from main branch:" + + while IFS= read -r config_path; do + if [ -z "$config_path" ]; then + continue + fi + + # Extract filename for storage + filename=$(basename $(dirname "$config_path")) + echo " - $filename (from $config_path)" + + git show origin/main:"$config_path" > "/tmp/main_configs/${filename}.yaml" + done <<< "$CONFIG_PATHS" + + echo "" + echo "Extracted $(ls /tmp/main_configs/*.yaml | wc -l) config files" + + - name: Test all configs from main + id: test_configs + continue-on-error: true + run: | + # Run pytest once with all configs parameterized + if COMPAT_TEST_CONFIGS_DIR=/tmp/main_configs uv run pytest tests/backward_compat/test_run_config.py -v; then + echo "failed=false" >> $GITHUB_OUTPUT + else + echo "failed=true" >> $GITHUB_OUTPUT + exit 1 + fi + + - name: Check for breaking change acknowledgment + id: check_ack + if: steps.test_configs.outputs.failed == 'true' + run: | + echo "Breaking changes detected. Checking for acknowledgment..." + + # Check PR title for '!:' marker (conventional commits) + PR_TITLE="${{ github.event.pull_request.title }}" + if [[ "$PR_TITLE" =~ ^[a-z]+\!: ]]; then + echo "✓ Breaking change acknowledged in PR title" + echo "acknowledged=true" >> $GITHUB_OUTPUT + exit 0 + fi + + # Check commit messages for BREAKING CHANGE: + if git log origin/main..HEAD --format=%B | grep -q "BREAKING CHANGE:"; then + echo "✓ Breaking change acknowledged in commit message" + echo "acknowledged=true" >> $GITHUB_OUTPUT + exit 0 + fi + + echo "✗ Breaking change NOT acknowledged" + echo "acknowledged=false" >> $GITHUB_OUTPUT + env: + GH_TOKEN: ${{ github.token }} + + - name: Evaluate results + if: always() + run: | + FAILED="${{ steps.test_configs.outputs.failed }}" + ACKNOWLEDGED="${{ steps.check_ack.outputs.acknowledged }}" + + if [[ "$FAILED" == "true" ]]; then + if [[ "$ACKNOWLEDGED" == "true" ]]; then + echo "" + echo "⚠️ WARNING: Breaking changes detected but acknowledged" + echo "" + echo "This PR introduces backward-incompatible changes to run.yaml." + echo "The changes have been properly acknowledged." + echo "" + exit 0 # Pass the check + else + echo "" + echo "❌ ERROR: Breaking changes detected without acknowledgment" + echo "" + echo "This PR introduces backward-incompatible changes to run.yaml" + echo "that will break existing user configurations." + echo "" + echo "To acknowledge this breaking change, do ONE of:" + echo " 1. Add '!:' to your PR title (e.g., 'feat!: change xyz')" + echo " 2. Add the 'breaking-change' label to this PR" + echo " 3. Include 'BREAKING CHANGE:' in a commit message" + echo "" + exit 1 # Fail the check + fi + fi + + test-integration-main: + name: Run Integration Tests with main Config + runs-on: ubuntu-latest + + steps: + - name: Checkout PR branch + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + with: + fetch-depth: 0 + + - name: Extract ci-tests run.yaml from main + run: | + git fetch origin main + git show origin/main:src/llama_stack/distributions/ci-tests/run.yaml > /tmp/main-ci-tests-run.yaml + echo "Extracted ci-tests run.yaml from main branch" + + - name: Setup test environment + uses: ./.github/actions/setup-test-environment + with: + python-version: '3.12' + client-version: 'latest' + setup: 'ollama' + suite: 'base' + inference-mode: 'replay' + + - name: Run integration tests with main config + id: test_integration + continue-on-error: true + uses: ./.github/actions/run-and-record-tests + with: + stack-config: /tmp/main-ci-tests-run.yaml + setup: 'ollama' + inference-mode: 'replay' + suite: 'base' + + - name: Check for breaking change acknowledgment + id: check_ack + if: steps.test_integration.outcome == 'failure' + run: | + echo "Integration tests failed. Checking for acknowledgment..." + + # Check PR title for '!:' marker (conventional commits) + PR_TITLE="${{ github.event.pull_request.title }}" + if [[ "$PR_TITLE" =~ ^[a-z]+\!: ]]; then + echo "✓ Breaking change acknowledged in PR title" + echo "acknowledged=true" >> $GITHUB_OUTPUT + exit 0 + fi + + # Check commit messages for BREAKING CHANGE: + if git log origin/main..HEAD --format=%B | grep -q "BREAKING CHANGE:"; then + echo "✓ Breaking change acknowledged in commit message" + echo "acknowledged=true" >> $GITHUB_OUTPUT + exit 0 + fi + + echo "✗ Breaking change NOT acknowledged" + echo "acknowledged=false" >> $GITHUB_OUTPUT + env: + GH_TOKEN: ${{ github.token }} + + - name: Evaluate integration test results + if: always() + run: | + TEST_FAILED="${{ steps.test_integration.outcome == 'failure' }}" + ACKNOWLEDGED="${{ steps.check_ack.outputs.acknowledged }}" + + if [[ "$TEST_FAILED" == "true" ]]; then + if [[ "$ACKNOWLEDGED" == "true" ]]; then + echo "" + echo "⚠️ WARNING: Integration tests failed with main config but acknowledged" + echo "" + exit 0 # Pass the check + else + echo "" + echo "❌ ERROR: Integration tests failed with main config without acknowledgment" + echo "" + echo "To acknowledge this breaking change, do ONE of:" + echo " 1. Add '!:' to your PR title (e.g., 'feat!: change xyz')" + echo " 2. Include 'BREAKING CHANGE:' in a commit message" + echo "" + exit 1 # Fail the check + fi + fi + + test-integration-release: + name: Run Integration Tests with Latest Release (Informational) + runs-on: ubuntu-latest + + steps: + - name: Checkout PR branch + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + with: + fetch-depth: 0 + + - name: Get latest release + id: get_release + run: | + # Get the latest release from GitHub + LATEST_TAG=$(gh release list --limit 1 --json tagName --jq '.[0].tagName' 2>/dev/null || echo "") + + if [ -z "$LATEST_TAG" ]; then + echo "No releases found, skipping release compatibility check" + echo "has_release=false" >> $GITHUB_OUTPUT + exit 0 + fi + + echo "Latest release: $LATEST_TAG" + echo "has_release=true" >> $GITHUB_OUTPUT + echo "tag=$LATEST_TAG" >> $GITHUB_OUTPUT + env: + GH_TOKEN: ${{ github.token }} + + - name: Extract ci-tests run.yaml from release + if: steps.get_release.outputs.has_release == 'true' + id: extract_config + run: | + RELEASE_TAG="${{ steps.get_release.outputs.tag }}" + + # Try with src/ prefix first (newer releases), then without (older releases) + if git show "$RELEASE_TAG:src/llama_stack/distributions/ci-tests/run.yaml" > /tmp/release-ci-tests-run.yaml 2>/dev/null; then + echo "Extracted ci-tests run.yaml from release $RELEASE_TAG (src/ path)" + echo "has_config=true" >> $GITHUB_OUTPUT + elif git show "$RELEASE_TAG:llama_stack/distributions/ci-tests/run.yaml" > /tmp/release-ci-tests-run.yaml 2>/dev/null; then + echo "Extracted ci-tests run.yaml from release $RELEASE_TAG (old path)" + echo "has_config=true" >> $GITHUB_OUTPUT + else + echo "::warning::ci-tests/run.yaml not found in release $RELEASE_TAG" + echo "has_config=false" >> $GITHUB_OUTPUT + fi + + - name: Setup test environment + if: steps.get_release.outputs.has_release == 'true' && steps.extract_config.outputs.has_config == 'true' + uses: ./.github/actions/setup-test-environment + with: + python-version: '3.12' + client-version: 'latest' + setup: 'ollama' + suite: 'base' + inference-mode: 'replay' + + - name: Run integration tests with release config (PR branch) + id: test_release_pr + if: steps.get_release.outputs.has_release == 'true' && steps.extract_config.outputs.has_config == 'true' + continue-on-error: true + uses: ./.github/actions/run-and-record-tests + with: + stack-config: /tmp/release-ci-tests-run.yaml + setup: 'ollama' + inference-mode: 'replay' + suite: 'base' + + - name: Checkout main branch to test baseline + if: steps.get_release.outputs.has_release == 'true' && steps.extract_config.outputs.has_config == 'true' + run: | + git checkout origin/main + + - name: Setup test environment for main + if: steps.get_release.outputs.has_release == 'true' && steps.extract_config.outputs.has_config == 'true' + uses: ./.github/actions/setup-test-environment + with: + python-version: '3.12' + client-version: 'latest' + setup: 'ollama' + suite: 'base' + inference-mode: 'replay' + + - name: Run integration tests with release config (main branch) + id: test_release_main + if: steps.get_release.outputs.has_release == 'true' && steps.extract_config.outputs.has_config == 'true' + continue-on-error: true + uses: ./.github/actions/run-and-record-tests + with: + stack-config: /tmp/release-ci-tests-run.yaml + setup: 'ollama' + inference-mode: 'replay' + suite: 'base' + + - name: Report results and post PR comment + if: always() && steps.get_release.outputs.has_release == 'true' && steps.extract_config.outputs.has_config == 'true' + run: | + RELEASE_TAG="${{ steps.get_release.outputs.tag }}" + PR_OUTCOME="${{ steps.test_release_pr.outcome }}" + MAIN_OUTCOME="${{ steps.test_release_main.outcome }}" + + if [[ "$PR_OUTCOME" == "failure" && "$MAIN_OUTCOME" == "success" ]]; then + # NEW breaking change - PR fails but main passes + echo "::error::🚨 This PR introduces a NEW breaking change!" + + # Check if we already posted a comment (to avoid spam on every push) + EXISTING_COMMENT=$(gh pr view ${{ github.event.pull_request.number }} --json comments --jq '.comments[] | select(.body | contains("🚨 New Breaking Change Detected") and contains("Integration tests")) | .id' | head -1) + + if [[ -z "$EXISTING_COMMENT" ]]; then + gh pr comment ${{ github.event.pull_request.number }} --body "## 🚨 New Breaking Change Detected + + **Integration tests against release \`$RELEASE_TAG\` are now failing** + + ⚠️ This PR introduces a breaking change that affects compatibility with the latest release. + + - Users on release \`$RELEASE_TAG\` may not be able to upgrade + - Existing configurations may break + + The tests pass on \`main\` but fail with this PR's changes. + + > **Note:** This is informational only and does not block merge. + > Consider whether this breaking change is acceptable for users." + else + echo "Comment already exists, skipping to avoid spam" + fi + + cat >> $GITHUB_STEP_SUMMARY < **Note:** This is informational only and does not block merge. + > Consider whether this breaking change is acceptable for users. + EOF + + elif [[ "$PR_OUTCOME" == "failure" ]]; then + # Existing breaking change - both PR and main fail + echo "::warning::Breaking change already exists in main branch" + + cat >> $GITHUB_STEP_SUMMARY < **Note:** This is informational only. + EOF + + else + # Success - tests pass + cat >> $GITHUB_STEP_SUMMARY </dev/null || echo "") + + if [ -z "$LATEST_TAG" ]; then + echo "No releases found, skipping release compatibility check" + echo "has_release=false" >> $GITHUB_OUTPUT + exit 0 + fi + + echo "Latest release: $LATEST_TAG" + echo "has_release=true" >> $GITHUB_OUTPUT + echo "tag=$LATEST_TAG" >> $GITHUB_OUTPUT + env: + GH_TOKEN: ${{ github.token }} + + - name: Extract configs from release + if: steps.get_release.outputs.has_release == 'true' + id: extract_release_configs + run: | + RELEASE_TAG="${{ steps.get_release.outputs.tag }}" + + # Get run.yaml files from the release (try both src/ and old path) + CONFIG_PATHS=$(git ls-tree -r --name-only "$RELEASE_TAG" | grep "llama_stack/distributions/.*/run.yaml$" || true) + + if [ -z "$CONFIG_PATHS" ]; then + echo "::warning::No run.yaml files found in release $RELEASE_TAG" + echo "has_configs=false" >> $GITHUB_OUTPUT + exit 0 + fi + + # Extract all configs to a temp directory + mkdir -p /tmp/release_configs + echo "Extracting configs from release $RELEASE_TAG:" + + while IFS= read -r config_path; do + if [ -z "$config_path" ]; then + continue + fi + + filename=$(basename $(dirname "$config_path")) + echo " - $filename (from $config_path)" + + git show "$RELEASE_TAG:$config_path" > "/tmp/release_configs/${filename}.yaml" 2>/dev/null || true + done <<< "$CONFIG_PATHS" + + echo "" + echo "Extracted $(ls /tmp/release_configs/*.yaml 2>/dev/null | wc -l) config files" + echo "has_configs=true" >> $GITHUB_OUTPUT + + - name: Test against release configs (PR branch) + id: test_schema_pr + if: steps.get_release.outputs.has_release == 'true' && steps.extract_release_configs.outputs.has_configs == 'true' + continue-on-error: true + run: | + RELEASE_TAG="${{ steps.get_release.outputs.tag }}" + COMPAT_TEST_CONFIGS_DIR=/tmp/release_configs uv run pytest tests/backward_compat/test_run_config.py -v --tb=short + + - name: Checkout main branch to test baseline + if: steps.get_release.outputs.has_release == 'true' && steps.extract_release_configs.outputs.has_configs == 'true' + run: | + git checkout origin/main + + - name: Install dependencies for main + if: steps.get_release.outputs.has_release == 'true' && steps.extract_release_configs.outputs.has_configs == 'true' + run: | + uv sync --group dev + + - name: Test against release configs (main branch) + id: test_schema_main + if: steps.get_release.outputs.has_release == 'true' && steps.extract_release_configs.outputs.has_configs == 'true' + continue-on-error: true + run: | + RELEASE_TAG="${{ steps.get_release.outputs.tag }}" + COMPAT_TEST_CONFIGS_DIR=/tmp/release_configs uv run pytest tests/backward_compat/test_run_config.py -v --tb=short + + - name: Report results and post PR comment + if: always() && steps.get_release.outputs.has_release == 'true' && steps.extract_release_configs.outputs.has_configs == 'true' + run: | + RELEASE_TAG="${{ steps.get_release.outputs.tag }}" + PR_OUTCOME="${{ steps.test_schema_pr.outcome }}" + MAIN_OUTCOME="${{ steps.test_schema_main.outcome }}" + + if [[ "$PR_OUTCOME" == "failure" && "$MAIN_OUTCOME" == "success" ]]; then + # NEW breaking change - PR fails but main passes + echo "::error::🚨 This PR introduces a NEW schema breaking change!" + + # Check if we already posted a comment (to avoid spam on every push) + EXISTING_COMMENT=$(gh pr view ${{ github.event.pull_request.number }} --json comments --jq '.comments[] | select(.body | contains("🚨 New Schema Breaking Change Detected")) | .id' | head -1) + + if [[ -z "$EXISTING_COMMENT" ]]; then + gh pr comment ${{ github.event.pull_request.number }} --body "## 🚨 New Schema Breaking Change Detected + + **Schema validation against release \`$RELEASE_TAG\` is now failing** + + ⚠️ This PR introduces a schema breaking change that affects compatibility with the latest release. + + - Users on release \`$RELEASE_TAG\` will not be able to upgrade + - Existing run.yaml configurations will fail validation + + The tests pass on \`main\` but fail with this PR's changes. + + > **Note:** This is informational only and does not block merge. + > Consider whether this breaking change is acceptable for users." + else + echo "Comment already exists, skipping to avoid spam" + fi + + cat >> $GITHUB_STEP_SUMMARY < **Note:** This is informational only and does not block merge. + > Consider whether this breaking change is acceptable for users. + EOF + + elif [[ "$PR_OUTCOME" == "failure" ]]; then + # Existing breaking change - both PR and main fail + echo "::warning::Schema breaking change already exists in main branch" + + cat >> $GITHUB_STEP_SUMMARY < **Note:** This is informational only. + EOF + + else + # Success - tests pass + cat >> $GITHUB_STEP_SUMMARY <&1 | tee /tmp/precommit.log + status=${PIPESTATUS[0]} + echo "status=$status" >> $GITHUB_OUTPUT + exit 0 env: - SKIP: no-commit-to-branch + SKIP: no-commit-to-branch,mypy RUFF_OUTPUT_FORMAT: github - name: Check pre-commit results - if: steps.precommit.outcome == 'failure' + if: steps.precommit.outputs.status != '0' run: | echo "::error::Pre-commit hooks failed. Please run 'pre-commit run --all-files' locally and commit the fixes." - echo "::warning::Some pre-commit hooks failed. Check the output above for details." + echo "" + echo "Failed hooks output:" + cat /tmp/precommit.log exit 1 - name: Debug @@ -109,3 +129,39 @@ jobs: echo "$unstaged_files" exit 1 fi + + - name: Configure client installation + id: client-config + uses: ./.github/actions/install-llama-stack-client + + - name: Sync dev + type_checking dependencies + env: + UV_EXTRA_INDEX_URL: ${{ steps.client-config.outputs.uv-extra-index-url }} + run: | + if [ -n "$UV_EXTRA_INDEX_URL" ]; then + export UV_INDEX_STRATEGY="unsafe-best-match" + fi + + uv sync --group dev --group type_checking + + # Install specific client version after sync if needed + if [ "${{ steps.client-config.outputs.install-after-sync }}" = "true" ]; then + echo "Installing llama-stack-client from: ${{ steps.client-config.outputs.install-source }}" + uv pip install ${{ steps.client-config.outputs.install-source }} + fi + + - name: Run mypy (full type_checking) + env: + UV_EXTRA_INDEX_URL: ${{ steps.client-config.outputs.uv-extra-index-url }} + run: | + if [ -n "$UV_EXTRA_INDEX_URL" ]; then + export UV_INDEX_STRATEGY="unsafe-best-match" + fi + + set +e + uv run --group dev --group type_checking mypy + status=$? + if [ $status -ne 0 ]; then + echo "::error::Full mypy failed. Reproduce locally with 'uv run pre-commit run mypy-full --hook-stage manual --all-files'." + fi + exit $status diff --git a/.github/workflows/precommit-trigger.yml b/.github/workflows/precommit-trigger.yml deleted file mode 100644 index 502230448..000000000 --- a/.github/workflows/precommit-trigger.yml +++ /dev/null @@ -1,227 +0,0 @@ -name: Pre-commit Bot - -run-name: Pre-commit bot for PR #${{ github.event.issue.number }} - -on: - issue_comment: - types: [created] - -jobs: - pre-commit: - # Only run on pull request comments - if: github.event.issue.pull_request && contains(github.event.comment.body, '@github-actions run precommit') - runs-on: ubuntu-latest - permissions: - contents: write - pull-requests: write - - steps: - - name: Check comment author and get PR details - id: check_author - uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 - with: - github-token: ${{ secrets.GITHUB_TOKEN }} - script: | - // Get PR details - const pr = await github.rest.pulls.get({ - owner: context.repo.owner, - repo: context.repo.repo, - pull_number: context.issue.number - }); - - // Check if commenter has write access or is the PR author - const commenter = context.payload.comment.user.login; - const prAuthor = pr.data.user.login; - - let hasPermission = false; - - // Check if commenter is PR author - if (commenter === prAuthor) { - hasPermission = true; - console.log(`Comment author ${commenter} is the PR author`); - } else { - // Check if commenter has write/admin access - try { - const permission = await github.rest.repos.getCollaboratorPermissionLevel({ - owner: context.repo.owner, - repo: context.repo.repo, - username: commenter - }); - - const level = permission.data.permission; - hasPermission = ['write', 'admin', 'maintain'].includes(level); - console.log(`Comment author ${commenter} has permission: ${level}`); - } catch (error) { - console.log(`Could not check permissions for ${commenter}: ${error.message}`); - } - } - - if (!hasPermission) { - await github.rest.issues.createComment({ - owner: context.repo.owner, - repo: context.repo.repo, - issue_number: context.issue.number, - body: `❌ @${commenter} You don't have permission to trigger pre-commit. Only PR authors or repository collaborators can run this command.` - }); - core.setFailed(`User ${commenter} does not have permission`); - return; - } - - // Save PR info for later steps - core.setOutput('pr_number', context.issue.number); - core.setOutput('pr_head_ref', pr.data.head.ref); - core.setOutput('pr_head_sha', pr.data.head.sha); - core.setOutput('pr_head_repo', pr.data.head.repo.full_name); - core.setOutput('pr_base_ref', pr.data.base.ref); - core.setOutput('is_fork', pr.data.head.repo.full_name !== context.payload.repository.full_name); - core.setOutput('authorized', 'true'); - - - name: React to comment - if: steps.check_author.outputs.authorized == 'true' - uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 - with: - github-token: ${{ secrets.GITHUB_TOKEN }} - script: | - await github.rest.reactions.createForIssueComment({ - owner: context.repo.owner, - repo: context.repo.repo, - comment_id: context.payload.comment.id, - content: 'rocket' - }); - - - name: Comment starting - if: steps.check_author.outputs.authorized == 'true' - uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 - with: - github-token: ${{ secrets.GITHUB_TOKEN }} - script: | - await github.rest.issues.createComment({ - owner: context.repo.owner, - repo: context.repo.repo, - issue_number: ${{ steps.check_author.outputs.pr_number }}, - body: `⏳ Running [pre-commit hooks](https://github.com/${context.repo.owner}/${context.repo.repo}/actions/runs/${context.runId}) on PR #${{ steps.check_author.outputs.pr_number }}...` - }); - - - name: Checkout PR branch (same-repo) - if: steps.check_author.outputs.authorized == 'true' && steps.check_author.outputs.is_fork == 'false' - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - with: - ref: ${{ steps.check_author.outputs.pr_head_ref }} - fetch-depth: 0 - token: ${{ secrets.GITHUB_TOKEN }} - - - name: Checkout PR branch (fork) - if: steps.check_author.outputs.authorized == 'true' && steps.check_author.outputs.is_fork == 'true' - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - with: - repository: ${{ steps.check_author.outputs.pr_head_repo }} - ref: ${{ steps.check_author.outputs.pr_head_ref }} - fetch-depth: 0 - token: ${{ secrets.GITHUB_TOKEN }} - - - name: Verify checkout - if: steps.check_author.outputs.authorized == 'true' - run: | - echo "Current SHA: $(git rev-parse HEAD)" - echo "Expected SHA: ${{ steps.check_author.outputs.pr_head_sha }}" - if [[ "$(git rev-parse HEAD)" != "${{ steps.check_author.outputs.pr_head_sha }}" ]]; then - echo "::error::Checked out SHA does not match expected SHA" - exit 1 - fi - - - name: Set up Python - if: steps.check_author.outputs.authorized == 'true' - uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 - with: - python-version: '3.12' - cache: pip - cache-dependency-path: | - **/requirements*.txt - .pre-commit-config.yaml - - - name: Set up Node.js - if: steps.check_author.outputs.authorized == 'true' - uses: actions/setup-node@2028fbc5c25fe9cf00d9f06a71cc4710d4507903 # v6.0.0 - with: - node-version: '20' - cache: 'npm' - cache-dependency-path: 'src/llama_stack/ui/' - - - name: Install npm dependencies - if: steps.check_author.outputs.authorized == 'true' - run: npm ci - working-directory: src/llama_stack/ui - - - name: Run pre-commit - if: steps.check_author.outputs.authorized == 'true' - id: precommit - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1 - continue-on-error: true - env: - SKIP: no-commit-to-branch - RUFF_OUTPUT_FORMAT: github - - - name: Check for changes - if: steps.check_author.outputs.authorized == 'true' - id: changes - run: | - if ! git diff --exit-code || [ -n "$(git ls-files --others --exclude-standard)" ]; then - echo "has_changes=true" >> $GITHUB_OUTPUT - echo "Changes detected after pre-commit" - else - echo "has_changes=false" >> $GITHUB_OUTPUT - echo "No changes after pre-commit" - fi - - - name: Commit and push changes - if: steps.check_author.outputs.authorized == 'true' && steps.changes.outputs.has_changes == 'true' - run: | - git config --local user.email "github-actions[bot]@users.noreply.github.com" - git config --local user.name "github-actions[bot]" - - git add -A - git commit -m "style: apply pre-commit fixes - - 🤖 Applied by @github-actions bot via pre-commit workflow" - - # Push changes - git push origin HEAD:${{ steps.check_author.outputs.pr_head_ref }} - - - name: Comment success with changes - if: steps.check_author.outputs.authorized == 'true' && steps.changes.outputs.has_changes == 'true' - uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 - with: - github-token: ${{ secrets.GITHUB_TOKEN }} - script: | - await github.rest.issues.createComment({ - owner: context.repo.owner, - repo: context.repo.repo, - issue_number: ${{ steps.check_author.outputs.pr_number }}, - body: `✅ Pre-commit hooks completed successfully!\n\n🔧 Changes have been committed and pushed to the PR branch.` - }); - - - name: Comment success without changes - if: steps.check_author.outputs.authorized == 'true' && steps.changes.outputs.has_changes == 'false' && steps.precommit.outcome == 'success' - uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 - with: - github-token: ${{ secrets.GITHUB_TOKEN }} - script: | - await github.rest.issues.createComment({ - owner: context.repo.owner, - repo: context.repo.repo, - issue_number: ${{ steps.check_author.outputs.pr_number }}, - body: `✅ Pre-commit hooks passed!\n\n✨ No changes needed - your code is already formatted correctly.` - }); - - - name: Comment failure - if: failure() - uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 - with: - github-token: ${{ secrets.GITHUB_TOKEN }} - script: | - await github.rest.issues.createComment({ - owner: context.repo.owner, - repo: context.repo.repo, - issue_number: ${{ steps.check_author.outputs.pr_number }}, - body: `❌ Pre-commit workflow failed!\n\nPlease check the [workflow logs](https://github.com/${context.repo.owner}/${context.repo.repo}/actions/runs/${context.runId}) for details.` - }); diff --git a/.github/workflows/providers-build.yml b/.github/workflows/providers-build.yml index 2b2ca6330..f2559a258 100644 --- a/.github/workflows/providers-build.yml +++ b/.github/workflows/providers-build.yml @@ -72,10 +72,16 @@ jobs: - name: Build container image if: matrix.image-type == 'container' run: | + BUILD_ARGS="--build-arg INSTALL_MODE=editable --build-arg DISTRO_NAME=${{ matrix.distro }}" + if [ -n "${UV_EXTRA_INDEX_URL:-}" ]; then + BUILD_ARGS="$BUILD_ARGS --build-arg UV_EXTRA_INDEX_URL=$UV_EXTRA_INDEX_URL" + fi + if [ -n "${UV_INDEX_STRATEGY:-}" ]; then + BUILD_ARGS="$BUILD_ARGS --build-arg UV_INDEX_STRATEGY=$UV_INDEX_STRATEGY" + fi docker build . \ -f containers/Containerfile \ - --build-arg INSTALL_MODE=editable \ - --build-arg DISTRO_NAME=${{ matrix.distro }} \ + $BUILD_ARGS \ --tag llama-stack:${{ matrix.distro }}-ci - name: Print dependencies in the image @@ -108,12 +114,18 @@ jobs: - name: Build container image run: | BASE_IMAGE=$(yq -r '.distribution_spec.container_image // "python:3.12-slim"' src/llama_stack/distributions/ci-tests/build.yaml) + BUILD_ARGS="--build-arg INSTALL_MODE=editable --build-arg DISTRO_NAME=ci-tests" + BUILD_ARGS="$BUILD_ARGS --build-arg BASE_IMAGE=$BASE_IMAGE" + BUILD_ARGS="$BUILD_ARGS --build-arg RUN_CONFIG_PATH=/workspace/src/llama_stack/distributions/ci-tests/run.yaml" + if [ -n "${UV_EXTRA_INDEX_URL:-}" ]; then + BUILD_ARGS="$BUILD_ARGS --build-arg UV_EXTRA_INDEX_URL=$UV_EXTRA_INDEX_URL" + fi + if [ -n "${UV_INDEX_STRATEGY:-}" ]; then + BUILD_ARGS="$BUILD_ARGS --build-arg UV_INDEX_STRATEGY=$UV_INDEX_STRATEGY" + fi docker build . \ -f containers/Containerfile \ - --build-arg INSTALL_MODE=editable \ - --build-arg DISTRO_NAME=ci-tests \ - --build-arg BASE_IMAGE="$BASE_IMAGE" \ - --build-arg RUN_CONFIG_PATH=/workspace/src/llama_stack/distributions/ci-tests/run.yaml \ + $BUILD_ARGS \ -t llama-stack:ci-tests - name: Inspect the container image entrypoint @@ -148,12 +160,18 @@ jobs: - name: Build UBI9 container image run: | BASE_IMAGE=$(yq -r '.distribution_spec.container_image // "registry.access.redhat.com/ubi9:latest"' src/llama_stack/distributions/ci-tests/build.yaml) + BUILD_ARGS="--build-arg INSTALL_MODE=editable --build-arg DISTRO_NAME=ci-tests" + BUILD_ARGS="$BUILD_ARGS --build-arg BASE_IMAGE=$BASE_IMAGE" + BUILD_ARGS="$BUILD_ARGS --build-arg RUN_CONFIG_PATH=/workspace/src/llama_stack/distributions/ci-tests/run.yaml" + if [ -n "${UV_EXTRA_INDEX_URL:-}" ]; then + BUILD_ARGS="$BUILD_ARGS --build-arg UV_EXTRA_INDEX_URL=$UV_EXTRA_INDEX_URL" + fi + if [ -n "${UV_INDEX_STRATEGY:-}" ]; then + BUILD_ARGS="$BUILD_ARGS --build-arg UV_INDEX_STRATEGY=$UV_INDEX_STRATEGY" + fi docker build . \ -f containers/Containerfile \ - --build-arg INSTALL_MODE=editable \ - --build-arg DISTRO_NAME=ci-tests \ - --build-arg BASE_IMAGE="$BASE_IMAGE" \ - --build-arg RUN_CONFIG_PATH=/workspace/src/llama_stack/distributions/ci-tests/run.yaml \ + $BUILD_ARGS \ -t llama-stack:ci-tests-ubi9 - name: Inspect UBI9 image diff --git a/.github/workflows/python-build-test.yml b/.github/workflows/python-build-test.yml index 49caea6b3..1f5c0aebf 100644 --- a/.github/workflows/python-build-test.yml +++ b/.github/workflows/python-build-test.yml @@ -24,7 +24,7 @@ jobs: uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Install uv - uses: astral-sh/setup-uv@2ddd2b9cb38ad8efd50337e8ab201519a34c9f24 # v7.1.1 + uses: astral-sh/setup-uv@85856786d1ce8acfbcc2f13a5f3fbd6b938f9f41 # v7.1.2 with: python-version: ${{ matrix.python-version }} activate-environment: true diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 182643721..92c0a6a19 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -4,9 +4,13 @@ run-name: Run the unit test suite on: push: - branches: [ main ] + branches: + - main + - 'release-[0-9]+.[0-9]+.x' pull_request: - branches: [ main ] + branches: + - main + - 'release-[0-9]+.[0-9]+.x' paths: - 'src/llama_stack/**' - '!src/llama_stack/ui/**' diff --git a/.gitignore b/.gitignore index e6198b72c..f5ca450b2 100644 --- a/.gitignore +++ b/.gitignore @@ -32,3 +32,6 @@ CLAUDE.md docs/.docusaurus/ docs/node_modules/ docs/static/imported-files/ +docs/docs/api-deprecated/ +docs/docs/api-experimental/ +docs/docs/api/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f025bae5b..ce0d79b21 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -52,22 +52,19 @@ repos: additional_dependencies: - black==24.3.0 -- repo: https://github.com/astral-sh/uv-pre-commit - rev: 0.7.20 - hooks: - - id: uv-lock -- repo: local +- repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.18.2 hooks: - id: mypy - name: mypy additional_dependencies: - - uv==0.7.8 - entry: uv run --group dev --group type_checking mypy - language: python - types: [python] + - uv==0.6.2 + - mypy + - pytest + - rich + - types-requests + - pydantic pass_filenames: false - require_serial: true # - repo: https://github.com/tcort/markdown-link-check # rev: v3.11.2 @@ -77,11 +74,26 @@ repos: - repo: local hooks: + - id: uv-lock + name: uv-lock + additional_dependencies: + - uv==0.7.20 + entry: ./scripts/uv-run-with-index.sh lock + language: python + pass_filenames: false + require_serial: true + files: ^(pyproject\.toml|uv\.lock)$ + - id: mypy-full + name: mypy (full type_checking) + entry: ./scripts/uv-run-with-index.sh run --group dev --group type_checking mypy + language: system + pass_filenames: false + stages: [manual] - id: distro-codegen name: Distribution Template Codegen additional_dependencies: - uv==0.7.8 - entry: uv run --group codegen ./scripts/distro_codegen.py + entry: ./scripts/uv-run-with-index.sh run --group codegen ./scripts/distro_codegen.py language: python pass_filenames: false require_serial: true @@ -90,7 +102,7 @@ repos: name: Provider Codegen additional_dependencies: - uv==0.7.8 - entry: uv run --group codegen ./scripts/provider_codegen.py + entry: ./scripts/uv-run-with-index.sh run --group codegen ./scripts/provider_codegen.py language: python pass_filenames: false require_serial: true @@ -99,7 +111,7 @@ repos: name: API Spec Codegen additional_dependencies: - uv==0.7.8 - entry: sh -c 'uv run ./docs/openapi_generator/run_openapi_generator.sh > /dev/null' + entry: sh -c './scripts/uv-run-with-index.sh run ./docs/openapi_generator/run_openapi_generator.sh > /dev/null' language: python pass_filenames: false require_serial: true @@ -140,7 +152,7 @@ repos: name: Generate CI documentation additional_dependencies: - uv==0.7.8 - entry: uv run ./scripts/gen-ci-docs.py + entry: ./scripts/uv-run-with-index.sh run ./scripts/gen-ci-docs.py language: python pass_filenames: false require_serial: true @@ -171,6 +183,23 @@ repos: exit 1 fi exit 0 + - id: fips-compliance + name: Ensure llama-stack remains FIPS compliant + entry: bash + language: system + types: [python] + pass_filenames: true + exclude: '^tests/.*$' # Exclude test dir as some safety tests used MD5 + args: + - -c + - | + grep -EnH '^[^#]*\b(md5|sha1|uuid3|uuid5)\b' "$@" && { + echo; + echo "❌ Do not use any of the following functions: hashlib.md5, hashlib.sha1, uuid.uuid3, uuid.uuid5" + echo " These functions are not FIPS-compliant" + echo; + exit 1; + } || true ci: autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index c869b4f5c..d84332829 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -61,6 +61,18 @@ uv run pre-commit run --all-files -v The `-v` (verbose) parameter is optional but often helpful for getting more information about any issues with that the pre-commit checks identify. +To run the expanded mypy configuration that CI enforces, use: + +```bash +uv run pre-commit run mypy-full --hook-stage manual --all-files +``` + +or invoke mypy directly with all optional dependencies: + +```bash +uv run --group dev --group type_checking mypy +``` + ```{caution} Before pushing your changes, make sure that the pre-commit hooks have passed successfully. ``` diff --git a/client-sdks/stainless/openapi.stainless.yml b/client-sdks/stainless/openapi.stainless.yml deleted file mode 100644 index 9461be996..000000000 --- a/client-sdks/stainless/openapi.stainless.yml +++ /dev/null @@ -1,610 +0,0 @@ -# yaml-language-server: $schema=https://app.stainlessapi.com/config-internal.schema.json - -organization: - # Name of your organization or company, used to determine the name of the client - # and headings. - name: llama-stack-client - docs: https://llama-stack.readthedocs.io/en/latest/ - contact: llamastack@meta.com -security: - - {} - - BearerAuth: [] -security_schemes: - BearerAuth: - type: http - scheme: bearer -# `targets` define the output targets and their customization options, such as -# whether to emit the Node SDK and what it's package name should be. -targets: - node: - package_name: llama-stack-client - production_repo: llamastack/llama-stack-client-typescript - publish: - npm: false - python: - package_name: llama_stack_client - production_repo: llamastack/llama-stack-client-python - options: - use_uv: true - publish: - pypi: true - project_name: llama_stack_client - kotlin: - reverse_domain: com.llama_stack_client.api - production_repo: null - publish: - maven: false - go: - package_name: llama-stack-client - production_repo: llamastack/llama-stack-client-go - options: - enable_v2: true - back_compat_use_shared_package: false - -# `client_settings` define settings for the API client, such as extra constructor -# arguments (used for authentication), retry behavior, idempotency, etc. -client_settings: - default_env_prefix: LLAMA_STACK_CLIENT - opts: - api_key: - type: string - read_env: LLAMA_STACK_CLIENT_API_KEY - auth: { security_scheme: BearerAuth } - nullable: true - -# `environments` are a map of the name of the environment (e.g. "sandbox", -# "production") to the corresponding url to use. -environments: - production: http://any-hosted-llama-stack.com - -# `pagination` defines [pagination schemes] which provides a template to match -# endpoints and generate next-page and auto-pagination helpers in the SDKs. -pagination: - - name: datasets_iterrows - type: offset - request: - dataset_id: - type: string - start_index: - type: integer - x-stainless-pagination-property: - purpose: offset_count_param - limit: - type: integer - response: - data: - type: array - items: - type: object - next_index: - type: integer - x-stainless-pagination-property: - purpose: offset_count_start_field - - name: openai_cursor_page - type: cursor - request: - limit: - type: integer - after: - type: string - x-stainless-pagination-property: - purpose: next_cursor_param - response: - data: - type: array - items: {} - has_more: - type: boolean - last_id: - type: string - x-stainless-pagination-property: - purpose: next_cursor_field -# `resources` define the structure and organziation for your API, such as how -# methods and models are grouped together and accessed. See the [configuration -# guide] for more information. -# -# [configuration guide]: -# https://app.stainlessapi.com/docs/guides/configure#resources -resources: - $shared: - models: - agent_config: AgentConfig - interleaved_content_item: InterleavedContentItem - interleaved_content: InterleavedContent - param_type: ParamType - safety_violation: SafetyViolation - sampling_params: SamplingParams - scoring_result: ScoringResult - message: Message - user_message: UserMessage - completion_message: CompletionMessage - tool_response_message: ToolResponseMessage - system_message: SystemMessage - tool_call: ToolCall - query_result: RAGQueryResult - document: RAGDocument - query_config: RAGQueryConfig - response_format: ResponseFormat - toolgroups: - models: - tool_group: ToolGroup - list_tool_groups_response: ListToolGroupsResponse - methods: - register: post /v1/toolgroups - get: get /v1/toolgroups/{toolgroup_id} - list: get /v1/toolgroups - unregister: delete /v1/toolgroups/{toolgroup_id} - tools: - methods: - get: get /v1/tools/{tool_name} - list: - endpoint: get /v1/tools - paginated: false - - tool_runtime: - models: - tool_def: ToolDef - tool_invocation_result: ToolInvocationResult - methods: - list_tools: - endpoint: get /v1/tool-runtime/list-tools - paginated: false - invoke_tool: post /v1/tool-runtime/invoke - subresources: - rag_tool: - methods: - insert: post /v1/tool-runtime/rag-tool/insert - query: post /v1/tool-runtime/rag-tool/query - - responses: - models: - response_object_stream: OpenAIResponseObjectStream - response_object: OpenAIResponseObject - methods: - create: - type: http - endpoint: post /v1/responses - streaming: - stream_event_model: responses.response_object_stream - param_discriminator: stream - retrieve: get /v1/responses/{response_id} - list: - type: http - endpoint: get /v1/responses - delete: - type: http - endpoint: delete /v1/responses/{response_id} - subresources: - input_items: - methods: - list: - type: http - endpoint: get /v1/responses/{response_id}/input_items - - conversations: - models: - conversation_object: Conversation - methods: - create: - type: http - endpoint: post /v1/conversations - retrieve: get /v1/conversations/{conversation_id} - update: - type: http - endpoint: post /v1/conversations/{conversation_id} - delete: - type: http - endpoint: delete /v1/conversations/{conversation_id} - subresources: - items: - methods: - get: - type: http - endpoint: get /v1/conversations/{conversation_id}/items/{item_id} - list: - type: http - endpoint: get /v1/conversations/{conversation_id}/items - create: - type: http - endpoint: post /v1/conversations/{conversation_id}/items - - inspect: - models: - healthInfo: HealthInfo - providerInfo: ProviderInfo - routeInfo: RouteInfo - versionInfo: VersionInfo - methods: - health: get /v1/health - version: get /v1/version - - embeddings: - models: - create_embeddings_response: OpenAIEmbeddingsResponse - methods: - create: post /v1/embeddings - - chat: - models: - chat_completion_chunk: OpenAIChatCompletionChunk - subresources: - completions: - methods: - create: - type: http - endpoint: post /v1/chat/completions - streaming: - stream_event_model: chat.chat_completion_chunk - param_discriminator: stream - list: - type: http - endpoint: get /v1/chat/completions - retrieve: - type: http - endpoint: get /v1/chat/completions/{completion_id} - completions: - methods: - create: - type: http - endpoint: post /v1/completions - streaming: - param_discriminator: stream - - vector_io: - models: - queryChunksResponse: QueryChunksResponse - methods: - insert: post /v1/vector-io/insert - query: post /v1/vector-io/query - - vector_stores: - models: - vector_store: VectorStoreObject - list_vector_stores_response: VectorStoreListResponse - vector_store_delete_response: VectorStoreDeleteResponse - vector_store_search_response: VectorStoreSearchResponsePage - methods: - create: post /v1/vector_stores - list: - endpoint: get /v1/vector_stores - retrieve: get /v1/vector_stores/{vector_store_id} - update: post /v1/vector_stores/{vector_store_id} - delete: delete /v1/vector_stores/{vector_store_id} - search: post /v1/vector_stores/{vector_store_id}/search - subresources: - files: - models: - vector_store_file: VectorStoreFileObject - methods: - list: get /v1/vector_stores/{vector_store_id}/files - retrieve: get /v1/vector_stores/{vector_store_id}/files/{file_id} - update: post /v1/vector_stores/{vector_store_id}/files/{file_id} - delete: delete /v1/vector_stores/{vector_store_id}/files/{file_id} - create: post /v1/vector_stores/{vector_store_id}/files - content: get /v1/vector_stores/{vector_store_id}/files/{file_id}/content - file_batches: - models: - vector_store_file_batches: VectorStoreFileBatchObject - list_vector_store_files_in_batch_response: VectorStoreFilesListInBatchResponse - methods: - create: post /v1/vector_stores/{vector_store_id}/file_batches - retrieve: get /v1/vector_stores/{vector_store_id}/file_batches/{batch_id} - list_files: get /v1/vector_stores/{vector_store_id}/file_batches/{batch_id}/files - cancel: post /v1/vector_stores/{vector_store_id}/file_batches/{batch_id}/cancel - - models: - models: - model: Model - list_models_response: ListModelsResponse - methods: - retrieve: get /v1/models/{model_id} - list: - endpoint: get /v1/models - paginated: false - register: post /v1/models - unregister: delete /v1/models/{model_id} - subresources: - openai: - methods: - list: - endpoint: get /v1/models - paginated: false - - providers: - models: - list_providers_response: ListProvidersResponse - methods: - list: - endpoint: get /v1/providers - paginated: false - retrieve: get /v1/providers/{provider_id} - - routes: - models: - list_routes_response: ListRoutesResponse - methods: - list: - endpoint: get /v1/inspect/routes - paginated: false - - - moderations: - models: - create_response: ModerationObject - methods: - create: post /v1/moderations - - - safety: - models: - run_shield_response: RunShieldResponse - methods: - run_shield: post /v1/safety/run-shield - - - shields: - models: - shield: Shield - list_shields_response: ListShieldsResponse - methods: - retrieve: get /v1/shields/{identifier} - list: - endpoint: get /v1/shields - paginated: false - register: post /v1/shields - delete: delete /v1/shields/{identifier} - - synthetic_data_generation: - models: - syntheticDataGenerationResponse: SyntheticDataGenerationResponse - methods: - generate: post /v1/synthetic-data-generation/generate - - telemetry: - models: - span_with_status: SpanWithStatus - trace: Trace - query_spans_response: QuerySpansResponse - event: Event - query_condition: QueryCondition - methods: - query_traces: - endpoint: post /v1alpha/telemetry/traces - skip_test_reason: 'unsupported query params in java / kotlin' - get_span_tree: post /v1alpha/telemetry/spans/{span_id}/tree - query_spans: - endpoint: post /v1alpha/telemetry/spans - skip_test_reason: 'unsupported query params in java / kotlin' - query_metrics: - endpoint: post /v1alpha/telemetry/metrics/{metric_name} - skip_test_reason: 'unsupported query params in java / kotlin' - # log_event: post /v1alpha/telemetry/events - save_spans_to_dataset: post /v1alpha/telemetry/spans/export - get_span: get /v1alpha/telemetry/traces/{trace_id}/spans/{span_id} - get_trace: get /v1alpha/telemetry/traces/{trace_id} - - scoring: - methods: - score: post /v1/scoring/score - score_batch: post /v1/scoring/score-batch - scoring_functions: - methods: - retrieve: get /v1/scoring-functions/{scoring_fn_id} - list: - endpoint: get /v1/scoring-functions - paginated: false - register: post /v1/scoring-functions - models: - scoring_fn: ScoringFn - scoring_fn_params: ScoringFnParams - list_scoring_functions_response: ListScoringFunctionsResponse - - benchmarks: - methods: - retrieve: get /v1alpha/eval/benchmarks/{benchmark_id} - list: - endpoint: get /v1alpha/eval/benchmarks - paginated: false - register: post /v1alpha/eval/benchmarks - models: - benchmark: Benchmark - list_benchmarks_response: ListBenchmarksResponse - - files: - methods: - create: post /v1/files - list: get /v1/files - retrieve: get /v1/files/{file_id} - delete: delete /v1/files/{file_id} - content: get /v1/files/{file_id}/content - models: - file: OpenAIFileObject - list_files_response: ListOpenAIFileResponse - delete_file_response: OpenAIFileDeleteResponse - - alpha: - subresources: - inference: - methods: - rerank: post /v1alpha/inference/rerank - - post_training: - models: - algorithm_config: AlgorithmConfig - post_training_job: PostTrainingJob - list_post_training_jobs_response: ListPostTrainingJobsResponse - methods: - preference_optimize: post /v1alpha/post-training/preference-optimize - supervised_fine_tune: post /v1alpha/post-training/supervised-fine-tune - subresources: - job: - methods: - artifacts: get /v1alpha/post-training/job/artifacts - cancel: post /v1alpha/post-training/job/cancel - status: get /v1alpha/post-training/job/status - list: - endpoint: get /v1alpha/post-training/jobs - paginated: false - - eval: - methods: - evaluate_rows: post /v1alpha/eval/benchmarks/{benchmark_id}/evaluations - run_eval: post /v1alpha/eval/benchmarks/{benchmark_id}/jobs - evaluate_rows_alpha: post /v1alpha/eval/benchmarks/{benchmark_id}/evaluations - run_eval_alpha: post /v1alpha/eval/benchmarks/{benchmark_id}/jobs - - subresources: - jobs: - methods: - cancel: delete /v1alpha/eval/benchmarks/{benchmark_id}/jobs/{job_id} - status: get /v1alpha/eval/benchmarks/{benchmark_id}/jobs/{job_id} - retrieve: get /v1alpha/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result - models: - evaluate_response: EvaluateResponse - benchmark_config: BenchmarkConfig - job: Job - - agents: - methods: - create: post /v1alpha/agents - list: get /v1alpha/agents - retrieve: get /v1alpha/agents/{agent_id} - delete: delete /v1alpha/agents/{agent_id} - models: - inference_step: InferenceStep - tool_execution_step: ToolExecutionStep - tool_response: ToolResponse - shield_call_step: ShieldCallStep - memory_retrieval_step: MemoryRetrievalStep - subresources: - session: - models: - session: Session - methods: - list: get /v1alpha/agents/{agent_id}/sessions - create: post /v1alpha/agents/{agent_id}/session - delete: delete /v1alpha/agents/{agent_id}/session/{session_id} - retrieve: get /v1alpha/agents/{agent_id}/session/{session_id} - steps: - methods: - retrieve: get /v1alpha/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id} - turn: - models: - turn: Turn - turn_response_event: AgentTurnResponseEvent - agent_turn_response_stream_chunk: AgentTurnResponseStreamChunk - methods: - create: - type: http - endpoint: post /v1alpha/agents/{agent_id}/session/{session_id}/turn - streaming: - stream_event_model: alpha.agents.turn.agent_turn_response_stream_chunk - param_discriminator: stream - retrieve: get /v1alpha/agents/{agent_id}/session/{session_id}/turn/{turn_id} - resume: - type: http - endpoint: post /v1alpha/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume - streaming: - stream_event_model: alpha.agents.turn.agent_turn_response_stream_chunk - param_discriminator: stream - - beta: - subresources: - datasets: - models: - list_datasets_response: ListDatasetsResponse - methods: - register: post /v1beta/datasets - retrieve: get /v1beta/datasets/{dataset_id} - list: - endpoint: get /v1beta/datasets - paginated: false - unregister: delete /v1beta/datasets/{dataset_id} - iterrows: get /v1beta/datasetio/iterrows/{dataset_id} - appendrows: post /v1beta/datasetio/append-rows/{dataset_id} - - -settings: - license: MIT - unwrap_response_fields: [ data ] - -openapi: - transformations: - - command: renameValue - reason: pydantic reserved name - args: - filter: - only: - - '$.components.schemas.InferenceStep.properties.model_response' - rename: - python: - property_name: 'inference_model_response' - - # - command: renameValue - # reason: pydantic reserved name - # args: - # filter: - # only: - # - '$.components.schemas.Model.properties.model_type' - # rename: - # python: - # property_name: 'type' - - command: mergeObject - reason: Better return_type using enum - args: - target: - - '$.components.schemas' - object: - ReturnType: - additionalProperties: false - properties: - type: - enum: - - string - - number - - boolean - - array - - object - - json - - union - - chat_completion_input - - completion_input - - agent_turn_input - required: - - type - type: object - - command: replaceProperties - reason: Replace return type properties with better model (see above) - args: - filter: - only: - - '$.components.schemas.ScoringFn.properties.return_type' - - '$.components.schemas.RegisterScoringFunctionRequest.properties.return_type' - value: - $ref: '#/components/schemas/ReturnType' - - command: oneOfToAnyOf - reason: Prism (mock server) doesn't like one of our requests as it technically matches multiple variants - - reason: For better names - command: extractToRefs - args: - ref: - target: '$.components.schemas.ToolCallDelta.properties.tool_call' - name: '#/components/schemas/ToolCallOrString' - -# `readme` is used to configure the code snippets that will be rendered in the -# README.md of various SDKs. In particular, you can change the `headline` -# snippet's endpoint and the arguments to call it with. -readme: - example_requests: - default: - type: request - endpoint: post /v1/chat/completions - params: &ref_0 {} - headline: - type: request - endpoint: post /v1/models - params: *ref_0 - pagination: - type: request - endpoint: post /v1/chat/completions - params: {} diff --git a/client-sdks/stainless/openapi.yml b/client-sdks/stainless/openapi.yml index d992b72eb..b080a9efd 100644 --- a/client-sdks/stainless/openapi.yml +++ b/client-sdks/stainless/openapi.yml @@ -15,6 +15,141 @@ info: servers: - url: http://any-hosted-llama-stack.com paths: + /v1/batches: + get: + responses: + '200': + description: A list of batch objects. + content: + application/json: + schema: + $ref: '#/components/schemas/ListBatchesResponse' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Batches + summary: List all batches for the current user. + description: List all batches for the current user. + parameters: + - name: after + in: query + description: >- + A cursor for pagination; returns batches after this batch ID. + required: false + schema: + type: string + - name: limit + in: query + description: >- + Number of batches to return (default 20, max 100). + required: true + schema: + type: integer + deprecated: false + post: + responses: + '200': + description: The created batch object. + content: + application/json: + schema: + $ref: '#/components/schemas/Batch' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Batches + summary: >- + Create a new batch for processing multiple API requests. + description: >- + Create a new batch for processing multiple API requests. + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/CreateBatchRequest' + required: true + deprecated: false + /v1/batches/{batch_id}: + get: + responses: + '200': + description: The batch object. + content: + application/json: + schema: + $ref: '#/components/schemas/Batch' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Batches + summary: >- + Retrieve information about a specific batch. + description: >- + Retrieve information about a specific batch. + parameters: + - name: batch_id + in: path + description: The ID of the batch to retrieve. + required: true + schema: + type: string + deprecated: false + /v1/batches/{batch_id}/cancel: + post: + responses: + '200': + description: The updated batch object. + content: + application/json: + schema: + $ref: '#/components/schemas/Batch' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Batches + summary: Cancel a batch that is in progress. + description: Cancel a batch that is in progress. + parameters: + - name: batch_id + in: path + description: The ID of the batch to cancel. + required: true + schema: + type: string + deprecated: false /v1/chat/completions: get: responses: @@ -821,7 +956,22 @@ paths: List routes. List all available API routes with their methods and implementing providers. - parameters: [] + parameters: + - name: api_filter + in: query + description: >- + Optional filter to control which routes are returned. Can be an API level + ('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level, + or 'deprecated' to show deprecated routes across all levels. If not specified, + returns only non-deprecated v1 routes. + required: false + schema: + type: string + enum: + - v1 + - v1alpha + - v1beta + - deprecated deprecated: false /v1/models: get: @@ -979,6 +1129,31 @@ paths: $ref: '#/components/schemas/RunModerationRequest' required: true deprecated: false + /v1/openai/v1/models: + get: + responses: + '200': + description: A OpenAIListModelsResponse. + content: + application/json: + schema: + $ref: '#/components/schemas/OpenAIListModelsResponse' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Models + summary: List models using the OpenAI API. + description: List models using the OpenAI API. + parameters: [] + deprecated: false /v1/prompts: get: responses: @@ -1835,40 +2010,6 @@ paths: schema: type: string deprecated: false - /v1/synthetic-data-generation/generate: - post: - responses: - '200': - description: >- - Response containing filtered synthetic data samples and optional statistics - content: - application/json: - schema: - $ref: '#/components/schemas/SyntheticDataGenerationResponse' - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - SyntheticDataGeneration (Coming Soon) - summary: >- - Generate synthetic data based on input dialogs and apply filtering. - description: >- - Generate synthetic data based on input dialogs and apply filtering. - parameters: [] - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/SyntheticDataGenerateRequest' - required: true - deprecated: false /v1/tool-runtime/invoke: post: responses: @@ -4212,6 +4353,331 @@ components: title: Error description: >- Error response from the API. Roughly follows RFC 7807. + ListBatchesResponse: + type: object + properties: + object: + type: string + const: list + default: list + data: + type: array + items: + type: object + properties: + id: + type: string + completion_window: + type: string + created_at: + type: integer + endpoint: + type: string + input_file_id: + type: string + object: + type: string + const: batch + status: + type: string + enum: + - validating + - failed + - in_progress + - finalizing + - completed + - expired + - cancelling + - cancelled + cancelled_at: + type: integer + cancelling_at: + type: integer + completed_at: + type: integer + error_file_id: + type: string + errors: + type: object + properties: + data: + type: array + items: + type: object + properties: + code: + type: string + line: + type: integer + message: + type: string + param: + type: string + additionalProperties: false + title: BatchError + object: + type: string + additionalProperties: false + title: Errors + expired_at: + type: integer + expires_at: + type: integer + failed_at: + type: integer + finalizing_at: + type: integer + in_progress_at: + type: integer + metadata: + type: object + additionalProperties: + type: string + model: + type: string + output_file_id: + type: string + request_counts: + type: object + properties: + completed: + type: integer + failed: + type: integer + total: + type: integer + additionalProperties: false + required: + - completed + - failed + - total + title: BatchRequestCounts + usage: + type: object + properties: + input_tokens: + type: integer + input_tokens_details: + type: object + properties: + cached_tokens: + type: integer + additionalProperties: false + required: + - cached_tokens + title: InputTokensDetails + output_tokens: + type: integer + output_tokens_details: + type: object + properties: + reasoning_tokens: + type: integer + additionalProperties: false + required: + - reasoning_tokens + title: OutputTokensDetails + total_tokens: + type: integer + additionalProperties: false + required: + - input_tokens + - input_tokens_details + - output_tokens + - output_tokens_details + - total_tokens + title: BatchUsage + additionalProperties: false + required: + - id + - completion_window + - created_at + - endpoint + - input_file_id + - object + - status + title: Batch + first_id: + type: string + last_id: + type: string + has_more: + type: boolean + default: false + additionalProperties: false + required: + - object + - data + - has_more + title: ListBatchesResponse + description: >- + Response containing a list of batch objects. + CreateBatchRequest: + type: object + properties: + input_file_id: + type: string + description: >- + The ID of an uploaded file containing requests for the batch. + endpoint: + type: string + description: >- + The endpoint to be used for all requests in the batch. + completion_window: + type: string + const: 24h + description: >- + The time window within which the batch should be processed. + metadata: + type: object + additionalProperties: + type: string + description: Optional metadata for the batch. + idempotency_key: + type: string + description: >- + Optional idempotency key. When provided, enables idempotent behavior. + additionalProperties: false + required: + - input_file_id + - endpoint + - completion_window + title: CreateBatchRequest + Batch: + type: object + properties: + id: + type: string + completion_window: + type: string + created_at: + type: integer + endpoint: + type: string + input_file_id: + type: string + object: + type: string + const: batch + status: + type: string + enum: + - validating + - failed + - in_progress + - finalizing + - completed + - expired + - cancelling + - cancelled + cancelled_at: + type: integer + cancelling_at: + type: integer + completed_at: + type: integer + error_file_id: + type: string + errors: + type: object + properties: + data: + type: array + items: + type: object + properties: + code: + type: string + line: + type: integer + message: + type: string + param: + type: string + additionalProperties: false + title: BatchError + object: + type: string + additionalProperties: false + title: Errors + expired_at: + type: integer + expires_at: + type: integer + failed_at: + type: integer + finalizing_at: + type: integer + in_progress_at: + type: integer + metadata: + type: object + additionalProperties: + type: string + model: + type: string + output_file_id: + type: string + request_counts: + type: object + properties: + completed: + type: integer + failed: + type: integer + total: + type: integer + additionalProperties: false + required: + - completed + - failed + - total + title: BatchRequestCounts + usage: + type: object + properties: + input_tokens: + type: integer + input_tokens_details: + type: object + properties: + cached_tokens: + type: integer + additionalProperties: false + required: + - cached_tokens + title: InputTokensDetails + output_tokens: + type: integer + output_tokens_details: + type: object + properties: + reasoning_tokens: + type: integer + additionalProperties: false + required: + - reasoning_tokens + title: OutputTokensDetails + total_tokens: + type: integer + additionalProperties: false + required: + - input_tokens + - input_tokens_details + - output_tokens + - output_tokens_details + - total_tokens + title: BatchUsage + additionalProperties: false + required: + - id + - completion_window + - created_at + - endpoint + - input_file_id + - object + - status + title: Batch Order: type: string enum: @@ -6554,6 +7020,48 @@ components: - metadata title: ModerationObjectResults description: A moderation object. + OpenAIModel: + type: object + properties: + id: + type: string + object: + type: string + const: model + default: model + created: + type: integer + owned_by: + type: string + custom_metadata: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + additionalProperties: false + required: + - id + - object + - created + - owned_by + title: OpenAIModel + description: A model from OpenAI. + OpenAIListModelsResponse: + type: object + properties: + data: + type: array + items: + $ref: '#/components/schemas/OpenAIModel' + additionalProperties: false + required: + - data + title: OpenAIListModelsResponse Prompt: type: object properties: @@ -9478,45 +9986,29 @@ components: required: - shield_id title: RegisterShieldRequest - CompletionMessage: + InvokeToolRequest: type: object properties: - role: + tool_name: type: string - const: assistant - default: assistant + description: The name of the tool to invoke. + kwargs: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object description: >- - Must be "assistant" to identify this as the model's response - content: - $ref: '#/components/schemas/InterleavedContent' - description: The content of the model's response - stop_reason: - type: string - enum: - - end_of_turn - - end_of_message - - out_of_tokens - description: >- - Reason why the model stopped generating. Options are: - `StopReason.end_of_turn`: - The model finished generating the entire response. - `StopReason.end_of_message`: - The model finished generating but generated a partial response -- usually, - a tool call. The user may call the tool and continue the conversation - with the tool's response. - `StopReason.out_of_tokens`: The model ran - out of token budget. - tool_calls: - type: array - items: - $ref: '#/components/schemas/ToolCall' - description: >- - List of tool calls. Each tool call is a ToolCall object. + A dictionary of arguments to pass to the tool. additionalProperties: false required: - - role - - content - - stop_reason - title: CompletionMessage - description: >- - A message containing the model's (assistant) response in a chat conversation. + - tool_name + - kwargs + title: InvokeToolRequest ImageContentItem: type: object properties: @@ -9563,41 +10055,6 @@ components: mapping: image: '#/components/schemas/ImageContentItem' text: '#/components/schemas/TextContentItem' - Message: - oneOf: - - $ref: '#/components/schemas/UserMessage' - - $ref: '#/components/schemas/SystemMessage' - - $ref: '#/components/schemas/ToolResponseMessage' - - $ref: '#/components/schemas/CompletionMessage' - discriminator: - propertyName: role - mapping: - user: '#/components/schemas/UserMessage' - system: '#/components/schemas/SystemMessage' - tool: '#/components/schemas/ToolResponseMessage' - assistant: '#/components/schemas/CompletionMessage' - SystemMessage: - type: object - properties: - role: - type: string - const: system - default: system - description: >- - Must be "system" to identify this as a system message - content: - $ref: '#/components/schemas/InterleavedContent' - description: >- - The content of the "system prompt". If multiple system messages are provided, - they are concatenated. The underlying Llama Stack code may also add other - system messages (for example, for formatting tool definitions). - additionalProperties: false - required: - - role - - content - title: SystemMessage - description: >- - A system message providing instructions or context to the model. TextContentItem: type: object properties: @@ -9616,179 +10073,6 @@ components: - text title: TextContentItem description: A text content item - ToolCall: - type: object - properties: - call_id: - type: string - tool_name: - oneOf: - - type: string - enum: - - brave_search - - wolfram_alpha - - photogen - - code_interpreter - title: BuiltinTool - - type: string - arguments: - type: string - additionalProperties: false - required: - - call_id - - tool_name - - arguments - title: ToolCall - ToolResponseMessage: - type: object - properties: - role: - type: string - const: tool - default: tool - description: >- - Must be "tool" to identify this as a tool response - call_id: - type: string - description: >- - Unique identifier for the tool call this response is for - content: - $ref: '#/components/schemas/InterleavedContent' - description: The response content from the tool - additionalProperties: false - required: - - role - - call_id - - content - title: ToolResponseMessage - description: >- - A message representing the result of a tool invocation. - URL: - type: object - properties: - uri: - type: string - description: The URL string pointing to the resource - additionalProperties: false - required: - - uri - title: URL - description: A URL reference to external content. - UserMessage: - type: object - properties: - role: - type: string - const: user - default: user - description: >- - Must be "user" to identify this as a user message - content: - $ref: '#/components/schemas/InterleavedContent' - description: >- - The content of the message, which can include text and other media - context: - $ref: '#/components/schemas/InterleavedContent' - description: >- - (Optional) This field is used internally by Llama Stack to pass RAG context. - This field may be removed in the API in the future. - additionalProperties: false - required: - - role - - content - title: UserMessage - description: >- - A message from the user in a chat conversation. - SyntheticDataGenerateRequest: - type: object - properties: - dialogs: - type: array - items: - $ref: '#/components/schemas/Message' - description: >- - List of conversation messages to use as input for synthetic data generation - filtering_function: - type: string - enum: - - none - - random - - top_k - - top_p - - top_k_top_p - - sigmoid - description: >- - Type of filtering to apply to generated synthetic data samples - model: - type: string - description: >- - (Optional) The identifier of the model to use. The model must be registered - with Llama Stack and available via the /models endpoint - additionalProperties: false - required: - - dialogs - - filtering_function - title: SyntheticDataGenerateRequest - SyntheticDataGenerationResponse: - type: object - properties: - synthetic_data: - type: array - items: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: >- - List of generated synthetic data samples that passed the filtering criteria - statistics: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: >- - (Optional) Statistical information about the generation process and filtering - results - additionalProperties: false - required: - - synthetic_data - title: SyntheticDataGenerationResponse - description: >- - Response from the synthetic data generation. Batch of (prompt, response, score) - tuples that pass the threshold. - InvokeToolRequest: - type: object - properties: - tool_name: - type: string - description: The name of the tool to invoke. - kwargs: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: >- - A dictionary of arguments to pass to the tool. - additionalProperties: false - required: - - tool_name - - kwargs - title: InvokeToolRequest ToolInvocationResult: type: object properties: @@ -9819,6 +10103,17 @@ components: additionalProperties: false title: ToolInvocationResult description: Result of a tool invocation. + URL: + type: object + properties: + uri: + type: string + description: The URL string pointing to the resource + additionalProperties: false + required: + - uri + title: URL + description: A URL reference to external content. ToolDef: type: object properties: @@ -10258,6 +10553,10 @@ components: description: >- The content of the chunk, which can be interleaved text, images, or other types. + chunk_id: + type: string + description: >- + Unique identifier for the chunk. Must be provided explicitly. metadata: type: object additionalProperties: @@ -10278,10 +10577,6 @@ components: description: >- Optional embedding for the chunk. If not provided, it will be computed later. - stored_chunk_id: - type: string - description: >- - The chunk ID that is stored in the vector database. Used for backend functionality. chunk_metadata: $ref: '#/components/schemas/ChunkMetadata' description: >- @@ -10290,6 +10585,7 @@ components: additionalProperties: false required: - content + - chunk_id - metadata title: Chunk description: >- @@ -11850,6 +12146,45 @@ components: title: AgentSessionCreateResponse description: >- Response returned when creating a new agent session. + CompletionMessage: + type: object + properties: + role: + type: string + const: assistant + default: assistant + description: >- + Must be "assistant" to identify this as the model's response + content: + $ref: '#/components/schemas/InterleavedContent' + description: The content of the model's response + stop_reason: + type: string + enum: + - end_of_turn + - end_of_message + - out_of_tokens + description: >- + Reason why the model stopped generating. Options are: - `StopReason.end_of_turn`: + The model finished generating the entire response. - `StopReason.end_of_message`: + The model finished generating but generated a partial response -- usually, + a tool call. The user may call the tool and continue the conversation + with the tool's response. - `StopReason.out_of_tokens`: The model ran + out of token budget. + tool_calls: + type: array + items: + $ref: '#/components/schemas/ToolCall' + description: >- + List of tool calls. Each tool call is a ToolCall object. + additionalProperties: false + required: + - role + - content + - stop_reason + title: CompletionMessage + description: >- + A message containing the model's (assistant) response in a chat conversation. InferenceStep: type: object properties: @@ -12002,6 +12337,29 @@ components: - step_type title: ShieldCallStep description: A shield call step in an agent turn. + ToolCall: + type: object + properties: + call_id: + type: string + tool_name: + oneOf: + - type: string + enum: + - brave_search + - wolfram_alpha + - photogen + - code_interpreter + title: BuiltinTool + - type: string + arguments: + type: string + additionalProperties: false + required: + - call_id + - tool_name + - arguments + title: ToolCall ToolExecutionStep: type: object properties: @@ -12089,6 +12447,30 @@ components: - content title: ToolResponse description: Response from a tool invocation. + ToolResponseMessage: + type: object + properties: + role: + type: string + const: tool + default: tool + description: >- + Must be "tool" to identify this as a tool response + call_id: + type: string + description: >- + Unique identifier for the tool call this response is for + content: + $ref: '#/components/schemas/InterleavedContent' + description: The response content from the tool + additionalProperties: false + required: + - role + - call_id + - content + title: ToolResponseMessage + description: >- + A message representing the result of a tool invocation. Turn: type: object properties: @@ -12174,6 +12556,31 @@ components: title: Turn description: >- A single turn in an interaction with an Agentic System. + UserMessage: + type: object + properties: + role: + type: string + const: user + default: user + description: >- + Must be "user" to identify this as a user message + content: + $ref: '#/components/schemas/InterleavedContent' + description: >- + The content of the message, which can include text and other media + context: + $ref: '#/components/schemas/InterleavedContent' + description: >- + (Optional) This field is used internally by Llama Stack to pass RAG context. + This field may be removed in the API in the future. + additionalProperties: false + required: + - role + - content + title: UserMessage + description: >- + A message from the user in a chat conversation. CreateAgentTurnRequest: type: object properties: @@ -12787,6 +13194,28 @@ components: - sampling_params title: ModelCandidate description: A model candidate for evaluation. + SystemMessage: + type: object + properties: + role: + type: string + const: system + default: system + description: >- + Must be "system" to identify this as a system message + content: + $ref: '#/components/schemas/InterleavedContent' + description: >- + The content of the "system prompt". If multiple system messages are provided, + they are concatenated. The underlying Llama Stack code may also add other + system messages (for example, for formatting tool definitions). + additionalProperties: false + required: + - role + - content + title: SystemMessage + description: >- + A system message providing instructions or context to the model. EvaluateRowsRequest: type: object properties: @@ -13527,6 +13956,19 @@ tags: description: >- APIs for creating and interacting with agentic systems. x-displayName: Agents + - name: Batches + description: >- + The API is designed to allow use of openai client libraries for seamless integration. + + + This API provides the following extensions: + - idempotent batch creation + + Note: This API is currently under active development and may undergo changes. + x-displayName: >- + The Batches API enables efficient processing of multiple requests in a single + operation, particularly useful for processing large datasets, batch evaluation + workflows, and cost-effective inference at scale. - name: Benchmarks description: '' - name: Conversations @@ -13589,8 +14031,6 @@ tags: description: '' - name: Shields description: '' - - name: SyntheticDataGeneration (Coming Soon) - description: '' - name: ToolGroups description: '' - name: ToolRuntime @@ -13601,6 +14041,7 @@ x-tagGroups: - name: Operations tags: - Agents + - Batches - Benchmarks - Conversations - DatasetIO @@ -13617,7 +14058,6 @@ x-tagGroups: - Scoring - ScoringFunctions - Shields - - SyntheticDataGeneration (Coming Soon) - ToolGroups - ToolRuntime - VectorIO diff --git a/containers/Containerfile b/containers/Containerfile index 1c878ea9b..d2d066845 100644 --- a/containers/Containerfile +++ b/containers/Containerfile @@ -19,6 +19,8 @@ ARG KEEP_WORKSPACE="" ARG DISTRO_NAME="starter" ARG RUN_CONFIG_PATH="" ARG UV_HTTP_TIMEOUT=500 +ARG UV_EXTRA_INDEX_URL="" +ARG UV_INDEX_STRATEGY="" ENV UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT} ENV PYTHONDONTWRITEBYTECODE=1 ENV PIP_DISABLE_PIP_VERSION_CHECK=1 @@ -45,7 +47,7 @@ RUN set -eux; \ exit 1; \ fi -RUN pip install --no-cache uv +RUN pip install --no-cache-dir uv ENV UV_SYSTEM_PYTHON=1 ENV INSTALL_MODE=${INSTALL_MODE} @@ -62,47 +64,60 @@ COPY . /workspace # Install the client package if it is provided # NOTE: this is installed before llama-stack since llama-stack depends on llama-stack-client-python +# Unset UV index env vars to ensure we only use PyPI for the client RUN set -eux; \ + unset UV_EXTRA_INDEX_URL UV_INDEX_STRATEGY; \ if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then \ if [ ! -d "$LLAMA_STACK_CLIENT_DIR" ]; then \ echo "LLAMA_STACK_CLIENT_DIR is set but $LLAMA_STACK_CLIENT_DIR does not exist" >&2; \ exit 1; \ fi; \ - uv pip install --no-cache -e "$LLAMA_STACK_CLIENT_DIR"; \ + uv pip install --no-cache-dir -e "$LLAMA_STACK_CLIENT_DIR"; \ fi; # Install llama-stack +# Use UV_EXTRA_INDEX_URL inline only for editable install with RC dependencies RUN set -eux; \ + SAVED_UV_EXTRA_INDEX_URL="${UV_EXTRA_INDEX_URL:-}"; \ + SAVED_UV_INDEX_STRATEGY="${UV_INDEX_STRATEGY:-}"; \ + unset UV_EXTRA_INDEX_URL UV_INDEX_STRATEGY; \ if [ "$INSTALL_MODE" = "editable" ]; then \ if [ ! -d "$LLAMA_STACK_DIR" ]; then \ echo "INSTALL_MODE=editable requires LLAMA_STACK_DIR to point to a directory inside the build context" >&2; \ exit 1; \ fi; \ - uv pip install --no-cache -e "$LLAMA_STACK_DIR"; \ - elif [ "$INSTALL_MODE" = "test-pypi" ]; then \ - uv pip install --no-cache fastapi libcst; \ - if [ -n "$TEST_PYPI_VERSION" ]; then \ - uv pip install --no-cache --extra-index-url https://test.pypi.org/simple/ --index-strategy unsafe-best-match "llama-stack==$TEST_PYPI_VERSION"; \ + if [ -n "$SAVED_UV_EXTRA_INDEX_URL" ] && [ -n "$SAVED_UV_INDEX_STRATEGY" ]; then \ + UV_EXTRA_INDEX_URL="$SAVED_UV_EXTRA_INDEX_URL" UV_INDEX_STRATEGY="$SAVED_UV_INDEX_STRATEGY" \ + uv pip install --no-cache-dir -e "$LLAMA_STACK_DIR"; \ else \ - uv pip install --no-cache --extra-index-url https://test.pypi.org/simple/ --index-strategy unsafe-best-match llama-stack; \ + uv pip install --no-cache-dir -e "$LLAMA_STACK_DIR"; \ + fi; \ + elif [ "$INSTALL_MODE" = "test-pypi" ]; then \ + uv pip install --no-cache-dir fastapi libcst; \ + if [ -n "$TEST_PYPI_VERSION" ]; then \ + uv pip install --no-cache-dir --extra-index-url https://test.pypi.org/simple/ --index-strategy unsafe-best-match "llama-stack==$TEST_PYPI_VERSION"; \ + else \ + uv pip install --no-cache-dir --extra-index-url https://test.pypi.org/simple/ --index-strategy unsafe-best-match llama-stack; \ fi; \ else \ if [ -n "$PYPI_VERSION" ]; then \ - uv pip install --no-cache "llama-stack==$PYPI_VERSION"; \ + uv pip install --no-cache-dir "llama-stack==$PYPI_VERSION"; \ else \ - uv pip install --no-cache llama-stack; \ + uv pip install --no-cache-dir llama-stack; \ fi; \ fi; # Install the dependencies for the distribution +# Explicitly unset UV index env vars to ensure we only use PyPI for distribution deps RUN set -eux; \ + unset UV_EXTRA_INDEX_URL UV_INDEX_STRATEGY; \ if [ -z "$DISTRO_NAME" ]; then \ echo "DISTRO_NAME must be provided" >&2; \ exit 1; \ fi; \ deps="$(llama stack list-deps "$DISTRO_NAME")"; \ if [ -n "$deps" ]; then \ - printf '%s\n' "$deps" | xargs -L1 uv pip install --no-cache; \ + printf '%s\n' "$deps" | xargs -L1 uv pip install --no-cache-dir; \ fi # Cleanup diff --git a/docs/docs/concepts/apis/index.mdx b/docs/docs/concepts/apis/index.mdx index 11b8b2e08..7519f6eff 100644 --- a/docs/docs/concepts/apis/index.mdx +++ b/docs/docs/concepts/apis/index.mdx @@ -23,5 +23,4 @@ A Llama Stack API is described as a collection of REST endpoints. We currently s We are working on adding a few more APIs to complete the application lifecycle. These will include: - **Batch Inference**: run inference on a dataset of inputs - **Batch Agents**: run agents on a dataset of inputs -- **Synthetic Data Generation**: generate synthetic data for model development - **Batches**: OpenAI-compatible batch management for inference diff --git a/docs/docs/distributions/self_hosted_distro/meta-reference-gpu.md b/docs/docs/distributions/self_hosted_distro/meta-reference-gpu.md index b7134b3e1..9c4095e88 100644 --- a/docs/docs/distributions/self_hosted_distro/meta-reference-gpu.md +++ b/docs/docs/distributions/self_hosted_distro/meta-reference-gpu.md @@ -79,6 +79,33 @@ docker run \ --port $LLAMA_STACK_PORT ``` +### Via Docker with Custom Run Configuration + +You can also run the Docker container with a custom run configuration file by mounting it into the container: + +```bash +# Set the path to your custom run.yaml file +CUSTOM_RUN_CONFIG=/path/to/your/custom-run.yaml +LLAMA_STACK_PORT=8321 + +docker run \ + -it \ + --pull always \ + --gpu all \ + -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ + -v ~/.llama:/root/.llama \ + -v $CUSTOM_RUN_CONFIG:/app/custom-run.yaml \ + -e RUN_CONFIG_PATH=/app/custom-run.yaml \ + llamastack/distribution-meta-reference-gpu \ + --port $LLAMA_STACK_PORT +``` + +**Note**: The run configuration must be mounted into the container before it can be used. The `-v` flag mounts your local file into the container, and the `RUN_CONFIG_PATH` environment variable tells the entrypoint script which configuration to use. + +Available run configurations for this distribution: +- `run.yaml` +- `run-with-safety.yaml` + ### Via venv Make sure you have the Llama Stack CLI available. diff --git a/docs/docs/distributions/self_hosted_distro/nvidia.md b/docs/docs/distributions/self_hosted_distro/nvidia.md index 4a7d99ff5..c48a7d391 100644 --- a/docs/docs/distributions/self_hosted_distro/nvidia.md +++ b/docs/docs/distributions/self_hosted_distro/nvidia.md @@ -127,13 +127,39 @@ docker run \ -it \ --pull always \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ - -v ./run.yaml:/root/my-run.yaml \ + -v ~/.llama:/root/.llama \ -e NVIDIA_API_KEY=$NVIDIA_API_KEY \ llamastack/distribution-nvidia \ - --config /root/my-run.yaml \ --port $LLAMA_STACK_PORT ``` +### Via Docker with Custom Run Configuration + +You can also run the Docker container with a custom run configuration file by mounting it into the container: + +```bash +# Set the path to your custom run.yaml file +CUSTOM_RUN_CONFIG=/path/to/your/custom-run.yaml +LLAMA_STACK_PORT=8321 + +docker run \ + -it \ + --pull always \ + -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ + -v ~/.llama:/root/.llama \ + -v $CUSTOM_RUN_CONFIG:/app/custom-run.yaml \ + -e RUN_CONFIG_PATH=/app/custom-run.yaml \ + -e NVIDIA_API_KEY=$NVIDIA_API_KEY \ + llamastack/distribution-nvidia \ + --port $LLAMA_STACK_PORT +``` + +**Note**: The run configuration must be mounted into the container before it can be used. The `-v` flag mounts your local file into the container, and the `RUN_CONFIG_PATH` environment variable tells the entrypoint script which configuration to use. + +Available run configurations for this distribution: +- `run.yaml` +- `run-with-safety.yaml` + ### Via venv If you've set up your local development environment, you can also install the distribution dependencies using your local virtual environment. diff --git a/docs/docs/providers/files/remote_openai.mdx b/docs/docs/providers/files/remote_openai.mdx new file mode 100644 index 000000000..3b5c40aad --- /dev/null +++ b/docs/docs/providers/files/remote_openai.mdx @@ -0,0 +1,27 @@ +--- +description: "OpenAI Files API provider for managing files through OpenAI's native file storage service." +sidebar_label: Remote - Openai +title: remote::openai +--- + +# remote::openai + +## Description + +OpenAI Files API provider for managing files through OpenAI's native file storage service. + +## Configuration + +| Field | Type | Required | Default | Description | +|-------|------|----------|---------|-------------| +| `api_key` | `` | No | | OpenAI API key for authentication | +| `metadata_store` | `` | No | | SQL store configuration for file metadata | + +## Sample Configuration + +```yaml +api_key: ${env.OPENAI_API_KEY} +metadata_store: + table_name: openai_files_metadata + backend: sql_default +``` diff --git a/docs/docs/providers/inference/remote_nvidia.mdx b/docs/docs/providers/inference/remote_nvidia.mdx index b4e04176c..57c64ab46 100644 --- a/docs/docs/providers/inference/remote_nvidia.mdx +++ b/docs/docs/providers/inference/remote_nvidia.mdx @@ -20,6 +20,7 @@ NVIDIA inference provider for accessing NVIDIA NIM models and AI services. | `url` | `` | No | https://integrate.api.nvidia.com | A base url for accessing the NVIDIA NIM | | `timeout` | `` | No | 60 | Timeout for the HTTP requests | | `append_api_version` | `` | No | True | When set to false, the API version will not be appended to the base_url. By default, it is true. | +| `rerank_model_to_url` | `dict[str, str` | No | `{'nv-rerank-qa-mistral-4b:1': 'https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking', 'nvidia/nv-rerankqa-mistral-4b-v3': 'https://ai.api.nvidia.com/v1/retrieval/nvidia/nv-rerankqa-mistral-4b-v3/reranking', 'nvidia/llama-3.2-nv-rerankqa-1b-v2': 'https://ai.api.nvidia.com/v1/retrieval/nvidia/llama-3_2-nv-rerankqa-1b-v2/reranking'}` | Mapping of rerank model identifiers to their API endpoints. | ## Sample Configuration diff --git a/docs/notebooks/llamastack_agents_getting_started_examples.ipynb b/docs/notebooks/llamastack_agents_getting_started_examples.ipynb new file mode 100644 index 000000000..1ac1a2f92 --- /dev/null +++ b/docs/notebooks/llamastack_agents_getting_started_examples.ipynb @@ -0,0 +1,1036 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/meta-llama/llama-stack/blob/main/docs/notebooks/llamastack_agents_getting_started_examples.ipynb)\n", + "\n", + "# Llama Stack Agents - Getting Started Guide\n", + "\n", + "This notebook provides a comprehensive introduction to building AI agents with Llama Stack. The Agent SDK is built on top of an open source version of **OpenAI's Responses+ APIs**, providing a standardized interface for agent workflows.\n", + "\n", + "## What You'll Learn\n", + "\n", + "1. **Basic Agent Creation** - Simple Q&A agents with streaming\n", + "2. **Multi-Turn Conversations** - Maintaining context across conversations\n", + "3. **RAG Integration** - Adding knowledge bases to your agents \n", + "4. **MCP Tools** - Extending agents with Model Context Protocol tools\n", + "\n", + "## Prerequisites\n", + "\n", + "- Llama Stack server running: `llama stack run starter --port 8321`\n", + "- A model provider configured (Ollama, Fireworks, etc.)\n", + "- Python 3.10+\n", + "\n", + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "✅ Client initialized successfully!\n", + " Base URL: http://localhost:8321\n" + ] + } + ], + "source": [ + "# Import required libraries\n", + "import json\n", + "from typing import Any, Dict\n", + "\n", + "from llama_stack_client import LlamaStackClient, Agent\n", + "from llama_stack_client.types import UserMessage\n", + "\n", + "# Initialize client\n", + "client = LlamaStackClient(base_url=\"http://localhost:8321\")\n", + "\n", + "print(\"✅ Client initialized successfully!\")\n", + "print(f\" Base URL: http://localhost:8321\")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "✅ Created agent successfully\n" + ] + } + ], + "source": [ + "# Create a basic agent using the Agent class\n", + "agent = Agent(\n", + " client=client,\n", + " model=\"ollama/llama3.3:70b\",\n", + " instructions=\"You are a helpful AI assistant that can answer questions and help with tasks.\",\n", + ")\n", + "\n", + "print(\"✅ Created agent successfully\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "\n", + "# Part 1: Basic Agent Example\n", + "\n", + "Let's start with a simple agent that can answer questions. This demonstrates:\n", + "- Agent creation with basic configuration\n", + "- Session management\n", + "- Streaming responses" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:httpx:HTTP Request: POST http://localhost:8321/v1/conversations \"HTTP/1.1 200 OK\"\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "✅ Created session: conv_e6afd7aaa97b49ce8f4f96a801b07893d9cb784d72e53e3c\n" + ] + } + ], + "source": [ + "# Create agent session\n", + "basic_session_id = agent.create_session(session_name=\"basic_example_session\")\n", + "\n", + "print(f\"✅ Created session: {basic_session_id}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:httpx:HTTP Request: POST http://localhost:8321/v1/responses \"HTTP/1.1 200 OK\"\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "User: What is the capital of France? Please explain briefly.\n", + "\n", + "Assistant: The capital of France is Paris. It's the country's largest city, known for iconic landmarks like the Eiffel Tower, Notre-Dame Cathedral, and the Louvre Museum, serving as the center of French politics, culture, and economy.The capital of France is Paris. It's the country's largest city, known for iconic landmarks like the Eiffel Tower, Notre-Dame Cathedral, and the Louvre Museum, serving as the center of French politics, culture, and economy.\n", + "\n", + "✅ Response captured: 223 characters\n" + ] + } + ], + "source": [ + "# Send a message to the agent with streaming\n", + "query = \"What is the capital of France? Please explain briefly.\"\n", + "\n", + "print(f\"User: {query}\\n\")\n", + "print(\"Assistant: \", end='')\n", + "\n", + "# Create a turn with streaming\n", + "response = agent.create_turn(\n", + " session_id=basic_session_id,\n", + " messages=[UserMessage(content=query, role=\"user\")],\n", + " stream=True,\n", + ")\n", + "\n", + "# Stream the response\n", + "output_text = \"\"\n", + "for chunk in response:\n", + " if chunk.event.event_type == \"turn_completed\":\n", + " output_text = chunk.event.final_text\n", + " print(output_text)\n", + " break\n", + " elif chunk.event.event_type == \"step_progress\":\n", + " # Print text deltas as they arrive\n", + " if hasattr(chunk.event.delta, 'text'):\n", + " print(chunk.event.delta.text, end='', flush=True)\n", + "\n", + "print(f\"\\n✅ Response captured: {len(output_text)} characters\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:httpx:HTTP Request: DELETE http://localhost:8321/v1/conversations/conv_e6afd7aaa97b49ce8f4f96a801b07893d9cb784d72e53e3c \"HTTP/1.1 200 OK\"\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "✅ Session cleaned up\n" + ] + } + ], + "source": [ + "# Clean up the session\n", + "client.conversations.delete(conversation_id=basic_session_id)\n", + "print(\"✅ Session cleaned up\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "\n", + "# Part 2: Advanced Agent Features\n", + "\n", + "Now let's explore more advanced capabilities:\n", + "- Multi-turn conversations with context memory\n", + "- RAG (Retrieval-Augmented Generation) patterns" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2.1 Multi-Turn Conversation\n", + "\n", + "Demonstrate how agents can maintain context across multiple conversation turns." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:httpx:HTTP Request: POST http://localhost:8321/v1/conversations \"HTTP/1.1 200 OK\"\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "✅ Created conversation agent\n", + "✅ Created session: conv_936121c2e27b7d1f7d3f0b6a62adce867d79268f5f9ce265\n" + ] + } + ], + "source": [ + "# Create agent for multi-turn conversation\n", + "conv_agent = Agent(\n", + " client=client,\n", + " model=\"ollama/llama3.3:70b\",\n", + " instructions=\"You are a helpful assistant that remembers context from previous messages.\",\n", + ")\n", + "\n", + "print(\"✅ Created conversation agent\")\n", + "\n", + "conv_session_id = conv_agent.create_session(session_name=\"multi_turn_session\")\n", + "print(f\"✅ Created session: {conv_session_id}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:httpx:HTTP Request: POST http://localhost:8321/v1/responses \"HTTP/1.1 200 OK\"\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "============================================================\n", + "Turn 1\n", + "============================================================\n", + "User: My name is Alice and I'm learning about AI.\n", + "Assistant: Nice to meet you, Alice! It's great that you're interested in learning about AI. What aspects of AI would you like to explore? Are you curious about machine learning, natural language processing, or something else? I'll be happy to help and provide information tailored to your interests.Nice to meet you, Alice! It's great that you're interested in learning about AI. What aspects of AI would you like to explore? Are you curious about machine learning, natural language processing, or something else? I'll be happy to help and provide information tailored to your interests.\n", + "\n", + "============================================================\n", + "Turn 2\n", + "============================================================\n", + "User: What are some good resources for beginners?\n", + "Assistant: " + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:httpx:HTTP Request: POST http://localhost:8321/v1/responses \"HTTP/1.1 200 OK\"\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "As a beginner, it's essential to start with resources that provide a solid foundation in AI concepts. Here are some recommendations:\n", + "\n", + "1. **Online Courses**:\n", + "\t* Andrew Ng's Machine Learning course on Coursera: A popular and comprehensive introduction to machine learning.\n", + "\t* Stanford University's Natural Language Processing with Deep Learning Specialization on Coursera: Covers NLP fundamentals and deep learning techniques.\n", + "2. **Books**:\n", + "\t* \"Introduction to Artificial Intelligence\" by Philip C. Jackson Jr.: A gentle introduction to AI concepts, including machine learning and computer vision.\n", + "\t* \"Deep Learning\" by Ian Goodfellow, Yoshua Bengio, and Aaron Courville: A detailed book on deep learning techniques, although it may require some prior knowledge of linear algebra and calculus.\n", + "3. **Websites and Blogs**:\n", + "\t* Machine Learning Mastery: A website offering tutorials, examples, and explanations on various machine learning topics.\n", + "\t* KDnuggets: A popular blog covering AI, machine learning, and data science news, tutorials, and research papers.\n", + "4. **YouTube Channels**:\n", + "\t* 3Blue1Brown (Grant Sanderson): Engaging video explanations on AI, machine learning, and linear algebra concepts.\n", + "\t* Sentdex: Offers video tutorials on various AI topics, including machine learning, deep learning, and computer vision.\n", + "5. **Communities and Forums**:\n", + "\t* Kaggle: A platform for data science competitions and hosting datasets, where you can learn from others and participate in discussions.\n", + "\t* Reddit's r/MachineLearning and r/AI: Active communities discussing AI-related topics, sharing resources, and providing feedback on projects.\n", + "\n", + "Remember, learning about AI is a continuous process. Start with the basics, build projects, and gradually move on to more advanced topics. Practice and experimentation are key to gaining hands-on experience.\n", + "\n", + "What specific area of AI would you like to explore first, Alice?As a beginner, it's essential to start with resources that provide a solid foundation in AI concepts. Here are some recommendations:\n", + "\n", + "1. **Online Courses**:\n", + "\t* Andrew Ng's Machine Learning course on Coursera: A popular and comprehensive introduction to machine learning.\n", + "\t* Stanford University's Natural Language Processing with Deep Learning Specialization on Coursera: Covers NLP fundamentals and deep learning techniques.\n", + "2. **Books**:\n", + "\t* \"Introduction to Artificial Intelligence\" by Philip C. Jackson Jr.: A gentle introduction to AI concepts, including machine learning and computer vision.\n", + "\t* \"Deep Learning\" by Ian Goodfellow, Yoshua Bengio, and Aaron Courville: A detailed book on deep learning techniques, although it may require some prior knowledge of linear algebra and calculus.\n", + "3. **Websites and Blogs**:\n", + "\t* Machine Learning Mastery: A website offering tutorials, examples, and explanations on various machine learning topics.\n", + "\t* KDnuggets: A popular blog covering AI, machine learning, and data science news, tutorials, and research papers.\n", + "4. **YouTube Channels**:\n", + "\t* 3Blue1Brown (Grant Sanderson): Engaging video explanations on AI, machine learning, and linear algebra concepts.\n", + "\t* Sentdex: Offers video tutorials on various AI topics, including machine learning, deep learning, and computer vision.\n", + "5. **Communities and Forums**:\n", + "\t* Kaggle: A platform for data science competitions and hosting datasets, where you can learn from others and participate in discussions.\n", + "\t* Reddit's r/MachineLearning and r/AI: Active communities discussing AI-related topics, sharing resources, and providing feedback on projects.\n", + "\n", + "Remember, learning about AI is a continuous process. Start with the basics, build projects, and gradually move on to more advanced topics. Practice and experimentation are key to gaining hands-on experience.\n", + "\n", + "What specific area of AI would you like to explore first, Alice?\n", + "\n", + "============================================================\n", + "Turn 3\n", + "============================================================\n", + "User: Can you remind me what my name is?\n", + "Assistant: " + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:httpx:HTTP Request: POST http://localhost:8321/v1/responses \"HTTP/1.1 200 OK\"\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Your name is Alice! I remember that from our previous conversation when you introduced yourself as someone interested in learning about AI. How can I assist you further today?Your name is Alice! I remember that from our previous conversation when you introduced yourself as someone interested in learning about AI. How can I assist you further today?\n", + "\n", + "✅ Completed 3 conversational turns with context retention\n" + ] + } + ], + "source": [ + "# Conversation turns that build on each other\n", + "conversation_turns = [\n", + " \"My name is Alice and I'm learning about AI.\",\n", + " \"What are some good resources for beginners?\",\n", + " \"Can you remind me what my name is?\",\n", + "]\n", + "\n", + "for i, query in enumerate(conversation_turns, 1):\n", + " print(f\"\\n{'='*60}\")\n", + " print(f\"Turn {i}\")\n", + " print(f\"{'='*60}\")\n", + " print(f\"User: {query}\")\n", + "\n", + " response = conv_agent.create_turn(\n", + " session_id=conv_session_id,\n", + " messages=[UserMessage(content=query, role=\"user\")],\n", + " stream=True,\n", + " )\n", + "\n", + " print(\"Assistant: \", end='')\n", + " for chunk in response:\n", + " if chunk.event.event_type == \"turn_completed\":\n", + " output = chunk.event.final_text\n", + " print(output)\n", + " break\n", + " elif chunk.event.event_type == \"step_progress\":\n", + " if hasattr(chunk.event.delta, 'text'):\n", + " print(chunk.event.delta.text, end='', flush=True)\n", + "\n", + "print(f\"\\n✅ Completed {len(conversation_turns)} conversational turns with context retention\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:httpx:HTTP Request: DELETE http://localhost:8321/v1/conversations/conv_936121c2e27b7d1f7d3f0b6a62adce867d79268f5f9ce265 \"HTTP/1.1 200 OK\"\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "✅ Session cleaned up\n" + ] + } + ], + "source": [ + "# Cleanup\n", + "client.conversations.delete(conversation_id=conv_session_id)\n", + "print(\"✅ Session cleaned up\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2.2 RAG (Retrieval-Augmented Generation) Pattern\n", + "\n", + "Demonstrate how to provide context to the agent for more accurate responses." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Knowledge base: 3 Paul Graham essay excerpts\n", + " - pg_essay_1: What I Worked On\n", + " - pg_essay_2: How to Start a Startup\n", + " - pg_essay_3: Maker's Schedule, Manager's Schedule\n" + ] + } + ], + "source": [ + "# Sample knowledge base: Paul Graham essay excerpts\n", + "# This is a common RAG example - using actual content from Paul Graham's essays\n", + "documents = [\n", + " {\n", + " \"doc_id\": \"pg_essay_1\",\n", + " \"content\": \"\"\"What I Worked On\n", + "\n", + " Before college the two main things I worked on, outside of school, were writing and programming.\n", + " I didn't write essays. I wrote what beginning writers were supposed to write then, and probably\n", + " still are: short stories. My stories were awful. They had hardly any plot, just characters with\n", + " strong feelings, which I imagined made them deep.\n", + "\n", + " The first programs I tried writing were on the IBM 1401 that our school district used for what\n", + " was then called 'data processing.' This was in 9th grade, so I was 13 or 14. The school district's\n", + " 1401 happened to be in the basement of our junior high school, and my friend Rich Draves and I got\n", + " permission to use it.\"\"\",\n", + " \"metadata\": {\"essay\": \"What I Worked On\", \"author\": \"Paul Graham\", \"year\": 2021}\n", + " },\n", + " {\n", + " \"doc_id\": \"pg_essay_2\",\n", + " \"content\": \"\"\"How to Start a Startup\n", + "\n", + " You need three things to create a successful startup: to start with good people, to make something\n", + " customers actually want, and to spend as little money as possible. Most startups that fail do it\n", + " because they fail at one of these. A startup that does all three will probably succeed.\n", + "\n", + " And that's kind of exciting, when you think about it, because all three are doable. Hard, but doable.\n", + " And since a startup that succeeds ordinarily makes its founders rich, that implies getting rich is\n", + " doable too. Hard, but doable.\"\"\",\n", + " \"metadata\": {\"essay\": \"How to Start a Startup\", \"author\": \"Paul Graham\", \"year\": 2005}\n", + " },\n", + " {\n", + " \"doc_id\": \"pg_essay_3\",\n", + " \"content\": \"\"\"Maker's Schedule, Manager's Schedule\n", + "\n", + " One reason programmers dislike meetings so much is that they're on a different type of schedule\n", + " from other people. Meetings cost them more.\n", + "\n", + " There are two types of schedule, which I'll call the manager's schedule and the maker's schedule.\n", + " The manager's schedule is for bosses. It's embodied in the traditional appointment book, with each\n", + " day cut into one hour intervals. When you use time that way, it's merely a practical problem to\n", + " meet with someone. But there's another way of using time that's common among people who make things,\n", + " like programmers and writers. They generally prefer to use time in units of half a day at least.\"\"\",\n", + " \"metadata\": {\"essay\": \"Maker's Schedule, Manager's Schedule\", \"author\": \"Paul Graham\", \"year\": 2009}\n", + " },\n", + "]\n", + "\n", + "print(f\"Knowledge base: {len(documents)} Paul Graham essay excerpts\")\n", + "for doc in documents:\n", + " print(f\" - {doc['doc_id']}: {doc['metadata']['essay']}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "✅ Created RAG agent\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:httpx:HTTP Request: POST http://localhost:8321/v1/conversations \"HTTP/1.1 200 OK\"\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "✅ Created session: conv_9ae94374c781501f2d712620dcc8e55961b5a226df229b1d\n" + ] + } + ], + "source": [ + "# Create RAG-enabled agent\n", + "rag_agent = Agent(\n", + " client=client,\n", + " model=\"ollama/llama3.3:70b\",\n", + " instructions=(\n", + " \"You are a helpful AI assistant with access to a knowledge base. \"\n", + " \"When answering questions, use the provided context from the knowledge base. \"\n", + " \"If the context doesn't contain relevant information, say so.\"\n", + " ),\n", + ")\n", + "\n", + "print(\"✅ Created RAG agent\")\n", + "\n", + "rag_session_id = rag_agent.create_session(session_name=\"rag_session\")\n", + "print(f\"✅ Created session: {rag_session_id}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:httpx:HTTP Request: POST http://localhost:8321/v1/responses \"HTTP/1.1 200 OK\"\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Query: What did Paul Graham work on before college?\n", + "Retrieved 1 relevant document(s)\n", + "\n", + "Answer: Based on the provided context from \"What I Worked On\", before college, Paul Graham worked on two main things outside of school: \n", + "\n", + "1. Writing (specifically short stories)\n", + "2. Programming (initially on the IBM 1401)Based on the provided context from \"What I Worked On\", before college, Paul Graham worked on two main things outside of school: \n", + "\n", + "1. Writing (specifically short stories)\n", + "2. Programming (initially on the IBM 1401)\n", + "\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:httpx:HTTP Request: DELETE http://localhost:8321/v1/conversations/conv_9ae94374c781501f2d712620dcc8e55961b5a226df229b1d \"HTTP/1.1 200 OK\"\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "✅ Session cleaned up\n" + ] + } + ], + "source": [ + "# Query with context from Paul Graham essays\n", + "query = \"What did Paul Graham work on before college?\"\n", + "\n", + "# Simulate retrieval (in production, use vector search)\n", + "relevant_docs = [doc for doc in documents if \"before college\" in doc[\"content\"].lower()]\n", + "context = \"\\n\\n\".join([f\"From '{doc['metadata']['essay']}':\\n{doc['content']}\"\n", + " for doc in relevant_docs])\n", + "\n", + "# Create prompt with retrieved context\n", + "prompt_with_context = f\"\"\"Context from knowledge base:\n", + "{context}\n", + "\n", + "Question: {query}\n", + "\n", + "Please answer based on the provided context.\"\"\"\n", + "\n", + "print(f\"Query: {query}\")\n", + "print(f\"Retrieved {len(relevant_docs)} relevant document(s)\\n\")\n", + "print(\"Answer: \", end='')\n", + "\n", + "response = rag_agent.create_turn(\n", + " session_id=rag_session_id,\n", + " messages=[UserMessage(content=prompt_with_context, role=\"user\")],\n", + " stream=True,\n", + ")\n", + "\n", + "for chunk in response:\n", + " if chunk.event.event_type == \"turn_completed\":\n", + " output = chunk.event.final_text\n", + " print(output)\n", + " break\n", + " elif chunk.event.event_type == \"step_progress\":\n", + " if hasattr(chunk.event.delta, 'text'):\n", + " print(chunk.event.delta.text, end='', flush=True)\n", + "\n", + "print(\"\\n\")\n", + "client.conversations.delete(conversation_id=rag_session_id)\n", + "print(\"✅ Session cleaned up\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "\n", + "# Part 3: MCP (Model Context Protocol) Tools\n", + "\n", + "MCP provides a standardized way for AI models to interact with external tools and data sources.\n", + "\n", + "We'll demonstrate:\n", + "- Defining MCP-compatible tools\n", + "- Agent tool selection\n", + "- Tool execution and response handling" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Created 3 MCP tools:\n", + " - get_weather: Get current weather information for a specified location\n", + " - execute_code: Execute Python code and return the result\n", + " - web_search: Search the web for information\n" + ] + } + ], + "source": [ + "def create_mcp_tools():\n", + " \"\"\"Create MCP-compatible tool definitions.\"\"\"\n", + " return [\n", + " {\n", + " \"tool_name\": \"get_weather\",\n", + " \"description\": \"Get current weather information for a specified location\",\n", + " \"parameters\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"location\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"City and state/country, e.g., 'San Francisco, CA'\"\n", + " },\n", + " \"unit\": {\n", + " \"type\": \"string\",\n", + " \"enum\": [\"celsius\", \"fahrenheit\"],\n", + " \"description\": \"Temperature unit\",\n", + " \"default\": \"fahrenheit\"\n", + " }\n", + " },\n", + " \"required\": [\"location\"]\n", + " }\n", + " },\n", + " {\n", + " \"tool_name\": \"execute_code\",\n", + " \"description\": \"Execute Python code and return the result\",\n", + " \"parameters\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"code\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"Python code to execute\"\n", + " }\n", + " },\n", + " \"required\": [\"code\"]\n", + " }\n", + " },\n", + " {\n", + " \"tool_name\": \"web_search\",\n", + " \"description\": \"Search the web for information\",\n", + " \"parameters\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"query\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"Search query\"\n", + " }\n", + " },\n", + " \"required\": [\"query\"]\n", + " }\n", + " },\n", + " ]\n", + "\n", + "tools = create_mcp_tools()\n", + "print(f\"Created {len(tools)} MCP tools:\")\n", + "for tool in tools:\n", + " print(f\" - {tool['tool_name']}: {tool['description']}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MCP tool configuration ready\n", + " Server: http://localhost:3000/sse\n", + " Format: MCP server-based\n", + "\n", + " To use MCP tools:\n", + " 1. Set up your MCP server\n", + " 2. Update MCP_SERVER_URL and MCP_ACCESS_TOKEN above\n", + " 3. Pass mcp_tools to Agent(tools=mcp_tools)\n" + ] + } + ], + "source": [ + "# Example 2: MCP Server Configuration (0.3.0 format)\n", + "\n", + "# MCP server configuration\n", + "# Replace with your actual MCP server URL and credentials\n", + "MCP_SERVER_URL = \"https://api.example.com/mcp\" # Your MCP server endpoint\n", + "MCP_ACCESS_TOKEN = \"your-token-here\" # Your authentication token\n", + "\n", + "MCP_ACCESS_TOKEN = \"YOUR_ACCESS_TOKEN_HERE\"\n", + "## ran an MCP server locally, you can replace this field with your mcp server url\n", + "MCP_SERVER_URL = \"http://localhost:3000/sse\"\n", + "#MCP_SERVER_URL = \"https://mcp.deepwiki.com/sse\"\n", + "mcp_tools = [\n", + " {\n", + " \"type\": \"mcp\",\n", + " \"server_url\": MCP_SERVER_URL,\n", + " \"server_label\": \"weather\",\n", + " \"headers\": {\n", + " \"Authorization\": f\"Bearer {MCP_ACCESS_TOKEN}\",\n", + " },\n", + " }\n", + "]\n", + "\n", + "\n", + "print(\"MCP tool configuration ready\")\n", + "print(f\" Server: {MCP_SERVER_URL}\")\n", + "print(\" Format: MCP server-based\")\n", + "print(\"\\n To use MCP tools:\")\n", + "print(\" 1. Set up your MCP server\")\n", + "print(\" 2. Update MCP_SERVER_URL and MCP_ACCESS_TOKEN above\")\n", + "print(\" 3. Pass mcp_tools to Agent(tools=mcp_tools)\")" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tool execution simulator ready\n" + ] + } + ], + "source": [ + "def simulate_tool_execution(tool_name: str, arguments: Dict[str, Any]) -> str:\n", + " \"\"\"Simulate tool execution (replace with real implementations).\"\"\"\n", + " if tool_name == \"get_weather\":\n", + " location = arguments.get(\"location\", \"Unknown\")\n", + " unit = arguments.get(\"unit\", \"fahrenheit\")\n", + " temp = \"72°F\" if unit == \"fahrenheit\" else \"22°C\"\n", + " return json.dumps({\n", + " \"location\": location,\n", + " \"temperature\": temp,\n", + " \"condition\": \"Partly cloudy\",\n", + " \"humidity\": \"65%\",\n", + " \"wind\": \"10 mph NW\"\n", + " })\n", + " elif tool_name == \"execute_code\":\n", + " code = arguments.get(\"code\", \"\")\n", + " return json.dumps({\n", + " \"status\": \"success\",\n", + " \"output\": f\"Code execution simulated for: {code[:50]}...\"\n", + " })\n", + " elif tool_name == \"web_search\":\n", + " query = arguments.get(\"query\", \"\")\n", + " return json.dumps({\n", + " \"status\": \"success\",\n", + " \"results\": [\n", + " {\"title\": f\"Result {i+1}\", \"url\": f\"https://example.com/{i+1}\",\n", + " \"snippet\": f\"Information about {query}\"}\n", + " for i in range(3)\n", + " ]\n", + " })\n", + " return json.dumps({\"error\": \"Unknown tool\"})\n", + "\n", + "print(\"Tool execution simulator ready\")" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Created MCP agent\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:httpx:HTTP Request: POST http://localhost:8321/v1/conversations \"HTTP/1.1 200 OK\"\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "✅ Created session: conv_5613324aa4c3193b1434bf562fe1c75dc2e0563c681738b1\n" + ] + } + ], + "source": [ + "mcp_agent = Agent(\n", + " client=client,\n", + " model=\"ollama/llama3.3:70b\",\n", + " instructions=\"You are a helpful AI assistant that can answer questions and help with various tasks.\",\n", + " tools=mcp_tools # you can set this field to tools when experimenting with the tools created by create_mcp_tools above.\n", + ")\n", + "\n", + "print(\"Created MCP agent\")\n", + "\n", + "mcp_session_id = mcp_agent.create_session(session_name=\"mcp_tools_session\")\n", + "print(f\"✅ Created session: {mcp_session_id}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:httpx:HTTP Request: POST http://localhost:8321/v1/responses \"HTTP/1.1 200 OK\"\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "======================================================================\n", + "MCP TOOL EXAMPLE\n", + "======================================================================\n", + "\n", + " User: What's the weather like in New York City?\n", + "\n", + " Assistant: \n", + "\n", + " [Tool Execution Started]\n", + "\n", + "\n", + " [Tool Execution Started]\n", + "The current weather in New York City is mostly cloudy with a temperature of 49°F and a wind speed of 17 mph NE. Today, it will be partly sunny with a high of 55°F. Tonight, there's a chance of rain showers with a low of 53°F. The rest of the week will see a mix of rain, thunderstorms, and sunshine, with temperatures ranging from the mid-50s to the mid-60s. It's a good idea to check the forecast regularly for updates.The current weather in New York City is mostly cloudy with a temperature of 49°F and a wind speed of 17 mph NE. Today, it will be partly sunny with a high of 55°F. Tonight, there's a chance of rain showers with a low of 53°F. The rest of the week will see a mix of rain, thunderstorms, and sunshine, with temperatures ranging from the mid-50s to the mid-60s. It's a good idea to check the forecast regularly for updates.\n", + "\n", + "\n", + " Summary: Used 2 tool(s) to answer the query\n" + ] + } + ], + "source": [ + "# Example: Weather query that should trigger tool usage\n", + "query = \"What's the weather like in New York City?\"\n", + "\n", + "print(f\"{'='*70}\")\n", + "print(f\"MCP TOOL EXAMPLE\")\n", + "print(f\"{'='*70}\")\n", + "print(f\"\\n User: {query}\")\n", + "\n", + "response = mcp_agent.create_turn(\n", + " session_id=mcp_session_id,\n", + " messages=[UserMessage(content=query, role=\"user\")],\n", + " stream=True,\n", + ")\n", + "\n", + "print(\"\\n Assistant: \", end='')\n", + "tool_calls_made = []\n", + "\n", + "for chunk in response:\n", + " event_type = chunk.event.event_type\n", + "\n", + " if event_type == \"step_started\":\n", + " if chunk.event.step_type == \"tool_execution\":\n", + " print(f\"\\n\\n [Tool Execution Started]\")\n", + "\n", + " elif event_type == \"step_progress\":\n", + " # Check for tool call deltas\n", + " if hasattr(chunk.event.delta, 'delta_type'):\n", + " if chunk.event.delta.delta_type == \"tool_call_issued\":\n", + " tool_calls_made.append(chunk.event.delta)\n", + " result = simulate_tool_execution(\n", + " chunk.event.delta.tool_name,\n", + " json.loads(chunk.event.delta.arguments)\n", + " )\n", + " if hasattr(chunk.event.delta, 'text'):\n", + " print(chunk.event.delta.text, end='', flush=True)\n", + "\n", + " elif event_type == \"turn_completed\":\n", + " output = chunk.event.final_text\n", + " if output:\n", + " print(output)\n", + "\n", + "print()\n", + "if tool_calls_made:\n", + " print(f\"\\n Summary: Used {len(tool_calls_made)} tool(s) to answer the query\")" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:httpx:HTTP Request: DELETE http://localhost:8321/v1/conversations/conv_5613324aa4c3193b1434bf562fe1c75dc2e0563c681738b1 \"HTTP/1.1 200 OK\"\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "✅ Session cleaned up\n" + ] + } + ], + "source": [ + "# Cleanup\n", + "client.conversations.delete(conversation_id=mcp_session_id)\n", + "print(\"✅ Session cleaned up\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "\n", + "# Summary\n", + "\n", + "This notebook demonstrated three levels of Llama Stack agent capabilities:\n", + "\n", + "## 1. Basic Agent\n", + "- ✅ Simple agent creation\n", + "- ✅ Session management \n", + "- ✅ Streaming responses\n", + "\n", + "## 2. Advanced Features\n", + "- ✅ Multi-turn conversations\n", + "- ✅ RAG (Retrieval-Augmented Generation) pattern\n", + "- ✅ Custom knowledge base integration\n", + "\n", + "## 3. MCP Tools Integration\n", + "- ✅ MCP-compatible tool definitions\n", + "- ✅ Automatic tool selection by the agent\n", + "- ✅ Tool execution and response handling\n", + "- ✅ Real-time streaming with tool calls\n", + "\n", + "\n", + "## Resources\n", + "\n", + "- [Llama Stack Documentation](https://llama-stack.readthedocs.io/)\n", + "- [Llama Stack GitHub](https://github.com/meta-llama/llama-stack)\n", + "- [MCP Protocol Specification](https://modelcontextprotocol.io/)\n", + "- [Ollama Documentation](https://ollama.ai/)" + ] + } + ], + "metadata": { + "fileHeader": "", + "fileUid": "92b7454e-a941-41f0-bd02-6d5e728f20f1", + "isAdHoc": false, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.7" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/openapi_generator/generate.py b/docs/openapi_generator/generate.py index b489833b3..65720df4a 100644 --- a/docs/openapi_generator/generate.py +++ b/docs/openapi_generator/generate.py @@ -84,7 +84,6 @@ def generate_spec(output_dir: Path, stability_filter: str = None, main_spec: boo ) yaml_filename = f"{filename_prefix}llama-stack-spec.yaml" - html_filename = f"{filename_prefix}llama-stack-spec.html" with open(output_dir / yaml_filename, "w", encoding="utf-8") as fp: y = yaml.YAML() @@ -102,11 +101,6 @@ def generate_spec(output_dir: Path, stability_filter: str = None, main_spec: boo fp, ) - with open(output_dir / html_filename, "w") as fp: - spec.write_html(fp, pretty_print=True) - - print(f"Generated {yaml_filename} and {html_filename}") - def main(output_dir: str): output_dir = Path(output_dir) if not output_dir.exists(): diff --git a/docs/sidebars.ts b/docs/sidebars.ts index f2cfe3798..641c2eed3 100644 --- a/docs/sidebars.ts +++ b/docs/sidebars.ts @@ -242,15 +242,6 @@ const sidebars: SidebarsConfig = { 'providers/eval/remote_nvidia' ], }, - { - type: 'category', - label: 'Telemetry', - collapsed: true, - items: [ - 'providers/telemetry/index', - 'providers/telemetry/inline_meta-reference' - ], - }, { type: 'category', label: 'Batches', diff --git a/docs/static/deprecated-llama-stack-spec.html b/docs/static/deprecated-llama-stack-spec.html deleted file mode 100644 index dca129631..000000000 --- a/docs/static/deprecated-llama-stack-spec.html +++ /dev/null @@ -1,13582 +0,0 @@ - - - - - - - OpenAPI specification - - - - - - - - - - - - - diff --git a/docs/static/deprecated-llama-stack-spec.yaml b/docs/static/deprecated-llama-stack-spec.yaml index 12d1327a2..15a3166de 100644 --- a/docs/static/deprecated-llama-stack-spec.yaml +++ b/docs/static/deprecated-llama-stack-spec.yaml @@ -1012,6 +1012,141 @@ paths: schema: type: string deprecated: true + /v1/openai/v1/batches: + get: + responses: + '200': + description: A list of batch objects. + content: + application/json: + schema: + $ref: '#/components/schemas/ListBatchesResponse' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Batches + summary: List all batches for the current user. + description: List all batches for the current user. + parameters: + - name: after + in: query + description: >- + A cursor for pagination; returns batches after this batch ID. + required: false + schema: + type: string + - name: limit + in: query + description: >- + Number of batches to return (default 20, max 100). + required: true + schema: + type: integer + deprecated: true + post: + responses: + '200': + description: The created batch object. + content: + application/json: + schema: + $ref: '#/components/schemas/Batch' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Batches + summary: >- + Create a new batch for processing multiple API requests. + description: >- + Create a new batch for processing multiple API requests. + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/CreateBatchRequest' + required: true + deprecated: true + /v1/openai/v1/batches/{batch_id}: + get: + responses: + '200': + description: The batch object. + content: + application/json: + schema: + $ref: '#/components/schemas/Batch' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Batches + summary: >- + Retrieve information about a specific batch. + description: >- + Retrieve information about a specific batch. + parameters: + - name: batch_id + in: path + description: The ID of the batch to retrieve. + required: true + schema: + type: string + deprecated: true + /v1/openai/v1/batches/{batch_id}/cancel: + post: + responses: + '200': + description: The updated batch object. + content: + application/json: + schema: + $ref: '#/components/schemas/Batch' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Batches + summary: Cancel a batch that is in progress. + description: Cancel a batch that is in progress. + parameters: + - name: batch_id + in: path + description: The ID of the batch to cancel. + required: true + schema: + type: string + deprecated: true /v1/openai/v1/chat/completions: get: responses: @@ -1426,31 +1561,6 @@ paths: schema: type: string deprecated: true - /v1/openai/v1/models: - get: - responses: - '200': - description: A OpenAIListModelsResponse. - content: - application/json: - schema: - $ref: '#/components/schemas/OpenAIListModelsResponse' - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - Models - summary: List models using the OpenAI API. - description: List models using the OpenAI API. - parameters: [] - deprecated: true /v1/openai/v1/moderations: post: responses: @@ -4736,6 +4846,331 @@ components: title: Job description: >- A job execution instance with status tracking. + ListBatchesResponse: + type: object + properties: + object: + type: string + const: list + default: list + data: + type: array + items: + type: object + properties: + id: + type: string + completion_window: + type: string + created_at: + type: integer + endpoint: + type: string + input_file_id: + type: string + object: + type: string + const: batch + status: + type: string + enum: + - validating + - failed + - in_progress + - finalizing + - completed + - expired + - cancelling + - cancelled + cancelled_at: + type: integer + cancelling_at: + type: integer + completed_at: + type: integer + error_file_id: + type: string + errors: + type: object + properties: + data: + type: array + items: + type: object + properties: + code: + type: string + line: + type: integer + message: + type: string + param: + type: string + additionalProperties: false + title: BatchError + object: + type: string + additionalProperties: false + title: Errors + expired_at: + type: integer + expires_at: + type: integer + failed_at: + type: integer + finalizing_at: + type: integer + in_progress_at: + type: integer + metadata: + type: object + additionalProperties: + type: string + model: + type: string + output_file_id: + type: string + request_counts: + type: object + properties: + completed: + type: integer + failed: + type: integer + total: + type: integer + additionalProperties: false + required: + - completed + - failed + - total + title: BatchRequestCounts + usage: + type: object + properties: + input_tokens: + type: integer + input_tokens_details: + type: object + properties: + cached_tokens: + type: integer + additionalProperties: false + required: + - cached_tokens + title: InputTokensDetails + output_tokens: + type: integer + output_tokens_details: + type: object + properties: + reasoning_tokens: + type: integer + additionalProperties: false + required: + - reasoning_tokens + title: OutputTokensDetails + total_tokens: + type: integer + additionalProperties: false + required: + - input_tokens + - input_tokens_details + - output_tokens + - output_tokens_details + - total_tokens + title: BatchUsage + additionalProperties: false + required: + - id + - completion_window + - created_at + - endpoint + - input_file_id + - object + - status + title: Batch + first_id: + type: string + last_id: + type: string + has_more: + type: boolean + default: false + additionalProperties: false + required: + - object + - data + - has_more + title: ListBatchesResponse + description: >- + Response containing a list of batch objects. + CreateBatchRequest: + type: object + properties: + input_file_id: + type: string + description: >- + The ID of an uploaded file containing requests for the batch. + endpoint: + type: string + description: >- + The endpoint to be used for all requests in the batch. + completion_window: + type: string + const: 24h + description: >- + The time window within which the batch should be processed. + metadata: + type: object + additionalProperties: + type: string + description: Optional metadata for the batch. + idempotency_key: + type: string + description: >- + Optional idempotency key. When provided, enables idempotent behavior. + additionalProperties: false + required: + - input_file_id + - endpoint + - completion_window + title: CreateBatchRequest + Batch: + type: object + properties: + id: + type: string + completion_window: + type: string + created_at: + type: integer + endpoint: + type: string + input_file_id: + type: string + object: + type: string + const: batch + status: + type: string + enum: + - validating + - failed + - in_progress + - finalizing + - completed + - expired + - cancelling + - cancelled + cancelled_at: + type: integer + cancelling_at: + type: integer + completed_at: + type: integer + error_file_id: + type: string + errors: + type: object + properties: + data: + type: array + items: + type: object + properties: + code: + type: string + line: + type: integer + message: + type: string + param: + type: string + additionalProperties: false + title: BatchError + object: + type: string + additionalProperties: false + title: Errors + expired_at: + type: integer + expires_at: + type: integer + failed_at: + type: integer + finalizing_at: + type: integer + in_progress_at: + type: integer + metadata: + type: object + additionalProperties: + type: string + model: + type: string + output_file_id: + type: string + request_counts: + type: object + properties: + completed: + type: integer + failed: + type: integer + total: + type: integer + additionalProperties: false + required: + - completed + - failed + - total + title: BatchRequestCounts + usage: + type: object + properties: + input_tokens: + type: integer + input_tokens_details: + type: object + properties: + cached_tokens: + type: integer + additionalProperties: false + required: + - cached_tokens + title: InputTokensDetails + output_tokens: + type: integer + output_tokens_details: + type: object + properties: + reasoning_tokens: + type: integer + additionalProperties: false + required: + - reasoning_tokens + title: OutputTokensDetails + total_tokens: + type: integer + additionalProperties: false + required: + - input_tokens + - input_tokens_details + - output_tokens + - output_tokens_details + - total_tokens + title: BatchUsage + additionalProperties: false + required: + - id + - completion_window + - created_at + - endpoint + - input_file_id + - object + - status + title: Batch Order: type: string enum: @@ -6056,38 +6491,6 @@ components: Response: type: object title: Response - OpenAIModel: - type: object - properties: - id: - type: string - object: - type: string - const: model - default: model - created: - type: integer - owned_by: - type: string - additionalProperties: false - required: - - id - - object - - created - - owned_by - title: OpenAIModel - description: A model from OpenAI. - OpenAIListModelsResponse: - type: object - properties: - data: - type: array - items: - $ref: '#/components/schemas/OpenAIModel' - additionalProperties: false - required: - - data - title: OpenAIListModelsResponse RunModerationRequest: type: object properties: @@ -10263,6 +10666,19 @@ tags: - **Responses API**: Use the stable v1 Responses API endpoints x-displayName: Agents + - name: Batches + description: >- + The API is designed to allow use of openai client libraries for seamless integration. + + + This API provides the following extensions: + - idempotent batch creation + + Note: This API is currently under active development and may undergo changes. + x-displayName: >- + The Batches API enables efficient processing of multiple requests in a single + operation, particularly useful for processing large datasets, batch evaluation + workflows, and cost-effective inference at scale. - name: Benchmarks description: '' - name: DatasetIO @@ -10295,8 +10711,6 @@ tags: - Rerank models: these models reorder the documents based on their relevance to a query. x-displayName: Inference - - name: Models - description: '' - name: PostTraining (Coming Soon) description: '' - name: Safety @@ -10308,13 +10722,13 @@ x-tagGroups: - name: Operations tags: - Agents + - Batches - Benchmarks - DatasetIO - Datasets - Eval - Files - Inference - - Models - PostTraining (Coming Soon) - Safety - VectorIO diff --git a/docs/static/experimental-llama-stack-spec.html b/docs/static/experimental-llama-stack-spec.html deleted file mode 100644 index 22473ec11..000000000 --- a/docs/static/experimental-llama-stack-spec.html +++ /dev/null @@ -1,5552 +0,0 @@ - - - - - - - OpenAPI specification - - - - - - - - - - - - - diff --git a/docs/static/llama-stack-spec.html b/docs/static/llama-stack-spec.html index a481fe074..514bff145 100644 --- a/docs/static/llama-stack-spec.html +++ b/docs/static/llama-stack-spec.html @@ -40,6 +40,193 @@ } ], "paths": { + "/v1/batches": { + "get": { + "responses": { + "200": { + "description": "A list of batch objects.", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ListBatchesResponse" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "Batches" + ], + "summary": "List all batches for the current user.", + "description": "List all batches for the current user.", + "parameters": [ + { + "name": "after", + "in": "query", + "description": "A cursor for pagination; returns batches after this batch ID.", + "required": false, + "schema": { + "type": "string" + } + }, + { + "name": "limit", + "in": "query", + "description": "Number of batches to return (default 20, max 100).", + "required": true, + "schema": { + "type": "integer" + } + } + ], + "deprecated": false + }, + "post": { + "responses": { + "200": { + "description": "The created batch object.", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Batch" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "Batches" + ], + "summary": "Create a new batch for processing multiple API requests.", + "description": "Create a new batch for processing multiple API requests.", + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CreateBatchRequest" + } + } + }, + "required": true + }, + "deprecated": false + } + }, + "/v1/batches/{batch_id}": { + "get": { + "responses": { + "200": { + "description": "The batch object.", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Batch" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "Batches" + ], + "summary": "Retrieve information about a specific batch.", + "description": "Retrieve information about a specific batch.", + "parameters": [ + { + "name": "batch_id", + "in": "path", + "description": "The ID of the batch to retrieve.", + "required": true, + "schema": { + "type": "string" + } + } + ], + "deprecated": false + } + }, + "/v1/batches/{batch_id}/cancel": { + "post": { + "responses": { + "200": { + "description": "The updated batch object.", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Batch" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "Batches" + ], + "summary": "Cancel a batch that is in progress.", + "description": "Cancel a batch that is in progress.", + "parameters": [ + { + "name": "batch_id", + "in": "path", + "description": "The ID of the batch to cancel.", + "required": true, + "schema": { + "type": "string" + } + } + ], + "deprecated": false + } + }, "/v1/chat/completions": { "get": { "responses": { @@ -1071,7 +1258,23 @@ ], "summary": "List routes.", "description": "List routes.\nList all available API routes with their methods and implementing providers.", - "parameters": [], + "parameters": [ + { + "name": "api_filter", + "in": "query", + "description": "Optional filter to control which routes are returned. Can be an API level ('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level, or 'deprecated' to show deprecated routes across all levels. If not specified, returns only non-deprecated v1 routes.", + "required": false, + "schema": { + "type": "string", + "enum": [ + "v1", + "v1alpha", + "v1beta", + "deprecated" + ] + } + } + ], "deprecated": false } }, @@ -2447,51 +2650,6 @@ "deprecated": false } }, - "/v1/synthetic-data-generation/generate": { - "post": { - "responses": { - "200": { - "description": "Response containing filtered synthetic data samples and optional statistics", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/SyntheticDataGenerationResponse" - } - } - } - }, - "400": { - "$ref": "#/components/responses/BadRequest400" - }, - "429": { - "$ref": "#/components/responses/TooManyRequests429" - }, - "500": { - "$ref": "#/components/responses/InternalServerError500" - }, - "default": { - "$ref": "#/components/responses/DefaultError" - } - }, - "tags": [ - "SyntheticDataGeneration (Coming Soon)" - ], - "summary": "Generate synthetic data based on input dialogs and apply filtering.", - "description": "Generate synthetic data based on input dialogs and apply filtering.", - "parameters": [], - "requestBody": { - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/SyntheticDataGenerateRequest" - } - } - }, - "required": true - }, - "deprecated": false - } - }, "/v1/tool-runtime/invoke": { "post": { "responses": { @@ -4005,6 +4163,451 @@ "title": "Error", "description": "Error response from the API. Roughly follows RFC 7807." }, + "ListBatchesResponse": { + "type": "object", + "properties": { + "object": { + "type": "string", + "const": "list", + "default": "list" + }, + "data": { + "type": "array", + "items": { + "type": "object", + "properties": { + "id": { + "type": "string" + }, + "completion_window": { + "type": "string" + }, + "created_at": { + "type": "integer" + }, + "endpoint": { + "type": "string" + }, + "input_file_id": { + "type": "string" + }, + "object": { + "type": "string", + "const": "batch" + }, + "status": { + "type": "string", + "enum": [ + "validating", + "failed", + "in_progress", + "finalizing", + "completed", + "expired", + "cancelling", + "cancelled" + ] + }, + "cancelled_at": { + "type": "integer" + }, + "cancelling_at": { + "type": "integer" + }, + "completed_at": { + "type": "integer" + }, + "error_file_id": { + "type": "string" + }, + "errors": { + "type": "object", + "properties": { + "data": { + "type": "array", + "items": { + "type": "object", + "properties": { + "code": { + "type": "string" + }, + "line": { + "type": "integer" + }, + "message": { + "type": "string" + }, + "param": { + "type": "string" + } + }, + "additionalProperties": false, + "title": "BatchError" + } + }, + "object": { + "type": "string" + } + }, + "additionalProperties": false, + "title": "Errors" + }, + "expired_at": { + "type": "integer" + }, + "expires_at": { + "type": "integer" + }, + "failed_at": { + "type": "integer" + }, + "finalizing_at": { + "type": "integer" + }, + "in_progress_at": { + "type": "integer" + }, + "metadata": { + "type": "object", + "additionalProperties": { + "type": "string" + } + }, + "model": { + "type": "string" + }, + "output_file_id": { + "type": "string" + }, + "request_counts": { + "type": "object", + "properties": { + "completed": { + "type": "integer" + }, + "failed": { + "type": "integer" + }, + "total": { + "type": "integer" + } + }, + "additionalProperties": false, + "required": [ + "completed", + "failed", + "total" + ], + "title": "BatchRequestCounts" + }, + "usage": { + "type": "object", + "properties": { + "input_tokens": { + "type": "integer" + }, + "input_tokens_details": { + "type": "object", + "properties": { + "cached_tokens": { + "type": "integer" + } + }, + "additionalProperties": false, + "required": [ + "cached_tokens" + ], + "title": "InputTokensDetails" + }, + "output_tokens": { + "type": "integer" + }, + "output_tokens_details": { + "type": "object", + "properties": { + "reasoning_tokens": { + "type": "integer" + } + }, + "additionalProperties": false, + "required": [ + "reasoning_tokens" + ], + "title": "OutputTokensDetails" + }, + "total_tokens": { + "type": "integer" + } + }, + "additionalProperties": false, + "required": [ + "input_tokens", + "input_tokens_details", + "output_tokens", + "output_tokens_details", + "total_tokens" + ], + "title": "BatchUsage" + } + }, + "additionalProperties": false, + "required": [ + "id", + "completion_window", + "created_at", + "endpoint", + "input_file_id", + "object", + "status" + ], + "title": "Batch" + } + }, + "first_id": { + "type": "string" + }, + "last_id": { + "type": "string" + }, + "has_more": { + "type": "boolean", + "default": false + } + }, + "additionalProperties": false, + "required": [ + "object", + "data", + "has_more" + ], + "title": "ListBatchesResponse", + "description": "Response containing a list of batch objects." + }, + "CreateBatchRequest": { + "type": "object", + "properties": { + "input_file_id": { + "type": "string", + "description": "The ID of an uploaded file containing requests for the batch." + }, + "endpoint": { + "type": "string", + "description": "The endpoint to be used for all requests in the batch." + }, + "completion_window": { + "type": "string", + "const": "24h", + "description": "The time window within which the batch should be processed." + }, + "metadata": { + "type": "object", + "additionalProperties": { + "type": "string" + }, + "description": "Optional metadata for the batch." + }, + "idempotency_key": { + "type": "string", + "description": "Optional idempotency key. When provided, enables idempotent behavior." + } + }, + "additionalProperties": false, + "required": [ + "input_file_id", + "endpoint", + "completion_window" + ], + "title": "CreateBatchRequest" + }, + "Batch": { + "type": "object", + "properties": { + "id": { + "type": "string" + }, + "completion_window": { + "type": "string" + }, + "created_at": { + "type": "integer" + }, + "endpoint": { + "type": "string" + }, + "input_file_id": { + "type": "string" + }, + "object": { + "type": "string", + "const": "batch" + }, + "status": { + "type": "string", + "enum": [ + "validating", + "failed", + "in_progress", + "finalizing", + "completed", + "expired", + "cancelling", + "cancelled" + ] + }, + "cancelled_at": { + "type": "integer" + }, + "cancelling_at": { + "type": "integer" + }, + "completed_at": { + "type": "integer" + }, + "error_file_id": { + "type": "string" + }, + "errors": { + "type": "object", + "properties": { + "data": { + "type": "array", + "items": { + "type": "object", + "properties": { + "code": { + "type": "string" + }, + "line": { + "type": "integer" + }, + "message": { + "type": "string" + }, + "param": { + "type": "string" + } + }, + "additionalProperties": false, + "title": "BatchError" + } + }, + "object": { + "type": "string" + } + }, + "additionalProperties": false, + "title": "Errors" + }, + "expired_at": { + "type": "integer" + }, + "expires_at": { + "type": "integer" + }, + "failed_at": { + "type": "integer" + }, + "finalizing_at": { + "type": "integer" + }, + "in_progress_at": { + "type": "integer" + }, + "metadata": { + "type": "object", + "additionalProperties": { + "type": "string" + } + }, + "model": { + "type": "string" + }, + "output_file_id": { + "type": "string" + }, + "request_counts": { + "type": "object", + "properties": { + "completed": { + "type": "integer" + }, + "failed": { + "type": "integer" + }, + "total": { + "type": "integer" + } + }, + "additionalProperties": false, + "required": [ + "completed", + "failed", + "total" + ], + "title": "BatchRequestCounts" + }, + "usage": { + "type": "object", + "properties": { + "input_tokens": { + "type": "integer" + }, + "input_tokens_details": { + "type": "object", + "properties": { + "cached_tokens": { + "type": "integer" + } + }, + "additionalProperties": false, + "required": [ + "cached_tokens" + ], + "title": "InputTokensDetails" + }, + "output_tokens": { + "type": "integer" + }, + "output_tokens_details": { + "type": "object", + "properties": { + "reasoning_tokens": { + "type": "integer" + } + }, + "additionalProperties": false, + "required": [ + "reasoning_tokens" + ], + "title": "OutputTokensDetails" + }, + "total_tokens": { + "type": "integer" + } + }, + "additionalProperties": false, + "required": [ + "input_tokens", + "input_tokens_details", + "output_tokens", + "output_tokens_details", + "total_tokens" + ], + "title": "BatchUsage" + } + }, + "additionalProperties": false, + "required": [ + "id", + "completion_window", + "created_at", + "endpoint", + "input_file_id", + "object", + "status" + ], + "title": "Batch" + }, "Order": { "type": "string", "enum": [ @@ -10830,44 +11433,46 @@ ], "title": "RegisterShieldRequest" }, - "CompletionMessage": { + "InvokeToolRequest": { "type": "object", "properties": { - "role": { + "tool_name": { "type": "string", - "const": "assistant", - "default": "assistant", - "description": "Must be \"assistant\" to identify this as the model's response" + "description": "The name of the tool to invoke." }, - "content": { - "$ref": "#/components/schemas/InterleavedContent", - "description": "The content of the model's response" - }, - "stop_reason": { - "type": "string", - "enum": [ - "end_of_turn", - "end_of_message", - "out_of_tokens" - ], - "description": "Reason why the model stopped generating. Options are: - `StopReason.end_of_turn`: The model finished generating the entire response. - `StopReason.end_of_message`: The model finished generating but generated a partial response -- usually, a tool call. The user may call the tool and continue the conversation with the tool's response. - `StopReason.out_of_tokens`: The model ran out of token budget." - }, - "tool_calls": { - "type": "array", - "items": { - "$ref": "#/components/schemas/ToolCall" + "kwargs": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] }, - "description": "List of tool calls. Each tool call is a ToolCall object." + "description": "A dictionary of arguments to pass to the tool." } }, "additionalProperties": false, "required": [ - "role", - "content", - "stop_reason" + "tool_name", + "kwargs" ], - "title": "CompletionMessage", - "description": "A message containing the model's (assistant) response in a chat conversation." + "title": "InvokeToolRequest" }, "ImageContentItem": { "type": "object", @@ -10936,53 +11541,6 @@ } } }, - "Message": { - "oneOf": [ - { - "$ref": "#/components/schemas/UserMessage" - }, - { - "$ref": "#/components/schemas/SystemMessage" - }, - { - "$ref": "#/components/schemas/ToolResponseMessage" - }, - { - "$ref": "#/components/schemas/CompletionMessage" - } - ], - "discriminator": { - "propertyName": "role", - "mapping": { - "user": "#/components/schemas/UserMessage", - "system": "#/components/schemas/SystemMessage", - "tool": "#/components/schemas/ToolResponseMessage", - "assistant": "#/components/schemas/CompletionMessage" - } - } - }, - "SystemMessage": { - "type": "object", - "properties": { - "role": { - "type": "string", - "const": "system", - "default": "system", - "description": "Must be \"system\" to identify this as a system message" - }, - "content": { - "$ref": "#/components/schemas/InterleavedContent", - "description": "The content of the \"system prompt\". If multiple system messages are provided, they are concatenated. The underlying Llama Stack code may also add other system messages (for example, for formatting tool definitions)." - } - }, - "additionalProperties": false, - "required": [ - "role", - "content" - ], - "title": "SystemMessage", - "description": "A system message providing instructions or context to the model." - }, "TextContentItem": { "type": "object", "properties": { @@ -11005,250 +11563,6 @@ "title": "TextContentItem", "description": "A text content item" }, - "ToolCall": { - "type": "object", - "properties": { - "call_id": { - "type": "string" - }, - "tool_name": { - "oneOf": [ - { - "type": "string", - "enum": [ - "brave_search", - "wolfram_alpha", - "photogen", - "code_interpreter" - ], - "title": "BuiltinTool" - }, - { - "type": "string" - } - ] - }, - "arguments": { - "type": "string" - } - }, - "additionalProperties": false, - "required": [ - "call_id", - "tool_name", - "arguments" - ], - "title": "ToolCall" - }, - "ToolResponseMessage": { - "type": "object", - "properties": { - "role": { - "type": "string", - "const": "tool", - "default": "tool", - "description": "Must be \"tool\" to identify this as a tool response" - }, - "call_id": { - "type": "string", - "description": "Unique identifier for the tool call this response is for" - }, - "content": { - "$ref": "#/components/schemas/InterleavedContent", - "description": "The response content from the tool" - } - }, - "additionalProperties": false, - "required": [ - "role", - "call_id", - "content" - ], - "title": "ToolResponseMessage", - "description": "A message representing the result of a tool invocation." - }, - "URL": { - "type": "object", - "properties": { - "uri": { - "type": "string", - "description": "The URL string pointing to the resource" - } - }, - "additionalProperties": false, - "required": [ - "uri" - ], - "title": "URL", - "description": "A URL reference to external content." - }, - "UserMessage": { - "type": "object", - "properties": { - "role": { - "type": "string", - "const": "user", - "default": "user", - "description": "Must be \"user\" to identify this as a user message" - }, - "content": { - "$ref": "#/components/schemas/InterleavedContent", - "description": "The content of the message, which can include text and other media" - }, - "context": { - "$ref": "#/components/schemas/InterleavedContent", - "description": "(Optional) This field is used internally by Llama Stack to pass RAG context. This field may be removed in the API in the future." - } - }, - "additionalProperties": false, - "required": [ - "role", - "content" - ], - "title": "UserMessage", - "description": "A message from the user in a chat conversation." - }, - "SyntheticDataGenerateRequest": { - "type": "object", - "properties": { - "dialogs": { - "type": "array", - "items": { - "$ref": "#/components/schemas/Message" - }, - "description": "List of conversation messages to use as input for synthetic data generation" - }, - "filtering_function": { - "type": "string", - "enum": [ - "none", - "random", - "top_k", - "top_p", - "top_k_top_p", - "sigmoid" - ], - "description": "Type of filtering to apply to generated synthetic data samples" - }, - "model": { - "type": "string", - "description": "(Optional) The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint" - } - }, - "additionalProperties": false, - "required": [ - "dialogs", - "filtering_function" - ], - "title": "SyntheticDataGenerateRequest" - }, - "SyntheticDataGenerationResponse": { - "type": "object", - "properties": { - "synthetic_data": { - "type": "array", - "items": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } - }, - "description": "List of generated synthetic data samples that passed the filtering criteria" - }, - "statistics": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - }, - "description": "(Optional) Statistical information about the generation process and filtering results" - } - }, - "additionalProperties": false, - "required": [ - "synthetic_data" - ], - "title": "SyntheticDataGenerationResponse", - "description": "Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold." - }, - "InvokeToolRequest": { - "type": "object", - "properties": { - "tool_name": { - "type": "string", - "description": "The name of the tool to invoke." - }, - "kwargs": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - }, - "description": "A dictionary of arguments to pass to the tool." - } - }, - "additionalProperties": false, - "required": [ - "tool_name", - "kwargs" - ], - "title": "InvokeToolRequest" - }, "ToolInvocationResult": { "type": "object", "properties": { @@ -11295,6 +11609,21 @@ "title": "ToolInvocationResult", "description": "Result of a tool invocation." }, + "URL": { + "type": "object", + "properties": { + "uri": { + "type": "string", + "description": "The URL string pointing to the resource" + } + }, + "additionalProperties": false, + "required": [ + "uri" + ], + "title": "URL", + "description": "A URL reference to external content." + }, "ToolDef": { "type": "object", "properties": { @@ -11897,6 +12226,10 @@ "$ref": "#/components/schemas/InterleavedContent", "description": "The content of the chunk, which can be interleaved text, images, or other types." }, + "chunk_id": { + "type": "string", + "description": "Unique identifier for the chunk. Must be provided explicitly." + }, "metadata": { "type": "object", "additionalProperties": { @@ -11930,10 +12263,6 @@ }, "description": "Optional embedding for the chunk. If not provided, it will be computed later." }, - "stored_chunk_id": { - "type": "string", - "description": "The chunk ID that is stored in the vector database. Used for backend functionality." - }, "chunk_metadata": { "$ref": "#/components/schemas/ChunkMetadata", "description": "Metadata for the chunk that will NOT be used in the context during inference. The `chunk_metadata` is required backend functionality." @@ -11942,6 +12271,7 @@ "additionalProperties": false, "required": [ "content", + "chunk_id", "metadata" ], "title": "Chunk", @@ -13288,6 +13618,11 @@ "description": "APIs for creating and interacting with agentic systems.\n\n## Responses API\n\nThe Responses API provides OpenAI-compatible functionality with enhanced capabilities for dynamic, stateful interactions.\n\n> **✅ STABLE**: This API is production-ready with backward compatibility guarantees. Recommended for production applications.\n\n### ✅ Supported Tools\n\nThe Responses API supports the following tool types:\n\n- **`web_search`**: Search the web for current information and real-time data\n- **`file_search`**: Search through uploaded files and vector stores\n - Supports dynamic `vector_store_ids` per call\n - Compatible with OpenAI file search patterns\n- **`function`**: Call custom functions with JSON schema validation\n- **`mcp_tool`**: Model Context Protocol integration\n\n### ✅ Supported Fields & Features\n\n**Core Capabilities:**\n- **Dynamic Configuration**: Switch models, vector stores, and tools per request without pre-configuration\n- **Conversation Branching**: Use `previous_response_id` to branch conversations and explore different paths\n- **Rich Annotations**: Automatic file citations, URL citations, and container file citations\n- **Status Tracking**: Monitor tool call execution status and handle failures gracefully\n\n### 🚧 Work in Progress\n\n- Full real-time response streaming support\n- `tool_choice` parameter\n- `max_tool_calls` parameter\n- Built-in tools (code interpreter, containers API)\n- Safety & guardrails\n- `reasoning` capabilities\n- `service_tier`\n- `logprobs`\n- `max_output_tokens`\n- `metadata` handling\n- `instructions`\n- `incomplete_details`\n- `background`", "x-displayName": "Agents" }, + { + "name": "Batches", + "description": "The API is designed to allow use of openai client libraries for seamless integration.\n\nThis API provides the following extensions:\n - idempotent batch creation\n\nNote: This API is currently under active development and may undergo changes.", + "x-displayName": "The Batches API enables efficient processing of multiple requests in a single operation, particularly useful for processing large datasets, batch evaluation workflows, and cost-effective inference at scale." + }, { "name": "Conversations", "description": "Protocol for conversation management operations.", @@ -13339,10 +13674,6 @@ "name": "Shields", "description": "" }, - { - "name": "SyntheticDataGeneration (Coming Soon)", - "description": "" - }, { "name": "ToolGroups", "description": "" @@ -13361,6 +13692,7 @@ "name": "Operations", "tags": [ "Agents", + "Batches", "Conversations", "Files", "Inference", @@ -13372,7 +13704,6 @@ "Scoring", "ScoringFunctions", "Shields", - "SyntheticDataGeneration (Coming Soon)", "ToolGroups", "ToolRuntime", "VectorIO" diff --git a/docs/static/llama-stack-spec.yaml b/docs/static/llama-stack-spec.yaml index ae582580b..d366a2dd8 100644 --- a/docs/static/llama-stack-spec.yaml +++ b/docs/static/llama-stack-spec.yaml @@ -12,6 +12,141 @@ info: servers: - url: http://any-hosted-llama-stack.com paths: + /v1/batches: + get: + responses: + '200': + description: A list of batch objects. + content: + application/json: + schema: + $ref: '#/components/schemas/ListBatchesResponse' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Batches + summary: List all batches for the current user. + description: List all batches for the current user. + parameters: + - name: after + in: query + description: >- + A cursor for pagination; returns batches after this batch ID. + required: false + schema: + type: string + - name: limit + in: query + description: >- + Number of batches to return (default 20, max 100). + required: true + schema: + type: integer + deprecated: false + post: + responses: + '200': + description: The created batch object. + content: + application/json: + schema: + $ref: '#/components/schemas/Batch' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Batches + summary: >- + Create a new batch for processing multiple API requests. + description: >- + Create a new batch for processing multiple API requests. + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/CreateBatchRequest' + required: true + deprecated: false + /v1/batches/{batch_id}: + get: + responses: + '200': + description: The batch object. + content: + application/json: + schema: + $ref: '#/components/schemas/Batch' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Batches + summary: >- + Retrieve information about a specific batch. + description: >- + Retrieve information about a specific batch. + parameters: + - name: batch_id + in: path + description: The ID of the batch to retrieve. + required: true + schema: + type: string + deprecated: false + /v1/batches/{batch_id}/cancel: + post: + responses: + '200': + description: The updated batch object. + content: + application/json: + schema: + $ref: '#/components/schemas/Batch' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Batches + summary: Cancel a batch that is in progress. + description: Cancel a batch that is in progress. + parameters: + - name: batch_id + in: path + description: The ID of the batch to cancel. + required: true + schema: + type: string + deprecated: false /v1/chat/completions: get: responses: @@ -818,7 +953,22 @@ paths: List routes. List all available API routes with their methods and implementing providers. - parameters: [] + parameters: + - name: api_filter + in: query + description: >- + Optional filter to control which routes are returned. Can be an API level + ('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level, + or 'deprecated' to show deprecated routes across all levels. If not specified, + returns only non-deprecated v1 routes. + required: false + schema: + type: string + enum: + - v1 + - v1alpha + - v1beta + - deprecated deprecated: false /v1/models: get: @@ -976,6 +1126,31 @@ paths: $ref: '#/components/schemas/RunModerationRequest' required: true deprecated: false + /v1/openai/v1/models: + get: + responses: + '200': + description: A OpenAIListModelsResponse. + content: + application/json: + schema: + $ref: '#/components/schemas/OpenAIListModelsResponse' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Models + summary: List models using the OpenAI API. + description: List models using the OpenAI API. + parameters: [] + deprecated: false /v1/prompts: get: responses: @@ -1832,40 +2007,6 @@ paths: schema: type: string deprecated: false - /v1/synthetic-data-generation/generate: - post: - responses: - '200': - description: >- - Response containing filtered synthetic data samples and optional statistics - content: - application/json: - schema: - $ref: '#/components/schemas/SyntheticDataGenerationResponse' - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - SyntheticDataGeneration (Coming Soon) - summary: >- - Generate synthetic data based on input dialogs and apply filtering. - description: >- - Generate synthetic data based on input dialogs and apply filtering. - parameters: [] - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/SyntheticDataGenerateRequest' - required: true - deprecated: false /v1/tool-runtime/invoke: post: responses: @@ -2999,6 +3140,331 @@ components: title: Error description: >- Error response from the API. Roughly follows RFC 7807. + ListBatchesResponse: + type: object + properties: + object: + type: string + const: list + default: list + data: + type: array + items: + type: object + properties: + id: + type: string + completion_window: + type: string + created_at: + type: integer + endpoint: + type: string + input_file_id: + type: string + object: + type: string + const: batch + status: + type: string + enum: + - validating + - failed + - in_progress + - finalizing + - completed + - expired + - cancelling + - cancelled + cancelled_at: + type: integer + cancelling_at: + type: integer + completed_at: + type: integer + error_file_id: + type: string + errors: + type: object + properties: + data: + type: array + items: + type: object + properties: + code: + type: string + line: + type: integer + message: + type: string + param: + type: string + additionalProperties: false + title: BatchError + object: + type: string + additionalProperties: false + title: Errors + expired_at: + type: integer + expires_at: + type: integer + failed_at: + type: integer + finalizing_at: + type: integer + in_progress_at: + type: integer + metadata: + type: object + additionalProperties: + type: string + model: + type: string + output_file_id: + type: string + request_counts: + type: object + properties: + completed: + type: integer + failed: + type: integer + total: + type: integer + additionalProperties: false + required: + - completed + - failed + - total + title: BatchRequestCounts + usage: + type: object + properties: + input_tokens: + type: integer + input_tokens_details: + type: object + properties: + cached_tokens: + type: integer + additionalProperties: false + required: + - cached_tokens + title: InputTokensDetails + output_tokens: + type: integer + output_tokens_details: + type: object + properties: + reasoning_tokens: + type: integer + additionalProperties: false + required: + - reasoning_tokens + title: OutputTokensDetails + total_tokens: + type: integer + additionalProperties: false + required: + - input_tokens + - input_tokens_details + - output_tokens + - output_tokens_details + - total_tokens + title: BatchUsage + additionalProperties: false + required: + - id + - completion_window + - created_at + - endpoint + - input_file_id + - object + - status + title: Batch + first_id: + type: string + last_id: + type: string + has_more: + type: boolean + default: false + additionalProperties: false + required: + - object + - data + - has_more + title: ListBatchesResponse + description: >- + Response containing a list of batch objects. + CreateBatchRequest: + type: object + properties: + input_file_id: + type: string + description: >- + The ID of an uploaded file containing requests for the batch. + endpoint: + type: string + description: >- + The endpoint to be used for all requests in the batch. + completion_window: + type: string + const: 24h + description: >- + The time window within which the batch should be processed. + metadata: + type: object + additionalProperties: + type: string + description: Optional metadata for the batch. + idempotency_key: + type: string + description: >- + Optional idempotency key. When provided, enables idempotent behavior. + additionalProperties: false + required: + - input_file_id + - endpoint + - completion_window + title: CreateBatchRequest + Batch: + type: object + properties: + id: + type: string + completion_window: + type: string + created_at: + type: integer + endpoint: + type: string + input_file_id: + type: string + object: + type: string + const: batch + status: + type: string + enum: + - validating + - failed + - in_progress + - finalizing + - completed + - expired + - cancelling + - cancelled + cancelled_at: + type: integer + cancelling_at: + type: integer + completed_at: + type: integer + error_file_id: + type: string + errors: + type: object + properties: + data: + type: array + items: + type: object + properties: + code: + type: string + line: + type: integer + message: + type: string + param: + type: string + additionalProperties: false + title: BatchError + object: + type: string + additionalProperties: false + title: Errors + expired_at: + type: integer + expires_at: + type: integer + failed_at: + type: integer + finalizing_at: + type: integer + in_progress_at: + type: integer + metadata: + type: object + additionalProperties: + type: string + model: + type: string + output_file_id: + type: string + request_counts: + type: object + properties: + completed: + type: integer + failed: + type: integer + total: + type: integer + additionalProperties: false + required: + - completed + - failed + - total + title: BatchRequestCounts + usage: + type: object + properties: + input_tokens: + type: integer + input_tokens_details: + type: object + properties: + cached_tokens: + type: integer + additionalProperties: false + required: + - cached_tokens + title: InputTokensDetails + output_tokens: + type: integer + output_tokens_details: + type: object + properties: + reasoning_tokens: + type: integer + additionalProperties: false + required: + - reasoning_tokens + title: OutputTokensDetails + total_tokens: + type: integer + additionalProperties: false + required: + - input_tokens + - input_tokens_details + - output_tokens + - output_tokens_details + - total_tokens + title: BatchUsage + additionalProperties: false + required: + - id + - completion_window + - created_at + - endpoint + - input_file_id + - object + - status + title: Batch Order: type: string enum: @@ -5341,6 +5807,48 @@ components: - metadata title: ModerationObjectResults description: A moderation object. + OpenAIModel: + type: object + properties: + id: + type: string + object: + type: string + const: model + default: model + created: + type: integer + owned_by: + type: string + custom_metadata: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + additionalProperties: false + required: + - id + - object + - created + - owned_by + title: OpenAIModel + description: A model from OpenAI. + OpenAIListModelsResponse: + type: object + properties: + data: + type: array + items: + $ref: '#/components/schemas/OpenAIModel' + additionalProperties: false + required: + - data + title: OpenAIListModelsResponse Prompt: type: object properties: @@ -8265,45 +8773,29 @@ components: required: - shield_id title: RegisterShieldRequest - CompletionMessage: + InvokeToolRequest: type: object properties: - role: + tool_name: type: string - const: assistant - default: assistant + description: The name of the tool to invoke. + kwargs: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object description: >- - Must be "assistant" to identify this as the model's response - content: - $ref: '#/components/schemas/InterleavedContent' - description: The content of the model's response - stop_reason: - type: string - enum: - - end_of_turn - - end_of_message - - out_of_tokens - description: >- - Reason why the model stopped generating. Options are: - `StopReason.end_of_turn`: - The model finished generating the entire response. - `StopReason.end_of_message`: - The model finished generating but generated a partial response -- usually, - a tool call. The user may call the tool and continue the conversation - with the tool's response. - `StopReason.out_of_tokens`: The model ran - out of token budget. - tool_calls: - type: array - items: - $ref: '#/components/schemas/ToolCall' - description: >- - List of tool calls. Each tool call is a ToolCall object. + A dictionary of arguments to pass to the tool. additionalProperties: false required: - - role - - content - - stop_reason - title: CompletionMessage - description: >- - A message containing the model's (assistant) response in a chat conversation. + - tool_name + - kwargs + title: InvokeToolRequest ImageContentItem: type: object properties: @@ -8350,41 +8842,6 @@ components: mapping: image: '#/components/schemas/ImageContentItem' text: '#/components/schemas/TextContentItem' - Message: - oneOf: - - $ref: '#/components/schemas/UserMessage' - - $ref: '#/components/schemas/SystemMessage' - - $ref: '#/components/schemas/ToolResponseMessage' - - $ref: '#/components/schemas/CompletionMessage' - discriminator: - propertyName: role - mapping: - user: '#/components/schemas/UserMessage' - system: '#/components/schemas/SystemMessage' - tool: '#/components/schemas/ToolResponseMessage' - assistant: '#/components/schemas/CompletionMessage' - SystemMessage: - type: object - properties: - role: - type: string - const: system - default: system - description: >- - Must be "system" to identify this as a system message - content: - $ref: '#/components/schemas/InterleavedContent' - description: >- - The content of the "system prompt". If multiple system messages are provided, - they are concatenated. The underlying Llama Stack code may also add other - system messages (for example, for formatting tool definitions). - additionalProperties: false - required: - - role - - content - title: SystemMessage - description: >- - A system message providing instructions or context to the model. TextContentItem: type: object properties: @@ -8403,179 +8860,6 @@ components: - text title: TextContentItem description: A text content item - ToolCall: - type: object - properties: - call_id: - type: string - tool_name: - oneOf: - - type: string - enum: - - brave_search - - wolfram_alpha - - photogen - - code_interpreter - title: BuiltinTool - - type: string - arguments: - type: string - additionalProperties: false - required: - - call_id - - tool_name - - arguments - title: ToolCall - ToolResponseMessage: - type: object - properties: - role: - type: string - const: tool - default: tool - description: >- - Must be "tool" to identify this as a tool response - call_id: - type: string - description: >- - Unique identifier for the tool call this response is for - content: - $ref: '#/components/schemas/InterleavedContent' - description: The response content from the tool - additionalProperties: false - required: - - role - - call_id - - content - title: ToolResponseMessage - description: >- - A message representing the result of a tool invocation. - URL: - type: object - properties: - uri: - type: string - description: The URL string pointing to the resource - additionalProperties: false - required: - - uri - title: URL - description: A URL reference to external content. - UserMessage: - type: object - properties: - role: - type: string - const: user - default: user - description: >- - Must be "user" to identify this as a user message - content: - $ref: '#/components/schemas/InterleavedContent' - description: >- - The content of the message, which can include text and other media - context: - $ref: '#/components/schemas/InterleavedContent' - description: >- - (Optional) This field is used internally by Llama Stack to pass RAG context. - This field may be removed in the API in the future. - additionalProperties: false - required: - - role - - content - title: UserMessage - description: >- - A message from the user in a chat conversation. - SyntheticDataGenerateRequest: - type: object - properties: - dialogs: - type: array - items: - $ref: '#/components/schemas/Message' - description: >- - List of conversation messages to use as input for synthetic data generation - filtering_function: - type: string - enum: - - none - - random - - top_k - - top_p - - top_k_top_p - - sigmoid - description: >- - Type of filtering to apply to generated synthetic data samples - model: - type: string - description: >- - (Optional) The identifier of the model to use. The model must be registered - with Llama Stack and available via the /models endpoint - additionalProperties: false - required: - - dialogs - - filtering_function - title: SyntheticDataGenerateRequest - SyntheticDataGenerationResponse: - type: object - properties: - synthetic_data: - type: array - items: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: >- - List of generated synthetic data samples that passed the filtering criteria - statistics: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: >- - (Optional) Statistical information about the generation process and filtering - results - additionalProperties: false - required: - - synthetic_data - title: SyntheticDataGenerationResponse - description: >- - Response from the synthetic data generation. Batch of (prompt, response, score) - tuples that pass the threshold. - InvokeToolRequest: - type: object - properties: - tool_name: - type: string - description: The name of the tool to invoke. - kwargs: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: >- - A dictionary of arguments to pass to the tool. - additionalProperties: false - required: - - tool_name - - kwargs - title: InvokeToolRequest ToolInvocationResult: type: object properties: @@ -8606,6 +8890,17 @@ components: additionalProperties: false title: ToolInvocationResult description: Result of a tool invocation. + URL: + type: object + properties: + uri: + type: string + description: The URL string pointing to the resource + additionalProperties: false + required: + - uri + title: URL + description: A URL reference to external content. ToolDef: type: object properties: @@ -9045,6 +9340,10 @@ components: description: >- The content of the chunk, which can be interleaved text, images, or other types. + chunk_id: + type: string + description: >- + Unique identifier for the chunk. Must be provided explicitly. metadata: type: object additionalProperties: @@ -9065,10 +9364,6 @@ components: description: >- Optional embedding for the chunk. If not provided, it will be computed later. - stored_chunk_id: - type: string - description: >- - The chunk ID that is stored in the vector database. Used for backend functionality. chunk_metadata: $ref: '#/components/schemas/ChunkMetadata' description: >- @@ -9077,6 +9372,7 @@ components: additionalProperties: false required: - content + - chunk_id - metadata title: Chunk description: >- @@ -10143,6 +10439,19 @@ tags: - `background` x-displayName: Agents + - name: Batches + description: >- + The API is designed to allow use of openai client libraries for seamless integration. + + + This API provides the following extensions: + - idempotent batch creation + + Note: This API is currently under active development and may undergo changes. + x-displayName: >- + The Batches API enables efficient processing of multiple requests in a single + operation, particularly useful for processing large datasets, batch evaluation + workflows, and cost-effective inference at scale. - name: Conversations description: >- Protocol for conversation management operations. @@ -10193,8 +10502,6 @@ tags: description: '' - name: Shields description: '' - - name: SyntheticDataGeneration (Coming Soon) - description: '' - name: ToolGroups description: '' - name: ToolRuntime @@ -10205,6 +10512,7 @@ x-tagGroups: - name: Operations tags: - Agents + - Batches - Conversations - Files - Inference @@ -10216,7 +10524,6 @@ x-tagGroups: - Scoring - ScoringFunctions - Shields - - SyntheticDataGeneration (Coming Soon) - ToolGroups - ToolRuntime - VectorIO diff --git a/docs/static/stainless-llama-stack-spec.html b/docs/static/stainless-llama-stack-spec.html deleted file mode 100644 index daa5db20e..000000000 --- a/docs/static/stainless-llama-stack-spec.html +++ /dev/null @@ -1,18091 +0,0 @@ - - - - - - - OpenAPI specification - - - - - - - - - - - - - diff --git a/docs/static/stainless-llama-stack-spec.yaml b/docs/static/stainless-llama-stack-spec.yaml index d992b72eb..b080a9efd 100644 --- a/docs/static/stainless-llama-stack-spec.yaml +++ b/docs/static/stainless-llama-stack-spec.yaml @@ -15,6 +15,141 @@ info: servers: - url: http://any-hosted-llama-stack.com paths: + /v1/batches: + get: + responses: + '200': + description: A list of batch objects. + content: + application/json: + schema: + $ref: '#/components/schemas/ListBatchesResponse' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Batches + summary: List all batches for the current user. + description: List all batches for the current user. + parameters: + - name: after + in: query + description: >- + A cursor for pagination; returns batches after this batch ID. + required: false + schema: + type: string + - name: limit + in: query + description: >- + Number of batches to return (default 20, max 100). + required: true + schema: + type: integer + deprecated: false + post: + responses: + '200': + description: The created batch object. + content: + application/json: + schema: + $ref: '#/components/schemas/Batch' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Batches + summary: >- + Create a new batch for processing multiple API requests. + description: >- + Create a new batch for processing multiple API requests. + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/CreateBatchRequest' + required: true + deprecated: false + /v1/batches/{batch_id}: + get: + responses: + '200': + description: The batch object. + content: + application/json: + schema: + $ref: '#/components/schemas/Batch' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Batches + summary: >- + Retrieve information about a specific batch. + description: >- + Retrieve information about a specific batch. + parameters: + - name: batch_id + in: path + description: The ID of the batch to retrieve. + required: true + schema: + type: string + deprecated: false + /v1/batches/{batch_id}/cancel: + post: + responses: + '200': + description: The updated batch object. + content: + application/json: + schema: + $ref: '#/components/schemas/Batch' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Batches + summary: Cancel a batch that is in progress. + description: Cancel a batch that is in progress. + parameters: + - name: batch_id + in: path + description: The ID of the batch to cancel. + required: true + schema: + type: string + deprecated: false /v1/chat/completions: get: responses: @@ -821,7 +956,22 @@ paths: List routes. List all available API routes with their methods and implementing providers. - parameters: [] + parameters: + - name: api_filter + in: query + description: >- + Optional filter to control which routes are returned. Can be an API level + ('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level, + or 'deprecated' to show deprecated routes across all levels. If not specified, + returns only non-deprecated v1 routes. + required: false + schema: + type: string + enum: + - v1 + - v1alpha + - v1beta + - deprecated deprecated: false /v1/models: get: @@ -979,6 +1129,31 @@ paths: $ref: '#/components/schemas/RunModerationRequest' required: true deprecated: false + /v1/openai/v1/models: + get: + responses: + '200': + description: A OpenAIListModelsResponse. + content: + application/json: + schema: + $ref: '#/components/schemas/OpenAIListModelsResponse' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Models + summary: List models using the OpenAI API. + description: List models using the OpenAI API. + parameters: [] + deprecated: false /v1/prompts: get: responses: @@ -1835,40 +2010,6 @@ paths: schema: type: string deprecated: false - /v1/synthetic-data-generation/generate: - post: - responses: - '200': - description: >- - Response containing filtered synthetic data samples and optional statistics - content: - application/json: - schema: - $ref: '#/components/schemas/SyntheticDataGenerationResponse' - '400': - $ref: '#/components/responses/BadRequest400' - '429': - $ref: >- - #/components/responses/TooManyRequests429 - '500': - $ref: >- - #/components/responses/InternalServerError500 - default: - $ref: '#/components/responses/DefaultError' - tags: - - SyntheticDataGeneration (Coming Soon) - summary: >- - Generate synthetic data based on input dialogs and apply filtering. - description: >- - Generate synthetic data based on input dialogs and apply filtering. - parameters: [] - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/SyntheticDataGenerateRequest' - required: true - deprecated: false /v1/tool-runtime/invoke: post: responses: @@ -4212,6 +4353,331 @@ components: title: Error description: >- Error response from the API. Roughly follows RFC 7807. + ListBatchesResponse: + type: object + properties: + object: + type: string + const: list + default: list + data: + type: array + items: + type: object + properties: + id: + type: string + completion_window: + type: string + created_at: + type: integer + endpoint: + type: string + input_file_id: + type: string + object: + type: string + const: batch + status: + type: string + enum: + - validating + - failed + - in_progress + - finalizing + - completed + - expired + - cancelling + - cancelled + cancelled_at: + type: integer + cancelling_at: + type: integer + completed_at: + type: integer + error_file_id: + type: string + errors: + type: object + properties: + data: + type: array + items: + type: object + properties: + code: + type: string + line: + type: integer + message: + type: string + param: + type: string + additionalProperties: false + title: BatchError + object: + type: string + additionalProperties: false + title: Errors + expired_at: + type: integer + expires_at: + type: integer + failed_at: + type: integer + finalizing_at: + type: integer + in_progress_at: + type: integer + metadata: + type: object + additionalProperties: + type: string + model: + type: string + output_file_id: + type: string + request_counts: + type: object + properties: + completed: + type: integer + failed: + type: integer + total: + type: integer + additionalProperties: false + required: + - completed + - failed + - total + title: BatchRequestCounts + usage: + type: object + properties: + input_tokens: + type: integer + input_tokens_details: + type: object + properties: + cached_tokens: + type: integer + additionalProperties: false + required: + - cached_tokens + title: InputTokensDetails + output_tokens: + type: integer + output_tokens_details: + type: object + properties: + reasoning_tokens: + type: integer + additionalProperties: false + required: + - reasoning_tokens + title: OutputTokensDetails + total_tokens: + type: integer + additionalProperties: false + required: + - input_tokens + - input_tokens_details + - output_tokens + - output_tokens_details + - total_tokens + title: BatchUsage + additionalProperties: false + required: + - id + - completion_window + - created_at + - endpoint + - input_file_id + - object + - status + title: Batch + first_id: + type: string + last_id: + type: string + has_more: + type: boolean + default: false + additionalProperties: false + required: + - object + - data + - has_more + title: ListBatchesResponse + description: >- + Response containing a list of batch objects. + CreateBatchRequest: + type: object + properties: + input_file_id: + type: string + description: >- + The ID of an uploaded file containing requests for the batch. + endpoint: + type: string + description: >- + The endpoint to be used for all requests in the batch. + completion_window: + type: string + const: 24h + description: >- + The time window within which the batch should be processed. + metadata: + type: object + additionalProperties: + type: string + description: Optional metadata for the batch. + idempotency_key: + type: string + description: >- + Optional idempotency key. When provided, enables idempotent behavior. + additionalProperties: false + required: + - input_file_id + - endpoint + - completion_window + title: CreateBatchRequest + Batch: + type: object + properties: + id: + type: string + completion_window: + type: string + created_at: + type: integer + endpoint: + type: string + input_file_id: + type: string + object: + type: string + const: batch + status: + type: string + enum: + - validating + - failed + - in_progress + - finalizing + - completed + - expired + - cancelling + - cancelled + cancelled_at: + type: integer + cancelling_at: + type: integer + completed_at: + type: integer + error_file_id: + type: string + errors: + type: object + properties: + data: + type: array + items: + type: object + properties: + code: + type: string + line: + type: integer + message: + type: string + param: + type: string + additionalProperties: false + title: BatchError + object: + type: string + additionalProperties: false + title: Errors + expired_at: + type: integer + expires_at: + type: integer + failed_at: + type: integer + finalizing_at: + type: integer + in_progress_at: + type: integer + metadata: + type: object + additionalProperties: + type: string + model: + type: string + output_file_id: + type: string + request_counts: + type: object + properties: + completed: + type: integer + failed: + type: integer + total: + type: integer + additionalProperties: false + required: + - completed + - failed + - total + title: BatchRequestCounts + usage: + type: object + properties: + input_tokens: + type: integer + input_tokens_details: + type: object + properties: + cached_tokens: + type: integer + additionalProperties: false + required: + - cached_tokens + title: InputTokensDetails + output_tokens: + type: integer + output_tokens_details: + type: object + properties: + reasoning_tokens: + type: integer + additionalProperties: false + required: + - reasoning_tokens + title: OutputTokensDetails + total_tokens: + type: integer + additionalProperties: false + required: + - input_tokens + - input_tokens_details + - output_tokens + - output_tokens_details + - total_tokens + title: BatchUsage + additionalProperties: false + required: + - id + - completion_window + - created_at + - endpoint + - input_file_id + - object + - status + title: Batch Order: type: string enum: @@ -6554,6 +7020,48 @@ components: - metadata title: ModerationObjectResults description: A moderation object. + OpenAIModel: + type: object + properties: + id: + type: string + object: + type: string + const: model + default: model + created: + type: integer + owned_by: + type: string + custom_metadata: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + additionalProperties: false + required: + - id + - object + - created + - owned_by + title: OpenAIModel + description: A model from OpenAI. + OpenAIListModelsResponse: + type: object + properties: + data: + type: array + items: + $ref: '#/components/schemas/OpenAIModel' + additionalProperties: false + required: + - data + title: OpenAIListModelsResponse Prompt: type: object properties: @@ -9478,45 +9986,29 @@ components: required: - shield_id title: RegisterShieldRequest - CompletionMessage: + InvokeToolRequest: type: object properties: - role: + tool_name: type: string - const: assistant - default: assistant + description: The name of the tool to invoke. + kwargs: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object description: >- - Must be "assistant" to identify this as the model's response - content: - $ref: '#/components/schemas/InterleavedContent' - description: The content of the model's response - stop_reason: - type: string - enum: - - end_of_turn - - end_of_message - - out_of_tokens - description: >- - Reason why the model stopped generating. Options are: - `StopReason.end_of_turn`: - The model finished generating the entire response. - `StopReason.end_of_message`: - The model finished generating but generated a partial response -- usually, - a tool call. The user may call the tool and continue the conversation - with the tool's response. - `StopReason.out_of_tokens`: The model ran - out of token budget. - tool_calls: - type: array - items: - $ref: '#/components/schemas/ToolCall' - description: >- - List of tool calls. Each tool call is a ToolCall object. + A dictionary of arguments to pass to the tool. additionalProperties: false required: - - role - - content - - stop_reason - title: CompletionMessage - description: >- - A message containing the model's (assistant) response in a chat conversation. + - tool_name + - kwargs + title: InvokeToolRequest ImageContentItem: type: object properties: @@ -9563,41 +10055,6 @@ components: mapping: image: '#/components/schemas/ImageContentItem' text: '#/components/schemas/TextContentItem' - Message: - oneOf: - - $ref: '#/components/schemas/UserMessage' - - $ref: '#/components/schemas/SystemMessage' - - $ref: '#/components/schemas/ToolResponseMessage' - - $ref: '#/components/schemas/CompletionMessage' - discriminator: - propertyName: role - mapping: - user: '#/components/schemas/UserMessage' - system: '#/components/schemas/SystemMessage' - tool: '#/components/schemas/ToolResponseMessage' - assistant: '#/components/schemas/CompletionMessage' - SystemMessage: - type: object - properties: - role: - type: string - const: system - default: system - description: >- - Must be "system" to identify this as a system message - content: - $ref: '#/components/schemas/InterleavedContent' - description: >- - The content of the "system prompt". If multiple system messages are provided, - they are concatenated. The underlying Llama Stack code may also add other - system messages (for example, for formatting tool definitions). - additionalProperties: false - required: - - role - - content - title: SystemMessage - description: >- - A system message providing instructions or context to the model. TextContentItem: type: object properties: @@ -9616,179 +10073,6 @@ components: - text title: TextContentItem description: A text content item - ToolCall: - type: object - properties: - call_id: - type: string - tool_name: - oneOf: - - type: string - enum: - - brave_search - - wolfram_alpha - - photogen - - code_interpreter - title: BuiltinTool - - type: string - arguments: - type: string - additionalProperties: false - required: - - call_id - - tool_name - - arguments - title: ToolCall - ToolResponseMessage: - type: object - properties: - role: - type: string - const: tool - default: tool - description: >- - Must be "tool" to identify this as a tool response - call_id: - type: string - description: >- - Unique identifier for the tool call this response is for - content: - $ref: '#/components/schemas/InterleavedContent' - description: The response content from the tool - additionalProperties: false - required: - - role - - call_id - - content - title: ToolResponseMessage - description: >- - A message representing the result of a tool invocation. - URL: - type: object - properties: - uri: - type: string - description: The URL string pointing to the resource - additionalProperties: false - required: - - uri - title: URL - description: A URL reference to external content. - UserMessage: - type: object - properties: - role: - type: string - const: user - default: user - description: >- - Must be "user" to identify this as a user message - content: - $ref: '#/components/schemas/InterleavedContent' - description: >- - The content of the message, which can include text and other media - context: - $ref: '#/components/schemas/InterleavedContent' - description: >- - (Optional) This field is used internally by Llama Stack to pass RAG context. - This field may be removed in the API in the future. - additionalProperties: false - required: - - role - - content - title: UserMessage - description: >- - A message from the user in a chat conversation. - SyntheticDataGenerateRequest: - type: object - properties: - dialogs: - type: array - items: - $ref: '#/components/schemas/Message' - description: >- - List of conversation messages to use as input for synthetic data generation - filtering_function: - type: string - enum: - - none - - random - - top_k - - top_p - - top_k_top_p - - sigmoid - description: >- - Type of filtering to apply to generated synthetic data samples - model: - type: string - description: >- - (Optional) The identifier of the model to use. The model must be registered - with Llama Stack and available via the /models endpoint - additionalProperties: false - required: - - dialogs - - filtering_function - title: SyntheticDataGenerateRequest - SyntheticDataGenerationResponse: - type: object - properties: - synthetic_data: - type: array - items: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: >- - List of generated synthetic data samples that passed the filtering criteria - statistics: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: >- - (Optional) Statistical information about the generation process and filtering - results - additionalProperties: false - required: - - synthetic_data - title: SyntheticDataGenerationResponse - description: >- - Response from the synthetic data generation. Batch of (prompt, response, score) - tuples that pass the threshold. - InvokeToolRequest: - type: object - properties: - tool_name: - type: string - description: The name of the tool to invoke. - kwargs: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - description: >- - A dictionary of arguments to pass to the tool. - additionalProperties: false - required: - - tool_name - - kwargs - title: InvokeToolRequest ToolInvocationResult: type: object properties: @@ -9819,6 +10103,17 @@ components: additionalProperties: false title: ToolInvocationResult description: Result of a tool invocation. + URL: + type: object + properties: + uri: + type: string + description: The URL string pointing to the resource + additionalProperties: false + required: + - uri + title: URL + description: A URL reference to external content. ToolDef: type: object properties: @@ -10258,6 +10553,10 @@ components: description: >- The content of the chunk, which can be interleaved text, images, or other types. + chunk_id: + type: string + description: >- + Unique identifier for the chunk. Must be provided explicitly. metadata: type: object additionalProperties: @@ -10278,10 +10577,6 @@ components: description: >- Optional embedding for the chunk. If not provided, it will be computed later. - stored_chunk_id: - type: string - description: >- - The chunk ID that is stored in the vector database. Used for backend functionality. chunk_metadata: $ref: '#/components/schemas/ChunkMetadata' description: >- @@ -10290,6 +10585,7 @@ components: additionalProperties: false required: - content + - chunk_id - metadata title: Chunk description: >- @@ -11850,6 +12146,45 @@ components: title: AgentSessionCreateResponse description: >- Response returned when creating a new agent session. + CompletionMessage: + type: object + properties: + role: + type: string + const: assistant + default: assistant + description: >- + Must be "assistant" to identify this as the model's response + content: + $ref: '#/components/schemas/InterleavedContent' + description: The content of the model's response + stop_reason: + type: string + enum: + - end_of_turn + - end_of_message + - out_of_tokens + description: >- + Reason why the model stopped generating. Options are: - `StopReason.end_of_turn`: + The model finished generating the entire response. - `StopReason.end_of_message`: + The model finished generating but generated a partial response -- usually, + a tool call. The user may call the tool and continue the conversation + with the tool's response. - `StopReason.out_of_tokens`: The model ran + out of token budget. + tool_calls: + type: array + items: + $ref: '#/components/schemas/ToolCall' + description: >- + List of tool calls. Each tool call is a ToolCall object. + additionalProperties: false + required: + - role + - content + - stop_reason + title: CompletionMessage + description: >- + A message containing the model's (assistant) response in a chat conversation. InferenceStep: type: object properties: @@ -12002,6 +12337,29 @@ components: - step_type title: ShieldCallStep description: A shield call step in an agent turn. + ToolCall: + type: object + properties: + call_id: + type: string + tool_name: + oneOf: + - type: string + enum: + - brave_search + - wolfram_alpha + - photogen + - code_interpreter + title: BuiltinTool + - type: string + arguments: + type: string + additionalProperties: false + required: + - call_id + - tool_name + - arguments + title: ToolCall ToolExecutionStep: type: object properties: @@ -12089,6 +12447,30 @@ components: - content title: ToolResponse description: Response from a tool invocation. + ToolResponseMessage: + type: object + properties: + role: + type: string + const: tool + default: tool + description: >- + Must be "tool" to identify this as a tool response + call_id: + type: string + description: >- + Unique identifier for the tool call this response is for + content: + $ref: '#/components/schemas/InterleavedContent' + description: The response content from the tool + additionalProperties: false + required: + - role + - call_id + - content + title: ToolResponseMessage + description: >- + A message representing the result of a tool invocation. Turn: type: object properties: @@ -12174,6 +12556,31 @@ components: title: Turn description: >- A single turn in an interaction with an Agentic System. + UserMessage: + type: object + properties: + role: + type: string + const: user + default: user + description: >- + Must be "user" to identify this as a user message + content: + $ref: '#/components/schemas/InterleavedContent' + description: >- + The content of the message, which can include text and other media + context: + $ref: '#/components/schemas/InterleavedContent' + description: >- + (Optional) This field is used internally by Llama Stack to pass RAG context. + This field may be removed in the API in the future. + additionalProperties: false + required: + - role + - content + title: UserMessage + description: >- + A message from the user in a chat conversation. CreateAgentTurnRequest: type: object properties: @@ -12787,6 +13194,28 @@ components: - sampling_params title: ModelCandidate description: A model candidate for evaluation. + SystemMessage: + type: object + properties: + role: + type: string + const: system + default: system + description: >- + Must be "system" to identify this as a system message + content: + $ref: '#/components/schemas/InterleavedContent' + description: >- + The content of the "system prompt". If multiple system messages are provided, + they are concatenated. The underlying Llama Stack code may also add other + system messages (for example, for formatting tool definitions). + additionalProperties: false + required: + - role + - content + title: SystemMessage + description: >- + A system message providing instructions or context to the model. EvaluateRowsRequest: type: object properties: @@ -13527,6 +13956,19 @@ tags: description: >- APIs for creating and interacting with agentic systems. x-displayName: Agents + - name: Batches + description: >- + The API is designed to allow use of openai client libraries for seamless integration. + + + This API provides the following extensions: + - idempotent batch creation + + Note: This API is currently under active development and may undergo changes. + x-displayName: >- + The Batches API enables efficient processing of multiple requests in a single + operation, particularly useful for processing large datasets, batch evaluation + workflows, and cost-effective inference at scale. - name: Benchmarks description: '' - name: Conversations @@ -13589,8 +14031,6 @@ tags: description: '' - name: Shields description: '' - - name: SyntheticDataGeneration (Coming Soon) - description: '' - name: ToolGroups description: '' - name: ToolRuntime @@ -13601,6 +14041,7 @@ x-tagGroups: - name: Operations tags: - Agents + - Batches - Benchmarks - Conversations - DatasetIO @@ -13617,7 +14058,6 @@ x-tagGroups: - Scoring - ScoringFunctions - Shields - - SyntheticDataGeneration (Coming Soon) - ToolGroups - ToolRuntime - VectorIO diff --git a/pyproject.toml b/pyproject.toml index 1093a4c82..8f07f9cbd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ required-version = ">=0.7.0" [project] name = "llama_stack" -version = "0.3.0" +version = "0.4.0.dev0" authors = [{ name = "Meta Llama", email = "llama-oss@meta.com" }] description = "Llama Stack" readme = "README.md" @@ -284,7 +284,6 @@ exclude = [ "^src/llama_stack/models/llama/llama3/interface\\.py$", "^src/llama_stack/models/llama/llama3/tokenizer\\.py$", "^src/llama_stack/models/llama/llama3/tool_utils\\.py$", - "^src/llama_stack/providers/inline/agents/meta_reference/", "^src/llama_stack/providers/inline/datasetio/localfs/", "^src/llama_stack/providers/inline/eval/meta_reference/eval\\.py$", "^src/llama_stack/providers/inline/inference/meta_reference/inference\\.py$", diff --git a/scripts/docker.sh b/scripts/docker.sh index a0690c8a9..b56df8c03 100755 --- a/scripts/docker.sh +++ b/scripts/docker.sh @@ -215,6 +215,16 @@ build_image() { --build-arg "LLAMA_STACK_DIR=/workspace" ) + # Pass UV index configuration for release branches + if [[ -n "${UV_EXTRA_INDEX_URL:-}" ]]; then + echo "Adding UV_EXTRA_INDEX_URL to docker build: $UV_EXTRA_INDEX_URL" + build_cmd+=(--build-arg "UV_EXTRA_INDEX_URL=$UV_EXTRA_INDEX_URL") + fi + if [[ -n "${UV_INDEX_STRATEGY:-}" ]]; then + echo "Adding UV_INDEX_STRATEGY to docker build: $UV_INDEX_STRATEGY" + build_cmd+=(--build-arg "UV_INDEX_STRATEGY=$UV_INDEX_STRATEGY") + fi + if ! "${build_cmd[@]}"; then echo "❌ Failed to build Docker image" exit 1 diff --git a/scripts/integration-tests.sh b/scripts/integration-tests.sh index 93739052b..cdd3e736f 100755 --- a/scripts/integration-tests.sh +++ b/scripts/integration-tests.sh @@ -23,7 +23,7 @@ COLLECT_ONLY=false # Function to display usage usage() { - cat << EOF + cat < /dev/null; then +if [[ "$COLLECT_ONLY" == false ]] && ! command -v llama &>/dev/null; then echo "llama could not be found, ensure llama-stack is installed" exit 1 fi -if ! command -v pytest &> /dev/null; then +if ! command -v pytest &>/dev/null; then echo "pytest could not be found, ensure pytest is installed" exit 1 fi +# Helper function to find next available port +find_available_port() { + local start_port=$1 + local port=$start_port + for ((i=0; i<100; i++)); do + if ! lsof -Pi :$port -sTCP:LISTEN -t >/dev/null 2>&1; then + echo $port + return 0 + fi + ((port++)) + done + echo "Failed to find available port starting from $start_port" >&2 + return 1 +} + # Start Llama Stack Server if needed if [[ "$STACK_CONFIG" == *"server:"* && "$COLLECT_ONLY" == false ]]; then + # Find an available port for the server + LLAMA_STACK_PORT=$(find_available_port 8321) + if [[ $? -ne 0 ]]; then + echo "Error: $LLAMA_STACK_PORT" + exit 1 + fi + export LLAMA_STACK_PORT + echo "Will use port: $LLAMA_STACK_PORT" + stop_server() { echo "Stopping Llama Stack Server..." - pids=$(lsof -i :8321 | awk 'NR>1 {print $2}') + pids=$(lsof -i :$LLAMA_STACK_PORT | awk 'NR>1 {print $2}') if [[ -n "$pids" ]]; then echo "Killing Llama Stack Server processes: $pids" kill -9 $pids @@ -201,33 +224,37 @@ if [[ "$STACK_CONFIG" == *"server:"* && "$COLLECT_ONLY" == false ]]; then echo "Llama Stack Server stopped" } - # check if server is already running - if curl -s http://localhost:8321/v1/health 2>/dev/null | grep -q "OK"; then - echo "Llama Stack Server is already running, skipping start" - else - echo "=== Starting Llama Stack Server ===" - export LLAMA_STACK_LOG_WIDTH=120 + echo "=== Starting Llama Stack Server ===" + export LLAMA_STACK_LOG_WIDTH=120 - # remove "server:" from STACK_CONFIG - stack_config=$(echo "$STACK_CONFIG" | sed 's/^server://') - nohup llama stack run $stack_config > server.log 2>&1 & + # Configure telemetry collector for server mode + # Use a fixed port for the OTEL collector so the server can connect to it + COLLECTOR_PORT=4317 + export LLAMA_STACK_TEST_COLLECTOR_PORT="${COLLECTOR_PORT}" + export OTEL_EXPORTER_OTLP_ENDPOINT="http://127.0.0.1:${COLLECTOR_PORT}" + export OTEL_EXPORTER_OTLP_PROTOCOL="http/protobuf" + export OTEL_BSP_SCHEDULE_DELAY="200" + export OTEL_BSP_EXPORT_TIMEOUT="2000" - echo "Waiting for Llama Stack Server to start..." - for i in {1..30}; do - if curl -s http://localhost:8321/v1/health 2>/dev/null | grep -q "OK"; then - echo "✅ Llama Stack Server started successfully" - break - fi - if [[ $i -eq 30 ]]; then - echo "❌ Llama Stack Server failed to start" - echo "Server logs:" - cat server.log - exit 1 - fi - sleep 1 - done - echo "" - fi + # remove "server:" from STACK_CONFIG + stack_config=$(echo "$STACK_CONFIG" | sed 's/^server://') + nohup llama stack run $stack_config >server.log 2>&1 & + + echo "Waiting for Llama Stack Server to start on port $LLAMA_STACK_PORT..." + for i in {1..30}; do + if curl -s http://localhost:$LLAMA_STACK_PORT/v1/health 2>/dev/null | grep -q "OK"; then + echo "✅ Llama Stack Server started successfully" + break + fi + if [[ $i -eq 30 ]]; then + echo "❌ Llama Stack Server failed to start" + echo "Server logs:" + cat server.log + exit 1 + fi + sleep 1 + done + echo "" trap stop_server EXIT ERR INT TERM fi @@ -239,7 +266,7 @@ if [[ "$STACK_CONFIG" == *"docker:"* && "$COLLECT_ONLY" == false ]]; then container_name="llama-stack-test-$DISTRO" if docker ps -a --format '{{.Names}}' | grep -q "^${container_name}$"; then echo "Dumping container logs before stopping..." - docker logs "$container_name" > "docker-${DISTRO}-${INFERENCE_MODE}.log" 2>&1 || true + docker logs "$container_name" >"docker-${DISTRO}-${INFERENCE_MODE}.log" 2>&1 || true echo "Stopping and removing container: $container_name" docker stop "$container_name" 2>/dev/null || true docker rm "$container_name" 2>/dev/null || true @@ -251,7 +278,14 @@ if [[ "$STACK_CONFIG" == *"docker:"* && "$COLLECT_ONLY" == false ]]; then # Extract distribution name from docker:distro format DISTRO=$(echo "$STACK_CONFIG" | sed 's/^docker://') - export LLAMA_STACK_PORT=8321 + # Find an available port for the docker container + LLAMA_STACK_PORT=$(find_available_port 8321) + if [[ $? -ne 0 ]]; then + echo "Error: $LLAMA_STACK_PORT" + exit 1 + fi + export LLAMA_STACK_PORT + echo "Will use port: $LLAMA_STACK_PORT" echo "=== Building Docker Image for distribution: $DISTRO ===" containerfile="$ROOT_DIR/containers/Containerfile" @@ -271,6 +305,16 @@ if [[ "$STACK_CONFIG" == *"docker:"* && "$COLLECT_ONLY" == false ]]; then --build-arg "LLAMA_STACK_DIR=/workspace" ) + # Pass UV index configuration for release branches + if [[ -n "${UV_EXTRA_INDEX_URL:-}" ]]; then + echo "Adding UV_EXTRA_INDEX_URL to docker build: $UV_EXTRA_INDEX_URL" + build_cmd+=(--build-arg "UV_EXTRA_INDEX_URL=$UV_EXTRA_INDEX_URL") + fi + if [[ -n "${UV_INDEX_STRATEGY:-}" ]]; then + echo "Adding UV_INDEX_STRATEGY to docker build: $UV_INDEX_STRATEGY" + build_cmd+=(--build-arg "UV_INDEX_STRATEGY=$UV_INDEX_STRATEGY") + fi + if ! "${build_cmd[@]}"; then echo "❌ Failed to build Docker image" exit 1 @@ -284,10 +328,15 @@ if [[ "$STACK_CONFIG" == *"docker:"* && "$COLLECT_ONLY" == false ]]; then docker stop "$container_name" 2>/dev/null || true docker rm "$container_name" 2>/dev/null || true + # Configure telemetry collector port shared between host and container + COLLECTOR_PORT=4317 + export LLAMA_STACK_TEST_COLLECTOR_PORT="${COLLECTOR_PORT}" + # Build environment variables for docker run DOCKER_ENV_VARS="" DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e LLAMA_STACK_TEST_INFERENCE_MODE=$INFERENCE_MODE" DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e LLAMA_STACK_TEST_STACK_CONFIG_TYPE=server" + DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e OTEL_EXPORTER_OTLP_ENDPOINT=http://localhost:${COLLECTOR_PORT}" # Pass through API keys if they exist [ -n "${TOGETHER_API_KEY:-}" ] && DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e TOGETHER_API_KEY=$TOGETHER_API_KEY" @@ -308,8 +357,20 @@ if [[ "$STACK_CONFIG" == *"docker:"* && "$COLLECT_ONLY" == false ]]; then fi echo "Using image: $IMAGE_NAME" - docker run -d --network host --name "$container_name" \ - -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ + # On macOS/Darwin, --network host doesn't work as expected due to Docker running in a VM + # Use regular port mapping instead + NETWORK_MODE="" + PORT_MAPPINGS="" + if [[ "$(uname)" != "Darwin" ]] && [[ "$(uname)" != *"MINGW"* ]]; then + NETWORK_MODE="--network host" + else + # On non-Linux (macOS, Windows), need explicit port mappings for both app and telemetry + PORT_MAPPINGS="-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT -p $COLLECTOR_PORT:$COLLECTOR_PORT" + echo "Using bridge networking with port mapping (non-Linux)" + fi + + docker run -d $NETWORK_MODE --name "$container_name" \ + $PORT_MAPPINGS \ $DOCKER_ENV_VARS \ "$IMAGE_NAME" \ --port $LLAMA_STACK_PORT @@ -411,17 +472,13 @@ elif [ $exit_code -eq 5 ]; then else echo "❌ Tests failed" echo "" - echo "=== Dumping last 100 lines of logs for debugging ===" - # Output server or container logs based on stack config if [[ "$STACK_CONFIG" == *"server:"* && -f "server.log" ]]; then - echo "--- Last 100 lines of server.log ---" - tail -100 server.log + echo "--- Server side failures can be located inside server.log (available from artifacts on CI) ---" elif [[ "$STACK_CONFIG" == *"docker:"* ]]; then docker_log_file="docker-${DISTRO}-${INFERENCE_MODE}.log" if [[ -f "$docker_log_file" ]]; then - echo "--- Last 100 lines of $docker_log_file ---" - tail -100 "$docker_log_file" + echo "--- Server side failures can be located inside $docker_log_file (available from artifacts on CI) ---" fi fi diff --git a/scripts/uv-run-with-index.sh b/scripts/uv-run-with-index.sh new file mode 100755 index 000000000..18d0a0e9c --- /dev/null +++ b/scripts/uv-run-with-index.sh @@ -0,0 +1,42 @@ +#!/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +set -euo pipefail + +# Detect current branch and target branch +# In GitHub Actions, use GITHUB_REF/GITHUB_BASE_REF +if [[ -n "${GITHUB_REF:-}" ]]; then + BRANCH="${GITHUB_REF#refs/heads/}" +else + BRANCH=$(git rev-parse --abbrev-ref HEAD 2>/dev/null || echo "") +fi + +# For PRs, check the target branch +if [[ -n "${GITHUB_BASE_REF:-}" ]]; then + TARGET_BRANCH="${GITHUB_BASE_REF}" +else + TARGET_BRANCH=$(git rev-parse --abbrev-ref HEAD@{upstream} 2>/dev/null | sed 's|origin/||' || echo "") +fi + +# Check if on a release branch or targeting one, or LLAMA_STACK_RELEASE_MODE is set +IS_RELEASE=false +if [[ "$BRANCH" =~ ^release-[0-9]+\.[0-9]+\.x$ ]]; then + IS_RELEASE=true +elif [[ "$TARGET_BRANCH" =~ ^release-[0-9]+\.[0-9]+\.x$ ]]; then + IS_RELEASE=true +elif [[ "${LLAMA_STACK_RELEASE_MODE:-}" == "true" ]]; then + IS_RELEASE=true +fi + +# On release branches, use test.pypi as extra index for RC versions +if [[ "$IS_RELEASE" == "true" ]]; then + export UV_EXTRA_INDEX_URL="https://test.pypi.org/simple/" + export UV_INDEX_STRATEGY="unsafe-best-match" +fi + +# Run uv with all arguments passed through +exec uv "$@" diff --git a/src/llama_stack/apis/agents/openai_responses.py b/src/llama_stack/apis/agents/openai_responses.py index 972b03c94..69e2b2012 100644 --- a/src/llama_stack/apis/agents/openai_responses.py +++ b/src/llama_stack/apis/agents/openai_responses.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from collections.abc import Sequence from typing import Annotated, Any, Literal from pydantic import BaseModel, Field, model_validator @@ -202,7 +203,7 @@ class OpenAIResponseMessage(BaseModel): scenarios. """ - content: str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent] + content: str | Sequence[OpenAIResponseInputMessageContent] | Sequence[OpenAIResponseOutputMessageContent] role: Literal["system"] | Literal["developer"] | Literal["user"] | Literal["assistant"] type: Literal["message"] = "message" @@ -254,10 +255,10 @@ class OpenAIResponseOutputMessageFileSearchToolCall(BaseModel): """ id: str - queries: list[str] + queries: Sequence[str] status: str type: Literal["file_search_call"] = "file_search_call" - results: list[OpenAIResponseOutputMessageFileSearchToolCallResults] | None = None + results: Sequence[OpenAIResponseOutputMessageFileSearchToolCallResults] | None = None @json_schema_type @@ -597,7 +598,7 @@ class OpenAIResponseObject(BaseModel): id: str model: str object: Literal["response"] = "response" - output: list[OpenAIResponseOutput] + output: Sequence[OpenAIResponseOutput] parallel_tool_calls: bool = False previous_response_id: str | None = None prompt: OpenAIResponsePrompt | None = None @@ -607,7 +608,7 @@ class OpenAIResponseObject(BaseModel): # before the field was added. New responses will have this set always. text: OpenAIResponseText = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) top_p: float | None = None - tools: list[OpenAIResponseTool] | None = None + tools: Sequence[OpenAIResponseTool] | None = None truncation: str | None = None usage: OpenAIResponseUsage | None = None instructions: str | None = None @@ -1315,7 +1316,7 @@ class ListOpenAIResponseInputItem(BaseModel): :param object: Object type identifier, always "list" """ - data: list[OpenAIResponseInput] + data: Sequence[OpenAIResponseInput] object: Literal["list"] = "list" @@ -1326,7 +1327,7 @@ class OpenAIResponseObjectWithInput(OpenAIResponseObject): :param input: List of input items that led to this response """ - input: list[OpenAIResponseInput] + input: Sequence[OpenAIResponseInput] def to_response_object(self) -> OpenAIResponseObject: """Convert to OpenAIResponseObject by excluding input field.""" @@ -1344,7 +1345,7 @@ class ListOpenAIResponseObject(BaseModel): :param object: Object type identifier, always "list" """ - data: list[OpenAIResponseObjectWithInput] + data: Sequence[OpenAIResponseObjectWithInput] has_more: bool first_id: str last_id: str diff --git a/src/llama_stack/apis/inspect/inspect.py b/src/llama_stack/apis/inspect/inspect.py index 8b0996e69..4e0e2548b 100644 --- a/src/llama_stack/apis/inspect/inspect.py +++ b/src/llama_stack/apis/inspect/inspect.py @@ -4,14 +4,21 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Protocol, runtime_checkable +from typing import Literal, Protocol, runtime_checkable from pydantic import BaseModel -from llama_stack.apis.version import LLAMA_STACK_API_V1 +from llama_stack.apis.version import ( + LLAMA_STACK_API_V1, +) from llama_stack.providers.datatypes import HealthStatus from llama_stack.schema_utils import json_schema_type, webmethod +# Valid values for the route filter parameter. +# Actual API levels: v1, v1alpha, v1beta (filters by level, excludes deprecated) +# Special filter value: "deprecated" (shows deprecated routes regardless of level) +ApiFilter = Literal["v1", "v1alpha", "v1beta", "deprecated"] + @json_schema_type class RouteInfo(BaseModel): @@ -64,11 +71,12 @@ class Inspect(Protocol): """ @webmethod(route="/inspect/routes", method="GET", level=LLAMA_STACK_API_V1) - async def list_routes(self) -> ListRoutesResponse: + async def list_routes(self, api_filter: ApiFilter | None = None) -> ListRoutesResponse: """List routes. List all available API routes with their methods and implementing providers. + :param api_filter: Optional filter to control which routes are returned. Can be an API level ('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level, or 'deprecated' to show deprecated routes across all levels. If not specified, returns only non-deprecated v1 routes. :returns: Response containing information about all available routes. """ ... diff --git a/src/llama_stack/apis/models/models.py b/src/llama_stack/apis/models/models.py index 903bd6510..a963c8dcc 100644 --- a/src/llama_stack/apis/models/models.py +++ b/src/llama_stack/apis/models/models.py @@ -90,12 +90,14 @@ class OpenAIModel(BaseModel): :object: The object type, which will be "model" :created: The Unix timestamp in seconds when the model was created :owned_by: The owner of the model + :custom_metadata: Llama Stack-specific metadata including model_type, provider info, and additional metadata """ id: str object: Literal["model"] = "model" created: int owned_by: str + custom_metadata: dict[str, Any] | None = None class OpenAIListModelsResponse(BaseModel): @@ -113,7 +115,7 @@ class Models(Protocol): """ ... - @webmethod(route="/openai/v1/models", method="GET", level=LLAMA_STACK_API_V1, deprecated=True) + @webmethod(route="/openai/v1/models", method="GET", level=LLAMA_STACK_API_V1) async def openai_list_models(self) -> OpenAIListModelsResponse: """List models using the OpenAI API. diff --git a/src/llama_stack/apis/synthetic_data_generation/__init__.py b/src/llama_stack/apis/synthetic_data_generation/__init__.py deleted file mode 100644 index bc169e8e6..000000000 --- a/src/llama_stack/apis/synthetic_data_generation/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from .synthetic_data_generation import * diff --git a/src/llama_stack/apis/synthetic_data_generation/synthetic_data_generation.py b/src/llama_stack/apis/synthetic_data_generation/synthetic_data_generation.py deleted file mode 100644 index c13e2c17c..000000000 --- a/src/llama_stack/apis/synthetic_data_generation/synthetic_data_generation.py +++ /dev/null @@ -1,77 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from enum import Enum -from typing import Any, Protocol - -from pydantic import BaseModel - -from llama_stack.apis.inference import Message -from llama_stack.apis.version import LLAMA_STACK_API_V1 -from llama_stack.schema_utils import json_schema_type, webmethod - - -class FilteringFunction(Enum): - """The type of filtering function. - - :cvar none: No filtering applied, accept all generated synthetic data - :cvar random: Random sampling of generated data points - :cvar top_k: Keep only the top-k highest scoring synthetic data samples - :cvar top_p: Nucleus-style filtering, keep samples exceeding cumulative score threshold - :cvar top_k_top_p: Combined top-k and top-p filtering strategy - :cvar sigmoid: Apply sigmoid function for probability-based filtering - """ - - none = "none" - random = "random" - top_k = "top_k" - top_p = "top_p" - top_k_top_p = "top_k_top_p" - sigmoid = "sigmoid" - - -@json_schema_type -class SyntheticDataGenerationRequest(BaseModel): - """Request to generate synthetic data. A small batch of prompts and a filtering function - - :param dialogs: List of conversation messages to use as input for synthetic data generation - :param filtering_function: Type of filtering to apply to generated synthetic data samples - :param model: (Optional) The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint - """ - - dialogs: list[Message] - filtering_function: FilteringFunction = FilteringFunction.none - model: str | None = None - - -@json_schema_type -class SyntheticDataGenerationResponse(BaseModel): - """Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold. - - :param synthetic_data: List of generated synthetic data samples that passed the filtering criteria - :param statistics: (Optional) Statistical information about the generation process and filtering results - """ - - synthetic_data: list[dict[str, Any]] - statistics: dict[str, Any] | None = None - - -class SyntheticDataGeneration(Protocol): - @webmethod(route="/synthetic-data-generation/generate", level=LLAMA_STACK_API_V1) - def synthetic_data_generate( - self, - dialogs: list[Message], - filtering_function: FilteringFunction = FilteringFunction.none, - model: str | None = None, - ) -> SyntheticDataGenerationResponse: - """Generate synthetic data based on input dialogs and apply filtering. - - :param dialogs: List of conversation messages to use as input for synthetic data generation - :param filtering_function: Type of filtering to apply to generated synthetic data samples - :param model: (Optional) The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint - :returns: Response containing filtered synthetic data samples and optional statistics - """ - ... diff --git a/src/llama_stack/apis/vector_io/vector_io.py b/src/llama_stack/apis/vector_io/vector_io.py index 19703e7bb..0ef2a6fd6 100644 --- a/src/llama_stack/apis/vector_io/vector_io.py +++ b/src/llama_stack/apis/vector_io/vector_io.py @@ -8,7 +8,6 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import uuid from typing import Annotated, Any, Literal, Protocol, runtime_checkable from fastapi import Body @@ -18,7 +17,6 @@ from llama_stack.apis.inference import InterleavedContent from llama_stack.apis.vector_stores import VectorStore from llama_stack.apis.version import LLAMA_STACK_API_V1 from llama_stack.core.telemetry.trace_protocol import trace_protocol -from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id from llama_stack.schema_utils import json_schema_type, webmethod from llama_stack.strong_typing.schema import register_schema @@ -61,38 +59,19 @@ class Chunk(BaseModel): """ A chunk of content that can be inserted into a vector database. :param content: The content of the chunk, which can be interleaved text, images, or other types. - :param embedding: Optional embedding for the chunk. If not provided, it will be computed later. + :param chunk_id: Unique identifier for the chunk. Must be provided explicitly. :param metadata: Metadata associated with the chunk that will be used in the model context during inference. - :param stored_chunk_id: The chunk ID that is stored in the vector database. Used for backend functionality. + :param embedding: Optional embedding for the chunk. If not provided, it will be computed later. :param chunk_metadata: Metadata for the chunk that will NOT be used in the context during inference. The `chunk_metadata` is required backend functionality. """ content: InterleavedContent + chunk_id: str metadata: dict[str, Any] = Field(default_factory=dict) embedding: list[float] | None = None - # The alias parameter serializes the field as "chunk_id" in JSON but keeps the internal name as "stored_chunk_id" - stored_chunk_id: str | None = Field(default=None, alias="chunk_id") chunk_metadata: ChunkMetadata | None = None - model_config = {"populate_by_name": True} - - def model_post_init(self, __context): - # Extract chunk_id from metadata if present - if self.metadata and "chunk_id" in self.metadata: - self.stored_chunk_id = self.metadata.pop("chunk_id") - - @property - def chunk_id(self) -> str: - """Returns the chunk ID, which is either an input `chunk_id` or a generated one if not set.""" - if self.stored_chunk_id: - return self.stored_chunk_id - - if "document_id" in self.metadata: - return generate_chunk_id(self.metadata["document_id"], str(self.content)) - - return generate_chunk_id(str(uuid.uuid4()), str(self.content)) - @property def document_id(self) -> str | None: """Returns the document_id from either metadata or chunk_metadata, with metadata taking precedence.""" diff --git a/src/llama_stack/cli/stack/run.py b/src/llama_stack/cli/stack/run.py index 2882500ce..9ceb238fa 100644 --- a/src/llama_stack/cli/stack/run.py +++ b/src/llama_stack/cli/stack/run.py @@ -8,16 +8,30 @@ import argparse import os import ssl import subprocess +import sys from pathlib import Path import uvicorn import yaml +from termcolor import cprint from llama_stack.cli.stack.utils import ImageType from llama_stack.cli.subcommand import Subcommand -from llama_stack.core.datatypes import StackRunConfig +from llama_stack.core.datatypes import Api, Provider, StackRunConfig +from llama_stack.core.distribution import get_provider_registry from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars +from llama_stack.core.storage.datatypes import ( + InferenceStoreReference, + KVStoreReference, + ServerStoresConfig, + SqliteKVStoreConfig, + SqliteSqlStoreConfig, + SqlStoreReference, + StorageConfig, +) +from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro +from llama_stack.core.utils.dynamic import instantiate_class_type from llama_stack.log import LoggingConfig, get_logger REPO_ROOT = Path(__file__).parent.parent.parent.parent @@ -68,6 +82,12 @@ class StackRun(Subcommand): action="store_true", help="Start the UI server", ) + self.parser.add_argument( + "--providers", + type=str, + default=None, + help="Run a stack with only a list of providers. This list is formatted like: api1=provider1,api1=provider2,api2=provider3. Where there can be multiple providers per API.", + ) def _run_stack_run_cmd(self, args: argparse.Namespace) -> None: import yaml @@ -93,6 +113,55 @@ class StackRun(Subcommand): config_file = resolve_config_or_distro(args.config, Mode.RUN) except ValueError as e: self.parser.error(str(e)) + elif args.providers: + provider_list: dict[str, list[Provider]] = dict() + for api_provider in args.providers.split(","): + if "=" not in api_provider: + cprint( + "Could not parse `--providers`. Please ensure the list is in the format api1=provider1,api2=provider2", + color="red", + file=sys.stderr, + ) + sys.exit(1) + api, provider_type = api_provider.split("=") + providers_for_api = get_provider_registry().get(Api(api), None) + if providers_for_api is None: + cprint( + f"{api} is not a valid API.", + color="red", + file=sys.stderr, + ) + sys.exit(1) + if provider_type in providers_for_api: + config_type = instantiate_class_type(providers_for_api[provider_type].config_class) + if config_type is not None and hasattr(config_type, "sample_run_config"): + config = config_type.sample_run_config(__distro_dir__="~/.llama/distributions/providers-run") + else: + config = {} + provider = Provider( + provider_type=provider_type, + config=config, + provider_id=provider_type.split("::")[1], + ) + provider_list.setdefault(api, []).append(provider) + else: + cprint( + f"{provider} is not a valid provider for the {api} API.", + color="red", + file=sys.stderr, + ) + sys.exit(1) + run_config = self._generate_run_config_from_providers(providers=provider_list) + config_dict = run_config.model_dump(mode="json") + + # Write config to disk in providers-run directory + distro_dir = DISTRIBS_BASE_DIR / "providers-run" + config_file = distro_dir / "run.yaml" + + logger.info(f"Writing generated config to: {config_file}") + with open(config_file, "w") as f: + yaml.dump(config_dict, f, default_flow_style=False, sort_keys=False) + else: config_file = None @@ -106,7 +175,8 @@ class StackRun(Subcommand): try: config = parse_and_maybe_upgrade_config(config_dict) - if not os.path.exists(str(config.external_providers_dir)): + # Create external_providers_dir if it's specified and doesn't exist + if config.external_providers_dir and not os.path.exists(str(config.external_providers_dir)): os.makedirs(str(config.external_providers_dir), exist_ok=True) except AttributeError as e: self.parser.error(f"failed to parse config file '{config_file}':\n {e}") @@ -127,7 +197,7 @@ class StackRun(Subcommand): config = StackRunConfig(**cast_image_name_to_string(replace_env_vars(config_contents))) port = args.port or config.server.port - host = config.server.host or ["::", "0.0.0.0"] + host = config.server.host or "0.0.0.0" # Set the config file in environment so create_app can find it os.environ["LLAMA_STACK_CONFIG"] = str(config_file) @@ -139,6 +209,7 @@ class StackRun(Subcommand): "lifespan": "on", "log_level": logger.getEffectiveLevel(), "log_config": logger_config, + "workers": config.server.workers, } keyfile = config.server.tls_keyfile @@ -212,3 +283,44 @@ class StackRun(Subcommand): ) except Exception as e: logger.error(f"Failed to start UI development server in {ui_dir}: {e}") + + def _generate_run_config_from_providers(self, providers: dict[str, list[Provider]]): + apis = list(providers.keys()) + distro_dir = DISTRIBS_BASE_DIR / "providers-run" + # need somewhere to put the storage. + os.makedirs(distro_dir, exist_ok=True) + storage = StorageConfig( + backends={ + "kv_default": SqliteKVStoreConfig( + db_path=f"${{env.SQLITE_STORE_DIR:={distro_dir}}}/kvstore.db", + ), + "sql_default": SqliteSqlStoreConfig( + db_path=f"${{env.SQLITE_STORE_DIR:={distro_dir}}}/sql_store.db", + ), + }, + stores=ServerStoresConfig( + metadata=KVStoreReference( + backend="kv_default", + namespace="registry", + ), + inference=InferenceStoreReference( + backend="sql_default", + table_name="inference_store", + ), + conversations=SqlStoreReference( + backend="sql_default", + table_name="openai_conversations", + ), + prompts=KVStoreReference( + backend="kv_default", + namespace="prompts", + ), + ), + ) + + return StackRunConfig( + image_name="providers-run", + apis=apis, + providers=providers, + storage=storage, + ) diff --git a/src/llama_stack/core/configure.py b/src/llama_stack/core/configure.py index 734839ea9..5d4a54184 100644 --- a/src/llama_stack/core/configure.py +++ b/src/llama_stack/core/configure.py @@ -17,7 +17,6 @@ from llama_stack.core.distribution import ( get_provider_registry, ) from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars -from llama_stack.core.utils.config_dirs import EXTERNAL_PROVIDERS_DIR from llama_stack.core.utils.dynamic import instantiate_class_type from llama_stack.core.utils.prompt_for_config import prompt_for_config from llama_stack.log import get_logger @@ -194,19 +193,11 @@ def upgrade_from_routing_table( def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfig: - version = config_dict.get("version", None) - if version == LLAMA_STACK_RUN_CONFIG_VERSION: - processed_config_dict = replace_env_vars(config_dict) - return StackRunConfig(**cast_image_name_to_string(processed_config_dict)) - if "routing_table" in config_dict: logger.info("Upgrading config...") config_dict = upgrade_from_routing_table(config_dict) config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION - if not config_dict.get("external_providers_dir", None): - config_dict["external_providers_dir"] = EXTERNAL_PROVIDERS_DIR - processed_config_dict = replace_env_vars(config_dict) return StackRunConfig(**cast_image_name_to_string(processed_config_dict)) diff --git a/src/llama_stack/core/datatypes.py b/src/llama_stack/core/datatypes.py index 95907adcf..2182ea4e5 100644 --- a/src/llama_stack/core/datatypes.py +++ b/src/llama_stack/core/datatypes.py @@ -473,6 +473,10 @@ class ServerConfig(BaseModel): "- true: Enable localhost CORS for development\n" "- {allow_origins: [...], allow_methods: [...], ...}: Full configuration", ) + workers: int = Field( + default=1, + description="Number of workers to use for the server", + ) class StackRunConfig(BaseModel): diff --git a/src/llama_stack/core/inspect.py b/src/llama_stack/core/inspect.py index 37dab4199..6352af00f 100644 --- a/src/llama_stack/core/inspect.py +++ b/src/llama_stack/core/inspect.py @@ -15,6 +15,7 @@ from llama_stack.apis.inspect import ( RouteInfo, VersionInfo, ) +from llama_stack.apis.version import LLAMA_STACK_API_V1 from llama_stack.core.datatypes import StackRunConfig from llama_stack.core.external import load_external_apis from llama_stack.core.server.routes import get_all_api_routes @@ -39,9 +40,21 @@ class DistributionInspectImpl(Inspect): async def initialize(self) -> None: pass - async def list_routes(self) -> ListRoutesResponse: + async def list_routes(self, api_filter: str | None = None) -> ListRoutesResponse: run_config: StackRunConfig = self.config.run_config + # Helper function to determine if a route should be included based on api_filter + def should_include_route(webmethod) -> bool: + if api_filter is None: + # Default: only non-deprecated v1 APIs + return not webmethod.deprecated and webmethod.level == LLAMA_STACK_API_V1 + elif api_filter == "deprecated": + # Special filter: show deprecated routes regardless of their actual level + return bool(webmethod.deprecated) + else: + # Filter by API level (non-deprecated routes only) + return not webmethod.deprecated and webmethod.level == api_filter + ret = [] external_apis = load_external_apis(run_config) all_endpoints = get_all_api_routes(external_apis) @@ -55,8 +68,8 @@ class DistributionInspectImpl(Inspect): method=next(iter([m for m in e.methods if m != "HEAD"])), provider_types=[], # These APIs don't have "real" providers - they're internal to the stack ) - for e, _ in endpoints - if e.methods is not None + for e, webmethod in endpoints + if e.methods is not None and should_include_route(webmethod) ] ) else: @@ -69,8 +82,8 @@ class DistributionInspectImpl(Inspect): method=next(iter([m for m in e.methods if m != "HEAD"])), provider_types=[p.provider_type for p in providers], ) - for e, _ in endpoints - if e.methods is not None + for e, webmethod in endpoints + if e.methods is not None and should_include_route(webmethod) ] ) diff --git a/src/llama_stack/core/routers/inference.py b/src/llama_stack/core/routers/inference.py index ef8270093..a4f0f4411 100644 --- a/src/llama_stack/core/routers/inference.py +++ b/src/llama_stack/core/routers/inference.py @@ -6,7 +6,7 @@ import asyncio import time -from collections.abc import AsyncGenerator, AsyncIterator +from collections.abc import AsyncIterator from datetime import UTC, datetime from typing import Annotated, Any @@ -15,20 +15,10 @@ from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatC from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam from pydantic import TypeAdapter -from llama_stack.apis.common.content_types import ( - InterleavedContent, -) from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError from llama_stack.apis.inference import ( - ChatCompletionResponse, - ChatCompletionResponseEventType, - ChatCompletionResponseStreamChunk, - CompletionMessage, - CompletionResponse, - CompletionResponseStreamChunk, Inference, ListOpenAIChatCompletionResponse, - Message, OpenAIAssistantMessageParam, OpenAIChatCompletion, OpenAIChatCompletionChunk, @@ -45,15 +35,13 @@ from llama_stack.apis.inference import ( OpenAIMessageParam, Order, RerankResponse, - StopReason, - ToolPromptFormat, ) from llama_stack.apis.inference.inference import ( OpenAIChatCompletionContentPartImageParam, OpenAIChatCompletionContentPartTextParam, ) -from llama_stack.apis.models import Model, ModelType -from llama_stack.core.telemetry.telemetry import MetricEvent, MetricInResponse +from llama_stack.apis.models import ModelType +from llama_stack.core.telemetry.telemetry import MetricEvent from llama_stack.core.telemetry.tracing import enqueue_event, get_current_span from llama_stack.log import get_logger from llama_stack.models.llama.llama3.chat_format import ChatFormat @@ -153,35 +141,6 @@ class InferenceRouter(Inference): ) return metric_events - async def _compute_and_log_token_usage( - self, - prompt_tokens: int, - completion_tokens: int, - total_tokens: int, - model: Model, - ) -> list[MetricInResponse]: - metrics = self._construct_metrics( - prompt_tokens, completion_tokens, total_tokens, model.model_id, model.provider_id - ) - if self.telemetry_enabled: - for metric in metrics: - enqueue_event(metric) - return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics] - - async def _count_tokens( - self, - messages: list[Message] | InterleavedContent, - tool_prompt_format: ToolPromptFormat | None = None, - ) -> int | None: - if not hasattr(self, "formatter") or self.formatter is None: - return None - - if isinstance(messages, list): - encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format) - else: - encoded = self.formatter.encode_content(messages) - return len(encoded.tokens) if encoded and encoded.tokens else 0 - async def _get_model_provider(self, model_id: str, expected_model_type: str) -> tuple[Inference, str]: model = await self.routing_table.get_object_by_identifier("model", model_id) if model: @@ -375,121 +334,6 @@ class InferenceRouter(Inference): ) return health_statuses - async def stream_tokens_and_compute_metrics( - self, - response, - prompt_tokens, - fully_qualified_model_id: str, - provider_id: str, - tool_prompt_format: ToolPromptFormat | None = None, - ) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None] | AsyncGenerator[CompletionResponseStreamChunk, None]: - completion_text = "" - async for chunk in response: - complete = False - if hasattr(chunk, "event"): # only ChatCompletions have .event - if chunk.event.event_type == ChatCompletionResponseEventType.progress: - if chunk.event.delta.type == "text": - completion_text += chunk.event.delta.text - if chunk.event.event_type == ChatCompletionResponseEventType.complete: - complete = True - completion_tokens = await self._count_tokens( - [ - CompletionMessage( - content=completion_text, - stop_reason=StopReason.end_of_turn, - ) - ], - tool_prompt_format=tool_prompt_format, - ) - else: - if hasattr(chunk, "delta"): - completion_text += chunk.delta - if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry_enabled: - complete = True - completion_tokens = await self._count_tokens(completion_text) - # if we are done receiving tokens - if complete: - total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) - - # Create a separate span for streaming completion metrics - if self.telemetry_enabled: - # Log metrics in the new span context - completion_metrics = self._construct_metrics( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=total_tokens, - fully_qualified_model_id=fully_qualified_model_id, - provider_id=provider_id, - ) - for metric in completion_metrics: - if metric.metric in [ - "completion_tokens", - "total_tokens", - ]: # Only log completion and total tokens - enqueue_event(metric) - - # Return metrics in response - async_metrics = [ - MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics - ] - chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics - else: - # Fallback if no telemetry - completion_metrics = self._construct_metrics( - prompt_tokens or 0, - completion_tokens or 0, - total_tokens, - fully_qualified_model_id=fully_qualified_model_id, - provider_id=provider_id, - ) - async_metrics = [ - MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics - ] - chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics - yield chunk - - async def count_tokens_and_compute_metrics( - self, - response: ChatCompletionResponse | CompletionResponse, - prompt_tokens, - fully_qualified_model_id: str, - provider_id: str, - tool_prompt_format: ToolPromptFormat | None = None, - ): - if isinstance(response, ChatCompletionResponse): - content = [response.completion_message] - else: - content = response.content - completion_tokens = await self._count_tokens(messages=content, tool_prompt_format=tool_prompt_format) - total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) - - # Create a separate span for completion metrics - if self.telemetry_enabled: - # Log metrics in the new span context - completion_metrics = self._construct_metrics( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=total_tokens, - fully_qualified_model_id=fully_qualified_model_id, - provider_id=provider_id, - ) - for metric in completion_metrics: - if metric.metric in ["completion_tokens", "total_tokens"]: # Only log completion and total tokens - enqueue_event(metric) - - # Return metrics in response - return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics] - - # Fallback if no telemetry - metrics = self._construct_metrics( - prompt_tokens or 0, - completion_tokens or 0, - total_tokens, - fully_qualified_model_id=fully_qualified_model_id, - provider_id=provider_id, - ) - return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics] - async def stream_tokens_and_compute_metrics_openai_chat( self, response: AsyncIterator[OpenAIChatCompletionChunk], diff --git a/src/llama_stack/core/routing_tables/models.py b/src/llama_stack/core/routing_tables/models.py index 7e43d7273..1fb1186cd 100644 --- a/src/llama_stack/core/routing_tables/models.py +++ b/src/llama_stack/core/routing_tables/models.py @@ -13,6 +13,8 @@ from llama_stack.core.datatypes import ( ModelWithOwner, RegistryEntrySource, ) +from llama_stack.core.request_headers import PROVIDER_DATA_VAR, NeedsRequestProviderData +from llama_stack.core.utils.dynamic import instantiate_class_type from llama_stack.log import get_logger from .common import CommonRoutingTableImpl, lookup_model @@ -42,19 +44,104 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): await self.update_registered_models(provider_id, models) + async def _get_dynamic_models_from_provider_data(self) -> list[Model]: + """ + Fetch models from providers that have credentials in the current request's provider_data. + + This allows users to see models available to them from providers that require + per-request API keys (via X-LlamaStack-Provider-Data header). + + Returns models with fully qualified identifiers (provider_id/model_id) but does NOT + cache them in the registry since they are user-specific. + """ + provider_data = PROVIDER_DATA_VAR.get() + if not provider_data: + return [] + + dynamic_models = [] + + for provider_id, provider in self.impls_by_provider_id.items(): + # Check if this provider supports provider_data + if not isinstance(provider, NeedsRequestProviderData): + continue + + # Check if provider has a validator (some providers like ollama don't need per-request credentials) + spec = getattr(provider, "__provider_spec__", None) + if not spec or not getattr(spec, "provider_data_validator", None): + continue + + # Validate provider_data silently - we're speculatively checking all providers + # so validation failures are expected when user didn't provide keys for this provider + try: + validator = instantiate_class_type(spec.provider_data_validator) + validator(**provider_data) + except Exception: + # User didn't provide credentials for this provider - skip silently + continue + + # Validation succeeded! User has credentials for this provider + # Now try to list models + try: + models = await provider.list_models() + if not models: + continue + + # Ensure models have fully qualified identifiers with provider_id prefix + for model in models: + # Only add prefix if model identifier doesn't already have it + if not model.identifier.startswith(f"{provider_id}/"): + model.identifier = f"{provider_id}/{model.provider_resource_id}" + + dynamic_models.append(model) + + logger.debug(f"Fetched {len(models)} models from provider {provider_id} using provider_data") + + except Exception as e: + logger.debug(f"Failed to list models from provider {provider_id} with provider_data: {e}") + continue + + return dynamic_models + async def list_models(self) -> ListModelsResponse: - return ListModelsResponse(data=await self.get_all_with_type("model")) + # Get models from registry + registry_models = await self.get_all_with_type("model") + + # Get additional models available via provider_data (user-specific, not cached) + dynamic_models = await self._get_dynamic_models_from_provider_data() + + # Combine, avoiding duplicates (registry takes precedence) + registry_identifiers = {m.identifier for m in registry_models} + unique_dynamic_models = [m for m in dynamic_models if m.identifier not in registry_identifiers] + + return ListModelsResponse(data=registry_models + unique_dynamic_models) async def openai_list_models(self) -> OpenAIListModelsResponse: - models = await self.get_all_with_type("model") + # Get models from registry + registry_models = await self.get_all_with_type("model") + + # Get additional models available via provider_data (user-specific, not cached) + dynamic_models = await self._get_dynamic_models_from_provider_data() + + # Combine, avoiding duplicates (registry takes precedence) + registry_identifiers = {m.identifier for m in registry_models} + unique_dynamic_models = [m for m in dynamic_models if m.identifier not in registry_identifiers] + + all_models = registry_models + unique_dynamic_models + openai_models = [ OpenAIModel( id=model.identifier, object="model", created=int(time.time()), owned_by="llama_stack", + custom_metadata={ + "model_type": model.model_type, + "provider_id": model.provider_id, + "provider_resource_id": model.provider_resource_id, + **model.metadata, + }, ) - for model in models + for model in all_models ] return OpenAIListModelsResponse(data=openai_models) diff --git a/src/llama_stack/core/stack.py b/src/llama_stack/core/stack.py index eccc562ae..2ff7db6eb 100644 --- a/src/llama_stack/core/stack.py +++ b/src/llama_stack/core/stack.py @@ -14,6 +14,7 @@ from typing import Any import yaml from llama_stack.apis.agents import Agents +from llama_stack.apis.batches import Batches from llama_stack.apis.benchmarks import Benchmarks from llama_stack.apis.conversations import Conversations from llama_stack.apis.datasetio import DatasetIO @@ -30,7 +31,6 @@ from llama_stack.apis.safety import Safety from llama_stack.apis.scoring import Scoring from llama_stack.apis.scoring_functions import ScoringFunctions from llama_stack.apis.shields import Shields -from llama_stack.apis.synthetic_data_generation import SyntheticDataGeneration from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime from llama_stack.apis.vector_io import VectorIO from llama_stack.core.conversations.conversations import ConversationServiceConfig, ConversationServiceImpl @@ -63,8 +63,8 @@ class LlamaStack( Providers, Inference, Agents, + Batches, Safety, - SyntheticDataGeneration, Datasets, PostTraining, VectorIO, diff --git a/src/llama_stack/distributions/dell/doc_template.md b/src/llama_stack/distributions/dell/doc_template.md index 4e28673e8..1530f665a 100644 --- a/src/llama_stack/distributions/dell/doc_template.md +++ b/src/llama_stack/distributions/dell/doc_template.md @@ -152,6 +152,37 @@ docker run \ --port $LLAMA_STACK_PORT ``` +### Via Docker with Custom Run Configuration + +You can also run the Docker container with a custom run configuration file by mounting it into the container: + +```bash +# Set the path to your custom run.yaml file +CUSTOM_RUN_CONFIG=/path/to/your/custom-run.yaml + +docker run -it \ + --pull always \ + --network host \ + -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ + -v $HOME/.llama:/root/.llama \ + -v $CUSTOM_RUN_CONFIG:/app/custom-run.yaml \ + -e RUN_CONFIG_PATH=/app/custom-run.yaml \ + -e INFERENCE_MODEL=$INFERENCE_MODEL \ + -e DEH_URL=$DEH_URL \ + -e CHROMA_URL=$CHROMA_URL \ + llamastack/distribution-{{ name }} \ + --port $LLAMA_STACK_PORT +``` + +**Note**: The run configuration must be mounted into the container before it can be used. The `-v` flag mounts your local file into the container, and the `RUN_CONFIG_PATH` environment variable tells the entrypoint script which configuration to use. + +{% if run_configs %} +Available run configurations for this distribution: +{% for config in run_configs %} +- `{{ config }}` +{% endfor %} +{% endif %} + ### Via Conda Make sure you have done `pip install llama-stack` and have the Llama Stack CLI available. diff --git a/src/llama_stack/distributions/meta-reference-gpu/doc_template.md b/src/llama_stack/distributions/meta-reference-gpu/doc_template.md index ec4452d81..af71d8388 100644 --- a/src/llama_stack/distributions/meta-reference-gpu/doc_template.md +++ b/src/llama_stack/distributions/meta-reference-gpu/doc_template.md @@ -68,6 +68,36 @@ docker run \ --port $LLAMA_STACK_PORT ``` +### Via Docker with Custom Run Configuration + +You can also run the Docker container with a custom run configuration file by mounting it into the container: + +```bash +# Set the path to your custom run.yaml file +CUSTOM_RUN_CONFIG=/path/to/your/custom-run.yaml +LLAMA_STACK_PORT=8321 + +docker run \ + -it \ + --pull always \ + --gpu all \ + -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ + -v ~/.llama:/root/.llama \ + -v $CUSTOM_RUN_CONFIG:/app/custom-run.yaml \ + -e RUN_CONFIG_PATH=/app/custom-run.yaml \ + llamastack/distribution-{{ name }} \ + --port $LLAMA_STACK_PORT +``` + +**Note**: The run configuration must be mounted into the container before it can be used. The `-v` flag mounts your local file into the container, and the `RUN_CONFIG_PATH` environment variable tells the entrypoint script which configuration to use. + +{% if run_configs %} +Available run configurations for this distribution: +{% for config in run_configs %} +- `{{ config }}` +{% endfor %} +{% endif %} + ### Via venv Make sure you have the Llama Stack CLI available. diff --git a/src/llama_stack/distributions/nvidia/doc_template.md b/src/llama_stack/distributions/nvidia/doc_template.md index 40f39e4f3..054a1e3ec 100644 --- a/src/llama_stack/distributions/nvidia/doc_template.md +++ b/src/llama_stack/distributions/nvidia/doc_template.md @@ -117,13 +117,42 @@ docker run \ -it \ --pull always \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ - -v ./run.yaml:/root/my-run.yaml \ + -v ~/.llama:/root/.llama \ -e NVIDIA_API_KEY=$NVIDIA_API_KEY \ llamastack/distribution-{{ name }} \ - --config /root/my-run.yaml \ --port $LLAMA_STACK_PORT ``` +### Via Docker with Custom Run Configuration + +You can also run the Docker container with a custom run configuration file by mounting it into the container: + +```bash +# Set the path to your custom run.yaml file +CUSTOM_RUN_CONFIG=/path/to/your/custom-run.yaml +LLAMA_STACK_PORT=8321 + +docker run \ + -it \ + --pull always \ + -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ + -v ~/.llama:/root/.llama \ + -v $CUSTOM_RUN_CONFIG:/app/custom-run.yaml \ + -e RUN_CONFIG_PATH=/app/custom-run.yaml \ + -e NVIDIA_API_KEY=$NVIDIA_API_KEY \ + llamastack/distribution-{{ name }} \ + --port $LLAMA_STACK_PORT +``` + +**Note**: The run configuration must be mounted into the container before it can be used. The `-v` flag mounts your local file into the container, and the `RUN_CONFIG_PATH` environment variable tells the entrypoint script which configuration to use. + +{% if run_configs %} +Available run configurations for this distribution: +{% for config in run_configs %} +- `{{ config }}` +{% endfor %} +{% endif %} + ### Via venv If you've set up your local development environment, you can also install the distribution dependencies using your local virtual environment. diff --git a/src/llama_stack/distributions/template.py b/src/llama_stack/distributions/template.py index 1dad60064..e6813806a 100644 --- a/src/llama_stack/distributions/template.py +++ b/src/llama_stack/distributions/template.py @@ -424,6 +424,7 @@ class DistributionTemplate(BaseModel): providers_table=providers_table, run_config_env_vars=self.run_config_env_vars, default_models=default_models, + run_configs=list(self.run_configs.keys()), ) return "" diff --git a/src/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/src/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 80ef068c7..a2a49abd3 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -11,6 +11,7 @@ import uuid import warnings from collections.abc import AsyncGenerator from datetime import UTC, datetime +from typing import Any, cast import httpx @@ -125,12 +126,12 @@ class ChatAgent(ShieldRunnerMixin): ) def turn_to_messages(self, turn: Turn) -> list[Message]: - messages = [] + messages: list[Message] = [] # NOTE: if a toolcall response is in a step, we do not add it when processing the input messages tool_call_ids = set() for step in turn.steps: - if step.step_type == StepType.tool_execution.value: + if step.step_type == StepType.tool_execution.value and isinstance(step, ToolExecutionStep): for response in step.tool_responses: tool_call_ids.add(response.call_id) @@ -149,9 +150,9 @@ class ChatAgent(ShieldRunnerMixin): messages.append(msg) for step in turn.steps: - if step.step_type == StepType.inference.value: + if step.step_type == StepType.inference.value and isinstance(step, InferenceStep): messages.append(step.model_response) - elif step.step_type == StepType.tool_execution.value: + elif step.step_type == StepType.tool_execution.value and isinstance(step, ToolExecutionStep): for response in step.tool_responses: messages.append( ToolResponseMessage( @@ -159,8 +160,8 @@ class ChatAgent(ShieldRunnerMixin): content=response.content, ) ) - elif step.step_type == StepType.shield_call.value: - if step.violation: + elif step.step_type == StepType.shield_call.value and isinstance(step, ShieldCallStep): + if step.violation and step.violation.user_message: # CompletionMessage itself in the ShieldResponse messages.append( CompletionMessage( @@ -174,7 +175,7 @@ class ChatAgent(ShieldRunnerMixin): return await self.storage.create_session(name) async def get_messages_from_turns(self, turns: list[Turn]) -> list[Message]: - messages = [] + messages: list[Message] = [] if self.agent_config.instructions != "": messages.append(SystemMessage(content=self.agent_config.instructions)) @@ -231,7 +232,9 @@ class ChatAgent(ShieldRunnerMixin): steps = [] messages = await self.get_messages_from_turns(turns) + if is_resume: + assert isinstance(request, AgentTurnResumeRequest) tool_response_messages = [ ToolResponseMessage(call_id=x.call_id, content=x.content) for x in request.tool_responses ] @@ -252,42 +255,52 @@ class ChatAgent(ShieldRunnerMixin): in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step( request.session_id, request.turn_id ) - now = datetime.now(UTC).isoformat() + now_dt = datetime.now(UTC) tool_execution_step = ToolExecutionStep( step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())), turn_id=request.turn_id, tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []), tool_responses=request.tool_responses, - completed_at=now, - started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now), + completed_at=now_dt, + started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now_dt), ) steps.append(tool_execution_step) yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepCompletePayload( - step_type=StepType.tool_execution.value, + step_type=StepType.tool_execution, step_id=tool_execution_step.step_id, step_details=tool_execution_step, ) ) ) - input_messages = last_turn.input_messages + # Cast needed due to list invariance - last_turn.input_messages is the right type + input_messages = last_turn.input_messages # type: ignore[assignment] - turn_id = request.turn_id + actual_turn_id = request.turn_id start_time = last_turn.started_at else: + assert isinstance(request, AgentTurnCreateRequest) messages.extend(request.messages) - start_time = datetime.now(UTC).isoformat() - input_messages = request.messages + start_time = datetime.now(UTC) + # Cast needed due to list invariance - request.messages is the right type + input_messages = request.messages # type: ignore[assignment] + # Use the generated turn_id from beginning of function + actual_turn_id = turn_id if turn_id else str(uuid.uuid4()) output_message = None + req_documents = request.documents if isinstance(request, AgentTurnCreateRequest) and not is_resume else None + req_sampling = ( + self.agent_config.sampling_params if self.agent_config.sampling_params is not None else SamplingParams() + ) + async for chunk in self.run( session_id=request.session_id, - turn_id=turn_id, + turn_id=actual_turn_id, input_messages=messages, - sampling_params=self.agent_config.sampling_params, + sampling_params=req_sampling, stream=request.stream, - documents=request.documents if not is_resume else None, + documents=req_documents, ): if isinstance(chunk, CompletionMessage): output_message = chunk @@ -295,20 +308,23 @@ class ChatAgent(ShieldRunnerMixin): assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}" event = chunk.event - if event.payload.event_type == AgentTurnResponseEventType.step_complete.value: - steps.append(event.payload.step_details) + if event.payload.event_type == AgentTurnResponseEventType.step_complete.value and hasattr( + event.payload, "step_details" + ): + step_details = event.payload.step_details + steps.append(step_details) yield chunk assert output_message is not None turn = Turn( - turn_id=turn_id, + turn_id=actual_turn_id, session_id=request.session_id, - input_messages=input_messages, + input_messages=input_messages, # type: ignore[arg-type] output_message=output_message, started_at=start_time, - completed_at=datetime.now(UTC).isoformat(), + completed_at=datetime.now(UTC), steps=steps, ) await self.storage.add_turn_to_session(request.session_id, turn) @@ -345,9 +361,9 @@ class ChatAgent(ShieldRunnerMixin): # return a "final value" for the `yield from` statement. we simulate that by yielding a # final boolean (to see whether an exception happened) and then explicitly testing for it. - if len(self.input_shields) > 0: + if self.input_shields: async for res in self.run_multiple_shields_wrapper( - turn_id, input_messages, self.input_shields, "user-input" + turn_id, cast(list[OpenAIMessageParam], input_messages), self.input_shields, "user-input" ): if isinstance(res, bool): return @@ -374,9 +390,9 @@ class ChatAgent(ShieldRunnerMixin): # for output shields run on the full input and output combination messages = input_messages + [final_response] - if len(self.output_shields) > 0: + if self.output_shields: async for res in self.run_multiple_shields_wrapper( - turn_id, messages, self.output_shields, "assistant-output" + turn_id, cast(list[OpenAIMessageParam], messages), self.output_shields, "assistant-output" ): if isinstance(res, bool): return @@ -388,7 +404,7 @@ class ChatAgent(ShieldRunnerMixin): async def run_multiple_shields_wrapper( self, turn_id: str, - messages: list[Message], + messages: list[OpenAIMessageParam], shields: list[str], touchpoint: str, ) -> AsyncGenerator: @@ -402,12 +418,12 @@ class ChatAgent(ShieldRunnerMixin): return step_id = str(uuid.uuid4()) - shield_call_start_time = datetime.now(UTC).isoformat() + shield_call_start_time = datetime.now(UTC) try: yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepStartPayload( - step_type=StepType.shield_call.value, + step_type=StepType.shield_call, step_id=step_id, metadata=dict(touchpoint=touchpoint), ) @@ -419,14 +435,14 @@ class ChatAgent(ShieldRunnerMixin): yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepCompletePayload( - step_type=StepType.shield_call.value, + step_type=StepType.shield_call, step_id=step_id, step_details=ShieldCallStep( step_id=step_id, turn_id=turn_id, violation=e.violation, started_at=shield_call_start_time, - completed_at=datetime.now(UTC).isoformat(), + completed_at=datetime.now(UTC), ), ) ) @@ -443,14 +459,14 @@ class ChatAgent(ShieldRunnerMixin): yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepCompletePayload( - step_type=StepType.shield_call.value, + step_type=StepType.shield_call, step_id=step_id, step_details=ShieldCallStep( step_id=step_id, turn_id=turn_id, violation=None, started_at=shield_call_start_time, - completed_at=datetime.now(UTC).isoformat(), + completed_at=datetime.now(UTC), ), ) ) @@ -496,21 +512,22 @@ class ChatAgent(ShieldRunnerMixin): else: self.tool_name_to_args[tool_name]["vector_store_ids"].append(session_info.vector_store_id) - output_attachments = [] + output_attachments: list[Attachment] = [] n_iter = await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0 # Build a map of custom tools to their definitions for faster lookup client_tools = {} - for tool in self.agent_config.client_tools: - client_tools[tool.name] = tool + if self.agent_config.client_tools: + for tool in self.agent_config.client_tools: + client_tools[tool.name] = tool while True: step_id = str(uuid.uuid4()) - inference_start_time = datetime.now(UTC).isoformat() + inference_start_time = datetime.now(UTC) yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepStartPayload( - step_type=StepType.inference.value, + step_type=StepType.inference, step_id=step_id, ) ) @@ -538,7 +555,7 @@ class ChatAgent(ShieldRunnerMixin): else: return value - def _add_type(openai_msg: dict) -> OpenAIMessageParam: + def _add_type(openai_msg: Any) -> OpenAIMessageParam: # Serialize any nested Pydantic models to plain dicts openai_msg = _serialize_nested(openai_msg) @@ -588,7 +605,7 @@ class ChatAgent(ShieldRunnerMixin): messages=openai_messages, tools=openai_tools if openai_tools else None, tool_choice=tool_choice, - response_format=self.agent_config.response_format, + response_format=self.agent_config.response_format, # type: ignore[arg-type] temperature=temperature, top_p=top_p, max_tokens=max_tokens, @@ -598,7 +615,8 @@ class ChatAgent(ShieldRunnerMixin): # Convert OpenAI stream back to Llama Stack format response_stream = convert_openai_chat_completion_stream( - openai_stream, enable_incremental_tool_calls=True + openai_stream, # type: ignore[arg-type] + enable_incremental_tool_calls=True, ) async for chunk in response_stream: @@ -620,7 +638,7 @@ class ChatAgent(ShieldRunnerMixin): yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepProgressPayload( - step_type=StepType.inference.value, + step_type=StepType.inference, step_id=step_id, delta=delta, ) @@ -633,7 +651,7 @@ class ChatAgent(ShieldRunnerMixin): yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepProgressPayload( - step_type=StepType.inference.value, + step_type=StepType.inference, step_id=step_id, delta=delta, ) @@ -651,7 +669,9 @@ class ChatAgent(ShieldRunnerMixin): output_attr = json.dumps( { "content": content, - "tool_calls": [json.loads(t.model_dump_json()) for t in tool_calls], + "tool_calls": [ + json.loads(t.model_dump_json()) for t in tool_calls if isinstance(t, ToolCall) + ], } ) span.set_attribute("output", output_attr) @@ -667,16 +687,18 @@ class ChatAgent(ShieldRunnerMixin): if tool_calls: content = "" + # Filter out string tool calls for CompletionMessage (only keep ToolCall objects) + valid_tool_calls = [t for t in tool_calls if isinstance(t, ToolCall)] message = CompletionMessage( content=content, stop_reason=stop_reason, - tool_calls=tool_calls, + tool_calls=valid_tool_calls if valid_tool_calls else None, ) yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepCompletePayload( - step_type=StepType.inference.value, + step_type=StepType.inference, step_id=step_id, step_details=InferenceStep( # somewhere deep, we are re-assigning message or closing over some @@ -686,13 +708,14 @@ class ChatAgent(ShieldRunnerMixin): turn_id=turn_id, model_response=copy.deepcopy(message), started_at=inference_start_time, - completed_at=datetime.now(UTC).isoformat(), + completed_at=datetime.now(UTC), ), ) ) ) - if n_iter >= self.agent_config.max_infer_iters: + max_iters = self.agent_config.max_infer_iters if self.agent_config.max_infer_iters is not None else 10 + if n_iter >= max_iters: logger.info(f"done with MAX iterations ({n_iter}), exiting.") # NOTE: mark end_of_turn to indicate to client that we are done with the turn # Do not continue the tool call loop after this point @@ -705,14 +728,16 @@ class ChatAgent(ShieldRunnerMixin): yield message break - if len(message.tool_calls) == 0: + if not message.tool_calls or len(message.tool_calls) == 0: if stop_reason == StopReason.end_of_turn: # TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS) if len(output_attachments) > 0: if isinstance(message.content, list): - message.content += output_attachments + # List invariance - attachments are compatible at runtime + message.content += output_attachments # type: ignore[arg-type] else: - message.content = [message.content] + output_attachments + # List invariance - attachments are compatible at runtime + message.content = [message.content] + output_attachments # type: ignore[assignment] yield message else: logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}") @@ -725,11 +750,12 @@ class ChatAgent(ShieldRunnerMixin): non_client_tool_calls = [] # Separate client and non-client tool calls - for tool_call in message.tool_calls: - if tool_call.tool_name in client_tools: - client_tool_calls.append(tool_call) - else: - non_client_tool_calls.append(tool_call) + if message.tool_calls: + for tool_call in message.tool_calls: + if tool_call.tool_name in client_tools: + client_tool_calls.append(tool_call) + else: + non_client_tool_calls.append(tool_call) # Process non-client tool calls first for tool_call in non_client_tool_calls: @@ -737,7 +763,7 @@ class ChatAgent(ShieldRunnerMixin): yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepStartPayload( - step_type=StepType.tool_execution.value, + step_type=StepType.tool_execution, step_id=step_id, ) ) @@ -746,7 +772,7 @@ class ChatAgent(ShieldRunnerMixin): yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepProgressPayload( - step_type=StepType.tool_execution.value, + step_type=StepType.tool_execution, step_id=step_id, delta=ToolCallDelta( parse_status=ToolCallParseStatus.in_progress, @@ -766,7 +792,7 @@ class ChatAgent(ShieldRunnerMixin): if self.telemetry_enabled else {}, ) as span: - tool_execution_start_time = datetime.now(UTC).isoformat() + tool_execution_start_time = datetime.now(UTC) tool_result = await self.execute_tool_call_maybe( session_id, tool_call, @@ -796,14 +822,14 @@ class ChatAgent(ShieldRunnerMixin): ) ], started_at=tool_execution_start_time, - completed_at=datetime.now(UTC).isoformat(), + completed_at=datetime.now(UTC), ) # Yield the step completion event yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepCompletePayload( - step_type=StepType.tool_execution.value, + step_type=StepType.tool_execution, step_id=step_id, step_details=tool_execution_step, ) @@ -833,7 +859,7 @@ class ChatAgent(ShieldRunnerMixin): turn_id=turn_id, tool_calls=client_tool_calls, tool_responses=[], - started_at=datetime.now(UTC).isoformat(), + started_at=datetime.now(UTC), ), ) @@ -868,19 +894,20 @@ class ChatAgent(ShieldRunnerMixin): toolgroup_to_args = toolgroup_to_args or {} - tool_name_to_def = {} + tool_name_to_def: dict[str, ToolDefinition] = {} tool_name_to_args = {} - for tool_def in self.agent_config.client_tools: - if tool_name_to_def.get(tool_def.name, None): - raise ValueError(f"Tool {tool_def.name} already exists") + if self.agent_config.client_tools: + for tool_def in self.agent_config.client_tools: + if tool_name_to_def.get(tool_def.name, None): + raise ValueError(f"Tool {tool_def.name} already exists") - # Use input_schema from ToolDef directly - tool_name_to_def[tool_def.name] = ToolDefinition( - tool_name=tool_def.name, - description=tool_def.description, - input_schema=tool_def.input_schema, - ) + # Use input_schema from ToolDef directly + tool_name_to_def[tool_def.name] = ToolDefinition( + tool_name=tool_def.name, + description=tool_def.description, + input_schema=tool_def.input_schema, + ) for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups: toolgroup_name, input_tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name) tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name) @@ -908,15 +935,17 @@ class ChatAgent(ShieldRunnerMixin): else: identifier = None - if tool_name_to_def.get(identifier, None): - raise ValueError(f"Tool {identifier} already exists") if identifier: - tool_name_to_def[identifier] = ToolDefinition( - tool_name=identifier, + # Convert BuiltinTool to string for dictionary key + identifier_str = identifier.value if isinstance(identifier, BuiltinTool) else identifier + if tool_name_to_def.get(identifier_str, None): + raise ValueError(f"Tool {identifier_str} already exists") + tool_name_to_def[identifier_str] = ToolDefinition( + tool_name=identifier_str, description=tool_def.description, input_schema=tool_def.input_schema, ) - tool_name_to_args[identifier] = toolgroup_to_args.get(toolgroup_name, {}) + tool_name_to_args[identifier_str] = toolgroup_to_args.get(toolgroup_name, {}) self.tool_defs, self.tool_name_to_args = ( list(tool_name_to_def.values()), @@ -966,14 +995,17 @@ class ChatAgent(ShieldRunnerMixin): except json.JSONDecodeError as e: raise ValueError(f"Failed to parse arguments for tool call: {tool_call.arguments}") from e - result = await self.tool_runtime_api.invoke_tool( - tool_name=tool_name_str, - kwargs={ - "session_id": session_id, - # get the arguments generated by the model and augment with toolgroup arg overrides for the agent - **args, - **self.tool_name_to_args.get(tool_name_str, {}), - }, + result = cast( + ToolInvocationResult, + await self.tool_runtime_api.invoke_tool( + tool_name=tool_name_str, + kwargs={ + "session_id": session_id, + # get the arguments generated by the model and augment with toolgroup arg overrides for the agent + **args, + **self.tool_name_to_args.get(tool_name_str, {}), + }, + ), ) logger.debug(f"tool call {tool_name_str} completed with result: {result}") return result @@ -983,7 +1015,7 @@ async def load_data_from_url(url: str) -> str: if url.startswith("http"): async with httpx.AsyncClient() as client: r = await client.get(url) - resp = r.text + resp: str = r.text return resp raise ValueError(f"Unexpected URL: {type(url)}") @@ -1017,7 +1049,7 @@ def _interpret_content_as_attachment( snippet = match.group(1) data = json.loads(snippet) return Attachment( - url=URL(uri="file://" + data["filepath"]), + content=URL(uri="file://" + data["filepath"]), mime_type=data["mimetype"], ) diff --git a/src/llama_stack/providers/inline/agents/meta_reference/agents.py b/src/llama_stack/providers/inline/agents/meta_reference/agents.py index b4b77bacd..85c6cb251 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -21,6 +21,7 @@ from llama_stack.apis.agents import ( Document, ListOpenAIResponseInputItem, ListOpenAIResponseObject, + OpenAIDeleteResponseObject, OpenAIResponseInput, OpenAIResponseInputTool, OpenAIResponseObject, @@ -141,7 +142,7 @@ class MetaReferenceAgentsImpl(Agents): persistence_store=( self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store ), - created_at=agent_info.created_at, + created_at=agent_info.created_at.isoformat(), policy=self.policy, telemetry_enabled=self.telemetry_enabled, ) @@ -163,9 +164,9 @@ class MetaReferenceAgentsImpl(Agents): agent_id: str, session_id: str, messages: list[UserMessage | ToolResponseMessage], - toolgroups: list[AgentToolGroup] | None = None, - documents: list[Document] | None = None, stream: bool | None = False, + documents: list[Document] | None = None, + toolgroups: list[AgentToolGroup] | None = None, tool_config: ToolConfig | None = None, ) -> AsyncGenerator: request = AgentTurnCreateRequest( @@ -221,6 +222,8 @@ class MetaReferenceAgentsImpl(Agents): async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn: agent = await self._get_agent_impl(agent_id) turn = await agent.storage.get_session_turn(session_id, turn_id) + if turn is None: + raise ValueError(f"Turn {turn_id} not found in session {session_id}") return turn async def get_agents_step(self, agent_id: str, session_id: str, turn_id: str, step_id: str) -> AgentStepResponse: @@ -232,13 +235,15 @@ class MetaReferenceAgentsImpl(Agents): async def get_agents_session( self, - agent_id: str, session_id: str, + agent_id: str, turn_ids: list[str] | None = None, ) -> Session: agent = await self._get_agent_impl(agent_id) session_info = await agent.storage.get_session_info(session_id) + if session_info is None: + raise ValueError(f"Session {session_id} not found") turns = await agent.storage.get_session_turns(session_id) if turn_ids: turns = [turn for turn in turns if turn.turn_id in turn_ids] @@ -249,7 +254,7 @@ class MetaReferenceAgentsImpl(Agents): started_at=session_info.started_at, ) - async def delete_agents_session(self, agent_id: str, session_id: str) -> None: + async def delete_agents_session(self, session_id: str, agent_id: str) -> None: agent = await self._get_agent_impl(agent_id) # Delete turns first, then the session @@ -302,7 +307,7 @@ class MetaReferenceAgentsImpl(Agents): agent = Agent( agent_id=agent_id, agent_config=chat_agent.agent_config, - created_at=chat_agent.created_at, + created_at=datetime.fromisoformat(chat_agent.created_at), ) return agent @@ -323,6 +328,7 @@ class MetaReferenceAgentsImpl(Agents): self, response_id: str, ) -> OpenAIResponseObject: + assert self.openai_responses_impl is not None, "OpenAI responses not initialized" return await self.openai_responses_impl.get_openai_response(response_id) async def create_openai_response( @@ -342,7 +348,8 @@ class MetaReferenceAgentsImpl(Agents): max_infer_iters: int | None = 10, guardrails: list[ResponseGuardrail] | None = None, ) -> OpenAIResponseObject: - return await self.openai_responses_impl.create_openai_response( + assert self.openai_responses_impl is not None, "OpenAI responses not initialized" + result = await self.openai_responses_impl.create_openai_response( input, model, prompt, @@ -358,6 +365,7 @@ class MetaReferenceAgentsImpl(Agents): max_infer_iters, guardrails, ) + return result # type: ignore[no-any-return] async def list_openai_responses( self, @@ -366,6 +374,7 @@ class MetaReferenceAgentsImpl(Agents): model: str | None = None, order: Order | None = Order.desc, ) -> ListOpenAIResponseObject: + assert self.openai_responses_impl is not None, "OpenAI responses not initialized" return await self.openai_responses_impl.list_openai_responses(after, limit, model, order) async def list_openai_response_input_items( @@ -377,9 +386,11 @@ class MetaReferenceAgentsImpl(Agents): limit: int | None = 20, order: Order | None = Order.desc, ) -> ListOpenAIResponseInputItem: + assert self.openai_responses_impl is not None, "OpenAI responses not initialized" return await self.openai_responses_impl.list_openai_response_input_items( response_id, after, before, include, limit, order ) - async def delete_openai_response(self, response_id: str) -> None: + async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject: + assert self.openai_responses_impl is not None, "OpenAI responses not initialized" return await self.openai_responses_impl.delete_openai_response(response_id) diff --git a/src/llama_stack/providers/inline/agents/meta_reference/persistence.py b/src/llama_stack/providers/inline/agents/meta_reference/persistence.py index 26a2151e3..9e0598bf1 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/persistence.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/persistence.py @@ -6,12 +6,14 @@ import json import uuid +from dataclasses import dataclass from datetime import UTC, datetime from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn from llama_stack.apis.common.errors import SessionNotFoundError from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed -from llama_stack.core.access_control.datatypes import AccessRule +from llama_stack.core.access_control.conditions import User as ProtocolUser +from llama_stack.core.access_control.datatypes import AccessRule, Action from llama_stack.core.datatypes import User from llama_stack.core.request_headers import get_authenticated_user from llama_stack.log import get_logger @@ -33,6 +35,15 @@ class AgentInfo(AgentConfig): created_at: datetime +@dataclass +class SessionResource: + """Concrete implementation of ProtectedResource for session access control.""" + + type: str + identifier: str + owner: ProtocolUser # Use the protocol type for structural compatibility + + class AgentPersistence: def __init__(self, agent_id: str, kvstore: KVStore, policy: list[AccessRule]): self.agent_id = agent_id @@ -53,8 +64,15 @@ class AgentPersistence: turns=[], identifier=name, # should this be qualified in any way? ) - if not is_action_allowed(self.policy, "create", session_info, user): - raise AccessDeniedError("create", session_info, user) + # Only perform access control if we have an authenticated user + if user is not None and session_info.identifier is not None: + resource = SessionResource( + type=session_info.type, + identifier=session_info.identifier, + owner=user, + ) + if not is_action_allowed(self.policy, Action.CREATE, resource, user): + raise AccessDeniedError(Action.CREATE, resource, user) await self.kvstore.set( key=f"session:{self.agent_id}:{session_id}", @@ -62,7 +80,7 @@ class AgentPersistence: ) return session_id - async def get_session_info(self, session_id: str) -> AgentSessionInfo: + async def get_session_info(self, session_id: str) -> AgentSessionInfo | None: value = await self.kvstore.get( key=f"session:{self.agent_id}:{session_id}", ) @@ -83,7 +101,22 @@ class AgentPersistence: if not hasattr(session_info, "access_attributes") and not hasattr(session_info, "owner"): return True - return is_action_allowed(self.policy, "read", session_info, get_authenticated_user()) + # Get current user - if None, skip access control (e.g., in tests) + user = get_authenticated_user() + if user is None: + return True + + # Access control requires identifier and owner to be set + if session_info.identifier is None or session_info.owner is None: + return True + + # At this point, both identifier and owner are guaranteed to be non-None + resource = SessionResource( + type=session_info.type, + identifier=session_info.identifier, + owner=session_info.owner, + ) + return is_action_allowed(self.policy, Action.READ, resource, user) async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None: """Get session info if the user has access to it. For internal use by sub-session methods.""" diff --git a/src/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py b/src/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py index 524ca1b0e..933cfe963 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py @@ -91,7 +91,8 @@ class OpenAIResponsesImpl: input: str | list[OpenAIResponseInput], previous_response: _OpenAIResponseObjectWithInputAndMessages, ): - new_input_items = previous_response.input.copy() + # Convert Sequence to list for mutation + new_input_items = list(previous_response.input) new_input_items.extend(previous_response.output) if isinstance(input, str): @@ -107,7 +108,7 @@ class OpenAIResponsesImpl: tools: list[OpenAIResponseInputTool] | None, previous_response_id: str | None, conversation: str | None, - ) -> tuple[str | list[OpenAIResponseInput], list[OpenAIMessageParam]]: + ) -> tuple[str | list[OpenAIResponseInput], list[OpenAIMessageParam], ToolContext]: """Process input with optional previous response context. Returns: @@ -208,6 +209,9 @@ class OpenAIResponsesImpl: messages: list[OpenAIMessageParam], ) -> None: new_input_id = f"msg_{uuid.uuid4()}" + # Type input_items_data as the full OpenAIResponseInput union to avoid list invariance issues + input_items_data: list[OpenAIResponseInput] = [] + if isinstance(input, str): # synthesize a message from the input string input_content = OpenAIResponseInputMessageContentText(text=input) @@ -219,7 +223,6 @@ class OpenAIResponsesImpl: input_items_data = [input_content_item] else: # we already have a list of messages - input_items_data = [] for input_item in input: if isinstance(input_item, OpenAIResponseMessage): # These may or may not already have an id, so dump to dict, check for id, and add if missing @@ -251,7 +254,7 @@ class OpenAIResponsesImpl: tools: list[OpenAIResponseInputTool] | None = None, include: list[str] | None = None, max_infer_iters: int | None = 10, - guardrails: list[ResponseGuardrailSpec] | None = None, + guardrails: list[str | ResponseGuardrailSpec] | None = None, ): stream = bool(stream) text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text @@ -289,16 +292,19 @@ class OpenAIResponsesImpl: failed_response = None async for stream_chunk in stream_gen: - if stream_chunk.type in {"response.completed", "response.incomplete"}: - if final_response is not None: - raise ValueError( - "The response stream produced multiple terminal responses! " - f"Earlier response from {final_event_type}" - ) - final_response = stream_chunk.response - final_event_type = stream_chunk.type - elif stream_chunk.type == "response.failed": - failed_response = stream_chunk.response + match stream_chunk.type: + case "response.completed" | "response.incomplete": + if final_response is not None: + raise ValueError( + "The response stream produced multiple terminal responses! " + f"Earlier response from {final_event_type}" + ) + final_response = stream_chunk.response + final_event_type = stream_chunk.type + case "response.failed": + failed_response = stream_chunk.response + case _: + pass # Other event types don't have .response if failed_response is not None: error_message = ( @@ -326,6 +332,11 @@ class OpenAIResponsesImpl: max_infer_iters: int | None = 10, guardrail_ids: list[str] | None = None, ) -> AsyncIterator[OpenAIResponseObjectStream]: + # These should never be None when called from create_openai_response (which sets defaults) + # but we assert here to help mypy understand the types + assert text is not None, "text must not be None" + assert max_infer_iters is not None, "max_infer_iters must not be None" + # Input preprocessing all_input, messages, tool_context = await self._process_input_with_previous_response( input, tools, previous_response_id, conversation @@ -368,16 +379,19 @@ class OpenAIResponsesImpl: final_response = None failed_response = None - output_items = [] + # Type as ConversationItem to avoid list invariance issues + output_items: list[ConversationItem] = [] async for stream_chunk in orchestrator.create_response(): - if stream_chunk.type in {"response.completed", "response.incomplete"}: - final_response = stream_chunk.response - elif stream_chunk.type == "response.failed": - failed_response = stream_chunk.response - - if stream_chunk.type == "response.output_item.done": - item = stream_chunk.item - output_items.append(item) + match stream_chunk.type: + case "response.completed" | "response.incomplete": + final_response = stream_chunk.response + case "response.failed": + failed_response = stream_chunk.response + case "response.output_item.done": + item = stream_chunk.item + output_items.append(item) + case _: + pass # Other event types # Store and sync before yielding terminal events # This ensures the storage/syncing happens even if the consumer breaks after receiving the event @@ -410,7 +424,8 @@ class OpenAIResponsesImpl: self, conversation_id: str, input: str | list[OpenAIResponseInput] | None, output_items: list[ConversationItem] ) -> None: """Sync content and response messages to the conversation.""" - conversation_items = [] + # Type as ConversationItem union to avoid list invariance issues + conversation_items: list[ConversationItem] = [] if isinstance(input, str): conversation_items.append( diff --git a/src/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py b/src/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py index 2cbfead40..ef5603420 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py @@ -111,7 +111,7 @@ class StreamingResponseOrchestrator: text: OpenAIResponseText, max_infer_iters: int, tool_executor, # Will be the tool execution logic from the main class - instructions: str, + instructions: str | None, safety_api, guardrail_ids: list[str] | None = None, prompt: OpenAIResponsePrompt | None = None, @@ -128,7 +128,9 @@ class StreamingResponseOrchestrator: self.prompt = prompt self.sequence_number = 0 # Store MCP tool mapping that gets built during tool processing - self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = ctx.tool_context.previous_tools or {} + self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = ( + ctx.tool_context.previous_tools if ctx.tool_context else {} + ) # Track final messages after all tool executions self.final_messages: list[OpenAIMessageParam] = [] # mapping for annotations @@ -229,7 +231,8 @@ class StreamingResponseOrchestrator: params = OpenAIChatCompletionRequestWithExtraBody( model=self.ctx.model, messages=messages, - tools=self.ctx.chat_tools, + # Pydantic models are dict-compatible but mypy treats them as distinct types + tools=self.ctx.chat_tools, # type: ignore[arg-type] stream=True, temperature=self.ctx.temperature, response_format=response_format, @@ -272,7 +275,12 @@ class StreamingResponseOrchestrator: # Handle choices with no tool calls for choice in current_response.choices: - if not (choice.message.tool_calls and self.ctx.response_tools): + has_tool_calls = ( + isinstance(choice.message, OpenAIAssistantMessageParam) + and choice.message.tool_calls + and self.ctx.response_tools + ) + if not has_tool_calls: output_messages.append( await convert_chat_choice_to_response_message( choice, @@ -722,7 +730,10 @@ class StreamingResponseOrchestrator: ) # Accumulate arguments for final response (only for subsequent chunks) - if not is_new_tool_call: + if not is_new_tool_call and response_tool_call is not None: + # Both should have functions since we're inside the tool_call.function check above + assert response_tool_call.function is not None + assert tool_call.function is not None response_tool_call.function.arguments = ( response_tool_call.function.arguments or "" ) + tool_call.function.arguments @@ -747,10 +758,13 @@ class StreamingResponseOrchestrator: for tool_call_index in sorted(chat_response_tool_calls.keys()): tool_call = chat_response_tool_calls[tool_call_index] # Ensure that arguments, if sent back to the inference provider, are not None - tool_call.function.arguments = tool_call.function.arguments or "{}" + if tool_call.function: + tool_call.function.arguments = tool_call.function.arguments or "{}" tool_call_item_id = tool_call_item_ids[tool_call_index] - final_arguments = tool_call.function.arguments - tool_call_name = chat_response_tool_calls[tool_call_index].function.name + final_arguments: str = tool_call.function.arguments or "{}" if tool_call.function else "{}" + func = chat_response_tool_calls[tool_call_index].function + + tool_call_name = func.name if func else "" # Check if this is an MCP tool call is_mcp_tool = tool_call_name and tool_call_name in self.mcp_tool_to_server @@ -894,12 +908,11 @@ class StreamingResponseOrchestrator: self.sequence_number += 1 if tool_call.function.name and tool_call.function.name in self.mcp_tool_to_server: - item = OpenAIResponseOutputMessageMCPCall( + item: OpenAIResponseOutput = OpenAIResponseOutputMessageMCPCall( arguments="", name=tool_call.function.name, id=matching_item_id, server_label=self.mcp_tool_to_server[tool_call.function.name].server_label, - status="in_progress", ) elif tool_call.function.name == "web_search": item = OpenAIResponseOutputMessageWebSearchToolCall( @@ -1008,7 +1021,7 @@ class StreamingResponseOrchestrator: description=tool.description, input_schema=tool.input_schema, ) - return convert_tooldef_to_openai_tool(tool_def) + return convert_tooldef_to_openai_tool(tool_def) # type: ignore[return-value] # Returns dict but ChatCompletionToolParam expects TypedDict # Initialize chat_tools if not already set if self.ctx.chat_tools is None: @@ -1016,7 +1029,7 @@ class StreamingResponseOrchestrator: for input_tool in tools: if input_tool.type == "function": - self.ctx.chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump())) + self.ctx.chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump())) # type: ignore[typeddict-item,arg-type] # Dict compatible with FunctionDefinition elif input_tool.type in WebSearchToolTypes: tool_name = "web_search" # Need to access tool_groups_api from tool_executor @@ -1055,8 +1068,8 @@ class StreamingResponseOrchestrator: if isinstance(mcp_tool.allowed_tools, list): always_allowed = mcp_tool.allowed_tools elif isinstance(mcp_tool.allowed_tools, AllowedToolsFilter): - always_allowed = mcp_tool.allowed_tools.always - never_allowed = mcp_tool.allowed_tools.never + # AllowedToolsFilter only has tool_names field (not allowed/disallowed) + always_allowed = mcp_tool.allowed_tools.tool_names # Call list_mcp_tools tool_defs = None @@ -1088,7 +1101,7 @@ class StreamingResponseOrchestrator: openai_tool = convert_tooldef_to_chat_tool(t) if self.ctx.chat_tools is None: self.ctx.chat_tools = [] - self.ctx.chat_tools.append(openai_tool) + self.ctx.chat_tools.append(openai_tool) # type: ignore[arg-type] # Returns dict but ChatCompletionToolParam expects TypedDict # Add to MCP tool mapping if t.name in self.mcp_tool_to_server: @@ -1120,13 +1133,17 @@ class StreamingResponseOrchestrator: self, output_messages: list[OpenAIResponseOutput] ) -> AsyncIterator[OpenAIResponseObjectStream]: # Handle all mcp tool lists from previous response that are still valid: - for tool in self.ctx.tool_context.previous_tool_listings: - async for evt in self._reuse_mcp_list_tools(tool, output_messages): - yield evt - # Process all remaining tools (including MCP tools) and emit streaming events - if self.ctx.tool_context.tools_to_process: - async for stream_event in self._process_new_tools(self.ctx.tool_context.tools_to_process, output_messages): - yield stream_event + # tool_context can be None when no tools are provided in the response request + if self.ctx.tool_context: + for tool in self.ctx.tool_context.previous_tool_listings: + async for evt in self._reuse_mcp_list_tools(tool, output_messages): + yield evt + # Process all remaining tools (including MCP tools) and emit streaming events + if self.ctx.tool_context.tools_to_process: + async for stream_event in self._process_new_tools( + self.ctx.tool_context.tools_to_process, output_messages + ): + yield stream_event def _approval_required(self, tool_name: str) -> bool: if tool_name not in self.mcp_tool_to_server: @@ -1220,7 +1237,7 @@ class StreamingResponseOrchestrator: openai_tool = convert_tooldef_to_openai_tool(tool_def) if self.ctx.chat_tools is None: self.ctx.chat_tools = [] - self.ctx.chat_tools.append(openai_tool) + self.ctx.chat_tools.append(openai_tool) # type: ignore[arg-type] # Returns dict but ChatCompletionToolParam expects TypedDict mcp_list_message = OpenAIResponseOutputMessageMCPListTools( id=f"mcp_list_{uuid.uuid4()}", diff --git a/src/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py b/src/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py index 8e0dc9ecb..09a161d50 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py @@ -7,6 +7,7 @@ import asyncio import json from collections.abc import AsyncIterator +from typing import Any from llama_stack.apis.agents.openai_responses import ( OpenAIResponseInputToolFileSearch, @@ -22,6 +23,7 @@ from llama_stack.apis.agents.openai_responses import ( OpenAIResponseObjectStreamResponseWebSearchCallSearching, OpenAIResponseOutputMessageFileSearchToolCall, OpenAIResponseOutputMessageFileSearchToolCallResults, + OpenAIResponseOutputMessageMCPCall, OpenAIResponseOutputMessageWebSearchToolCall, ) from llama_stack.apis.common.content_types import ( @@ -67,7 +69,7 @@ class ToolExecutor: ) -> AsyncIterator[ToolExecutionResult]: tool_call_id = tool_call.id function = tool_call.function - tool_kwargs = json.loads(function.arguments) if function.arguments else {} + tool_kwargs = json.loads(function.arguments) if function and function.arguments else {} if not function or not tool_call_id or not function.name: yield ToolExecutionResult(sequence_number=sequence_number) @@ -84,7 +86,16 @@ class ToolExecutor: error_exc, result = await self._execute_tool(function.name, tool_kwargs, ctx, mcp_tool_to_server) # Emit completion events for tool execution - has_error = error_exc or (result and ((result.error_code and result.error_code > 0) or result.error_message)) + has_error = bool( + error_exc + or ( + result + and ( + ((error_code := getattr(result, "error_code", None)) and error_code > 0) + or getattr(result, "error_message", None) + ) + ) + ) async for event_result in self._emit_completion_events( function.name, ctx, sequence_number, output_index, item_id, has_error, mcp_tool_to_server ): @@ -101,7 +112,9 @@ class ToolExecutor: sequence_number=sequence_number, final_output_message=output_message, final_input_message=input_message, - citation_files=result.metadata.get("citation_files") if result and result.metadata else None, + citation_files=( + metadata.get("citation_files") if result and (metadata := getattr(result, "metadata", None)) else None + ), ) async def _execute_knowledge_search_via_vector_store( @@ -188,8 +201,9 @@ class ToolExecutor: citation_files[file_id] = filename + # Cast to proper InterleavedContent type (list invariance) return ToolInvocationResult( - content=content_items, + content=content_items, # type: ignore[arg-type] metadata={ "document_ids": [r.file_id for r in search_results], "chunks": [r.content[0].text if r.content else "" for r in search_results], @@ -209,51 +223,60 @@ class ToolExecutor: ) -> AsyncIterator[ToolExecutionResult]: """Emit progress events for tool execution start.""" # Emit in_progress event based on tool type (only for tools with specific streaming events) - progress_event = None if mcp_tool_to_server and function_name in mcp_tool_to_server: sequence_number += 1 - progress_event = OpenAIResponseObjectStreamResponseMcpCallInProgress( - item_id=item_id, - output_index=output_index, + yield ToolExecutionResult( + stream_event=OpenAIResponseObjectStreamResponseMcpCallInProgress( + item_id=item_id, + output_index=output_index, + sequence_number=sequence_number, + ), sequence_number=sequence_number, ) elif function_name == "web_search": sequence_number += 1 - progress_event = OpenAIResponseObjectStreamResponseWebSearchCallInProgress( - item_id=item_id, - output_index=output_index, + yield ToolExecutionResult( + stream_event=OpenAIResponseObjectStreamResponseWebSearchCallInProgress( + item_id=item_id, + output_index=output_index, + sequence_number=sequence_number, + ), sequence_number=sequence_number, ) elif function_name == "knowledge_search": sequence_number += 1 - progress_event = OpenAIResponseObjectStreamResponseFileSearchCallInProgress( - item_id=item_id, - output_index=output_index, + yield ToolExecutionResult( + stream_event=OpenAIResponseObjectStreamResponseFileSearchCallInProgress( + item_id=item_id, + output_index=output_index, + sequence_number=sequence_number, + ), sequence_number=sequence_number, ) - if progress_event: - yield ToolExecutionResult(stream_event=progress_event, sequence_number=sequence_number) - # For web search, emit searching event if function_name == "web_search": sequence_number += 1 - searching_event = OpenAIResponseObjectStreamResponseWebSearchCallSearching( - item_id=item_id, - output_index=output_index, + yield ToolExecutionResult( + stream_event=OpenAIResponseObjectStreamResponseWebSearchCallSearching( + item_id=item_id, + output_index=output_index, + sequence_number=sequence_number, + ), sequence_number=sequence_number, ) - yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number) # For file search, emit searching event if function_name == "knowledge_search": sequence_number += 1 - searching_event = OpenAIResponseObjectStreamResponseFileSearchCallSearching( - item_id=item_id, - output_index=output_index, + yield ToolExecutionResult( + stream_event=OpenAIResponseObjectStreamResponseFileSearchCallSearching( + item_id=item_id, + output_index=output_index, + sequence_number=sequence_number, + ), sequence_number=sequence_number, ) - yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number) async def _execute_tool( self, @@ -261,7 +284,7 @@ class ToolExecutor: tool_kwargs: dict, ctx: ChatCompletionContext, mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None, - ) -> tuple[Exception | None, any]: + ) -> tuple[Exception | None, Any]: """Execute the tool and return error exception and result.""" error_exc = None result = None @@ -284,9 +307,13 @@ class ToolExecutor: kwargs=tool_kwargs, ) elif function_name == "knowledge_search": - response_file_search_tool = next( - (t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)), - None, + response_file_search_tool = ( + next( + (t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)), + None, + ) + if ctx.response_tools + else None ) if response_file_search_tool: # Use vector_stores.search API instead of knowledge_search tool @@ -322,35 +349,34 @@ class ToolExecutor: mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None, ) -> AsyncIterator[ToolExecutionResult]: """Emit completion or failure events for tool execution.""" - completion_event = None - if mcp_tool_to_server and function_name in mcp_tool_to_server: sequence_number += 1 if has_error: - completion_event = OpenAIResponseObjectStreamResponseMcpCallFailed( + mcp_failed_event = OpenAIResponseObjectStreamResponseMcpCallFailed( sequence_number=sequence_number, ) + yield ToolExecutionResult(stream_event=mcp_failed_event, sequence_number=sequence_number) else: - completion_event = OpenAIResponseObjectStreamResponseMcpCallCompleted( + mcp_completed_event = OpenAIResponseObjectStreamResponseMcpCallCompleted( sequence_number=sequence_number, ) + yield ToolExecutionResult(stream_event=mcp_completed_event, sequence_number=sequence_number) elif function_name == "web_search": sequence_number += 1 - completion_event = OpenAIResponseObjectStreamResponseWebSearchCallCompleted( + web_completion_event = OpenAIResponseObjectStreamResponseWebSearchCallCompleted( item_id=item_id, output_index=output_index, sequence_number=sequence_number, ) + yield ToolExecutionResult(stream_event=web_completion_event, sequence_number=sequence_number) elif function_name == "knowledge_search": sequence_number += 1 - completion_event = OpenAIResponseObjectStreamResponseFileSearchCallCompleted( + file_completion_event = OpenAIResponseObjectStreamResponseFileSearchCallCompleted( item_id=item_id, output_index=output_index, sequence_number=sequence_number, ) - - if completion_event: - yield ToolExecutionResult(stream_event=completion_event, sequence_number=sequence_number) + yield ToolExecutionResult(stream_event=file_completion_event, sequence_number=sequence_number) async def _build_result_messages( self, @@ -360,21 +386,18 @@ class ToolExecutor: tool_kwargs: dict, ctx: ChatCompletionContext, error_exc: Exception | None, - result: any, + result: Any, has_error: bool, mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None, - ) -> tuple[any, any]: + ) -> tuple[Any, Any]: """Build output and input messages from tool execution results.""" from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, ) # Build output message + message: Any if mcp_tool_to_server and function.name in mcp_tool_to_server: - from llama_stack.apis.agents.openai_responses import ( - OpenAIResponseOutputMessageMCPCall, - ) - message = OpenAIResponseOutputMessageMCPCall( id=item_id, arguments=function.arguments, @@ -383,10 +406,14 @@ class ToolExecutor: ) if error_exc: message.error = str(error_exc) - elif (result and result.error_code and result.error_code > 0) or (result and result.error_message): - message.error = f"Error (code {result.error_code}): {result.error_message}" - elif result and result.content: - message.output = interleaved_content_as_str(result.content) + elif (result and (error_code := getattr(result, "error_code", None)) and error_code > 0) or ( + result and getattr(result, "error_message", None) + ): + ec = getattr(result, "error_code", "unknown") + em = getattr(result, "error_message", "") + message.error = f"Error (code {ec}): {em}" + elif result and (content := getattr(result, "content", None)): + message.output = interleaved_content_as_str(content) else: if function.name == "web_search": message = OpenAIResponseOutputMessageWebSearchToolCall( @@ -401,17 +428,17 @@ class ToolExecutor: queries=[tool_kwargs.get("query", "")], status="completed", ) - if result and "document_ids" in result.metadata: + if result and (metadata := getattr(result, "metadata", None)) and "document_ids" in metadata: message.results = [] - for i, doc_id in enumerate(result.metadata["document_ids"]): - text = result.metadata["chunks"][i] if "chunks" in result.metadata else None - score = result.metadata["scores"][i] if "scores" in result.metadata else None + for i, doc_id in enumerate(metadata["document_ids"]): + text = metadata["chunks"][i] if "chunks" in metadata else None + score = metadata["scores"][i] if "scores" in metadata else None message.results.append( OpenAIResponseOutputMessageFileSearchToolCallResults( file_id=doc_id, filename=doc_id, - text=text, - score=score, + text=text if text is not None else "", + score=score if score is not None else 0.0, attributes={}, ) ) @@ -421,27 +448,32 @@ class ToolExecutor: raise ValueError(f"Unknown tool {function.name} called") # Build input message - input_message = None - if result and result.content: - if isinstance(result.content, str): - content = result.content - elif isinstance(result.content, list): - content = [] - for item in result.content: + input_message: OpenAIToolMessageParam | None = None + if result and (result_content := getattr(result, "content", None)): + # all the mypy contortions here are still unsatisfactory with random Any typing + if isinstance(result_content, str): + msg_content: str | list[Any] = result_content + elif isinstance(result_content, list): + content_list: list[Any] = [] + for item in result_content: + part: Any if isinstance(item, TextContentItem): part = OpenAIChatCompletionContentPartTextParam(text=item.text) elif isinstance(item, ImageContentItem): if item.image.data: - url = f"data:image;base64,{item.image.data}" + url_value = f"data:image;base64,{item.image.data}" else: - url = item.image.url - part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url)) + url_value = str(item.image.url) if item.image.url else "" + part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url_value)) else: raise ValueError(f"Unknown result content type: {type(item)}") - content.append(part) + content_list.append(part) + msg_content = content_list else: - raise ValueError(f"Unknown result content type: {type(result.content)}") - input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id) + raise ValueError(f"Unknown result content type: {type(result_content)}") + # OpenAIToolMessageParam accepts str | list[TextParam] but we may have images + # This is runtime-safe as the API accepts it, but mypy complains + input_message = OpenAIToolMessageParam(content=msg_content, tool_call_id=tool_call_id) # type: ignore[arg-type] else: text = str(error_exc) if error_exc else "Tool execution failed" input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id) diff --git a/src/llama_stack/providers/inline/agents/meta_reference/responses/types.py b/src/llama_stack/providers/inline/agents/meta_reference/responses/types.py index 829badf38..3b9a14b01 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/responses/types.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/responses/types.py @@ -5,6 +5,7 @@ # the root directory of this source tree. from dataclasses import dataclass +from typing import cast from openai.types.chat import ChatCompletionToolParam from pydantic import BaseModel @@ -100,17 +101,19 @@ class ToolContext(BaseModel): if isinstance(tool, OpenAIResponseToolMCP): previous_tools_by_label[tool.server_label] = tool # collect tool definitions which are the same in current and previous requests: - tools_to_process = [] + tools_to_process: list[OpenAIResponseInputTool] = [] matched: dict[str, OpenAIResponseInputToolMCP] = {} - for tool in self.current_tools: + # Mypy confuses OpenAIResponseInputTool (Input union) with OpenAIResponseTool (output union) + # which differ only in MCP type (InputToolMCP vs ToolMCP). Code is correct. + for tool in cast(list[OpenAIResponseInputTool], self.current_tools): # type: ignore[assignment] if isinstance(tool, OpenAIResponseInputToolMCP) and tool.server_label in previous_tools_by_label: previous_tool = previous_tools_by_label[tool.server_label] if previous_tool.allowed_tools == tool.allowed_tools: matched[tool.server_label] = tool else: - tools_to_process.append(tool) + tools_to_process.append(tool) # type: ignore[arg-type] else: - tools_to_process.append(tool) + tools_to_process.append(tool) # type: ignore[arg-type] # tools that are not the same or were not previously defined need to be processed: self.tools_to_process = tools_to_process # for all matched definitions, get the mcp_list_tools objects from the previous output: @@ -119,9 +122,11 @@ class ToolContext(BaseModel): ] # reconstruct the tool to server mappings that can be reused: for listing in self.previous_tool_listings: + # listing is OpenAIResponseOutputMessageMCPListTools which has tools: list[MCPListToolsTool] definition = matched[listing.server_label] - for tool in listing.tools: - self.previous_tools[tool.name] = definition + for mcp_tool in listing.tools: + # mcp_tool is MCPListToolsTool which has a name: str field + self.previous_tools[mcp_tool.name] = definition def available_tools(self) -> list[OpenAIResponseTool]: if not self.current_tools: @@ -139,6 +144,8 @@ class ToolContext(BaseModel): server_label=tool.server_label, allowed_tools=tool.allowed_tools, ) + # Exhaustive check - all tool types should be handled above + raise AssertionError(f"Unexpected tool type: {type(tool)}") return [convert_tool(tool) for tool in self.current_tools] diff --git a/src/llama_stack/providers/inline/agents/meta_reference/responses/utils.py b/src/llama_stack/providers/inline/agents/meta_reference/responses/utils.py index 7ca8af632..26af1d595 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/responses/utils.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/responses/utils.py @@ -7,6 +7,7 @@ import asyncio import re import uuid +from collections.abc import Sequence from llama_stack.apis.agents.agents import ResponseGuardrailSpec from llama_stack.apis.agents.openai_responses import ( @@ -71,14 +72,14 @@ async def convert_chat_choice_to_response_message( return OpenAIResponseMessage( id=message_id or f"msg_{uuid.uuid4()}", - content=[OpenAIResponseOutputMessageContentOutputText(text=clean_text, annotations=annotations)], + content=[OpenAIResponseOutputMessageContentOutputText(text=clean_text, annotations=list(annotations))], status="completed", role="assistant", ) async def convert_response_content_to_chat_content( - content: (str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]), + content: str | Sequence[OpenAIResponseInputMessageContent | OpenAIResponseOutputMessageContent], ) -> str | list[OpenAIChatCompletionContentPartParam]: """ Convert the content parts from an OpenAI Response API request into OpenAI Chat Completion content parts. @@ -88,7 +89,8 @@ async def convert_response_content_to_chat_content( if isinstance(content, str): return content - converted_parts = [] + # Type with union to avoid list invariance issues + converted_parts: list[OpenAIChatCompletionContentPartParam] = [] for content_part in content: if isinstance(content_part, OpenAIResponseInputMessageContentText): converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text)) @@ -158,9 +160,11 @@ async def convert_response_input_to_chat_messages( ), ) messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call])) + # Output can be None, use empty string as fallback + output_content = input_item.output if input_item.output is not None else "" messages.append( OpenAIToolMessageParam( - content=input_item.output, + content=output_content, tool_call_id=input_item.id, ) ) @@ -172,7 +176,8 @@ async def convert_response_input_to_chat_messages( ): # these are handled by the responses impl itself and not pass through to chat completions pass - else: + elif isinstance(input_item, OpenAIResponseMessage): + # Narrow type to OpenAIResponseMessage which has content and role attributes content = await convert_response_content_to_chat_content(input_item.content) message_type = await get_message_type_by_role(input_item.role) if message_type is None: @@ -191,7 +196,8 @@ async def convert_response_input_to_chat_messages( last_user_content = getattr(last_user_msg, "content", None) if last_user_content == content: continue # Skip duplicate user message - messages.append(message_type(content=content)) + # Dynamic message type call - different message types have different content expectations + messages.append(message_type(content=content)) # type: ignore[call-arg,arg-type] if len(tool_call_results): # Check if unpaired function_call_outputs reference function_calls from previous messages if previous_messages: @@ -237,8 +243,11 @@ async def convert_response_text_to_chat_response_format( if text.format["type"] == "json_object": return OpenAIResponseFormatJSONObject() if text.format["type"] == "json_schema": + # Assert name exists for json_schema format + assert text.format.get("name"), "json_schema format requires a name" + schema_name: str = text.format["name"] # type: ignore[assignment] return OpenAIResponseFormatJSONSchema( - json_schema=OpenAIJSONSchema(name=text.format["name"], schema=text.format["schema"]) + json_schema=OpenAIJSONSchema(name=schema_name, schema=text.format["schema"]) ) raise ValueError(f"Unsupported text format: {text.format}") @@ -251,7 +260,7 @@ async def get_message_type_by_role(role: str) -> type[OpenAIMessageParam] | None "assistant": OpenAIAssistantMessageParam, "developer": OpenAIDeveloperMessageParam, } - return role_to_type.get(role) + return role_to_type.get(role) # type: ignore[return-value] # Pydantic models use ModelMetaclass def _extract_citations_from_text( @@ -320,7 +329,8 @@ async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[ # Look up shields to get their provider_resource_id (actual model ID) model_ids = [] - shields_list = await safety_api.routing_table.list_shields() + # TODO: list_shields not in Safety interface but available at runtime via API routing + shields_list = await safety_api.routing_table.list_shields() # type: ignore[attr-defined] for guardrail_id in guardrail_ids: matching_shields = [shield for shield in shields_list.data if shield.identifier == guardrail_id] @@ -337,7 +347,9 @@ async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[ for result in response.results: if result.flagged: message = result.user_message or "Content blocked by safety guardrails" - flagged_categories = [cat for cat, flagged in result.categories.items() if flagged] + flagged_categories = ( + [cat for cat, flagged in result.categories.items() if flagged] if result.categories else [] + ) violation_type = result.metadata.get("violation_type", []) if result.metadata else [] if flagged_categories: @@ -347,6 +359,9 @@ async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[ return message + # No violations found + return None + def extract_guardrail_ids(guardrails: list | None) -> list[str]: """Extract guardrail IDs from guardrails parameter, handling both string IDs and ResponseGuardrailSpec objects.""" diff --git a/src/llama_stack/providers/inline/agents/meta_reference/safety.py b/src/llama_stack/providers/inline/agents/meta_reference/safety.py index 9baf5a14d..f0ae51423 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/safety.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/safety.py @@ -6,7 +6,7 @@ import asyncio -from llama_stack.apis.inference import Message +from llama_stack.apis.inference import OpenAIMessageParam from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel from llama_stack.core.telemetry import tracing from llama_stack.log import get_logger @@ -31,7 +31,7 @@ class ShieldRunnerMixin: self.input_shields = input_shields self.output_shields = output_shields - async def run_multiple_shields(self, messages: list[Message], identifiers: list[str]) -> None: + async def run_multiple_shields(self, messages: list[OpenAIMessageParam], identifiers: list[str]) -> None: async def run_shield_with_span(identifier: str): async with tracing.span(f"run_shield_{identifier}"): return await self.safety_api.run_shield( diff --git a/src/llama_stack/providers/registry/files.py b/src/llama_stack/providers/registry/files.py index 9acabfacd..3f5949ba2 100644 --- a/src/llama_stack/providers/registry/files.py +++ b/src/llama_stack/providers/registry/files.py @@ -28,4 +28,13 @@ def available_providers() -> list[ProviderSpec]: config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig", description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.", ), + RemoteProviderSpec( + api=Api.files, + provider_type="remote::openai", + adapter_type="openai", + pip_packages=["openai"] + sql_store_pip_packages, + module="llama_stack.providers.remote.files.openai", + config_class="llama_stack.providers.remote.files.openai.config.OpenAIFilesImplConfig", + description="OpenAI Files API provider for managing files through OpenAI's native file storage service.", + ), ] diff --git a/src/llama_stack/providers/remote/files/openai/__init__.py b/src/llama_stack/providers/remote/files/openai/__init__.py new file mode 100644 index 000000000..58f86ecfd --- /dev/null +++ b/src/llama_stack/providers/remote/files/openai/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from llama_stack.core.datatypes import AccessRule, Api + +from .config import OpenAIFilesImplConfig + + +async def get_adapter_impl(config: OpenAIFilesImplConfig, deps: dict[Api, Any], policy: list[AccessRule] | None = None): + from .files import OpenAIFilesImpl + + impl = OpenAIFilesImpl(config, policy or []) + await impl.initialize() + return impl diff --git a/src/llama_stack/providers/remote/files/openai/config.py b/src/llama_stack/providers/remote/files/openai/config.py new file mode 100644 index 000000000..a38031e41 --- /dev/null +++ b/src/llama_stack/providers/remote/files/openai/config.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from pydantic import BaseModel, Field + +from llama_stack.core.storage.datatypes import SqlStoreReference + + +class OpenAIFilesImplConfig(BaseModel): + """Configuration for OpenAI Files API provider.""" + + api_key: str = Field(description="OpenAI API key for authentication") + metadata_store: SqlStoreReference = Field(description="SQL store configuration for file metadata") + + @classmethod + def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]: + return { + "api_key": "${env.OPENAI_API_KEY}", + "metadata_store": SqlStoreReference( + backend="sql_default", + table_name="openai_files_metadata", + ).model_dump(exclude_none=True), + } diff --git a/src/llama_stack/providers/remote/files/openai/files.py b/src/llama_stack/providers/remote/files/openai/files.py new file mode 100644 index 000000000..c5d4194df --- /dev/null +++ b/src/llama_stack/providers/remote/files/openai/files.py @@ -0,0 +1,239 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from datetime import UTC, datetime +from typing import Annotated, Any + +from fastapi import Depends, File, Form, Response, UploadFile + +from llama_stack.apis.common.errors import ResourceNotFoundError +from llama_stack.apis.common.responses import Order +from llama_stack.apis.files import ( + ExpiresAfter, + Files, + ListOpenAIFileResponse, + OpenAIFileDeleteResponse, + OpenAIFileObject, + OpenAIFilePurpose, +) +from llama_stack.core.datatypes import AccessRule +from llama_stack.providers.utils.files.form_data import parse_expires_after +from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType +from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore +from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl +from openai import OpenAI + +from .config import OpenAIFilesImplConfig + + +def _make_file_object( + *, + id: str, + filename: str, + purpose: str, + bytes: int, + created_at: int, + expires_at: int, + **kwargs: Any, +) -> OpenAIFileObject: + """ + Construct an OpenAIFileObject and normalize expires_at. + + If expires_at is greater than the max we treat it as no-expiration and + return None for expires_at. + """ + obj = OpenAIFileObject( + id=id, + filename=filename, + purpose=OpenAIFilePurpose(purpose), + bytes=bytes, + created_at=created_at, + expires_at=expires_at, + ) + + if obj.expires_at is not None and obj.expires_at > (obj.created_at + ExpiresAfter.MAX): + obj.expires_at = None # type: ignore + + return obj + + +class OpenAIFilesImpl(Files): + """OpenAI Files API implementation.""" + + def __init__(self, config: OpenAIFilesImplConfig, policy: list[AccessRule]) -> None: + self._config = config + self.policy = policy + self._client: OpenAI | None = None + self._sql_store: AuthorizedSqlStore | None = None + + def _now(self) -> int: + """Return current UTC timestamp as int seconds.""" + return int(datetime.now(UTC).timestamp()) + + async def _get_file(self, file_id: str, return_expired: bool = False) -> dict[str, Any]: + where: dict[str, str | dict] = {"id": file_id} + if not return_expired: + where["expires_at"] = {">": self._now()} + if not (row := await self.sql_store.fetch_one("openai_files", where=where)): + raise ResourceNotFoundError(file_id, "File", "files.list()") + return row + + async def _delete_file(self, file_id: str) -> None: + """Delete a file from OpenAI and the database.""" + try: + self.client.files.delete(file_id) + except Exception as e: + # If file doesn't exist on OpenAI side, just remove from metadata store + if "not found" not in str(e).lower(): + raise RuntimeError(f"Failed to delete file from OpenAI: {e}") from e + + await self.sql_store.delete("openai_files", where={"id": file_id}) + + async def _delete_if_expired(self, file_id: str) -> None: + """If the file exists and is expired, delete it.""" + if row := await self._get_file(file_id, return_expired=True): + if (expires_at := row.get("expires_at")) and expires_at <= self._now(): + await self._delete_file(file_id) + + async def initialize(self) -> None: + self._client = OpenAI(api_key=self._config.api_key) + + self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store), self.policy) + await self._sql_store.create_table( + "openai_files", + { + "id": ColumnDefinition(type=ColumnType.STRING, primary_key=True), + "filename": ColumnType.STRING, + "purpose": ColumnType.STRING, + "bytes": ColumnType.INTEGER, + "created_at": ColumnType.INTEGER, + "expires_at": ColumnType.INTEGER, + }, + ) + + async def shutdown(self) -> None: + pass + + @property + def client(self) -> OpenAI: + assert self._client is not None, "Provider not initialized" + return self._client + + @property + def sql_store(self) -> AuthorizedSqlStore: + assert self._sql_store is not None, "Provider not initialized" + return self._sql_store + + async def openai_upload_file( + self, + file: Annotated[UploadFile, File()], + purpose: Annotated[OpenAIFilePurpose, Form()], + expires_after: Annotated[ExpiresAfter | None, Depends(parse_expires_after)] = None, + ) -> OpenAIFileObject: + filename = getattr(file, "filename", None) or "uploaded_file" + content = await file.read() + file_size = len(content) + + created_at = self._now() + + expires_at = created_at + ExpiresAfter.MAX * 42 + if purpose == OpenAIFilePurpose.BATCH: + expires_at = created_at + ExpiresAfter.MAX + + if expires_after is not None: + expires_at = created_at + expires_after.seconds + + try: + from io import BytesIO + + file_obj = BytesIO(content) + file_obj.name = filename + + response = self.client.files.create( + file=file_obj, + purpose=purpose.value, + ) + + file_id = response.id + + entry: dict[str, Any] = { + "id": file_id, + "filename": filename, + "purpose": purpose.value, + "bytes": file_size, + "created_at": created_at, + "expires_at": expires_at, + } + + await self.sql_store.insert("openai_files", entry) + + return _make_file_object(**entry) + + except Exception as e: + raise RuntimeError(f"Failed to upload file to OpenAI: {e}") from e + + async def openai_list_files( + self, + after: str | None = None, + limit: int | None = 10000, + order: Order | None = Order.desc, + purpose: OpenAIFilePurpose | None = None, + ) -> ListOpenAIFileResponse: + if not order: + order = Order.desc + + where_conditions: dict[str, Any] = {"expires_at": {">": self._now()}} + if purpose: + where_conditions["purpose"] = purpose.value + + paginated_result = await self.sql_store.fetch_all( + table="openai_files", + where=where_conditions, + order_by=[("created_at", order.value)], + cursor=("id", after) if after else None, + limit=limit, + ) + + files = [_make_file_object(**row) for row in paginated_result.data] + + return ListOpenAIFileResponse( + data=files, + has_more=paginated_result.has_more, + first_id=files[0].id if files else "", + last_id=files[-1].id if files else "", + ) + + async def openai_retrieve_file(self, file_id: str) -> OpenAIFileObject: + await self._delete_if_expired(file_id) + row = await self._get_file(file_id) + return _make_file_object(**row) + + async def openai_delete_file(self, file_id: str) -> OpenAIFileDeleteResponse: + await self._delete_if_expired(file_id) + _ = await self._get_file(file_id) + await self._delete_file(file_id) + return OpenAIFileDeleteResponse(id=file_id, deleted=True) + + async def openai_retrieve_file_content(self, file_id: str) -> Response: + await self._delete_if_expired(file_id) + + row = await self._get_file(file_id) + + try: + response = self.client.files.content(file_id) + file_content = response.content + + except Exception as e: + if "not found" in str(e).lower(): + await self._delete_file(file_id) + raise ResourceNotFoundError(file_id, "File", "files.list()") from e + raise RuntimeError(f"Failed to download file from OpenAI: {e}") from e + + return Response( + content=file_content, + media_type="application/octet-stream", + headers={"Content-Disposition": f'attachment; filename="{row["filename"]}"'}, + ) diff --git a/src/llama_stack/providers/remote/inference/anthropic/anthropic.py b/src/llama_stack/providers/remote/inference/anthropic/anthropic.py index dc9d8fb40..112b70524 100644 --- a/src/llama_stack/providers/remote/inference/anthropic/anthropic.py +++ b/src/llama_stack/providers/remote/inference/anthropic/anthropic.py @@ -33,4 +33,5 @@ class AnthropicInferenceAdapter(OpenAIMixin): return "https://api.anthropic.com/v1" async def list_provider_model_ids(self) -> Iterable[str]: - return [m.id async for m in AsyncAnthropic(api_key=self.get_api_key()).models.list()] + api_key = self._get_api_key_from_config_or_provider_data() + return [m.id async for m in AsyncAnthropic(api_key=api_key).models.list()] diff --git a/src/llama_stack/providers/remote/inference/databricks/databricks.py b/src/llama_stack/providers/remote/inference/databricks/databricks.py index 8a8c5d4e3..636241383 100644 --- a/src/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/src/llama_stack/providers/remote/inference/databricks/databricks.py @@ -33,10 +33,11 @@ class DatabricksInferenceAdapter(OpenAIMixin): async def list_provider_model_ids(self) -> Iterable[str]: # Filter out None values from endpoint names + api_token = self._get_api_key_from_config_or_provider_data() return [ endpoint.name # type: ignore[misc] for endpoint in WorkspaceClient( - host=self.config.url, token=self.get_api_key() + host=self.config.url, token=api_token ).serving_endpoints.list() # TODO: this is not async ] diff --git a/src/llama_stack/providers/remote/inference/nvidia/NVIDIA.md b/src/llama_stack/providers/remote/inference/nvidia/NVIDIA.md index f1a828413..97fa95a1f 100644 --- a/src/llama_stack/providers/remote/inference/nvidia/NVIDIA.md +++ b/src/llama_stack/providers/remote/inference/nvidia/NVIDIA.md @@ -181,3 +181,22 @@ vlm_response = client.chat.completions.create( print(f"VLM Response: {vlm_response.choices[0].message.content}") ``` + +### Rerank Example + +The following example shows how to rerank documents using an NVIDIA NIM. + +```python +rerank_response = client.alpha.inference.rerank( + model="nvidia/nvidia/llama-3.2-nv-rerankqa-1b-v2", + query="query", + items=[ + "item_1", + "item_2", + "item_3", + ], +) + +for i, result in enumerate(rerank_response): + print(f"{i+1}. [Index: {result.index}, " f"Score: {(result.relevance_score):.3f}]") +``` \ No newline at end of file diff --git a/src/llama_stack/providers/remote/inference/nvidia/config.py b/src/llama_stack/providers/remote/inference/nvidia/config.py index 3545d2b11..618bbe078 100644 --- a/src/llama_stack/providers/remote/inference/nvidia/config.py +++ b/src/llama_stack/providers/remote/inference/nvidia/config.py @@ -28,6 +28,7 @@ class NVIDIAConfig(RemoteInferenceProviderConfig): Attributes: url (str): A base url for accessing the NVIDIA NIM, e.g. http://localhost:8000 api_key (str): The access key for the hosted NIM endpoints + rerank_model_to_url (dict[str, str]): Mapping of rerank model identifiers to their API endpoints There are two ways to access NVIDIA NIMs - 0. Hosted: Preview APIs hosted at https://integrate.api.nvidia.com @@ -55,6 +56,14 @@ class NVIDIAConfig(RemoteInferenceProviderConfig): default_factory=lambda: os.getenv("NVIDIA_APPEND_API_VERSION", "True").lower() != "false", description="When set to false, the API version will not be appended to the base_url. By default, it is true.", ) + rerank_model_to_url: dict[str, str] = Field( + default_factory=lambda: { + "nv-rerank-qa-mistral-4b:1": "https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking", + "nvidia/nv-rerankqa-mistral-4b-v3": "https://ai.api.nvidia.com/v1/retrieval/nvidia/nv-rerankqa-mistral-4b-v3/reranking", + "nvidia/llama-3.2-nv-rerankqa-1b-v2": "https://ai.api.nvidia.com/v1/retrieval/nvidia/llama-3_2-nv-rerankqa-1b-v2/reranking", + }, + description="Mapping of rerank model identifiers to their API endpoints. ", + ) @classmethod def sample_run_config( diff --git a/src/llama_stack/providers/remote/inference/nvidia/nvidia.py b/src/llama_stack/providers/remote/inference/nvidia/nvidia.py index ea11b49cd..bc5aa7953 100644 --- a/src/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/src/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -5,6 +5,19 @@ # the root directory of this source tree. +from collections.abc import Iterable + +import aiohttp + +from llama_stack.apis.inference import ( + RerankData, + RerankResponse, +) +from llama_stack.apis.inference.inference import ( + OpenAIChatCompletionContentPartImageParam, + OpenAIChatCompletionContentPartTextParam, +) +from llama_stack.apis.models import Model, ModelType from llama_stack.log import get_logger from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin @@ -61,3 +74,101 @@ class NVIDIAInferenceAdapter(OpenAIMixin): :return: The NVIDIA API base URL """ return f"{self.config.url}/v1" if self.config.append_api_version else self.config.url + + async def list_provider_model_ids(self) -> Iterable[str]: + """ + Return both dynamic model IDs and statically configured rerank model IDs. + """ + dynamic_ids: Iterable[str] = [] + try: + dynamic_ids = await super().list_provider_model_ids() + except Exception: + # If the dynamic listing fails, proceed with just configured rerank IDs + dynamic_ids = [] + + configured_rerank_ids = list(self.config.rerank_model_to_url.keys()) + return list(dict.fromkeys(list(dynamic_ids) + configured_rerank_ids)) # remove duplicates + + def construct_model_from_identifier(self, identifier: str) -> Model: + """ + Classify rerank models from config; otherwise use the base behavior. + """ + if identifier in self.config.rerank_model_to_url: + return Model( + provider_id=self.__provider_id__, # type: ignore[attr-defined] + provider_resource_id=identifier, + identifier=identifier, + model_type=ModelType.rerank, + ) + return super().construct_model_from_identifier(identifier) + + async def rerank( + self, + model: str, + query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam, + items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam], + max_num_results: int | None = None, + ) -> RerankResponse: + provider_model_id = await self._get_provider_model_id(model) + + ranking_url = self.get_base_url() + + if _is_nvidia_hosted(self.config) and provider_model_id in self.config.rerank_model_to_url: + ranking_url = self.config.rerank_model_to_url[provider_model_id] + + logger.debug(f"Using rerank endpoint: {ranking_url} for model: {provider_model_id}") + + # Convert query to text format + if isinstance(query, str): + query_text = query + elif isinstance(query, OpenAIChatCompletionContentPartTextParam): + query_text = query.text + else: + raise ValueError("Query must be a string or text content part") + + # Convert items to text format + passages = [] + for item in items: + if isinstance(item, str): + passages.append({"text": item}) + elif isinstance(item, OpenAIChatCompletionContentPartTextParam): + passages.append({"text": item.text}) + else: + raise ValueError("Items must be strings or text content parts") + + payload = { + "model": provider_model_id, + "query": {"text": query_text}, + "passages": passages, + } + + headers = { + "Authorization": f"Bearer {self.get_api_key()}", + "Content-Type": "application/json", + } + + try: + async with aiohttp.ClientSession() as session: + async with session.post(ranking_url, headers=headers, json=payload) as response: + if response.status != 200: + response_text = await response.text() + raise ConnectionError( + f"NVIDIA rerank API request failed with status {response.status}: {response_text}" + ) + + result = await response.json() + rankings = result.get("rankings", []) + + # Convert to RerankData format + rerank_data = [] + for ranking in rankings: + rerank_data.append(RerankData(index=ranking["index"], relevance_score=ranking["logit"])) + + # Apply max_num_results limit + if max_num_results is not None: + rerank_data = rerank_data[:max_num_results] + + return RerankResponse(data=rerank_data) + + except aiohttp.ClientError as e: + raise ConnectionError(f"Failed to connect to NVIDIA rerank API at {ranking_url}: {e}") from e diff --git a/src/llama_stack/providers/utils/inference/inference_store.py b/src/llama_stack/providers/utils/inference/inference_store.py index 8e20bca6b..2bf947a8d 100644 --- a/src/llama_stack/providers/utils/inference/inference_store.py +++ b/src/llama_stack/providers/utils/inference/inference_store.py @@ -35,6 +35,7 @@ class InferenceStore: self.reference = reference self.sql_store = None self.policy = policy + self.enable_write_queue = True # Async write queue and worker control self._queue: asyncio.Queue[tuple[OpenAIChatCompletion, list[OpenAIMessageParam]]] | None = None @@ -47,14 +48,13 @@ class InferenceStore: base_store = sqlstore_impl(self.reference) self.sql_store = AuthorizedSqlStore(base_store, self.policy) - # Disable write queue for SQLite to avoid concurrency issues - backend_name = self.reference.backend - backend_config = _SQLSTORE_BACKENDS.get(backend_name) - if backend_config is None: - raise ValueError( - f"Unregistered SQL backend '{backend_name}'. Registered backends: {sorted(_SQLSTORE_BACKENDS)}" - ) - self.enable_write_queue = backend_config.type != StorageBackendType.SQL_SQLITE + # Disable write queue for SQLite since WAL mode handles concurrency + # Keep it enabled for other backends (like Postgres) for performance + backend_config = _SQLSTORE_BACKENDS.get(self.reference.backend) + if backend_config and backend_config.type == StorageBackendType.SQL_SQLITE: + self.enable_write_queue = False + logger.debug("Write queue disabled for SQLite (WAL mode handles concurrency)") + await self.sql_store.create_table( "chat_completions", { @@ -70,8 +70,9 @@ class InferenceStore: self._queue = asyncio.Queue(maxsize=self._max_write_queue_size) for _ in range(self._num_writers): self._worker_tasks.append(asyncio.create_task(self._worker_loop())) - else: - logger.info("Write queue disabled for SQLite to avoid concurrency issues") + logger.debug( + f"Inference store write queue enabled with {self._num_writers} writers, max queue size {self._max_write_queue_size}" + ) async def shutdown(self) -> None: if not self._worker_tasks: diff --git a/src/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/src/llama_stack/providers/utils/inference/litellm_openai_mixin.py index 3eef1f272..223497fb8 100644 --- a/src/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/src/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -128,7 +128,9 @@ class LiteLLMOpenAIMixin( return schema async def _get_params(self, request: ChatCompletionRequest) -> dict: - input_dict = {} + from typing import Any + + input_dict: dict[str, Any] = {} input_dict["messages"] = [ await convert_message_to_openai_dict_new(m, download_images=self.download_images) for m in request.messages @@ -139,30 +141,27 @@ class LiteLLMOpenAIMixin( f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported." ) - fmt = fmt.json_schema - name = fmt["title"] - del fmt["title"] - fmt["additionalProperties"] = False + # Convert to dict for manipulation + fmt_dict = dict(fmt.json_schema) + name = fmt_dict["title"] + del fmt_dict["title"] + fmt_dict["additionalProperties"] = False # Apply additionalProperties: False recursively to all objects - fmt = self._add_additional_properties_recursive(fmt) + fmt_dict = self._add_additional_properties_recursive(fmt_dict) input_dict["response_format"] = { "type": "json_schema", "json_schema": { "name": name, - "schema": fmt, + "schema": fmt_dict, "strict": self.json_schema_strict, }, } if request.tools: input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools] - if request.tool_config.tool_choice: - input_dict["tool_choice"] = ( - request.tool_config.tool_choice.value - if isinstance(request.tool_config.tool_choice, ToolChoice) - else request.tool_config.tool_choice - ) + if request.tool_config and (tool_choice := request.tool_config.tool_choice): + input_dict["tool_choice"] = tool_choice.value if isinstance(tool_choice, ToolChoice) else tool_choice return { "model": request.model, @@ -176,10 +175,10 @@ class LiteLLMOpenAIMixin( def get_api_key(self) -> str: provider_data = self.get_request_provider_data() key_field = self.provider_data_api_key_field - if provider_data and getattr(provider_data, key_field, None): - api_key = getattr(provider_data, key_field) - else: - api_key = self.api_key_from_config + if provider_data and key_field and (api_key := getattr(provider_data, key_field, None)): + return str(api_key) # type: ignore[no-any-return] # getattr returns Any, can't narrow without runtime type inspection + + api_key = self.api_key_from_config if not api_key: raise ValueError( "API key is not set. Please provide a valid API key in the " @@ -192,7 +191,13 @@ class LiteLLMOpenAIMixin( self, params: OpenAIEmbeddingsRequestWithExtraBody, ) -> OpenAIEmbeddingsResponse: + if not self.model_store: + raise ValueError("Model store is not initialized") + model_obj = await self.model_store.get_model(params.model) + if model_obj.provider_resource_id is None: + raise ValueError(f"Model {params.model} has no provider_resource_id") + provider_resource_id = model_obj.provider_resource_id # Convert input to list if it's a string input_list = [params.input] if isinstance(params.input, str) else params.input @@ -200,7 +205,7 @@ class LiteLLMOpenAIMixin( # Call litellm embedding function # litellm.drop_params = True response = litellm.embedding( - model=self.get_litellm_model_name(model_obj.provider_resource_id), + model=self.get_litellm_model_name(provider_resource_id), input=input_list, api_key=self.get_api_key(), api_base=self.api_base, @@ -217,7 +222,7 @@ class LiteLLMOpenAIMixin( return OpenAIEmbeddingsResponse( data=data, - model=model_obj.provider_resource_id, + model=provider_resource_id, usage=usage, ) @@ -225,10 +230,16 @@ class LiteLLMOpenAIMixin( self, params: OpenAICompletionRequestWithExtraBody, ) -> OpenAICompletion: + if not self.model_store: + raise ValueError("Model store is not initialized") + model_obj = await self.model_store.get_model(params.model) + if model_obj.provider_resource_id is None: + raise ValueError(f"Model {params.model} has no provider_resource_id") + provider_resource_id = model_obj.provider_resource_id request_params = await prepare_openai_completion_params( - model=self.get_litellm_model_name(model_obj.provider_resource_id), + model=self.get_litellm_model_name(provider_resource_id), prompt=params.prompt, best_of=params.best_of, echo=params.echo, @@ -249,7 +260,8 @@ class LiteLLMOpenAIMixin( api_key=self.get_api_key(), api_base=self.api_base, ) - return await litellm.atext_completion(**request_params) + # LiteLLM returns compatible type but mypy can't verify external library + return await litellm.atext_completion(**request_params) # type: ignore[no-any-return] # external lib lacks type stubs async def openai_chat_completion( self, @@ -265,10 +277,16 @@ class LiteLLMOpenAIMixin( elif "include_usage" not in stream_options: stream_options = {**stream_options, "include_usage": True} + if not self.model_store: + raise ValueError("Model store is not initialized") + model_obj = await self.model_store.get_model(params.model) + if model_obj.provider_resource_id is None: + raise ValueError(f"Model {params.model} has no provider_resource_id") + provider_resource_id = model_obj.provider_resource_id request_params = await prepare_openai_completion_params( - model=self.get_litellm_model_name(model_obj.provider_resource_id), + model=self.get_litellm_model_name(provider_resource_id), messages=params.messages, frequency_penalty=params.frequency_penalty, function_call=params.function_call, @@ -294,7 +312,8 @@ class LiteLLMOpenAIMixin( api_key=self.get_api_key(), api_base=self.api_base, ) - return await litellm.acompletion(**request_params) + # LiteLLM returns compatible type but mypy can't verify external library + return await litellm.acompletion(**request_params) # type: ignore[no-any-return] # external lib lacks type stubs async def check_model_availability(self, model: str) -> bool: """ diff --git a/src/llama_stack/providers/utils/inference/model_registry.py b/src/llama_stack/providers/utils/inference/model_registry.py index d60d00f87..8a120b698 100644 --- a/src/llama_stack/providers/utils/inference/model_registry.py +++ b/src/llama_stack/providers/utils/inference/model_registry.py @@ -20,7 +20,7 @@ logger = get_logger(name=__name__, category="providers::utils") class RemoteInferenceProviderConfig(BaseModel): - allowed_models: list[str] | None = Field( # TODO: make this non-optional and give a list() default + allowed_models: list[str] | None = Field( default=None, description="List of models that should be registered with the model registry. If None, all models are allowed.", ) diff --git a/src/llama_stack/providers/utils/inference/openai_compat.py b/src/llama_stack/providers/utils/inference/openai_compat.py index 7e465a14c..aabcb50f8 100644 --- a/src/llama_stack/providers/utils/inference/openai_compat.py +++ b/src/llama_stack/providers/utils/inference/openai_compat.py @@ -161,8 +161,10 @@ def get_sampling_strategy_options(params: SamplingParams) -> dict: if isinstance(params.strategy, GreedySamplingStrategy): options["temperature"] = 0.0 elif isinstance(params.strategy, TopPSamplingStrategy): - options["temperature"] = params.strategy.temperature - options["top_p"] = params.strategy.top_p + if params.strategy.temperature is not None: + options["temperature"] = params.strategy.temperature + if params.strategy.top_p is not None: + options["top_p"] = params.strategy.top_p elif isinstance(params.strategy, TopKSamplingStrategy): options["top_k"] = params.strategy.top_k else: @@ -192,12 +194,12 @@ def get_sampling_options(params: SamplingParams | None) -> dict: def text_from_choice(choice) -> str: if hasattr(choice, "delta") and choice.delta: - return choice.delta.content + return choice.delta.content # type: ignore[no-any-return] # external OpenAI types lack precise annotations if hasattr(choice, "message"): - return choice.message.content + return choice.message.content # type: ignore[no-any-return] # external OpenAI types lack precise annotations - return choice.text + return choice.text # type: ignore[no-any-return] # external OpenAI types lack precise annotations def get_stop_reason(finish_reason: str) -> StopReason: @@ -216,7 +218,7 @@ def convert_openai_completion_logprobs( ) -> list[TokenLogProbs] | None: if not logprobs: return None - if hasattr(logprobs, "top_logprobs"): + if hasattr(logprobs, "top_logprobs") and logprobs.top_logprobs: return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs] # Together supports logprobs with top_k=1 only. This means for each token position, @@ -236,7 +238,7 @@ def convert_openai_completion_logprobs_stream(text: str, logprobs: float | OpenA if isinstance(logprobs, float): # Adapt response from Together CompletionChoicesChunk return [TokenLogProbs(logprobs_by_token={text: logprobs})] - if hasattr(logprobs, "top_logprobs"): + if hasattr(logprobs, "top_logprobs") and logprobs.top_logprobs: return [TokenLogProbs(logprobs_by_token=x) for x in logprobs.top_logprobs] return None @@ -245,23 +247,24 @@ def process_completion_response( response: OpenAICompatCompletionResponse, ) -> CompletionResponse: choice = response.choices[0] + text = choice.text or "" # drop suffix if present and return stop reason as end of turn - if choice.text.endswith("<|eot_id|>"): + if text.endswith("<|eot_id|>"): return CompletionResponse( stop_reason=StopReason.end_of_turn, - content=choice.text[: -len("<|eot_id|>")], + content=text[: -len("<|eot_id|>")], logprobs=convert_openai_completion_logprobs(choice.logprobs), ) # drop suffix if present and return stop reason as end of message - if choice.text.endswith("<|eom_id|>"): + if text.endswith("<|eom_id|>"): return CompletionResponse( stop_reason=StopReason.end_of_message, - content=choice.text[: -len("<|eom_id|>")], + content=text[: -len("<|eom_id|>")], logprobs=convert_openai_completion_logprobs(choice.logprobs), ) return CompletionResponse( - stop_reason=get_stop_reason(choice.finish_reason), - content=choice.text, + stop_reason=get_stop_reason(choice.finish_reason or "stop"), + content=text, logprobs=convert_openai_completion_logprobs(choice.logprobs), ) @@ -272,10 +275,10 @@ def process_chat_completion_response( ) -> ChatCompletionResponse: choice = response.choices[0] if choice.finish_reason == "tool_calls": - if not choice.message or not choice.message.tool_calls: + if not hasattr(choice, "message") or not choice.message or not choice.message.tool_calls: # type: ignore[attr-defined] # OpenAICompatCompletionChoice is runtime duck-typed raise ValueError("Tool calls are not present in the response") - tool_calls = [convert_tool_call(tool_call) for tool_call in choice.message.tool_calls] + tool_calls = [convert_tool_call(tool_call) for tool_call in choice.message.tool_calls] # type: ignore[attr-defined] # OpenAICompatCompletionChoice is runtime duck-typed if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls): # If we couldn't parse a tool call, jsonify the tool calls and return them return ChatCompletionResponse( @@ -287,9 +290,11 @@ def process_chat_completion_response( ) else: # Otherwise, return tool calls as normal + # Filter to only valid ToolCall objects + valid_tool_calls = [tc for tc in tool_calls if isinstance(tc, ToolCall)] return ChatCompletionResponse( completion_message=CompletionMessage( - tool_calls=tool_calls, + tool_calls=valid_tool_calls, stop_reason=StopReason.end_of_turn, # Content is not optional content="", @@ -299,7 +304,7 @@ def process_chat_completion_response( # TODO: This does not work well with tool calls for vLLM remote provider # Ref: https://github.com/meta-llama/llama-stack/issues/1058 - raw_message = decode_assistant_message(text_from_choice(choice), get_stop_reason(choice.finish_reason)) + raw_message = decode_assistant_message(text_from_choice(choice), get_stop_reason(choice.finish_reason or "stop")) # NOTE: If we do not set tools in chat-completion request, we should not # expect the ToolCall in the response. Instead, we should return the raw @@ -324,8 +329,8 @@ def process_chat_completion_response( return ChatCompletionResponse( completion_message=CompletionMessage( - content=raw_message.content, - stop_reason=raw_message.stop_reason, + content=raw_message.content, # type: ignore[arg-type] # decode_assistant_message returns Union[str, InterleavedContent] + stop_reason=raw_message.stop_reason or StopReason.end_of_turn, tool_calls=raw_message.tool_calls, ), logprobs=None, @@ -448,7 +453,7 @@ async def process_chat_completion_stream_response( ) # parse tool calls and report errors - message = decode_assistant_message(buffer, stop_reason) + message = decode_assistant_message(buffer, stop_reason or StopReason.end_of_turn) parsed_tool_calls = len(message.tool_calls) > 0 if ipython and not parsed_tool_calls: @@ -463,7 +468,7 @@ async def process_chat_completion_stream_response( ) ) - request_tools = {t.tool_name: t for t in request.tools} + request_tools = {t.tool_name: t for t in (request.tools or [])} for tool_call in message.tool_calls: if tool_call.tool_name in request_tools: yield ChatCompletionResponseStreamChunk( @@ -525,7 +530,7 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals } if hasattr(message, "tool_calls") and message.tool_calls: - result["tool_calls"] = [] + tool_calls_list = [] for tc in message.tool_calls: # The tool.tool_name can be a str or a BuiltinTool enum. If # it's the latter, convert to a string. @@ -533,7 +538,7 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals if isinstance(tool_name, BuiltinTool): tool_name = tool_name.value - result["tool_calls"].append( + tool_calls_list.append( { "id": tc.call_id, "type": "function", @@ -543,6 +548,7 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals }, } ) + result["tool_calls"] = tool_calls_list # type: ignore[assignment] # dict allows Any value, stricter type expected return result @@ -608,7 +614,7 @@ async def convert_message_to_openai_dict_new( ), ) elif isinstance(content_, list): - return [await impl(item) for item in content_] + return [await impl(item) for item in content_] # type: ignore[misc] # recursive list comprehension confuses mypy's type narrowing else: raise ValueError(f"Unsupported content type: {type(content_)}") @@ -620,7 +626,7 @@ async def convert_message_to_openai_dict_new( else: return [ret] - out: OpenAIChatCompletionMessage = None + out: OpenAIChatCompletionMessage if isinstance(message, UserMessage): out = OpenAIChatCompletionUserMessage( role="user", @@ -636,7 +642,7 @@ async def convert_message_to_openai_dict_new( ), type="function", ) - for tool in message.tool_calls + for tool in (message.tool_calls or []) ] params = {} if tool_calls: @@ -644,18 +650,18 @@ async def convert_message_to_openai_dict_new( out = OpenAIChatCompletionAssistantMessage( role="assistant", content=await _convert_message_content(message.content), - **params, + **params, # type: ignore[typeddict-item] # tool_calls dict expansion conflicts with TypedDict optional field ) elif isinstance(message, ToolResponseMessage): out = OpenAIChatCompletionToolMessage( role="tool", tool_call_id=message.call_id, - content=await _convert_message_content(message.content), + content=await _convert_message_content(message.content), # type: ignore[typeddict-item] # content union type incompatible with TypedDict str requirement ) elif isinstance(message, SystemMessage): out = OpenAIChatCompletionSystemMessage( role="system", - content=await _convert_message_content(message.content), + content=await _convert_message_content(message.content), # type: ignore[typeddict-item] # content union type incompatible with TypedDict str requirement ) else: raise ValueError(f"Unsupported message type: {type(message)}") @@ -758,16 +764,16 @@ def convert_tooldef_to_openai_tool(tool: ToolDefinition) -> dict: function = out["function"] if isinstance(tool.tool_name, BuiltinTool): - function["name"] = tool.tool_name.value + function["name"] = tool.tool_name.value # type: ignore[index] # dict value inferred as Any but mypy sees Collection[str] else: - function["name"] = tool.tool_name + function["name"] = tool.tool_name # type: ignore[index] # dict value inferred as Any but mypy sees Collection[str] if tool.description: - function["description"] = tool.description + function["description"] = tool.description # type: ignore[index] # dict value inferred as Any but mypy sees Collection[str] if tool.input_schema: # Pass through the entire JSON Schema as-is - function["parameters"] = tool.input_schema + function["parameters"] = tool.input_schema # type: ignore[index] # dict value inferred as Any but mypy sees Collection[str] # NOTE: OpenAI does not support output_schema, so we drop it here # It's stored in LlamaStack for validation and other provider usage @@ -815,15 +821,15 @@ def _convert_openai_request_tool_config(tool_choice: str | dict[str, Any] | None tool_config = ToolConfig() if tool_choice: try: - tool_choice = ToolChoice(tool_choice) + tool_choice = ToolChoice(tool_choice) # type: ignore[assignment] # reassigning to enum narrows union but mypy can't track after exception except ValueError: pass - tool_config.tool_choice = tool_choice + tool_config.tool_choice = tool_choice # type: ignore[assignment] # ToolConfig.tool_choice accepts Union[ToolChoice, dict] but mypy tracks narrower type return tool_config def _convert_openai_request_tools(tools: list[dict[str, Any]] | None = None) -> list[ToolDefinition]: - lls_tools = [] + lls_tools: list[ToolDefinition] = [] if not tools: return lls_tools @@ -843,16 +849,16 @@ def _convert_openai_request_tools(tools: list[dict[str, Any]] | None = None) -> def _convert_openai_request_response_format( - response_format: OpenAIResponseFormatParam = None, + response_format: OpenAIResponseFormatParam | None = None, ): if not response_format: return None # response_format can be a dict or a pydantic model - response_format = dict(response_format) - if response_format.get("type", "") == "json_schema": + response_format_dict = dict(response_format) # type: ignore[arg-type] # OpenAIResponseFormatParam union needs dict conversion + if response_format_dict.get("type", "") == "json_schema": return JsonSchemaResponseFormat( - type="json_schema", - json_schema=response_format.get("json_schema", {}).get("schema", ""), + type="json_schema", # type: ignore[arg-type] # Literal["json_schema"] incompatible with expected type + json_schema=response_format_dict.get("json_schema", {}).get("schema", ""), ) return None @@ -938,16 +944,15 @@ def _convert_openai_sampling_params( # Map an explicit temperature of 0 to greedy sampling if temperature == 0: - strategy = GreedySamplingStrategy() + sampling_params.strategy = GreedySamplingStrategy() else: # OpenAI defaults to 1.0 for temperature and top_p if unset if temperature is None: temperature = 1.0 if top_p is None: top_p = 1.0 - strategy = TopPSamplingStrategy(temperature=temperature, top_p=top_p) + sampling_params.strategy = TopPSamplingStrategy(temperature=temperature, top_p=top_p) # type: ignore[assignment] # SamplingParams.strategy union accepts this type - sampling_params.strategy = strategy return sampling_params @@ -957,23 +962,24 @@ def openai_messages_to_messages( """ Convert a list of OpenAIChatCompletionMessage into a list of Message. """ - converted_messages = [] + converted_messages: list[Message] = [] for message in messages: + converted_message: Message if message.role == "system": - converted_message = SystemMessage(content=openai_content_to_content(message.content)) + converted_message = SystemMessage(content=openai_content_to_content(message.content)) # type: ignore[arg-type] # OpenAI SDK uses aliased types internally that mypy sees as incompatible with base types elif message.role == "user": - converted_message = UserMessage(content=openai_content_to_content(message.content)) + converted_message = UserMessage(content=openai_content_to_content(message.content)) # type: ignore[arg-type] # OpenAI SDK uses aliased types internally that mypy sees as incompatible with base types elif message.role == "assistant": converted_message = CompletionMessage( - content=openai_content_to_content(message.content), - tool_calls=_convert_openai_tool_calls(message.tool_calls), + content=openai_content_to_content(message.content), # type: ignore[arg-type] # OpenAI SDK uses aliased types internally that mypy sees as incompatible with base types + tool_calls=_convert_openai_tool_calls(message.tool_calls) if message.tool_calls else [], # type: ignore[arg-type] # OpenAI tool_calls type incompatible with conversion function stop_reason=StopReason.end_of_turn, ) elif message.role == "tool": converted_message = ToolResponseMessage( role="tool", call_id=message.tool_call_id, - content=openai_content_to_content(message.content), + content=openai_content_to_content(message.content), # type: ignore[arg-type] # OpenAI SDK uses aliased types internally that mypy sees as incompatible with base types ) else: raise ValueError(f"Unknown role {message.role}") @@ -990,9 +996,9 @@ def openai_content_to_content(content: str | Iterable[OpenAIChatCompletionConten return [openai_content_to_content(c) for c in content] elif hasattr(content, "type"): if content.type == "text": - return TextContentItem(type="text", text=content.text) + return TextContentItem(type="text", text=content.text) # type: ignore[attr-defined] # Iterable narrowed by hasattr check but mypy doesn't track elif content.type == "image_url": - return ImageContentItem(type="image", image=_URLOrData(url=URL(uri=content.image_url.url))) + return ImageContentItem(type="image", image=_URLOrData(url=URL(uri=content.image_url.url))) # type: ignore[attr-defined] # Iterable narrowed by hasattr check but mypy doesn't track else: raise ValueError(f"Unknown content type: {content.type}") else: @@ -1041,9 +1047,9 @@ def convert_openai_chat_completion_choice( completion_message=CompletionMessage( content=choice.message.content or "", # CompletionMessage content is not optional stop_reason=_convert_openai_finish_reason(choice.finish_reason), - tool_calls=_convert_openai_tool_calls(choice.message.tool_calls), + tool_calls=_convert_openai_tool_calls(choice.message.tool_calls) if choice.message.tool_calls else [], # type: ignore[arg-type] # OpenAI tool_calls Optional type broadens union ), - logprobs=_convert_openai_logprobs(getattr(choice, "logprobs", None)), + logprobs=_convert_openai_logprobs(getattr(choice, "logprobs", None)), # type: ignore[arg-type] # getattr returns Any, can't narrow without inspection ) @@ -1070,7 +1076,7 @@ async def convert_openai_chat_completion_stream( choice = chunk.choices[0] # assuming only one choice per chunk # we assume there's only one finish_reason in the stream - stop_reason = _convert_openai_finish_reason(choice.finish_reason) or stop_reason + stop_reason = _convert_openai_finish_reason(choice.finish_reason) if choice.finish_reason else stop_reason logprobs = getattr(choice, "logprobs", None) # if there's a tool call, emit an event for each tool in the list @@ -1083,7 +1089,7 @@ async def convert_openai_chat_completion_stream( event=ChatCompletionResponseEvent( event_type=event_type, delta=TextDelta(text=choice.delta.content), - logprobs=_convert_openai_logprobs(logprobs), + logprobs=_convert_openai_logprobs(logprobs), # type: ignore[arg-type] # logprobs type broadened from getattr result ) ) @@ -1101,10 +1107,10 @@ async def convert_openai_chat_completion_stream( event=ChatCompletionResponseEvent( event_type=event_type, delta=ToolCallDelta( - tool_call=_convert_openai_tool_calls([tool_call])[0], + tool_call=_convert_openai_tool_calls([tool_call])[0], # type: ignore[arg-type, list-item] # delta tool_call type differs from complete tool_call parse_status=ToolCallParseStatus.succeeded, ), - logprobs=_convert_openai_logprobs(logprobs), + logprobs=_convert_openai_logprobs(logprobs), # type: ignore[arg-type] # logprobs type broadened from getattr result ) ) else: @@ -1125,12 +1131,15 @@ async def convert_openai_chat_completion_stream( if tool_call.function.name: buffer["name"] = tool_call.function.name delta = f"{buffer['name']}(" - buffer["content"] += delta + if buffer["content"] is not None: + buffer["content"] += delta if tool_call.function.arguments: delta = tool_call.function.arguments - buffer["arguments"] += delta - buffer["content"] += delta + if buffer["arguments"] is not None and delta: + buffer["arguments"] += delta + if buffer["content"] is not None and delta: + buffer["content"] += delta yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( @@ -1139,7 +1148,7 @@ async def convert_openai_chat_completion_stream( tool_call=delta, parse_status=ToolCallParseStatus.in_progress, ), - logprobs=_convert_openai_logprobs(logprobs), + logprobs=_convert_openai_logprobs(logprobs), # type: ignore[arg-type] # logprobs type broadened from getattr result ) ) elif choice.delta.content: @@ -1147,7 +1156,7 @@ async def convert_openai_chat_completion_stream( event=ChatCompletionResponseEvent( event_type=event_type, delta=TextDelta(text=choice.delta.content or ""), - logprobs=_convert_openai_logprobs(logprobs), + logprobs=_convert_openai_logprobs(logprobs), # type: ignore[arg-type] # logprobs type broadened from getattr result ) ) @@ -1155,7 +1164,8 @@ async def convert_openai_chat_completion_stream( logger.debug(f"toolcall_buffer[{idx}]: {buffer}") if buffer["name"]: delta = ")" - buffer["content"] += delta + if buffer["content"] is not None: + buffer["content"] += delta yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=event_type, @@ -1168,16 +1178,16 @@ async def convert_openai_chat_completion_stream( ) try: - tool_call = ToolCall( - call_id=buffer["call_id"], - tool_name=buffer["name"], - arguments=buffer["arguments"], + parsed_tool_call = ToolCall( + call_id=buffer["call_id"] or "", + tool_name=buffer["name"] or "", + arguments=buffer["arguments"] or "", ) yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.progress, delta=ToolCallDelta( - tool_call=tool_call, + tool_call=parsed_tool_call, # type: ignore[arg-type] # ToolCallDelta.tool_call accepts Union[str, ToolCall] parse_status=ToolCallParseStatus.succeeded, ), stop_reason=stop_reason, @@ -1189,7 +1199,7 @@ async def convert_openai_chat_completion_stream( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.progress, delta=ToolCallDelta( - tool_call=buffer["content"], + tool_call=buffer["content"], # type: ignore[arg-type] # ToolCallDelta.tool_call accepts Union[str, ToolCall] parse_status=ToolCallParseStatus.failed, ), stop_reason=stop_reason, @@ -1250,7 +1260,7 @@ class OpenAIChatCompletionToLlamaStackMixin: top_p: float | None = None, user: str | None = None, ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: - messages = openai_messages_to_messages(messages) + messages = openai_messages_to_messages(messages) # type: ignore[assignment] # converted from OpenAI to LlamaStack message format response_format = _convert_openai_request_response_format(response_format) sampling_params = _convert_openai_sampling_params( max_tokens=max_tokens, @@ -1259,15 +1269,15 @@ class OpenAIChatCompletionToLlamaStackMixin: ) tool_config = _convert_openai_request_tool_config(tool_choice) - tools = _convert_openai_request_tools(tools) + tools = _convert_openai_request_tools(tools) # type: ignore[assignment] # converted from OpenAI to LlamaStack tool format if tool_config.tool_choice == ToolChoice.none: - tools = [] + tools = [] # type: ignore[assignment] # empty list narrows return type but mypy tracks broader type outstanding_responses = [] # "n" is the number of completions to generate per prompt n = n or 1 for _i in range(0, n): - response = self.chat_completion( + response = self.chat_completion( # type: ignore[attr-defined] # mixin expects class to implement chat_completion model_id=model, messages=messages, sampling_params=sampling_params, @@ -1279,7 +1289,7 @@ class OpenAIChatCompletionToLlamaStackMixin: outstanding_responses.append(response) if stream: - return OpenAIChatCompletionToLlamaStackMixin._process_stream_response(self, model, outstanding_responses) + return OpenAIChatCompletionToLlamaStackMixin._process_stream_response(self, model, outstanding_responses) # type: ignore[no-any-return] # mixin async generator return type too complex for mypy return await OpenAIChatCompletionToLlamaStackMixin._process_non_stream_response( self, model, outstanding_responses @@ -1295,14 +1305,16 @@ class OpenAIChatCompletionToLlamaStackMixin: response = await outstanding_response async for chunk in response: event = chunk.event - finish_reason = _convert_stop_reason_to_openai_finish_reason(event.stop_reason) + finish_reason = ( + _convert_stop_reason_to_openai_finish_reason(event.stop_reason) if event.stop_reason else None + ) if isinstance(event.delta, TextDelta): text_delta = event.delta.text delta = OpenAIChoiceDelta(content=text_delta) yield OpenAIChatCompletionChunk( id=id, - choices=[OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)], + choices=[OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)], # type: ignore[arg-type] # finish_reason Optional[str] incompatible with Literal union created=int(time.time()), model=model, object="chat.completion.chunk", @@ -1310,13 +1322,17 @@ class OpenAIChatCompletionToLlamaStackMixin: elif isinstance(event.delta, ToolCallDelta): if event.delta.parse_status == ToolCallParseStatus.succeeded: tool_call = event.delta.tool_call + if isinstance(tool_call, str): + continue # First chunk includes full structure openai_tool_call = OpenAIChoiceDeltaToolCall( index=0, id=tool_call.call_id, function=OpenAIChoiceDeltaToolCallFunction( - name=tool_call.tool_name, + name=tool_call.tool_name + if isinstance(tool_call.tool_name, str) + else tool_call.tool_name.value, # type: ignore[arg-type] # enum .value extraction on Union confuses mypy arguments="", ), ) @@ -1324,7 +1340,7 @@ class OpenAIChatCompletionToLlamaStackMixin: yield OpenAIChatCompletionChunk( id=id, choices=[ - OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta) + OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta) # type: ignore[arg-type] # finish_reason Optional[str] incompatible with Literal union ], created=int(time.time()), model=model, @@ -1341,7 +1357,7 @@ class OpenAIChatCompletionToLlamaStackMixin: yield OpenAIChatCompletionChunk( id=id, choices=[ - OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta) + OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta) # type: ignore[arg-type] # finish_reason Optional[str] incompatible with Literal union ], created=int(time.time()), model=model, @@ -1351,7 +1367,7 @@ class OpenAIChatCompletionToLlamaStackMixin: async def _process_non_stream_response( self, model: str, outstanding_responses: list[Awaitable[ChatCompletionResponse]] ) -> OpenAIChatCompletion: - choices = [] + choices: list[OpenAIChatCompletionChoice] = [] for outstanding_response in outstanding_responses: response = await outstanding_response completion_message = response.completion_message @@ -1360,14 +1376,14 @@ class OpenAIChatCompletionToLlamaStackMixin: choice = OpenAIChatCompletionChoice( index=len(choices), - message=message, + message=message, # type: ignore[arg-type] # OpenAIChatCompletionMessage union incompatible with narrower Message type finish_reason=finish_reason, ) - choices.append(choice) + choices.append(choice) # type: ignore[arg-type] # OpenAIChatCompletionChoice type annotation mismatch return OpenAIChatCompletion( id=f"chatcmpl-{uuid.uuid4()}", - choices=choices, + choices=choices, # type: ignore[arg-type] # list[OpenAIChatCompletionChoice] union incompatible created=int(time.time()), model=model, object="chat.completion", diff --git a/src/llama_stack/providers/utils/inference/openai_mixin.py b/src/llama_stack/providers/utils/inference/openai_mixin.py index 941772b0f..09059da09 100644 --- a/src/llama_stack/providers/utils/inference/openai_mixin.py +++ b/src/llama_stack/providers/utils/inference/openai_mixin.py @@ -83,9 +83,6 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): # This is set in list_models() and used in check_model_availability() _model_cache: dict[str, Model] = {} - # List of allowed models for this provider, if empty all models allowed - allowed_models: list[str] = [] - # Optional field name in provider data to look for API key, which takes precedence provider_data_api_key_field: str | None = None @@ -441,7 +438,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): for provider_model_id in provider_models_ids: if not isinstance(provider_model_id, str): raise ValueError(f"Model ID {provider_model_id} from list_provider_model_ids() is not a string") - if self.allowed_models and provider_model_id not in self.allowed_models: + if self.config.allowed_models is not None and provider_model_id not in self.config.allowed_models: logger.info(f"Skipping model {provider_model_id} as it is not in the allowed models list") continue model = self.construct_model_from_identifier(provider_model_id) diff --git a/src/llama_stack/providers/utils/memory/vector_store.py b/src/llama_stack/providers/utils/memory/vector_store.py index 6c8746e92..99f875227 100644 --- a/src/llama_stack/providers/utils/memory/vector_store.py +++ b/src/llama_stack/providers/utils/memory/vector_store.py @@ -196,6 +196,7 @@ def make_overlapped_chunks( chunks.append( Chunk( content=chunk, + chunk_id=chunk_id, metadata=chunk_metadata, chunk_metadata=backend_chunk_metadata, ) diff --git a/src/llama_stack/providers/utils/responses/responses_store.py b/src/llama_stack/providers/utils/responses/responses_store.py index d5c243252..40466d00c 100644 --- a/src/llama_stack/providers/utils/responses/responses_store.py +++ b/src/llama_stack/providers/utils/responses/responses_store.py @@ -70,13 +70,13 @@ class ResponsesStore: base_store = sqlstore_impl(self.reference) self.sql_store = AuthorizedSqlStore(base_store, self.policy) + # Disable write queue for SQLite since WAL mode handles concurrency + # Keep it enabled for other backends (like Postgres) for performance backend_config = _SQLSTORE_BACKENDS.get(self.reference.backend) - if backend_config is None: - raise ValueError( - f"Unregistered SQL backend '{self.reference.backend}'. Registered backends: {sorted(_SQLSTORE_BACKENDS)}" - ) - if backend_config.type == StorageBackendType.SQL_SQLITE: + if backend_config and backend_config.type == StorageBackendType.SQL_SQLITE: self.enable_write_queue = False + logger.debug("Write queue disabled for SQLite (WAL mode handles concurrency)") + await self.sql_store.create_table( "openai_responses", { @@ -99,8 +99,9 @@ class ResponsesStore: self._queue = asyncio.Queue(maxsize=self._max_write_queue_size) for _ in range(self._num_writers): self._worker_tasks.append(asyncio.create_task(self._worker_loop())) - else: - logger.debug("Write queue disabled for SQLite to avoid concurrency issues") + logger.debug( + f"Responses store write queue enabled with {self._num_writers} writers, max queue size {self._max_write_queue_size}" + ) async def shutdown(self) -> None: if not self._worker_tasks: diff --git a/src/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py b/src/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py index 1bd364d43..356f49ed1 100644 --- a/src/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py +++ b/src/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py @@ -17,6 +17,7 @@ from sqlalchemy import ( String, Table, Text, + event, inspect, select, text, @@ -75,7 +76,36 @@ class SqlAlchemySqlStoreImpl(SqlStore): self.metadata = MetaData() def create_engine(self) -> AsyncEngine: - return create_async_engine(self.config.engine_str, pool_pre_ping=True) + # Configure connection args for better concurrency support + connect_args = {} + if "sqlite" in self.config.engine_str: + # SQLite-specific optimizations for concurrent access + # With WAL mode, most locks resolve in milliseconds, but allow up to 5s for edge cases + connect_args["timeout"] = 5.0 + connect_args["check_same_thread"] = False # Allow usage across asyncio tasks + + engine = create_async_engine( + self.config.engine_str, + pool_pre_ping=True, + connect_args=connect_args, + ) + + # Enable WAL mode for SQLite to support concurrent readers and writers + if "sqlite" in self.config.engine_str: + + @event.listens_for(engine.sync_engine, "connect") + def set_sqlite_pragma(dbapi_conn, connection_record): + cursor = dbapi_conn.cursor() + # Enable Write-Ahead Logging for better concurrency + cursor.execute("PRAGMA journal_mode=WAL") + # Set busy timeout to 5 seconds (retry instead of immediate failure) + # With WAL mode, locks should be brief; if we hit 5s there's a bigger issue + cursor.execute("PRAGMA busy_timeout=5000") + # Use NORMAL synchronous mode for better performance (still safe with WAL) + cursor.execute("PRAGMA synchronous=NORMAL") + cursor.close() + + return engine async def create_table( self, diff --git a/src/llama_stack/strong_typing/inspection.py b/src/llama_stack/strong_typing/inspection.py index d3ebc7585..319d12657 100644 --- a/src/llama_stack/strong_typing/inspection.py +++ b/src/llama_stack/strong_typing/inspection.py @@ -430,6 +430,32 @@ def _unwrap_generic_list(typ: type[list[T]]) -> type[T]: return list_type # type: ignore[no-any-return] +def is_generic_sequence(typ: object) -> bool: + "True if the specified type is a generic Sequence, i.e. `Sequence[T]`." + import collections.abc + + typ = unwrap_annotated_type(typ) + return typing.get_origin(typ) is collections.abc.Sequence + + +def unwrap_generic_sequence(typ: object) -> type: + """ + Extracts the item type of a Sequence type. + + :param typ: The Sequence type `Sequence[T]`. + :returns: The item type `T`. + """ + + return rewrap_annotated_type(_unwrap_generic_sequence, typ) # type: ignore[arg-type] + + +def _unwrap_generic_sequence(typ: object) -> type: + "Extracts the item type of a Sequence type (e.g. returns `T` for `Sequence[T]`)." + + (sequence_type,) = typing.get_args(typ) # unpack single tuple element + return sequence_type # type: ignore[no-any-return] + + def is_generic_set(typ: object) -> TypeGuard[type[set]]: "True if the specified type is a generic set, i.e. `Set[T]`." diff --git a/src/llama_stack/strong_typing/name.py b/src/llama_stack/strong_typing/name.py index 00cdc2ae2..60501ac43 100644 --- a/src/llama_stack/strong_typing/name.py +++ b/src/llama_stack/strong_typing/name.py @@ -18,10 +18,12 @@ from .inspection import ( TypeLike, is_generic_dict, is_generic_list, + is_generic_sequence, is_type_optional, is_type_union, unwrap_generic_dict, unwrap_generic_list, + unwrap_generic_sequence, unwrap_optional_type, unwrap_union_types, ) @@ -155,24 +157,28 @@ def python_type_to_name(data_type: TypeLike, force: bool = False) -> str: if metadata is not None: # type is Annotated[T, ...] arg = typing.get_args(data_type)[0] - return python_type_to_name(arg) + return python_type_to_name(arg, force=force) if force: # generic types if is_type_optional(data_type, strict=True): - inner_name = python_type_to_name(unwrap_optional_type(data_type)) + inner_name = python_type_to_name(unwrap_optional_type(data_type), force=True) return f"Optional__{inner_name}" elif is_generic_list(data_type): - item_name = python_type_to_name(unwrap_generic_list(data_type)) + item_name = python_type_to_name(unwrap_generic_list(data_type), force=True) + return f"List__{item_name}" + elif is_generic_sequence(data_type): + # Treat Sequence the same as List for schema generation purposes + item_name = python_type_to_name(unwrap_generic_sequence(data_type), force=True) return f"List__{item_name}" elif is_generic_dict(data_type): key_type, value_type = unwrap_generic_dict(data_type) - key_name = python_type_to_name(key_type) - value_name = python_type_to_name(value_type) + key_name = python_type_to_name(key_type, force=True) + value_name = python_type_to_name(value_type, force=True) return f"Dict__{key_name}__{value_name}" elif is_type_union(data_type): member_types = unwrap_union_types(data_type) - member_names = "__".join(python_type_to_name(member_type) for member_type in member_types) + member_names = "__".join(python_type_to_name(member_type, force=True) for member_type in member_types) return f"Union__{member_names}" # named system or user-defined type diff --git a/src/llama_stack/strong_typing/schema.py b/src/llama_stack/strong_typing/schema.py index 15a3bbbfc..916690e41 100644 --- a/src/llama_stack/strong_typing/schema.py +++ b/src/llama_stack/strong_typing/schema.py @@ -111,7 +111,7 @@ def get_class_property_docstrings( def docstring_to_schema(data_type: type) -> Schema: short_description, long_description = get_class_docstrings(data_type) schema: Schema = { - "title": python_type_to_name(data_type), + "title": python_type_to_name(data_type, force=True), } description = "\n".join(filter(None, [short_description, long_description])) @@ -417,6 +417,10 @@ class JsonSchemaGenerator: if origin_type is list: (list_type,) = typing.get_args(typ) # unpack single tuple element return {"type": "array", "items": self.type_to_schema(list_type)} + elif origin_type is collections.abc.Sequence: + # Treat Sequence the same as list for JSON schema (both are arrays) + (sequence_type,) = typing.get_args(typ) # unpack single tuple element + return {"type": "array", "items": self.type_to_schema(sequence_type)} elif origin_type is dict: key_type, value_type = typing.get_args(typ) if not (key_type is str or key_type is int or is_type_enum(key_type)): diff --git a/src/llama_stack/ui/app/api/v1/[...path]/route.ts b/src/llama_stack/ui/app/api/v1/[...path]/route.ts index 51c1f8004..d1aa31014 100644 --- a/src/llama_stack/ui/app/api/v1/[...path]/route.ts +++ b/src/llama_stack/ui/app/api/v1/[...path]/route.ts @@ -51,10 +51,14 @@ async function proxyRequest(request: NextRequest, method: string) { ); // Create response with same status and headers - const proxyResponse = new NextResponse(responseText, { - status: response.status, - statusText: response.statusText, - }); + // Handle 204 No Content responses specially + const proxyResponse = + response.status === 204 + ? new NextResponse(null, { status: 204 }) + : new NextResponse(responseText, { + status: response.status, + statusText: response.statusText, + }); // Copy response headers (except problematic ones) response.headers.forEach((value, key) => { diff --git a/src/llama_stack/ui/app/prompts/page.tsx b/src/llama_stack/ui/app/prompts/page.tsx new file mode 100644 index 000000000..30106a056 --- /dev/null +++ b/src/llama_stack/ui/app/prompts/page.tsx @@ -0,0 +1,5 @@ +import { PromptManagement } from "@/components/prompts"; + +export default function PromptsPage() { + return ; +} diff --git a/src/llama_stack/ui/components/layout/app-sidebar.tsx b/src/llama_stack/ui/components/layout/app-sidebar.tsx index 373f0c5ae..a5df60aef 100644 --- a/src/llama_stack/ui/components/layout/app-sidebar.tsx +++ b/src/llama_stack/ui/components/layout/app-sidebar.tsx @@ -8,6 +8,7 @@ import { MessageCircle, Settings2, Compass, + FileText, } from "lucide-react"; import Link from "next/link"; import { usePathname } from "next/navigation"; @@ -50,6 +51,11 @@ const manageItems = [ url: "/logs/vector-stores", icon: Database, }, + { + title: "Prompts", + url: "/prompts", + icon: FileText, + }, { title: "Documentation", url: "https://llama-stack.readthedocs.io/en/latest/references/api_reference/index.html", diff --git a/src/llama_stack/ui/components/prompts/index.ts b/src/llama_stack/ui/components/prompts/index.ts new file mode 100644 index 000000000..d190c5eb6 --- /dev/null +++ b/src/llama_stack/ui/components/prompts/index.ts @@ -0,0 +1,4 @@ +export { PromptManagement } from "./prompt-management"; +export { PromptList } from "./prompt-list"; +export { PromptEditor } from "./prompt-editor"; +export * from "./types"; diff --git a/src/llama_stack/ui/components/prompts/prompt-editor.test.tsx b/src/llama_stack/ui/components/prompts/prompt-editor.test.tsx new file mode 100644 index 000000000..458a5f942 --- /dev/null +++ b/src/llama_stack/ui/components/prompts/prompt-editor.test.tsx @@ -0,0 +1,309 @@ +import React from "react"; +import { render, screen, fireEvent } from "@testing-library/react"; +import "@testing-library/jest-dom"; +import { PromptEditor } from "./prompt-editor"; +import type { Prompt, PromptFormData } from "./types"; + +describe("PromptEditor", () => { + const mockOnSave = jest.fn(); + const mockOnCancel = jest.fn(); + const mockOnDelete = jest.fn(); + + const defaultProps = { + onSave: mockOnSave, + onCancel: mockOnCancel, + onDelete: mockOnDelete, + }; + + beforeEach(() => { + jest.clearAllMocks(); + }); + + describe("Create Mode", () => { + test("renders create form correctly", () => { + render(); + + expect(screen.getByLabelText("Prompt Content *")).toBeInTheDocument(); + expect(screen.getByText("Variables")).toBeInTheDocument(); + expect(screen.getByText("Preview")).toBeInTheDocument(); + expect(screen.getByText("Create Prompt")).toBeInTheDocument(); + expect(screen.getByText("Cancel")).toBeInTheDocument(); + }); + + test("shows preview placeholder when no content", () => { + render(); + + expect( + screen.getByText("Enter content to preview the compiled prompt") + ).toBeInTheDocument(); + }); + + test("submits form with correct data", () => { + render(); + + const promptInput = screen.getByLabelText("Prompt Content *"); + fireEvent.change(promptInput, { + target: { value: "Hello {{name}}, welcome!" }, + }); + + fireEvent.click(screen.getByText("Create Prompt")); + + expect(mockOnSave).toHaveBeenCalledWith({ + prompt: "Hello {{name}}, welcome!", + variables: [], + }); + }); + + test("prevents submission with empty prompt", () => { + render(); + + fireEvent.click(screen.getByText("Create Prompt")); + + expect(mockOnSave).not.toHaveBeenCalled(); + }); + }); + + describe("Edit Mode", () => { + const mockPrompt: Prompt = { + prompt_id: "prompt_123", + prompt: "Hello {{name}}, how is {{weather}}?", + version: 1, + variables: ["name", "weather"], + is_default: true, + }; + + test("renders edit form with existing data", () => { + render(); + + expect( + screen.getByDisplayValue("Hello {{name}}, how is {{weather}}?") + ).toBeInTheDocument(); + expect(screen.getAllByText("name")).toHaveLength(2); // One in variables, one in preview + expect(screen.getAllByText("weather")).toHaveLength(2); // One in variables, one in preview + expect(screen.getByText("Update Prompt")).toBeInTheDocument(); + expect(screen.getByText("Delete Prompt")).toBeInTheDocument(); + }); + + test("submits updated data correctly", () => { + render(); + + const promptInput = screen.getByLabelText("Prompt Content *"); + fireEvent.change(promptInput, { + target: { value: "Updated: Hello {{name}}!" }, + }); + + fireEvent.click(screen.getByText("Update Prompt")); + + expect(mockOnSave).toHaveBeenCalledWith({ + prompt: "Updated: Hello {{name}}!", + variables: ["name", "weather"], + }); + }); + }); + + describe("Variables Management", () => { + test("adds new variable", () => { + render(); + + const variableInput = screen.getByPlaceholderText( + "Add variable name (e.g. user_name, topic)" + ); + fireEvent.change(variableInput, { target: { value: "testVar" } }); + fireEvent.click(screen.getByText("Add")); + + expect(screen.getByText("testVar")).toBeInTheDocument(); + }); + + test("prevents adding duplicate variables", () => { + render(); + + const variableInput = screen.getByPlaceholderText( + "Add variable name (e.g. user_name, topic)" + ); + + // Add first variable + fireEvent.change(variableInput, { target: { value: "test" } }); + fireEvent.click(screen.getByText("Add")); + + // Try to add same variable again + fireEvent.change(variableInput, { target: { value: "test" } }); + + // Button should be disabled + expect(screen.getByText("Add")).toBeDisabled(); + }); + + test("removes variable", () => { + const mockPrompt: Prompt = { + prompt_id: "prompt_123", + prompt: "Hello {{name}}", + version: 1, + variables: ["name", "location"], + is_default: true, + }; + + render(); + + // Check that both variables are present initially + expect(screen.getAllByText("name").length).toBeGreaterThan(0); + expect(screen.getAllByText("location").length).toBeGreaterThan(0); + + // Remove the location variable by clicking the X button with the specific title + const removeLocationButton = screen.getByTitle( + "Remove location variable" + ); + fireEvent.click(removeLocationButton); + + // Name should still be there, location should be gone from the variables section + expect(screen.getAllByText("name").length).toBeGreaterThan(0); + expect( + screen.queryByTitle("Remove location variable") + ).not.toBeInTheDocument(); + }); + + test("adds variable on Enter key", () => { + render(); + + const variableInput = screen.getByPlaceholderText( + "Add variable name (e.g. user_name, topic)" + ); + fireEvent.change(variableInput, { target: { value: "enterVar" } }); + + // Simulate Enter key press + fireEvent.keyPress(variableInput, { + key: "Enter", + code: "Enter", + charCode: 13, + preventDefault: jest.fn(), + }); + + // Check if the variable was added by looking for the badge + expect(screen.getAllByText("enterVar").length).toBeGreaterThan(0); + }); + }); + + describe("Preview Functionality", () => { + test("shows live preview with variables", () => { + render(); + + // Add prompt content + const promptInput = screen.getByLabelText("Prompt Content *"); + fireEvent.change(promptInput, { + target: { value: "Hello {{name}}, welcome to {{place}}!" }, + }); + + // Add variables + const variableInput = screen.getByPlaceholderText( + "Add variable name (e.g. user_name, topic)" + ); + fireEvent.change(variableInput, { target: { value: "name" } }); + fireEvent.click(screen.getByText("Add")); + + fireEvent.change(variableInput, { target: { value: "place" } }); + fireEvent.click(screen.getByText("Add")); + + // Check that preview area shows the content + expect(screen.getByText("Compiled Prompt")).toBeInTheDocument(); + }); + + test("shows variable value inputs in preview", () => { + const mockPrompt: Prompt = { + prompt_id: "prompt_123", + prompt: "Hello {{name}}", + version: 1, + variables: ["name"], + is_default: true, + }; + + render(); + + expect(screen.getByText("Variable Values")).toBeInTheDocument(); + expect( + screen.getByPlaceholderText("Enter value for name") + ).toBeInTheDocument(); + }); + + test("shows color legend for variable states", () => { + render(); + + // Add content to show preview + const promptInput = screen.getByLabelText("Prompt Content *"); + fireEvent.change(promptInput, { + target: { value: "Hello {{name}}" }, + }); + + expect(screen.getByText("Used")).toBeInTheDocument(); + expect(screen.getByText("Unused")).toBeInTheDocument(); + expect(screen.getByText("Undefined")).toBeInTheDocument(); + }); + }); + + describe("Error Handling", () => { + test("displays error message", () => { + const errorMessage = "Prompt contains undeclared variables"; + render(); + + expect(screen.getByText(errorMessage)).toBeInTheDocument(); + }); + }); + + describe("Delete Functionality", () => { + const mockPrompt: Prompt = { + prompt_id: "prompt_123", + prompt: "Hello {{name}}", + version: 1, + variables: ["name"], + is_default: true, + }; + + test("shows delete button in edit mode", () => { + render(); + + expect(screen.getByText("Delete Prompt")).toBeInTheDocument(); + }); + + test("hides delete button in create mode", () => { + render(); + + expect(screen.queryByText("Delete Prompt")).not.toBeInTheDocument(); + }); + + test("calls onDelete with confirmation", () => { + const originalConfirm = window.confirm; + window.confirm = jest.fn(() => true); + + render(); + + fireEvent.click(screen.getByText("Delete Prompt")); + + expect(window.confirm).toHaveBeenCalledWith( + "Are you sure you want to delete this prompt? This action cannot be undone." + ); + expect(mockOnDelete).toHaveBeenCalledWith("prompt_123"); + + window.confirm = originalConfirm; + }); + + test("does not delete when confirmation is cancelled", () => { + const originalConfirm = window.confirm; + window.confirm = jest.fn(() => false); + + render(); + + fireEvent.click(screen.getByText("Delete Prompt")); + + expect(mockOnDelete).not.toHaveBeenCalled(); + + window.confirm = originalConfirm; + }); + }); + + describe("Cancel Functionality", () => { + test("calls onCancel when cancel button is clicked", () => { + render(); + + fireEvent.click(screen.getByText("Cancel")); + + expect(mockOnCancel).toHaveBeenCalled(); + }); + }); +}); diff --git a/src/llama_stack/ui/components/prompts/prompt-editor.tsx b/src/llama_stack/ui/components/prompts/prompt-editor.tsx new file mode 100644 index 000000000..efa76f757 --- /dev/null +++ b/src/llama_stack/ui/components/prompts/prompt-editor.tsx @@ -0,0 +1,346 @@ +"use client"; + +import { useState, useEffect } from "react"; +import { Button } from "@/components/ui/button"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; +import { Textarea } from "@/components/ui/textarea"; +import { Badge } from "@/components/ui/badge"; +import { + Card, + CardContent, + CardDescription, + CardHeader, + CardTitle, +} from "@/components/ui/card"; +import { Separator } from "@/components/ui/separator"; +import { X, Plus, Save, Trash2 } from "lucide-react"; +import { Prompt, PromptFormData } from "./types"; + +interface PromptEditorProps { + prompt?: Prompt; + onSave: (prompt: PromptFormData) => void; + onCancel: () => void; + onDelete?: (promptId: string) => void; + error?: string | null; +} + +export function PromptEditor({ + prompt, + onSave, + onCancel, + onDelete, + error, +}: PromptEditorProps) { + const [formData, setFormData] = useState({ + prompt: "", + variables: [], + }); + + const [newVariable, setNewVariable] = useState(""); + const [variableValues, setVariableValues] = useState>( + {} + ); + + useEffect(() => { + if (prompt) { + setFormData({ + prompt: prompt.prompt || "", + variables: prompt.variables || [], + }); + } + }, [prompt]); + + const handleSubmit = (e: React.FormEvent) => { + e.preventDefault(); + if (!formData.prompt.trim()) { + return; + } + onSave(formData); + }; + + const addVariable = () => { + if ( + newVariable.trim() && + !formData.variables.includes(newVariable.trim()) + ) { + setFormData(prev => ({ + ...prev, + variables: [...prev.variables, newVariable.trim()], + })); + setNewVariable(""); + } + }; + + const removeVariable = (variableToRemove: string) => { + setFormData(prev => ({ + ...prev, + variables: prev.variables.filter( + variable => variable !== variableToRemove + ), + })); + }; + + const renderPreview = () => { + const text = formData.prompt; + if (!text) return text; + + // Split text by variable patterns and process each part + const parts = text.split(/(\{\{\s*\w+\s*\}\})/g); + + return parts.map((part, index) => { + const variableMatch = part.match(/\{\{\s*(\w+)\s*\}\}/); + if (variableMatch) { + const variableName = variableMatch[1]; + const isDefined = formData.variables.includes(variableName); + const value = variableValues[variableName]; + + if (!isDefined) { + // Variable not in variables list - likely a typo/bug (RED) + return ( + + {part} + + ); + } else if (value && value.trim()) { + // Variable defined and has value - show the value (GREEN) + return ( + + {value} + + ); + } else { + // Variable defined but empty (YELLOW) + return ( + + {part} + + ); + } + } + return part; + }); + }; + + const updateVariableValue = (variable: string, value: string) => { + setVariableValues(prev => ({ + ...prev, + [variable]: value, + })); + }; + + return ( +
+ {error && ( +
+

{error}

+
+ )} +
+ {/* Form Section */} +
+
+ +