diff --git a/.github/actions/run-and-record-tests/action.yml b/.github/actions/run-and-record-tests/action.yml new file mode 100644 index 000000000..573148e46 --- /dev/null +++ b/.github/actions/run-and-record-tests/action.yml @@ -0,0 +1,82 @@ +name: 'Run and Record Tests' +description: 'Run integration tests and handle recording/artifact upload' + +inputs: + test-types: + description: 'JSON array of test types to run' + required: true + stack-config: + description: 'Stack configuration to use' + required: true + provider: + description: 'Provider to use for tests' + required: true + inference-mode: + description: 'Inference mode (record or replay)' + required: true + run-vision-tests: + description: 'Whether to run vision tests' + required: false + default: 'false' + +runs: + using: 'composite' + steps: + - name: Check Storage and Memory Available Before Tests + if: ${{ always() }} + shell: bash + run: | + free -h + df -h + + - name: Run Integration Tests + shell: bash + run: | + ./scripts/integration-tests.sh \ + --stack-config '${{ inputs.stack-config }}' \ + --provider '${{ inputs.provider }}' \ + --test-types '${{ inputs.test-types }}' \ + --inference-mode '${{ inputs.inference-mode }}' \ + ${{ inputs.run-vision-tests == 'true' && '--run-vision-tests' || '' }} + + + - name: Commit and push recordings + if: ${{ inputs.inference-mode == 'record' }} + shell: bash + run: | + echo "Checking for recording changes" + git status --porcelain tests/integration/recordings/ + + if [[ -n $(git status --porcelain tests/integration/recordings/) ]]; then + echo "New recordings detected, committing and pushing" + git add tests/integration/recordings/ + + if [ "${{ inputs.run-vision-tests }}" == "true" ]; then + git commit -m "Recordings update from CI (vision)" + else + git commit -m "Recordings update from CI" + fi + + git fetch origin ${{ github.event.pull_request.head.ref }} + git rebase origin/${{ github.event.pull_request.head.ref }} + echo "Rebased successfully" + git push origin HEAD:${{ github.event.pull_request.head.ref }} + echo "Pushed successfully" + else + echo "No recording changes" + fi + + - name: Write inference logs to file + if: ${{ always() }} + shell: bash + run: | + sudo docker logs ollama > ollama-${{ inputs.inference-mode }}.log || true + + - name: Upload logs + if: ${{ always() }} + uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + with: + name: logs-${{ github.run_id }}-${{ github.run_attempt || '' }}-${{ strategy.job-index }} + path: | + *.log + retention-days: 1 diff --git a/.github/actions/setup-ollama/action.yml b/.github/actions/setup-ollama/action.yml index bb08520f1..e57876cb0 100644 --- a/.github/actions/setup-ollama/action.yml +++ b/.github/actions/setup-ollama/action.yml @@ -1,11 +1,23 @@ name: Setup Ollama description: Start Ollama +inputs: + run-vision-tests: + description: 'Run vision tests: "true" or "false"' + required: false + default: 'false' runs: using: "composite" steps: - name: Start Ollama shell: bash run: | - docker run -d --name ollama -p 11434:11434 docker.io/leseb/ollama-with-models + if [ "${{ inputs.run-vision-tests }}" == "true" ]; then + image="ollama-with-vision-model" + else + image="ollama-with-models" + fi + + echo "Starting Ollama with image: $image" + docker run -d --name ollama -p 11434:11434 docker.io/llamastack/$image echo "Verifying Ollama status..." timeout 30 bash -c 'while ! curl -s -L http://127.0.0.1:11434; do sleep 1 && echo "."; done' diff --git a/.github/actions/setup-test-environment/action.yml b/.github/actions/setup-test-environment/action.yml new file mode 100644 index 000000000..30b9b0130 --- /dev/null +++ b/.github/actions/setup-test-environment/action.yml @@ -0,0 +1,51 @@ +name: 'Setup Test Environment' +description: 'Common setup steps for integration tests including dependencies, providers, and build' + +inputs: + python-version: + description: 'Python version to use' + required: true + client-version: + description: 'Client version (latest or published)' + required: true + provider: + description: 'Provider to setup (ollama or vllm)' + required: true + default: 'ollama' + run-vision-tests: + description: 'Whether to setup provider for vision tests' + required: false + default: 'false' + inference-mode: + description: 'Inference mode (record or replay)' + required: true + +runs: + using: 'composite' + steps: + - name: Install dependencies + uses: ./.github/actions/setup-runner + with: + python-version: ${{ inputs.python-version }} + client-version: ${{ inputs.client-version }} + + - name: Setup ollama + if: ${{ inputs.provider == 'ollama' && inputs.inference-mode == 'record' }} + uses: ./.github/actions/setup-ollama + with: + run-vision-tests: ${{ inputs.run-vision-tests }} + + - name: Setup vllm + if: ${{ inputs.provider == 'vllm' && inputs.inference-mode == 'record' }} + uses: ./.github/actions/setup-vllm + + - name: Build Llama Stack + shell: bash + run: | + uv run llama stack build --template ci-tests --image-type venv + + - name: Configure git for commits + shell: bash + run: | + git config --local user.email "github-actions[bot]@users.noreply.github.com" + git config --local user.name "github-actions[bot]" diff --git a/.github/workflows/README.md b/.github/workflows/README.md index 3347b05f8..3c3d93dc2 100644 --- a/.github/workflows/README.md +++ b/.github/workflows/README.md @@ -1,19 +1,19 @@ # Llama Stack CI -Llama Stack uses GitHub Actions for Continous Integration (CI). Below is a table detailing what CI the project includes and the purpose. +Llama Stack uses GitHub Actions for Continuous Integration (CI). Below is a table detailing what CI the project includes and the purpose. | Name | File | Purpose | | ---- | ---- | ------- | | Update Changelog | [changelog.yml](changelog.yml) | Creates PR for updating the CHANGELOG.md | -| Coverage Badge | [coverage-badge.yml](coverage-badge.yml) | Creates PR for updating the code coverage badge | | Installer CI | [install-script-ci.yml](install-script-ci.yml) | Test the installation script | | Integration Auth Tests | [integration-auth-tests.yml](integration-auth-tests.yml) | Run the integration test suite with Kubernetes authentication | | SqlStore Integration Tests | [integration-sql-store-tests.yml](integration-sql-store-tests.yml) | Run the integration test suite with SqlStore | -| Integration Tests | [integration-tests.yml](integration-tests.yml) | Run the integration test suite with Ollama | +| Integration Tests (Replay) | [integration-tests.yml](integration-tests.yml) | Run the integration test suite from tests/integration in replay mode | | Vector IO Integration Tests | [integration-vector-io-tests.yml](integration-vector-io-tests.yml) | Run the integration test suite with various VectorIO providers | | Pre-commit | [pre-commit.yml](pre-commit.yml) | Run pre-commit checks | | Test Llama Stack Build | [providers-build.yml](providers-build.yml) | Test llama stack build | | Python Package Build Test | [python-build-test.yml](python-build-test.yml) | Test building the llama-stack PyPI project | +| Integration Tests (Record) | [record-integration-tests.yml](record-integration-tests.yml) | Run the integration test suite from tests/integration | | Check semantic PR titles | [semantic-pr.yml](semantic-pr.yml) | Ensure that PR titles follow the conventional commit spec | | Close stale issues and PRs | [stale_bot.yml](stale_bot.yml) | Run the Stale Bot action | | Test External Providers Installed via Module | [test-external-provider-module.yml](test-external-provider-module.yml) | Test External Provider installation via Python module | diff --git a/.github/workflows/coverage-badge.yml b/.github/workflows/coverage-badge.yml deleted file mode 100644 index 75428539e..000000000 --- a/.github/workflows/coverage-badge.yml +++ /dev/null @@ -1,62 +0,0 @@ -name: Coverage Badge - -run-name: Creates PR for updating the code coverage badge - -on: - push: - branches: [ main ] - paths: - - 'llama_stack/**' - - 'tests/unit/**' - - 'uv.lock' - - 'pyproject.toml' - - 'requirements.txt' - - '.github/workflows/unit-tests.yml' - - '.github/workflows/coverage-badge.yml' # This workflow - workflow_dispatch: - -jobs: - unit-tests: - permissions: - contents: write # for peter-evans/create-pull-request to create branch - pull-requests: write # for peter-evans/create-pull-request to create a PR - runs-on: ubuntu-latest - steps: - - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - - name: Install dependencies - uses: ./.github/actions/setup-runner - - - name: Run unit tests - run: | - ./scripts/unit-tests.sh - - - name: Coverage Badge - uses: tj-actions/coverage-badge-py@1788babcb24544eb5bbb6e0d374df5d1e54e670f # v2.0.4 - - - name: Verify Changed files - uses: tj-actions/verify-changed-files@a1c6acee9df209257a246f2cc6ae8cb6581c1edf # v20.0.4 - id: verify-changed-files - with: - files: coverage.svg - - - name: Commit files - if: steps.verify-changed-files.outputs.files_changed == 'true' - run: | - git config --local user.email "github-actions[bot]@users.noreply.github.com" - git config --local user.name "github-actions[bot]" - git add coverage.svg - git commit -m "Updated coverage.svg" - - - name: Create Pull Request - if: steps.verify-changed-files.outputs.files_changed == 'true' - uses: peter-evans/create-pull-request@271a8d0340265f705b14b6d32b9829c1cb33d45e # v7.0.8 - with: - token: ${{ secrets.GITHUB_TOKEN }} - title: "ci: [Automatic] Coverage Badge Update" - body: | - This PR updates the coverage badge based on the latest coverage report. - - Automatically generated by the [workflow coverage-badge.yaml](.github/workflows/coverage-badge.yaml) - delete-branch: true diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index be2613fbb..a2a56c003 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -1,20 +1,22 @@ -name: Integration Tests +name: Integration Tests (Replay) -run-name: Run the integration test suite with Ollama +run-name: Run the integration test suite from tests/integration in replay mode on: push: branches: [ main ] pull_request: branches: [ main ] + types: [opened, synchronize, reopened] paths: - 'llama_stack/**' - 'tests/**' - 'uv.lock' - 'pyproject.toml' - - 'requirements.txt' - '.github/workflows/integration-tests.yml' # This workflow - '.github/actions/setup-ollama/action.yml' + - '.github/actions/setup-test-environment/action.yml' + - '.github/actions/run-and-record-tests/action.yml' schedule: # If changing the cron schedule, update the provider in the test-matrix job - cron: '0 0 * * *' # (test latest client) Daily at 12 AM UTC @@ -31,129 +33,64 @@ on: default: 'ollama' concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + # Skip concurrency for pushes to main - each commit should be tested independently + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }} cancel-in-progress: true jobs: discover-tests: runs-on: ubuntu-latest outputs: - test-type: ${{ steps.generate-matrix.outputs.test-type }} + test-types: ${{ steps.generate-test-types.outputs.test-types }} + steps: - name: Checkout repository uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Generate test matrix - id: generate-matrix + - name: Generate test types + id: generate-test-types run: | # Get test directories dynamically, excluding non-test directories + # NOTE: we are excluding post_training since the tests take too long TEST_TYPES=$(find tests/integration -maxdepth 1 -mindepth 1 -type d -printf "%f\n" | - grep -Ev "^(__pycache__|fixtures|test_cases)$" | + grep -Ev "^(__pycache__|fixtures|test_cases|recordings|post_training)$" | sort | jq -R -s -c 'split("\n")[:-1]') - echo "test-type=$TEST_TYPES" >> $GITHUB_OUTPUT + echo "test-types=$TEST_TYPES" >> $GITHUB_OUTPUT - test-matrix: + run-replay-mode-tests: needs: discover-tests runs-on: ubuntu-latest + name: ${{ format('Integration Tests ({0}, {1}, {2}, client={3}, vision={4})', matrix.client-type, matrix.provider, matrix.python-version, matrix.client-version, matrix.run-vision-tests) }} strategy: fail-fast: false matrix: - test-type: ${{ fromJson(needs.discover-tests.outputs.test-type) }} client-type: [library, server] # Use vllm on weekly schedule, otherwise use test-provider input (defaults to ollama) provider: ${{ (github.event.schedule == '1 0 * * 0') && fromJSON('["vllm"]') || fromJSON(format('["{0}"]', github.event.inputs.test-provider || 'ollama')) }} - python-version: ["3.12", "3.13"] - client-version: ${{ (github.event.schedule == '0 0 * * 0' || github.event.inputs.test-all-client-versions == 'true') && fromJSON('["published", "latest"]') || fromJSON('["latest"]') }} - exclude: # TODO: look into why these tests are failing and fix them - - provider: vllm - test-type: safety - - provider: vllm - test-type: post_training - - provider: vllm - test-type: tool_runtime + # Use Python 3.13 only on nightly schedule (daily latest client test), otherwise use 3.12 + python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }} + client-version: ${{ (github.event.schedule == '0 0 * * *' || github.event.inputs.test-all-client-versions == 'true') && fromJSON('["published", "latest"]') || fromJSON('["latest"]') }} + run-vision-tests: [true, false] steps: - name: Checkout repository uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Install dependencies - uses: ./.github/actions/setup-runner + - name: Setup test environment + uses: ./.github/actions/setup-test-environment with: python-version: ${{ matrix.python-version }} client-version: ${{ matrix.client-version }} + provider: ${{ matrix.provider }} + run-vision-tests: ${{ matrix.run-vision-tests }} + inference-mode: 'replay' - - name: Setup ollama - if: ${{ matrix.provider == 'ollama' }} - uses: ./.github/actions/setup-ollama - - - name: Setup vllm - if: ${{ matrix.provider == 'vllm' }} - uses: ./.github/actions/setup-vllm - - - name: Build Llama Stack - run: | - uv run llama stack build --template ci-tests --image-type venv - - - name: Check Storage and Memory Available Before Tests - if: ${{ always() }} - run: | - free -h - df -h - - - name: Run Integration Tests - env: - LLAMA_STACK_CLIENT_TIMEOUT: "300" # Increased timeout for eval operations - # Use 'shell' to get pipefail behavior - # https://docs.github.com/en/actions/reference/workflow-syntax-for-github-actions#exit-codes-and-error-action-preference - # TODO: write a precommit hook to detect if a test contains a pipe but does not use 'shell: bash' - shell: bash - run: | - if [ "${{ matrix.client-type }}" == "library" ]; then - stack_config="ci-tests" - else - stack_config="server:ci-tests" - fi - - EXCLUDE_TESTS="builtin_tool or safety_with_image or code_interpreter or test_rag" - if [ "${{ matrix.provider }}" == "ollama" ]; then - export OLLAMA_URL="http://0.0.0.0:11434" - export TEXT_MODEL=ollama/llama3.2:3b-instruct-fp16 - export SAFETY_MODEL="ollama/llama-guard3:1b" - EXTRA_PARAMS="--safety-shield=llama-guard" - else - export VLLM_URL="http://localhost:8000/v1" - export TEXT_MODEL=vllm/meta-llama/Llama-3.2-1B-Instruct - # TODO: remove the not(test_inference_store_tool_calls) once we can get the tool called consistently - EXTRA_PARAMS= - EXCLUDE_TESTS="${EXCLUDE_TESTS} or test_inference_store_tool_calls" - fi - - - uv run pytest -s -v tests/integration/${{ matrix.test-type }} --stack-config=${stack_config} \ - -k "not( ${EXCLUDE_TESTS} )" \ - --text-model=$TEXT_MODEL \ - --embedding-model=sentence-transformers/all-MiniLM-L6-v2 \ - --color=yes ${EXTRA_PARAMS} \ - --capture=tee-sys | tee pytest-${{ matrix.test-type }}.log - - - name: Check Storage and Memory Available After Tests - if: ${{ always() }} - run: | - free -h - df -h - - - name: Write inference logs to file - if: ${{ always() }} - run: | - sudo docker logs ollama > ollama.log || true - sudo docker logs vllm > vllm.log || true - - - name: Upload all logs to artifacts - if: ${{ always() }} - uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + - name: Run tests + uses: ./.github/actions/run-and-record-tests with: - name: logs-${{ github.run_id }}-${{ github.run_attempt }}-${{ matrix.provider }}-${{ matrix.client-type }}-${{ matrix.test-type }}-${{ matrix.python-version }}-${{ matrix.client-version }} - path: | - *.log - retention-days: 1 + test-types: ${{ needs.discover-tests.outputs.test-types }} + stack-config: ${{ matrix.client-type == 'library' && 'ci-tests' || 'server:ci-tests' }} + provider: ${{ matrix.provider }} + inference-mode: 'replay' + run-vision-tests: ${{ matrix.run-vision-tests }} diff --git a/.github/workflows/integration-vector-io-tests.yml b/.github/workflows/integration-vector-io-tests.yml index 9a02bbcf8..aa239572b 100644 --- a/.github/workflows/integration-vector-io-tests.yml +++ b/.github/workflows/integration-vector-io-tests.yml @@ -24,7 +24,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - vector-io-provider: ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "remote::chromadb", "remote::pgvector"] + vector-io-provider: ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "remote::chromadb", "remote::pgvector", "remote::weaviate", "remote::qdrant"] python-version: ["3.12", "3.13"] fail-fast: false # we want to run all tests regardless of failure @@ -48,6 +48,14 @@ jobs: -e ANONYMIZED_TELEMETRY=FALSE \ chromadb/chroma:latest + - name: Setup Weaviate + if: matrix.vector-io-provider == 'remote::weaviate' + run: | + docker run --rm -d --pull always \ + --name weaviate \ + -p 8080:8080 -p 50051:50051 \ + cr.weaviate.io/semitechnologies/weaviate:1.32.0 + - name: Start PGVector DB if: matrix.vector-io-provider == 'remote::pgvector' run: | @@ -78,6 +86,29 @@ jobs: PGPASSWORD=llamastack psql -h localhost -U llamastack -d llamastack \ -c "CREATE EXTENSION IF NOT EXISTS vector;" + - name: Setup Qdrant + if: matrix.vector-io-provider == 'remote::qdrant' + run: | + docker run --rm -d --pull always \ + --name qdrant \ + -p 6333:6333 \ + qdrant/qdrant + + - name: Wait for Qdrant to be ready + if: matrix.vector-io-provider == 'remote::qdrant' + run: | + echo "Waiting for Qdrant to be ready..." + for i in {1..30}; do + if curl -s http://localhost:6333/collections | grep -q '"status":"ok"'; then + echo "Qdrant is ready!" + exit 0 + fi + sleep 2 + done + echo "Qdrant failed to start" + docker logs qdrant + exit 1 + - name: Wait for ChromaDB to be ready if: matrix.vector-io-provider == 'remote::chromadb' run: | @@ -93,6 +124,21 @@ jobs: docker logs chromadb exit 1 + - name: Wait for Weaviate to be ready + if: matrix.vector-io-provider == 'remote::weaviate' + run: | + echo "Waiting for Weaviate to be ready..." + for i in {1..30}; do + if curl -s http://localhost:8080 | grep -q "https://weaviate.io/developers/weaviate/current/"; then + echo "Weaviate is ready!" + exit 0 + fi + sleep 2 + done + echo "Weaviate failed to start" + docker logs weaviate + exit 1 + - name: Build Llama Stack run: | uv run llama stack build --template ci-tests --image-type venv @@ -113,6 +159,10 @@ jobs: PGVECTOR_DB: ${{ matrix.vector-io-provider == 'remote::pgvector' && 'llamastack' || '' }} PGVECTOR_USER: ${{ matrix.vector-io-provider == 'remote::pgvector' && 'llamastack' || '' }} PGVECTOR_PASSWORD: ${{ matrix.vector-io-provider == 'remote::pgvector' && 'llamastack' || '' }} + ENABLE_QDRANT: ${{ matrix.vector-io-provider == 'remote::qdrant' && 'true' || '' }} + QDRANT_URL: ${{ matrix.vector-io-provider == 'remote::qdrant' && 'http://localhost:6333' || '' }} + ENABLE_WEAVIATE: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'true' || '' }} + WEAVIATE_CLUSTER_URL: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'localhost:8080' || '' }} run: | uv run pytest -sv --stack-config="inference=inline::sentence-transformers,vector_io=${{ matrix.vector-io-provider }}" \ tests/integration/vector_io \ @@ -134,6 +184,11 @@ jobs: run: | docker logs chromadb > chromadb.log + - name: Write Qdrant logs to file + if: ${{ always() && matrix.vector-io-provider == 'remote::qdrant' }} + run: | + docker logs qdrant > qdrant.log + - name: Upload all logs to artifacts if: ${{ always() }} uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 diff --git a/.github/workflows/providers-build.yml b/.github/workflows/providers-build.yml index 284076d50..929d76760 100644 --- a/.github/workflows/providers-build.yml +++ b/.github/workflows/providers-build.yml @@ -9,20 +9,20 @@ on: paths: - 'llama_stack/cli/stack/build.py' - 'llama_stack/cli/stack/_build.py' - - 'llama_stack/distribution/build.*' - - 'llama_stack/distribution/*.sh' + - 'llama_stack/core/build.*' + - 'llama_stack/core/*.sh' - '.github/workflows/providers-build.yml' - - 'llama_stack/templates/**' + - 'llama_stack/distributions/**' - 'pyproject.toml' pull_request: paths: - 'llama_stack/cli/stack/build.py' - 'llama_stack/cli/stack/_build.py' - - 'llama_stack/distribution/build.*' - - 'llama_stack/distribution/*.sh' + - 'llama_stack/core/build.*' + - 'llama_stack/core/*.sh' - '.github/workflows/providers-build.yml' - - 'llama_stack/templates/**' + - 'llama_stack/distributions/**' - 'pyproject.toml' concurrency: @@ -33,23 +33,23 @@ jobs: generate-matrix: runs-on: ubuntu-latest outputs: - templates: ${{ steps.set-matrix.outputs.templates }} + distros: ${{ steps.set-matrix.outputs.distros }} steps: - name: Checkout repository uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - name: Generate Template List + - name: Generate Distribution List id: set-matrix run: | - templates=$(ls llama_stack/templates/*/*build.yaml | awk -F'/' '{print $(NF-1)}' | jq -R -s -c 'split("\n")[:-1]') - echo "templates=$templates" >> "$GITHUB_OUTPUT" + distros=$(ls llama_stack/distributions/*/*build.yaml | awk -F'/' '{print $(NF-1)}' | jq -R -s -c 'split("\n")[:-1]') + echo "distros=$distros" >> "$GITHUB_OUTPUT" build: needs: generate-matrix runs-on: ubuntu-latest strategy: matrix: - template: ${{ fromJson(needs.generate-matrix.outputs.templates) }} + distro: ${{ fromJson(needs.generate-matrix.outputs.distros) }} image-type: [venv, container] fail-fast: false # We want to run all jobs even if some fail @@ -62,13 +62,13 @@ jobs: - name: Print build dependencies run: | - uv run llama stack build --template ${{ matrix.template }} --image-type ${{ matrix.image-type }} --image-name test --print-deps-only + uv run llama stack build --distro ${{ matrix.distro }} --image-type ${{ matrix.image-type }} --image-name test --print-deps-only - name: Run Llama Stack Build run: | # USE_COPY_NOT_MOUNT is set to true since mounting is not supported by docker buildx, we use COPY instead # LLAMA_STACK_DIR is set to the current directory so we are building from the source - USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --template ${{ matrix.template }} --image-type ${{ matrix.image-type }} --image-name test + USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --distro ${{ matrix.distro }} --image-type ${{ matrix.image-type }} --image-name test - name: Print dependencies in the image if: matrix.image-type == 'venv' @@ -99,16 +99,16 @@ jobs: - name: Build a single provider run: | - yq -i '.image_type = "container"' llama_stack/templates/ci-tests/build.yaml - yq -i '.image_name = "test"' llama_stack/templates/ci-tests/build.yaml - USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --config llama_stack/templates/ci-tests/build.yaml + yq -i '.image_type = "container"' llama_stack/distributions/ci-tests/build.yaml + yq -i '.image_name = "test"' llama_stack/distributions/ci-tests/build.yaml + USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --config llama_stack/distributions/ci-tests/build.yaml - name: Inspect the container image entrypoint run: | IMAGE_ID=$(docker images --format "{{.Repository}}:{{.Tag}}" | head -n 1) entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID) echo "Entrypoint: $entrypoint" - if [ "$entrypoint" != "[python -m llama_stack.distribution.server.server --config /app/run.yaml]" ]; then + if [ "$entrypoint" != "[python -m llama_stack.core.server.server /app/run.yaml]" ]; then echo "Entrypoint is not correct" exit 1 fi @@ -122,27 +122,27 @@ jobs: - name: Install dependencies uses: ./.github/actions/setup-runner - - name: Pin template to UBI9 base + - name: Pin distribution to UBI9 base run: | yq -i ' .image_type = "container" | .image_name = "ubi9-test" | .distribution_spec.container_image = "registry.access.redhat.com/ubi9:latest" - ' llama_stack/templates/ci-tests/build.yaml + ' llama_stack/distributions/ci-tests/build.yaml - name: Build dev container (UBI9) env: USE_COPY_NOT_MOUNT: "true" LLAMA_STACK_DIR: "." run: | - uv run llama stack build --config llama_stack/templates/ci-tests/build.yaml + uv run llama stack build --config llama_stack/distributions/ci-tests/build.yaml - name: Inspect UBI9 image run: | IMAGE_ID=$(docker images --format "{{.Repository}}:{{.Tag}}" | head -n 1) entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID) echo "Entrypoint: $entrypoint" - if [ "$entrypoint" != "[python -m llama_stack.distribution.server.server --config /app/run.yaml]" ]; then + if [ "$entrypoint" != "[python -m llama_stack.core.server.server /app/run.yaml]" ]; then echo "Entrypoint is not correct" exit 1 fi diff --git a/.github/workflows/record-integration-tests.yml b/.github/workflows/record-integration-tests.yml new file mode 100644 index 000000000..12957db27 --- /dev/null +++ b/.github/workflows/record-integration-tests.yml @@ -0,0 +1,109 @@ +name: Integration Tests (Record) + +run-name: Run the integration test suite from tests/integration + +on: + pull_request: + branches: [ main ] + types: [opened, synchronize, labeled] + paths: + - 'llama_stack/**' + - 'tests/**' + - 'uv.lock' + - 'pyproject.toml' + - '.github/workflows/record-integration-tests.yml' # This workflow + - '.github/actions/setup-ollama/action.yml' + - '.github/actions/setup-test-environment/action.yml' + - '.github/actions/run-and-record-tests/action.yml' + workflow_dispatch: + inputs: + test-provider: + description: 'Test against a specific provider' + type: string + default: 'ollama' + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + discover-tests: + if: contains(github.event.pull_request.labels.*.name, 're-record-tests') || + contains(github.event.pull_request.labels.*.name, 're-record-vision-tests') + runs-on: ubuntu-latest + outputs: + test-types: ${{ steps.generate-test-types.outputs.test-types }} + matrix-modes: ${{ steps.generate-test-types.outputs.matrix-modes }} + + steps: + - name: Checkout repository + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Generate test types + id: generate-test-types + run: | + # Get test directories dynamically, excluding non-test directories + TEST_TYPES=$(find tests/integration -maxdepth 1 -mindepth 1 -type d -printf "%f\n" | + grep -Ev "^(__pycache__|fixtures|test_cases|recordings|post_training)$" | + sort | jq -R -s -c 'split("\n")[:-1]') + echo "test-types=$TEST_TYPES" >> $GITHUB_OUTPUT + + labels=$(gh pr view ${{ github.event.pull_request.number }} --json labels --jq '.labels[].name') + echo "labels=$labels" + + modes_array=() + if [[ $labels == *"re-record-vision-tests"* ]]; then + modes_array+=("vision") + fi + if [[ $labels == *"re-record-tests"* ]]; then + modes_array+=("non-vision") + fi + + # Convert to JSON array + if [ ${#modes_array[@]} -eq 0 ]; then + matrix_modes="[]" + else + matrix_modes=$(printf '%s\n' "${modes_array[@]}" | jq -R -s -c 'split("\n")[:-1]') + fi + echo "matrix_modes=$matrix_modes" + echo "matrix-modes=$matrix_modes" >> $GITHUB_OUTPUT + + env: + GH_TOKEN: ${{ github.token }} + + record-tests: + needs: discover-tests + runs-on: ubuntu-latest + + permissions: + contents: write + + strategy: + fail-fast: false + matrix: + mode: ${{ fromJSON(needs.discover-tests.outputs.matrix-modes) }} + + steps: + - name: Checkout repository + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + ref: ${{ github.event.pull_request.head.ref }} + fetch-depth: 0 + + - name: Setup test environment + uses: ./.github/actions/setup-test-environment + with: + python-version: "3.12" # Use single Python version for recording + client-version: "latest" + provider: ${{ inputs.test-provider || 'ollama' }} + run-vision-tests: ${{ matrix.mode == 'vision' && 'true' || 'false' }} + inference-mode: 'record' + + - name: Run and record tests + uses: ./.github/actions/run-and-record-tests + with: + test-types: ${{ needs.discover-tests.outputs.test-types }} + stack-config: 'server:ci-tests' # recording must be done with server since more tests are run + provider: ${{ inputs.test-provider || 'ollama' }} + inference-mode: 'record' + run-vision-tests: ${{ matrix.mode == 'vision' && 'true' || 'false' }} diff --git a/.github/workflows/test-external-provider-module.yml b/.github/workflows/test-external-provider-module.yml index 8567a9446..d61b0dfe9 100644 --- a/.github/workflows/test-external-provider-module.yml +++ b/.github/workflows/test-external-provider-module.yml @@ -12,12 +12,13 @@ on: - 'tests/integration/**' - 'uv.lock' - 'pyproject.toml' - - 'requirements.txt' - 'tests/external/*' - '.github/workflows/test-external-provider-module.yml' # This workflow jobs: test-external-providers-from-module: + # This workflow is disabled. See https://github.com/meta-llama/llama-stack/pull/2975#issuecomment-3138702984 for details + if: false runs-on: ubuntu-latest strategy: matrix: @@ -47,7 +48,7 @@ jobs: - name: Build distro from config file run: | - USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. llama stack build --config tests/external/ramalama-stack/build.yaml + USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --config tests/external/ramalama-stack/build.yaml - name: Start Llama Stack server in background if: ${{ matrix.image-type }} == 'venv' diff --git a/.github/workflows/test-external.yml b/.github/workflows/test-external.yml index 053b38fab..27181a236 100644 --- a/.github/workflows/test-external.yml +++ b/.github/workflows/test-external.yml @@ -43,11 +43,11 @@ jobs: - name: Print distro dependencies run: | - USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. llama stack build --config tests/external/build.yaml --print-deps-only + USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --config tests/external/build.yaml --print-deps-only - name: Build distro from config file run: | - USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. llama stack build --config tests/external/build.yaml + USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --config tests/external/build.yaml - name: Start Llama Stack server in background if: ${{ matrix.image-type }} == 'venv' diff --git a/CHANGELOG.md b/CHANGELOG.md index aad0cb529..2f47c3ae3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -451,7 +451,7 @@ GenAI application developers need more than just an LLM - they need to integrate Llama Stack was created to provide developers with a comprehensive and coherent interface that simplifies AI application development and codifies best practices across the Llama ecosystem. Since our launch in September 2024, we have seen a huge uptick in interest in Llama Stack APIs by both AI developers and from partners building AI services with Llama models. Partners like Nvidia, Fireworks, and Ollama have collaborated with us to develop implementations across various APIs, including inference, memory, and safety. -With Llama Stack, you can easily build a RAG agent which can also search the web, do complex math, and custom tool calling. You can use telemetry to inspect those traces, and convert telemetry into evals datasets. And with Llama Stack’s plugin architecture and prepackage distributions, you choose to run your agent anywhere - in the cloud with our partners, deploy your own environment using virtualenv, conda, or Docker, operate locally with Ollama, or even run on mobile devices with our SDKs. Llama Stack offers unprecedented flexibility while also simplifying the developer experience. +With Llama Stack, you can easily build a RAG agent which can also search the web, do complex math, and custom tool calling. You can use telemetry to inspect those traces, and convert telemetry into evals datasets. And with Llama Stack’s plugin architecture and prepackage distributions, you choose to run your agent anywhere - in the cloud with our partners, deploy your own environment using virtualenv or Docker, operate locally with Ollama, or even run on mobile devices with our SDKs. Llama Stack offers unprecedented flexibility while also simplifying the developer experience. ## Release After iterating on the APIs for the last 3 months, today we’re launching a stable release (V1) of the Llama Stack APIs and the corresponding llama-stack server and client packages(v0.1.0). We now have automated tests for providers. These tests make sure that all provider implementations are verified. Developers can now easily and reliably select distributions or providers based on their specific requirements. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 8d866328b..066fcecf0 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -157,6 +157,7 @@ uv sync that describes the configuration. These descriptions will be used to generate the provider documentation. * When possible, use keyword arguments only when calling functions. +* Llama Stack utilizes [custom Exception classes](llama_stack/apis/common/errors.py) for certain Resources that should be used where applicable. ## Common Tasks @@ -164,7 +165,7 @@ Some tips about common tasks you work on while contributing to Llama Stack: ### Using `llama stack build` -Building a stack image (conda / docker) will use the production version of the `llama-stack` and `llama-stack-client` packages. If you are developing with a llama-stack repository checked out and need your code to be reflected in the stack image, set `LLAMA_STACK_DIR` and `LLAMA_STACK_CLIENT_DIR` to the appropriate checked out directories when running any of the `llama` CLI commands. +Building a stack image will use the production version of the `llama-stack` and `llama-stack-client` packages. If you are developing with a llama-stack repository checked out and need your code to be reflected in the stack image, set `LLAMA_STACK_DIR` and `LLAMA_STACK_CLIENT_DIR` to the appropriate checked out directories when running any of the `llama` CLI commands. Example: ```bash @@ -172,7 +173,7 @@ cd work/ git clone https://github.com/meta-llama/llama-stack.git git clone https://github.com/meta-llama/llama-stack-client-python.git cd llama-stack -LLAMA_STACK_DIR=$(pwd) LLAMA_STACK_CLIENT_DIR=../llama-stack-client-python llama stack build --template <...> +LLAMA_STACK_DIR=$(pwd) LLAMA_STACK_CLIENT_DIR=../llama-stack-client-python llama stack build --distro <...> ``` ### Updating distribution configurations diff --git a/MANIFEST.in b/MANIFEST.in index 88bd11767..e678e6b01 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,9 +1,9 @@ include pyproject.toml include llama_stack/models/llama/llama3/tokenizer.model include llama_stack/models/llama/llama4/tokenizer.model -include llama_stack/distribution/*.sh +include llama_stack/core/*.sh include llama_stack/cli/scripts/*.sh -include llama_stack/templates/*/*.yaml +include llama_stack/distributions/*/*.yaml include llama_stack/providers/tests/test_cases/inference/*.json include llama_stack/models/llama/*/*.md include llama_stack/tests/integration/*.jpg diff --git a/README.md b/README.md index 7f0fed345..03aa3dd50 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,6 @@ [![Discord](https://img.shields.io/discord/1257833999603335178?color=6A7EC2&logo=discord&logoColor=ffffff)](https://discord.gg/llama-stack) [![Unit Tests](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml/badge.svg?branch=main)](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml?query=branch%3Amain) [![Integration Tests](https://github.com/meta-llama/llama-stack/actions/workflows/integration-tests.yml/badge.svg?branch=main)](https://github.com/meta-llama/llama-stack/actions/workflows/integration-tests.yml?query=branch%3Amain) -![coverage badge](./coverage.svg) [**Quick Start**](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) | [**Documentation**](https://llama-stack.readthedocs.io/en/latest/index.html) | [**Colab Notebook**](./docs/getting_started.ipynb) | [**Discord**](https://discord.gg/llama-stack) @@ -112,29 +111,33 @@ Here is a list of the various API providers and available distributions that can Please checkout for [full list](https://llama-stack.readthedocs.io/en/latest/providers/index.html) | API Provider Builder | Environments | Agents | Inference | VectorIO | Safety | Telemetry | Post Training | Eval | DatasetIO | -|:-------------------:|:------------:|:------:|:---------:|:--------:|:------:|:---------:|:-------------:|:----:|:--------:| -| Meta Reference | Single Node | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| SambaNova | Hosted | | ✅ | | ✅ | | | | | -| Cerebras | Hosted | | ✅ | | | | | | | -| Fireworks | Hosted | ✅ | ✅ | ✅ | | | | | | -| AWS Bedrock | Hosted | | ✅ | | ✅ | | | | | -| Together | Hosted | ✅ | ✅ | | ✅ | | | | | -| Groq | Hosted | | ✅ | | | | | | | -| Ollama | Single Node | | ✅ | | | | | | | -| TGI | Hosted/Single Node | | ✅ | | | | | | | -| NVIDIA NIM | Hosted/Single Node | | ✅ | | ✅ | | | | | -| ChromaDB | Hosted/Single Node | | | ✅ | | | | | | -| PG Vector | Single Node | | | ✅ | | | | | | -| PyTorch ExecuTorch | On-device iOS | ✅ | ✅ | | | | | | | -| vLLM | Single Node | | ✅ | | | | | | | -| OpenAI | Hosted | | ✅ | | | | | | | -| Anthropic | Hosted | | ✅ | | | | | | | -| Gemini | Hosted | | ✅ | | | | | | | -| WatsonX | Hosted | | ✅ | | | | | | | -| HuggingFace | Single Node | | | | | | ✅ | | ✅ | -| TorchTune | Single Node | | | | | | ✅ | | | -| NVIDIA NEMO | Hosted | | ✅ | ✅ | | | ✅ | ✅ | ✅ | -| NVIDIA | Hosted | | | | | | ✅ | ✅ | ✅ | +|:--------------------:|:------------:|:------:|:---------:|:--------:|:------:|:---------:|:-------------:|:----:|:--------:| +| Meta Reference | Single Node | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| SambaNova | Hosted | | ✅ | | ✅ | | | | | +| Cerebras | Hosted | | ✅ | | | | | | | +| Fireworks | Hosted | ✅ | ✅ | ✅ | | | | | | +| AWS Bedrock | Hosted | | ✅ | | ✅ | | | | | +| Together | Hosted | ✅ | ✅ | | ✅ | | | | | +| Groq | Hosted | | ✅ | | | | | | | +| Ollama | Single Node | | ✅ | | | | | | | +| TGI | Hosted/Single Node | | ✅ | | | | | | | +| NVIDIA NIM | Hosted/Single Node | | ✅ | | ✅ | | | | | +| ChromaDB | Hosted/Single Node | | | ✅ | | | | | | +| Milvus | Hosted/Single Node | | | ✅ | | | | | | +| Qdrant | Hosted/Single Node | | | ✅ | | | | | | +| Weaviate | Hosted/Single Node | | | ✅ | | | | | | +| SQLite-vec | Single Node | | | ✅ | | | | | | +| PG Vector | Single Node | | | ✅ | | | | | | +| PyTorch ExecuTorch | On-device iOS | ✅ | ✅ | | | | | | | +| vLLM | Single Node | | ✅ | | | | | | | +| OpenAI | Hosted | | ✅ | | | | | | | +| Anthropic | Hosted | | ✅ | | | | | | | +| Gemini | Hosted | | ✅ | | | | | | | +| WatsonX | Hosted | | ✅ | | | | | | | +| HuggingFace | Single Node | | | | | | ✅ | | ✅ | +| TorchTune | Single Node | | | | | | ✅ | | | +| NVIDIA NEMO | Hosted | | ✅ | ✅ | | | ✅ | ✅ | ✅ | +| NVIDIA | Hosted | | | | | | ✅ | ✅ | ✅ | > **Note**: Additional providers are available through external packages. See [External Providers](https://llama-stack.readthedocs.io/en/latest/providers/external.html) documentation. diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 65b515ef4..d480ff592 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -1452,6 +1452,40 @@ } } ] + }, + "delete": { + "responses": { + "200": { + "description": "OK" + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "Shields" + ], + "description": "Unregister a shield.", + "parameters": [ + { + "name": "identifier", + "in": "path", + "description": "The identifier of the shield to unregister.", + "required": true, + "schema": { + "type": "string" + } + } + ] } }, "/v1/telemetry/traces/{trace_id}/spans/{span_id}": { @@ -1922,7 +1956,7 @@ "get": { "responses": { "200": { - "description": "A HealthInfo.", + "description": "Health information indicating if the service is operational.", "content": { "application/json": { "schema": { @@ -1947,7 +1981,7 @@ "tags": [ "Inspect" ], - "description": "Get the health of the service.", + "description": "Get the current health status of the service.", "parameters": [] } }, @@ -1973,7 +2007,7 @@ "tags": [ "ToolRuntime" ], - "description": "Index documents so they can be used by the RAG system", + "description": "Index documents so they can be used by the RAG system.", "parameters": [], "requestBody": { "content": { @@ -2839,7 +2873,7 @@ "get": { "responses": { "200": { - "description": "A ListRoutesResponse.", + "description": "Response containing information about all available routes.", "content": { "application/json": { "schema": { @@ -2864,7 +2898,7 @@ "tags": [ "Inspect" ], - "description": "List all routes.", + "description": "List all available API routes with their methods and implementing providers.", "parameters": [] } }, @@ -3324,6 +3358,7 @@ { "name": "limit", "in": "query", + "description": "(Optional) A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.", "required": false, "schema": { "type": "integer" @@ -3332,6 +3367,7 @@ { "name": "order", "in": "query", + "description": "(Optional) Sort order by the `created_at` timestamp of the objects. `asc` for ascending order and `desc` for descending order.", "required": false, "schema": { "type": "string" @@ -3340,6 +3376,7 @@ { "name": "after", "in": "query", + "description": "(Optional) A cursor for use in pagination. `after` is an object ID that defines your place in the list.", "required": false, "schema": { "type": "string" @@ -3348,6 +3385,7 @@ { "name": "before", "in": "query", + "description": "(Optional) A cursor for use in pagination. `before` is an object ID that defines your place in the list.", "required": false, "schema": { "type": "string" @@ -3356,6 +3394,7 @@ { "name": "filter", "in": "query", + "description": "(Optional) Filter by file status to only return files with the specified status.", "required": false, "schema": { "$ref": "#/components/schemas/VectorStoreFileStatus" @@ -4345,7 +4384,7 @@ "post": { "responses": { "200": { - "description": "OK", + "description": "RAGQueryResult containing the retrieved content and metadata", "content": { "application/json": { "schema": { @@ -4370,7 +4409,7 @@ "tags": [ "ToolRuntime" ], - "description": "Query the RAG system for context; typically invoked by the agent", + "description": "Query the RAG system for context; typically invoked by the agent.", "parameters": [], "requestBody": { "content": { @@ -4695,6 +4734,49 @@ } } }, + "/v1/openai/v1/moderations": { + "post": { + "responses": { + "200": { + "description": "A moderation object.", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ModerationObject" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "Safety" + ], + "description": "Classifies if text and/or image inputs are potentially harmful.", + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/RunModerationRequest" + } + } + }, + "required": true + } + } + }, "/v1/safety/run-shield": { "post": { "responses": { @@ -4907,7 +4989,7 @@ "post": { "responses": { "200": { - "description": "OK", + "description": "Response containing filtered synthetic data samples and optional statistics", "content": { "application/json": { "schema": { @@ -4932,7 +5014,7 @@ "tags": [ "SyntheticDataGeneration (Coming Soon)" ], - "description": "", + "description": "Generate synthetic data based on input dialogs and apply filtering.", "parameters": [], "requestBody": { "content": { @@ -4950,7 +5032,7 @@ "get": { "responses": { "200": { - "description": "A VersionInfo.", + "description": "Version information containing the service version number.", "content": { "application/json": { "schema": { @@ -5144,14 +5226,16 @@ "type": { "type": "string", "const": "greedy", - "default": "greedy" + "default": "greedy", + "description": "Must be \"greedy\" to identify this sampling strategy" } }, "additionalProperties": false, "required": [ "type" ], - "title": "GreedySamplingStrategy" + "title": "GreedySamplingStrategy", + "description": "Greedy sampling strategy that selects the highest probability token at each step." }, "ImageContentItem": { "type": "object", @@ -5671,10 +5755,12 @@ "type": { "type": "string", "const": "top_k", - "default": "top_k" + "default": "top_k", + "description": "Must be \"top_k\" to identify this sampling strategy" }, "top_k": { - "type": "integer" + "type": "integer", + "description": "Number of top tokens to consider for sampling. Must be at least 1" } }, "additionalProperties": false, @@ -5682,7 +5768,8 @@ "type", "top_k" ], - "title": "TopKSamplingStrategy" + "title": "TopKSamplingStrategy", + "description": "Top-k sampling strategy that restricts sampling to the k most likely tokens." }, "TopPSamplingStrategy": { "type": "object", @@ -5690,34 +5777,40 @@ "type": { "type": "string", "const": "top_p", - "default": "top_p" + "default": "top_p", + "description": "Must be \"top_p\" to identify this sampling strategy" }, "temperature": { - "type": "number" + "type": "number", + "description": "Controls randomness in sampling. Higher values increase randomness" }, "top_p": { "type": "number", - "default": 0.95 + "default": 0.95, + "description": "Cumulative probability threshold for nucleus sampling. Defaults to 0.95" } }, "additionalProperties": false, "required": [ "type" ], - "title": "TopPSamplingStrategy" + "title": "TopPSamplingStrategy", + "description": "Top-p (nucleus) sampling strategy that samples from the smallest set of tokens with cumulative probability >= p." }, "URL": { "type": "object", "properties": { "uri": { - "type": "string" + "type": "string", + "description": "The URL string pointing to the resource" } }, "additionalProperties": false, "required": [ "uri" ], - "title": "URL" + "title": "URL", + "description": "A URL reference to external content." }, "UserMessage": { "type": "object", @@ -5808,14 +5901,16 @@ "type": "array", "items": { "$ref": "#/components/schemas/ChatCompletionResponse" - } + }, + "description": "List of chat completion responses, one for each conversation in the batch" } }, "additionalProperties": false, "required": [ "batch" ], - "title": "BatchChatCompletionResponse" + "title": "BatchChatCompletionResponse", + "description": "Response from a batch chat completion request." }, "ChatCompletionResponse": { "type": "object", @@ -5824,7 +5919,8 @@ "type": "array", "items": { "$ref": "#/components/schemas/MetricInResponse" - } + }, + "description": "(Optional) List of metrics associated with the API response" }, "completion_message": { "$ref": "#/components/schemas/CompletionMessage", @@ -5849,7 +5945,8 @@ "type": "object", "properties": { "metric": { - "type": "string" + "type": "string", + "description": "The name of the metric" }, "value": { "oneOf": [ @@ -5859,10 +5956,12 @@ { "type": "number" } - ] + ], + "description": "The numeric value of the metric" }, "unit": { - "type": "string" + "type": "string", + "description": "(Optional) The unit of measurement for the metric value" } }, "additionalProperties": false, @@ -5870,7 +5969,8 @@ "metric", "value" ], - "title": "MetricInResponse" + "title": "MetricInResponse", + "description": "A metric value included in API responses." }, "TokenLogProbs": { "type": "object", @@ -5939,14 +6039,16 @@ "type": "array", "items": { "$ref": "#/components/schemas/CompletionResponse" - } + }, + "description": "List of completion responses, one for each input in the batch" } }, "additionalProperties": false, "required": [ "batch" ], - "title": "BatchCompletionResponse" + "title": "BatchCompletionResponse", + "description": "Response from a batch completion request." }, "CompletionResponse": { "type": "object", @@ -5955,7 +6057,8 @@ "type": "array", "items": { "$ref": "#/components/schemas/MetricInResponse" - } + }, + "description": "(Optional) List of metrics associated with the API response" }, "content": { "type": "string", @@ -6123,7 +6226,8 @@ "type": "array", "items": { "$ref": "#/components/schemas/MetricInResponse" - } + }, + "description": "(Optional) List of metrics associated with the API response" }, "event": { "$ref": "#/components/schemas/ChatCompletionResponseEvent", @@ -6164,11 +6268,13 @@ "type": { "type": "string", "const": "image", - "default": "image" + "default": "image", + "description": "Discriminator type of the delta. Always \"image\"" }, "image": { "type": "string", - "contentEncoding": "base64" + "contentEncoding": "base64", + "description": "The incremental image data as bytes" } }, "additionalProperties": false, @@ -6176,7 +6282,8 @@ "type", "image" ], - "title": "ImageDelta" + "title": "ImageDelta", + "description": "An image content delta for streaming responses." }, "TextDelta": { "type": "object", @@ -6184,10 +6291,12 @@ "type": { "type": "string", "const": "text", - "default": "text" + "default": "text", + "description": "Discriminator type of the delta. Always \"text\"" }, "text": { - "type": "string" + "type": "string", + "description": "The incremental text content" } }, "additionalProperties": false, @@ -6195,7 +6304,8 @@ "type", "text" ], - "title": "TextDelta" + "title": "TextDelta", + "description": "A text content delta for streaming responses." }, "ToolCallDelta": { "type": "object", @@ -6203,7 +6313,8 @@ "type": { "type": "string", "const": "tool_call", - "default": "tool_call" + "default": "tool_call", + "description": "Discriminator type of the delta. Always \"tool_call\"" }, "tool_call": { "oneOf": [ @@ -6213,7 +6324,8 @@ { "$ref": "#/components/schemas/ToolCall" } - ] + ], + "description": "Either an in-progress tool call string or the final parsed tool call" }, "parse_status": { "type": "string", @@ -6223,7 +6335,7 @@ "failed", "succeeded" ], - "title": "ToolCallParseStatus" + "description": "Current parsing status of the tool call" } }, "additionalProperties": false, @@ -6232,7 +6344,8 @@ "tool_call", "parse_status" ], - "title": "ToolCallDelta" + "title": "ToolCallDelta", + "description": "A tool call content delta for streaming responses." }, "CompletionRequest": { "type": "object", @@ -6284,7 +6397,8 @@ "type": "array", "items": { "$ref": "#/components/schemas/MetricInResponse" - } + }, + "description": "(Optional) List of metrics associated with the API response" }, "delta": { "type": "string", @@ -6453,16 +6567,19 @@ "type": "object", "properties": { "name": { - "type": "string" + "type": "string", + "description": "Name of the tool" }, "description": { - "type": "string" + "type": "string", + "description": "(Optional) Human-readable description of what the tool does" }, "parameters": { "type": "array", "items": { "$ref": "#/components/schemas/ToolParameter" - } + }, + "description": "(Optional) List of parameters this tool accepts" }, "metadata": { "type": "object", @@ -6487,30 +6604,36 @@ "type": "object" } ] - } + }, + "description": "(Optional) Additional metadata about the tool" } }, "additionalProperties": false, "required": [ "name" ], - "title": "ToolDef" + "title": "ToolDef", + "description": "Tool definition used in runtime contexts." }, "ToolParameter": { "type": "object", "properties": { "name": { - "type": "string" + "type": "string", + "description": "Name of the parameter" }, "parameter_type": { - "type": "string" + "type": "string", + "description": "Type of the parameter (e.g., string, integer)" }, "description": { - "type": "string" + "type": "string", + "description": "Human-readable description of what the parameter does" }, "required": { "type": "boolean", - "default": true + "default": true, + "description": "Whether this parameter is required for tool invocation" }, "default": { "oneOf": [ @@ -6532,7 +6655,8 @@ { "type": "object" } - ] + ], + "description": "(Optional) Default value for the parameter if not provided" } }, "additionalProperties": false, @@ -6542,7 +6666,8 @@ "description", "required" ], - "title": "ToolParameter" + "title": "ToolParameter", + "description": "Parameter definition for a tool." }, "CreateAgentRequest": { "type": "object", @@ -6562,14 +6687,16 @@ "type": "object", "properties": { "agent_id": { - "type": "string" + "type": "string", + "description": "Unique identifier for the created agent" } }, "additionalProperties": false, "required": [ "agent_id" ], - "title": "AgentCreateResponse" + "title": "AgentCreateResponse", + "description": "Response returned when creating a new agent." }, "CreateAgentSessionRequest": { "type": "object", @@ -6589,14 +6716,16 @@ "type": "object", "properties": { "session_id": { - "type": "string" + "type": "string", + "description": "Unique identifier for the created session" } }, "additionalProperties": false, "required": [ "session_id" ], - "title": "AgentSessionCreateResponse" + "title": "AgentSessionCreateResponse", + "description": "Response returned when creating a new agent session." }, "CreateAgentTurnRequest": { "type": "object", @@ -6784,10 +6913,12 @@ "type": "object", "properties": { "violation_level": { - "$ref": "#/components/schemas/ViolationLevel" + "$ref": "#/components/schemas/ViolationLevel", + "description": "Severity level of the violation" }, "user_message": { - "type": "string" + "type": "string", + "description": "(Optional) Message to convey to the user about the violation" }, "metadata": { "type": "object", @@ -6812,7 +6943,8 @@ "type": "object" } ] - } + }, + "description": "Additional metadata including specific violation codes for debugging and telemetry" } }, "additionalProperties": false, @@ -6820,7 +6952,8 @@ "violation_level", "metadata" ], - "title": "SafetyViolation" + "title": "SafetyViolation", + "description": "Details of a safety violation detected by content moderation." }, "ShieldCallStep": { "type": "object", @@ -6934,7 +7067,8 @@ "type": "object", "properties": { "call_id": { - "type": "string" + "type": "string", + "description": "Unique identifier for the tool call this response is for" }, "tool_name": { "oneOf": [ @@ -6951,10 +7085,12 @@ { "type": "string" } - ] + ], + "description": "Name of the tool that was invoked" }, "content": { - "$ref": "#/components/schemas/InterleavedContent" + "$ref": "#/components/schemas/InterleavedContent", + "description": "The response content from the tool" }, "metadata": { "type": "object", @@ -6979,7 +7115,8 @@ "type": "object" } ] - } + }, + "description": "(Optional) Additional metadata about the tool response" } }, "additionalProperties": false, @@ -6988,16 +7125,19 @@ "tool_name", "content" ], - "title": "ToolResponse" + "title": "ToolResponse", + "description": "Response from a tool invocation." }, "Turn": { "type": "object", "properties": { "turn_id": { - "type": "string" + "type": "string", + "description": "Unique identifier for the turn within a session" }, "session_id": { - "type": "string" + "type": "string", + "description": "Unique identifier for the conversation session" }, "input_messages": { "type": "array", @@ -7010,7 +7150,8 @@ "$ref": "#/components/schemas/ToolResponseMessage" } ] - } + }, + "description": "List of messages that initiated this turn" }, "steps": { "type": "array", @@ -7038,10 +7179,12 @@ "memory_retrieval": "#/components/schemas/MemoryRetrievalStep" } } - } + }, + "description": "Ordered list of processing steps executed during this turn" }, "output_message": { - "$ref": "#/components/schemas/CompletionMessage" + "$ref": "#/components/schemas/CompletionMessage", + "description": "The model's generated response containing content and metadata" }, "output_attachments": { "type": "array", @@ -7080,15 +7223,18 @@ ], "title": "Attachment", "description": "An attachment to an agent turn." - } + }, + "description": "(Optional) Files or media attached to the agent's response" }, "started_at": { "type": "string", - "format": "date-time" + "format": "date-time", + "description": "Timestamp when the turn began" }, "completed_at": { "type": "string", - "format": "date-time" + "format": "date-time", + "description": "(Optional) Timestamp when the turn finished, if completed" } }, "additionalProperties": false, @@ -7110,20 +7256,23 @@ "warn", "error" ], - "title": "ViolationLevel" + "title": "ViolationLevel", + "description": "Severity level of a safety violation." }, "AgentTurnResponseEvent": { "type": "object", "properties": { "payload": { - "$ref": "#/components/schemas/AgentTurnResponseEventPayload" + "$ref": "#/components/schemas/AgentTurnResponseEventPayload", + "description": "Event-specific payload containing event data" } }, "additionalProperties": false, "required": [ "payload" ], - "title": "AgentTurnResponseEvent" + "title": "AgentTurnResponseEvent", + "description": "An event in an agent turn response stream." }, "AgentTurnResponseEventPayload": { "oneOf": [ @@ -7171,9 +7320,9 @@ "turn_complete", "turn_awaiting_input" ], - "title": "AgentTurnResponseEventType", "const": "step_complete", - "default": "step_complete" + "default": "step_complete", + "description": "Type of event being reported" }, "step_type": { "type": "string", @@ -7183,11 +7332,11 @@ "shield_call", "memory_retrieval" ], - "title": "StepType", - "description": "Type of the step in an agent turn." + "description": "Type of step being executed" }, "step_id": { - "type": "string" + "type": "string", + "description": "Unique identifier for the step within a turn" }, "step_details": { "oneOf": [ @@ -7212,7 +7361,8 @@ "shield_call": "#/components/schemas/ShieldCallStep", "memory_retrieval": "#/components/schemas/MemoryRetrievalStep" } - } + }, + "description": "Complete details of the executed step" } }, "additionalProperties": false, @@ -7222,7 +7372,8 @@ "step_id", "step_details" ], - "title": "AgentTurnResponseStepCompletePayload" + "title": "AgentTurnResponseStepCompletePayload", + "description": "Payload for step completion events in agent turn responses." }, "AgentTurnResponseStepProgressPayload": { "type": "object", @@ -7237,9 +7388,9 @@ "turn_complete", "turn_awaiting_input" ], - "title": "AgentTurnResponseEventType", "const": "step_progress", - "default": "step_progress" + "default": "step_progress", + "description": "Type of event being reported" }, "step_type": { "type": "string", @@ -7249,14 +7400,15 @@ "shield_call", "memory_retrieval" ], - "title": "StepType", - "description": "Type of the step in an agent turn." + "description": "Type of step being executed" }, "step_id": { - "type": "string" + "type": "string", + "description": "Unique identifier for the step within a turn" }, "delta": { - "$ref": "#/components/schemas/ContentDelta" + "$ref": "#/components/schemas/ContentDelta", + "description": "Incremental content changes during step execution" } }, "additionalProperties": false, @@ -7266,7 +7418,8 @@ "step_id", "delta" ], - "title": "AgentTurnResponseStepProgressPayload" + "title": "AgentTurnResponseStepProgressPayload", + "description": "Payload for step progress events in agent turn responses." }, "AgentTurnResponseStepStartPayload": { "type": "object", @@ -7281,9 +7434,9 @@ "turn_complete", "turn_awaiting_input" ], - "title": "AgentTurnResponseEventType", "const": "step_start", - "default": "step_start" + "default": "step_start", + "description": "Type of event being reported" }, "step_type": { "type": "string", @@ -7293,11 +7446,11 @@ "shield_call", "memory_retrieval" ], - "title": "StepType", - "description": "Type of the step in an agent turn." + "description": "Type of step being executed" }, "step_id": { - "type": "string" + "type": "string", + "description": "Unique identifier for the step within a turn" }, "metadata": { "type": "object", @@ -7322,7 +7475,8 @@ "type": "object" } ] - } + }, + "description": "(Optional) Additional metadata for the step" } }, "additionalProperties": false, @@ -7331,13 +7485,15 @@ "step_type", "step_id" ], - "title": "AgentTurnResponseStepStartPayload" + "title": "AgentTurnResponseStepStartPayload", + "description": "Payload for step start events in agent turn responses." }, "AgentTurnResponseStreamChunk": { "type": "object", "properties": { "event": { - "$ref": "#/components/schemas/AgentTurnResponseEvent" + "$ref": "#/components/schemas/AgentTurnResponseEvent", + "description": "Individual event in the agent turn response stream" } }, "additionalProperties": false, @@ -7345,7 +7501,7 @@ "event" ], "title": "AgentTurnResponseStreamChunk", - "description": "streamed agent turn completion response." + "description": "Streamed agent turn completion response." }, "AgentTurnResponseTurnAwaitingInputPayload": { "type": "object", @@ -7360,12 +7516,13 @@ "turn_complete", "turn_awaiting_input" ], - "title": "AgentTurnResponseEventType", "const": "turn_awaiting_input", - "default": "turn_awaiting_input" + "default": "turn_awaiting_input", + "description": "Type of event being reported" }, "turn": { - "$ref": "#/components/schemas/Turn" + "$ref": "#/components/schemas/Turn", + "description": "Turn data when waiting for external tool responses" } }, "additionalProperties": false, @@ -7373,7 +7530,8 @@ "event_type", "turn" ], - "title": "AgentTurnResponseTurnAwaitingInputPayload" + "title": "AgentTurnResponseTurnAwaitingInputPayload", + "description": "Payload for turn awaiting input events in agent turn responses." }, "AgentTurnResponseTurnCompletePayload": { "type": "object", @@ -7388,12 +7546,13 @@ "turn_complete", "turn_awaiting_input" ], - "title": "AgentTurnResponseEventType", "const": "turn_complete", - "default": "turn_complete" + "default": "turn_complete", + "description": "Type of event being reported" }, "turn": { - "$ref": "#/components/schemas/Turn" + "$ref": "#/components/schemas/Turn", + "description": "Complete turn data including all steps and results" } }, "additionalProperties": false, @@ -7401,7 +7560,8 @@ "event_type", "turn" ], - "title": "AgentTurnResponseTurnCompletePayload" + "title": "AgentTurnResponseTurnCompletePayload", + "description": "Payload for turn completion events in agent turn responses." }, "AgentTurnResponseTurnStartPayload": { "type": "object", @@ -7416,12 +7576,13 @@ "turn_complete", "turn_awaiting_input" ], - "title": "AgentTurnResponseEventType", "const": "turn_start", - "default": "turn_start" + "default": "turn_start", + "description": "Type of event being reported" }, "turn_id": { - "type": "string" + "type": "string", + "description": "Unique identifier for the turn within a session" } }, "additionalProperties": false, @@ -7429,7 +7590,8 @@ "event_type", "turn_id" ], - "title": "AgentTurnResponseTurnStartPayload" + "title": "AgentTurnResponseTurnStartPayload", + "description": "Payload for turn start events in agent turn responses." }, "OpenAIResponseAnnotationCitation": { "type": "object", @@ -7437,19 +7599,24 @@ "type": { "type": "string", "const": "url_citation", - "default": "url_citation" + "default": "url_citation", + "description": "Annotation type identifier, always \"url_citation\"" }, "end_index": { - "type": "integer" + "type": "integer", + "description": "End position of the citation span in the content" }, "start_index": { - "type": "integer" + "type": "integer", + "description": "Start position of the citation span in the content" }, "title": { - "type": "string" + "type": "string", + "description": "Title of the referenced web resource" }, "url": { - "type": "string" + "type": "string", + "description": "URL of the referenced web resource" } }, "additionalProperties": false, @@ -7460,7 +7627,8 @@ "title", "url" ], - "title": "OpenAIResponseAnnotationCitation" + "title": "OpenAIResponseAnnotationCitation", + "description": "URL citation annotation for referencing external web resources." }, "OpenAIResponseAnnotationContainerFileCitation": { "type": "object", @@ -7503,16 +7671,20 @@ "type": { "type": "string", "const": "file_citation", - "default": "file_citation" + "default": "file_citation", + "description": "Annotation type identifier, always \"file_citation\"" }, "file_id": { - "type": "string" + "type": "string", + "description": "Unique identifier of the referenced file" }, "filename": { - "type": "string" + "type": "string", + "description": "Name of the referenced file" }, "index": { - "type": "integer" + "type": "integer", + "description": "Position index of the citation within the content" } }, "additionalProperties": false, @@ -7522,7 +7694,8 @@ "filename", "index" ], - "title": "OpenAIResponseAnnotationFileCitation" + "title": "OpenAIResponseAnnotationFileCitation", + "description": "File citation annotation for referencing specific files in response content." }, "OpenAIResponseAnnotationFilePath": { "type": "object", @@ -7656,15 +7829,18 @@ "const": "auto" } ], - "default": "auto" + "default": "auto", + "description": "Level of detail for image processing, can be \"low\", \"high\", or \"auto\"" }, "type": { "type": "string", "const": "input_image", - "default": "input_image" + "default": "input_image", + "description": "Content type identifier, always \"input_image\"" }, "image_url": { - "type": "string" + "type": "string", + "description": "(Optional) URL of the image content" } }, "additionalProperties": false, @@ -7672,18 +7848,21 @@ "detail", "type" ], - "title": "OpenAIResponseInputMessageContentImage" + "title": "OpenAIResponseInputMessageContentImage", + "description": "Image content for input messages in OpenAI response format." }, "OpenAIResponseInputMessageContentText": { "type": "object", "properties": { "text": { - "type": "string" + "type": "string", + "description": "The text content of the input message" }, "type": { "type": "string", "const": "input_text", - "default": "input_text" + "default": "input_text", + "description": "Content type identifier, always \"input_text\"" } }, "additionalProperties": false, @@ -7691,7 +7870,8 @@ "text", "type" ], - "title": "OpenAIResponseInputMessageContentText" + "title": "OpenAIResponseInputMessageContentText", + "description": "Text content for input messages in OpenAI response format." }, "OpenAIResponseInputTool": { "oneOf": [ @@ -7724,13 +7904,15 @@ "type": { "type": "string", "const": "file_search", - "default": "file_search" + "default": "file_search", + "description": "Tool type identifier, always \"file_search\"" }, "vector_store_ids": { "type": "array", "items": { "type": "string" - } + }, + "description": "List of vector store identifiers to search within" }, "filters": { "type": "object", @@ -7755,25 +7937,29 @@ "type": "object" } ] - } + }, + "description": "(Optional) Additional filters to apply to the search" }, "max_num_results": { "type": "integer", - "default": 10 + "default": 10, + "description": "(Optional) Maximum number of search results to return (1-50)" }, "ranking_options": { "type": "object", "properties": { "ranker": { - "type": "string" + "type": "string", + "description": "(Optional) Name of the ranking algorithm to use" }, "score_threshold": { "type": "number", - "default": 0.0 + "default": 0.0, + "description": "(Optional) Minimum relevance score threshold for results" } }, "additionalProperties": false, - "title": "SearchRankingOptions" + "description": "(Optional) Options for ranking and scoring search results" } }, "additionalProperties": false, @@ -7781,7 +7967,8 @@ "type", "vector_store_ids" ], - "title": "OpenAIResponseInputToolFileSearch" + "title": "OpenAIResponseInputToolFileSearch", + "description": "File search tool configuration for OpenAI response inputs." }, "OpenAIResponseInputToolFunction": { "type": "object", @@ -7789,13 +7976,16 @@ "type": { "type": "string", "const": "function", - "default": "function" + "default": "function", + "description": "Tool type identifier, always \"function\"" }, "name": { - "type": "string" + "type": "string", + "description": "Name of the function that can be called" }, "description": { - "type": "string" + "type": "string", + "description": "(Optional) Description of what the function does" }, "parameters": { "type": "object", @@ -7820,10 +8010,12 @@ "type": "object" } ] - } + }, + "description": "(Optional) JSON schema defining the function's parameters" }, "strict": { - "type": "boolean" + "type": "boolean", + "description": "(Optional) Whether to enforce strict parameter validation" } }, "additionalProperties": false, @@ -7831,7 +8023,8 @@ "type", "name" ], - "title": "OpenAIResponseInputToolFunction" + "title": "OpenAIResponseInputToolFunction", + "description": "Function tool configuration for OpenAI response inputs." }, "OpenAIResponseInputToolMCP": { "type": "object", @@ -7839,13 +8032,16 @@ "type": { "type": "string", "const": "mcp", - "default": "mcp" + "default": "mcp", + "description": "Tool type identifier, always \"mcp\"" }, "server_label": { - "type": "string" + "type": "string", + "description": "Label to identify this MCP server" }, "server_url": { - "type": "string" + "type": "string", + "description": "URL endpoint of the MCP server" }, "headers": { "type": "object", @@ -7870,7 +8066,8 @@ "type": "object" } ] - } + }, + "description": "(Optional) HTTP headers to include when connecting to the server" }, "require_approval": { "oneOf": [ @@ -7889,20 +8086,24 @@ "type": "array", "items": { "type": "string" - } + }, + "description": "(Optional) List of tool names that always require approval" }, "never": { "type": "array", "items": { "type": "string" - } + }, + "description": "(Optional) List of tool names that never require approval" } }, "additionalProperties": false, - "title": "ApprovalFilter" + "title": "ApprovalFilter", + "description": "Filter configuration for MCP tool approval requirements." } ], - "default": "never" + "default": "never", + "description": "Approval requirement for tool calls (\"always\", \"never\", or filter)" }, "allowed_tools": { "oneOf": [ @@ -7919,13 +8120,16 @@ "type": "array", "items": { "type": "string" - } + }, + "description": "(Optional) List of specific tool names that are allowed" } }, "additionalProperties": false, - "title": "AllowedToolsFilter" + "title": "AllowedToolsFilter", + "description": "Filter configuration for restricting which MCP tools can be used." } - ] + ], + "description": "(Optional) Restriction on which tools can be used from this server" } }, "additionalProperties": false, @@ -7935,7 +8139,8 @@ "server_url", "require_approval" ], - "title": "OpenAIResponseInputToolMCP" + "title": "OpenAIResponseInputToolMCP", + "description": "Model Context Protocol (MCP) tool configuration for OpenAI response inputs." }, "OpenAIResponseInputToolWebSearch": { "type": "object", @@ -7955,18 +8160,21 @@ "const": "web_search_preview_2025_03_11" } ], - "default": "web_search" + "default": "web_search", + "description": "Web search tool type variant to use" }, "search_context_size": { "type": "string", - "default": "medium" + "default": "medium", + "description": "(Optional) Size of search context, must be \"low\", \"medium\", or \"high\"" } }, "additionalProperties": false, "required": [ "type" ], - "title": "OpenAIResponseInputToolWebSearch" + "title": "OpenAIResponseInputToolWebSearch", + "description": "Web search tool configuration for OpenAI response inputs." }, "OpenAIResponseMessage": { "type": "object", @@ -8061,21 +8269,25 @@ "type": "object", "properties": { "id": { - "type": "string" + "type": "string", + "description": "Unique identifier for this tool call" }, "queries": { "type": "array", "items": { "type": "string" - } + }, + "description": "List of search queries executed" }, "status": { - "type": "string" + "type": "string", + "description": "Current status of the file search operation" }, "type": { "type": "string", "const": "file_search_call", - "default": "file_search_call" + "default": "file_search_call", + "description": "Tool call type identifier, always \"file_search_call\"" }, "results": { "type": "array", @@ -8103,7 +8315,8 @@ } ] } - } + }, + "description": "(Optional) Search results returned by the file search operation" } }, "additionalProperties": false, @@ -8113,30 +8326,37 @@ "status", "type" ], - "title": "OpenAIResponseOutputMessageFileSearchToolCall" + "title": "OpenAIResponseOutputMessageFileSearchToolCall", + "description": "File search tool call output message for OpenAI responses." }, "OpenAIResponseOutputMessageFunctionToolCall": { "type": "object", "properties": { "call_id": { - "type": "string" + "type": "string", + "description": "Unique identifier for the function call" }, "name": { - "type": "string" + "type": "string", + "description": "Name of the function being called" }, "arguments": { - "type": "string" + "type": "string", + "description": "JSON string containing the function arguments" }, "type": { "type": "string", "const": "function_call", - "default": "function_call" + "default": "function_call", + "description": "Tool call type identifier, always \"function_call\"" }, "id": { - "type": "string" + "type": "string", + "description": "(Optional) Additional identifier for the tool call" }, "status": { - "type": "string" + "type": "string", + "description": "(Optional) Current status of the function call execution" } }, "additionalProperties": false, @@ -8146,21 +8366,25 @@ "arguments", "type" ], - "title": "OpenAIResponseOutputMessageFunctionToolCall" + "title": "OpenAIResponseOutputMessageFunctionToolCall", + "description": "Function tool call output message for OpenAI responses." }, "OpenAIResponseOutputMessageWebSearchToolCall": { "type": "object", "properties": { "id": { - "type": "string" + "type": "string", + "description": "Unique identifier for this tool call" }, "status": { - "type": "string" + "type": "string", + "description": "Current status of the web search operation" }, "type": { "type": "string", "const": "web_search_call", - "default": "web_search_call" + "default": "web_search_call", + "description": "Tool call type identifier, always \"web_search_call\"" } }, "additionalProperties": false, @@ -8169,7 +8393,8 @@ "status", "type" ], - "title": "OpenAIResponseOutputMessageWebSearchToolCall" + "title": "OpenAIResponseOutputMessageWebSearchToolCall", + "description": "Web search tool call output message for OpenAI responses." }, "OpenAIResponseText": { "type": "object", @@ -8237,12 +8462,12 @@ "required": [ "type" ], - "title": "OpenAIResponseTextFormat", - "description": "Configuration for Responses API text format." + "description": "(Optional) Text format configuration specifying output format requirements" } }, "additionalProperties": false, - "title": "OpenAIResponseText" + "title": "OpenAIResponseText", + "description": "Text response configuration for OpenAI responses." }, "CreateOpenaiResponseRequest": { "type": "object", @@ -8305,10 +8530,12 @@ "type": "object", "properties": { "code": { - "type": "string" + "type": "string", + "description": "Error code identifying the type of failure" }, "message": { - "type": "string" + "type": "string", + "description": "Human-readable error message describing the failure" } }, "additionalProperties": false, @@ -8316,58 +8543,73 @@ "code", "message" ], - "title": "OpenAIResponseError" + "title": "OpenAIResponseError", + "description": "Error details for failed OpenAI response requests." }, "OpenAIResponseObject": { "type": "object", "properties": { "created_at": { - "type": "integer" + "type": "integer", + "description": "Unix timestamp when the response was created" }, "error": { - "$ref": "#/components/schemas/OpenAIResponseError" + "$ref": "#/components/schemas/OpenAIResponseError", + "description": "(Optional) Error details if the response generation failed" }, "id": { - "type": "string" + "type": "string", + "description": "Unique identifier for this response" }, "model": { - "type": "string" + "type": "string", + "description": "Model identifier used for generation" }, "object": { "type": "string", "const": "response", - "default": "response" + "default": "response", + "description": "Object type identifier, always \"response\"" }, "output": { "type": "array", "items": { "$ref": "#/components/schemas/OpenAIResponseOutput" - } + }, + "description": "List of generated output items (messages, tool calls, etc.)" }, "parallel_tool_calls": { "type": "boolean", - "default": false + "default": false, + "description": "Whether tool calls can be executed in parallel" }, "previous_response_id": { - "type": "string" + "type": "string", + "description": "(Optional) ID of the previous response in a conversation" }, "status": { - "type": "string" + "type": "string", + "description": "Current status of the response generation" }, "temperature": { - "type": "number" + "type": "number", + "description": "(Optional) Sampling temperature used for generation" }, "text": { - "$ref": "#/components/schemas/OpenAIResponseText" + "$ref": "#/components/schemas/OpenAIResponseText", + "description": "Text formatting configuration for the response" }, "top_p": { - "type": "number" + "type": "number", + "description": "(Optional) Nucleus sampling parameter used for generation" }, "truncation": { - "type": "string" + "type": "string", + "description": "(Optional) Truncation strategy applied to the response" }, "user": { - "type": "string" + "type": "string", + "description": "(Optional) User identifier associated with the request" } }, "additionalProperties": false, @@ -8381,7 +8623,8 @@ "status", "text" ], - "title": "OpenAIResponseObject" + "title": "OpenAIResponseObject", + "description": "Complete OpenAI response object containing generation results and metadata." }, "OpenAIResponseOutput": { "oneOf": [ @@ -8420,27 +8663,34 @@ "type": "object", "properties": { "id": { - "type": "string" + "type": "string", + "description": "Unique identifier for this MCP call" }, "type": { "type": "string", "const": "mcp_call", - "default": "mcp_call" + "default": "mcp_call", + "description": "Tool call type identifier, always \"mcp_call\"" }, "arguments": { - "type": "string" + "type": "string", + "description": "JSON string containing the MCP call arguments" }, "name": { - "type": "string" + "type": "string", + "description": "Name of the MCP method being called" }, "server_label": { - "type": "string" + "type": "string", + "description": "Label identifying the MCP server handling the call" }, "error": { - "type": "string" + "type": "string", + "description": "(Optional) Error message if the MCP call failed" }, "output": { - "type": "string" + "type": "string", + "description": "(Optional) Output result from the successful MCP call" } }, "additionalProperties": false, @@ -8451,21 +8701,25 @@ "name", "server_label" ], - "title": "OpenAIResponseOutputMessageMCPCall" + "title": "OpenAIResponseOutputMessageMCPCall", + "description": "Model Context Protocol (MCP) call output message for OpenAI responses." }, "OpenAIResponseOutputMessageMCPListTools": { "type": "object", "properties": { "id": { - "type": "string" + "type": "string", + "description": "Unique identifier for this MCP list tools operation" }, "type": { "type": "string", "const": "mcp_list_tools", - "default": "mcp_list_tools" + "default": "mcp_list_tools", + "description": "Tool call type identifier, always \"mcp_list_tools\"" }, "server_label": { - "type": "string" + "type": "string", + "description": "Label identifying the MCP server providing the tools" }, "tools": { "type": "array", @@ -8495,13 +8749,16 @@ "type": "object" } ] - } + }, + "description": "JSON schema defining the tool's input parameters" }, "name": { - "type": "string" + "type": "string", + "description": "Name of the tool" }, "description": { - "type": "string" + "type": "string", + "description": "(Optional) Description of what the tool does" } }, "additionalProperties": false, @@ -8509,8 +8766,10 @@ "input_schema", "name" ], - "title": "MCPListToolsTool" - } + "title": "MCPListToolsTool", + "description": "Tool definition returned by MCP list tools operation." + }, + "description": "List of available tools provided by the MCP server" } }, "additionalProperties": false, @@ -8520,7 +8779,8 @@ "server_label", "tools" ], - "title": "OpenAIResponseOutputMessageMCPListTools" + "title": "OpenAIResponseOutputMessageMCPListTools", + "description": "MCP list tools output message containing available tools from an MCP server." }, "OpenAIResponseObjectStream": { "oneOf": [ @@ -8611,12 +8871,14 @@ "type": "object", "properties": { "response": { - "$ref": "#/components/schemas/OpenAIResponseObject" + "$ref": "#/components/schemas/OpenAIResponseObject", + "description": "The completed response object" }, "type": { "type": "string", "const": "response.completed", - "default": "response.completed" + "default": "response.completed", + "description": "Event type identifier, always \"response.completed\"" } }, "additionalProperties": false, @@ -8624,18 +8886,21 @@ "response", "type" ], - "title": "OpenAIResponseObjectStreamResponseCompleted" + "title": "OpenAIResponseObjectStreamResponseCompleted", + "description": "Streaming event indicating a response has been completed." }, "OpenAIResponseObjectStreamResponseCreated": { "type": "object", "properties": { "response": { - "$ref": "#/components/schemas/OpenAIResponseObject" + "$ref": "#/components/schemas/OpenAIResponseObject", + "description": "The newly created response object" }, "type": { "type": "string", "const": "response.created", - "default": "response.created" + "default": "response.created", + "description": "Event type identifier, always \"response.created\"" } }, "additionalProperties": false, @@ -8643,27 +8908,33 @@ "response", "type" ], - "title": "OpenAIResponseObjectStreamResponseCreated" + "title": "OpenAIResponseObjectStreamResponseCreated", + "description": "Streaming event indicating a new response has been created." }, "OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta": { "type": "object", "properties": { "delta": { - "type": "string" + "type": "string", + "description": "Incremental function call arguments being added" }, "item_id": { - "type": "string" + "type": "string", + "description": "Unique identifier of the function call being updated" }, "output_index": { - "type": "integer" + "type": "integer", + "description": "Index position of the item in the output list" }, "sequence_number": { - "type": "integer" + "type": "integer", + "description": "Sequential number for ordering streaming events" }, "type": { "type": "string", "const": "response.function_call_arguments.delta", - "default": "response.function_call_arguments.delta" + "default": "response.function_call_arguments.delta", + "description": "Event type identifier, always \"response.function_call_arguments.delta\"" } }, "additionalProperties": false, @@ -8674,27 +8945,33 @@ "sequence_number", "type" ], - "title": "OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta" + "title": "OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta", + "description": "Streaming event for incremental function call argument updates." }, "OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone": { "type": "object", "properties": { "arguments": { - "type": "string" + "type": "string", + "description": "Final complete arguments JSON string for the function call" }, "item_id": { - "type": "string" + "type": "string", + "description": "Unique identifier of the completed function call" }, "output_index": { - "type": "integer" + "type": "integer", + "description": "Index position of the item in the output list" }, "sequence_number": { - "type": "integer" + "type": "integer", + "description": "Sequential number for ordering streaming events" }, "type": { "type": "string", "const": "response.function_call_arguments.done", - "default": "response.function_call_arguments.done" + "default": "response.function_call_arguments.done", + "description": "Event type identifier, always \"response.function_call_arguments.done\"" } }, "additionalProperties": false, @@ -8705,7 +8982,8 @@ "sequence_number", "type" ], - "title": "OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone" + "title": "OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone", + "description": "Streaming event for when function call arguments are completed." }, "OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta": { "type": "object", @@ -8773,12 +9051,14 @@ "type": "object", "properties": { "sequence_number": { - "type": "integer" + "type": "integer", + "description": "Sequential number for ordering streaming events" }, "type": { "type": "string", "const": "response.mcp_call.completed", - "default": "response.mcp_call.completed" + "default": "response.mcp_call.completed", + "description": "Event type identifier, always \"response.mcp_call.completed\"" } }, "additionalProperties": false, @@ -8786,18 +9066,21 @@ "sequence_number", "type" ], - "title": "OpenAIResponseObjectStreamResponseMcpCallCompleted" + "title": "OpenAIResponseObjectStreamResponseMcpCallCompleted", + "description": "Streaming event for completed MCP calls." }, "OpenAIResponseObjectStreamResponseMcpCallFailed": { "type": "object", "properties": { "sequence_number": { - "type": "integer" + "type": "integer", + "description": "Sequential number for ordering streaming events" }, "type": { "type": "string", "const": "response.mcp_call.failed", - "default": "response.mcp_call.failed" + "default": "response.mcp_call.failed", + "description": "Event type identifier, always \"response.mcp_call.failed\"" } }, "additionalProperties": false, @@ -8805,24 +9088,29 @@ "sequence_number", "type" ], - "title": "OpenAIResponseObjectStreamResponseMcpCallFailed" + "title": "OpenAIResponseObjectStreamResponseMcpCallFailed", + "description": "Streaming event for failed MCP calls." }, "OpenAIResponseObjectStreamResponseMcpCallInProgress": { "type": "object", "properties": { "item_id": { - "type": "string" + "type": "string", + "description": "Unique identifier of the MCP call" }, "output_index": { - "type": "integer" + "type": "integer", + "description": "Index position of the item in the output list" }, "sequence_number": { - "type": "integer" + "type": "integer", + "description": "Sequential number for ordering streaming events" }, "type": { "type": "string", "const": "response.mcp_call.in_progress", - "default": "response.mcp_call.in_progress" + "default": "response.mcp_call.in_progress", + "description": "Event type identifier, always \"response.mcp_call.in_progress\"" } }, "additionalProperties": false, @@ -8832,7 +9120,8 @@ "sequence_number", "type" ], - "title": "OpenAIResponseObjectStreamResponseMcpCallInProgress" + "title": "OpenAIResponseObjectStreamResponseMcpCallInProgress", + "description": "Streaming event for MCP calls in progress." }, "OpenAIResponseObjectStreamResponseMcpListToolsCompleted": { "type": "object", @@ -8895,21 +9184,26 @@ "type": "object", "properties": { "response_id": { - "type": "string" + "type": "string", + "description": "Unique identifier of the response containing this output" }, "item": { - "$ref": "#/components/schemas/OpenAIResponseOutput" + "$ref": "#/components/schemas/OpenAIResponseOutput", + "description": "The output item that was added (message, tool call, etc.)" }, "output_index": { - "type": "integer" + "type": "integer", + "description": "Index position of this item in the output list" }, "sequence_number": { - "type": "integer" + "type": "integer", + "description": "Sequential number for ordering streaming events" }, "type": { "type": "string", "const": "response.output_item.added", - "default": "response.output_item.added" + "default": "response.output_item.added", + "description": "Event type identifier, always \"response.output_item.added\"" } }, "additionalProperties": false, @@ -8920,27 +9214,33 @@ "sequence_number", "type" ], - "title": "OpenAIResponseObjectStreamResponseOutputItemAdded" + "title": "OpenAIResponseObjectStreamResponseOutputItemAdded", + "description": "Streaming event for when a new output item is added to the response." }, "OpenAIResponseObjectStreamResponseOutputItemDone": { "type": "object", "properties": { "response_id": { - "type": "string" + "type": "string", + "description": "Unique identifier of the response containing this output" }, "item": { - "$ref": "#/components/schemas/OpenAIResponseOutput" + "$ref": "#/components/schemas/OpenAIResponseOutput", + "description": "The completed output item (message, tool call, etc.)" }, "output_index": { - "type": "integer" + "type": "integer", + "description": "Index position of this item in the output list" }, "sequence_number": { - "type": "integer" + "type": "integer", + "description": "Sequential number for ordering streaming events" }, "type": { "type": "string", "const": "response.output_item.done", - "default": "response.output_item.done" + "default": "response.output_item.done", + "description": "Event type identifier, always \"response.output_item.done\"" } }, "additionalProperties": false, @@ -8951,30 +9251,37 @@ "sequence_number", "type" ], - "title": "OpenAIResponseObjectStreamResponseOutputItemDone" + "title": "OpenAIResponseObjectStreamResponseOutputItemDone", + "description": "Streaming event for when an output item is completed." }, "OpenAIResponseObjectStreamResponseOutputTextDelta": { "type": "object", "properties": { "content_index": { - "type": "integer" + "type": "integer", + "description": "Index position within the text content" }, "delta": { - "type": "string" + "type": "string", + "description": "Incremental text content being added" }, "item_id": { - "type": "string" + "type": "string", + "description": "Unique identifier of the output item being updated" }, "output_index": { - "type": "integer" + "type": "integer", + "description": "Index position of the item in the output list" }, "sequence_number": { - "type": "integer" + "type": "integer", + "description": "Sequential number for ordering streaming events" }, "type": { "type": "string", "const": "response.output_text.delta", - "default": "response.output_text.delta" + "default": "response.output_text.delta", + "description": "Event type identifier, always \"response.output_text.delta\"" } }, "additionalProperties": false, @@ -8986,30 +9293,37 @@ "sequence_number", "type" ], - "title": "OpenAIResponseObjectStreamResponseOutputTextDelta" + "title": "OpenAIResponseObjectStreamResponseOutputTextDelta", + "description": "Streaming event for incremental text content updates." }, "OpenAIResponseObjectStreamResponseOutputTextDone": { "type": "object", "properties": { "content_index": { - "type": "integer" + "type": "integer", + "description": "Index position within the text content" }, "text": { - "type": "string" + "type": "string", + "description": "Final complete text content of the output item" }, "item_id": { - "type": "string" + "type": "string", + "description": "Unique identifier of the completed output item" }, "output_index": { - "type": "integer" + "type": "integer", + "description": "Index position of the item in the output list" }, "sequence_number": { - "type": "integer" + "type": "integer", + "description": "Sequential number for ordering streaming events" }, "type": { "type": "string", "const": "response.output_text.done", - "default": "response.output_text.done" + "default": "response.output_text.done", + "description": "Event type identifier, always \"response.output_text.done\"" } }, "additionalProperties": false, @@ -9021,24 +9335,29 @@ "sequence_number", "type" ], - "title": "OpenAIResponseObjectStreamResponseOutputTextDone" + "title": "OpenAIResponseObjectStreamResponseOutputTextDone", + "description": "Streaming event for when text output is completed." }, "OpenAIResponseObjectStreamResponseWebSearchCallCompleted": { "type": "object", "properties": { "item_id": { - "type": "string" + "type": "string", + "description": "Unique identifier of the completed web search call" }, "output_index": { - "type": "integer" + "type": "integer", + "description": "Index position of the item in the output list" }, "sequence_number": { - "type": "integer" + "type": "integer", + "description": "Sequential number for ordering streaming events" }, "type": { "type": "string", "const": "response.web_search_call.completed", - "default": "response.web_search_call.completed" + "default": "response.web_search_call.completed", + "description": "Event type identifier, always \"response.web_search_call.completed\"" } }, "additionalProperties": false, @@ -9048,24 +9367,29 @@ "sequence_number", "type" ], - "title": "OpenAIResponseObjectStreamResponseWebSearchCallCompleted" + "title": "OpenAIResponseObjectStreamResponseWebSearchCallCompleted", + "description": "Streaming event for completed web search calls." }, "OpenAIResponseObjectStreamResponseWebSearchCallInProgress": { "type": "object", "properties": { "item_id": { - "type": "string" + "type": "string", + "description": "Unique identifier of the web search call" }, "output_index": { - "type": "integer" + "type": "integer", + "description": "Index position of the item in the output list" }, "sequence_number": { - "type": "integer" + "type": "integer", + "description": "Sequential number for ordering streaming events" }, "type": { "type": "string", "const": "response.web_search_call.in_progress", - "default": "response.web_search_call.in_progress" + "default": "response.web_search_call.in_progress", + "description": "Event type identifier, always \"response.web_search_call.in_progress\"" } }, "additionalProperties": false, @@ -9075,7 +9399,8 @@ "sequence_number", "type" ], - "title": "OpenAIResponseObjectStreamResponseWebSearchCallInProgress" + "title": "OpenAIResponseObjectStreamResponseWebSearchCallInProgress", + "description": "Streaming event for web search calls in progress." }, "OpenAIResponseObjectStreamResponseWebSearchCallSearching": { "type": "object", @@ -9108,16 +9433,19 @@ "type": "object", "properties": { "id": { - "type": "string" + "type": "string", + "description": "Unique identifier of the deleted response" }, "object": { "type": "string", "const": "response", - "default": "response" + "default": "response", + "description": "Object type identifier, always \"response\"" }, "deleted": { "type": "boolean", - "default": true + "default": true, + "description": "Deletion confirmation flag, always True" } }, "additionalProperties": false, @@ -9126,7 +9454,8 @@ "object", "deleted" ], - "title": "OpenAIDeleteResponseObject" + "title": "OpenAIDeleteResponseObject", + "description": "Response object confirming deletion of an OpenAI response." }, "EmbeddingsRequest": { "type": "object", @@ -9232,7 +9561,8 @@ "categorical_count", "accuracy" ], - "title": "AggregationFunctionType" + "title": "AggregationFunctionType", + "description": "Types of aggregation functions for scoring results." }, "BasicScoringFnParams": { "type": "object", @@ -9240,13 +9570,15 @@ "type": { "$ref": "#/components/schemas/ScoringFnParamsType", "const": "basic", - "default": "basic" + "default": "basic", + "description": "The type of scoring function parameters, always basic" }, "aggregation_functions": { "type": "array", "items": { "$ref": "#/components/schemas/AggregationFunctionType" - } + }, + "description": "Aggregation functions to apply to the scores of each row" } }, "additionalProperties": false, @@ -9254,7 +9586,8 @@ "type", "aggregation_functions" ], - "title": "BasicScoringFnParams" + "title": "BasicScoringFnParams", + "description": "Parameters for basic scoring function configuration." }, "BenchmarkConfig": { "type": "object", @@ -9306,25 +9639,30 @@ "type": { "$ref": "#/components/schemas/ScoringFnParamsType", "const": "llm_as_judge", - "default": "llm_as_judge" + "default": "llm_as_judge", + "description": "The type of scoring function parameters, always llm_as_judge" }, "judge_model": { - "type": "string" + "type": "string", + "description": "Identifier of the LLM model to use as a judge for scoring" }, "prompt_template": { - "type": "string" + "type": "string", + "description": "(Optional) Custom prompt template for the judge model" }, "judge_score_regexes": { "type": "array", "items": { "type": "string" - } + }, + "description": "Regexes to extract the answer from generated response" }, "aggregation_functions": { "type": "array", "items": { "$ref": "#/components/schemas/AggregationFunctionType" - } + }, + "description": "Aggregation functions to apply to the scores of each row" } }, "additionalProperties": false, @@ -9334,7 +9672,8 @@ "judge_score_regexes", "aggregation_functions" ], - "title": "LLMAsJudgeScoringFnParams" + "title": "LLMAsJudgeScoringFnParams", + "description": "Parameters for LLM-as-judge scoring function configuration." }, "ModelCandidate": { "type": "object", @@ -9372,19 +9711,22 @@ "type": { "$ref": "#/components/schemas/ScoringFnParamsType", "const": "regex_parser", - "default": "regex_parser" + "default": "regex_parser", + "description": "The type of scoring function parameters, always regex_parser" }, "parsing_regexes": { "type": "array", "items": { "type": "string" - } + }, + "description": "Regex to extract the answer from generated response" }, "aggregation_functions": { "type": "array", "items": { "$ref": "#/components/schemas/AggregationFunctionType" - } + }, + "description": "Aggregation functions to apply to the scores of each row" } }, "additionalProperties": false, @@ -9393,7 +9735,8 @@ "parsing_regexes", "aggregation_functions" ], - "title": "RegexParserScoringFnParams" + "title": "RegexParserScoringFnParams", + "description": "Parameters for regex parser scoring function configuration." }, "ScoringFnParams": { "oneOf": [ @@ -9423,7 +9766,8 @@ "regex_parser", "basic" ], - "title": "ScoringFnParamsType" + "title": "ScoringFnParamsType", + "description": "Types of scoring function parameter configurations." }, "EvaluateRowsRequest": { "type": "object", @@ -9596,14 +9940,17 @@ "type": "object", "properties": { "agent_id": { - "type": "string" + "type": "string", + "description": "Unique identifier for the agent" }, "agent_config": { - "$ref": "#/components/schemas/AgentConfig" + "$ref": "#/components/schemas/AgentConfig", + "description": "Configuration settings for the agent" }, "created_at": { "type": "string", - "format": "date-time" + "format": "date-time", + "description": "Timestamp when the agent was created" } }, "additionalProperties": false, @@ -9612,26 +9959,31 @@ "agent_config", "created_at" ], - "title": "Agent" + "title": "Agent", + "description": "An agent instance with configuration and metadata." }, "Session": { "type": "object", "properties": { "session_id": { - "type": "string" + "type": "string", + "description": "Unique identifier for the conversation session" }, "session_name": { - "type": "string" + "type": "string", + "description": "Human-readable name for the session" }, "turns": { "type": "array", "items": { "$ref": "#/components/schemas/Turn" - } + }, + "description": "List of all turns that have occurred in this session" }, "started_at": { "type": "string", - "format": "date-time" + "format": "date-time", + "description": "Timestamp when the session was created" } }, "additionalProperties": false, @@ -9670,14 +10022,16 @@ "shield_call": "#/components/schemas/ShieldCallStep", "memory_retrieval": "#/components/schemas/MemoryRetrievalStep" } - } + }, + "description": "The complete step data and execution details" } }, "additionalProperties": false, "required": [ "step" ], - "title": "AgentStepResponse" + "title": "AgentStepResponse", + "description": "Response containing details of a specific agent step." }, "Benchmark": { "type": "object", @@ -9703,18 +10057,20 @@ "tool", "tool_group" ], - "title": "ResourceType", "const": "benchmark", - "default": "benchmark" + "default": "benchmark", + "description": "The resource type, always benchmark" }, "dataset_id": { - "type": "string" + "type": "string", + "description": "Identifier of the dataset to use for the benchmark evaluation" }, "scoring_functions": { "type": "array", "items": { "type": "string" - } + }, + "description": "List of scoring function identifiers to apply during evaluation" }, "metadata": { "type": "object", @@ -9739,7 +10095,8 @@ "type": "object" } ] - } + }, + "description": "Metadata for this evaluation task" } }, "additionalProperties": false, @@ -9751,7 +10108,8 @@ "scoring_functions", "metadata" ], - "title": "Benchmark" + "title": "Benchmark", + "description": "A benchmark resource for evaluating model performance." }, "OpenAIAssistantMessageParam": { "type": "object", @@ -9801,10 +10159,12 @@ "type": { "type": "string", "const": "image_url", - "default": "image_url" + "default": "image_url", + "description": "Must be \"image_url\" to identify this as image content" }, "image_url": { - "$ref": "#/components/schemas/OpenAIImageURL" + "$ref": "#/components/schemas/OpenAIImageURL", + "description": "Image URL specification and processing details" } }, "additionalProperties": false, @@ -9812,7 +10172,8 @@ "type", "image_url" ], - "title": "OpenAIChatCompletionContentPartImageParam" + "title": "OpenAIChatCompletionContentPartImageParam", + "description": "Image content part for OpenAI-compatible chat completion messages." }, "OpenAIChatCompletionContentPartParam": { "oneOf": [ @@ -9841,10 +10202,12 @@ "type": { "type": "string", "const": "text", - "default": "text" + "default": "text", + "description": "Must be \"text\" to identify this as text content" }, "text": { - "type": "string" + "type": "string", + "description": "The text content of the message" } }, "additionalProperties": false, @@ -9852,44 +10215,53 @@ "type", "text" ], - "title": "OpenAIChatCompletionContentPartTextParam" + "title": "OpenAIChatCompletionContentPartTextParam", + "description": "Text content part for OpenAI-compatible chat completion messages." }, "OpenAIChatCompletionToolCall": { "type": "object", "properties": { "index": { - "type": "integer" + "type": "integer", + "description": "(Optional) Index of the tool call in the list" }, "id": { - "type": "string" + "type": "string", + "description": "(Optional) Unique identifier for the tool call" }, "type": { "type": "string", "const": "function", - "default": "function" + "default": "function", + "description": "Must be \"function\" to identify this as a function call" }, "function": { - "$ref": "#/components/schemas/OpenAIChatCompletionToolCallFunction" + "$ref": "#/components/schemas/OpenAIChatCompletionToolCallFunction", + "description": "(Optional) Function call details" } }, "additionalProperties": false, "required": [ "type" ], - "title": "OpenAIChatCompletionToolCall" + "title": "OpenAIChatCompletionToolCall", + "description": "Tool call specification for OpenAI-compatible chat completion responses." }, "OpenAIChatCompletionToolCallFunction": { "type": "object", "properties": { "name": { - "type": "string" + "type": "string", + "description": "(Optional) Name of the function to call" }, "arguments": { - "type": "string" + "type": "string", + "description": "(Optional) Arguments to pass to the function as a JSON string" } }, "additionalProperties": false, - "title": "OpenAIChatCompletionToolCallFunction" + "title": "OpenAIChatCompletionToolCallFunction", + "description": "Function call details for OpenAI-compatible tool calls." }, "OpenAIChoice": { "type": "object", @@ -10017,17 +10389,20 @@ "type": "object", "properties": { "url": { - "type": "string" + "type": "string", + "description": "URL of the image to include in the message" }, "detail": { - "type": "string" + "type": "string", + "description": "(Optional) Level of detail for image processing. Can be \"low\", \"high\", or \"auto\"" } }, "additionalProperties": false, "required": [ "url" ], - "title": "OpenAIImageURL" + "title": "OpenAIImageURL", + "description": "Image URL specification for OpenAI-compatible chat completion messages." }, "OpenAIMessageParam": { "oneOf": [ @@ -10309,9 +10684,9 @@ "tool", "tool_group" ], - "title": "ResourceType", "const": "dataset", - "default": "dataset" + "default": "dataset", + "description": "Type of resource, always 'dataset' for datasets" }, "purpose": { "type": "string", @@ -10320,11 +10695,11 @@ "eval/question-answer", "eval/messages-answer" ], - "title": "DatasetPurpose", - "description": "Purpose of the dataset. Each purpose has a required input data schema." + "description": "Purpose of the dataset indicating its intended use" }, "source": { - "$ref": "#/components/schemas/DataSource" + "$ref": "#/components/schemas/DataSource", + "description": "Data source configuration for the dataset" }, "metadata": { "type": "object", @@ -10349,7 +10724,8 @@ "type": "object" } ] - } + }, + "description": "Additional metadata for the dataset" } }, "additionalProperties": false, @@ -10361,7 +10737,8 @@ "source", "metadata" ], - "title": "Dataset" + "title": "Dataset", + "description": "Dataset resource for storing and accessing training or evaluation data." }, "RowsDataSource": { "type": "object", @@ -10434,13 +10811,16 @@ "type": "object", "properties": { "identifier": { - "type": "string" + "type": "string", + "description": "Unique identifier for this resource in llama stack" }, "provider_resource_id": { - "type": "string" + "type": "string", + "description": "Unique identifier for this resource in the provider" }, "provider_id": { - "type": "string" + "type": "string", + "description": "ID of the provider that owns this resource" }, "type": { "type": "string", @@ -10454,9 +10834,9 @@ "tool", "tool_group" ], - "title": "ResourceType", "const": "model", - "default": "model" + "default": "model", + "description": "The resource type, always 'model' for model resources" }, "metadata": { "type": "object", @@ -10481,11 +10861,13 @@ "type": "object" } ] - } + }, + "description": "Any additional metadata for this model" }, "model_type": { "$ref": "#/components/schemas/ModelType", - "default": "llm" + "default": "llm", + "description": "The type of model (LLM or embedding model)" } }, "additionalProperties": false, @@ -10496,7 +10878,8 @@ "metadata", "model_type" ], - "title": "Model" + "title": "Model", + "description": "A model resource representing an AI model registered in Llama Stack." }, "ModelType": { "type": "string", @@ -10504,7 +10887,8 @@ "llm", "embedding" ], - "title": "ModelType" + "title": "ModelType", + "description": "Enumeration of supported model types in Llama Stack." }, "AgentTurnInputType": { "type": "object", @@ -10512,14 +10896,16 @@ "type": { "type": "string", "const": "agent_turn_input", - "default": "agent_turn_input" + "default": "agent_turn_input", + "description": "Discriminator type. Always \"agent_turn_input\"" } }, "additionalProperties": false, "required": [ "type" ], - "title": "AgentTurnInputType" + "title": "AgentTurnInputType", + "description": "Parameter type for agent turn input." }, "ArrayType": { "type": "object", @@ -10527,14 +10913,16 @@ "type": { "type": "string", "const": "array", - "default": "array" + "default": "array", + "description": "Discriminator type. Always \"array\"" } }, "additionalProperties": false, "required": [ "type" ], - "title": "ArrayType" + "title": "ArrayType", + "description": "Parameter type for array values." }, "BooleanType": { "type": "object", @@ -10542,14 +10930,16 @@ "type": { "type": "string", "const": "boolean", - "default": "boolean" + "default": "boolean", + "description": "Discriminator type. Always \"boolean\"" } }, "additionalProperties": false, "required": [ "type" ], - "title": "BooleanType" + "title": "BooleanType", + "description": "Parameter type for boolean values." }, "ChatCompletionInputType": { "type": "object", @@ -10557,14 +10947,16 @@ "type": { "type": "string", "const": "chat_completion_input", - "default": "chat_completion_input" + "default": "chat_completion_input", + "description": "Discriminator type. Always \"chat_completion_input\"" } }, "additionalProperties": false, "required": [ "type" ], - "title": "ChatCompletionInputType" + "title": "ChatCompletionInputType", + "description": "Parameter type for chat completion input." }, "CompletionInputType": { "type": "object", @@ -10572,14 +10964,16 @@ "type": { "type": "string", "const": "completion_input", - "default": "completion_input" + "default": "completion_input", + "description": "Discriminator type. Always \"completion_input\"" } }, "additionalProperties": false, "required": [ "type" ], - "title": "CompletionInputType" + "title": "CompletionInputType", + "description": "Parameter type for completion input." }, "JsonType": { "type": "object", @@ -10587,14 +10981,16 @@ "type": { "type": "string", "const": "json", - "default": "json" + "default": "json", + "description": "Discriminator type. Always \"json\"" } }, "additionalProperties": false, "required": [ "type" ], - "title": "JsonType" + "title": "JsonType", + "description": "Parameter type for JSON values." }, "NumberType": { "type": "object", @@ -10602,14 +10998,16 @@ "type": { "type": "string", "const": "number", - "default": "number" + "default": "number", + "description": "Discriminator type. Always \"number\"" } }, "additionalProperties": false, "required": [ "type" ], - "title": "NumberType" + "title": "NumberType", + "description": "Parameter type for numeric values." }, "ObjectType": { "type": "object", @@ -10617,14 +11015,16 @@ "type": { "type": "string", "const": "object", - "default": "object" + "default": "object", + "description": "Discriminator type. Always \"object\"" } }, "additionalProperties": false, "required": [ "type" ], - "title": "ObjectType" + "title": "ObjectType", + "description": "Parameter type for object values." }, "ParamType": { "oneOf": [ @@ -10699,9 +11099,9 @@ "tool", "tool_group" ], - "title": "ResourceType", "const": "scoring_function", - "default": "scoring_function" + "default": "scoring_function", + "description": "The resource type, always scoring_function" }, "description": { "type": "string" @@ -10746,7 +11146,8 @@ "metadata", "return_type" ], - "title": "ScoringFn" + "title": "ScoringFn", + "description": "A scoring function resource for evaluating model outputs." }, "StringType": { "type": "object", @@ -10754,14 +11155,16 @@ "type": { "type": "string", "const": "string", - "default": "string" + "default": "string", + "description": "Discriminator type. Always \"string\"" } }, "additionalProperties": false, "required": [ "type" ], - "title": "StringType" + "title": "StringType", + "description": "Parameter type for string values." }, "UnionType": { "type": "object", @@ -10769,14 +11172,16 @@ "type": { "type": "string", "const": "union", - "default": "union" + "default": "union", + "description": "Discriminator type. Always \"union\"" } }, "additionalProperties": false, "required": [ "type" ], - "title": "UnionType" + "title": "UnionType", + "description": "Parameter type for union values." }, "Shield": { "type": "object", @@ -10802,9 +11207,9 @@ "tool", "tool_group" ], - "title": "ResourceType", "const": "shield", - "default": "shield" + "default": "shield", + "description": "The resource type, always shield" }, "params": { "type": "object", @@ -10829,7 +11234,8 @@ "type": "object" } ] - } + }, + "description": "(Optional) Configuration parameters for the shield" } }, "additionalProperties": false, @@ -10839,30 +11245,36 @@ "type" ], "title": "Shield", - "description": "A safety shield resource that can be used to check content" + "description": "A safety shield resource that can be used to check content." }, "Span": { "type": "object", "properties": { "span_id": { - "type": "string" + "type": "string", + "description": "Unique identifier for the span" }, "trace_id": { - "type": "string" + "type": "string", + "description": "Unique identifier for the trace this span belongs to" }, "parent_span_id": { - "type": "string" + "type": "string", + "description": "(Optional) Unique identifier for the parent span, if this is a child span" }, "name": { - "type": "string" + "type": "string", + "description": "Human-readable name describing the operation this span represents" }, "start_time": { "type": "string", - "format": "date-time" + "format": "date-time", + "description": "Timestamp when the operation began" }, "end_time": { "type": "string", - "format": "date-time" + "format": "date-time", + "description": "(Optional) Timestamp when the operation finished, if completed" }, "attributes": { "type": "object", @@ -10887,7 +11299,8 @@ "type": "object" } ] - } + }, + "description": "(Optional) Key-value pairs containing additional metadata about the span" } }, "additionalProperties": false, @@ -10897,7 +11310,8 @@ "name", "start_time" ], - "title": "Span" + "title": "Span", + "description": "A span representing a single operation within a trace." }, "GetSpanTreeRequest": { "type": "object", @@ -10923,30 +11337,37 @@ "ok", "error" ], - "title": "SpanStatus" + "title": "SpanStatus", + "description": "The status of a span indicating whether it completed successfully or with an error." }, "SpanWithStatus": { "type": "object", "properties": { "span_id": { - "type": "string" + "type": "string", + "description": "Unique identifier for the span" }, "trace_id": { - "type": "string" + "type": "string", + "description": "Unique identifier for the trace this span belongs to" }, "parent_span_id": { - "type": "string" + "type": "string", + "description": "(Optional) Unique identifier for the parent span, if this is a child span" }, "name": { - "type": "string" + "type": "string", + "description": "Human-readable name describing the operation this span represents" }, "start_time": { "type": "string", - "format": "date-time" + "format": "date-time", + "description": "Timestamp when the operation began" }, "end_time": { "type": "string", - "format": "date-time" + "format": "date-time", + "description": "(Optional) Timestamp when the operation finished, if completed" }, "attributes": { "type": "object", @@ -10971,10 +11392,12 @@ "type": "object" } ] - } + }, + "description": "(Optional) Key-value pairs containing additional metadata about the span" }, "status": { - "$ref": "#/components/schemas/SpanStatus" + "$ref": "#/components/schemas/SpanStatus", + "description": "(Optional) The current status of the span" } }, "additionalProperties": false, @@ -10984,7 +11407,8 @@ "name", "start_time" ], - "title": "SpanWithStatus" + "title": "SpanWithStatus", + "description": "A span that includes status information." }, "QuerySpanTreeResponse": { "type": "object", @@ -10993,14 +11417,16 @@ "type": "object", "additionalProperties": { "$ref": "#/components/schemas/SpanWithStatus" - } + }, + "description": "Dictionary mapping span IDs to spans with status information" } }, "additionalProperties": false, "required": [ "data" ], - "title": "QuerySpanTreeResponse" + "title": "QuerySpanTreeResponse", + "description": "Response containing a tree structure of spans." }, "Tool": { "type": "object", @@ -11026,21 +11452,24 @@ "tool", "tool_group" ], - "title": "ResourceType", "const": "tool", - "default": "tool" + "default": "tool", + "description": "Type of resource, always 'tool'" }, "toolgroup_id": { - "type": "string" + "type": "string", + "description": "ID of the tool group this tool belongs to" }, "description": { - "type": "string" + "type": "string", + "description": "Human-readable description of what the tool does" }, "parameters": { "type": "array", "items": { "$ref": "#/components/schemas/ToolParameter" - } + }, + "description": "List of parameters this tool accepts" }, "metadata": { "type": "object", @@ -11065,7 +11494,8 @@ "type": "object" } ] - } + }, + "description": "(Optional) Additional metadata about the tool" } }, "additionalProperties": false, @@ -11077,7 +11507,8 @@ "description", "parameters" ], - "title": "Tool" + "title": "Tool", + "description": "A tool that can be invoked by agents." }, "ToolGroup": { "type": "object", @@ -11103,12 +11534,13 @@ "tool", "tool_group" ], - "title": "ResourceType", "const": "tool_group", - "default": "tool_group" + "default": "tool_group", + "description": "Type of resource, always 'tool_group'" }, "mcp_endpoint": { - "$ref": "#/components/schemas/URL" + "$ref": "#/components/schemas/URL", + "description": "(Optional) Model Context Protocol endpoint for remote tools" }, "args": { "type": "object", @@ -11133,7 +11565,8 @@ "type": "object" } ] - } + }, + "description": "(Optional) Additional arguments for the tool group" } }, "additionalProperties": false, @@ -11142,24 +11575,29 @@ "provider_id", "type" ], - "title": "ToolGroup" + "title": "ToolGroup", + "description": "A group of related tools managed together." }, "Trace": { "type": "object", "properties": { "trace_id": { - "type": "string" + "type": "string", + "description": "Unique identifier for the trace" }, "root_span_id": { - "type": "string" + "type": "string", + "description": "Unique identifier for the root span that started this trace" }, "start_time": { "type": "string", - "format": "date-time" + "format": "date-time", + "description": "Timestamp when the trace began" }, "end_time": { "type": "string", - "format": "date-time" + "format": "date-time", + "description": "(Optional) Timestamp when the trace finished, if completed" } }, "additionalProperties": false, @@ -11168,29 +11606,36 @@ "root_span_id", "start_time" ], - "title": "Trace" + "title": "Trace", + "description": "A trace representing the complete execution path of a request across multiple operations." }, "Checkpoint": { "type": "object", "properties": { "identifier": { - "type": "string" + "type": "string", + "description": "Unique identifier for the checkpoint" }, "created_at": { "type": "string", - "format": "date-time" + "format": "date-time", + "description": "Timestamp when the checkpoint was created" }, "epoch": { - "type": "integer" + "type": "integer", + "description": "Training epoch when the checkpoint was saved" }, "post_training_job_id": { - "type": "string" + "type": "string", + "description": "Identifier of the training job that created this checkpoint" }, "path": { - "type": "string" + "type": "string", + "description": "File system path where the checkpoint is stored" }, "training_metrics": { - "$ref": "#/components/schemas/PostTrainingMetric" + "$ref": "#/components/schemas/PostTrainingMetric", + "description": "(Optional) Training metrics associated with this checkpoint" } }, "additionalProperties": false, @@ -11202,19 +11647,21 @@ "path" ], "title": "Checkpoint", - "description": "Checkpoint created during training runs" + "description": "Checkpoint created during training runs." }, "PostTrainingJobArtifactsResponse": { "type": "object", "properties": { "job_uuid": { - "type": "string" + "type": "string", + "description": "Unique identifier for the training job" }, "checkpoints": { "type": "array", "items": { "$ref": "#/components/schemas/Checkpoint" - } + }, + "description": "List of model checkpoints created during training" } }, "additionalProperties": false, @@ -11229,16 +11676,20 @@ "type": "object", "properties": { "epoch": { - "type": "integer" + "type": "integer", + "description": "Training epoch number" }, "train_loss": { - "type": "number" + "type": "number", + "description": "Loss value on the training dataset" }, "validation_loss": { - "type": "number" + "type": "number", + "description": "Loss value on the validation dataset" }, "perplexity": { - "type": "number" + "type": "number", + "description": "Perplexity metric indicating model confidence" } }, "additionalProperties": false, @@ -11248,13 +11699,15 @@ "validation_loss", "perplexity" ], - "title": "PostTrainingMetric" + "title": "PostTrainingMetric", + "description": "Training metrics captured during post-training jobs." }, "PostTrainingJobStatusResponse": { "type": "object", "properties": { "job_uuid": { - "type": "string" + "type": "string", + "description": "Unique identifier for the training job" }, "status": { "type": "string", @@ -11265,19 +11718,22 @@ "scheduled", "cancelled" ], - "title": "JobStatus" + "description": "Current status of the training job" }, "scheduled_at": { "type": "string", - "format": "date-time" + "format": "date-time", + "description": "(Optional) Timestamp when the job was scheduled" }, "started_at": { "type": "string", - "format": "date-time" + "format": "date-time", + "description": "(Optional) Timestamp when the job execution began" }, "completed_at": { "type": "string", - "format": "date-time" + "format": "date-time", + "description": "(Optional) Timestamp when the job finished, if completed" }, "resources_allocated": { "type": "object", @@ -11302,13 +11758,15 @@ "type": "object" } ] - } + }, + "description": "(Optional) Information about computational resources allocated to the job" }, "checkpoints": { "type": "array", "items": { "$ref": "#/components/schemas/Checkpoint" - } + }, + "description": "List of model checkpoints created during training" } }, "additionalProperties": false, @@ -11370,15 +11828,17 @@ "tool", "tool_group" ], - "title": "ResourceType", "const": "vector_db", - "default": "vector_db" + "default": "vector_db", + "description": "Type of resource, always 'vector_db' for vector databases" }, "embedding_model": { - "type": "string" + "type": "string", + "description": "Name of the embedding model to use for vector generation" }, "embedding_dimension": { - "type": "integer" + "type": "integer", + "description": "Dimension of the embedding vectors" }, "vector_db_name": { "type": "string" @@ -11392,7 +11852,8 @@ "embedding_model", "embedding_dimension" ], - "title": "VectorDB" + "title": "VectorDB", + "description": "Vector database resource for storing and querying vector embeddings." }, "HealthInfo": { "type": "object", @@ -11404,14 +11865,15 @@ "Error", "Not Implemented" ], - "title": "HealthStatus" + "description": "Current health status of the service" } }, "additionalProperties": false, "required": [ "status" ], - "title": "HealthInfo" + "title": "HealthInfo", + "description": "Health status information for the service." }, "RAGDocument": { "type": "object", @@ -11487,13 +11949,16 @@ "type": "array", "items": { "$ref": "#/components/schemas/RAGDocument" - } + }, + "description": "List of documents to index in the RAG system" }, "vector_db_id": { - "type": "string" + "type": "string", + "description": "ID of the vector database to store the document embeddings" }, "chunk_size_in_tokens": { - "type": "integer" + "type": "integer", + "description": "(Optional) Size in tokens for document chunking during indexing" } }, "additionalProperties": false, @@ -11643,13 +12108,16 @@ "type": "object", "properties": { "api": { - "type": "string" + "type": "string", + "description": "The API name this provider implements" }, "provider_id": { - "type": "string" + "type": "string", + "description": "Unique identifier for the provider" }, "provider_type": { - "type": "string" + "type": "string", + "description": "The type of provider implementation" }, "config": { "type": "object", @@ -11674,7 +12142,8 @@ "type": "object" } ] - } + }, + "description": "Configuration parameters for the provider" }, "health": { "type": "object", @@ -11699,7 +12168,8 @@ "type": "object" } ] - } + }, + "description": "Current health status of the provider" } }, "additionalProperties": false, @@ -11710,7 +12180,8 @@ "config", "health" ], - "title": "ProviderInfo" + "title": "ProviderInfo", + "description": "Information about a registered provider including its configuration and health status." }, "InvokeToolRequest": { "type": "object", @@ -11757,13 +12228,16 @@ "type": "object", "properties": { "content": { - "$ref": "#/components/schemas/InterleavedContent" + "$ref": "#/components/schemas/InterleavedContent", + "description": "(Optional) The output content from the tool execution" }, "error_message": { - "type": "string" + "type": "string", + "description": "(Optional) Error message if the tool execution failed" }, "error_code": { - "type": "integer" + "type": "integer", + "description": "(Optional) Numeric error code if the tool execution failed" }, "metadata": { "type": "object", @@ -11788,11 +12262,13 @@ "type": "object" } ] - } + }, + "description": "(Optional) Additional metadata about the tool execution" } }, "additionalProperties": false, - "title": "ToolInvocationResult" + "title": "ToolInvocationResult", + "description": "Result of a tool invocation." }, "PaginatedResponse": { "type": "object", @@ -11847,7 +12323,8 @@ "type": "object", "properties": { "job_id": { - "type": "string" + "type": "string", + "description": "Unique identifier for the job" }, "status": { "type": "string", @@ -11858,7 +12335,7 @@ "scheduled", "cancelled" ], - "title": "JobStatus" + "description": "Current execution status of the job" } }, "additionalProperties": false, @@ -11866,7 +12343,8 @@ "job_id", "status" ], - "title": "Job" + "title": "Job", + "description": "A job execution instance with status tracking." }, "ListBenchmarksResponse": { "type": "object", @@ -11890,7 +12368,8 @@ "asc", "desc" ], - "title": "Order" + "title": "Order", + "description": "Sort order for paginated responses." }, "ListOpenAIChatCompletionResponse": { "type": "object", @@ -11942,21 +12421,26 @@ "input_messages" ], "title": "OpenAICompletionWithInputMessages" - } + }, + "description": "List of chat completion objects with their input messages" }, "has_more": { - "type": "boolean" + "type": "boolean", + "description": "Whether there are more completions available beyond this list" }, "first_id": { - "type": "string" + "type": "string", + "description": "ID of the first completion in this list" }, "last_id": { - "type": "string" + "type": "string", + "description": "ID of the last completion in this list" }, "object": { "type": "string", "const": "list", - "default": "list" + "default": "list", + "description": "Must be \"list\" to identify this as a list response" } }, "additionalProperties": false, @@ -11967,7 +12451,8 @@ "last_id", "object" ], - "title": "ListOpenAIChatCompletionResponse" + "title": "ListOpenAIChatCompletionResponse", + "description": "Response from listing OpenAI-compatible chat completions." }, "ListDatasetsResponse": { "type": "object", @@ -11976,14 +12461,16 @@ "type": "array", "items": { "$ref": "#/components/schemas/Dataset" - } + }, + "description": "List of datasets" } }, "additionalProperties": false, "required": [ "data" ], - "title": "ListDatasetsResponse" + "title": "ListDatasetsResponse", + "description": "Response from listing datasets." }, "ListModelsResponse": { "type": "object", @@ -12008,12 +12495,14 @@ "type": "array", "items": { "$ref": "#/components/schemas/OpenAIResponseInput" - } + }, + "description": "List of input items" }, "object": { "type": "string", "const": "list", - "default": "list" + "default": "list", + "description": "Object type identifier, always \"list\"" } }, "additionalProperties": false, @@ -12021,7 +12510,8 @@ "data", "object" ], - "title": "ListOpenAIResponseInputItem" + "title": "ListOpenAIResponseInputItem", + "description": "List container for OpenAI response input items." }, "ListOpenAIResponseObject": { "type": "object", @@ -12030,21 +12520,26 @@ "type": "array", "items": { "$ref": "#/components/schemas/OpenAIResponseObjectWithInput" - } + }, + "description": "List of response objects with their input context" }, "has_more": { - "type": "boolean" + "type": "boolean", + "description": "Whether there are more results available beyond this page" }, "first_id": { - "type": "string" + "type": "string", + "description": "Identifier of the first item in this page" }, "last_id": { - "type": "string" + "type": "string", + "description": "Identifier of the last item in this page" }, "object": { "type": "string", "const": "list", - "default": "list" + "default": "list", + "description": "Object type identifier, always \"list\"" } }, "additionalProperties": false, @@ -12055,64 +12550,80 @@ "last_id", "object" ], - "title": "ListOpenAIResponseObject" + "title": "ListOpenAIResponseObject", + "description": "Paginated list of OpenAI response objects with navigation metadata." }, "OpenAIResponseObjectWithInput": { "type": "object", "properties": { "created_at": { - "type": "integer" + "type": "integer", + "description": "Unix timestamp when the response was created" }, "error": { - "$ref": "#/components/schemas/OpenAIResponseError" + "$ref": "#/components/schemas/OpenAIResponseError", + "description": "(Optional) Error details if the response generation failed" }, "id": { - "type": "string" + "type": "string", + "description": "Unique identifier for this response" }, "model": { - "type": "string" + "type": "string", + "description": "Model identifier used for generation" }, "object": { "type": "string", "const": "response", - "default": "response" + "default": "response", + "description": "Object type identifier, always \"response\"" }, "output": { "type": "array", "items": { "$ref": "#/components/schemas/OpenAIResponseOutput" - } + }, + "description": "List of generated output items (messages, tool calls, etc.)" }, "parallel_tool_calls": { "type": "boolean", - "default": false + "default": false, + "description": "Whether tool calls can be executed in parallel" }, "previous_response_id": { - "type": "string" + "type": "string", + "description": "(Optional) ID of the previous response in a conversation" }, "status": { - "type": "string" + "type": "string", + "description": "Current status of the response generation" }, "temperature": { - "type": "number" + "type": "number", + "description": "(Optional) Sampling temperature used for generation" }, "text": { - "$ref": "#/components/schemas/OpenAIResponseText" + "$ref": "#/components/schemas/OpenAIResponseText", + "description": "Text formatting configuration for the response" }, "top_p": { - "type": "number" + "type": "number", + "description": "(Optional) Nucleus sampling parameter used for generation" }, "truncation": { - "type": "string" + "type": "string", + "description": "(Optional) Truncation strategy applied to the response" }, "user": { - "type": "string" + "type": "string", + "description": "(Optional) User identifier associated with the request" }, "input": { "type": "array", "items": { "$ref": "#/components/schemas/OpenAIResponseInput" - } + }, + "description": "List of input items that led to this response" } }, "additionalProperties": false, @@ -12127,7 +12638,8 @@ "text", "input" ], - "title": "OpenAIResponseObjectWithInput" + "title": "OpenAIResponseObjectWithInput", + "description": "OpenAI response object extended with input context information." }, "ListProvidersResponse": { "type": "object", @@ -12136,29 +12648,34 @@ "type": "array", "items": { "$ref": "#/components/schemas/ProviderInfo" - } + }, + "description": "List of provider information objects" } }, "additionalProperties": false, "required": [ "data" ], - "title": "ListProvidersResponse" + "title": "ListProvidersResponse", + "description": "Response containing a list of all available providers." }, "RouteInfo": { "type": "object", "properties": { "route": { - "type": "string" + "type": "string", + "description": "The API endpoint path" }, "method": { - "type": "string" + "type": "string", + "description": "HTTP method for the route" }, "provider_types": { "type": "array", "items": { "type": "string" - } + }, + "description": "List of provider types that implement this route" } }, "additionalProperties": false, @@ -12167,7 +12684,8 @@ "method", "provider_types" ], - "title": "RouteInfo" + "title": "RouteInfo", + "description": "Information about an API route including its path, method, and implementing providers." }, "ListRoutesResponse": { "type": "object", @@ -12176,14 +12694,16 @@ "type": "array", "items": { "$ref": "#/components/schemas/RouteInfo" - } + }, + "description": "List of available route information objects" } }, "additionalProperties": false, "required": [ "data" ], - "title": "ListRoutesResponse" + "title": "ListRoutesResponse", + "description": "Response containing a list of all available API routes." }, "ListToolDefsResponse": { "type": "object", @@ -12192,14 +12712,16 @@ "type": "array", "items": { "$ref": "#/components/schemas/ToolDef" - } + }, + "description": "List of tool definitions" } }, "additionalProperties": false, "required": [ "data" ], - "title": "ListToolDefsResponse" + "title": "ListToolDefsResponse", + "description": "Response containing a list of tool definitions." }, "ListScoringFunctionsResponse": { "type": "object", @@ -12240,14 +12762,16 @@ "type": "array", "items": { "$ref": "#/components/schemas/ToolGroup" - } + }, + "description": "List of tool groups" } }, "additionalProperties": false, "required": [ "data" ], - "title": "ListToolGroupsResponse" + "title": "ListToolGroupsResponse", + "description": "Response containing a list of tool groups." }, "ListToolsResponse": { "type": "object", @@ -12256,14 +12780,16 @@ "type": "array", "items": { "$ref": "#/components/schemas/Tool" - } + }, + "description": "List of tools" } }, "additionalProperties": false, "required": [ "data" ], - "title": "ListToolsResponse" + "title": "ListToolsResponse", + "description": "Response containing a list of tools." }, "ListVectorDBsResponse": { "type": "object", @@ -12272,14 +12798,16 @@ "type": "array", "items": { "$ref": "#/components/schemas/VectorDB" - } + }, + "description": "List of vector databases" } }, "additionalProperties": false, "required": [ "data" ], - "title": "ListVectorDBsResponse" + "title": "ListVectorDBsResponse", + "description": "Response from listing vector databases." }, "Event": { "oneOf": [ @@ -12309,7 +12837,8 @@ "structured_log", "metric" ], - "title": "EventType" + "title": "EventType", + "description": "The type of telemetry event being logged." }, "LogSeverity": { "type": "string", @@ -12321,20 +12850,24 @@ "error", "critical" ], - "title": "LogSeverity" + "title": "LogSeverity", + "description": "The severity level of a log message." }, "MetricEvent": { "type": "object", "properties": { "trace_id": { - "type": "string" + "type": "string", + "description": "Unique identifier for the trace this event belongs to" }, "span_id": { - "type": "string" + "type": "string", + "description": "Unique identifier for the span this event belongs to" }, "timestamp": { "type": "string", - "format": "date-time" + "format": "date-time", + "description": "Timestamp when the event occurred" }, "attributes": { "type": "object", @@ -12356,15 +12889,18 @@ "type": "null" } ] - } + }, + "description": "(Optional) Key-value pairs containing additional metadata about the event" }, "type": { "$ref": "#/components/schemas/EventType", "const": "metric", - "default": "metric" + "default": "metric", + "description": "Event type identifier set to METRIC" }, "metric": { - "type": "string" + "type": "string", + "description": "The name of the metric being measured" }, "value": { "oneOf": [ @@ -12374,10 +12910,12 @@ { "type": "number" } - ] + ], + "description": "The numeric value of the metric measurement" }, "unit": { - "type": "string" + "type": "string", + "description": "The unit of measurement for the metric value" } }, "additionalProperties": false, @@ -12390,7 +12928,8 @@ "value", "unit" ], - "title": "MetricEvent" + "title": "MetricEvent", + "description": "A metric event containing a measured value." }, "SpanEndPayload": { "type": "object", @@ -12398,10 +12937,12 @@ "type": { "$ref": "#/components/schemas/StructuredLogType", "const": "span_end", - "default": "span_end" + "default": "span_end", + "description": "Payload type identifier set to SPAN_END" }, "status": { - "$ref": "#/components/schemas/SpanStatus" + "$ref": "#/components/schemas/SpanStatus", + "description": "The final status of the span indicating success or failure" } }, "additionalProperties": false, @@ -12409,7 +12950,8 @@ "type", "status" ], - "title": "SpanEndPayload" + "title": "SpanEndPayload", + "description": "Payload for a span end event." }, "SpanStartPayload": { "type": "object", @@ -12417,13 +12959,16 @@ "type": { "$ref": "#/components/schemas/StructuredLogType", "const": "span_start", - "default": "span_start" + "default": "span_start", + "description": "Payload type identifier set to SPAN_START" }, "name": { - "type": "string" + "type": "string", + "description": "Human-readable name describing the operation this span represents" }, "parent_span_id": { - "type": "string" + "type": "string", + "description": "(Optional) Unique identifier for the parent span, if this is a child span" } }, "additionalProperties": false, @@ -12431,20 +12976,24 @@ "type", "name" ], - "title": "SpanStartPayload" + "title": "SpanStartPayload", + "description": "Payload for a span start event." }, "StructuredLogEvent": { "type": "object", "properties": { "trace_id": { - "type": "string" + "type": "string", + "description": "Unique identifier for the trace this event belongs to" }, "span_id": { - "type": "string" + "type": "string", + "description": "Unique identifier for the span this event belongs to" }, "timestamp": { "type": "string", - "format": "date-time" + "format": "date-time", + "description": "Timestamp when the event occurred" }, "attributes": { "type": "object", @@ -12466,15 +13015,18 @@ "type": "null" } ] - } + }, + "description": "(Optional) Key-value pairs containing additional metadata about the event" }, "type": { "$ref": "#/components/schemas/EventType", "const": "structured_log", - "default": "structured_log" + "default": "structured_log", + "description": "Event type identifier set to STRUCTURED_LOG" }, "payload": { - "$ref": "#/components/schemas/StructuredLogPayload" + "$ref": "#/components/schemas/StructuredLogPayload", + "description": "The structured payload data for the log event" } }, "additionalProperties": false, @@ -12485,7 +13037,8 @@ "type", "payload" ], - "title": "StructuredLogEvent" + "title": "StructuredLogEvent", + "description": "A structured log event containing typed payload data." }, "StructuredLogPayload": { "oneOf": [ @@ -12510,20 +13063,24 @@ "span_start", "span_end" ], - "title": "StructuredLogType" + "title": "StructuredLogType", + "description": "The type of structured log event payload." }, "UnstructuredLogEvent": { "type": "object", "properties": { "trace_id": { - "type": "string" + "type": "string", + "description": "Unique identifier for the trace this event belongs to" }, "span_id": { - "type": "string" + "type": "string", + "description": "Unique identifier for the span this event belongs to" }, "timestamp": { "type": "string", - "format": "date-time" + "format": "date-time", + "description": "Timestamp when the event occurred" }, "attributes": { "type": "object", @@ -12545,18 +13102,22 @@ "type": "null" } ] - } + }, + "description": "(Optional) Key-value pairs containing additional metadata about the event" }, "type": { "$ref": "#/components/schemas/EventType", "const": "unstructured_log", - "default": "unstructured_log" + "default": "unstructured_log", + "description": "Event type identifier set to UNSTRUCTURED_LOG" }, "message": { - "type": "string" + "type": "string", + "description": "The log message text" }, "severity": { - "$ref": "#/components/schemas/LogSeverity" + "$ref": "#/components/schemas/LogSeverity", + "description": "The severity level of the log message" } }, "additionalProperties": false, @@ -12568,7 +13129,8 @@ "message", "severity" ], - "title": "UnstructuredLogEvent" + "title": "UnstructuredLogEvent", + "description": "An unstructured log event containing a simple text message." }, "LogEventRequest": { "type": "object", @@ -12612,14 +13174,16 @@ "type": { "type": "string", "const": "auto", - "default": "auto" + "default": "auto", + "description": "Strategy type, always \"auto\" for automatic chunking" } }, "additionalProperties": false, "required": [ "type" ], - "title": "VectorStoreChunkingStrategyAuto" + "title": "VectorStoreChunkingStrategyAuto", + "description": "Automatic chunking strategy for vector store files." }, "VectorStoreChunkingStrategyStatic": { "type": "object", @@ -12627,10 +13191,12 @@ "type": { "type": "string", "const": "static", - "default": "static" + "default": "static", + "description": "Strategy type, always \"static\" for static chunking" }, "static": { - "$ref": "#/components/schemas/VectorStoreChunkingStrategyStaticConfig" + "$ref": "#/components/schemas/VectorStoreChunkingStrategyStaticConfig", + "description": "Configuration parameters for the static chunking strategy" } }, "additionalProperties": false, @@ -12638,18 +13204,21 @@ "type", "static" ], - "title": "VectorStoreChunkingStrategyStatic" + "title": "VectorStoreChunkingStrategyStatic", + "description": "Static chunking strategy with configurable parameters." }, "VectorStoreChunkingStrategyStaticConfig": { "type": "object", "properties": { "chunk_overlap_tokens": { "type": "integer", - "default": 400 + "default": 400, + "description": "Number of tokens to overlap between adjacent chunks" }, "max_chunk_size_tokens": { "type": "integer", - "default": 800 + "default": 800, + "description": "Maximum number of tokens per chunk, must be between 100 and 4096" } }, "additionalProperties": false, @@ -12657,7 +13226,8 @@ "chunk_overlap_tokens", "max_chunk_size_tokens" ], - "title": "VectorStoreChunkingStrategyStaticConfig" + "title": "VectorStoreChunkingStrategyStaticConfig", + "description": "Configuration for static chunking strategy." }, "OpenaiAttachFileToVectorStoreRequest": { "type": "object", @@ -12716,10 +13286,12 @@ "type": "string", "const": "rate_limit_exceeded" } - ] + ], + "description": "Error code indicating the type of failure" }, "message": { - "type": "string" + "type": "string", + "description": "Human-readable error message describing the failure" } }, "additionalProperties": false, @@ -12727,17 +13299,20 @@ "code", "message" ], - "title": "VectorStoreFileLastError" + "title": "VectorStoreFileLastError", + "description": "Error information for failed vector store file processing." }, "VectorStoreFileObject": { "type": "object", "properties": { "id": { - "type": "string" + "type": "string", + "description": "Unique identifier for the file" }, "object": { "type": "string", - "default": "vector_store.file" + "default": "vector_store.file", + "description": "Object type identifier, always \"vector_store.file\"" }, "attributes": { "type": "object", @@ -12762,26 +13337,33 @@ "type": "object" } ] - } + }, + "description": "Key-value attributes associated with the file" }, "chunking_strategy": { - "$ref": "#/components/schemas/VectorStoreChunkingStrategy" + "$ref": "#/components/schemas/VectorStoreChunkingStrategy", + "description": "Strategy used for splitting the file into chunks" }, "created_at": { - "type": "integer" + "type": "integer", + "description": "Timestamp when the file was added to the vector store" }, "last_error": { - "$ref": "#/components/schemas/VectorStoreFileLastError" + "$ref": "#/components/schemas/VectorStoreFileLastError", + "description": "(Optional) Error information if file processing failed" }, "status": { - "$ref": "#/components/schemas/VectorStoreFileStatus" + "$ref": "#/components/schemas/VectorStoreFileStatus", + "description": "Current processing status of the file" }, "usage_bytes": { "type": "integer", - "default": 0 + "default": 0, + "description": "Storage space used by this file in bytes" }, "vector_store_id": { - "type": "string" + "type": "string", + "description": "ID of the vector store containing this file" } }, "additionalProperties": false, @@ -12822,13 +13404,16 @@ "type": "object", "properties": { "name": { - "type": "string" + "type": "string", + "description": "Name of the schema" }, "description": { - "type": "string" + "type": "string", + "description": "(Optional) Description of the schema" }, "strict": { - "type": "boolean" + "type": "boolean", + "description": "(Optional) Whether to enforce strict adherence to the schema" }, "schema": { "type": "object", @@ -12853,14 +13438,16 @@ "type": "object" } ] - } + }, + "description": "(Optional) The JSON schema definition" } }, "additionalProperties": false, "required": [ "name" ], - "title": "OpenAIJSONSchema" + "title": "OpenAIJSONSchema", + "description": "JSON schema specification for OpenAI-compatible structured response format." }, "OpenAIResponseFormatJSONObject": { "type": "object", @@ -12868,14 +13455,16 @@ "type": { "type": "string", "const": "json_object", - "default": "json_object" + "default": "json_object", + "description": "Must be \"json_object\" to indicate generic JSON object response format" } }, "additionalProperties": false, "required": [ "type" ], - "title": "OpenAIResponseFormatJSONObject" + "title": "OpenAIResponseFormatJSONObject", + "description": "JSON object response format for OpenAI-compatible chat completion requests." }, "OpenAIResponseFormatJSONSchema": { "type": "object", @@ -12883,10 +13472,12 @@ "type": { "type": "string", "const": "json_schema", - "default": "json_schema" + "default": "json_schema", + "description": "Must be \"json_schema\" to indicate structured JSON response format" }, "json_schema": { - "$ref": "#/components/schemas/OpenAIJSONSchema" + "$ref": "#/components/schemas/OpenAIJSONSchema", + "description": "The JSON schema specification for the response" } }, "additionalProperties": false, @@ -12894,7 +13485,8 @@ "type", "json_schema" ], - "title": "OpenAIResponseFormatJSONSchema" + "title": "OpenAIResponseFormatJSONSchema", + "description": "JSON schema response format for OpenAI-compatible chat completion requests." }, "OpenAIResponseFormatParam": { "oneOf": [ @@ -12923,14 +13515,16 @@ "type": { "type": "string", "const": "text", - "default": "text" + "default": "text", + "description": "Must be \"text\" to indicate plain text response format" } }, "additionalProperties": false, "required": [ "type" ], - "title": "OpenAIResponseFormatText" + "title": "OpenAIResponseFormatText", + "description": "Text response format for OpenAI-compatible chat completion requests." }, "OpenaiChatCompletionRequest": { "type": "object", @@ -13641,19 +14235,24 @@ "type": "object", "properties": { "completed": { - "type": "integer" + "type": "integer", + "description": "Number of files that have been successfully processed" }, "cancelled": { - "type": "integer" + "type": "integer", + "description": "Number of files that had their processing cancelled" }, "failed": { - "type": "integer" + "type": "integer", + "description": "Number of files that failed to process" }, "in_progress": { - "type": "integer" + "type": "integer", + "description": "Number of files currently being processed" }, "total": { - "type": "integer" + "type": "integer", + "description": "Total number of files in the vector store" } }, "additionalProperties": false, @@ -13664,34 +14263,42 @@ "in_progress", "total" ], - "title": "VectorStoreFileCounts" + "title": "VectorStoreFileCounts", + "description": "File processing status counts for a vector store." }, "VectorStoreObject": { "type": "object", "properties": { "id": { - "type": "string" + "type": "string", + "description": "Unique identifier for the vector store" }, "object": { "type": "string", - "default": "vector_store" + "default": "vector_store", + "description": "Object type identifier, always \"vector_store\"" }, "created_at": { - "type": "integer" + "type": "integer", + "description": "Timestamp when the vector store was created" }, "name": { - "type": "string" + "type": "string", + "description": "(Optional) Name of the vector store" }, "usage_bytes": { "type": "integer", - "default": 0 + "default": 0, + "description": "Storage space used by the vector store in bytes" }, "file_counts": { - "$ref": "#/components/schemas/VectorStoreFileCounts" + "$ref": "#/components/schemas/VectorStoreFileCounts", + "description": "File processing status counts for the vector store" }, "status": { "type": "string", - "default": "completed" + "default": "completed", + "description": "Current status of the vector store" }, "expires_after": { "type": "object", @@ -13716,13 +14323,16 @@ "type": "object" } ] - } + }, + "description": "(Optional) Expiration policy for the vector store" }, "expires_at": { - "type": "integer" + "type": "integer", + "description": "(Optional) Timestamp when the vector store will expire" }, "last_active_at": { - "type": "integer" + "type": "integer", + "description": "(Optional) Timestamp of last activity on the vector store" }, "metadata": { "type": "object", @@ -13747,7 +14357,8 @@ "type": "object" } ] - } + }, + "description": "Set of key-value pairs that can be attached to the vector store" } }, "additionalProperties": false, @@ -13794,15 +14405,18 @@ "type": "object", "properties": { "id": { - "type": "string" + "type": "string", + "description": "Unique identifier of the deleted vector store" }, "object": { "type": "string", - "default": "vector_store.deleted" + "default": "vector_store.deleted", + "description": "Object type identifier for the deletion response" }, "deleted": { "type": "boolean", - "default": true + "default": true, + "description": "Whether the deletion operation was successful" } }, "additionalProperties": false, @@ -13818,15 +14432,18 @@ "type": "object", "properties": { "id": { - "type": "string" + "type": "string", + "description": "Unique identifier of the deleted file" }, "object": { "type": "string", - "default": "vector_store.file.deleted" + "default": "vector_store.file.deleted", + "description": "Object type identifier for the deletion response" }, "deleted": { "type": "boolean", - "default": true + "default": true, + "description": "Whether the deletion operation was successful" } }, "additionalProperties": false, @@ -13990,13 +14607,16 @@ "description": "List of file objects" }, "has_more": { - "type": "boolean" + "type": "boolean", + "description": "Whether there are more files available beyond this page" }, "first_id": { - "type": "string" + "type": "string", + "description": "ID of the first file in the list for pagination" }, "last_id": { - "type": "string" + "type": "string", + "description": "ID of the last file in the list for pagination" }, "object": { "type": "string", @@ -14071,23 +14691,28 @@ "properties": { "object": { "type": "string", - "default": "list" + "default": "list", + "description": "Object type identifier, always \"list\"" }, "data": { "type": "array", "items": { "$ref": "#/components/schemas/VectorStoreFileObject" - } + }, + "description": "List of vector store file objects" }, "first_id": { - "type": "string" + "type": "string", + "description": "(Optional) ID of the first file in the list for pagination" }, "last_id": { - "type": "string" + "type": "string", + "description": "(Optional) ID of the last file in the list for pagination" }, "has_more": { "type": "boolean", - "default": false + "default": false, + "description": "Whether there are more files available beyond this page" } }, "additionalProperties": false, @@ -14097,7 +14722,7 @@ "has_more" ], "title": "VectorStoreListFilesResponse", - "description": "Response from listing vector stores." + "description": "Response from listing files in a vector store." }, "OpenAIModel": { "type": "object", @@ -14148,23 +14773,28 @@ "properties": { "object": { "type": "string", - "default": "list" + "default": "list", + "description": "Object type identifier, always \"list\"" }, "data": { "type": "array", "items": { "$ref": "#/components/schemas/VectorStoreObject" - } + }, + "description": "List of vector store objects" }, "first_id": { - "type": "string" + "type": "string", + "description": "(Optional) ID of the first vector store in the list for pagination" }, "last_id": { - "type": "string" + "type": "string", + "description": "(Optional) ID of the last vector store in the list for pagination" }, "has_more": { "type": "boolean", - "default": false + "default": false, + "description": "Whether there are more vector stores available beyond this page" } }, "additionalProperties": false, @@ -14185,10 +14815,12 @@ "properties": { "type": { "type": "string", - "const": "text" + "const": "text", + "description": "Content type, currently only \"text\" is supported" }, "text": { - "type": "string" + "type": "string", + "description": "The actual text content" } }, "additionalProperties": false, @@ -14196,16 +14828,19 @@ "type", "text" ], - "title": "VectorStoreContent" + "title": "VectorStoreContent", + "description": "Content item from a vector store file or search result." }, "VectorStoreFileContentsResponse": { "type": "object", "properties": { "file_id": { - "type": "string" + "type": "string", + "description": "Unique identifier for the file" }, "filename": { - "type": "string" + "type": "string", + "description": "Name of the file" }, "attributes": { "type": "object", @@ -14230,13 +14865,15 @@ "type": "object" } ] - } + }, + "description": "Key-value attributes associated with the file" }, "content": { "type": "array", "items": { "$ref": "#/components/schemas/VectorStoreContent" - } + }, + "description": "List of content items from the file" } }, "additionalProperties": false, @@ -14300,11 +14937,13 @@ "type": "object", "properties": { "ranker": { - "type": "string" + "type": "string", + "description": "(Optional) Name of the ranking algorithm to use" }, "score_threshold": { "type": "number", - "default": 0.0 + "default": 0.0, + "description": "(Optional) Minimum relevance score threshold for results" } }, "additionalProperties": false, @@ -14329,13 +14968,16 @@ "type": "object", "properties": { "file_id": { - "type": "string" + "type": "string", + "description": "Unique identifier of the file containing the result" }, "filename": { - "type": "string" + "type": "string", + "description": "Name of the file containing the result" }, "score": { - "type": "number" + "type": "number", + "description": "Relevance score for this search result" }, "attributes": { "type": "object", @@ -14351,13 +14993,15 @@ "type": "boolean" } ] - } + }, + "description": "(Optional) Key-value attributes associated with the file" }, "content": { "type": "array", "items": { "$ref": "#/components/schemas/VectorStoreContent" - } + }, + "description": "List of content items matching the search query" } }, "additionalProperties": false, @@ -14375,23 +15019,28 @@ "properties": { "object": { "type": "string", - "default": "vector_store.search_results.page" + "default": "vector_store.search_results.page", + "description": "Object type identifier for the search results page" }, "search_query": { - "type": "string" + "type": "string", + "description": "The original search query that was executed" }, "data": { "type": "array", "items": { "$ref": "#/components/schemas/VectorStoreSearchResponse" - } + }, + "description": "List of search result objects" }, "has_more": { "type": "boolean", - "default": false + "default": false, + "description": "Whether there are more results available beyond this page" }, "next_page": { - "type": "string" + "type": "string", + "description": "(Optional) Token for retrieving the next page of results" } }, "additionalProperties": false, @@ -14402,7 +15051,7 @@ "has_more" ], "title": "VectorStoreSearchResponsePage", - "description": "Response from searching a vector store." + "description": "Paginated response from searching a vector store." }, "OpenaiUpdateVectorStoreRequest": { "type": "object", @@ -14507,11 +15156,13 @@ "type": "object", "properties": { "beta": { - "type": "number" + "type": "number", + "description": "Temperature parameter for the DPO loss" }, "loss_type": { "$ref": "#/components/schemas/DPOLossType", - "default": "sigmoid" + "default": "sigmoid", + "description": "The type of loss function to use for DPO" } }, "additionalProperties": false, @@ -14519,7 +15170,8 @@ "beta", "loss_type" ], - "title": "DPOAlignmentConfig" + "title": "DPOAlignmentConfig", + "description": "Configuration for Direct Preference Optimization (DPO) alignment." }, "DPOLossType": { "type": "string", @@ -14535,27 +15187,34 @@ "type": "object", "properties": { "dataset_id": { - "type": "string" + "type": "string", + "description": "Unique identifier for the training dataset" }, "batch_size": { - "type": "integer" + "type": "integer", + "description": "Number of samples per training batch" }, "shuffle": { - "type": "boolean" + "type": "boolean", + "description": "Whether to shuffle the dataset during training" }, "data_format": { - "$ref": "#/components/schemas/DatasetFormat" + "$ref": "#/components/schemas/DatasetFormat", + "description": "Format of the dataset (instruct or dialog)" }, "validation_dataset_id": { - "type": "string" + "type": "string", + "description": "(Optional) Unique identifier for the validation dataset" }, "packed": { "type": "boolean", - "default": false + "default": false, + "description": "(Optional) Whether to pack multiple samples into a single sequence for efficiency" }, "train_on_input": { "type": "boolean", - "default": false + "default": false, + "description": "(Optional) Whether to compute loss on input tokens as well as output tokens" } }, "additionalProperties": false, @@ -14565,7 +15224,8 @@ "shuffle", "data_format" ], - "title": "DataConfig" + "title": "DataConfig", + "description": "Configuration for training data and data loading." }, "DatasetFormat": { "type": "string", @@ -14573,45 +15233,55 @@ "instruct", "dialog" ], - "title": "DatasetFormat" + "title": "DatasetFormat", + "description": "Format of the training dataset." }, "EfficiencyConfig": { "type": "object", "properties": { "enable_activation_checkpointing": { "type": "boolean", - "default": false + "default": false, + "description": "(Optional) Whether to use activation checkpointing to reduce memory usage" }, "enable_activation_offloading": { "type": "boolean", - "default": false + "default": false, + "description": "(Optional) Whether to offload activations to CPU to save GPU memory" }, "memory_efficient_fsdp_wrap": { "type": "boolean", - "default": false + "default": false, + "description": "(Optional) Whether to use memory-efficient FSDP wrapping" }, "fsdp_cpu_offload": { "type": "boolean", - "default": false + "default": false, + "description": "(Optional) Whether to offload FSDP parameters to CPU" } }, "additionalProperties": false, - "title": "EfficiencyConfig" + "title": "EfficiencyConfig", + "description": "Configuration for memory and compute efficiency optimizations." }, "OptimizerConfig": { "type": "object", "properties": { "optimizer_type": { - "$ref": "#/components/schemas/OptimizerType" + "$ref": "#/components/schemas/OptimizerType", + "description": "Type of optimizer to use (adam, adamw, or sgd)" }, "lr": { - "type": "number" + "type": "number", + "description": "Learning rate for the optimizer" }, "weight_decay": { - "type": "number" + "type": "number", + "description": "Weight decay coefficient for regularization" }, "num_warmup_steps": { - "type": "integer" + "type": "integer", + "description": "Number of steps for learning rate warmup" } }, "additionalProperties": false, @@ -14621,7 +15291,8 @@ "weight_decay", "num_warmup_steps" ], - "title": "OptimizerConfig" + "title": "OptimizerConfig", + "description": "Configuration parameters for the optimization algorithm." }, "OptimizerType": { "type": "string", @@ -14630,38 +15301,47 @@ "adamw", "sgd" ], - "title": "OptimizerType" + "title": "OptimizerType", + "description": "Available optimizer algorithms for training." }, "TrainingConfig": { "type": "object", "properties": { "n_epochs": { - "type": "integer" + "type": "integer", + "description": "Number of training epochs to run" }, "max_steps_per_epoch": { "type": "integer", - "default": 1 + "default": 1, + "description": "Maximum number of steps to run per epoch" }, "gradient_accumulation_steps": { "type": "integer", - "default": 1 + "default": 1, + "description": "Number of steps to accumulate gradients before updating" }, "max_validation_steps": { "type": "integer", - "default": 1 + "default": 1, + "description": "(Optional) Maximum number of validation steps per epoch" }, "data_config": { - "$ref": "#/components/schemas/DataConfig" + "$ref": "#/components/schemas/DataConfig", + "description": "(Optional) Configuration for data loading and formatting" }, "optimizer_config": { - "$ref": "#/components/schemas/OptimizerConfig" + "$ref": "#/components/schemas/OptimizerConfig", + "description": "(Optional) Configuration for the optimization algorithm" }, "efficiency_config": { - "$ref": "#/components/schemas/EfficiencyConfig" + "$ref": "#/components/schemas/EfficiencyConfig", + "description": "(Optional) Configuration for memory and compute optimizations" }, "dtype": { "type": "string", - "default": "bf16" + "default": "bf16", + "description": "(Optional) Data type for model parameters (bf16, fp16, fp32)" } }, "additionalProperties": false, @@ -14670,7 +15350,8 @@ "max_steps_per_epoch", "gradient_accumulation_steps" ], - "title": "TrainingConfig" + "title": "TrainingConfig", + "description": "Comprehensive configuration for the training process." }, "PreferenceOptimizeRequest": { "type": "object", @@ -14774,11 +15455,13 @@ "type": { "type": "string", "const": "default", - "default": "default" + "default": "default", + "description": "Type of query generator, always 'default'" }, "separator": { "type": "string", - "default": " " + "default": " ", + "description": "String separator used to join query terms" } }, "additionalProperties": false, @@ -14786,7 +15469,8 @@ "type", "separator" ], - "title": "DefaultRAGQueryGeneratorConfig" + "title": "DefaultRAGQueryGeneratorConfig", + "description": "Configuration for the default RAG query generator." }, "LLMRAGQueryGeneratorConfig": { "type": "object", @@ -14794,13 +15478,16 @@ "type": { "type": "string", "const": "llm", - "default": "llm" + "default": "llm", + "description": "Type of query generator, always 'llm'" }, "model": { - "type": "string" + "type": "string", + "description": "Name of the language model to use for query generation" }, "template": { - "type": "string" + "type": "string", + "description": "Template string for formatting the query generation prompt" } }, "additionalProperties": false, @@ -14809,7 +15496,8 @@ "model", "template" ], - "title": "LLMRAGQueryGeneratorConfig" + "title": "LLMRAGQueryGeneratorConfig", + "description": "Configuration for the LLM-based RAG query generator." }, "RAGQueryConfig": { "type": "object", @@ -14892,7 +15580,7 @@ "impact_factor": { "type": "number", "default": 60.0, - "description": "The impact factor for RRF scoring. Higher values give more weight to higher-ranked results. Must be greater than 0. Default of 60 is from the original RRF paper (Cormack et al., 2009)." + "description": "The impact factor for RRF scoring. Higher values give more weight to higher-ranked results. Must be greater than 0" } }, "additionalProperties": false, @@ -14947,16 +15635,19 @@ "type": "object", "properties": { "content": { - "$ref": "#/components/schemas/InterleavedContent" + "$ref": "#/components/schemas/InterleavedContent", + "description": "The query content to search for in the indexed documents" }, "vector_db_ids": { "type": "array", "items": { "type": "string" - } + }, + "description": "List of vector database IDs to search within" }, "query_config": { - "$ref": "#/components/schemas/RAGQueryConfig" + "$ref": "#/components/schemas/RAGQueryConfig", + "description": "(Optional) Configuration parameters for the query operation" } }, "additionalProperties": false, @@ -14970,7 +15661,8 @@ "type": "object", "properties": { "content": { - "$ref": "#/components/schemas/InterleavedContent" + "$ref": "#/components/schemas/InterleavedContent", + "description": "(Optional) The retrieved content from the query" }, "metadata": { "type": "object", @@ -14995,14 +15687,16 @@ "type": "object" } ] - } + }, + "description": "Additional metadata about the query result" } }, "additionalProperties": false, "required": [ "metadata" ], - "title": "RAGQueryResult" + "title": "RAGQueryResult", + "description": "Result of a RAG query containing retrieved content and metadata." }, "QueryChunksRequest": { "type": "object", @@ -15056,13 +15750,15 @@ "type": "array", "items": { "$ref": "#/components/schemas/Chunk" - } + }, + "description": "List of content chunks returned from the query" }, "scores": { "type": "array", "items": { "type": "number" - } + }, + "description": "Relevance scores corresponding to each returned chunk" } }, "additionalProperties": false, @@ -15070,7 +15766,8 @@ "chunks", "scores" ], - "title": "QueryChunksResponse" + "title": "QueryChunksResponse", + "description": "Response from querying chunks in a vector database." }, "QueryMetricsRequest": { "type": "object", @@ -15101,10 +15798,12 @@ "type": "object", "properties": { "name": { - "type": "string" + "type": "string", + "description": "The name of the label to match" }, "value": { - "type": "string" + "type": "string", + "description": "The value to match against" }, "operator": { "type": "string", @@ -15114,7 +15813,7 @@ "=~", "!~" ], - "title": "MetricLabelOperator", + "description": "The comparison operator to use for matching", "default": "=" } }, @@ -15124,7 +15823,8 @@ "value", "operator" ], - "title": "MetricLabelMatcher" + "title": "MetricLabelMatcher", + "description": "A matcher for filtering metrics by label values." }, "description": "The label matchers to apply to the metric." } @@ -15140,10 +15840,12 @@ "type": "object", "properties": { "timestamp": { - "type": "integer" + "type": "integer", + "description": "Unix timestamp when the metric value was recorded" }, "value": { - "type": "number" + "type": "number", + "description": "The numeric value of the metric at this timestamp" } }, "additionalProperties": false, @@ -15151,16 +15853,19 @@ "timestamp", "value" ], - "title": "MetricDataPoint" + "title": "MetricDataPoint", + "description": "A single data point in a metric time series." }, "MetricLabel": { "type": "object", "properties": { "name": { - "type": "string" + "type": "string", + "description": "The name of the label" }, "value": { - "type": "string" + "type": "string", + "description": "The value of the label" } }, "additionalProperties": false, @@ -15168,25 +15873,29 @@ "name", "value" ], - "title": "MetricLabel" + "title": "MetricLabel", + "description": "A label associated with a metric." }, "MetricSeries": { "type": "object", "properties": { "metric": { - "type": "string" + "type": "string", + "description": "The name of the metric" }, "labels": { "type": "array", "items": { "$ref": "#/components/schemas/MetricLabel" - } + }, + "description": "List of labels associated with this metric series" }, "values": { "type": "array", "items": { "$ref": "#/components/schemas/MetricDataPoint" - } + }, + "description": "List of data points in chronological order" } }, "additionalProperties": false, @@ -15195,7 +15904,8 @@ "labels", "values" ], - "title": "MetricSeries" + "title": "MetricSeries", + "description": "A time series of metric data points." }, "QueryMetricsResponse": { "type": "object", @@ -15204,23 +15914,27 @@ "type": "array", "items": { "$ref": "#/components/schemas/MetricSeries" - } + }, + "description": "List of metric series matching the query criteria" } }, "additionalProperties": false, "required": [ "data" ], - "title": "QueryMetricsResponse" + "title": "QueryMetricsResponse", + "description": "Response containing metric time series data." }, "QueryCondition": { "type": "object", "properties": { "key": { - "type": "string" + "type": "string", + "description": "The attribute key to filter on" }, "op": { - "$ref": "#/components/schemas/QueryConditionOp" + "$ref": "#/components/schemas/QueryConditionOp", + "description": "The comparison operator to apply" }, "value": { "oneOf": [ @@ -15242,7 +15956,8 @@ { "type": "object" } - ] + ], + "description": "The value to compare against" } }, "additionalProperties": false, @@ -15251,7 +15966,8 @@ "op", "value" ], - "title": "QueryCondition" + "title": "QueryCondition", + "description": "A condition for filtering query results." }, "QueryConditionOp": { "type": "string", @@ -15261,7 +15977,8 @@ "gt", "lt" ], - "title": "QueryConditionOp" + "title": "QueryConditionOp", + "description": "Comparison operators for query conditions." }, "QuerySpansRequest": { "type": "object", @@ -15299,14 +16016,16 @@ "type": "array", "items": { "$ref": "#/components/schemas/Span" - } + }, + "description": "List of spans matching the query criteria" } }, "additionalProperties": false, "required": [ "data" ], - "title": "QuerySpansResponse" + "title": "QuerySpansResponse", + "description": "Response containing a list of spans." }, "QueryTracesRequest": { "type": "object", @@ -15344,14 +16063,16 @@ "type": "array", "items": { "$ref": "#/components/schemas/Trace" - } + }, + "description": "List of traces matching the query criteria" } }, "additionalProperties": false, "required": [ "data" ], - "title": "QueryTracesResponse" + "title": "QueryTracesResponse", + "description": "Response containing a list of traces." }, "RegisterBenchmarkRequest": { "type": "object", @@ -15723,6 +16444,131 @@ ], "title": "RunEvalRequest" }, + "RunModerationRequest": { + "type": "object", + "properties": { + "input": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "array", + "items": { + "type": "string" + } + } + ], + "description": "Input (or inputs) to classify. Can be a single string, an array of strings, or an array of multi-modal input objects similar to other models." + }, + "model": { + "type": "string", + "description": "The content moderation model you would like to use." + } + }, + "additionalProperties": false, + "required": [ + "input", + "model" + ], + "title": "RunModerationRequest" + }, + "ModerationObject": { + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "The unique identifier for the moderation request." + }, + "model": { + "type": "string", + "description": "The model used to generate the moderation results." + }, + "results": { + "type": "array", + "items": { + "$ref": "#/components/schemas/ModerationObjectResults" + }, + "description": "A list of moderation objects" + } + }, + "additionalProperties": false, + "required": [ + "id", + "model", + "results" + ], + "title": "ModerationObject", + "description": "A moderation object." + }, + "ModerationObjectResults": { + "type": "object", + "properties": { + "flagged": { + "type": "boolean", + "description": "Whether any of the below categories are flagged." + }, + "categories": { + "type": "object", + "additionalProperties": { + "type": "boolean" + }, + "description": "A list of the categories, and whether they are flagged or not." + }, + "category_applied_input_types": { + "type": "object", + "additionalProperties": { + "type": "array", + "items": { + "type": "string" + } + }, + "description": "A list of the categories along with the input type(s) that the score applies to." + }, + "category_scores": { + "type": "object", + "additionalProperties": { + "type": "number" + }, + "description": "A list of the categories along with their scores as predicted by model. Required set of categories that need to be in response - violence - violence/graphic - harassment - harassment/threatening - hate - hate/threatening - illicit - illicit/violent - sexual - sexual/minors - self-harm - self-harm/intent - self-harm/instructions" + }, + "user_message": { + "type": "string" + }, + "metadata": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + } + }, + "additionalProperties": false, + "required": [ + "flagged", + "metadata" + ], + "title": "ModerationObjectResults", + "description": "A moderation object." + }, "RunShieldRequest": { "type": "object", "properties": { @@ -15776,11 +16622,13 @@ "type": "object", "properties": { "violation": { - "$ref": "#/components/schemas/SafetyViolation" + "$ref": "#/components/schemas/SafetyViolation", + "description": "(Optional) Safety violation detected by the shield, if any" } }, "additionalProperties": false, - "title": "RunShieldResponse" + "title": "RunShieldResponse", + "description": "Response from running a safety shield." }, "SaveSpansToDatasetRequest": { "type": "object", @@ -15926,20 +16774,23 @@ "type": "object", "properties": { "dataset_id": { - "type": "string" + "type": "string", + "description": "(Optional) The identifier of the dataset that was scored" }, "results": { "type": "object", "additionalProperties": { "$ref": "#/components/schemas/ScoringResult" - } + }, + "description": "A map of scoring function name to ScoringResult" } }, "additionalProperties": false, "required": [ "results" ], - "title": "ScoreBatchResponse" + "title": "ScoreBatchResponse", + "description": "Response from batch scoring operations on datasets." }, "AlgorithmConfig": { "oneOf": [ @@ -15964,33 +16815,41 @@ "type": { "type": "string", "const": "LoRA", - "default": "LoRA" + "default": "LoRA", + "description": "Algorithm type identifier, always \"LoRA\"" }, "lora_attn_modules": { "type": "array", "items": { "type": "string" - } + }, + "description": "List of attention module names to apply LoRA to" }, "apply_lora_to_mlp": { - "type": "boolean" + "type": "boolean", + "description": "Whether to apply LoRA to MLP layers" }, "apply_lora_to_output": { - "type": "boolean" + "type": "boolean", + "description": "Whether to apply LoRA to output projection layers" }, "rank": { - "type": "integer" + "type": "integer", + "description": "Rank of the LoRA adaptation (lower rank = fewer parameters)" }, "alpha": { - "type": "integer" + "type": "integer", + "description": "LoRA scaling parameter that controls adaptation strength" }, "use_dora": { "type": "boolean", - "default": false + "default": false, + "description": "(Optional) Whether to use DoRA (Weight-Decomposed Low-Rank Adaptation)" }, "quantize_base": { "type": "boolean", - "default": false + "default": false, + "description": "(Optional) Whether to quantize the base model weights" } }, "additionalProperties": false, @@ -16002,7 +16861,8 @@ "rank", "alpha" ], - "title": "LoraFinetuningConfig" + "title": "LoraFinetuningConfig", + "description": "Configuration for Low-Rank Adaptation (LoRA) fine-tuning." }, "QATFinetuningConfig": { "type": "object", @@ -16010,13 +16870,16 @@ "type": { "type": "string", "const": "QAT", - "default": "QAT" + "default": "QAT", + "description": "Algorithm type identifier, always \"QAT\"" }, "quantizer_name": { - "type": "string" + "type": "string", + "description": "Name of the quantization algorithm to use" }, "group_size": { - "type": "integer" + "type": "integer", + "description": "Size of groups for grouped quantization" } }, "additionalProperties": false, @@ -16025,7 +16888,8 @@ "quantizer_name", "group_size" ], - "title": "QATFinetuningConfig" + "title": "QATFinetuningConfig", + "description": "Configuration for Quantization-Aware Training (QAT) fine-tuning." }, "SupervisedFineTuneRequest": { "type": "object", @@ -16119,7 +16983,8 @@ "type": "array", "items": { "$ref": "#/components/schemas/Message" - } + }, + "description": "List of conversation messages to use as input for synthetic data generation" }, "filtering_function": { "type": "string", @@ -16131,11 +16996,11 @@ "top_k_top_p", "sigmoid" ], - "title": "FilteringFunction", - "description": "The type of filtering function." + "description": "Type of filtering to apply to generated synthetic data samples" }, "model": { - "type": "string" + "type": "string", + "description": "(Optional) The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint" } }, "additionalProperties": false, @@ -16174,7 +17039,8 @@ } ] } - } + }, + "description": "List of generated synthetic data samples that passed the filtering criteria" }, "statistics": { "type": "object", @@ -16199,7 +17065,8 @@ "type": "object" } ] - } + }, + "description": "(Optional) Statistical information about the generation process and filtering results" } }, "additionalProperties": false, @@ -16213,14 +17080,16 @@ "type": "object", "properties": { "version": { - "type": "string" + "type": "string", + "description": "Version number of the service" } }, "additionalProperties": false, "required": [ "version" ], - "title": "VersionInfo" + "title": "VersionInfo", + "description": "Version information for the service." } }, "responses": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 9ac29034d..9c0fba554 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -999,6 +999,31 @@ paths: required: true schema: type: string + delete: + responses: + '200': + description: OK + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Shields + description: Unregister a shield. + parameters: + - name: identifier + in: path + description: >- + The identifier of the shield to unregister. + required: true + schema: + type: string /v1/telemetry/traces/{trace_id}/spans/{span_id}: get: responses: @@ -1323,7 +1348,8 @@ paths: get: responses: '200': - description: A HealthInfo. + description: >- + Health information indicating if the service is operational. content: application/json: schema: @@ -1340,7 +1366,8 @@ paths: $ref: '#/components/responses/DefaultError' tags: - Inspect - description: Get the health of the service. + description: >- + Get the current health status of the service. parameters: [] /v1/tool-runtime/rag-tool/insert: post: @@ -1360,7 +1387,7 @@ paths: tags: - ToolRuntime description: >- - Index documents so they can be used by the RAG system + Index documents so they can be used by the RAG system. parameters: [] requestBody: content: @@ -1984,7 +2011,8 @@ paths: get: responses: '200': - description: A ListRoutesResponse. + description: >- + Response containing information about all available routes. content: application/json: schema: @@ -2001,7 +2029,8 @@ paths: $ref: '#/components/responses/DefaultError' tags: - Inspect - description: List all routes. + description: >- + List all available API routes with their methods and implementing providers. parameters: [] /v1/tool-runtime/list-tools: get: @@ -2324,26 +2353,41 @@ paths: type: string - name: limit in: query + description: >- + (Optional) A limit on the number of objects to be returned. Limit can + range between 1 and 100, and the default is 20. required: false schema: type: integer - name: order in: query + description: >- + (Optional) Sort order by the `created_at` timestamp of the objects. `asc` + for ascending order and `desc` for descending order. required: false schema: type: string - name: after in: query + description: >- + (Optional) A cursor for use in pagination. `after` is an object ID that + defines your place in the list. required: false schema: type: string - name: before in: query + description: >- + (Optional) A cursor for use in pagination. `before` is an object ID that + defines your place in the list. required: false schema: type: string - name: filter in: query + description: >- + (Optional) Filter by file status to only return files with the specified + status. required: false schema: $ref: '#/components/schemas/VectorStoreFileStatus' @@ -3071,7 +3115,8 @@ paths: post: responses: '200': - description: OK + description: >- + RAGQueryResult containing the retrieved content and metadata content: application/json: schema: @@ -3089,7 +3134,7 @@ paths: tags: - ToolRuntime description: >- - Query the RAG system for context; typically invoked by the agent + Query the RAG system for context; typically invoked by the agent. parameters: [] requestBody: content: @@ -3313,6 +3358,36 @@ paths: schema: $ref: '#/components/schemas/RunEvalRequest' required: true + /v1/openai/v1/moderations: + post: + responses: + '200': + description: A moderation object. + content: + application/json: + schema: + $ref: '#/components/schemas/ModerationObject' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Safety + description: >- + Classifies if text and/or image inputs are potentially harmful. + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/RunModerationRequest' + required: true /v1/safety/run-shield: post: responses: @@ -3459,7 +3534,8 @@ paths: post: responses: '200': - description: OK + description: >- + Response containing filtered synthetic data samples and optional statistics content: application/json: schema: @@ -3476,7 +3552,8 @@ paths: $ref: '#/components/responses/DefaultError' tags: - SyntheticDataGeneration (Coming Soon) - description: '' + description: >- + Generate synthetic data based on input dialogs and apply filtering. parameters: [] requestBody: content: @@ -3488,7 +3565,8 @@ paths: get: responses: '200': - description: A VersionInfo. + description: >- + Version information containing the service version number. content: application/json: schema: @@ -3636,10 +3714,15 @@ components: type: string const: greedy default: greedy + description: >- + Must be "greedy" to identify this sampling strategy additionalProperties: false required: - type title: GreedySamplingStrategy + description: >- + Greedy sampling strategy that selects the highest probability token at each + step. ImageContentItem: type: object properties: @@ -3997,13 +4080,19 @@ components: type: string const: top_k default: top_k + description: >- + Must be "top_k" to identify this sampling strategy top_k: type: integer + description: >- + Number of top tokens to consider for sampling. Must be at least 1 additionalProperties: false required: - type - top_k title: TopKSamplingStrategy + description: >- + Top-k sampling strategy that restricts sampling to the k most likely tokens. TopPSamplingStrategy: type: object properties: @@ -4011,24 +4100,35 @@ components: type: string const: top_p default: top_p + description: >- + Must be "top_p" to identify this sampling strategy temperature: type: number + description: >- + Controls randomness in sampling. Higher values increase randomness top_p: type: number default: 0.95 + description: >- + Cumulative probability threshold for nucleus sampling. Defaults to 0.95 additionalProperties: false required: - type title: TopPSamplingStrategy + description: >- + Top-p (nucleus) sampling strategy that samples from the smallest set of tokens + with cumulative probability >= p. URL: type: object properties: uri: type: string + description: The URL string pointing to the resource additionalProperties: false required: - uri title: URL + description: A URL reference to external content. UserMessage: type: object properties: @@ -4111,10 +4211,14 @@ components: type: array items: $ref: '#/components/schemas/ChatCompletionResponse' + description: >- + List of chat completion responses, one for each conversation in the batch additionalProperties: false required: - batch title: BatchChatCompletionResponse + description: >- + Response from a batch chat completion request. ChatCompletionResponse: type: object properties: @@ -4122,6 +4226,8 @@ components: type: array items: $ref: '#/components/schemas/MetricInResponse' + description: >- + (Optional) List of metrics associated with the API response completion_message: $ref: '#/components/schemas/CompletionMessage' description: The complete response message @@ -4141,17 +4247,23 @@ components: properties: metric: type: string + description: The name of the metric value: oneOf: - type: integer - type: number + description: The numeric value of the metric unit: type: string + description: >- + (Optional) The unit of measurement for the metric value additionalProperties: false required: - metric - value title: MetricInResponse + description: >- + A metric value included in API responses. TokenLogProbs: type: object properties: @@ -4211,10 +4323,14 @@ components: type: array items: $ref: '#/components/schemas/CompletionResponse' + description: >- + List of completion responses, one for each input in the batch additionalProperties: false required: - batch title: BatchCompletionResponse + description: >- + Response from a batch completion request. CompletionResponse: type: object properties: @@ -4222,6 +4338,8 @@ components: type: array items: $ref: '#/components/schemas/MetricInResponse' + description: >- + (Optional) List of metrics associated with the API response content: type: string description: The generated completion text @@ -4375,6 +4493,8 @@ components: type: array items: $ref: '#/components/schemas/MetricInResponse' + description: >- + (Optional) List of metrics associated with the API response event: $ref: '#/components/schemas/ChatCompletionResponseEvent' description: The event containing the new content @@ -4402,14 +4522,19 @@ components: type: string const: image default: image + description: >- + Discriminator type of the delta. Always "image" image: type: string contentEncoding: base64 + description: The incremental image data as bytes additionalProperties: false required: - type - image title: ImageDelta + description: >- + An image content delta for streaming responses. TextDelta: type: object properties: @@ -4417,13 +4542,18 @@ components: type: string const: text default: text + description: >- + Discriminator type of the delta. Always "text" text: type: string + description: The incremental text content additionalProperties: false required: - type - text title: TextDelta + description: >- + A text content delta for streaming responses. ToolCallDelta: type: object properties: @@ -4431,10 +4561,14 @@ components: type: string const: tool_call default: tool_call + description: >- + Discriminator type of the delta. Always "tool_call" tool_call: oneOf: - type: string - $ref: '#/components/schemas/ToolCall' + description: >- + Either an in-progress tool call string or the final parsed tool call parse_status: type: string enum: @@ -4442,13 +4576,15 @@ components: - in_progress - failed - succeeded - title: ToolCallParseStatus + description: Current parsing status of the tool call additionalProperties: false required: - type - tool_call - parse_status title: ToolCallDelta + description: >- + A tool call content delta for streaming responses. CompletionRequest: type: object properties: @@ -4498,6 +4634,8 @@ components: type: array items: $ref: '#/components/schemas/MetricInResponse' + description: >- + (Optional) List of metrics associated with the API response delta: type: string description: >- @@ -4622,12 +4760,17 @@ components: properties: name: type: string + description: Name of the tool description: type: string + description: >- + (Optional) Human-readable description of what the tool does parameters: type: array items: $ref: '#/components/schemas/ToolParameter' + description: >- + (Optional) List of parameters this tool accepts metadata: type: object additionalProperties: @@ -4638,22 +4781,33 @@ components: - type: string - type: array - type: object + description: >- + (Optional) Additional metadata about the tool additionalProperties: false required: - name title: ToolDef + description: >- + Tool definition used in runtime contexts. ToolParameter: type: object properties: name: type: string + description: Name of the parameter parameter_type: type: string + description: >- + Type of the parameter (e.g., string, integer) description: type: string + description: >- + Human-readable description of what the parameter does required: type: boolean default: true + description: >- + Whether this parameter is required for tool invocation default: oneOf: - type: 'null' @@ -4662,6 +4816,8 @@ components: - type: string - type: array - type: object + description: >- + (Optional) Default value for the parameter if not provided additionalProperties: false required: - name @@ -4669,6 +4825,7 @@ components: - description - required title: ToolParameter + description: Parameter definition for a tool. CreateAgentRequest: type: object properties: @@ -4684,10 +4841,13 @@ components: properties: agent_id: type: string + description: Unique identifier for the created agent additionalProperties: false required: - agent_id title: AgentCreateResponse + description: >- + Response returned when creating a new agent. CreateAgentSessionRequest: type: object properties: @@ -4703,10 +4863,14 @@ components: properties: session_id: type: string + description: >- + Unique identifier for the created session additionalProperties: false required: - session_id title: AgentSessionCreateResponse + description: >- + Response returned when creating a new agent session. CreateAgentTurnRequest: type: object properties: @@ -4853,8 +5017,11 @@ components: properties: violation_level: $ref: '#/components/schemas/ViolationLevel' + description: Severity level of the violation user_message: type: string + description: >- + (Optional) Message to convey to the user about the violation metadata: type: object additionalProperties: @@ -4865,11 +5032,16 @@ components: - type: string - type: array - type: object + description: >- + Additional metadata including specific violation codes for debugging and + telemetry additionalProperties: false required: - violation_level - metadata title: SafetyViolation + description: >- + Details of a safety violation detected by content moderation. ShieldCallStep: type: object properties: @@ -4960,6 +5132,8 @@ components: properties: call_id: type: string + description: >- + Unique identifier for the tool call this response is for tool_name: oneOf: - type: string @@ -4970,8 +5144,10 @@ components: - code_interpreter title: BuiltinTool - type: string + description: Name of the tool that was invoked content: $ref: '#/components/schemas/InterleavedContent' + description: The response content from the tool metadata: type: object additionalProperties: @@ -4982,25 +5158,34 @@ components: - type: string - type: array - type: object + description: >- + (Optional) Additional metadata about the tool response additionalProperties: false required: - call_id - tool_name - content title: ToolResponse + description: Response from a tool invocation. Turn: type: object properties: turn_id: type: string + description: >- + Unique identifier for the turn within a session session_id: type: string + description: >- + Unique identifier for the conversation session input_messages: type: array items: oneOf: - $ref: '#/components/schemas/UserMessage' - $ref: '#/components/schemas/ToolResponseMessage' + description: >- + List of messages that initiated this turn steps: type: array items: @@ -5016,8 +5201,12 @@ components: tool_execution: '#/components/schemas/ToolExecutionStep' shield_call: '#/components/schemas/ShieldCallStep' memory_retrieval: '#/components/schemas/MemoryRetrievalStep' + description: >- + Ordered list of processing steps executed during this turn output_message: $ref: '#/components/schemas/CompletionMessage' + description: >- + The model's generated response containing content and metadata output_attachments: type: array items: @@ -5041,12 +5230,17 @@ components: - mime_type title: Attachment description: An attachment to an agent turn. + description: >- + (Optional) Files or media attached to the agent's response started_at: type: string format: date-time + description: Timestamp when the turn began completed_at: type: string format: date-time + description: >- + (Optional) Timestamp when the turn finished, if completed additionalProperties: false required: - turn_id @@ -5065,15 +5259,20 @@ components: - warn - error title: ViolationLevel + description: Severity level of a safety violation. AgentTurnResponseEvent: type: object properties: payload: $ref: '#/components/schemas/AgentTurnResponseEventPayload' + description: >- + Event-specific payload containing event data additionalProperties: false required: - payload title: AgentTurnResponseEvent + description: >- + An event in an agent turn response stream. AgentTurnResponseEventPayload: oneOf: - $ref: '#/components/schemas/AgentTurnResponseStepStartPayload' @@ -5103,9 +5302,9 @@ components: - turn_start - turn_complete - turn_awaiting_input - title: AgentTurnResponseEventType const: step_complete default: step_complete + description: Type of event being reported step_type: type: string enum: @@ -5113,10 +5312,11 @@ components: - tool_execution - shield_call - memory_retrieval - title: StepType - description: Type of the step in an agent turn. + description: Type of step being executed step_id: type: string + description: >- + Unique identifier for the step within a turn step_details: oneOf: - $ref: '#/components/schemas/InferenceStep' @@ -5130,6 +5330,7 @@ components: tool_execution: '#/components/schemas/ToolExecutionStep' shield_call: '#/components/schemas/ShieldCallStep' memory_retrieval: '#/components/schemas/MemoryRetrievalStep' + description: Complete details of the executed step additionalProperties: false required: - event_type @@ -5137,6 +5338,8 @@ components: - step_id - step_details title: AgentTurnResponseStepCompletePayload + description: >- + Payload for step completion events in agent turn responses. AgentTurnResponseStepProgressPayload: type: object properties: @@ -5149,9 +5352,9 @@ components: - turn_start - turn_complete - turn_awaiting_input - title: AgentTurnResponseEventType const: step_progress default: step_progress + description: Type of event being reported step_type: type: string enum: @@ -5159,12 +5362,15 @@ components: - tool_execution - shield_call - memory_retrieval - title: StepType - description: Type of the step in an agent turn. + description: Type of step being executed step_id: type: string + description: >- + Unique identifier for the step within a turn delta: $ref: '#/components/schemas/ContentDelta' + description: >- + Incremental content changes during step execution additionalProperties: false required: - event_type @@ -5172,6 +5378,8 @@ components: - step_id - delta title: AgentTurnResponseStepProgressPayload + description: >- + Payload for step progress events in agent turn responses. AgentTurnResponseStepStartPayload: type: object properties: @@ -5184,9 +5392,9 @@ components: - turn_start - turn_complete - turn_awaiting_input - title: AgentTurnResponseEventType const: step_start default: step_start + description: Type of event being reported step_type: type: string enum: @@ -5194,10 +5402,11 @@ components: - tool_execution - shield_call - memory_retrieval - title: StepType - description: Type of the step in an agent turn. + description: Type of step being executed step_id: type: string + description: >- + Unique identifier for the step within a turn metadata: type: object additionalProperties: @@ -5208,22 +5417,28 @@ components: - type: string - type: array - type: object + description: >- + (Optional) Additional metadata for the step additionalProperties: false required: - event_type - step_type - step_id title: AgentTurnResponseStepStartPayload + description: >- + Payload for step start events in agent turn responses. AgentTurnResponseStreamChunk: type: object properties: event: $ref: '#/components/schemas/AgentTurnResponseEvent' + description: >- + Individual event in the agent turn response stream additionalProperties: false required: - event title: AgentTurnResponseStreamChunk - description: streamed agent turn completion response. + description: Streamed agent turn completion response. "AgentTurnResponseTurnAwaitingInputPayload": type: object properties: @@ -5236,17 +5451,21 @@ components: - turn_start - turn_complete - turn_awaiting_input - title: AgentTurnResponseEventType const: turn_awaiting_input default: turn_awaiting_input + description: Type of event being reported turn: $ref: '#/components/schemas/Turn' + description: >- + Turn data when waiting for external tool responses additionalProperties: false required: - event_type - turn title: >- AgentTurnResponseTurnAwaitingInputPayload + description: >- + Payload for turn awaiting input events in agent turn responses. AgentTurnResponseTurnCompletePayload: type: object properties: @@ -5259,16 +5478,20 @@ components: - turn_start - turn_complete - turn_awaiting_input - title: AgentTurnResponseEventType const: turn_complete default: turn_complete + description: Type of event being reported turn: $ref: '#/components/schemas/Turn' + description: >- + Complete turn data including all steps and results additionalProperties: false required: - event_type - turn title: AgentTurnResponseTurnCompletePayload + description: >- + Payload for turn completion events in agent turn responses. AgentTurnResponseTurnStartPayload: type: object properties: @@ -5281,16 +5504,20 @@ components: - turn_start - turn_complete - turn_awaiting_input - title: AgentTurnResponseEventType const: turn_start default: turn_start + description: Type of event being reported turn_id: type: string + description: >- + Unique identifier for the turn within a session additionalProperties: false required: - event_type - turn_id title: AgentTurnResponseTurnStartPayload + description: >- + Payload for turn start events in agent turn responses. OpenAIResponseAnnotationCitation: type: object properties: @@ -5298,14 +5525,22 @@ components: type: string const: url_citation default: url_citation + description: >- + Annotation type identifier, always "url_citation" end_index: type: integer + description: >- + End position of the citation span in the content start_index: type: integer + description: >- + Start position of the citation span in the content title: type: string + description: Title of the referenced web resource url: type: string + description: URL of the referenced web resource additionalProperties: false required: - type @@ -5314,6 +5549,8 @@ components: - title - url title: OpenAIResponseAnnotationCitation + description: >- + URL citation annotation for referencing external web resources. "OpenAIResponseAnnotationContainerFileCitation": type: object properties: @@ -5348,12 +5585,18 @@ components: type: string const: file_citation default: file_citation + description: >- + Annotation type identifier, always "file_citation" file_id: type: string + description: Unique identifier of the referenced file filename: type: string + description: Name of the referenced file index: type: integer + description: >- + Position index of the citation within the content additionalProperties: false required: - type @@ -5361,6 +5604,8 @@ components: - filename - index title: OpenAIResponseAnnotationFileCitation + description: >- + File citation annotation for referencing specific files in response content. OpenAIResponseAnnotationFilePath: type: object properties: @@ -5444,31 +5689,43 @@ components: - type: string const: auto default: auto + description: >- + Level of detail for image processing, can be "low", "high", or "auto" type: type: string const: input_image default: input_image + description: >- + Content type identifier, always "input_image" image_url: type: string + description: (Optional) URL of the image content additionalProperties: false required: - detail - type title: OpenAIResponseInputMessageContentImage + description: >- + Image content for input messages in OpenAI response format. OpenAIResponseInputMessageContentText: type: object properties: text: type: string + description: The text content of the input message type: type: string const: input_text default: input_text + description: >- + Content type identifier, always "input_text" additionalProperties: false required: - text - type title: OpenAIResponseInputMessageContentText + description: >- + Text content for input messages in OpenAI response format. OpenAIResponseInputTool: oneOf: - $ref: '#/components/schemas/OpenAIResponseInputToolWebSearch' @@ -5489,10 +5746,14 @@ components: type: string const: file_search default: file_search + description: >- + Tool type identifier, always "file_search" vector_store_ids: type: array items: type: string + description: >- + List of vector store identifiers to search within filters: type: object additionalProperties: @@ -5503,24 +5764,35 @@ components: - type: string - type: array - type: object + description: >- + (Optional) Additional filters to apply to the search max_num_results: type: integer default: 10 + description: >- + (Optional) Maximum number of search results to return (1-50) ranking_options: type: object properties: ranker: type: string + description: >- + (Optional) Name of the ranking algorithm to use score_threshold: type: number default: 0.0 + description: >- + (Optional) Minimum relevance score threshold for results additionalProperties: false - title: SearchRankingOptions + description: >- + (Optional) Options for ranking and scoring search results additionalProperties: false required: - type - vector_store_ids title: OpenAIResponseInputToolFileSearch + description: >- + File search tool configuration for OpenAI response inputs. OpenAIResponseInputToolFunction: type: object properties: @@ -5528,10 +5800,14 @@ components: type: string const: function default: function + description: Tool type identifier, always "function" name: type: string + description: Name of the function that can be called description: type: string + description: >- + (Optional) Description of what the function does parameters: type: object additionalProperties: @@ -5542,13 +5818,19 @@ components: - type: string - type: array - type: object + description: >- + (Optional) JSON schema defining the function's parameters strict: type: boolean + description: >- + (Optional) Whether to enforce strict parameter validation additionalProperties: false required: - type - name title: OpenAIResponseInputToolFunction + description: >- + Function tool configuration for OpenAI response inputs. OpenAIResponseInputToolMCP: type: object properties: @@ -5556,10 +5838,13 @@ components: type: string const: mcp default: mcp + description: Tool type identifier, always "mcp" server_label: type: string + description: Label to identify this MCP server server_url: type: string + description: URL endpoint of the MCP server headers: type: object additionalProperties: @@ -5570,6 +5855,8 @@ components: - type: string - type: array - type: object + description: >- + (Optional) HTTP headers to include when connecting to the server require_approval: oneOf: - type: string @@ -5582,13 +5869,21 @@ components: type: array items: type: string + description: >- + (Optional) List of tool names that always require approval never: type: array items: type: string + description: >- + (Optional) List of tool names that never require approval additionalProperties: false title: ApprovalFilter + description: >- + Filter configuration for MCP tool approval requirements. default: never + description: >- + Approval requirement for tool calls ("always", "never", or filter) allowed_tools: oneOf: - type: array @@ -5600,8 +5895,14 @@ components: type: array items: type: string + description: >- + (Optional) List of specific tool names that are allowed additionalProperties: false title: AllowedToolsFilter + description: >- + Filter configuration for restricting which MCP tools can be used. + description: >- + (Optional) Restriction on which tools can be used from this server additionalProperties: false required: - type @@ -5609,6 +5910,8 @@ components: - server_url - require_approval title: OpenAIResponseInputToolMCP + description: >- + Model Context Protocol (MCP) tool configuration for OpenAI response inputs. OpenAIResponseInputToolWebSearch: type: object properties: @@ -5621,13 +5924,18 @@ components: - type: string const: web_search_preview_2025_03_11 default: web_search + description: Web search tool type variant to use search_context_size: type: string default: medium + description: >- + (Optional) Size of search context, must be "low", "medium", or "high" additionalProperties: false required: - type title: OpenAIResponseInputToolWebSearch + description: >- + Web search tool configuration for OpenAI response inputs. OpenAIResponseMessage: type: object properties: @@ -5693,16 +6001,22 @@ components: properties: id: type: string + description: Unique identifier for this tool call queries: type: array items: type: string + description: List of search queries executed status: type: string + description: >- + Current status of the file search operation type: type: string const: file_search_call default: file_search_call + description: >- + Tool call type identifier, always "file_search_call" results: type: array items: @@ -5715,6 +6029,8 @@ components: - type: string - type: array - type: object + description: >- + (Optional) Search results returned by the file search operation additionalProperties: false required: - id @@ -5723,23 +6039,35 @@ components: - type title: >- OpenAIResponseOutputMessageFileSearchToolCall + description: >- + File search tool call output message for OpenAI responses. "OpenAIResponseOutputMessageFunctionToolCall": type: object properties: call_id: type: string + description: Unique identifier for the function call name: type: string + description: Name of the function being called arguments: type: string + description: >- + JSON string containing the function arguments type: type: string const: function_call default: function_call + description: >- + Tool call type identifier, always "function_call" id: type: string + description: >- + (Optional) Additional identifier for the tool call status: type: string + description: >- + (Optional) Current status of the function call execution additionalProperties: false required: - call_id @@ -5748,17 +6076,24 @@ components: - type title: >- OpenAIResponseOutputMessageFunctionToolCall + description: >- + Function tool call output message for OpenAI responses. "OpenAIResponseOutputMessageWebSearchToolCall": type: object properties: id: type: string + description: Unique identifier for this tool call status: type: string + description: >- + Current status of the web search operation type: type: string const: web_search_call default: web_search_call + description: >- + Tool call type identifier, always "web_search_call" additionalProperties: false required: - id @@ -5766,6 +6101,8 @@ components: - type title: >- OpenAIResponseOutputMessageWebSearchToolCall + description: >- + Web search tool call output message for OpenAI responses. OpenAIResponseText: type: object properties: @@ -5812,11 +6149,12 @@ components: additionalProperties: false required: - type - title: OpenAIResponseTextFormat description: >- - Configuration for Responses API text format. + (Optional) Text format configuration specifying output format requirements additionalProperties: false title: OpenAIResponseText + description: >- + Text response configuration for OpenAI responses. CreateOpenaiResponseRequest: type: object properties: @@ -5862,49 +6200,81 @@ components: properties: code: type: string + description: >- + Error code identifying the type of failure message: type: string + description: >- + Human-readable error message describing the failure additionalProperties: false required: - code - message title: OpenAIResponseError + description: >- + Error details for failed OpenAI response requests. OpenAIResponseObject: type: object properties: created_at: type: integer + description: >- + Unix timestamp when the response was created error: $ref: '#/components/schemas/OpenAIResponseError' + description: >- + (Optional) Error details if the response generation failed id: type: string + description: Unique identifier for this response model: type: string + description: Model identifier used for generation object: type: string const: response default: response + description: >- + Object type identifier, always "response" output: type: array items: $ref: '#/components/schemas/OpenAIResponseOutput' + description: >- + List of generated output items (messages, tool calls, etc.) parallel_tool_calls: type: boolean default: false + description: >- + Whether tool calls can be executed in parallel previous_response_id: type: string + description: >- + (Optional) ID of the previous response in a conversation status: type: string + description: >- + Current status of the response generation temperature: type: number + description: >- + (Optional) Sampling temperature used for generation text: $ref: '#/components/schemas/OpenAIResponseText' + description: >- + Text formatting configuration for the response top_p: type: number + description: >- + (Optional) Nucleus sampling parameter used for generation truncation: type: string + description: >- + (Optional) Truncation strategy applied to the response user: type: string + description: >- + (Optional) User identifier associated with the request additionalProperties: false required: - created_at @@ -5916,6 +6286,8 @@ components: - status - text title: OpenAIResponseObject + description: >- + Complete OpenAI response object containing generation results and metadata. OpenAIResponseOutput: oneOf: - $ref: '#/components/schemas/OpenAIResponseMessage' @@ -5938,20 +6310,32 @@ components: properties: id: type: string + description: Unique identifier for this MCP call type: type: string const: mcp_call default: mcp_call + description: >- + Tool call type identifier, always "mcp_call" arguments: type: string + description: >- + JSON string containing the MCP call arguments name: type: string + description: Name of the MCP method being called server_label: type: string + description: >- + Label identifying the MCP server handling the call error: type: string + description: >- + (Optional) Error message if the MCP call failed output: type: string + description: >- + (Optional) Output result from the successful MCP call additionalProperties: false required: - id @@ -5960,17 +6344,25 @@ components: - name - server_label title: OpenAIResponseOutputMessageMCPCall + description: >- + Model Context Protocol (MCP) call output message for OpenAI responses. OpenAIResponseOutputMessageMCPListTools: type: object properties: id: type: string + description: >- + Unique identifier for this MCP list tools operation type: type: string const: mcp_list_tools default: mcp_list_tools + description: >- + Tool call type identifier, always "mcp_list_tools" server_label: type: string + description: >- + Label identifying the MCP server providing the tools tools: type: array items: @@ -5986,15 +6378,24 @@ components: - type: string - type: array - type: object + description: >- + JSON schema defining the tool's input parameters name: type: string + description: Name of the tool description: type: string + description: >- + (Optional) Description of what the tool does additionalProperties: false required: - input_schema - name title: MCPListToolsTool + description: >- + Tool definition returned by MCP list tools operation. + description: >- + List of available tools provided by the MCP server additionalProperties: false required: - id @@ -6002,6 +6403,8 @@ components: - server_label - tools title: OpenAIResponseOutputMessageMCPListTools + description: >- + MCP list tools output message containing available tools from an MCP server. OpenAIResponseObjectStream: oneOf: - $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated' @@ -6050,46 +6453,66 @@ components: properties: response: $ref: '#/components/schemas/OpenAIResponseObject' + description: The completed response object type: type: string const: response.completed default: response.completed + description: >- + Event type identifier, always "response.completed" additionalProperties: false required: - response - type title: >- OpenAIResponseObjectStreamResponseCompleted + description: >- + Streaming event indicating a response has been completed. "OpenAIResponseObjectStreamResponseCreated": type: object properties: response: $ref: '#/components/schemas/OpenAIResponseObject' + description: The newly created response object type: type: string const: response.created default: response.created + description: >- + Event type identifier, always "response.created" additionalProperties: false required: - response - type title: >- OpenAIResponseObjectStreamResponseCreated + description: >- + Streaming event indicating a new response has been created. "OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta": type: object properties: delta: type: string + description: >- + Incremental function call arguments being added item_id: type: string + description: >- + Unique identifier of the function call being updated output_index: type: integer + description: >- + Index position of the item in the output list sequence_number: type: integer + description: >- + Sequential number for ordering streaming events type: type: string const: response.function_call_arguments.delta default: response.function_call_arguments.delta + description: >- + Event type identifier, always "response.function_call_arguments.delta" additionalProperties: false required: - delta @@ -6099,21 +6522,33 @@ components: - type title: >- OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta + description: >- + Streaming event for incremental function call argument updates. "OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone": type: object properties: arguments: type: string + description: >- + Final complete arguments JSON string for the function call item_id: type: string + description: >- + Unique identifier of the completed function call output_index: type: integer + description: >- + Index position of the item in the output list sequence_number: type: integer + description: >- + Sequential number for ordering streaming events type: type: string const: response.function_call_arguments.done default: response.function_call_arguments.done + description: >- + Event type identifier, always "response.function_call_arguments.done" additionalProperties: false required: - arguments @@ -6123,6 +6558,8 @@ components: - type title: >- OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone + description: >- + Streaming event for when function call arguments are completed. "OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta": type: object properties: @@ -6176,44 +6613,61 @@ components: properties: sequence_number: type: integer + description: >- + Sequential number for ordering streaming events type: type: string const: response.mcp_call.completed default: response.mcp_call.completed + description: >- + Event type identifier, always "response.mcp_call.completed" additionalProperties: false required: - sequence_number - type title: >- OpenAIResponseObjectStreamResponseMcpCallCompleted + description: Streaming event for completed MCP calls. "OpenAIResponseObjectStreamResponseMcpCallFailed": type: object properties: sequence_number: type: integer + description: >- + Sequential number for ordering streaming events type: type: string const: response.mcp_call.failed default: response.mcp_call.failed + description: >- + Event type identifier, always "response.mcp_call.failed" additionalProperties: false required: - sequence_number - type title: >- OpenAIResponseObjectStreamResponseMcpCallFailed + description: Streaming event for failed MCP calls. "OpenAIResponseObjectStreamResponseMcpCallInProgress": type: object properties: item_id: type: string + description: Unique identifier of the MCP call output_index: type: integer + description: >- + Index position of the item in the output list sequence_number: type: integer + description: >- + Sequential number for ordering streaming events type: type: string const: response.mcp_call.in_progress default: response.mcp_call.in_progress + description: >- + Event type identifier, always "response.mcp_call.in_progress" additionalProperties: false required: - item_id @@ -6222,6 +6676,8 @@ components: - type title: >- OpenAIResponseObjectStreamResponseMcpCallInProgress + description: >- + Streaming event for MCP calls in progress. "OpenAIResponseObjectStreamResponseMcpListToolsCompleted": type: object properties: @@ -6272,16 +6728,26 @@ components: properties: response_id: type: string + description: >- + Unique identifier of the response containing this output item: $ref: '#/components/schemas/OpenAIResponseOutput' + description: >- + The output item that was added (message, tool call, etc.) output_index: type: integer + description: >- + Index position of this item in the output list sequence_number: type: integer + description: >- + Sequential number for ordering streaming events type: type: string const: response.output_item.added default: response.output_item.added + description: >- + Event type identifier, always "response.output_item.added" additionalProperties: false required: - response_id @@ -6291,21 +6757,33 @@ components: - type title: >- OpenAIResponseObjectStreamResponseOutputItemAdded + description: >- + Streaming event for when a new output item is added to the response. "OpenAIResponseObjectStreamResponseOutputItemDone": type: object properties: response_id: type: string + description: >- + Unique identifier of the response containing this output item: $ref: '#/components/schemas/OpenAIResponseOutput' + description: >- + The completed output item (message, tool call, etc.) output_index: type: integer + description: >- + Index position of this item in the output list sequence_number: type: integer + description: >- + Sequential number for ordering streaming events type: type: string const: response.output_item.done default: response.output_item.done + description: >- + Event type identifier, always "response.output_item.done" additionalProperties: false required: - response_id @@ -6315,23 +6793,35 @@ components: - type title: >- OpenAIResponseObjectStreamResponseOutputItemDone + description: >- + Streaming event for when an output item is completed. "OpenAIResponseObjectStreamResponseOutputTextDelta": type: object properties: content_index: type: integer + description: Index position within the text content delta: type: string + description: Incremental text content being added item_id: type: string + description: >- + Unique identifier of the output item being updated output_index: type: integer + description: >- + Index position of the item in the output list sequence_number: type: integer + description: >- + Sequential number for ordering streaming events type: type: string const: response.output_text.delta default: response.output_text.delta + description: >- + Event type identifier, always "response.output_text.delta" additionalProperties: false required: - content_index @@ -6342,23 +6832,36 @@ components: - type title: >- OpenAIResponseObjectStreamResponseOutputTextDelta + description: >- + Streaming event for incremental text content updates. "OpenAIResponseObjectStreamResponseOutputTextDone": type: object properties: content_index: type: integer + description: Index position within the text content text: type: string + description: >- + Final complete text content of the output item item_id: type: string + description: >- + Unique identifier of the completed output item output_index: type: integer + description: >- + Index position of the item in the output list sequence_number: type: integer + description: >- + Sequential number for ordering streaming events type: type: string const: response.output_text.done default: response.output_text.done + description: >- + Event type identifier, always "response.output_text.done" additionalProperties: false required: - content_index @@ -6369,19 +6872,29 @@ components: - type title: >- OpenAIResponseObjectStreamResponseOutputTextDone + description: >- + Streaming event for when text output is completed. "OpenAIResponseObjectStreamResponseWebSearchCallCompleted": type: object properties: item_id: type: string + description: >- + Unique identifier of the completed web search call output_index: type: integer + description: >- + Index position of the item in the output list sequence_number: type: integer + description: >- + Sequential number for ordering streaming events type: type: string const: response.web_search_call.completed default: response.web_search_call.completed + description: >- + Event type identifier, always "response.web_search_call.completed" additionalProperties: false required: - item_id @@ -6390,19 +6903,28 @@ components: - type title: >- OpenAIResponseObjectStreamResponseWebSearchCallCompleted + description: >- + Streaming event for completed web search calls. "OpenAIResponseObjectStreamResponseWebSearchCallInProgress": type: object properties: item_id: type: string + description: Unique identifier of the web search call output_index: type: integer + description: >- + Index position of the item in the output list sequence_number: type: integer + description: >- + Sequential number for ordering streaming events type: type: string const: response.web_search_call.in_progress default: response.web_search_call.in_progress + description: >- + Event type identifier, always "response.web_search_call.in_progress" additionalProperties: false required: - item_id @@ -6411,6 +6933,8 @@ components: - type title: >- OpenAIResponseObjectStreamResponseWebSearchCallInProgress + description: >- + Streaming event for web search calls in progress. "OpenAIResponseObjectStreamResponseWebSearchCallSearching": type: object properties: @@ -6437,19 +6961,26 @@ components: properties: id: type: string + description: >- + Unique identifier of the deleted response object: type: string const: response default: response + description: >- + Object type identifier, always "response" deleted: type: boolean default: true + description: Deletion confirmation flag, always True additionalProperties: false required: - id - object - deleted title: OpenAIDeleteResponseObject + description: >- + Response object confirming deletion of an OpenAI response. EmbeddingsRequest: type: object properties: @@ -6542,6 +7073,8 @@ components: - categorical_count - accuracy title: AggregationFunctionType + description: >- + Types of aggregation functions for scoring results. BasicScoringFnParams: type: object properties: @@ -6549,15 +7082,21 @@ components: $ref: '#/components/schemas/ScoringFnParamsType' const: basic default: basic + description: >- + The type of scoring function parameters, always basic aggregation_functions: type: array items: $ref: '#/components/schemas/AggregationFunctionType' + description: >- + Aggregation functions to apply to the scores of each row additionalProperties: false required: - type - aggregation_functions title: BasicScoringFnParams + description: >- + Parameters for basic scoring function configuration. BenchmarkConfig: type: object properties: @@ -6599,18 +7138,28 @@ components: $ref: '#/components/schemas/ScoringFnParamsType' const: llm_as_judge default: llm_as_judge + description: >- + The type of scoring function parameters, always llm_as_judge judge_model: type: string + description: >- + Identifier of the LLM model to use as a judge for scoring prompt_template: type: string + description: >- + (Optional) Custom prompt template for the judge model judge_score_regexes: type: array items: type: string + description: >- + Regexes to extract the answer from generated response aggregation_functions: type: array items: $ref: '#/components/schemas/AggregationFunctionType' + description: >- + Aggregation functions to apply to the scores of each row additionalProperties: false required: - type @@ -6618,6 +7167,8 @@ components: - judge_score_regexes - aggregation_functions title: LLMAsJudgeScoringFnParams + description: >- + Parameters for LLM-as-judge scoring function configuration. ModelCandidate: type: object properties: @@ -6650,20 +7201,28 @@ components: $ref: '#/components/schemas/ScoringFnParamsType' const: regex_parser default: regex_parser + description: >- + The type of scoring function parameters, always regex_parser parsing_regexes: type: array items: type: string + description: >- + Regex to extract the answer from generated response aggregation_functions: type: array items: $ref: '#/components/schemas/AggregationFunctionType' + description: >- + Aggregation functions to apply to the scores of each row additionalProperties: false required: - type - parsing_regexes - aggregation_functions title: RegexParserScoringFnParams + description: >- + Parameters for regex parser scoring function configuration. ScoringFnParams: oneOf: - $ref: '#/components/schemas/LLMAsJudgeScoringFnParams' @@ -6682,6 +7241,8 @@ components: - regex_parser - basic title: ScoringFnParamsType + description: >- + Types of scoring function parameter configurations. EvaluateRowsRequest: type: object properties: @@ -6779,31 +7340,42 @@ components: properties: agent_id: type: string + description: Unique identifier for the agent agent_config: $ref: '#/components/schemas/AgentConfig' + description: Configuration settings for the agent created_at: type: string format: date-time + description: Timestamp when the agent was created additionalProperties: false required: - agent_id - agent_config - created_at title: Agent + description: >- + An agent instance with configuration and metadata. Session: type: object properties: session_id: type: string + description: >- + Unique identifier for the conversation session session_name: type: string + description: Human-readable name for the session turns: type: array items: $ref: '#/components/schemas/Turn' + description: >- + List of all turns that have occurred in this session started_at: type: string format: date-time + description: Timestamp when the session was created additionalProperties: false required: - session_id @@ -6829,10 +7401,14 @@ components: tool_execution: '#/components/schemas/ToolExecutionStep' shield_call: '#/components/schemas/ShieldCallStep' memory_retrieval: '#/components/schemas/MemoryRetrievalStep' + description: >- + The complete step data and execution details additionalProperties: false required: - step title: AgentStepResponse + description: >- + Response containing details of a specific agent step. Benchmark: type: object properties: @@ -6853,15 +7429,19 @@ components: - benchmark - tool - tool_group - title: ResourceType const: benchmark default: benchmark + description: The resource type, always benchmark dataset_id: type: string + description: >- + Identifier of the dataset to use for the benchmark evaluation scoring_functions: type: array items: type: string + description: >- + List of scoring function identifiers to apply during evaluation metadata: type: object additionalProperties: @@ -6872,6 +7452,7 @@ components: - type: string - type: array - type: object + description: Metadata for this evaluation task additionalProperties: false required: - identifier @@ -6881,6 +7462,8 @@ components: - scoring_functions - metadata title: Benchmark + description: >- + A benchmark resource for evaluating model performance. OpenAIAssistantMessageParam: type: object properties: @@ -6922,14 +7505,20 @@ components: type: string const: image_url default: image_url + description: >- + Must be "image_url" to identify this as image content image_url: $ref: '#/components/schemas/OpenAIImageURL' + description: >- + Image URL specification and processing details additionalProperties: false required: - type - image_url title: >- OpenAIChatCompletionContentPartImageParam + description: >- + Image content part for OpenAI-compatible chat completion messages. OpenAIChatCompletionContentPartParam: oneOf: - $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam' @@ -6948,39 +7537,58 @@ components: type: string const: text default: text + description: >- + Must be "text" to identify this as text content text: type: string + description: The text content of the message additionalProperties: false required: - type - text title: OpenAIChatCompletionContentPartTextParam + description: >- + Text content part for OpenAI-compatible chat completion messages. OpenAIChatCompletionToolCall: type: object properties: index: type: integer + description: >- + (Optional) Index of the tool call in the list id: type: string + description: >- + (Optional) Unique identifier for the tool call type: type: string const: function default: function + description: >- + Must be "function" to identify this as a function call function: $ref: '#/components/schemas/OpenAIChatCompletionToolCallFunction' + description: (Optional) Function call details additionalProperties: false required: - type title: OpenAIChatCompletionToolCall + description: >- + Tool call specification for OpenAI-compatible chat completion responses. OpenAIChatCompletionToolCallFunction: type: object properties: name: type: string + description: (Optional) Name of the function to call arguments: type: string + description: >- + (Optional) Arguments to pass to the function as a JSON string additionalProperties: false title: OpenAIChatCompletionToolCallFunction + description: >- + Function call details for OpenAI-compatible tool calls. OpenAIChoice: type: object properties: @@ -7082,12 +7690,19 @@ components: properties: url: type: string + description: >- + URL of the image to include in the message detail: type: string + description: >- + (Optional) Level of detail for image processing. Can be "low", "high", + or "auto" additionalProperties: false required: - url title: OpenAIImageURL + description: >- + Image URL specification for OpenAI-compatible chat completion messages. OpenAIMessageParam: oneOf: - $ref: '#/components/schemas/OpenAIUserMessageParam' @@ -7300,20 +7915,22 @@ components: - benchmark - tool - tool_group - title: ResourceType const: dataset default: dataset + description: >- + Type of resource, always 'dataset' for datasets purpose: type: string enum: - post-training/messages - eval/question-answer - eval/messages-answer - title: DatasetPurpose description: >- - Purpose of the dataset. Each purpose has a required input data schema. + Purpose of the dataset indicating its intended use source: $ref: '#/components/schemas/DataSource' + description: >- + Data source configuration for the dataset metadata: type: object additionalProperties: @@ -7324,6 +7941,7 @@ components: - type: string - type: array - type: object + description: Additional metadata for the dataset additionalProperties: false required: - identifier @@ -7333,6 +7951,8 @@ components: - source - metadata title: Dataset + description: >- + Dataset resource for storing and accessing training or evaluation data. RowsDataSource: type: object properties: @@ -7386,10 +8006,16 @@ components: properties: identifier: type: string + description: >- + Unique identifier for this resource in llama stack provider_resource_id: type: string + description: >- + Unique identifier for this resource in the provider provider_id: type: string + description: >- + ID of the provider that owns this resource type: type: string enum: @@ -7401,9 +8027,10 @@ components: - benchmark - tool - tool_group - title: ResourceType const: model default: model + description: >- + The resource type, always 'model' for model resources metadata: type: object additionalProperties: @@ -7414,9 +8041,12 @@ components: - type: string - type: array - type: object + description: Any additional metadata for this model model_type: $ref: '#/components/schemas/ModelType' default: llm + description: >- + The type of model (LLM or embedding model) additionalProperties: false required: - identifier @@ -7425,12 +8055,16 @@ components: - metadata - model_type title: Model + description: >- + A model resource representing an AI model registered in Llama Stack. ModelType: type: string enum: - llm - embedding title: ModelType + description: >- + Enumeration of supported model types in Llama Stack. AgentTurnInputType: type: object properties: @@ -7438,10 +8072,13 @@ components: type: string const: agent_turn_input default: agent_turn_input + description: >- + Discriminator type. Always "agent_turn_input" additionalProperties: false required: - type title: AgentTurnInputType + description: Parameter type for agent turn input. ArrayType: type: object properties: @@ -7449,10 +8086,12 @@ components: type: string const: array default: array + description: Discriminator type. Always "array" additionalProperties: false required: - type title: ArrayType + description: Parameter type for array values. BooleanType: type: object properties: @@ -7460,10 +8099,12 @@ components: type: string const: boolean default: boolean + description: Discriminator type. Always "boolean" additionalProperties: false required: - type title: BooleanType + description: Parameter type for boolean values. ChatCompletionInputType: type: object properties: @@ -7471,10 +8112,14 @@ components: type: string const: chat_completion_input default: chat_completion_input + description: >- + Discriminator type. Always "chat_completion_input" additionalProperties: false required: - type title: ChatCompletionInputType + description: >- + Parameter type for chat completion input. CompletionInputType: type: object properties: @@ -7482,10 +8127,13 @@ components: type: string const: completion_input default: completion_input + description: >- + Discriminator type. Always "completion_input" additionalProperties: false required: - type title: CompletionInputType + description: Parameter type for completion input. JsonType: type: object properties: @@ -7493,10 +8141,12 @@ components: type: string const: json default: json + description: Discriminator type. Always "json" additionalProperties: false required: - type title: JsonType + description: Parameter type for JSON values. NumberType: type: object properties: @@ -7504,10 +8154,12 @@ components: type: string const: number default: number + description: Discriminator type. Always "number" additionalProperties: false required: - type title: NumberType + description: Parameter type for numeric values. ObjectType: type: object properties: @@ -7515,10 +8167,12 @@ components: type: string const: object default: object + description: Discriminator type. Always "object" additionalProperties: false required: - type title: ObjectType + description: Parameter type for object values. ParamType: oneOf: - $ref: '#/components/schemas/StringType' @@ -7564,9 +8218,10 @@ components: - benchmark - tool - tool_group - title: ResourceType const: scoring_function default: scoring_function + description: >- + The resource type, always scoring_function description: type: string metadata: @@ -7591,6 +8246,8 @@ components: - metadata - return_type title: ScoringFn + description: >- + A scoring function resource for evaluating model outputs. StringType: type: object properties: @@ -7598,10 +8255,12 @@ components: type: string const: string default: string + description: Discriminator type. Always "string" additionalProperties: false required: - type title: StringType + description: Parameter type for string values. UnionType: type: object properties: @@ -7609,10 +8268,12 @@ components: type: string const: union default: union + description: Discriminator type. Always "union" additionalProperties: false required: - type title: UnionType + description: Parameter type for union values. Shield: type: object properties: @@ -7633,9 +8294,9 @@ components: - benchmark - tool - tool_group - title: ResourceType const: shield default: shield + description: The resource type, always shield params: type: object additionalProperties: @@ -7646,6 +8307,8 @@ components: - type: string - type: array - type: object + description: >- + (Optional) Configuration parameters for the shield additionalProperties: false required: - identifier @@ -7653,24 +8316,34 @@ components: - type title: Shield description: >- - A safety shield resource that can be used to check content + A safety shield resource that can be used to check content. Span: type: object properties: span_id: type: string + description: Unique identifier for the span trace_id: type: string + description: >- + Unique identifier for the trace this span belongs to parent_span_id: type: string + description: >- + (Optional) Unique identifier for the parent span, if this is a child span name: type: string + description: >- + Human-readable name describing the operation this span represents start_time: type: string format: date-time + description: Timestamp when the operation began end_time: type: string format: date-time + description: >- + (Optional) Timestamp when the operation finished, if completed attributes: type: object additionalProperties: @@ -7681,6 +8354,8 @@ components: - type: string - type: array - type: object + description: >- + (Optional) Key-value pairs containing additional metadata about the span additionalProperties: false required: - span_id @@ -7688,6 +8363,8 @@ components: - name - start_time title: Span + description: >- + A span representing a single operation within a trace. GetSpanTreeRequest: type: object properties: @@ -7707,23 +8384,36 @@ components: - ok - error title: SpanStatus + description: >- + The status of a span indicating whether it completed successfully or with + an error. SpanWithStatus: type: object properties: span_id: type: string + description: Unique identifier for the span trace_id: type: string + description: >- + Unique identifier for the trace this span belongs to parent_span_id: type: string + description: >- + (Optional) Unique identifier for the parent span, if this is a child span name: type: string + description: >- + Human-readable name describing the operation this span represents start_time: type: string format: date-time + description: Timestamp when the operation began end_time: type: string format: date-time + description: >- + (Optional) Timestamp when the operation finished, if completed attributes: type: object additionalProperties: @@ -7734,8 +8424,12 @@ components: - type: string - type: array - type: object + description: >- + (Optional) Key-value pairs containing additional metadata about the span status: $ref: '#/components/schemas/SpanStatus' + description: >- + (Optional) The current status of the span additionalProperties: false required: - span_id @@ -7743,6 +8437,7 @@ components: - name - start_time title: SpanWithStatus + description: A span that includes status information. QuerySpanTreeResponse: type: object properties: @@ -7750,10 +8445,14 @@ components: type: object additionalProperties: $ref: '#/components/schemas/SpanWithStatus' + description: >- + Dictionary mapping span IDs to spans with status information additionalProperties: false required: - data title: QuerySpanTreeResponse + description: >- + Response containing a tree structure of spans. Tool: type: object properties: @@ -7774,17 +8473,22 @@ components: - benchmark - tool - tool_group - title: ResourceType const: tool default: tool + description: Type of resource, always 'tool' toolgroup_id: type: string + description: >- + ID of the tool group this tool belongs to description: type: string + description: >- + Human-readable description of what the tool does parameters: type: array items: $ref: '#/components/schemas/ToolParameter' + description: List of parameters this tool accepts metadata: type: object additionalProperties: @@ -7795,6 +8499,8 @@ components: - type: string - type: array - type: object + description: >- + (Optional) Additional metadata about the tool additionalProperties: false required: - identifier @@ -7804,6 +8510,7 @@ components: - description - parameters title: Tool + description: A tool that can be invoked by agents. ToolGroup: type: object properties: @@ -7824,11 +8531,13 @@ components: - benchmark - tool - tool_group - title: ResourceType const: tool_group default: tool_group + description: Type of resource, always 'tool_group' mcp_endpoint: $ref: '#/components/schemas/URL' + description: >- + (Optional) Model Context Protocol endpoint for remote tools args: type: object additionalProperties: @@ -7839,47 +8548,71 @@ components: - type: string - type: array - type: object + description: >- + (Optional) Additional arguments for the tool group additionalProperties: false required: - identifier - provider_id - type title: ToolGroup + description: >- + A group of related tools managed together. Trace: type: object properties: trace_id: type: string + description: Unique identifier for the trace root_span_id: type: string + description: >- + Unique identifier for the root span that started this trace start_time: type: string format: date-time + description: Timestamp when the trace began end_time: type: string format: date-time + description: >- + (Optional) Timestamp when the trace finished, if completed additionalProperties: false required: - trace_id - root_span_id - start_time title: Trace + description: >- + A trace representing the complete execution path of a request across multiple + operations. Checkpoint: type: object properties: identifier: type: string + description: Unique identifier for the checkpoint created_at: type: string format: date-time + description: >- + Timestamp when the checkpoint was created epoch: type: integer + description: >- + Training epoch when the checkpoint was saved post_training_job_id: type: string + description: >- + Identifier of the training job that created this checkpoint path: type: string + description: >- + File system path where the checkpoint is stored training_metrics: $ref: '#/components/schemas/PostTrainingMetric' + description: >- + (Optional) Training metrics associated with this checkpoint additionalProperties: false required: - identifier @@ -7888,16 +8621,19 @@ components: - post_training_job_id - path title: Checkpoint - description: Checkpoint created during training runs + description: Checkpoint created during training runs. PostTrainingJobArtifactsResponse: type: object properties: job_uuid: type: string + description: Unique identifier for the training job checkpoints: type: array items: $ref: '#/components/schemas/Checkpoint' + description: >- + List of model checkpoints created during training additionalProperties: false required: - job_uuid @@ -7909,12 +8645,17 @@ components: properties: epoch: type: integer + description: Training epoch number train_loss: type: number + description: Loss value on the training dataset validation_loss: type: number + description: Loss value on the validation dataset perplexity: type: number + description: >- + Perplexity metric indicating model confidence additionalProperties: false required: - epoch @@ -7922,11 +8663,14 @@ components: - validation_loss - perplexity title: PostTrainingMetric + description: >- + Training metrics captured during post-training jobs. PostTrainingJobStatusResponse: type: object properties: job_uuid: type: string + description: Unique identifier for the training job status: type: string enum: @@ -7935,16 +8679,22 @@ components: - failed - scheduled - cancelled - title: JobStatus + description: Current status of the training job scheduled_at: type: string format: date-time + description: >- + (Optional) Timestamp when the job was scheduled started_at: type: string format: date-time + description: >- + (Optional) Timestamp when the job execution began completed_at: type: string format: date-time + description: >- + (Optional) Timestamp when the job finished, if completed resources_allocated: type: object additionalProperties: @@ -7955,10 +8705,15 @@ components: - type: string - type: array - type: object + description: >- + (Optional) Information about computational resources allocated to the + job checkpoints: type: array items: $ref: '#/components/schemas/Checkpoint' + description: >- + List of model checkpoints created during training additionalProperties: false required: - job_uuid @@ -8004,13 +8759,17 @@ components: - benchmark - tool - tool_group - title: ResourceType const: vector_db default: vector_db + description: >- + Type of resource, always 'vector_db' for vector databases embedding_model: type: string + description: >- + Name of the embedding model to use for vector generation embedding_dimension: type: integer + description: Dimension of the embedding vectors vector_db_name: type: string additionalProperties: false @@ -8021,6 +8780,8 @@ components: - embedding_model - embedding_dimension title: VectorDB + description: >- + Vector database resource for storing and querying vector embeddings. HealthInfo: type: object properties: @@ -8030,11 +8791,13 @@ components: - OK - Error - Not Implemented - title: HealthStatus + description: Current health status of the service additionalProperties: false required: - status title: HealthInfo + description: >- + Health status information for the service. RAGDocument: type: object properties: @@ -8079,10 +8842,16 @@ components: type: array items: $ref: '#/components/schemas/RAGDocument' + description: >- + List of documents to index in the RAG system vector_db_id: type: string + description: >- + ID of the vector database to store the document embeddings chunk_size_in_tokens: type: integer + description: >- + (Optional) Size in tokens for document chunking during indexing additionalProperties: false required: - documents @@ -8220,10 +8989,13 @@ components: properties: api: type: string + description: The API name this provider implements provider_id: type: string + description: Unique identifier for the provider provider_type: type: string + description: The type of provider implementation config: type: object additionalProperties: @@ -8234,6 +9006,8 @@ components: - type: string - type: array - type: object + description: >- + Configuration parameters for the provider health: type: object additionalProperties: @@ -8244,6 +9018,7 @@ components: - type: string - type: array - type: object + description: Current health status of the provider additionalProperties: false required: - api @@ -8252,6 +9027,9 @@ components: - config - health title: ProviderInfo + description: >- + Information about a registered provider including its configuration and health + status. InvokeToolRequest: type: object properties: @@ -8280,10 +9058,16 @@ components: properties: content: $ref: '#/components/schemas/InterleavedContent' + description: >- + (Optional) The output content from the tool execution error_message: type: string + description: >- + (Optional) Error message if the tool execution failed error_code: type: integer + description: >- + (Optional) Numeric error code if the tool execution failed metadata: type: object additionalProperties: @@ -8294,8 +9078,11 @@ components: - type: string - type: array - type: object + description: >- + (Optional) Additional metadata about the tool execution additionalProperties: false title: ToolInvocationResult + description: Result of a tool invocation. PaginatedResponse: type: object properties: @@ -8331,6 +9118,7 @@ components: properties: job_id: type: string + description: Unique identifier for the job status: type: string enum: @@ -8339,12 +9127,14 @@ components: - failed - scheduled - cancelled - title: JobStatus + description: Current execution status of the job additionalProperties: false required: - job_id - status title: Job + description: >- + A job execution instance with status tracking. ListBenchmarksResponse: type: object properties: @@ -8362,6 +9152,7 @@ components: - asc - desc title: Order + description: Sort order for paginated responses. ListOpenAIChatCompletionResponse: type: object properties: @@ -8405,16 +9196,24 @@ components: - model - input_messages title: OpenAICompletionWithInputMessages + description: >- + List of chat completion objects with their input messages has_more: type: boolean + description: >- + Whether there are more completions available beyond this list first_id: type: string + description: ID of the first completion in this list last_id: type: string + description: ID of the last completion in this list object: type: string const: list default: list + description: >- + Must be "list" to identify this as a list response additionalProperties: false required: - data @@ -8423,6 +9222,8 @@ components: - last_id - object title: ListOpenAIChatCompletionResponse + description: >- + Response from listing OpenAI-compatible chat completions. ListDatasetsResponse: type: object properties: @@ -8430,10 +9231,12 @@ components: type: array items: $ref: '#/components/schemas/Dataset' + description: List of datasets additionalProperties: false required: - data title: ListDatasetsResponse + description: Response from listing datasets. ListModelsResponse: type: object properties: @@ -8452,15 +9255,19 @@ components: type: array items: $ref: '#/components/schemas/OpenAIResponseInput' + description: List of input items object: type: string const: list default: list + description: Object type identifier, always "list" additionalProperties: false required: - data - object title: ListOpenAIResponseInputItem + description: >- + List container for OpenAI response input items. ListOpenAIResponseObject: type: object properties: @@ -8468,16 +9275,24 @@ components: type: array items: $ref: '#/components/schemas/OpenAIResponseObjectWithInput' + description: >- + List of response objects with their input context has_more: type: boolean + description: >- + Whether there are more results available beyond this page first_id: type: string + description: >- + Identifier of the first item in this page last_id: type: string + description: Identifier of the last item in this page object: type: string const: list default: list + description: Object type identifier, always "list" additionalProperties: false required: - data @@ -8486,46 +9301,76 @@ components: - last_id - object title: ListOpenAIResponseObject + description: >- + Paginated list of OpenAI response objects with navigation metadata. OpenAIResponseObjectWithInput: type: object properties: created_at: type: integer + description: >- + Unix timestamp when the response was created error: $ref: '#/components/schemas/OpenAIResponseError' + description: >- + (Optional) Error details if the response generation failed id: type: string + description: Unique identifier for this response model: type: string + description: Model identifier used for generation object: type: string const: response default: response + description: >- + Object type identifier, always "response" output: type: array items: $ref: '#/components/schemas/OpenAIResponseOutput' + description: >- + List of generated output items (messages, tool calls, etc.) parallel_tool_calls: type: boolean default: false + description: >- + Whether tool calls can be executed in parallel previous_response_id: type: string + description: >- + (Optional) ID of the previous response in a conversation status: type: string + description: >- + Current status of the response generation temperature: type: number + description: >- + (Optional) Sampling temperature used for generation text: $ref: '#/components/schemas/OpenAIResponseText' + description: >- + Text formatting configuration for the response top_p: type: number + description: >- + (Optional) Nucleus sampling parameter used for generation truncation: type: string + description: >- + (Optional) Truncation strategy applied to the response user: type: string + description: >- + (Optional) User identifier associated with the request input: type: array items: $ref: '#/components/schemas/OpenAIResponseInput' + description: >- + List of input items that led to this response additionalProperties: false required: - created_at @@ -8538,6 +9383,8 @@ components: - text - input title: OpenAIResponseObjectWithInput + description: >- + OpenAI response object extended with input context information. ListProvidersResponse: type: object properties: @@ -8545,27 +9392,37 @@ components: type: array items: $ref: '#/components/schemas/ProviderInfo' + description: List of provider information objects additionalProperties: false required: - data title: ListProvidersResponse + description: >- + Response containing a list of all available providers. RouteInfo: type: object properties: route: type: string + description: The API endpoint path method: type: string + description: HTTP method for the route provider_types: type: array items: type: string + description: >- + List of provider types that implement this route additionalProperties: false required: - route - method - provider_types title: RouteInfo + description: >- + Information about an API route including its path, method, and implementing + providers. ListRoutesResponse: type: object properties: @@ -8573,10 +9430,14 @@ components: type: array items: $ref: '#/components/schemas/RouteInfo' + description: >- + List of available route information objects additionalProperties: false required: - data title: ListRoutesResponse + description: >- + Response containing a list of all available API routes. ListToolDefsResponse: type: object properties: @@ -8584,10 +9445,13 @@ components: type: array items: $ref: '#/components/schemas/ToolDef' + description: List of tool definitions additionalProperties: false required: - data title: ListToolDefsResponse + description: >- + Response containing a list of tool definitions. ListScoringFunctionsResponse: type: object properties: @@ -8617,10 +9481,13 @@ components: type: array items: $ref: '#/components/schemas/ToolGroup' + description: List of tool groups additionalProperties: false required: - data title: ListToolGroupsResponse + description: >- + Response containing a list of tool groups. ListToolsResponse: type: object properties: @@ -8628,10 +9495,12 @@ components: type: array items: $ref: '#/components/schemas/Tool' + description: List of tools additionalProperties: false required: - data title: ListToolsResponse + description: Response containing a list of tools. ListVectorDBsResponse: type: object properties: @@ -8639,10 +9508,12 @@ components: type: array items: $ref: '#/components/schemas/VectorDB' + description: List of vector databases additionalProperties: false required: - data title: ListVectorDBsResponse + description: Response from listing vector databases. Event: oneOf: - $ref: '#/components/schemas/UnstructuredLogEvent' @@ -8661,6 +9532,8 @@ components: - structured_log - metric title: EventType + description: >- + The type of telemetry event being logged. LogSeverity: type: string enum: @@ -8671,16 +9544,22 @@ components: - error - critical title: LogSeverity + description: The severity level of a log message. MetricEvent: type: object properties: trace_id: type: string + description: >- + Unique identifier for the trace this event belongs to span_id: type: string + description: >- + Unique identifier for the span this event belongs to timestamp: type: string format: date-time + description: Timestamp when the event occurred attributes: type: object additionalProperties: @@ -8690,18 +9569,26 @@ components: - type: number - type: boolean - type: 'null' + description: >- + (Optional) Key-value pairs containing additional metadata about the event type: $ref: '#/components/schemas/EventType' const: metric default: metric + description: Event type identifier set to METRIC metric: type: string + description: The name of the metric being measured value: oneOf: - type: integer - type: number + description: >- + The numeric value of the metric measurement unit: type: string + description: >- + The unit of measurement for the metric value additionalProperties: false required: - trace_id @@ -8712,6 +9599,8 @@ components: - value - unit title: MetricEvent + description: >- + A metric event containing a measured value. SpanEndPayload: type: object properties: @@ -8719,13 +9608,17 @@ components: $ref: '#/components/schemas/StructuredLogType' const: span_end default: span_end + description: Payload type identifier set to SPAN_END status: $ref: '#/components/schemas/SpanStatus' + description: >- + The final status of the span indicating success or failure additionalProperties: false required: - type - status title: SpanEndPayload + description: Payload for a span end event. SpanStartPayload: type: object properties: @@ -8733,25 +9626,37 @@ components: $ref: '#/components/schemas/StructuredLogType' const: span_start default: span_start + description: >- + Payload type identifier set to SPAN_START name: type: string + description: >- + Human-readable name describing the operation this span represents parent_span_id: type: string + description: >- + (Optional) Unique identifier for the parent span, if this is a child span additionalProperties: false required: - type - name title: SpanStartPayload + description: Payload for a span start event. StructuredLogEvent: type: object properties: trace_id: type: string + description: >- + Unique identifier for the trace this event belongs to span_id: type: string + description: >- + Unique identifier for the span this event belongs to timestamp: type: string format: date-time + description: Timestamp when the event occurred attributes: type: object additionalProperties: @@ -8761,12 +9666,18 @@ components: - type: number - type: boolean - type: 'null' + description: >- + (Optional) Key-value pairs containing additional metadata about the event type: $ref: '#/components/schemas/EventType' const: structured_log default: structured_log + description: >- + Event type identifier set to STRUCTURED_LOG payload: $ref: '#/components/schemas/StructuredLogPayload' + description: >- + The structured payload data for the log event additionalProperties: false required: - trace_id @@ -8775,6 +9686,8 @@ components: - type - payload title: StructuredLogEvent + description: >- + A structured log event containing typed payload data. StructuredLogPayload: oneOf: - $ref: '#/components/schemas/SpanStartPayload' @@ -8790,16 +9703,23 @@ components: - span_start - span_end title: StructuredLogType + description: >- + The type of structured log event payload. UnstructuredLogEvent: type: object properties: trace_id: type: string + description: >- + Unique identifier for the trace this event belongs to span_id: type: string + description: >- + Unique identifier for the span this event belongs to timestamp: type: string format: date-time + description: Timestamp when the event occurred attributes: type: object additionalProperties: @@ -8809,14 +9729,20 @@ components: - type: number - type: boolean - type: 'null' + description: >- + (Optional) Key-value pairs containing additional metadata about the event type: $ref: '#/components/schemas/EventType' const: unstructured_log default: unstructured_log + description: >- + Event type identifier set to UNSTRUCTURED_LOG message: type: string + description: The log message text severity: $ref: '#/components/schemas/LogSeverity' + description: The severity level of the log message additionalProperties: false required: - trace_id @@ -8826,6 +9752,8 @@ components: - message - severity title: UnstructuredLogEvent + description: >- + An unstructured log event containing a simple text message. LogEventRequest: type: object properties: @@ -8856,10 +9784,14 @@ components: type: string const: auto default: auto + description: >- + Strategy type, always "auto" for automatic chunking additionalProperties: false required: - type title: VectorStoreChunkingStrategyAuto + description: >- + Automatic chunking strategy for vector store files. VectorStoreChunkingStrategyStatic: type: object properties: @@ -8867,27 +9799,39 @@ components: type: string const: static default: static + description: >- + Strategy type, always "static" for static chunking static: $ref: '#/components/schemas/VectorStoreChunkingStrategyStaticConfig' + description: >- + Configuration parameters for the static chunking strategy additionalProperties: false required: - type - static title: VectorStoreChunkingStrategyStatic + description: >- + Static chunking strategy with configurable parameters. VectorStoreChunkingStrategyStaticConfig: type: object properties: chunk_overlap_tokens: type: integer default: 400 + description: >- + Number of tokens to overlap between adjacent chunks max_chunk_size_tokens: type: integer default: 800 + description: >- + Maximum number of tokens per chunk, must be between 100 and 4096 additionalProperties: false required: - chunk_overlap_tokens - max_chunk_size_tokens title: VectorStoreChunkingStrategyStaticConfig + description: >- + Configuration for static chunking strategy. OpenaiAttachFileToVectorStoreRequest: type: object properties: @@ -8924,21 +9868,30 @@ components: const: server_error - type: string const: rate_limit_exceeded + description: >- + Error code indicating the type of failure message: type: string + description: >- + Human-readable error message describing the failure additionalProperties: false required: - code - message title: VectorStoreFileLastError + description: >- + Error information for failed vector store file processing. VectorStoreFileObject: type: object properties: id: type: string + description: Unique identifier for the file object: type: string default: vector_store.file + description: >- + Object type identifier, always "vector_store.file" attributes: type: object additionalProperties: @@ -8949,19 +9902,31 @@ components: - type: string - type: array - type: object + description: >- + Key-value attributes associated with the file chunking_strategy: $ref: '#/components/schemas/VectorStoreChunkingStrategy' + description: >- + Strategy used for splitting the file into chunks created_at: type: integer + description: >- + Timestamp when the file was added to the vector store last_error: $ref: '#/components/schemas/VectorStoreFileLastError' + description: >- + (Optional) Error information if file processing failed status: $ref: '#/components/schemas/VectorStoreFileStatus' + description: Current processing status of the file usage_bytes: type: integer default: 0 + description: Storage space used by this file in bytes vector_store_id: type: string + description: >- + ID of the vector store containing this file additionalProperties: false required: - id @@ -8989,10 +9954,14 @@ components: properties: name: type: string + description: Name of the schema description: type: string + description: (Optional) Description of the schema strict: type: boolean + description: >- + (Optional) Whether to enforce strict adherence to the schema schema: type: object additionalProperties: @@ -9003,10 +9972,13 @@ components: - type: string - type: array - type: object + description: (Optional) The JSON schema definition additionalProperties: false required: - name title: OpenAIJSONSchema + description: >- + JSON schema specification for OpenAI-compatible structured response format. OpenAIResponseFormatJSONObject: type: object properties: @@ -9014,10 +9986,14 @@ components: type: string const: json_object default: json_object + description: >- + Must be "json_object" to indicate generic JSON object response format additionalProperties: false required: - type title: OpenAIResponseFormatJSONObject + description: >- + JSON object response format for OpenAI-compatible chat completion requests. OpenAIResponseFormatJSONSchema: type: object properties: @@ -9025,13 +10001,19 @@ components: type: string const: json_schema default: json_schema + description: >- + Must be "json_schema" to indicate structured JSON response format json_schema: $ref: '#/components/schemas/OpenAIJSONSchema' + description: >- + The JSON schema specification for the response additionalProperties: false required: - type - json_schema title: OpenAIResponseFormatJSONSchema + description: >- + JSON schema response format for OpenAI-compatible chat completion requests. OpenAIResponseFormatParam: oneOf: - $ref: '#/components/schemas/OpenAIResponseFormatText' @@ -9050,10 +10032,14 @@ components: type: string const: text default: text + description: >- + Must be "text" to indicate plain text response format additionalProperties: false required: - type title: OpenAIResponseFormatText + description: >- + Text response format for OpenAI-compatible chat completion requests. OpenaiChatCompletionRequest: type: object properties: @@ -9530,14 +10516,23 @@ components: properties: completed: type: integer + description: >- + Number of files that have been successfully processed cancelled: type: integer + description: >- + Number of files that had their processing cancelled failed: type: integer + description: Number of files that failed to process in_progress: type: integer + description: >- + Number of files currently being processed total: type: integer + description: >- + Total number of files in the vector store additionalProperties: false required: - completed @@ -9546,26 +10541,39 @@ components: - in_progress - total title: VectorStoreFileCounts + description: >- + File processing status counts for a vector store. VectorStoreObject: type: object properties: id: type: string + description: Unique identifier for the vector store object: type: string default: vector_store + description: >- + Object type identifier, always "vector_store" created_at: type: integer + description: >- + Timestamp when the vector store was created name: type: string + description: (Optional) Name of the vector store usage_bytes: type: integer default: 0 + description: >- + Storage space used by the vector store in bytes file_counts: $ref: '#/components/schemas/VectorStoreFileCounts' + description: >- + File processing status counts for the vector store status: type: string default: completed + description: Current status of the vector store expires_after: type: object additionalProperties: @@ -9576,10 +10584,16 @@ components: - type: string - type: array - type: object + description: >- + (Optional) Expiration policy for the vector store expires_at: type: integer + description: >- + (Optional) Timestamp when the vector store will expire last_active_at: type: integer + description: >- + (Optional) Timestamp of last activity on the vector store metadata: type: object additionalProperties: @@ -9590,6 +10604,8 @@ components: - type: string - type: array - type: object + description: >- + Set of key-value pairs that can be attached to the vector store additionalProperties: false required: - id @@ -9629,12 +10645,18 @@ components: properties: id: type: string + description: >- + Unique identifier of the deleted vector store object: type: string default: vector_store.deleted + description: >- + Object type identifier for the deletion response deleted: type: boolean default: true + description: >- + Whether the deletion operation was successful additionalProperties: false required: - id @@ -9647,12 +10669,17 @@ components: properties: id: type: string + description: Unique identifier of the deleted file object: type: string default: vector_store.file.deleted + description: >- + Object type identifier for the deletion response deleted: type: boolean default: true + description: >- + Whether the deletion operation was successful additionalProperties: false required: - id @@ -9790,10 +10817,16 @@ components: description: List of file objects has_more: type: boolean + description: >- + Whether there are more files available beyond this page first_id: type: string + description: >- + ID of the first file in the list for pagination last_id: type: string + description: >- + ID of the last file in the list for pagination object: type: string const: list @@ -9858,24 +10891,33 @@ components: object: type: string default: list + description: Object type identifier, always "list" data: type: array items: $ref: '#/components/schemas/VectorStoreFileObject' + description: List of vector store file objects first_id: type: string + description: >- + (Optional) ID of the first file in the list for pagination last_id: type: string + description: >- + (Optional) ID of the last file in the list for pagination has_more: type: boolean default: false + description: >- + Whether there are more files available beyond this page additionalProperties: false required: - object - data - has_more title: VectorStoreListFilesResponse - description: Response from listing vector stores. + description: >- + Response from listing files in a vector store. OpenAIModel: type: object properties: @@ -9914,17 +10956,25 @@ components: object: type: string default: list + description: Object type identifier, always "list" data: type: array items: $ref: '#/components/schemas/VectorStoreObject' + description: List of vector store objects first_id: type: string + description: >- + (Optional) ID of the first vector store in the list for pagination last_id: type: string + description: >- + (Optional) ID of the last vector store in the list for pagination has_more: type: boolean default: false + description: >- + Whether there are more vector stores available beyond this page additionalProperties: false required: - object @@ -9941,20 +10991,27 @@ components: type: type: string const: text + description: >- + Content type, currently only "text" is supported text: type: string + description: The actual text content additionalProperties: false required: - type - text title: VectorStoreContent + description: >- + Content item from a vector store file or search result. VectorStoreFileContentsResponse: type: object properties: file_id: type: string + description: Unique identifier for the file filename: type: string + description: Name of the file attributes: type: object additionalProperties: @@ -9965,10 +11022,13 @@ components: - type: string - type: array - type: object + description: >- + Key-value attributes associated with the file content: type: array items: $ref: '#/components/schemas/VectorStoreContent' + description: List of content items from the file additionalProperties: false required: - file_id @@ -10010,9 +11070,13 @@ components: properties: ranker: type: string + description: >- + (Optional) Name of the ranking algorithm to use score_threshold: type: number default: 0.0 + description: >- + (Optional) Minimum relevance score threshold for results additionalProperties: false description: >- Ranking options for fine-tuning the search results. @@ -10034,10 +11098,14 @@ components: properties: file_id: type: string + description: >- + Unique identifier of the file containing the result filename: type: string + description: Name of the file containing the result score: type: number + description: Relevance score for this search result attributes: type: object additionalProperties: @@ -10045,10 +11113,14 @@ components: - type: string - type: number - type: boolean + description: >- + (Optional) Key-value attributes associated with the file content: type: array items: $ref: '#/components/schemas/VectorStoreContent' + description: >- + List of content items matching the search query additionalProperties: false required: - file_id @@ -10063,17 +11135,26 @@ components: object: type: string default: vector_store.search_results.page + description: >- + Object type identifier for the search results page search_query: type: string + description: >- + The original search query that was executed data: type: array items: $ref: '#/components/schemas/VectorStoreSearchResponse' + description: List of search result objects has_more: type: boolean default: false + description: >- + Whether there are more results available beyond this page next_page: type: string + description: >- + (Optional) Token for retrieving the next page of results additionalProperties: false required: - object @@ -10081,7 +11162,8 @@ components: - data - has_more title: VectorStoreSearchResponsePage - description: Response from searching a vector store. + description: >- + Paginated response from searching a vector store. OpenaiUpdateVectorStoreRequest: type: object properties: @@ -10138,14 +11220,18 @@ components: properties: beta: type: number + description: Temperature parameter for the DPO loss loss_type: $ref: '#/components/schemas/DPOLossType' default: sigmoid + description: The type of loss function to use for DPO additionalProperties: false required: - beta - loss_type title: DPOAlignmentConfig + description: >- + Configuration for Direct Preference Optimization (DPO) alignment. DPOLossType: type: string enum: @@ -10159,20 +11245,34 @@ components: properties: dataset_id: type: string + description: >- + Unique identifier for the training dataset batch_size: type: integer + description: Number of samples per training batch shuffle: type: boolean + description: >- + Whether to shuffle the dataset during training data_format: $ref: '#/components/schemas/DatasetFormat' + description: >- + Format of the dataset (instruct or dialog) validation_dataset_id: type: string + description: >- + (Optional) Unique identifier for the validation dataset packed: type: boolean default: false + description: >- + (Optional) Whether to pack multiple samples into a single sequence for + efficiency train_on_input: type: boolean default: false + description: >- + (Optional) Whether to compute loss on input tokens as well as output tokens additionalProperties: false required: - dataset_id @@ -10180,40 +11280,59 @@ components: - shuffle - data_format title: DataConfig + description: >- + Configuration for training data and data loading. DatasetFormat: type: string enum: - instruct - dialog title: DatasetFormat + description: Format of the training dataset. EfficiencyConfig: type: object properties: enable_activation_checkpointing: type: boolean default: false + description: >- + (Optional) Whether to use activation checkpointing to reduce memory usage enable_activation_offloading: type: boolean default: false + description: >- + (Optional) Whether to offload activations to CPU to save GPU memory memory_efficient_fsdp_wrap: type: boolean default: false + description: >- + (Optional) Whether to use memory-efficient FSDP wrapping fsdp_cpu_offload: type: boolean default: false + description: >- + (Optional) Whether to offload FSDP parameters to CPU additionalProperties: false title: EfficiencyConfig + description: >- + Configuration for memory and compute efficiency optimizations. OptimizerConfig: type: object properties: optimizer_type: $ref: '#/components/schemas/OptimizerType' + description: >- + Type of optimizer to use (adam, adamw, or sgd) lr: type: number + description: Learning rate for the optimizer weight_decay: type: number + description: >- + Weight decay coefficient for regularization num_warmup_steps: type: integer + description: Number of steps for learning rate warmup additionalProperties: false required: - optimizer_type @@ -10221,6 +11340,8 @@ components: - weight_decay - num_warmup_steps title: OptimizerConfig + description: >- + Configuration parameters for the optimization algorithm. OptimizerType: type: string enum: @@ -10228,35 +11349,53 @@ components: - adamw - sgd title: OptimizerType + description: >- + Available optimizer algorithms for training. TrainingConfig: type: object properties: n_epochs: type: integer + description: Number of training epochs to run max_steps_per_epoch: type: integer default: 1 + description: Maximum number of steps to run per epoch gradient_accumulation_steps: type: integer default: 1 + description: >- + Number of steps to accumulate gradients before updating max_validation_steps: type: integer default: 1 + description: >- + (Optional) Maximum number of validation steps per epoch data_config: $ref: '#/components/schemas/DataConfig' + description: >- + (Optional) Configuration for data loading and formatting optimizer_config: $ref: '#/components/schemas/OptimizerConfig' + description: >- + (Optional) Configuration for the optimization algorithm efficiency_config: $ref: '#/components/schemas/EfficiencyConfig' + description: >- + (Optional) Configuration for memory and compute optimizations dtype: type: string default: bf16 + description: >- + (Optional) Data type for model parameters (bf16, fp16, fp32) additionalProperties: false required: - n_epochs - max_steps_per_epoch - gradient_accumulation_steps title: TrainingConfig + description: >- + Comprehensive configuration for the training process. PreferenceOptimizeRequest: type: object properties: @@ -10319,14 +11458,20 @@ components: type: string const: default default: default + description: >- + Type of query generator, always 'default' separator: type: string default: ' ' + description: >- + String separator used to join query terms additionalProperties: false required: - type - separator title: DefaultRAGQueryGeneratorConfig + description: >- + Configuration for the default RAG query generator. LLMRAGQueryGeneratorConfig: type: object properties: @@ -10334,16 +11479,23 @@ components: type: string const: llm default: llm + description: Type of query generator, always 'llm' model: type: string + description: >- + Name of the language model to use for query generation template: type: string + description: >- + Template string for formatting the query generation prompt additionalProperties: false required: - type - model - template title: LLMRAGQueryGeneratorConfig + description: >- + Configuration for the LLM-based RAG query generator. RAGQueryConfig: type: object properties: @@ -10424,8 +11576,7 @@ components: default: 60.0 description: >- The impact factor for RRF scoring. Higher values give more weight to higher-ranked - results. Must be greater than 0. Default of 60 is from the original RRF - paper (Cormack et al., 2009). + results. Must be greater than 0 additionalProperties: false required: - type @@ -10468,12 +11619,18 @@ components: properties: content: $ref: '#/components/schemas/InterleavedContent' + description: >- + The query content to search for in the indexed documents vector_db_ids: type: array items: type: string + description: >- + List of vector database IDs to search within query_config: $ref: '#/components/schemas/RAGQueryConfig' + description: >- + (Optional) Configuration parameters for the query operation additionalProperties: false required: - content @@ -10484,6 +11641,8 @@ components: properties: content: $ref: '#/components/schemas/InterleavedContent' + description: >- + (Optional) The retrieved content from the query metadata: type: object additionalProperties: @@ -10494,10 +11653,14 @@ components: - type: string - type: array - type: object + description: >- + Additional metadata about the query result additionalProperties: false required: - metadata title: RAGQueryResult + description: >- + Result of a RAG query containing retrieved content and metadata. QueryChunksRequest: type: object properties: @@ -10531,15 +11694,21 @@ components: type: array items: $ref: '#/components/schemas/Chunk' + description: >- + List of content chunks returned from the query scores: type: array items: type: number + description: >- + Relevance scores corresponding to each returned chunk additionalProperties: false required: - chunks - scores title: QueryChunksResponse + description: >- + Response from querying chunks in a vector database. QueryMetricsRequest: type: object properties: @@ -10565,8 +11734,10 @@ components: properties: name: type: string + description: The name of the label to match value: type: string + description: The value to match against operator: type: string enum: @@ -10574,7 +11745,8 @@ components: - '!=' - =~ - '!~' - title: MetricLabelOperator + description: >- + The comparison operator to use for matching default: '=' additionalProperties: false required: @@ -10582,6 +11754,8 @@ components: - value - operator title: MetricLabelMatcher + description: >- + A matcher for filtering metrics by label values. description: >- The label matchers to apply to the metric. additionalProperties: false @@ -10594,44 +11768,59 @@ components: properties: timestamp: type: integer + description: >- + Unix timestamp when the metric value was recorded value: type: number + description: >- + The numeric value of the metric at this timestamp additionalProperties: false required: - timestamp - value title: MetricDataPoint + description: >- + A single data point in a metric time series. MetricLabel: type: object properties: name: type: string + description: The name of the label value: type: string + description: The value of the label additionalProperties: false required: - name - value title: MetricLabel + description: A label associated with a metric. MetricSeries: type: object properties: metric: type: string + description: The name of the metric labels: type: array items: $ref: '#/components/schemas/MetricLabel' + description: >- + List of labels associated with this metric series values: type: array items: $ref: '#/components/schemas/MetricDataPoint' + description: >- + List of data points in chronological order additionalProperties: false required: - metric - labels - values title: MetricSeries + description: A time series of metric data points. QueryMetricsResponse: type: object properties: @@ -10639,17 +11828,23 @@ components: type: array items: $ref: '#/components/schemas/MetricSeries' + description: >- + List of metric series matching the query criteria additionalProperties: false required: - data title: QueryMetricsResponse + description: >- + Response containing metric time series data. QueryCondition: type: object properties: key: type: string + description: The attribute key to filter on op: $ref: '#/components/schemas/QueryConditionOp' + description: The comparison operator to apply value: oneOf: - type: 'null' @@ -10658,12 +11853,14 @@ components: - type: string - type: array - type: object + description: The value to compare against additionalProperties: false required: - key - op - value title: QueryCondition + description: A condition for filtering query results. QueryConditionOp: type: string enum: @@ -10672,6 +11869,8 @@ components: - gt - lt title: QueryConditionOp + description: >- + Comparison operators for query conditions. QuerySpansRequest: type: object properties: @@ -10701,10 +11900,13 @@ components: type: array items: $ref: '#/components/schemas/Span' + description: >- + List of spans matching the query criteria additionalProperties: false required: - data title: QuerySpansResponse + description: Response containing a list of spans. QueryTracesRequest: type: object properties: @@ -10734,10 +11936,13 @@ components: type: array items: $ref: '#/components/schemas/Trace' + description: >- + List of traces matching the query criteria additionalProperties: false required: - data title: QueryTracesResponse + description: Response containing a list of traces. RegisterBenchmarkRequest: type: object properties: @@ -11009,6 +12214,100 @@ components: required: - benchmark_config title: RunEvalRequest + RunModerationRequest: + type: object + properties: + input: + oneOf: + - type: string + - type: array + items: + type: string + description: >- + Input (or inputs) to classify. Can be a single string, an array of strings, + or an array of multi-modal input objects similar to other models. + model: + type: string + description: >- + The content moderation model you would like to use. + additionalProperties: false + required: + - input + - model + title: RunModerationRequest + ModerationObject: + type: object + properties: + id: + type: string + description: >- + The unique identifier for the moderation request. + model: + type: string + description: >- + The model used to generate the moderation results. + results: + type: array + items: + $ref: '#/components/schemas/ModerationObjectResults' + description: A list of moderation objects + additionalProperties: false + required: + - id + - model + - results + title: ModerationObject + description: A moderation object. + ModerationObjectResults: + type: object + properties: + flagged: + type: boolean + description: >- + Whether any of the below categories are flagged. + categories: + type: object + additionalProperties: + type: boolean + description: >- + A list of the categories, and whether they are flagged or not. + category_applied_input_types: + type: object + additionalProperties: + type: array + items: + type: string + description: >- + A list of the categories along with the input type(s) that the score applies + to. + category_scores: + type: object + additionalProperties: + type: number + description: >- + A list of the categories along with their scores as predicted by model. + Required set of categories that need to be in response - violence - violence/graphic + - harassment - harassment/threatening - hate - hate/threatening - illicit + - illicit/violent - sexual - sexual/minors - self-harm - self-harm/intent + - self-harm/instructions + user_message: + type: string + metadata: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + additionalProperties: false + required: + - flagged + - metadata + title: ModerationObjectResults + description: A moderation object. RunShieldRequest: type: object properties: @@ -11042,8 +12341,11 @@ components: properties: violation: $ref: '#/components/schemas/SafetyViolation' + description: >- + (Optional) Safety violation detected by the shield, if any additionalProperties: false title: RunShieldResponse + description: Response from running a safety shield. SaveSpansToDatasetRequest: type: object properties: @@ -11143,14 +12445,20 @@ components: properties: dataset_id: type: string + description: >- + (Optional) The identifier of the dataset that was scored results: type: object additionalProperties: $ref: '#/components/schemas/ScoringResult' + description: >- + A map of scoring function name to ScoringResult additionalProperties: false required: - results title: ScoreBatchResponse + description: >- + Response from batch scoring operations on datasets. AlgorithmConfig: oneOf: - $ref: '#/components/schemas/LoraFinetuningConfig' @@ -11167,24 +12475,38 @@ components: type: string const: LoRA default: LoRA + description: Algorithm type identifier, always "LoRA" lora_attn_modules: type: array items: type: string + description: >- + List of attention module names to apply LoRA to apply_lora_to_mlp: type: boolean + description: Whether to apply LoRA to MLP layers apply_lora_to_output: type: boolean + description: >- + Whether to apply LoRA to output projection layers rank: type: integer + description: >- + Rank of the LoRA adaptation (lower rank = fewer parameters) alpha: type: integer + description: >- + LoRA scaling parameter that controls adaptation strength use_dora: type: boolean default: false + description: >- + (Optional) Whether to use DoRA (Weight-Decomposed Low-Rank Adaptation) quantize_base: type: boolean default: false + description: >- + (Optional) Whether to quantize the base model weights additionalProperties: false required: - type @@ -11194,6 +12516,8 @@ components: - rank - alpha title: LoraFinetuningConfig + description: >- + Configuration for Low-Rank Adaptation (LoRA) fine-tuning. QATFinetuningConfig: type: object properties: @@ -11201,16 +12525,22 @@ components: type: string const: QAT default: QAT + description: Algorithm type identifier, always "QAT" quantizer_name: type: string + description: >- + Name of the quantization algorithm to use group_size: type: integer + description: Size of groups for grouped quantization additionalProperties: false required: - type - quantizer_name - group_size title: QATFinetuningConfig + description: >- + Configuration for Quantization-Aware Training (QAT) fine-tuning. SupervisedFineTuneRequest: type: object properties: @@ -11265,6 +12595,8 @@ components: type: array items: $ref: '#/components/schemas/Message' + description: >- + List of conversation messages to use as input for synthetic data generation filtering_function: type: string enum: @@ -11274,10 +12606,13 @@ components: - top_p - top_k_top_p - sigmoid - title: FilteringFunction - description: The type of filtering function. + description: >- + Type of filtering to apply to generated synthetic data samples model: type: string + description: >- + (Optional) The identifier of the model to use. The model must be registered + with Llama Stack and available via the /models endpoint additionalProperties: false required: - dialogs @@ -11298,6 +12633,8 @@ components: - type: string - type: array - type: object + description: >- + List of generated synthetic data samples that passed the filtering criteria statistics: type: object additionalProperties: @@ -11308,6 +12645,9 @@ components: - type: string - type: array - type: object + description: >- + (Optional) Statistical information about the generation process and filtering + results additionalProperties: false required: - synthetic_data @@ -11320,10 +12660,12 @@ components: properties: version: type: string + description: Version number of the service additionalProperties: false required: - version title: VersionInfo + description: Version information for the service. responses: BadRequest400: description: The request was invalid or malformed diff --git a/docs/getting_started.ipynb b/docs/getting_started.ipynb index 88878c9be..eeebf12d9 100644 --- a/docs/getting_started.ipynb +++ b/docs/getting_started.ipynb @@ -123,7 +123,7 @@ " del os.environ[\"UV_SYSTEM_PYTHON\"]\n", "\n", "# this command installs all the dependencies needed for the llama stack server with the together inference provider\n", - "!uv run --with llama-stack llama stack build --template together --image-type venv \n", + "!uv run --with llama-stack llama stack build --distro together --image-type venv \n", "\n", "def run_llama_stack_server_background():\n", " log_file = open(\"llama_stack_server.log\", \"w\")\n", @@ -165,7 +165,7 @@ "# use this helper if needed to kill the server \n", "def kill_llama_stack_server():\n", " # Kill any existing llama stack server processes\n", - " os.system(\"ps aux | grep -v grep | grep llama_stack.distribution.server.server | awk '{print $2}' | xargs kill -9\")\n" + " os.system(\"ps aux | grep -v grep | grep llama_stack.core.server.server | awk '{print $2}' | xargs kill -9\")\n" ] }, { diff --git a/docs/getting_started_llama4.ipynb b/docs/getting_started_llama4.ipynb index 82aef6039..1913330fe 100644 --- a/docs/getting_started_llama4.ipynb +++ b/docs/getting_started_llama4.ipynb @@ -233,7 +233,7 @@ " del os.environ[\"UV_SYSTEM_PYTHON\"]\n", "\n", "# this command installs all the dependencies needed for the llama stack server \n", - "!uv run --with llama-stack llama stack build --template meta-reference-gpu --image-type venv \n", + "!uv run --with llama-stack llama stack build --distro meta-reference-gpu --image-type venv \n", "\n", "def run_llama_stack_server_background():\n", " log_file = open(\"llama_stack_server.log\", \"w\")\n", @@ -275,7 +275,7 @@ "# use this helper if needed to kill the server \n", "def kill_llama_stack_server():\n", " # Kill any existing llama stack server processes\n", - " os.system(\"ps aux | grep -v grep | grep llama_stack.distribution.server.server | awk '{print $2}' | xargs kill -9\")\n" + " os.system(\"ps aux | grep -v grep | grep llama_stack.core.server.server | awk '{print $2}' | xargs kill -9\")\n" ] }, { diff --git a/docs/getting_started_llama_api.ipynb b/docs/getting_started_llama_api.ipynb index e6c74986b..5a4283117 100644 --- a/docs/getting_started_llama_api.ipynb +++ b/docs/getting_started_llama_api.ipynb @@ -223,7 +223,7 @@ " del os.environ[\"UV_SYSTEM_PYTHON\"]\n", "\n", "# this command installs all the dependencies needed for the llama stack server \n", - "!uv run --with llama-stack llama stack build --template llama_api --image-type venv \n", + "!uv run --with llama-stack llama stack build --distro llama_api --image-type venv \n", "\n", "def run_llama_stack_server_background():\n", " log_file = open(\"llama_stack_server.log\", \"w\")\n", @@ -265,7 +265,7 @@ "# use this helper if needed to kill the server \n", "def kill_llama_stack_server():\n", " # Kill any existing llama stack server processes\n", - " os.system(\"ps aux | grep -v grep | grep llama_stack.distribution.server.server | awk '{print $2}' | xargs kill -9\")\n" + " os.system(\"ps aux | grep -v grep | grep llama_stack.core.server.server | awk '{print $2}' | xargs kill -9\")\n" ] }, { diff --git a/docs/notebooks/Alpha_Llama_Stack_Post_Training.ipynb b/docs/notebooks/Alpha_Llama_Stack_Post_Training.ipynb index b7d769b51..9b1893f9d 100644 --- a/docs/notebooks/Alpha_Llama_Stack_Post_Training.ipynb +++ b/docs/notebooks/Alpha_Llama_Stack_Post_Training.ipynb @@ -37,7 +37,7 @@ "\n", "To learn more about torchtune: https://github.com/pytorch/torchtune\n", "\n", - "We will use [experimental-post-training](https://github.com/meta-llama/llama-stack/tree/main/llama_stack/templates/experimental-post-training) as the distribution template\n", + "We will use [experimental-post-training](https://github.com/meta-llama/llama-stack/tree/main/llama_stack/distributions/experimental-post-training) as the distribution template\n", "\n", "#### 0.0. Prerequisite: Have an OpenAI API key\n", "In this showcase, we will use [braintrust](https://www.braintrust.dev/) as scoring provider for eval and it uses OpenAI model as judge model for scoring. So, you need to get an API key from [OpenAI developer platform](https://platform.openai.com/docs/overview).\n", @@ -2864,7 +2864,7 @@ } ], "source": [ - "!llama stack build --template experimental-post-training --image-type venv --image-name __system__" + "!llama stack build --distro experimental-post-training --image-type venv --image-name __system__" ] }, { @@ -3216,19 +3216,19 @@ "INFO:datasets:Duckdb version 1.1.3 available.\n", "INFO:datasets:TensorFlow version 2.18.0 available.\n", "INFO:datasets:JAX version 0.4.33 available.\n", - "INFO:llama_stack.distribution.stack:Scoring_fns: basic::equality served by basic\n", - "INFO:llama_stack.distribution.stack:Scoring_fns: basic::subset_of served by basic\n", - "INFO:llama_stack.distribution.stack:Scoring_fns: basic::regex_parser_multiple_choice_answer served by basic\n", - "INFO:llama_stack.distribution.stack:Scoring_fns: braintrust::factuality served by braintrust\n", - "INFO:llama_stack.distribution.stack:Scoring_fns: braintrust::answer-correctness served by braintrust\n", - "INFO:llama_stack.distribution.stack:Scoring_fns: braintrust::answer-relevancy served by braintrust\n", - "INFO:llama_stack.distribution.stack:Scoring_fns: braintrust::answer-similarity served by braintrust\n", - "INFO:llama_stack.distribution.stack:Scoring_fns: braintrust::faithfulness served by braintrust\n", - "INFO:llama_stack.distribution.stack:Scoring_fns: braintrust::context-entity-recall served by braintrust\n", - "INFO:llama_stack.distribution.stack:Scoring_fns: braintrust::context-precision served by braintrust\n", - "INFO:llama_stack.distribution.stack:Scoring_fns: braintrust::context-recall served by braintrust\n", - "INFO:llama_stack.distribution.stack:Scoring_fns: braintrust::context-relevancy served by braintrust\n", - "INFO:llama_stack.distribution.stack:\n" + "INFO:llama_stack.core.stack:Scoring_fns: basic::equality served by basic\n", + "INFO:llama_stack.core.stack:Scoring_fns: basic::subset_of served by basic\n", + "INFO:llama_stack.core.stack:Scoring_fns: basic::regex_parser_multiple_choice_answer served by basic\n", + "INFO:llama_stack.core.stack:Scoring_fns: braintrust::factuality served by braintrust\n", + "INFO:llama_stack.core.stack:Scoring_fns: braintrust::answer-correctness served by braintrust\n", + "INFO:llama_stack.core.stack:Scoring_fns: braintrust::answer-relevancy served by braintrust\n", + "INFO:llama_stack.core.stack:Scoring_fns: braintrust::answer-similarity served by braintrust\n", + "INFO:llama_stack.core.stack:Scoring_fns: braintrust::faithfulness served by braintrust\n", + "INFO:llama_stack.core.stack:Scoring_fns: braintrust::context-entity-recall served by braintrust\n", + "INFO:llama_stack.core.stack:Scoring_fns: braintrust::context-precision served by braintrust\n", + "INFO:llama_stack.core.stack:Scoring_fns: braintrust::context-recall served by braintrust\n", + "INFO:llama_stack.core.stack:Scoring_fns: braintrust::context-relevancy served by braintrust\n", + "INFO:llama_stack.core.stack:\n" ] }, { @@ -3448,7 +3448,7 @@ "\n", "os.environ['OPENAI_API_KEY'] = userdata.get('OPENAI_API_KEY')\n", "\n", - "from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n", + "from llama_stack.core.library_client import LlamaStackAsLibraryClient\n", "client = LlamaStackAsLibraryClient(\"experimental-post-training\")\n", "_ = client.initialize()" ] diff --git a/docs/notebooks/Llama_Stack_Agent_Workflows.ipynb b/docs/notebooks/Llama_Stack_Agent_Workflows.ipynb index cad28ab82..82f8566ba 100644 --- a/docs/notebooks/Llama_Stack_Agent_Workflows.ipynb +++ b/docs/notebooks/Llama_Stack_Agent_Workflows.ipynb @@ -38,7 +38,7 @@ "source": [ "# NBVAL_SKIP\n", "!pip install -U llama-stack\n", - "!UV_SYSTEM_PYTHON=1 llama stack build --template fireworks --image-type venv" + "!UV_SYSTEM_PYTHON=1 llama stack build --distro fireworks --image-type venv" ] }, { @@ -48,7 +48,7 @@ "outputs": [], "source": [ "from llama_stack_client import LlamaStackClient, Agent\n", - "from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n", + "from llama_stack.core.library_client import LlamaStackAsLibraryClient\n", "from rich.pretty import pprint\n", "import json\n", "import uuid\n", diff --git a/docs/notebooks/Llama_Stack_Benchmark_Evals.ipynb b/docs/notebooks/Llama_Stack_Benchmark_Evals.ipynb index 93f78d268..6e7d37cf2 100644 --- a/docs/notebooks/Llama_Stack_Benchmark_Evals.ipynb +++ b/docs/notebooks/Llama_Stack_Benchmark_Evals.ipynb @@ -57,7 +57,7 @@ "outputs": [], "source": [ "# NBVAL_SKIP\n", - "!UV_SYSTEM_PYTHON=1 llama stack build --template together --image-type venv" + "!UV_SYSTEM_PYTHON=1 llama stack build --distro together --image-type venv" ] }, { @@ -661,7 +661,7 @@ "except ImportError:\n", " print(\"Not in Google Colab environment\")\n", "\n", - "from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n", + "from llama_stack.core.library_client import LlamaStackAsLibraryClient\n", "\n", "client = LlamaStackAsLibraryClient(\"together\")\n", "_ = client.initialize()" diff --git a/docs/notebooks/Llama_Stack_RAG_Lifecycle.ipynb b/docs/notebooks/Llama_Stack_RAG_Lifecycle.ipynb index e70cc3bbe..769c91dfd 100644 --- a/docs/notebooks/Llama_Stack_RAG_Lifecycle.ipynb +++ b/docs/notebooks/Llama_Stack_RAG_Lifecycle.ipynb @@ -35,7 +35,7 @@ ], "source": [ "from llama_stack_client import LlamaStackClient, Agent\n", - "from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n", + "from llama_stack.core.library_client import LlamaStackAsLibraryClient\n", "from rich.pretty import pprint\n", "import json\n", "import uuid\n", diff --git a/docs/notebooks/nvidia/beginner_e2e/Llama_Stack_NVIDIA_E2E_Flow.ipynb b/docs/notebooks/nvidia/beginner_e2e/Llama_Stack_NVIDIA_E2E_Flow.ipynb index 583870404..d8f29d999 100644 --- a/docs/notebooks/nvidia/beginner_e2e/Llama_Stack_NVIDIA_E2E_Flow.ipynb +++ b/docs/notebooks/nvidia/beginner_e2e/Llama_Stack_NVIDIA_E2E_Flow.ipynb @@ -92,7 +92,7 @@ "metadata": {}, "source": [ "```bash\n", - "LLAMA_STACK_DIR=$(pwd) llama stack build --template nvidia --image-type venv\n", + "LLAMA_STACK_DIR=$(pwd) llama stack build --distro nvidia --image-type venv\n", "```" ] }, @@ -194,7 +194,7 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n", + "from llama_stack.core.library_client import LlamaStackAsLibraryClient\n", "\n", "client = LlamaStackAsLibraryClient(\"nvidia\")\n", "client.initialize()" diff --git a/docs/notebooks/nvidia/tool_calling/1_data_preparation.ipynb b/docs/notebooks/nvidia/tool_calling/1_data_preparation.ipynb index 6c7d61fbe..5fa5ef26b 100644 --- a/docs/notebooks/nvidia/tool_calling/1_data_preparation.ipynb +++ b/docs/notebooks/nvidia/tool_calling/1_data_preparation.ipynb @@ -81,7 +81,7 @@ "metadata": {}, "source": [ "```bash\n", - "LLAMA_STACK_DIR=$(pwd) llama stack build --template nvidia --image-type venv\n", + "LLAMA_STACK_DIR=$(pwd) llama stack build --distro nvidia --image-type venv\n", "```" ] }, diff --git a/docs/notebooks/nvidia/tool_calling/2_finetuning_and_inference.ipynb b/docs/notebooks/nvidia/tool_calling/2_finetuning_and_inference.ipynb index 647a16b6d..a80720a5f 100644 --- a/docs/notebooks/nvidia/tool_calling/2_finetuning_and_inference.ipynb +++ b/docs/notebooks/nvidia/tool_calling/2_finetuning_and_inference.ipynb @@ -56,7 +56,7 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n", + "from llama_stack.core.library_client import LlamaStackAsLibraryClient\n", "\n", "client = LlamaStackAsLibraryClient(\"nvidia\")\n", "client.initialize()" diff --git a/docs/notebooks/nvidia/tool_calling/3_model_evaluation.ipynb b/docs/notebooks/nvidia/tool_calling/3_model_evaluation.ipynb index 5a1316adb..91d1db88f 100644 --- a/docs/notebooks/nvidia/tool_calling/3_model_evaluation.ipynb +++ b/docs/notebooks/nvidia/tool_calling/3_model_evaluation.ipynb @@ -56,7 +56,7 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n", + "from llama_stack.core.library_client import LlamaStackAsLibraryClient\n", "\n", "client = LlamaStackAsLibraryClient(\"nvidia\")\n", "client.initialize()" diff --git a/docs/notebooks/nvidia/tool_calling/4_adding_safety_guardrails.ipynb b/docs/notebooks/nvidia/tool_calling/4_adding_safety_guardrails.ipynb index 699a561f9..25bcd0b69 100644 --- a/docs/notebooks/nvidia/tool_calling/4_adding_safety_guardrails.ipynb +++ b/docs/notebooks/nvidia/tool_calling/4_adding_safety_guardrails.ipynb @@ -56,7 +56,7 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_stack.distribution.library_client import LlamaStackAsLibraryClient\n", + "from llama_stack.core.library_client import LlamaStackAsLibraryClient\n", "\n", "client = LlamaStackAsLibraryClient(\"nvidia\")\n", "client.initialize()" diff --git a/docs/openapi_generator/README.md b/docs/openapi_generator/README.md index 7888e7828..85021d911 100644 --- a/docs/openapi_generator/README.md +++ b/docs/openapi_generator/README.md @@ -1 +1 @@ -The RFC Specification (OpenAPI format) is generated from the set of API endpoints located in `llama_stack/distribution/server/endpoints.py` using the `generate.py` utility. +The RFC Specification (OpenAPI format) is generated from the set of API endpoints located in `llama_stack.core/server/endpoints.py` using the `generate.py` utility. diff --git a/docs/openapi_generator/generate.py b/docs/openapi_generator/generate.py index 9fc375175..c27bc6440 100644 --- a/docs/openapi_generator/generate.py +++ b/docs/openapi_generator/generate.py @@ -17,7 +17,7 @@ import fire import ruamel.yaml as yaml from llama_stack.apis.version import LLAMA_STACK_API_VERSION # noqa: E402 -from llama_stack.distribution.stack import LlamaStack # noqa: E402 +from llama_stack.core.stack import LlamaStack # noqa: E402 from .pyopenapi.options import Options # noqa: E402 from .pyopenapi.specification import Info, Server # noqa: E402 diff --git a/docs/openapi_generator/pyopenapi/utility.py b/docs/openapi_generator/pyopenapi/utility.py index 57f92403d..d302b114f 100644 --- a/docs/openapi_generator/pyopenapi/utility.py +++ b/docs/openapi_generator/pyopenapi/utility.py @@ -12,7 +12,7 @@ from typing import TextIO from typing import Any, List, Optional, Union, get_type_hints, get_origin, get_args from llama_stack.strong_typing.schema import object_to_json, StrictJsonType -from llama_stack.distribution.resolver import api_protocol_map +from llama_stack.core.resolver import api_protocol_map from .generator import Generator from .options import Options diff --git a/docs/original_rfc.md b/docs/original_rfc.md index dc95a04cb..e9191cb6d 100644 --- a/docs/original_rfc.md +++ b/docs/original_rfc.md @@ -73,7 +73,7 @@ The API is defined in the [YAML](_static/llama-stack-spec.yaml) and [HTML](_stat To prove out the API, we implemented a handful of use cases to make things more concrete. The [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps) repository contains [6 different examples](https://github.com/meta-llama/llama-stack-apps/tree/main/examples) ranging from very basic to a multi turn agent. -There is also a sample inference endpoint implementation in the [llama-stack](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/distribution/server/server.py) repository. +There is also a sample inference endpoint implementation in the [llama-stack](https://github.com/meta-llama/llama-stack/blob/main/llama_stack.core/server/server.py) repository. ## Limitations diff --git a/docs/quick_start.ipynb b/docs/quick_start.ipynb index 482815aa5..757824578 100644 --- a/docs/quick_start.ipynb +++ b/docs/quick_start.ipynb @@ -145,12 +145,12 @@ " del os.environ[\"UV_SYSTEM_PYTHON\"]\n", "\n", "# this command installs all the dependencies needed for the llama stack server with the ollama inference provider\n", - "!uv run --with llama-stack llama stack build --template starter --image-type venv\n", + "!uv run --with llama-stack llama stack build --distro starter --image-type venv\n", "\n", "def run_llama_stack_server_background():\n", " log_file = open(\"llama_stack_server.log\", \"w\")\n", " process = subprocess.Popen(\n", - " f\"uv run --with llama-stack llama stack run starter --image-type venv --env INFERENCE_MODEL=llama3.2:3b\",\n", + " f\"OLLAMA_URL=http://localhost:11434 uv run --with llama-stack llama stack run starter --image-type venv", " shell=True,\n", " stdout=log_file,\n", " stderr=log_file,\n", @@ -187,7 +187,7 @@ "# use this helper if needed to kill the server \n", "def kill_llama_stack_server():\n", " # Kill any existing llama stack server processes\n", - " os.system(\"ps aux | grep -v grep | grep llama_stack.distribution.server.server | awk '{print $2}' | xargs kill -9\")\n" + " os.system(\"ps aux | grep -v grep | grep llama_stack.core.server.server | awk '{print $2}' | xargs kill -9\")\n" ] }, { diff --git a/docs/source/advanced_apis/eval/inline_meta-reference.md b/docs/source/advanced_apis/eval/inline_meta-reference.md index 606883c72..5bec89cfc 100644 --- a/docs/source/advanced_apis/eval/inline_meta-reference.md +++ b/docs/source/advanced_apis/eval/inline_meta-reference.md @@ -1,3 +1,7 @@ +--- +orphan: true +--- + # inline::meta-reference ## Description diff --git a/docs/source/advanced_apis/eval/remote_nvidia.md b/docs/source/advanced_apis/eval/remote_nvidia.md index cb764b511..ab91767d6 100644 --- a/docs/source/advanced_apis/eval/remote_nvidia.md +++ b/docs/source/advanced_apis/eval/remote_nvidia.md @@ -1,3 +1,7 @@ +--- +orphan: true +--- + # remote::nvidia ## Description diff --git a/docs/source/advanced_apis/evaluation_concepts.md b/docs/source/advanced_apis/evaluation_concepts.md index 3f03d098f..c26ec8f5e 100644 --- a/docs/source/advanced_apis/evaluation_concepts.md +++ b/docs/source/advanced_apis/evaluation_concepts.md @@ -43,7 +43,7 @@ We have built-in functionality to run the supported open-benckmarks using llama- Spin up llama stack server with 'open-benchmark' template ``` -llama stack run llama_stack/templates/open-benchmark/run.yaml +llama stack run llama_stack/distributions/open-benchmark/run.yaml ``` diff --git a/docs/source/advanced_apis/post_training/huggingface.md b/docs/source/advanced_apis/post_training/huggingface.md index c7896aaf4..a7609d6da 100644 --- a/docs/source/advanced_apis/post_training/huggingface.md +++ b/docs/source/advanced_apis/post_training/huggingface.md @@ -23,7 +23,7 @@ To use the HF SFTTrainer in your Llama Stack project, follow these steps: You can access the HuggingFace trainer via the `ollama` distribution: ```bash -llama stack build --template starter --image-type venv +llama stack build --distro starter --image-type venv llama stack run --image-type venv ~/.llama/distributions/ollama/ollama-run.yaml ``` diff --git a/docs/source/advanced_apis/post_training/inline_huggingface.md b/docs/source/advanced_apis/post_training/inline_huggingface.md index 367258a1d..4d2201c99 100644 --- a/docs/source/advanced_apis/post_training/inline_huggingface.md +++ b/docs/source/advanced_apis/post_training/inline_huggingface.md @@ -1,3 +1,7 @@ +--- +orphan: true +--- + # inline::huggingface ## Description diff --git a/docs/source/advanced_apis/post_training/inline_torchtune.md b/docs/source/advanced_apis/post_training/inline_torchtune.md index 82730e54b..6684c99ac 100644 --- a/docs/source/advanced_apis/post_training/inline_torchtune.md +++ b/docs/source/advanced_apis/post_training/inline_torchtune.md @@ -1,3 +1,7 @@ +--- +orphan: true +--- + # inline::torchtune ## Description diff --git a/docs/source/advanced_apis/post_training/remote_nvidia.md b/docs/source/advanced_apis/post_training/remote_nvidia.md index 9a381d872..9840fa3c4 100644 --- a/docs/source/advanced_apis/post_training/remote_nvidia.md +++ b/docs/source/advanced_apis/post_training/remote_nvidia.md @@ -1,3 +1,7 @@ +--- +orphan: true +--- + # remote::nvidia ## Description diff --git a/docs/source/advanced_apis/scoring/inline_basic.md b/docs/source/advanced_apis/scoring/inline_basic.md index e9e50cff4..b56b36013 100644 --- a/docs/source/advanced_apis/scoring/inline_basic.md +++ b/docs/source/advanced_apis/scoring/inline_basic.md @@ -1,3 +1,7 @@ +--- +orphan: true +--- + # inline::basic ## Description diff --git a/docs/source/advanced_apis/scoring/inline_braintrust.md b/docs/source/advanced_apis/scoring/inline_braintrust.md index 70a6a1e26..d1278217c 100644 --- a/docs/source/advanced_apis/scoring/inline_braintrust.md +++ b/docs/source/advanced_apis/scoring/inline_braintrust.md @@ -1,3 +1,7 @@ +--- +orphan: true +--- + # inline::braintrust ## Description diff --git a/docs/source/advanced_apis/scoring/inline_llm-as-judge.md b/docs/source/advanced_apis/scoring/inline_llm-as-judge.md index 971e02897..c7fcddf37 100644 --- a/docs/source/advanced_apis/scoring/inline_llm-as-judge.md +++ b/docs/source/advanced_apis/scoring/inline_llm-as-judge.md @@ -1,3 +1,7 @@ +--- +orphan: true +--- + # inline::llm-as-judge ## Description diff --git a/docs/source/apis/external.md b/docs/source/apis/external.md index 025267c33..cc13deb9b 100644 --- a/docs/source/apis/external.md +++ b/docs/source/apis/external.md @@ -355,7 +355,7 @@ server: 8. Run the server: ```bash -python -m llama_stack.distribution.server.server --yaml-config ~/.llama/run-byoa.yaml +python -m llama_stack.core.server.server --yaml-config ~/.llama/run-byoa.yaml ``` 9. Test the API: diff --git a/docs/source/building_applications/playground/index.md b/docs/source/building_applications/playground/index.md index 85895f6a5..fd2b92434 100644 --- a/docs/source/building_applications/playground/index.md +++ b/docs/source/building_applications/playground/index.md @@ -97,11 +97,11 @@ To start the Llama Stack Playground, run the following commands: 1. Start up the Llama Stack API server ```bash -llama stack build --template together --image-type conda +llama stack build --distro together --image-type venv llama stack run together ``` 2. Start Streamlit UI ```bash -uv run --with ".[ui]" streamlit run llama_stack/distribution/ui/app.py +uv run --with ".[ui]" streamlit run llama_stack.core/ui/app.py ``` diff --git a/docs/source/contributing/index.md b/docs/source/contributing/index.md index 8e4f5e867..1e067ea6c 100644 --- a/docs/source/contributing/index.md +++ b/docs/source/contributing/index.md @@ -11,4 +11,5 @@ See the [Adding a New API Provider](new_api_provider.md) which describes how to :hidden: new_api_provider +testing ``` diff --git a/docs/source/contributing/new_api_provider.md b/docs/source/contributing/new_api_provider.md index 01a8ec093..6f8f59a47 100644 --- a/docs/source/contributing/new_api_provider.md +++ b/docs/source/contributing/new_api_provider.md @@ -6,7 +6,7 @@ This guide will walk you through the process of adding a new API provider to Lla - Begin by reviewing the [core concepts](../concepts/index.md) of Llama Stack and choose the API your provider belongs to (Inference, Safety, VectorIO, etc.) - Determine the provider type ({repopath}`Remote::llama_stack/providers/remote` or {repopath}`Inline::llama_stack/providers/inline`). Remote providers make requests to external services, while inline providers execute implementation locally. - Add your provider to the appropriate {repopath}`Registry::llama_stack/providers/registry/`. Specify pip dependencies necessary. -- Update any distribution {repopath}`Templates::llama_stack/templates/` `build.yaml` and `run.yaml` files if they should include your provider by default. Run {repopath}`./scripts/distro_codegen.py` if necessary. Note that `distro_codegen.py` will fail if the new provider causes any distribution template to attempt to import provider-specific dependencies. This usually means the distribution's `get_distribution_template()` code path should only import any necessary Config or model alias definitions from each provider and not the provider's actual implementation. +- Update any distribution {repopath}`Templates::llama_stack/distributions/` `build.yaml` and `run.yaml` files if they should include your provider by default. Run {repopath}`./scripts/distro_codegen.py` if necessary. Note that `distro_codegen.py` will fail if the new provider causes any distribution template to attempt to import provider-specific dependencies. This usually means the distribution's `get_distribution_template()` code path should only import any necessary Config or model alias definitions from each provider and not the provider's actual implementation. Here are some example PRs to help you get started: @@ -52,7 +52,7 @@ def get_base_url(self) -> str: ## Testing the Provider -Before running tests, you must have required dependencies installed. This depends on the providers or distributions you are testing. For example, if you are testing the `together` distribution, you should install dependencies via `llama stack build --template together`. +Before running tests, you must have required dependencies installed. This depends on the providers or distributions you are testing. For example, if you are testing the `together` distribution, you should install dependencies via `llama stack build --distro together`. ### 1. Integration Testing diff --git a/docs/source/deploying/kubernetes_deployment.md b/docs/source/deploying/kubernetes_deployment.md index 7e9791d8d..4bdd87b24 100644 --- a/docs/source/deploying/kubernetes_deployment.md +++ b/docs/source/deploying/kubernetes_deployment.md @@ -174,7 +174,7 @@ spec: - name: llama-stack image: localhost/llama-stack-run-k8s:latest imagePullPolicy: IfNotPresent - command: ["python", "-m", "llama_stack.distribution.server.server", "--config", "/app/config.yaml"] + command: ["python", "-m", "llama_stack.core.server.server", "--config", "/app/config.yaml"] ports: - containerPort: 5000 volumeMounts: diff --git a/docs/source/distributions/building_distro.md b/docs/source/distributions/building_distro.md index cd2c6b6a8..24098708f 100644 --- a/docs/source/distributions/building_distro.md +++ b/docs/source/distributions/building_distro.md @@ -47,30 +47,37 @@ pip install -e . ``` Use the CLI to build your distribution. The main points to consider are: -1. **Image Type** - Do you want a Conda / venv environment or a Container (eg. Docker) +1. **Image Type** - Do you want a venv environment or a Container (eg. Docker) 2. **Template** - Do you want to use a template to build your distribution? or start from scratch ? 3. **Config** - Do you want to use a pre-existing config file to build your distribution? ``` llama stack build -h -usage: llama stack build [-h] [--config CONFIG] [--template TEMPLATE] [--list-templates] [--image-type {conda,container,venv}] [--image-name IMAGE_NAME] [--print-deps-only] [--run] +usage: llama stack build [-h] [--config CONFIG] [--template TEMPLATE] [--distro DISTRIBUTION] [--list-distros] [--image-type {container,venv}] [--image-name IMAGE_NAME] [--print-deps-only] + [--run] [--providers PROVIDERS] Build a Llama stack container options: -h, --help show this help message and exit - --config CONFIG Path to a config file to use for the build. You can find example configs in llama_stack/distributions/**/build.yaml. If this argument is not provided, you will - be prompted to enter information interactively (default: None) - --template TEMPLATE Name of the example template config to use for build. You may use `llama stack build --list-templates` to check out the available templates (default: None) - --list-templates Show the available templates for building a Llama Stack distribution (default: False) - --image-type {conda,container,venv} + --config CONFIG Path to a config file to use for the build. You can find example configs in llama_stack.cores/**/build.yaml. If this argument is not provided, you will be prompted to + enter information interactively (default: None) + --template TEMPLATE (deprecated) Name of the example template config to use for build. You may use `llama stack build --list-distros` to check out the available distributions (default: + None) + --distro DISTRIBUTION, --distribution DISTRIBUTION + Name of the distribution to use for build. You may use `llama stack build --list-distros` to check out the available distributions (default: None) + --list-distros, --list-distributions + Show the available distributions for building a Llama Stack distribution (default: False) + --image-type {container,venv} Image Type to use for the build. If not specified, will use the image type from the template config. (default: None) --image-name IMAGE_NAME - [for image-type=conda|container|venv] Name of the conda or virtual environment to use for the build. If not specified, currently active environment will be used if - found. (default: None) + [for image-type=container|venv] Name of the virtual environment to use for the build. If not specified, currently active environment will be used if found. (default: + None) --print-deps-only Print the dependencies for the stack only, without building the stack (default: False) --run Run the stack after building using the same image type, name, and other applicable arguments (default: False) - + --providers PROVIDERS + Build a config for a list of providers and only those providers. This list is formatted like: api1=provider1,api2=provider2. Where there can be multiple providers per + API. (default: None) ``` After this step is complete, a file named `-build.yaml` and template file `-run.yaml` will be generated and saved at the output file path specified at the end of the command. @@ -141,7 +148,7 @@ You may then pick a template to build your distribution with providers fitted to For example, to build a distribution with TGI as the inference provider, you can run: ``` -$ llama stack build --template starter +$ llama stack build --distro starter ... You can now edit ~/.llama/distributions/llamastack-starter/starter-run.yaml and run `llama stack run ~/.llama/distributions/llamastack-starter/starter-run.yaml` ``` @@ -159,7 +166,7 @@ It would be best to start with a template and understand the structure of the co llama stack build > Enter a name for your Llama Stack (e.g. my-local-stack): my-stack -> Enter the image type you want your Llama Stack to be built as (container or conda or venv): conda +> Enter the image type you want your Llama Stack to be built as (container or venv): venv Llama Stack is composed of several APIs working together. Let's select the provider types (implementations) you want to use for these APIs. @@ -184,10 +191,10 @@ You can now edit ~/.llama/distributions/llamastack-my-local-stack/my-local-stack :::{tab-item} Building from a pre-existing build config file - In addition to templates, you may customize the build to your liking through editing config files and build from config files with the following command. -- The config file will be of contents like the ones in `llama_stack/templates/*build.yaml`. +- The config file will be of contents like the ones in `llama_stack/distributions/*build.yaml`. ``` -llama stack build --config llama_stack/templates/starter/build.yaml +llama stack build --config llama_stack/distributions/starter/build.yaml ``` ::: @@ -253,11 +260,11 @@ Podman is supported as an alternative to Docker. Set `CONTAINER_BINARY` to `podm To build a container image, you may start off from a template and use the `--image-type container` flag to specify `container` as the build image type. ``` -llama stack build --template starter --image-type container +llama stack build --distro starter --image-type container ``` ``` -$ llama stack build --template starter --image-type container +$ llama stack build --distro starter --image-type container ... Containerfile created successfully in /tmp/tmp.viA3a3Rdsg/ContainerfileFROM python:3.10-slim ... @@ -312,7 +319,7 @@ Now, let's start the Llama Stack Distribution Server. You will need the YAML con ``` llama stack run -h usage: llama stack run [-h] [--port PORT] [--image-name IMAGE_NAME] [--env KEY=VALUE] - [--image-type {conda,venv}] [--enable-ui] + [--image-type {venv}] [--enable-ui] [config | template] Start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution. @@ -326,8 +333,8 @@ options: --image-name IMAGE_NAME Name of the image to run. Defaults to the current environment (default: None) --env KEY=VALUE Environment variables to pass to the server in KEY=VALUE format. Can be specified multiple times. (default: None) - --image-type {conda,venv} - Image Type used during the build. This can be either conda or venv. (default: None) + --image-type {venv} + Image Type used during the build. This should be venv. (default: None) --enable-ui Start the UI server (default: False) ``` @@ -342,9 +349,6 @@ llama stack run ~/.llama/distributions/llamastack-my-local-stack/my-local-stack- # Start using a venv llama stack run --image-type venv ~/.llama/distributions/llamastack-my-local-stack/my-local-stack-run.yaml - -# Start using a conda environment -llama stack run --image-type conda ~/.llama/distributions/llamastack-my-local-stack/my-local-stack-run.yaml ``` ``` diff --git a/docs/source/distributions/configuration.md b/docs/source/distributions/configuration.md index 775749dd6..335fa3a68 100644 --- a/docs/source/distributions/configuration.md +++ b/docs/source/distributions/configuration.md @@ -10,7 +10,6 @@ The default `run.yaml` files generated by templates are starting points for your ```yaml version: 2 -conda_env: ollama apis: - agents - inference diff --git a/docs/source/distributions/importing_as_library.md b/docs/source/distributions/importing_as_library.md index 3427356a7..fbc48dd95 100644 --- a/docs/source/distributions/importing_as_library.md +++ b/docs/source/distributions/importing_as_library.md @@ -6,11 +6,11 @@ This avoids the overhead of setting up a server. ```bash # setup uv pip install llama-stack -llama stack build --template starter --image-type venv +llama stack build --distro starter --image-type venv ``` ```python -from llama_stack.distribution.library_client import LlamaStackAsLibraryClient +from llama_stack.core.library_client import LlamaStackAsLibraryClient client = LlamaStackAsLibraryClient( "starter", diff --git a/docs/source/distributions/index.md b/docs/source/distributions/index.md index fce0347d3..2a702c282 100644 --- a/docs/source/distributions/index.md +++ b/docs/source/distributions/index.md @@ -9,6 +9,7 @@ This section provides an overview of the distributions available in Llama Stack. list_of_distributions building_distro customizing_run_yaml +starting_llama_stack_server importing_as_library configuration ``` diff --git a/docs/source/distributions/k8s/stack-configmap.yaml b/docs/source/distributions/k8s/stack-configmap.yaml index c505cba49..4f95554e3 100644 --- a/docs/source/distributions/k8s/stack-configmap.yaml +++ b/docs/source/distributions/k8s/stack-configmap.yaml @@ -34,6 +34,13 @@ data: provider_type: remote::chromadb config: url: ${env.CHROMADB_URL:=} + kvstore: + type: postgres + host: ${env.POSTGRES_HOST:=localhost} + port: ${env.POSTGRES_PORT:=5432} + db: ${env.POSTGRES_DB:=llamastack} + user: ${env.POSTGRES_USER:=llamastack} + password: ${env.POSTGRES_PASSWORD:=llamastack} safety: - provider_id: llama-guard provider_type: inline::llama-guard diff --git a/docs/source/distributions/k8s/stack-k8s.yaml.template b/docs/source/distributions/k8s/stack-k8s.yaml.template index 912445f68..ad5d2c716 100644 --- a/docs/source/distributions/k8s/stack-k8s.yaml.template +++ b/docs/source/distributions/k8s/stack-k8s.yaml.template @@ -52,7 +52,7 @@ spec: value: "${SAFETY_MODEL}" - name: TAVILY_SEARCH_API_KEY value: "${TAVILY_SEARCH_API_KEY}" - command: ["python", "-m", "llama_stack.distribution.server.server", "--config", "/etc/config/stack_run_config.yaml", "--port", "8321"] + command: ["python", "-m", "llama_stack.core.server.server", "--config", "/etc/config/stack_run_config.yaml", "--port", "8321"] ports: - containerPort: 8321 volumeMounts: diff --git a/docs/source/distributions/k8s/stack_run_config.yaml b/docs/source/distributions/k8s/stack_run_config.yaml index 4da1bd8b4..a2d65e1a9 100644 --- a/docs/source/distributions/k8s/stack_run_config.yaml +++ b/docs/source/distributions/k8s/stack_run_config.yaml @@ -31,6 +31,13 @@ providers: provider_type: remote::chromadb config: url: ${env.CHROMADB_URL:=} + kvstore: + type: postgres + host: ${env.POSTGRES_HOST:=localhost} + port: ${env.POSTGRES_PORT:=5432} + db: ${env.POSTGRES_DB:=llamastack} + user: ${env.POSTGRES_USER:=llamastack} + password: ${env.POSTGRES_PASSWORD:=llamastack} safety: - provider_id: llama-guard provider_type: inline::llama-guard diff --git a/docs/source/distributions/ondevice_distro/android_sdk.md b/docs/source/distributions/ondevice_distro/android_sdk.md index 1cddf1d1f..9d16d07d7 100644 --- a/docs/source/distributions/ondevice_distro/android_sdk.md +++ b/docs/source/distributions/ondevice_distro/android_sdk.md @@ -56,12 +56,12 @@ Breaking down the demo app, this section will show the core pieces that are used ### Setup Remote Inferencing Start a Llama Stack server on localhost. Here is an example of how you can do this using the firework.ai distribution: ``` -conda create -n stack-fireworks python=3.10 -conda activate stack-fireworks +uv venv starter --python 3.12 +source starter/bin/activate # On Windows: starter\Scripts\activate pip install --no-cache llama-stack==0.2.2 -llama stack build --template fireworks --image-type conda +llama stack build --distro starter --image-type venv export FIREWORKS_API_KEY= -llama stack run fireworks --port 5050 +llama stack run starter --port 5050 ``` Ensure the Llama Stack server version is the same as the Kotlin SDK Library for maximum compatibility. diff --git a/docs/source/distributions/remote_hosted_distro/watsonx.md b/docs/source/distributions/remote_hosted_distro/watsonx.md index ec1b98059..977af90dd 100644 --- a/docs/source/distributions/remote_hosted_distro/watsonx.md +++ b/docs/source/distributions/remote_hosted_distro/watsonx.md @@ -57,7 +57,7 @@ Make sure you have access to a watsonx API Key. You can get one by referring [wa ## Running Llama Stack with watsonx -You can do this via Conda (build code), venv or Docker which has a pre-built image. +You can do this via venv or Docker which has a pre-built image. ### Via Docker @@ -76,13 +76,3 @@ docker run \ --env WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID \ --env WATSONX_BASE_URL=$WATSONX_BASE_URL ``` - -### Via Conda - -```bash -llama stack build --template watsonx --image-type conda -llama stack run ./run.yaml \ - --port $LLAMA_STACK_PORT \ - --env WATSONX_API_KEY=$WATSONX_API_KEY \ - --env WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID -``` diff --git a/docs/source/distributions/self_hosted_distro/dell.md b/docs/source/distributions/self_hosted_distro/dell.md index eded3bdc4..68e7b6f58 100644 --- a/docs/source/distributions/self_hosted_distro/dell.md +++ b/docs/source/distributions/self_hosted_distro/dell.md @@ -114,7 +114,7 @@ podman run --rm -it \ ## Running Llama Stack -Now you are ready to run Llama Stack with TGI as the inference provider. You can do this via Conda (build code) or Docker which has a pre-built image. +Now you are ready to run Llama Stack with TGI as the inference provider. You can do this via venv or Docker which has a pre-built image. ### Via Docker @@ -153,7 +153,7 @@ docker run \ --pull always \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -v $HOME/.llama:/root/.llama \ - -v ./llama_stack/templates/tgi/run-with-safety.yaml:/root/my-run.yaml \ + -v ./llama_stack/distributions/tgi/run-with-safety.yaml:/root/my-run.yaml \ llamastack/distribution-dell \ --config /root/my-run.yaml \ --port $LLAMA_STACK_PORT \ @@ -164,12 +164,12 @@ docker run \ --env CHROMA_URL=$CHROMA_URL ``` -### Via Conda +### Via venv Make sure you have done `pip install llama-stack` and have the Llama Stack CLI available. ```bash -llama stack build --template dell --image-type conda +llama stack build --distro dell --image-type venv llama stack run dell --port $LLAMA_STACK_PORT \ --env INFERENCE_MODEL=$INFERENCE_MODEL \ diff --git a/docs/source/distributions/self_hosted_distro/meta-reference-gpu.md b/docs/source/distributions/self_hosted_distro/meta-reference-gpu.md index 8b9dcec55..7e50a4161 100644 --- a/docs/source/distributions/self_hosted_distro/meta-reference-gpu.md +++ b/docs/source/distributions/self_hosted_distro/meta-reference-gpu.md @@ -70,7 +70,7 @@ $ llama model list --downloaded ## Running the Distribution -You can do this via Conda (build code) or Docker which has a pre-built image. +You can do this via venv or Docker which has a pre-built image. ### Via Docker @@ -104,12 +104,12 @@ docker run \ --env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B ``` -### Via Conda +### Via venv Make sure you have done `uv pip install llama-stack` and have the Llama Stack CLI available. ```bash -llama stack build --template meta-reference-gpu --image-type conda +llama stack build --distro meta-reference-gpu --image-type venv llama stack run distributions/meta-reference-gpu/run.yaml \ --port 8321 \ --env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct diff --git a/docs/source/distributions/self_hosted_distro/nvidia.md b/docs/source/distributions/self_hosted_distro/nvidia.md index aeb14e6a6..6e399e6ce 100644 --- a/docs/source/distributions/self_hosted_distro/nvidia.md +++ b/docs/source/distributions/self_hosted_distro/nvidia.md @@ -133,7 +133,7 @@ curl -X DELETE "$NEMO_URL/v1/deployment/model-deployments/meta/llama-3.1-8b-inst ## Running Llama Stack with NVIDIA -You can do this via Conda or venv (build code), or Docker which has a pre-built image. +You can do this via venv (build code), or Docker which has a pre-built image. ### Via Docker @@ -152,24 +152,13 @@ docker run \ --env NVIDIA_API_KEY=$NVIDIA_API_KEY ``` -### Via Conda - -```bash -INFERENCE_MODEL=meta-llama/Llama-3.1-8b-Instruct -llama stack build --template nvidia --image-type conda -llama stack run ./run.yaml \ - --port 8321 \ - --env NVIDIA_API_KEY=$NVIDIA_API_KEY \ - --env INFERENCE_MODEL=$INFERENCE_MODEL -``` - ### Via venv If you've set up your local development environment, you can also build the image using your local virtual environment. ```bash INFERENCE_MODEL=meta-llama/Llama-3.1-8b-Instruct -llama stack build --template nvidia --image-type venv +llama stack build --distro nvidia --image-type venv llama stack run ./run.yaml \ --port 8321 \ --env NVIDIA_API_KEY=$NVIDIA_API_KEY \ diff --git a/docs/source/distributions/self_hosted_distro/starter.md b/docs/source/distributions/self_hosted_distro/starter.md index 58a3e4411..9218f7f81 100644 --- a/docs/source/distributions/self_hosted_distro/starter.md +++ b/docs/source/distributions/self_hosted_distro/starter.md @@ -100,10 +100,6 @@ The following environment variables can be configured: ### Model Configuration - `INFERENCE_MODEL`: HuggingFace model for serverless inference - `INFERENCE_ENDPOINT_NAME`: HuggingFace endpoint name -- `OLLAMA_INFERENCE_MODEL`: Ollama model name -- `OLLAMA_EMBEDDING_MODEL`: Ollama embedding model name -- `OLLAMA_EMBEDDING_DIMENSION`: Ollama embedding dimension (default: `384`) -- `VLLM_INFERENCE_MODEL`: vLLM model name ### Vector Database Configuration - `SQLITE_STORE_DIR`: SQLite store directory (default: `~/.llama/distributions/starter`) @@ -127,47 +123,29 @@ The following environment variables can be configured: ## Enabling Providers -You can enable specific providers by setting their provider ID to a valid value using environment variables. This is useful when you want to use certain providers or don't have the required API keys. +You can enable specific providers by setting appropriate environment variables. For example, -### Examples of Enabling Providers - -#### Enable FAISS Vector Provider ```bash -export ENABLE_FAISS=faiss +# self-hosted +export OLLAMA_URL=http://localhost:11434 # enables the Ollama inference provider +export VLLM_URL=http://localhost:8000/v1 # enables the vLLM inference provider +export TGI_URL=http://localhost:8000/v1 # enables the TGI inference provider + +# cloud-hosted requiring API key configuration on the server +export CEREBRAS_API_KEY=your_cerebras_api_key # enables the Cerebras inference provider +export NVIDIA_API_KEY=your_nvidia_api_key # enables the NVIDIA inference provider + +# vector providers +export MILVUS_URL=http://localhost:19530 # enables the Milvus vector provider +export CHROMADB_URL=http://localhost:8000/v1 # enables the ChromaDB vector provider +export PGVECTOR_DB=llama_stack_db # enables the PGVector vector provider ``` -#### Enable Ollama Models -```bash -export ENABLE_OLLAMA=ollama -``` - -#### Disable vLLM Models -```bash -export VLLM_INFERENCE_MODEL=__disabled__ -``` - -#### Disable Optional Vector Providers -```bash -export ENABLE_SQLITE_VEC=__disabled__ -export ENABLE_CHROMADB=__disabled__ -export ENABLE_PGVECTOR=__disabled__ -``` - -### Provider ID Patterns - -The starter distribution uses several patterns for provider IDs: - -1. **Direct provider IDs**: `faiss`, `ollama`, `vllm` -2. **Environment-based provider IDs**: `${env.ENABLE_SQLITE_VEC:+sqlite-vec}` -3. **Model-based provider IDs**: `${env.OLLAMA_INFERENCE_MODEL:__disabled__}` - -When using the `+` pattern (like `${env.ENABLE_SQLITE_VEC+sqlite-vec}`), the provider is enabled by default and can be disabled by setting the environment variable to `__disabled__`. - -When using the `:` pattern (like `${env.OLLAMA_INFERENCE_MODEL:__disabled__}`), the provider is disabled by default and can be enabled by setting the environment variable to a valid value. +This distribution comes with a default "llama-guard" shield that can be enabled by setting the `SAFETY_MODEL` environment variable to point to an appropriate Llama Guard model id. Use `llama-stack-client models list` to see the list of available models. ## Running the Distribution -You can run the starter distribution via Docker, Conda, or venv. +You can run the starter distribution via Docker or venv. ### Via Docker @@ -186,12 +164,12 @@ docker run \ --port $LLAMA_STACK_PORT ``` -### Via Conda or venv +### Via venv Ensure you have configured the starter distribution using the environment variables explained above. ```bash -uv run --with llama-stack llama stack build --template starter --image-type --run +uv run --with llama-stack llama stack build --distro starter --image-type venv --run ``` ## Example Usage diff --git a/docs/source/distributions/starting_llama_stack_server.md b/docs/source/distributions/starting_llama_stack_server.md index 91cb1fe88..1a26694a6 100644 --- a/docs/source/distributions/starting_llama_stack_server.md +++ b/docs/source/distributions/starting_llama_stack_server.md @@ -11,12 +11,6 @@ This is the simplest way to get started. Using Llama Stack as a library means yo Another simple way to start interacting with Llama Stack is to just spin up a container (via Docker or Podman) which is pre-built with all the providers you need. We provide a number of pre-built images so you can start a Llama Stack server instantly. You can also build your own custom container. Which distribution to choose depends on the hardware you have. See [Selection of a Distribution](selection) for more details. - -## Conda: - -If you have a custom or an advanced setup or you are developing on Llama Stack you can also build a custom Llama Stack server. Using `llama stack build` and `llama stack run` you can build/run a custom Llama Stack server containing the exact combination of providers you wish. We have also provided various templates to make getting started easier. See [Building a Custom Distribution](building_distro) for more details. - - ## Kubernetes: If you have built a container image and want to deploy it in a Kubernetes cluster instead of starting the Llama Stack server locally. See [Kubernetes Deployment Guide](kubernetes_deployment) for more details. diff --git a/docs/source/getting_started/detailed_tutorial.md b/docs/source/getting_started/detailed_tutorial.md index c6589e758..14f888628 100644 --- a/docs/source/getting_started/detailed_tutorial.md +++ b/docs/source/getting_started/detailed_tutorial.md @@ -59,10 +59,10 @@ Now let's build and run the Llama Stack config for Ollama. We use `starter` as template. By default all providers are disabled, this requires enable ollama by passing environment variables. ```bash -llama stack build --template starter --image-type venv --run +llama stack build --distro starter --image-type venv --run ``` ::: -:::{tab-item} Using `conda` +:::{tab-item} Using `venv` You can use Python to build and run the Llama Stack server, which is useful for testing and development. Llama Stack uses a [YAML configuration file](../distributions/configuration.md) to specify the stack setup, @@ -70,7 +70,7 @@ which defines the providers and their settings. Now let's build and run the Llama Stack config for Ollama. ```bash -llama stack build --template starter --image-type conda --run +llama stack build --distro starter --image-type venv --run ``` ::: :::{tab-item} Using a Container @@ -150,13 +150,7 @@ pip install llama-stack-client ``` ::: -:::{tab-item} Install with `conda` -```bash -yes | conda create -n stack-client python=3.12 -conda activate stack-client -pip install llama-stack-client -``` -::: + :::: Now let's use the `llama-stack-client` [CLI](../references/llama_stack_client_cli_reference.md) to check the diff --git a/docs/source/getting_started/quickstart.md b/docs/source/getting_started/quickstart.md index b66fabc77..0136a7fba 100644 --- a/docs/source/getting_started/quickstart.md +++ b/docs/source/getting_started/quickstart.md @@ -16,10 +16,13 @@ as the inference [provider](../providers/inference/index) for a Llama Model. ```bash ollama run llama3.2:3b --keepalive 60m ``` + #### Step 2: Run the Llama Stack server + We will use `uv` to run the Llama Stack server. ```bash -uv run --with llama-stack llama stack build --template starter --image-type venv --run +OLLAMA_URL=http://localhost:11434 \ + uv run --with llama-stack llama stack build --distro starter --image-type venv --run ``` #### Step 3: Run the demo Now open up a new terminal and copy the following script into a file named `demo_script.py`. diff --git a/docs/source/providers/agents/index.md b/docs/source/providers/agents/index.md index ebc134ce9..92bf9edc0 100644 --- a/docs/source/providers/agents/index.md +++ b/docs/source/providers/agents/index.md @@ -1,5 +1,13 @@ -# Agents Providers +# Agents + +## Overview This section contains documentation for all available providers for the **agents** API. -- [inline::meta-reference](inline_meta-reference.md) \ No newline at end of file +## Providers + +```{toctree} +:maxdepth: 1 + +inline_meta-reference +``` diff --git a/docs/source/providers/datasetio/index.md b/docs/source/providers/datasetio/index.md index 726bc75b8..94a97e2ed 100644 --- a/docs/source/providers/datasetio/index.md +++ b/docs/source/providers/datasetio/index.md @@ -1,7 +1,15 @@ -# Datasetio Providers +# Datasetio + +## Overview This section contains documentation for all available providers for the **datasetio** API. -- [inline::localfs](inline_localfs.md) -- [remote::huggingface](remote_huggingface.md) -- [remote::nvidia](remote_nvidia.md) \ No newline at end of file +## Providers + +```{toctree} +:maxdepth: 1 + +inline_localfs +remote_huggingface +remote_nvidia +``` diff --git a/docs/source/providers/eval/index.md b/docs/source/providers/eval/index.md index 330380670..d180d256c 100644 --- a/docs/source/providers/eval/index.md +++ b/docs/source/providers/eval/index.md @@ -1,6 +1,14 @@ -# Eval Providers +# Eval + +## Overview This section contains documentation for all available providers for the **eval** API. -- [inline::meta-reference](inline_meta-reference.md) -- [remote::nvidia](remote_nvidia.md) \ No newline at end of file +## Providers + +```{toctree} +:maxdepth: 1 + +inline_meta-reference +remote_nvidia +``` diff --git a/docs/source/providers/external.md b/docs/source/providers/external/external-providers-guide.md similarity index 85% rename from docs/source/providers/external.md rename to docs/source/providers/external/external-providers-guide.md index f906890f1..2479d406f 100644 --- a/docs/source/providers/external.md +++ b/docs/source/providers/external/external-providers-guide.md @@ -1,9 +1,4 @@ -# External Providers Guide - -Llama Stack supports external providers that live outside of the main codebase. This allows you to: -- Create and maintain your own providers independently -- Share providers with others without contributing to the main codebase -- Keep provider-specific code separate from the core Llama Stack code +# Creating External Providers ## Configuration @@ -55,17 +50,6 @@ Llama Stack supports two types of external providers: 1. **Remote Providers**: Providers that communicate with external services (e.g., cloud APIs) 2. **Inline Providers**: Providers that run locally within the Llama Stack process -## Known External Providers - -Here's a list of known external providers that you can use with Llama Stack: - -| Name | Description | API | Type | Repository | -|------|-------------|-----|------|------------| -| KubeFlow Training | Train models with KubeFlow | Post Training | Remote | [llama-stack-provider-kft](https://github.com/opendatahub-io/llama-stack-provider-kft) | -| KubeFlow Pipelines | Train models with KubeFlow Pipelines | Post Training | Inline **and** Remote | [llama-stack-provider-kfp-trainer](https://github.com/opendatahub-io/llama-stack-provider-kfp-trainer) | -| RamaLama | Inference models with RamaLama | Inference | Remote | [ramalama-stack](https://github.com/containers/ramalama-stack) | -| TrustyAI LM-Eval | Evaluate models with TrustyAI LM-Eval | Eval | Remote | [llama-stack-provider-lmeval](https://github.com/trustyai-explainability/llama-stack-provider-lmeval) | - ### Remote Provider Specification Remote providers are used when you need to communicate with external services. Here's an example for a custom Ollama provider: @@ -119,9 +103,9 @@ container_image: custom-vector-store:latest # optional - `provider_data_validator`: Optional validator for provider data - `container_image`: Optional container image to use instead of pip packages -## Required Implementation +## Required Fields -## All Providers +### All Providers All providers must contain a `get_provider_spec` function in their `provider` module. This is a standardized structure that Llama Stack expects and is necessary for getting things such as the config class. The `get_provider_spec` method returns a structure identical to the `adapter`. An example function may look like: @@ -146,7 +130,7 @@ def get_provider_spec() -> ProviderSpec: ) ``` -### Remote Providers +#### Remote Providers Remote providers must expose a `get_adapter_impl()` function in their module that takes two arguments: 1. `config`: An instance of the provider's config class @@ -162,7 +146,7 @@ async def get_adapter_impl( return OllamaInferenceAdapter(config) ``` -### Inline Providers +#### Inline Providers Inline providers must expose a `get_provider_impl()` function in their module that takes two arguments: 1. `config`: An instance of the provider's config class @@ -189,7 +173,40 @@ Version: 0.1.0 Location: /path/to/venv/lib/python3.10/site-packages ``` -## Example using `external_providers_dir`: Custom Ollama Provider +## Best Practices + +1. **Package Naming**: Use the prefix `llama-stack-provider-` for your provider packages to make them easily identifiable. + +2. **Version Management**: Keep your provider package versioned and compatible with the Llama Stack version you're using. + +3. **Dependencies**: Only include the minimum required dependencies in your provider package. + +4. **Documentation**: Include clear documentation in your provider package about: + - Installation requirements + - Configuration options + - Usage examples + - Any limitations or known issues + +5. **Testing**: Include tests in your provider package to ensure it works correctly with Llama Stack. +You can refer to the [integration tests +guide](https://github.com/meta-llama/llama-stack/blob/main/tests/integration/README.md) for more +information. Execute the test for the Provider type you are developing. + +## Troubleshooting + +If your external provider isn't being loaded: + +1. Check that `module` points to a published pip package with a top level `provider` module including `get_provider_spec`. +1. Check that the `external_providers_dir` path is correct and accessible. +2. Verify that the YAML files are properly formatted. +3. Ensure all required Python packages are installed. +4. Check the Llama Stack server logs for any error messages - turn on debug logging to get more + information using `LLAMA_STACK_LOGGING=all=debug`. +5. Verify that the provider package is installed in your Python environment if using `external_providers_dir`. + +## Examples + +### Example using `external_providers_dir`: Custom Ollama Provider Here's a complete example of creating and using a custom Ollama provider: @@ -241,7 +258,7 @@ external_providers_dir: ~/.llama/providers.d/ The provider will now be available in Llama Stack with the type `remote::custom_ollama`. -## Example using `module`: ramalama-stack +### Example using `module`: ramalama-stack [ramalama-stack](https://github.com/containers/ramalama-stack) is a recognized external provider that supports installation via module. @@ -266,35 +283,4 @@ additional_pip_packages: No other steps are required other than `llama stack build` and `llama stack run`. The build process will use `module` to install all of the provider dependencies, retrieve the spec, etc. -The provider will now be available in Llama Stack with the type `remote::ramalama`. - -## Best Practices - -1. **Package Naming**: Use the prefix `llama-stack-provider-` for your provider packages to make them easily identifiable. - -2. **Version Management**: Keep your provider package versioned and compatible with the Llama Stack version you're using. - -3. **Dependencies**: Only include the minimum required dependencies in your provider package. - -4. **Documentation**: Include clear documentation in your provider package about: - - Installation requirements - - Configuration options - - Usage examples - - Any limitations or known issues - -5. **Testing**: Include tests in your provider package to ensure it works correctly with Llama Stack. -You can refer to the [integration tests -guide](https://github.com/meta-llama/llama-stack/blob/main/tests/integration/README.md) for more -information. Execute the test for the Provider type you are developing. - -## Troubleshooting - -If your external provider isn't being loaded: - -1. Check that `module` points to a published pip package with a top level `provider` module including `get_provider_spec`. -1. Check that the `external_providers_dir` path is correct and accessible. -2. Verify that the YAML files are properly formatted. -3. Ensure all required Python packages are installed. -4. Check the Llama Stack server logs for any error messages - turn on debug logging to get more - information using `LLAMA_STACK_LOGGING=all=debug`. -5. Verify that the provider package is installed in your Python environment if using `external_providers_dir`. +The provider will now be available in Llama Stack with the type `remote::ramalama`. \ No newline at end of file diff --git a/docs/source/providers/external/external-providers-list.md b/docs/source/providers/external/external-providers-list.md new file mode 100644 index 000000000..49f49076b --- /dev/null +++ b/docs/source/providers/external/external-providers-list.md @@ -0,0 +1,10 @@ +# Known External Providers + +Here's a list of known external providers that you can use with Llama Stack: + +| Name | Description | API | Type | Repository | +|------|-------------|-----|------|------------| +| KubeFlow Training | Train models with KubeFlow | Post Training | Remote | [llama-stack-provider-kft](https://github.com/opendatahub-io/llama-stack-provider-kft) | +| KubeFlow Pipelines | Train models with KubeFlow Pipelines | Post Training | Inline **and** Remote | [llama-stack-provider-kfp-trainer](https://github.com/opendatahub-io/llama-stack-provider-kfp-trainer) | +| RamaLama | Inference models with RamaLama | Inference | Remote | [ramalama-stack](https://github.com/containers/ramalama-stack) | +| TrustyAI LM-Eval | Evaluate models with TrustyAI LM-Eval | Eval | Remote | [llama-stack-provider-lmeval](https://github.com/trustyai-explainability/llama-stack-provider-lmeval) | \ No newline at end of file diff --git a/docs/source/providers/external/index.md b/docs/source/providers/external/index.md new file mode 100644 index 000000000..989a7f5b8 --- /dev/null +++ b/docs/source/providers/external/index.md @@ -0,0 +1,13 @@ +# External Providers + +Llama Stack supports external providers that live outside of the main codebase. This allows you to: +- Create and maintain your own providers independently +- Share providers with others without contributing to the main codebase +- Keep provider-specific code separate from the core Llama Stack code + +```{toctree} +:maxdepth: 1 + +external-providers-list +external-providers-guide +``` \ No newline at end of file diff --git a/docs/source/providers/files/index.md b/docs/source/providers/files/index.md index 25d9b05ba..692aad3ca 100644 --- a/docs/source/providers/files/index.md +++ b/docs/source/providers/files/index.md @@ -1,5 +1,13 @@ -# Files Providers +# Files + +## Overview This section contains documentation for all available providers for the **files** API. -- [inline::localfs](inline_localfs.md) \ No newline at end of file +## Providers + +```{toctree} +:maxdepth: 1 + +inline_localfs +``` diff --git a/docs/source/providers/files/inline_localfs.md b/docs/source/providers/files/inline_localfs.md index 54c489c7d..09267b7d8 100644 --- a/docs/source/providers/files/inline_localfs.md +++ b/docs/source/providers/files/inline_localfs.md @@ -8,7 +8,7 @@ Local filesystem-based file storage provider for managing files and documents lo | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| -| `storage_dir` | `` | No | PydanticUndefined | Directory to store uploaded files | +| `storage_dir` | `` | No | | Directory to store uploaded files | | `metadata_store` | `utils.sqlstore.sqlstore.SqliteSqlStoreConfig \| utils.sqlstore.sqlstore.PostgresSqlStoreConfig` | No | sqlite | SQL store configuration for file metadata | | `ttl_secs` | `` | No | 31536000 | | diff --git a/docs/source/providers/index.md b/docs/source/providers/index.md index 596daa9ba..3f66ecd0c 100644 --- a/docs/source/providers/index.md +++ b/docs/source/providers/index.md @@ -1,4 +1,4 @@ -# API Providers Overview +# API Providers The goal of Llama Stack is to build an ecosystem where users can easily swap out different implementations for the same API. Examples for these include: - LLM inference providers (e.g., Meta Reference, Ollama, Fireworks, Together, AWS Bedrock, Groq, Cerebras, SambaNova, vLLM, OpenAI, Anthropic, Gemini, WatsonX, etc.), @@ -12,81 +12,17 @@ Providers come in two flavors: Importantly, Llama Stack always strives to provide at least one fully inline provider for each API so you can iterate on a fully featured environment locally. -## External Providers -Llama Stack supports external providers that live outside of the main codebase. This allows you to create and maintain your own providers independently. - -```{toctree} -:maxdepth: 1 - -external.md -``` - -```{include} openai.md -:start-after: ## OpenAI API Compatibility -``` - -## Inference -Runs inference with an LLM. - ```{toctree} :maxdepth: 1 +external/index +openai inference/index -``` - -## Agents -Run multi-step agentic workflows with LLMs with tool usage, memory (RAG), etc. - -```{toctree} -:maxdepth: 1 - agents/index -``` - -## DatasetIO -Interfaces with datasets and data loaders. - -```{toctree} -:maxdepth: 1 - datasetio/index -``` - -## Safety -Applies safety policies to the output at a Systems (not only model) level. - -```{toctree} -:maxdepth: 1 - safety/index -``` - -## Telemetry -Collects telemetry data from the system. - -```{toctree} -:maxdepth: 1 - telemetry/index -``` - -## Vector IO - -Vector IO refers to operations on vector databases, such as adding documents, searching, and deleting documents. -Vector IO plays a crucial role in [Retreival Augmented Generation (RAG)](../..//building_applications/rag), where the vector -io and database are used to store and retrieve documents for retrieval. - -```{toctree} -:maxdepth: 1 - vector_io/index -``` - -## Tool Runtime -Is associated with the ToolGroup resources. - -```{toctree} -:maxdepth: 1 - tool_runtime/index -``` \ No newline at end of file +files/index +``` diff --git a/docs/source/providers/inference/index.md b/docs/source/providers/inference/index.md index dcc6da5b5..1c7bc86b9 100644 --- a/docs/source/providers/inference/index.md +++ b/docs/source/providers/inference/index.md @@ -1,26 +1,34 @@ -# Inference Providers +# Inference + +## Overview This section contains documentation for all available providers for the **inference** API. -- [inline::meta-reference](inline_meta-reference.md) -- [inline::sentence-transformers](inline_sentence-transformers.md) -- [remote::anthropic](remote_anthropic.md) -- [remote::bedrock](remote_bedrock.md) -- [remote::cerebras](remote_cerebras.md) -- [remote::databricks](remote_databricks.md) -- [remote::fireworks](remote_fireworks.md) -- [remote::gemini](remote_gemini.md) -- [remote::groq](remote_groq.md) -- [remote::hf::endpoint](remote_hf_endpoint.md) -- [remote::hf::serverless](remote_hf_serverless.md) -- [remote::llama-openai-compat](remote_llama-openai-compat.md) -- [remote::nvidia](remote_nvidia.md) -- [remote::ollama](remote_ollama.md) -- [remote::openai](remote_openai.md) -- [remote::passthrough](remote_passthrough.md) -- [remote::runpod](remote_runpod.md) -- [remote::sambanova](remote_sambanova.md) -- [remote::tgi](remote_tgi.md) -- [remote::together](remote_together.md) -- [remote::vllm](remote_vllm.md) -- [remote::watsonx](remote_watsonx.md) \ No newline at end of file +## Providers + +```{toctree} +:maxdepth: 1 + +inline_meta-reference +inline_sentence-transformers +remote_anthropic +remote_bedrock +remote_cerebras +remote_databricks +remote_fireworks +remote_gemini +remote_groq +remote_hf_endpoint +remote_hf_serverless +remote_llama-openai-compat +remote_nvidia +remote_ollama +remote_openai +remote_passthrough +remote_runpod +remote_sambanova +remote_tgi +remote_together +remote_vllm +remote_watsonx +``` diff --git a/docs/source/providers/inference/remote_cerebras-openai-compat.md b/docs/source/providers/inference/remote_cerebras-openai-compat.md deleted file mode 100644 index 64b899246..000000000 --- a/docs/source/providers/inference/remote_cerebras-openai-compat.md +++ /dev/null @@ -1,21 +0,0 @@ -# remote::cerebras-openai-compat - -## Description - -Cerebras OpenAI-compatible provider for using Cerebras models with OpenAI API format. - -## Configuration - -| Field | Type | Required | Default | Description | -|-------|------|----------|---------|-------------| -| `api_key` | `str \| None` | No | | The Cerebras API key | -| `openai_compat_api_base` | `` | No | https://api.cerebras.ai/v1 | The URL for the Cerebras API server | - -## Sample Configuration - -```yaml -openai_compat_api_base: https://api.cerebras.ai/v1 -api_key: ${env.CEREBRAS_API_KEY} - -``` - diff --git a/docs/source/providers/inference/remote_fireworks-openai-compat.md b/docs/source/providers/inference/remote_fireworks-openai-compat.md deleted file mode 100644 index 0a2bd0fe8..000000000 --- a/docs/source/providers/inference/remote_fireworks-openai-compat.md +++ /dev/null @@ -1,21 +0,0 @@ -# remote::fireworks-openai-compat - -## Description - -Fireworks AI OpenAI-compatible provider for using Fireworks models with OpenAI API format. - -## Configuration - -| Field | Type | Required | Default | Description | -|-------|------|----------|---------|-------------| -| `api_key` | `str \| None` | No | | The Fireworks API key | -| `openai_compat_api_base` | `` | No | https://api.fireworks.ai/inference/v1 | The URL for the Fireworks API server | - -## Sample Configuration - -```yaml -openai_compat_api_base: https://api.fireworks.ai/inference/v1 -api_key: ${env.FIREWORKS_API_KEY} - -``` - diff --git a/docs/source/providers/inference/remote_groq-openai-compat.md b/docs/source/providers/inference/remote_groq-openai-compat.md deleted file mode 100644 index e424bedd2..000000000 --- a/docs/source/providers/inference/remote_groq-openai-compat.md +++ /dev/null @@ -1,21 +0,0 @@ -# remote::groq-openai-compat - -## Description - -Groq OpenAI-compatible provider for using Groq models with OpenAI API format. - -## Configuration - -| Field | Type | Required | Default | Description | -|-------|------|----------|---------|-------------| -| `api_key` | `str \| None` | No | | The Groq API key | -| `openai_compat_api_base` | `` | No | https://api.groq.com/openai/v1 | The URL for the Groq API server | - -## Sample Configuration - -```yaml -openai_compat_api_base: https://api.groq.com/openai/v1 -api_key: ${env.GROQ_API_KEY} - -``` - diff --git a/docs/source/providers/inference/remote_hf_endpoint.md b/docs/source/providers/inference/remote_hf_endpoint.md index f9ca6b538..8aaf13476 100644 --- a/docs/source/providers/inference/remote_hf_endpoint.md +++ b/docs/source/providers/inference/remote_hf_endpoint.md @@ -8,7 +8,7 @@ HuggingFace Inference Endpoints provider for dedicated model serving. | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| -| `endpoint_name` | `` | No | PydanticUndefined | The name of the Hugging Face Inference Endpoint in the format of '{namespace}/{endpoint_name}' (e.g. 'my-cool-org/meta-llama-3-1-8b-instruct-rce'). Namespace is optional and will default to the user account if not provided. | +| `endpoint_name` | `` | No | | The name of the Hugging Face Inference Endpoint in the format of '{namespace}/{endpoint_name}' (e.g. 'my-cool-org/meta-llama-3-1-8b-instruct-rce'). Namespace is optional and will default to the user account if not provided. | | `api_token` | `pydantic.types.SecretStr \| None` | No | | Your Hugging Face user access token (will default to locally saved token if not provided) | ## Sample Configuration diff --git a/docs/source/providers/inference/remote_hf_serverless.md b/docs/source/providers/inference/remote_hf_serverless.md index 345af3e49..6764590b8 100644 --- a/docs/source/providers/inference/remote_hf_serverless.md +++ b/docs/source/providers/inference/remote_hf_serverless.md @@ -8,7 +8,7 @@ HuggingFace Inference API serverless provider for on-demand model inference. | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| -| `huggingface_repo` | `` | No | PydanticUndefined | The model ID of the model on the Hugging Face Hub (e.g. 'meta-llama/Meta-Llama-3.1-70B-Instruct') | +| `huggingface_repo` | `` | No | | The model ID of the model on the Hugging Face Hub (e.g. 'meta-llama/Meta-Llama-3.1-70B-Instruct') | | `api_token` | `pydantic.types.SecretStr \| None` | No | | Your Hugging Face user access token (will default to locally saved token if not provided) | ## Sample Configuration diff --git a/docs/source/providers/inference/remote_tgi.md b/docs/source/providers/inference/remote_tgi.md index 125984fab..104bb4aab 100644 --- a/docs/source/providers/inference/remote_tgi.md +++ b/docs/source/providers/inference/remote_tgi.md @@ -8,7 +8,7 @@ Text Generation Inference (TGI) provider for HuggingFace model serving. | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| -| `url` | `` | No | PydanticUndefined | The URL for the TGI serving endpoint | +| `url` | `` | No | | The URL for the TGI serving endpoint | ## Sample Configuration diff --git a/docs/source/providers/inference/remote_together-openai-compat.md b/docs/source/providers/inference/remote_together-openai-compat.md deleted file mode 100644 index 833fa8cb0..000000000 --- a/docs/source/providers/inference/remote_together-openai-compat.md +++ /dev/null @@ -1,21 +0,0 @@ -# remote::together-openai-compat - -## Description - -Together AI OpenAI-compatible provider for using Together models with OpenAI API format. - -## Configuration - -| Field | Type | Required | Default | Description | -|-------|------|----------|---------|-------------| -| `api_key` | `str \| None` | No | | The Together API key | -| `openai_compat_api_base` | `` | No | https://api.together.xyz/v1 | The URL for the Together API server | - -## Sample Configuration - -```yaml -openai_compat_api_base: https://api.together.xyz/v1 -api_key: ${env.TOGETHER_API_KEY} - -``` - diff --git a/docs/source/providers/post_training/index.md b/docs/source/providers/post_training/index.md index 35d10d14b..c6c92c40e 100644 --- a/docs/source/providers/post_training/index.md +++ b/docs/source/providers/post_training/index.md @@ -1,7 +1,15 @@ -# Post_Training Providers +# Post_Training + +## Overview This section contains documentation for all available providers for the **post_training** API. -- [inline::huggingface](inline_huggingface.md) -- [inline::torchtune](inline_torchtune.md) -- [remote::nvidia](remote_nvidia.md) \ No newline at end of file +## Providers + +```{toctree} +:maxdepth: 1 + +inline_huggingface +inline_torchtune +remote_nvidia +``` diff --git a/docs/source/providers/post_training/inline_huggingface.md b/docs/source/providers/post_training/inline_huggingface.md index 82b08bf7a..8b10fe79c 100644 --- a/docs/source/providers/post_training/inline_huggingface.md +++ b/docs/source/providers/post_training/inline_huggingface.md @@ -24,6 +24,10 @@ HuggingFace-based post-training provider for fine-tuning models using the Huggin | `weight_decay` | `` | No | 0.01 | | | `dataloader_num_workers` | `` | No | 4 | | | `dataloader_pin_memory` | `` | No | True | | +| `dpo_beta` | `` | No | 0.1 | | +| `use_reference_model` | `` | No | True | | +| `dpo_loss_type` | `Literal['sigmoid', 'hinge', 'ipo', 'kto_pair'` | No | sigmoid | | +| `dpo_output_dir` | `` | No | | | ## Sample Configuration @@ -31,6 +35,7 @@ HuggingFace-based post-training provider for fine-tuning models using the Huggin checkpoint_format: huggingface distributed_backend: null device: cpu +dpo_output_dir: ~/.llama/dummy/dpo_output ``` diff --git a/docs/source/providers/safety/index.md b/docs/source/providers/safety/index.md index 1a245c13d..5ddda2242 100644 --- a/docs/source/providers/safety/index.md +++ b/docs/source/providers/safety/index.md @@ -1,10 +1,18 @@ -# Safety Providers +# Safety + +## Overview This section contains documentation for all available providers for the **safety** API. -- [inline::code-scanner](inline_code-scanner.md) -- [inline::llama-guard](inline_llama-guard.md) -- [inline::prompt-guard](inline_prompt-guard.md) -- [remote::bedrock](remote_bedrock.md) -- [remote::nvidia](remote_nvidia.md) -- [remote::sambanova](remote_sambanova.md) \ No newline at end of file +## Providers + +```{toctree} +:maxdepth: 1 + +inline_code-scanner +inline_llama-guard +inline_prompt-guard +remote_bedrock +remote_nvidia +remote_sambanova +``` diff --git a/docs/source/providers/scoring/index.md b/docs/source/providers/scoring/index.md index 3cf7af537..f3bd48eb0 100644 --- a/docs/source/providers/scoring/index.md +++ b/docs/source/providers/scoring/index.md @@ -1,7 +1,15 @@ -# Scoring Providers +# Scoring + +## Overview This section contains documentation for all available providers for the **scoring** API. -- [inline::basic](inline_basic.md) -- [inline::braintrust](inline_braintrust.md) -- [inline::llm-as-judge](inline_llm-as-judge.md) \ No newline at end of file +## Providers + +```{toctree} +:maxdepth: 1 + +inline_basic +inline_braintrust +inline_llm-as-judge +``` diff --git a/docs/source/providers/telemetry/index.md b/docs/source/providers/telemetry/index.md index e2b221b50..c7fbfed73 100644 --- a/docs/source/providers/telemetry/index.md +++ b/docs/source/providers/telemetry/index.md @@ -1,5 +1,13 @@ -# Telemetry Providers +# Telemetry + +## Overview This section contains documentation for all available providers for the **telemetry** API. -- [inline::meta-reference](inline_meta-reference.md) \ No newline at end of file +## Providers + +```{toctree} +:maxdepth: 1 + +inline_meta-reference +``` diff --git a/docs/source/providers/tool_runtime/index.md b/docs/source/providers/tool_runtime/index.md index f162c4f9c..8d29aed43 100644 --- a/docs/source/providers/tool_runtime/index.md +++ b/docs/source/providers/tool_runtime/index.md @@ -1,10 +1,18 @@ -# Tool_Runtime Providers +# Tool_Runtime + +## Overview This section contains documentation for all available providers for the **tool_runtime** API. -- [inline::rag-runtime](inline_rag-runtime.md) -- [remote::bing-search](remote_bing-search.md) -- [remote::brave-search](remote_brave-search.md) -- [remote::model-context-protocol](remote_model-context-protocol.md) -- [remote::tavily-search](remote_tavily-search.md) -- [remote::wolfram-alpha](remote_wolfram-alpha.md) \ No newline at end of file +## Providers + +```{toctree} +:maxdepth: 1 + +inline_rag-runtime +remote_bing-search +remote_brave-search +remote_model-context-protocol +remote_tavily-search +remote_wolfram-alpha +``` diff --git a/docs/source/providers/vector_io/index.md b/docs/source/providers/vector_io/index.md index 870d04401..28ae523d7 100644 --- a/docs/source/providers/vector_io/index.md +++ b/docs/source/providers/vector_io/index.md @@ -1,16 +1,24 @@ -# Vector_Io Providers +# Vector_Io + +## Overview This section contains documentation for all available providers for the **vector_io** API. -- [inline::chromadb](inline_chromadb.md) -- [inline::faiss](inline_faiss.md) -- [inline::meta-reference](inline_meta-reference.md) -- [inline::milvus](inline_milvus.md) -- [inline::qdrant](inline_qdrant.md) -- [inline::sqlite-vec](inline_sqlite-vec.md) -- [inline::sqlite_vec](inline_sqlite_vec.md) -- [remote::chromadb](remote_chromadb.md) -- [remote::milvus](remote_milvus.md) -- [remote::pgvector](remote_pgvector.md) -- [remote::qdrant](remote_qdrant.md) -- [remote::weaviate](remote_weaviate.md) \ No newline at end of file +## Providers + +```{toctree} +:maxdepth: 1 + +inline_chromadb +inline_faiss +inline_meta-reference +inline_milvus +inline_qdrant +inline_sqlite-vec +inline_sqlite_vec +remote_chromadb +remote_milvus +remote_pgvector +remote_qdrant +remote_weaviate +``` diff --git a/docs/source/providers/vector_io/inline_chromadb.md b/docs/source/providers/vector_io/inline_chromadb.md index 679c82830..518e3f689 100644 --- a/docs/source/providers/vector_io/inline_chromadb.md +++ b/docs/source/providers/vector_io/inline_chromadb.md @@ -41,7 +41,7 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| -| `db_path` | `` | No | PydanticUndefined | | +| `db_path` | `` | No | | | | `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Config for KV store backend | ## Sample Configuration diff --git a/docs/source/providers/vector_io/inline_milvus.md b/docs/source/providers/vector_io/inline_milvus.md index 3b3aad3fc..33ea4d179 100644 --- a/docs/source/providers/vector_io/inline_milvus.md +++ b/docs/source/providers/vector_io/inline_milvus.md @@ -10,7 +10,7 @@ Please refer to the remote provider documentation. | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| -| `db_path` | `` | No | PydanticUndefined | | +| `db_path` | `` | No | | | | `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Config for KV store backend (SQLite only for now) | | `consistency_level` | `` | No | Strong | The consistency level of the Milvus server | diff --git a/docs/source/providers/vector_io/inline_qdrant.md b/docs/source/providers/vector_io/inline_qdrant.md index 63e2d81d8..b5072d220 100644 --- a/docs/source/providers/vector_io/inline_qdrant.md +++ b/docs/source/providers/vector_io/inline_qdrant.md @@ -50,12 +50,16 @@ See the [Qdrant documentation](https://qdrant.tech/documentation/) for more deta | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| -| `path` | `` | No | PydanticUndefined | | +| `path` | `` | No | | | +| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | | ## Sample Configuration ```yaml path: ${env.QDRANT_PATH:=~/.llama/~/.llama/dummy}/qdrant.db +kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/qdrant_registry.db ``` diff --git a/docs/source/providers/vector_io/inline_sqlite-vec.md b/docs/source/providers/vector_io/inline_sqlite-vec.md index ae7c45b21..854bb9d08 100644 --- a/docs/source/providers/vector_io/inline_sqlite-vec.md +++ b/docs/source/providers/vector_io/inline_sqlite-vec.md @@ -205,7 +205,7 @@ See [sqlite-vec's GitHub repo](https://github.com/asg017/sqlite-vec/tree/main) f | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| -| `db_path` | `` | No | PydanticUndefined | Path to the SQLite database file | +| `db_path` | `` | No | | Path to the SQLite database file | | `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Config for KV store backend (SQLite only for now) | ## Sample Configuration diff --git a/docs/source/providers/vector_io/inline_sqlite_vec.md b/docs/source/providers/vector_io/inline_sqlite_vec.md index 7e14bb8bd..7ad8eb252 100644 --- a/docs/source/providers/vector_io/inline_sqlite_vec.md +++ b/docs/source/providers/vector_io/inline_sqlite_vec.md @@ -10,7 +10,7 @@ Please refer to the sqlite-vec provider documentation. | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| -| `db_path` | `` | No | PydanticUndefined | Path to the SQLite database file | +| `db_path` | `` | No | | Path to the SQLite database file | | `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Config for KV store backend (SQLite only for now) | ## Sample Configuration diff --git a/docs/source/providers/vector_io/remote_chromadb.md b/docs/source/providers/vector_io/remote_chromadb.md index 447ea6cd6..badfebe90 100644 --- a/docs/source/providers/vector_io/remote_chromadb.md +++ b/docs/source/providers/vector_io/remote_chromadb.md @@ -40,7 +40,7 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| -| `url` | `str \| None` | No | PydanticUndefined | | +| `url` | `str \| None` | No | | | | `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Config for KV store backend | ## Sample Configuration diff --git a/docs/source/providers/vector_io/remote_milvus.md b/docs/source/providers/vector_io/remote_milvus.md index 6734d8315..3646f4acc 100644 --- a/docs/source/providers/vector_io/remote_milvus.md +++ b/docs/source/providers/vector_io/remote_milvus.md @@ -111,8 +111,8 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| -| `uri` | `` | No | PydanticUndefined | The URI of the Milvus server | -| `token` | `str \| None` | No | PydanticUndefined | The token of the Milvus server | +| `uri` | `` | No | | The URI of the Milvus server | +| `token` | `str \| None` | No | | The token of the Milvus server | | `consistency_level` | `` | No | Strong | The consistency level of the Milvus server | | `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Config for KV store backend | | `config` | `dict` | No | {} | This configuration allows additional fields to be passed through to the underlying Milvus client. See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general. | diff --git a/docs/source/providers/vector_io/remote_qdrant.md b/docs/source/providers/vector_io/remote_qdrant.md index 14c821f35..043141007 100644 --- a/docs/source/providers/vector_io/remote_qdrant.md +++ b/docs/source/providers/vector_io/remote_qdrant.md @@ -20,11 +20,15 @@ Please refer to the inline provider documentation. | `prefix` | `str \| None` | No | | | | `timeout` | `int \| None` | No | | | | `host` | `str \| None` | No | | | +| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | | ## Sample Configuration ```yaml -api_key: ${env.QDRANT_API_KEY} +api_key: ${env.QDRANT_API_KEY:=} +kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/qdrant_registry.db ``` diff --git a/docs/source/providers/vector_io/remote_weaviate.md b/docs/source/providers/vector_io/remote_weaviate.md index d930515d5..c59487cf6 100644 --- a/docs/source/providers/vector_io/remote_weaviate.md +++ b/docs/source/providers/vector_io/remote_weaviate.md @@ -33,9 +33,19 @@ To install Weaviate see the [Weaviate quickstart documentation](https://weaviate See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more details about Weaviate in general. +## Configuration + +| Field | Type | Required | Default | Description | +|-------|------|----------|---------|-------------| +| `weaviate_api_key` | `str \| None` | No | | The API key for the Weaviate instance | +| `weaviate_cluster_url` | `str \| None` | No | localhost:8080 | The URL of the Weaviate cluster | +| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig, annotation=NoneType, required=False, default='sqlite', discriminator='type'` | No | | Config for KV store backend (SQLite only for now) | + ## Sample Configuration ```yaml +weaviate_api_key: null +weaviate_cluster_url: ${env.WEAVIATE_CLUSTER_URL:=localhost:8080} kvstore: type: sqlite db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/weaviate_registry.db diff --git a/docs/source/references/evals_reference/index.md b/docs/source/references/evals_reference/index.md index 0294d83ea..054a0b809 100644 --- a/docs/source/references/evals_reference/index.md +++ b/docs/source/references/evals_reference/index.md @@ -366,7 +366,7 @@ The purpose of scoring function is to calculate the score for each example based Firstly, you can see if the existing [llama stack scoring functions](https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/inline/scoring) can fulfill your need. If not, you need to write a new scoring function based on what benchmark author / other open source repo describe. ### Add new benchmark into template -Firstly, you need to add the evaluation dataset associated with your benchmark under `datasets` resource in the [open-benchmark](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/templates/open-benchmark/run.yaml) +Firstly, you need to add the evaluation dataset associated with your benchmark under `datasets` resource in the [open-benchmark](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/distributions/open-benchmark/run.yaml) Secondly, you need to add the new benchmark you just created under the `benchmarks` resource in the same template. To add the new benchmark, you need to have - `benchmark_id`: identifier of the benchmark @@ -378,7 +378,7 @@ Secondly, you need to add the new benchmark you just created under the `benchmar Spin up llama stack server with 'open-benchmark' templates ``` -llama stack run llama_stack/templates/open-benchmark/run.yaml +llama stack run llama_stack/distributions/open-benchmark/run.yaml ``` diff --git a/docs/source/references/llama_cli_reference/download_models.md b/docs/source/references/llama_cli_reference/download_models.md index ca470f8c2..e32099023 100644 --- a/docs/source/references/llama_cli_reference/download_models.md +++ b/docs/source/references/llama_cli_reference/download_models.md @@ -19,11 +19,11 @@ You have two ways to install Llama Stack: cd ~/local git clone git@github.com:meta-llama/llama-stack.git - conda create -n myenv python=3.10 - conda activate myenv + uv venv myenv --python 3.12 + source myenv/bin/activate # On Windows: myenv\Scripts\activate cd llama-stack - $CONDA_PREFIX/bin/pip install -e . + pip install -e . ## Downloading models via CLI diff --git a/docs/source/references/llama_cli_reference/index.md b/docs/source/references/llama_cli_reference/index.md index 7b7abdf88..4ef76fe7d 100644 --- a/docs/source/references/llama_cli_reference/index.md +++ b/docs/source/references/llama_cli_reference/index.md @@ -19,11 +19,11 @@ You have two ways to install Llama Stack: cd ~/local git clone git@github.com:meta-llama/llama-stack.git - conda create -n myenv python=3.10 - conda activate myenv + uv venv myenv --python 3.12 + source myenv/bin/activate # On Windows: myenv\Scripts\activate cd llama-stack - $CONDA_PREFIX/bin/pip install -e . + pip install -e . ## `llama` subcommands diff --git a/docs/zero_to_hero_guide/06_Safety101.ipynb b/docs/zero_to_hero_guide/06_Safety101.ipynb index 5d7763924..91b809621 100644 --- a/docs/zero_to_hero_guide/06_Safety101.ipynb +++ b/docs/zero_to_hero_guide/06_Safety101.ipynb @@ -66,7 +66,7 @@ "from pydantic import BaseModel\n", "from termcolor import cprint\n", "\n", - "from llama_stack.distribution.datatypes import RemoteProviderConfig\n", + "from llama_stack.core.datatypes import RemoteProviderConfig\n", "from llama_stack.apis.safety import Safety\n", "from llama_stack_client import LlamaStackClient\n", "\n", diff --git a/docs/zero_to_hero_guide/README.md b/docs/zero_to_hero_guide/README.md index cc3adc706..9f1f42b30 100644 --- a/docs/zero_to_hero_guide/README.md +++ b/docs/zero_to_hero_guide/README.md @@ -47,20 +47,20 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next ## Install Dependencies and Set Up Environment -1. **Create a Conda Environment**: - Create a new Conda environment with Python 3.12: +1. **Install uv**: + Install [uv](https://docs.astral.sh/uv/) for managing dependencies: ```bash - conda create -n ollama python=3.12 - ``` - Activate the environment: - ```bash - conda activate ollama + # macOS and Linux + curl -LsSf https://astral.sh/uv/install.sh | sh + + # Windows + powershell -ExecutionPolicy ByPass -c "irm https://astral.sh/uv/install.ps1 | iex" ``` 2. **Install ChromaDB**: - Install `chromadb` using `pip`: + Install `chromadb` using `uv`: ```bash - pip install chromadb + uv pip install chromadb ``` 3. **Run ChromaDB**: @@ -69,28 +69,21 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next chroma run --host localhost --port 8000 --path ./my_chroma_data ``` -4. **Install Llama Stack**: - Open a new terminal and install `llama-stack`: - ```bash - conda activate ollama - pip install -U llama-stack - ``` - --- ## Build, Configure, and Run Llama Stack 1. **Build the Llama Stack**: - Build the Llama Stack using the `ollama` template: + Build the Llama Stack using the `starter` template: ```bash - llama stack build --template starter --image-type conda + uv run --with llama-stack llama stack build --distro starter --image-type venv ``` **Expected Output:** ```bash ... Build Successful! - You can find the newly-built template here: ~/.llama/distributions/ollama/ollama-run.yaml - You can run the new Llama Stack Distro via: llama stack run ~/.llama/distributions/ollama/ollama-run.yaml --image-type conda + You can find the newly-built template here: ~/.llama/distributions/starter/starter-run.yaml + You can run the new Llama Stack Distro via: uv run --with llama-stack llama stack run starter --image-type venv ``` 3. **Set the ENV variables by exporting them to the terminal**: @@ -102,12 +95,13 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next ``` 3. **Run the Llama Stack**: - Run the stack with command shared by the API from earlier: + Run the stack using uv: ```bash - llama stack run ollama - --port $LLAMA_STACK_PORT - --env INFERENCE_MODEL=$INFERENCE_MODEL - --env SAFETY_MODEL=$SAFETY_MODEL + uv run --with llama-stack llama stack run starter \ + --image-type venv \ + --port $LLAMA_STACK_PORT \ + --env INFERENCE_MODEL=$INFERENCE_MODEL \ + --env SAFETY_MODEL=$SAFETY_MODEL \ --env OLLAMA_URL=$OLLAMA_URL ``` Note: Every time you run a new model with `ollama run`, you will need to restart the llama stack. Otherwise it won't see the new model. @@ -120,7 +114,7 @@ After setting up the server, open a new terminal window and configure the llama- 1. Configure the CLI to point to the llama-stack server. ```bash - llama-stack-client configure --endpoint http://localhost:8321 + uv run --with llama-stack-client llama-stack-client configure --endpoint http://localhost:8321 ``` **Expected Output:** ```bash @@ -128,7 +122,7 @@ After setting up the server, open a new terminal window and configure the llama- ``` 2. Test the CLI by running inference: ```bash - llama-stack-client inference chat-completion --message "Write me a 2-sentence poem about the moon" + uv run --with llama-stack-client llama-stack-client inference chat-completion --message "Write me a 2-sentence poem about the moon" ``` **Expected Output:** ```bash @@ -170,7 +164,7 @@ curl http://localhost:$LLAMA_STACK_PORT/alpha/inference/chat-completion EOF ``` -You can check the available models with the command `llama-stack-client models list`. +You can check the available models with the command `uv run --with llama-stack-client llama-stack-client models list`. **Expected Output:** ```json @@ -191,18 +185,12 @@ You can check the available models with the command `llama-stack-client models l You can also interact with the Llama Stack server using a simple Python script. Below is an example: -### 1. Activate Conda Environment - -```bash -conda activate ollama -``` - -### 2. Create Python Script (`test_llama_stack.py`) +### 1. Create Python Script (`test_llama_stack.py`) ```bash touch test_llama_stack.py ``` -### 3. Create a Chat Completion Request in Python +### 2. Create a Chat Completion Request in Python In `test_llama_stack.py`, write the following code: @@ -233,10 +221,10 @@ response = client.inference.chat_completion( print(response.completion_message.content) ``` -### 4. Run the Python Script +### 3. Run the Python Script ```bash -python test_llama_stack.py +uv run --with llama-stack-client python test_llama_stack.py ``` **Expected Output:** diff --git a/llama_stack/__init__.py b/llama_stack/__init__.py index 98f2441c0..1c2ce7123 100644 --- a/llama_stack/__init__.py +++ b/llama_stack/__init__.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.distribution.library_client import ( # noqa: F401 +from llama_stack.core.library_client import ( # noqa: F401 AsyncLlamaStackAsLibraryClient, LlamaStackAsLibraryClient, ) diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 64b162e9e..e816da766 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -152,7 +152,17 @@ Step = Annotated[ @json_schema_type class Turn(BaseModel): - """A single turn in an interaction with an Agentic System.""" + """A single turn in an interaction with an Agentic System. + + :param turn_id: Unique identifier for the turn within a session + :param session_id: Unique identifier for the conversation session + :param input_messages: List of messages that initiated this turn + :param steps: Ordered list of processing steps executed during this turn + :param output_message: The model's generated response containing content and metadata + :param output_attachments: (Optional) Files or media attached to the agent's response + :param started_at: Timestamp when the turn began + :param completed_at: (Optional) Timestamp when the turn finished, if completed + """ turn_id: str session_id: str @@ -167,7 +177,13 @@ class Turn(BaseModel): @json_schema_type class Session(BaseModel): - """A single session of an interaction with an Agentic System.""" + """A single session of an interaction with an Agentic System. + + :param session_id: Unique identifier for the conversation session + :param session_name: Human-readable name for the session + :param turns: List of all turns that have occurred in this session + :param started_at: Timestamp when the session was created + """ session_id: str session_name: str @@ -232,6 +248,13 @@ class AgentConfig(AgentConfigCommon): @json_schema_type class Agent(BaseModel): + """An agent instance with configuration and metadata. + + :param agent_id: Unique identifier for the agent + :param agent_config: Configuration settings for the agent + :param created_at: Timestamp when the agent was created + """ + agent_id: str agent_config: AgentConfig created_at: datetime @@ -253,6 +276,14 @@ class AgentTurnResponseEventType(StrEnum): @json_schema_type class AgentTurnResponseStepStartPayload(BaseModel): + """Payload for step start events in agent turn responses. + + :param event_type: Type of event being reported + :param step_type: Type of step being executed + :param step_id: Unique identifier for the step within a turn + :param metadata: (Optional) Additional metadata for the step + """ + event_type: Literal[AgentTurnResponseEventType.step_start] = AgentTurnResponseEventType.step_start step_type: StepType step_id: str @@ -261,6 +292,14 @@ class AgentTurnResponseStepStartPayload(BaseModel): @json_schema_type class AgentTurnResponseStepCompletePayload(BaseModel): + """Payload for step completion events in agent turn responses. + + :param event_type: Type of event being reported + :param step_type: Type of step being executed + :param step_id: Unique identifier for the step within a turn + :param step_details: Complete details of the executed step + """ + event_type: Literal[AgentTurnResponseEventType.step_complete] = AgentTurnResponseEventType.step_complete step_type: StepType step_id: str @@ -269,6 +308,14 @@ class AgentTurnResponseStepCompletePayload(BaseModel): @json_schema_type class AgentTurnResponseStepProgressPayload(BaseModel): + """Payload for step progress events in agent turn responses. + + :param event_type: Type of event being reported + :param step_type: Type of step being executed + :param step_id: Unique identifier for the step within a turn + :param delta: Incremental content changes during step execution + """ + model_config = ConfigDict(protected_namespaces=()) event_type: Literal[AgentTurnResponseEventType.step_progress] = AgentTurnResponseEventType.step_progress @@ -280,18 +327,36 @@ class AgentTurnResponseStepProgressPayload(BaseModel): @json_schema_type class AgentTurnResponseTurnStartPayload(BaseModel): + """Payload for turn start events in agent turn responses. + + :param event_type: Type of event being reported + :param turn_id: Unique identifier for the turn within a session + """ + event_type: Literal[AgentTurnResponseEventType.turn_start] = AgentTurnResponseEventType.turn_start turn_id: str @json_schema_type class AgentTurnResponseTurnCompletePayload(BaseModel): + """Payload for turn completion events in agent turn responses. + + :param event_type: Type of event being reported + :param turn: Complete turn data including all steps and results + """ + event_type: Literal[AgentTurnResponseEventType.turn_complete] = AgentTurnResponseEventType.turn_complete turn: Turn @json_schema_type class AgentTurnResponseTurnAwaitingInputPayload(BaseModel): + """Payload for turn awaiting input events in agent turn responses. + + :param event_type: Type of event being reported + :param turn: Turn data when waiting for external tool responses + """ + event_type: Literal[AgentTurnResponseEventType.turn_awaiting_input] = AgentTurnResponseEventType.turn_awaiting_input turn: Turn @@ -310,21 +375,47 @@ register_schema(AgentTurnResponseEventPayload, name="AgentTurnResponseEventPaylo @json_schema_type class AgentTurnResponseEvent(BaseModel): + """An event in an agent turn response stream. + + :param payload: Event-specific payload containing event data + """ + payload: AgentTurnResponseEventPayload @json_schema_type class AgentCreateResponse(BaseModel): + """Response returned when creating a new agent. + + :param agent_id: Unique identifier for the created agent + """ + agent_id: str @json_schema_type class AgentSessionCreateResponse(BaseModel): + """Response returned when creating a new agent session. + + :param session_id: Unique identifier for the created session + """ + session_id: str @json_schema_type class AgentTurnCreateRequest(AgentConfigOverridablePerTurn): + """Request to create a new turn for an agent. + + :param agent_id: Unique identifier for the agent + :param session_id: Unique identifier for the conversation session + :param messages: List of messages to start the turn with + :param documents: (Optional) List of documents to provide to the agent + :param toolgroups: (Optional) List of tool groups to make available for this turn + :param stream: (Optional) Whether to stream the response + :param tool_config: (Optional) Tool configuration to override agent defaults + """ + agent_id: str session_id: str @@ -342,6 +433,15 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn): @json_schema_type class AgentTurnResumeRequest(BaseModel): + """Request to resume an agent turn with tool responses. + + :param agent_id: Unique identifier for the agent + :param session_id: Unique identifier for the conversation session + :param turn_id: Unique identifier for the turn within a session + :param tool_responses: List of tool responses to submit to continue the turn + :param stream: (Optional) Whether to stream the response + """ + agent_id: str session_id: str turn_id: str @@ -351,13 +451,21 @@ class AgentTurnResumeRequest(BaseModel): @json_schema_type class AgentTurnResponseStreamChunk(BaseModel): - """streamed agent turn completion response.""" + """Streamed agent turn completion response. + + :param event: Individual event in the agent turn response stream + """ event: AgentTurnResponseEvent @json_schema_type class AgentStepResponse(BaseModel): + """Response containing details of a specific agent step. + + :param step: The complete step data and execution details + """ + step: Step diff --git a/llama_stack/apis/agents/openai_responses.py b/llama_stack/apis/agents/openai_responses.py index 10843a3fe..10cadf38f 100644 --- a/llama_stack/apis/agents/openai_responses.py +++ b/llama_stack/apis/agents/openai_responses.py @@ -18,18 +18,37 @@ from llama_stack.schema_utils import json_schema_type, register_schema @json_schema_type class OpenAIResponseError(BaseModel): + """Error details for failed OpenAI response requests. + + :param code: Error code identifying the type of failure + :param message: Human-readable error message describing the failure + """ + code: str message: str @json_schema_type class OpenAIResponseInputMessageContentText(BaseModel): + """Text content for input messages in OpenAI response format. + + :param text: The text content of the input message + :param type: Content type identifier, always "input_text" + """ + text: str type: Literal["input_text"] = "input_text" @json_schema_type class OpenAIResponseInputMessageContentImage(BaseModel): + """Image content for input messages in OpenAI response format. + + :param detail: Level of detail for image processing, can be "low", "high", or "auto" + :param type: Content type identifier, always "input_image" + :param image_url: (Optional) URL of the image content + """ + detail: Literal["low"] | Literal["high"] | Literal["auto"] = "auto" type: Literal["input_image"] = "input_image" # TODO: handle file_id @@ -46,6 +65,14 @@ register_schema(OpenAIResponseInputMessageContent, name="OpenAIResponseInputMess @json_schema_type class OpenAIResponseAnnotationFileCitation(BaseModel): + """File citation annotation for referencing specific files in response content. + + :param type: Annotation type identifier, always "file_citation" + :param file_id: Unique identifier of the referenced file + :param filename: Name of the referenced file + :param index: Position index of the citation within the content + """ + type: Literal["file_citation"] = "file_citation" file_id: str filename: str @@ -54,6 +81,15 @@ class OpenAIResponseAnnotationFileCitation(BaseModel): @json_schema_type class OpenAIResponseAnnotationCitation(BaseModel): + """URL citation annotation for referencing external web resources. + + :param type: Annotation type identifier, always "url_citation" + :param end_index: End position of the citation span in the content + :param start_index: Start position of the citation span in the content + :param title: Title of the referenced web resource + :param url: URL of the referenced web resource + """ + type: Literal["url_citation"] = "url_citation" end_index: int start_index: int @@ -122,6 +158,13 @@ class OpenAIResponseMessage(BaseModel): @json_schema_type class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel): + """Web search tool call output message for OpenAI responses. + + :param id: Unique identifier for this tool call + :param status: Current status of the web search operation + :param type: Tool call type identifier, always "web_search_call" + """ + id: str status: str type: Literal["web_search_call"] = "web_search_call" @@ -129,6 +172,15 @@ class OpenAIResponseOutputMessageWebSearchToolCall(BaseModel): @json_schema_type class OpenAIResponseOutputMessageFileSearchToolCall(BaseModel): + """File search tool call output message for OpenAI responses. + + :param id: Unique identifier for this tool call + :param queries: List of search queries executed + :param status: Current status of the file search operation + :param type: Tool call type identifier, always "file_search_call" + :param results: (Optional) Search results returned by the file search operation + """ + id: str queries: list[str] status: str @@ -138,6 +190,16 @@ class OpenAIResponseOutputMessageFileSearchToolCall(BaseModel): @json_schema_type class OpenAIResponseOutputMessageFunctionToolCall(BaseModel): + """Function tool call output message for OpenAI responses. + + :param call_id: Unique identifier for the function call + :param name: Name of the function being called + :param arguments: JSON string containing the function arguments + :param type: Tool call type identifier, always "function_call" + :param id: (Optional) Additional identifier for the tool call + :param status: (Optional) Current status of the function call execution + """ + call_id: str name: str arguments: str @@ -148,6 +210,17 @@ class OpenAIResponseOutputMessageFunctionToolCall(BaseModel): @json_schema_type class OpenAIResponseOutputMessageMCPCall(BaseModel): + """Model Context Protocol (MCP) call output message for OpenAI responses. + + :param id: Unique identifier for this MCP call + :param type: Tool call type identifier, always "mcp_call" + :param arguments: JSON string containing the MCP call arguments + :param name: Name of the MCP method being called + :param server_label: Label identifying the MCP server handling the call + :param error: (Optional) Error message if the MCP call failed + :param output: (Optional) Output result from the successful MCP call + """ + id: str type: Literal["mcp_call"] = "mcp_call" arguments: str @@ -158,6 +231,13 @@ class OpenAIResponseOutputMessageMCPCall(BaseModel): class MCPListToolsTool(BaseModel): + """Tool definition returned by MCP list tools operation. + + :param input_schema: JSON schema defining the tool's input parameters + :param name: Name of the tool + :param description: (Optional) Description of what the tool does + """ + input_schema: dict[str, Any] name: str description: str | None = None @@ -165,6 +245,14 @@ class MCPListToolsTool(BaseModel): @json_schema_type class OpenAIResponseOutputMessageMCPListTools(BaseModel): + """MCP list tools output message containing available tools from an MCP server. + + :param id: Unique identifier for this MCP list tools operation + :param type: Tool call type identifier, always "mcp_list_tools" + :param server_label: Label identifying the MCP server providing the tools + :param tools: List of available tools provided by the MCP server + """ + id: str type: Literal["mcp_list_tools"] = "mcp_list_tools" server_label: str @@ -206,11 +294,34 @@ class OpenAIResponseTextFormat(TypedDict, total=False): @json_schema_type class OpenAIResponseText(BaseModel): + """Text response configuration for OpenAI responses. + + :param format: (Optional) Text format configuration specifying output format requirements + """ + format: OpenAIResponseTextFormat | None = None @json_schema_type class OpenAIResponseObject(BaseModel): + """Complete OpenAI response object containing generation results and metadata. + + :param created_at: Unix timestamp when the response was created + :param error: (Optional) Error details if the response generation failed + :param id: Unique identifier for this response + :param model: Model identifier used for generation + :param object: Object type identifier, always "response" + :param output: List of generated output items (messages, tool calls, etc.) + :param parallel_tool_calls: Whether tool calls can be executed in parallel + :param previous_response_id: (Optional) ID of the previous response in a conversation + :param status: Current status of the response generation + :param temperature: (Optional) Sampling temperature used for generation + :param text: Text formatting configuration for the response + :param top_p: (Optional) Nucleus sampling parameter used for generation + :param truncation: (Optional) Truncation strategy applied to the response + :param user: (Optional) User identifier associated with the request + """ + created_at: int error: OpenAIResponseError | None = None id: str @@ -231,6 +342,13 @@ class OpenAIResponseObject(BaseModel): @json_schema_type class OpenAIDeleteResponseObject(BaseModel): + """Response object confirming deletion of an OpenAI response. + + :param id: Unique identifier of the deleted response + :param object: Object type identifier, always "response" + :param deleted: Deletion confirmation flag, always True + """ + id: str object: Literal["response"] = "response" deleted: bool = True @@ -238,18 +356,39 @@ class OpenAIDeleteResponseObject(BaseModel): @json_schema_type class OpenAIResponseObjectStreamResponseCreated(BaseModel): + """Streaming event indicating a new response has been created. + + :param response: The newly created response object + :param type: Event type identifier, always "response.created" + """ + response: OpenAIResponseObject type: Literal["response.created"] = "response.created" @json_schema_type class OpenAIResponseObjectStreamResponseCompleted(BaseModel): + """Streaming event indicating a response has been completed. + + :param response: The completed response object + :param type: Event type identifier, always "response.completed" + """ + response: OpenAIResponseObject type: Literal["response.completed"] = "response.completed" @json_schema_type class OpenAIResponseObjectStreamResponseOutputItemAdded(BaseModel): + """Streaming event for when a new output item is added to the response. + + :param response_id: Unique identifier of the response containing this output + :param item: The output item that was added (message, tool call, etc.) + :param output_index: Index position of this item in the output list + :param sequence_number: Sequential number for ordering streaming events + :param type: Event type identifier, always "response.output_item.added" + """ + response_id: str item: OpenAIResponseOutput output_index: int @@ -259,6 +398,15 @@ class OpenAIResponseObjectStreamResponseOutputItemAdded(BaseModel): @json_schema_type class OpenAIResponseObjectStreamResponseOutputItemDone(BaseModel): + """Streaming event for when an output item is completed. + + :param response_id: Unique identifier of the response containing this output + :param item: The completed output item (message, tool call, etc.) + :param output_index: Index position of this item in the output list + :param sequence_number: Sequential number for ordering streaming events + :param type: Event type identifier, always "response.output_item.done" + """ + response_id: str item: OpenAIResponseOutput output_index: int @@ -268,6 +416,16 @@ class OpenAIResponseObjectStreamResponseOutputItemDone(BaseModel): @json_schema_type class OpenAIResponseObjectStreamResponseOutputTextDelta(BaseModel): + """Streaming event for incremental text content updates. + + :param content_index: Index position within the text content + :param delta: Incremental text content being added + :param item_id: Unique identifier of the output item being updated + :param output_index: Index position of the item in the output list + :param sequence_number: Sequential number for ordering streaming events + :param type: Event type identifier, always "response.output_text.delta" + """ + content_index: int delta: str item_id: str @@ -278,6 +436,16 @@ class OpenAIResponseObjectStreamResponseOutputTextDelta(BaseModel): @json_schema_type class OpenAIResponseObjectStreamResponseOutputTextDone(BaseModel): + """Streaming event for when text output is completed. + + :param content_index: Index position within the text content + :param text: Final complete text content of the output item + :param item_id: Unique identifier of the completed output item + :param output_index: Index position of the item in the output list + :param sequence_number: Sequential number for ordering streaming events + :param type: Event type identifier, always "response.output_text.done" + """ + content_index: int text: str # final text of the output item item_id: str @@ -288,6 +456,15 @@ class OpenAIResponseObjectStreamResponseOutputTextDone(BaseModel): @json_schema_type class OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta(BaseModel): + """Streaming event for incremental function call argument updates. + + :param delta: Incremental function call arguments being added + :param item_id: Unique identifier of the function call being updated + :param output_index: Index position of the item in the output list + :param sequence_number: Sequential number for ordering streaming events + :param type: Event type identifier, always "response.function_call_arguments.delta" + """ + delta: str item_id: str output_index: int @@ -297,6 +474,15 @@ class OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta(BaseModel): @json_schema_type class OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone(BaseModel): + """Streaming event for when function call arguments are completed. + + :param arguments: Final complete arguments JSON string for the function call + :param item_id: Unique identifier of the completed function call + :param output_index: Index position of the item in the output list + :param sequence_number: Sequential number for ordering streaming events + :param type: Event type identifier, always "response.function_call_arguments.done" + """ + arguments: str # final arguments of the function call item_id: str output_index: int @@ -306,6 +492,14 @@ class OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone(BaseModel): @json_schema_type class OpenAIResponseObjectStreamResponseWebSearchCallInProgress(BaseModel): + """Streaming event for web search calls in progress. + + :param item_id: Unique identifier of the web search call + :param output_index: Index position of the item in the output list + :param sequence_number: Sequential number for ordering streaming events + :param type: Event type identifier, always "response.web_search_call.in_progress" + """ + item_id: str output_index: int sequence_number: int @@ -322,6 +516,14 @@ class OpenAIResponseObjectStreamResponseWebSearchCallSearching(BaseModel): @json_schema_type class OpenAIResponseObjectStreamResponseWebSearchCallCompleted(BaseModel): + """Streaming event for completed web search calls. + + :param item_id: Unique identifier of the completed web search call + :param output_index: Index position of the item in the output list + :param sequence_number: Sequential number for ordering streaming events + :param type: Event type identifier, always "response.web_search_call.completed" + """ + item_id: str output_index: int sequence_number: int @@ -366,6 +568,14 @@ class OpenAIResponseObjectStreamResponseMcpCallArgumentsDone(BaseModel): @json_schema_type class OpenAIResponseObjectStreamResponseMcpCallInProgress(BaseModel): + """Streaming event for MCP calls in progress. + + :param item_id: Unique identifier of the MCP call + :param output_index: Index position of the item in the output list + :param sequence_number: Sequential number for ordering streaming events + :param type: Event type identifier, always "response.mcp_call.in_progress" + """ + item_id: str output_index: int sequence_number: int @@ -374,12 +584,24 @@ class OpenAIResponseObjectStreamResponseMcpCallInProgress(BaseModel): @json_schema_type class OpenAIResponseObjectStreamResponseMcpCallFailed(BaseModel): + """Streaming event for failed MCP calls. + + :param sequence_number: Sequential number for ordering streaming events + :param type: Event type identifier, always "response.mcp_call.failed" + """ + sequence_number: int type: Literal["response.mcp_call.failed"] = "response.mcp_call.failed" @json_schema_type class OpenAIResponseObjectStreamResponseMcpCallCompleted(BaseModel): + """Streaming event for completed MCP calls. + + :param sequence_number: Sequential number for ordering streaming events + :param type: Event type identifier, always "response.mcp_call.completed" + """ + sequence_number: int type: Literal["response.mcp_call.completed"] = "response.mcp_call.completed" @@ -442,6 +664,12 @@ WebSearchToolTypes = ["web_search", "web_search_preview", "web_search_preview_20 @json_schema_type class OpenAIResponseInputToolWebSearch(BaseModel): + """Web search tool configuration for OpenAI response inputs. + + :param type: Web search tool type variant to use + :param search_context_size: (Optional) Size of search context, must be "low", "medium", or "high" + """ + # Must match values of WebSearchToolTypes above type: Literal["web_search"] | Literal["web_search_preview"] | Literal["web_search_preview_2025_03_11"] = ( "web_search" @@ -453,6 +681,15 @@ class OpenAIResponseInputToolWebSearch(BaseModel): @json_schema_type class OpenAIResponseInputToolFunction(BaseModel): + """Function tool configuration for OpenAI response inputs. + + :param type: Tool type identifier, always "function" + :param name: Name of the function that can be called + :param description: (Optional) Description of what the function does + :param parameters: (Optional) JSON schema defining the function's parameters + :param strict: (Optional) Whether to enforce strict parameter validation + """ + type: Literal["function"] = "function" name: str description: str | None = None @@ -462,6 +699,15 @@ class OpenAIResponseInputToolFunction(BaseModel): @json_schema_type class OpenAIResponseInputToolFileSearch(BaseModel): + """File search tool configuration for OpenAI response inputs. + + :param type: Tool type identifier, always "file_search" + :param vector_store_ids: List of vector store identifiers to search within + :param filters: (Optional) Additional filters to apply to the search + :param max_num_results: (Optional) Maximum number of search results to return (1-50) + :param ranking_options: (Optional) Options for ranking and scoring search results + """ + type: Literal["file_search"] = "file_search" vector_store_ids: list[str] filters: dict[str, Any] | None = None @@ -470,16 +716,37 @@ class OpenAIResponseInputToolFileSearch(BaseModel): class ApprovalFilter(BaseModel): + """Filter configuration for MCP tool approval requirements. + + :param always: (Optional) List of tool names that always require approval + :param never: (Optional) List of tool names that never require approval + """ + always: list[str] | None = None never: list[str] | None = None class AllowedToolsFilter(BaseModel): + """Filter configuration for restricting which MCP tools can be used. + + :param tool_names: (Optional) List of specific tool names that are allowed + """ + tool_names: list[str] | None = None @json_schema_type class OpenAIResponseInputToolMCP(BaseModel): + """Model Context Protocol (MCP) tool configuration for OpenAI response inputs. + + :param type: Tool type identifier, always "mcp" + :param server_label: Label to identify this MCP server + :param server_url: URL endpoint of the MCP server + :param headers: (Optional) HTTP headers to include when connecting to the server + :param require_approval: Approval requirement for tool calls ("always", "never", or filter) + :param allowed_tools: (Optional) Restriction on which tools can be used from this server + """ + type: Literal["mcp"] = "mcp" server_label: str server_url: str @@ -500,17 +767,37 @@ register_schema(OpenAIResponseInputTool, name="OpenAIResponseInputTool") class ListOpenAIResponseInputItem(BaseModel): + """List container for OpenAI response input items. + + :param data: List of input items + :param object: Object type identifier, always "list" + """ + data: list[OpenAIResponseInput] object: Literal["list"] = "list" @json_schema_type class OpenAIResponseObjectWithInput(OpenAIResponseObject): + """OpenAI response object extended with input context information. + + :param input: List of input items that led to this response + """ + input: list[OpenAIResponseInput] @json_schema_type class ListOpenAIResponseObject(BaseModel): + """Paginated list of OpenAI response objects with navigation metadata. + + :param data: List of response objects with their input context + :param has_more: Whether there are more results available beyond this page + :param first_id: Identifier of the first item in this page + :param last_id: Identifier of the last item in this page + :param object: Object type identifier, always "list" + """ + data: list[OpenAIResponseObjectWithInput] has_more: bool first_id: str diff --git a/llama_stack/apis/benchmarks/benchmarks.py b/llama_stack/apis/benchmarks/benchmarks.py index d80c767f8..706eaed6c 100644 --- a/llama_stack/apis/benchmarks/benchmarks.py +++ b/llama_stack/apis/benchmarks/benchmarks.py @@ -22,6 +22,14 @@ class CommonBenchmarkFields(BaseModel): @json_schema_type class Benchmark(CommonBenchmarkFields, Resource): + """A benchmark resource for evaluating model performance. + + :param dataset_id: Identifier of the dataset to use for the benchmark evaluation + :param scoring_functions: List of scoring function identifiers to apply during evaluation + :param metadata: Metadata for this evaluation task + :param type: The resource type, always benchmark + """ + type: Literal[ResourceType.benchmark] = ResourceType.benchmark @property diff --git a/llama_stack/apis/common/content_types.py b/llama_stack/apis/common/content_types.py index 8bcb781f7..950dd17ff 100644 --- a/llama_stack/apis/common/content_types.py +++ b/llama_stack/apis/common/content_types.py @@ -15,6 +15,11 @@ from llama_stack.schema_utils import json_schema_type, register_schema @json_schema_type class URL(BaseModel): + """A URL reference to external content. + + :param uri: The URL string pointing to the resource + """ + uri: str @@ -76,17 +81,36 @@ register_schema(InterleavedContent, name="InterleavedContent") @json_schema_type class TextDelta(BaseModel): + """A text content delta for streaming responses. + + :param type: Discriminator type of the delta. Always "text" + :param text: The incremental text content + """ + type: Literal["text"] = "text" text: str @json_schema_type class ImageDelta(BaseModel): + """An image content delta for streaming responses. + + :param type: Discriminator type of the delta. Always "image" + :param image: The incremental image data as bytes + """ + type: Literal["image"] = "image" image: bytes class ToolCallParseStatus(Enum): + """Status of tool call parsing during streaming. + :cvar started: Tool call parsing has begun + :cvar in_progress: Tool call parsing is ongoing + :cvar failed: Tool call parsing failed + :cvar succeeded: Tool call parsing completed successfully + """ + started = "started" in_progress = "in_progress" failed = "failed" @@ -95,6 +119,13 @@ class ToolCallParseStatus(Enum): @json_schema_type class ToolCallDelta(BaseModel): + """A tool call content delta for streaming responses. + + :param type: Discriminator type of the delta. Always "tool_call" + :param tool_call: Either an in-progress tool call string or the final parsed tool call + :param parse_status: Current parsing status of the tool call + """ + type: Literal["tool_call"] = "tool_call" # you either send an in-progress tool call so the client can stream a long diff --git a/llama_stack/apis/common/errors.py b/llama_stack/apis/common/errors.py index 80f297bce..95d6ac18e 100644 --- a/llama_stack/apis/common/errors.py +++ b/llama_stack/apis/common/errors.py @@ -4,6 +4,21 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +# Custom Llama Stack Exception classes should follow the following schema +# 1. All classes should inherit from an existing Built-In Exception class: https://docs.python.org/3/library/exceptions.html +# 2. All classes should have a custom error message with the goal of informing the Llama Stack user specifically +# 3. All classes should propogate the inherited __init__ function otherwise via 'super().__init__(message)' + + +class ResourceNotFoundError(ValueError): + """generic exception for a missing Llama Stack resource""" + + def __init__(self, resource_name: str, resource_type: str, client_list: str) -> None: + message = ( + f"{resource_type} '{resource_name}' not found. Use '{client_list}' to list available {resource_type}s." + ) + super().__init__(message) + class UnsupportedModelError(ValueError): """raised when model is not present in the list of supported models""" @@ -11,3 +26,39 @@ class UnsupportedModelError(ValueError): def __init__(self, model_name: str, supported_models_list: list[str]): message = f"'{model_name}' model is not supported. Supported models are: {', '.join(supported_models_list)}" super().__init__(message) + + +class ModelNotFoundError(ResourceNotFoundError): + """raised when Llama Stack cannot find a referenced model""" + + def __init__(self, model_name: str) -> None: + super().__init__(model_name, "Model", "client.models.list()") + + +class VectorStoreNotFoundError(ResourceNotFoundError): + """raised when Llama Stack cannot find a referenced vector store""" + + def __init__(self, vector_store_name: str) -> None: + super().__init__(vector_store_name, "Vector Store", "client.vector_dbs.list()") + + +class DatasetNotFoundError(ResourceNotFoundError): + """raised when Llama Stack cannot find a referenced dataset""" + + def __init__(self, dataset_name: str) -> None: + super().__init__(dataset_name, "Dataset", "client.datasets.list()") + + +class ToolGroupNotFoundError(ResourceNotFoundError): + """raised when Llama Stack cannot find a referenced tool group""" + + def __init__(self, toolgroup_name: str) -> None: + super().__init__(toolgroup_name, "Tool Group", "client.toolgroups.list()") + + +class SessionNotFoundError(ValueError): + """raised when Llama Stack cannot find a referenced session or access is denied""" + + def __init__(self, session_name: str) -> None: + message = f"Session '{session_name}' not found or access denied." + super().__init__(message) diff --git a/llama_stack/apis/common/job_types.py b/llama_stack/apis/common/job_types.py index ca6bcaf63..5da42bfd3 100644 --- a/llama_stack/apis/common/job_types.py +++ b/llama_stack/apis/common/job_types.py @@ -11,6 +11,14 @@ from llama_stack.schema_utils import json_schema_type class JobStatus(Enum): + """Status of a job execution. + :cvar completed: Job has finished successfully + :cvar in_progress: Job is currently running + :cvar failed: Job has failed during execution + :cvar scheduled: Job is scheduled but not yet started + :cvar cancelled: Job was cancelled before completion + """ + completed = "completed" in_progress = "in_progress" failed = "failed" @@ -20,5 +28,11 @@ class JobStatus(Enum): @json_schema_type class Job(BaseModel): + """A job execution instance with status tracking. + + :param job_id: Unique identifier for the job + :param status: Current execution status of the job + """ + job_id: str status: JobStatus diff --git a/llama_stack/apis/common/responses.py b/llama_stack/apis/common/responses.py index e4cf21a54..616bee73a 100644 --- a/llama_stack/apis/common/responses.py +++ b/llama_stack/apis/common/responses.py @@ -13,6 +13,11 @@ from llama_stack.schema_utils import json_schema_type class Order(Enum): + """Sort order for paginated responses. + :cvar asc: Ascending order + :cvar desc: Descending order + """ + asc = "asc" desc = "desc" diff --git a/llama_stack/apis/common/training_types.py b/llama_stack/apis/common/training_types.py index a2c3b78f1..5c236a25d 100644 --- a/llama_stack/apis/common/training_types.py +++ b/llama_stack/apis/common/training_types.py @@ -13,6 +13,14 @@ from llama_stack.schema_utils import json_schema_type @json_schema_type class PostTrainingMetric(BaseModel): + """Training metrics captured during post-training jobs. + + :param epoch: Training epoch number + :param train_loss: Loss value on the training dataset + :param validation_loss: Loss value on the validation dataset + :param perplexity: Perplexity metric indicating model confidence + """ + epoch: int train_loss: float validation_loss: float @@ -21,7 +29,15 @@ class PostTrainingMetric(BaseModel): @json_schema_type class Checkpoint(BaseModel): - """Checkpoint created during training runs""" + """Checkpoint created during training runs. + + :param identifier: Unique identifier for the checkpoint + :param created_at: Timestamp when the checkpoint was created + :param epoch: Training epoch when the checkpoint was saved + :param post_training_job_id: Identifier of the training job that created this checkpoint + :param path: File system path where the checkpoint is stored + :param training_metrics: (Optional) Training metrics associated with this checkpoint + """ identifier: str created_at: datetime diff --git a/llama_stack/apis/common/type_system.py b/llama_stack/apis/common/type_system.py index db4aab4c5..0e62ee484 100644 --- a/llama_stack/apis/common/type_system.py +++ b/llama_stack/apis/common/type_system.py @@ -13,59 +13,114 @@ from llama_stack.schema_utils import json_schema_type, register_schema @json_schema_type class StringType(BaseModel): + """Parameter type for string values. + + :param type: Discriminator type. Always "string" + """ + type: Literal["string"] = "string" @json_schema_type class NumberType(BaseModel): + """Parameter type for numeric values. + + :param type: Discriminator type. Always "number" + """ + type: Literal["number"] = "number" @json_schema_type class BooleanType(BaseModel): + """Parameter type for boolean values. + + :param type: Discriminator type. Always "boolean" + """ + type: Literal["boolean"] = "boolean" @json_schema_type class ArrayType(BaseModel): + """Parameter type for array values. + + :param type: Discriminator type. Always "array" + """ + type: Literal["array"] = "array" @json_schema_type class ObjectType(BaseModel): + """Parameter type for object values. + + :param type: Discriminator type. Always "object" + """ + type: Literal["object"] = "object" @json_schema_type class JsonType(BaseModel): + """Parameter type for JSON values. + + :param type: Discriminator type. Always "json" + """ + type: Literal["json"] = "json" @json_schema_type class UnionType(BaseModel): + """Parameter type for union values. + + :param type: Discriminator type. Always "union" + """ + type: Literal["union"] = "union" @json_schema_type class ChatCompletionInputType(BaseModel): + """Parameter type for chat completion input. + + :param type: Discriminator type. Always "chat_completion_input" + """ + # expects List[Message] for messages type: Literal["chat_completion_input"] = "chat_completion_input" @json_schema_type class CompletionInputType(BaseModel): + """Parameter type for completion input. + + :param type: Discriminator type. Always "completion_input" + """ + # expects InterleavedTextMedia for content type: Literal["completion_input"] = "completion_input" @json_schema_type class AgentTurnInputType(BaseModel): + """Parameter type for agent turn input. + + :param type: Discriminator type. Always "agent_turn_input" + """ + # expects List[Message] for messages (may also include attachments?) type: Literal["agent_turn_input"] = "agent_turn_input" @json_schema_type class DialogType(BaseModel): + """Parameter type for dialog data with semantic output labels. + + :param type: Discriminator type. Always "dialog" + """ + # expects List[Message] for messages # this type semantically contains the output label whereas ChatCompletionInputType does not type: Literal["dialog"] = "dialog" diff --git a/llama_stack/apis/datasets/datasets.py b/llama_stack/apis/datasets/datasets.py index 8bf7a48d0..f347e0e29 100644 --- a/llama_stack/apis/datasets/datasets.py +++ b/llama_stack/apis/datasets/datasets.py @@ -94,6 +94,10 @@ register_schema(DataSource, name="DataSource") class CommonDatasetFields(BaseModel): """ Common fields for a dataset. + + :param purpose: Purpose of the dataset indicating its intended use + :param source: Data source configuration for the dataset + :param metadata: Additional metadata for the dataset """ purpose: DatasetPurpose @@ -106,6 +110,11 @@ class CommonDatasetFields(BaseModel): @json_schema_type class Dataset(CommonDatasetFields, Resource): + """Dataset resource for storing and accessing training or evaluation data. + + :param type: Type of resource, always 'dataset' for datasets + """ + type: Literal[ResourceType.dataset] = ResourceType.dataset @property @@ -118,10 +127,20 @@ class Dataset(CommonDatasetFields, Resource): class DatasetInput(CommonDatasetFields, BaseModel): + """Input parameters for dataset operations. + + :param dataset_id: Unique identifier for the dataset + """ + dataset_id: str class ListDatasetsResponse(BaseModel): + """Response from listing datasets. + + :param data: List of datasets + """ + data: list[Dataset] diff --git a/llama_stack/apis/datatypes.py b/llama_stack/apis/datatypes.py index e6628f5d7..cabe46a2f 100644 --- a/llama_stack/apis/datatypes.py +++ b/llama_stack/apis/datatypes.py @@ -81,6 +81,29 @@ class DynamicApiMeta(EnumMeta): @json_schema_type class Api(Enum, metaclass=DynamicApiMeta): + """Enumeration of all available APIs in the Llama Stack system. + :cvar providers: Provider management and configuration + :cvar inference: Text generation, chat completions, and embeddings + :cvar safety: Content moderation and safety shields + :cvar agents: Agent orchestration and execution + :cvar vector_io: Vector database operations and queries + :cvar datasetio: Dataset input/output operations + :cvar scoring: Model output evaluation and scoring + :cvar eval: Model evaluation and benchmarking framework + :cvar post_training: Fine-tuning and model training + :cvar tool_runtime: Tool execution and management + :cvar telemetry: Observability and system monitoring + :cvar models: Model metadata and management + :cvar shields: Safety shield implementations + :cvar vector_dbs: Vector database management + :cvar datasets: Dataset creation and management + :cvar scoring_functions: Scoring function definitions + :cvar benchmarks: Benchmark suite management + :cvar tool_groups: Tool group organization + :cvar files: File storage and management + :cvar inspect: Built-in system inspection and introspection + """ + providers = "providers" inference = "inference" safety = "safety" diff --git a/llama_stack/apis/files/files.py b/llama_stack/apis/files/files.py index a72dcd8d4..ba8701e23 100644 --- a/llama_stack/apis/files/files.py +++ b/llama_stack/apis/files/files.py @@ -54,6 +54,9 @@ class ListOpenAIFileResponse(BaseModel): Response for listing files in OpenAI Files API. :param data: List of file objects + :param has_more: Whether there are more files available beyond this page + :param first_id: ID of the first file in the list for pagination + :param last_id: ID of the last file in the list for pagination :param object: The object type, which is always "list" """ diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index aabb41839..7e7bd0a3d 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -41,11 +41,23 @@ from enum import StrEnum @json_schema_type class GreedySamplingStrategy(BaseModel): + """Greedy sampling strategy that selects the highest probability token at each step. + + :param type: Must be "greedy" to identify this sampling strategy + """ + type: Literal["greedy"] = "greedy" @json_schema_type class TopPSamplingStrategy(BaseModel): + """Top-p (nucleus) sampling strategy that samples from the smallest set of tokens with cumulative probability >= p. + + :param type: Must be "top_p" to identify this sampling strategy + :param temperature: Controls randomness in sampling. Higher values increase randomness + :param top_p: Cumulative probability threshold for nucleus sampling. Defaults to 0.95 + """ + type: Literal["top_p"] = "top_p" temperature: float | None = Field(..., gt=0.0) top_p: float | None = 0.95 @@ -53,6 +65,12 @@ class TopPSamplingStrategy(BaseModel): @json_schema_type class TopKSamplingStrategy(BaseModel): + """Top-k sampling strategy that restricts sampling to the k most likely tokens. + + :param type: Must be "top_k" to identify this sampling strategy + :param top_k: Number of top tokens to consider for sampling. Must be at least 1 + """ + type: Literal["top_k"] = "top_k" top_k: int = Field(..., ge=1) @@ -108,11 +126,21 @@ class QuantizationType(Enum): @json_schema_type class Fp8QuantizationConfig(BaseModel): + """Configuration for 8-bit floating point quantization. + + :param type: Must be "fp8_mixed" to identify this quantization type + """ + type: Literal["fp8_mixed"] = "fp8_mixed" @json_schema_type class Bf16QuantizationConfig(BaseModel): + """Configuration for BFloat16 precision (typically no quantization). + + :param type: Must be "bf16" to identify this quantization type + """ + type: Literal["bf16"] = "bf16" @@ -202,6 +230,14 @@ register_schema(Message, name="Message") @json_schema_type class ToolResponse(BaseModel): + """Response from a tool invocation. + + :param call_id: Unique identifier for the tool call this response is for + :param tool_name: Name of the tool that was invoked + :param content: The response content from the tool + :param metadata: (Optional) Additional metadata about the tool response + """ + call_id: str tool_name: BuiltinTool | str content: InterleavedContent @@ -439,18 +475,36 @@ class EmbeddingsResponse(BaseModel): @json_schema_type class OpenAIChatCompletionContentPartTextParam(BaseModel): + """Text content part for OpenAI-compatible chat completion messages. + + :param type: Must be "text" to identify this as text content + :param text: The text content of the message + """ + type: Literal["text"] = "text" text: str @json_schema_type class OpenAIImageURL(BaseModel): + """Image URL specification for OpenAI-compatible chat completion messages. + + :param url: URL of the image to include in the message + :param detail: (Optional) Level of detail for image processing. Can be "low", "high", or "auto" + """ + url: str detail: str | None = None @json_schema_type class OpenAIChatCompletionContentPartImageParam(BaseModel): + """Image content part for OpenAI-compatible chat completion messages. + + :param type: Must be "image_url" to identify this as image content + :param image_url: Image URL specification and processing details + """ + type: Literal["image_url"] = "image_url" image_url: OpenAIImageURL @@ -510,12 +564,26 @@ class OpenAISystemMessageParam(BaseModel): @json_schema_type class OpenAIChatCompletionToolCallFunction(BaseModel): + """Function call details for OpenAI-compatible tool calls. + + :param name: (Optional) Name of the function to call + :param arguments: (Optional) Arguments to pass to the function as a JSON string + """ + name: str | None = None arguments: str | None = None @json_schema_type class OpenAIChatCompletionToolCall(BaseModel): + """Tool call specification for OpenAI-compatible chat completion responses. + + :param index: (Optional) Index of the tool call in the list + :param id: (Optional) Unique identifier for the tool call + :param type: Must be "function" to identify this as a function call + :param function: (Optional) Function call details + """ + index: int | None = None id: str | None = None type: Literal["function"] = "function" @@ -579,11 +647,24 @@ register_schema(OpenAIMessageParam, name="OpenAIMessageParam") @json_schema_type class OpenAIResponseFormatText(BaseModel): + """Text response format for OpenAI-compatible chat completion requests. + + :param type: Must be "text" to indicate plain text response format + """ + type: Literal["text"] = "text" @json_schema_type class OpenAIJSONSchema(TypedDict, total=False): + """JSON schema specification for OpenAI-compatible structured response format. + + :param name: Name of the schema + :param description: (Optional) Description of the schema + :param strict: (Optional) Whether to enforce strict adherence to the schema + :param schema: (Optional) The JSON schema definition + """ + name: str description: str | None strict: bool | None @@ -597,12 +678,23 @@ class OpenAIJSONSchema(TypedDict, total=False): @json_schema_type class OpenAIResponseFormatJSONSchema(BaseModel): + """JSON schema response format for OpenAI-compatible chat completion requests. + + :param type: Must be "json_schema" to indicate structured JSON response format + :param json_schema: The JSON schema specification for the response + """ + type: Literal["json_schema"] = "json_schema" json_schema: OpenAIJSONSchema @json_schema_type class OpenAIResponseFormatJSONObject(BaseModel): + """JSON object response format for OpenAI-compatible chat completion requests. + + :param type: Must be "json_object" to indicate generic JSON object response format + """ + type: Literal["json_object"] = "json_object" @@ -861,11 +953,21 @@ class EmbeddingTaskType(Enum): @json_schema_type class BatchCompletionResponse(BaseModel): + """Response from a batch completion request. + + :param batch: List of completion responses, one for each input in the batch + """ + batch: list[CompletionResponse] @json_schema_type class BatchChatCompletionResponse(BaseModel): + """Response from a batch chat completion request. + + :param batch: List of chat completion responses, one for each conversation in the batch + """ + batch: list[ChatCompletionResponse] @@ -875,6 +977,15 @@ class OpenAICompletionWithInputMessages(OpenAIChatCompletion): @json_schema_type class ListOpenAIChatCompletionResponse(BaseModel): + """Response from listing OpenAI-compatible chat completions. + + :param data: List of chat completion objects with their input messages + :param has_more: Whether there are more completions available beyond this list + :param first_id: ID of the first completion in this list + :param last_id: ID of the last completion in this list + :param object: Must be "list" to identify this as a list response + """ + data: list[OpenAICompletionWithInputMessages] has_more: bool first_id: str diff --git a/llama_stack/apis/inspect/inspect.py b/llama_stack/apis/inspect/inspect.py index 44a5e95b2..91d9c3da7 100644 --- a/llama_stack/apis/inspect/inspect.py +++ b/llama_stack/apis/inspect/inspect.py @@ -14,6 +14,13 @@ from llama_stack.schema_utils import json_schema_type, webmethod @json_schema_type class RouteInfo(BaseModel): + """Information about an API route including its path, method, and implementing providers. + + :param route: The API endpoint path + :param method: HTTP method for the route + :param provider_types: List of provider types that implement this route + """ + route: str method: str provider_types: list[str] @@ -21,15 +28,30 @@ class RouteInfo(BaseModel): @json_schema_type class HealthInfo(BaseModel): + """Health status information for the service. + + :param status: Current health status of the service + """ + status: HealthStatus @json_schema_type class VersionInfo(BaseModel): + """Version information for the service. + + :param version: Version number of the service + """ + version: str class ListRoutesResponse(BaseModel): + """Response containing a list of all available API routes. + + :param data: List of available route information objects + """ + data: list[RouteInfo] @@ -37,17 +59,17 @@ class ListRoutesResponse(BaseModel): class Inspect(Protocol): @webmethod(route="/inspect/routes", method="GET") async def list_routes(self) -> ListRoutesResponse: - """List all routes. + """List all available API routes with their methods and implementing providers. - :returns: A ListRoutesResponse. + :returns: Response containing information about all available routes. """ ... @webmethod(route="/health", method="GET") async def health(self) -> HealthInfo: - """Get the health of the service. + """Get the current health status of the service. - :returns: A HealthInfo. + :returns: Health information indicating if the service is operational. """ ... @@ -55,6 +77,6 @@ class Inspect(Protocol): async def version(self) -> VersionInfo: """Get the version of the service. - :returns: A VersionInfo. + :returns: Version information containing the service version number. """ ... diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index 2143346d9..1af6fc9df 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -23,12 +23,27 @@ class CommonModelFields(BaseModel): @json_schema_type class ModelType(StrEnum): + """Enumeration of supported model types in Llama Stack. + :cvar llm: Large language model for text generation and completion + :cvar embedding: Embedding model for converting text to vector representations + """ + llm = "llm" embedding = "embedding" @json_schema_type class Model(CommonModelFields, Resource): + """A model resource representing an AI model registered in Llama Stack. + + :param type: The resource type, always 'model' for model resources + :param model_type: The type of model (LLM or embedding model) + :param metadata: Any additional metadata for this model + :param identifier: Unique identifier for this resource in llama stack + :param provider_resource_id: Unique identifier for this resource in the provider + :param provider_id: ID of the provider that owns this resource + """ + type: Literal[ResourceType.model] = ResourceType.model @property diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index f6860ea4b..c16221289 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -18,6 +18,12 @@ from llama_stack.schema_utils import json_schema_type, register_schema, webmetho @json_schema_type class OptimizerType(Enum): + """Available optimizer algorithms for training. + :cvar adam: Adaptive Moment Estimation optimizer + :cvar adamw: AdamW optimizer with weight decay + :cvar sgd: Stochastic Gradient Descent optimizer + """ + adam = "adam" adamw = "adamw" sgd = "sgd" @@ -25,12 +31,28 @@ class OptimizerType(Enum): @json_schema_type class DatasetFormat(Enum): + """Format of the training dataset. + :cvar instruct: Instruction-following format with prompt and completion + :cvar dialog: Multi-turn conversation format with messages + """ + instruct = "instruct" dialog = "dialog" @json_schema_type class DataConfig(BaseModel): + """Configuration for training data and data loading. + + :param dataset_id: Unique identifier for the training dataset + :param batch_size: Number of samples per training batch + :param shuffle: Whether to shuffle the dataset during training + :param data_format: Format of the dataset (instruct or dialog) + :param validation_dataset_id: (Optional) Unique identifier for the validation dataset + :param packed: (Optional) Whether to pack multiple samples into a single sequence for efficiency + :param train_on_input: (Optional) Whether to compute loss on input tokens as well as output tokens + """ + dataset_id: str batch_size: int shuffle: bool @@ -42,6 +64,14 @@ class DataConfig(BaseModel): @json_schema_type class OptimizerConfig(BaseModel): + """Configuration parameters for the optimization algorithm. + + :param optimizer_type: Type of optimizer to use (adam, adamw, or sgd) + :param lr: Learning rate for the optimizer + :param weight_decay: Weight decay coefficient for regularization + :param num_warmup_steps: Number of steps for learning rate warmup + """ + optimizer_type: OptimizerType lr: float weight_decay: float @@ -50,6 +80,14 @@ class OptimizerConfig(BaseModel): @json_schema_type class EfficiencyConfig(BaseModel): + """Configuration for memory and compute efficiency optimizations. + + :param enable_activation_checkpointing: (Optional) Whether to use activation checkpointing to reduce memory usage + :param enable_activation_offloading: (Optional) Whether to offload activations to CPU to save GPU memory + :param memory_efficient_fsdp_wrap: (Optional) Whether to use memory-efficient FSDP wrapping + :param fsdp_cpu_offload: (Optional) Whether to offload FSDP parameters to CPU + """ + enable_activation_checkpointing: bool | None = False enable_activation_offloading: bool | None = False memory_efficient_fsdp_wrap: bool | None = False @@ -58,6 +96,18 @@ class EfficiencyConfig(BaseModel): @json_schema_type class TrainingConfig(BaseModel): + """Comprehensive configuration for the training process. + + :param n_epochs: Number of training epochs to run + :param max_steps_per_epoch: Maximum number of steps to run per epoch + :param gradient_accumulation_steps: Number of steps to accumulate gradients before updating + :param max_validation_steps: (Optional) Maximum number of validation steps per epoch + :param data_config: (Optional) Configuration for data loading and formatting + :param optimizer_config: (Optional) Configuration for the optimization algorithm + :param efficiency_config: (Optional) Configuration for memory and compute optimizations + :param dtype: (Optional) Data type for model parameters (bf16, fp16, fp32) + """ + n_epochs: int max_steps_per_epoch: int = 1 gradient_accumulation_steps: int = 1 @@ -70,6 +120,18 @@ class TrainingConfig(BaseModel): @json_schema_type class LoraFinetuningConfig(BaseModel): + """Configuration for Low-Rank Adaptation (LoRA) fine-tuning. + + :param type: Algorithm type identifier, always "LoRA" + :param lora_attn_modules: List of attention module names to apply LoRA to + :param apply_lora_to_mlp: Whether to apply LoRA to MLP layers + :param apply_lora_to_output: Whether to apply LoRA to output projection layers + :param rank: Rank of the LoRA adaptation (lower rank = fewer parameters) + :param alpha: LoRA scaling parameter that controls adaptation strength + :param use_dora: (Optional) Whether to use DoRA (Weight-Decomposed Low-Rank Adaptation) + :param quantize_base: (Optional) Whether to quantize the base model weights + """ + type: Literal["LoRA"] = "LoRA" lora_attn_modules: list[str] apply_lora_to_mlp: bool @@ -82,6 +144,13 @@ class LoraFinetuningConfig(BaseModel): @json_schema_type class QATFinetuningConfig(BaseModel): + """Configuration for Quantization-Aware Training (QAT) fine-tuning. + + :param type: Algorithm type identifier, always "QAT" + :param quantizer_name: Name of the quantization algorithm to use + :param group_size: Size of groups for grouped quantization + """ + type: Literal["QAT"] = "QAT" quantizer_name: str group_size: int @@ -93,7 +162,11 @@ register_schema(AlgorithmConfig, name="AlgorithmConfig") @json_schema_type class PostTrainingJobLogStream(BaseModel): - """Stream of logs from a finetuning job.""" + """Stream of logs from a finetuning job. + + :param job_uuid: Unique identifier for the training job + :param log_lines: List of log message strings from the training process + """ job_uuid: str log_lines: list[str] @@ -101,6 +174,10 @@ class PostTrainingJobLogStream(BaseModel): @json_schema_type class RLHFAlgorithm(Enum): + """Available reinforcement learning from human feedback algorithms. + :cvar dpo: Direct Preference Optimization algorithm + """ + dpo = "dpo" @@ -114,13 +191,31 @@ class DPOLossType(Enum): @json_schema_type class DPOAlignmentConfig(BaseModel): + """Configuration for Direct Preference Optimization (DPO) alignment. + + :param beta: Temperature parameter for the DPO loss + :param loss_type: The type of loss function to use for DPO + """ + beta: float loss_type: DPOLossType = DPOLossType.sigmoid @json_schema_type class PostTrainingRLHFRequest(BaseModel): - """Request to finetune a model.""" + """Request to finetune a model using reinforcement learning from human feedback. + + :param job_uuid: Unique identifier for the training job + :param finetuned_model: URL or path to the base model to fine-tune + :param dataset_id: Unique identifier for the training dataset + :param validation_dataset_id: Unique identifier for the validation dataset + :param algorithm: RLHF algorithm to use for training + :param algorithm_config: Configuration parameters for the RLHF algorithm + :param optimizer_config: Configuration parameters for the optimization algorithm + :param training_config: Configuration parameters for the training process + :param hyperparam_search_config: Configuration for hyperparameter search + :param logger_config: Configuration for training logging + """ job_uuid: str @@ -146,7 +241,16 @@ class PostTrainingJob(BaseModel): @json_schema_type class PostTrainingJobStatusResponse(BaseModel): - """Status of a finetuning job.""" + """Status of a finetuning job. + + :param job_uuid: Unique identifier for the training job + :param status: Current status of the training job + :param scheduled_at: (Optional) Timestamp when the job was scheduled + :param started_at: (Optional) Timestamp when the job execution began + :param completed_at: (Optional) Timestamp when the job finished, if completed + :param resources_allocated: (Optional) Information about computational resources allocated to the job + :param checkpoints: List of model checkpoints created during training + """ job_uuid: str status: JobStatus @@ -166,7 +270,11 @@ class ListPostTrainingJobsResponse(BaseModel): @json_schema_type class PostTrainingJobArtifactsResponse(BaseModel): - """Artifacts of a finetuning job.""" + """Artifacts of a finetuning job. + + :param job_uuid: Unique identifier for the training job + :param checkpoints: List of model checkpoints created during training + """ job_uuid: str checkpoints: list[Checkpoint] = Field(default_factory=list) diff --git a/llama_stack/apis/providers/providers.py b/llama_stack/apis/providers/providers.py index 4bc977bf1..8a1e93d8f 100644 --- a/llama_stack/apis/providers/providers.py +++ b/llama_stack/apis/providers/providers.py @@ -14,6 +14,15 @@ from llama_stack.schema_utils import json_schema_type, webmethod @json_schema_type class ProviderInfo(BaseModel): + """Information about a registered provider including its configuration and health status. + + :param api: The API name this provider implements + :param provider_id: Unique identifier for the provider + :param provider_type: The type of provider implementation + :param config: Configuration parameters for the provider + :param health: Current health status of the provider + """ + api: str provider_id: str provider_type: str @@ -22,6 +31,11 @@ class ProviderInfo(BaseModel): class ListProvidersResponse(BaseModel): + """Response containing a list of all available providers. + + :param data: List of provider information objects + """ + data: list[ProviderInfo] diff --git a/llama_stack/apis/safety/safety.py b/llama_stack/apis/safety/safety.py index 3aee52b7e..3f374460b 100644 --- a/llama_stack/apis/safety/safety.py +++ b/llama_stack/apis/safety/safety.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from enum import Enum +from enum import Enum, StrEnum from typing import Any, Protocol, runtime_checkable from pydantic import BaseModel, Field @@ -15,8 +15,80 @@ from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol from llama_stack.schema_utils import json_schema_type, webmethod +# OpenAI Categories to return in the response +class OpenAICategories(StrEnum): + """ + Required set of categories in moderations api response + """ + + VIOLENCE = "violence" + VIOLENCE_GRAPHIC = "violence/graphic" + HARRASMENT = "harassment" + HARRASMENT_THREATENING = "harassment/threatening" + HATE = "hate" + HATE_THREATENING = "hate/threatening" + ILLICIT = "illicit" + ILLICIT_VIOLENT = "illicit/violent" + SEXUAL = "sexual" + SEXUAL_MINORS = "sexual/minors" + SELF_HARM = "self-harm" + SELF_HARM_INTENT = "self-harm/intent" + SELF_HARM_INSTRUCTIONS = "self-harm/instructions" + + +@json_schema_type +class ModerationObjectResults(BaseModel): + """A moderation object. + :param flagged: Whether any of the below categories are flagged. + :param categories: A list of the categories, and whether they are flagged or not. + :param category_applied_input_types: A list of the categories along with the input type(s) that the score applies to. + :param category_scores: A list of the categories along with their scores as predicted by model. + Required set of categories that need to be in response + - violence + - violence/graphic + - harassment + - harassment/threatening + - hate + - hate/threatening + - illicit + - illicit/violent + - sexual + - sexual/minors + - self-harm + - self-harm/intent + - self-harm/instructions + """ + + flagged: bool + categories: dict[str, bool] | None = None + category_applied_input_types: dict[str, list[str]] | None = None + category_scores: dict[str, float] | None = None + user_message: str | None = None + metadata: dict[str, Any] = Field(default_factory=dict) + + +@json_schema_type +class ModerationObject(BaseModel): + """A moderation object. + :param id: The unique identifier for the moderation request. + :param model: The model used to generate the moderation results. + :param results: A list of moderation objects + """ + + id: str + model: str + results: list[ModerationObjectResults] + + @json_schema_type class ViolationLevel(Enum): + """Severity level of a safety violation. + + :cvar INFO: Informational level violation that does not require action + :cvar WARN: Warning level violation that suggests caution but allows continuation + :cvar ERROR: Error level violation that requires blocking or intervention + """ + INFO = "info" WARN = "warn" ERROR = "error" @@ -24,6 +96,13 @@ class ViolationLevel(Enum): @json_schema_type class SafetyViolation(BaseModel): + """Details of a safety violation detected by content moderation. + + :param violation_level: Severity level of the violation + :param user_message: (Optional) Message to convey to the user about the violation + :param metadata: Additional metadata including specific violation codes for debugging and telemetry + """ + violation_level: ViolationLevel # what message should you convey to the user @@ -36,6 +115,11 @@ class SafetyViolation(BaseModel): @json_schema_type class RunShieldResponse(BaseModel): + """Response from running a safety shield. + + :param violation: (Optional) Safety violation detected by the shield, if any + """ + violation: SafetyViolation | None = None @@ -63,3 +147,13 @@ class Safety(Protocol): :returns: A RunShieldResponse. """ ... + + @webmethod(route="/openai/v1/moderations", method="POST") + async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject: + """Classifies if text and/or image inputs are potentially harmful. + :param input: Input (or inputs) to classify. + Can be a single string, an array of strings, or an array of multi-modal input objects similar to other models. + :param model: The content moderation model you would like to use. + :returns: A moderation object. + """ + ... diff --git a/llama_stack/apis/scoring/scoring.py b/llama_stack/apis/scoring/scoring.py index 732e80e79..8ca599b44 100644 --- a/llama_stack/apis/scoring/scoring.py +++ b/llama_stack/apis/scoring/scoring.py @@ -31,6 +31,12 @@ class ScoringResult(BaseModel): @json_schema_type class ScoreBatchResponse(BaseModel): + """Response from batch scoring operations on datasets. + + :param dataset_id: (Optional) The identifier of the dataset that was scored + :param results: A map of scoring function name to ScoringResult + """ + dataset_id: str | None = None results: dict[str, ScoringResult] diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py index 684041308..05b6325b7 100644 --- a/llama_stack/apis/scoring_functions/scoring_functions.py +++ b/llama_stack/apis/scoring_functions/scoring_functions.py @@ -25,6 +25,12 @@ from llama_stack.schema_utils import json_schema_type, register_schema, webmetho # with standard metrics so they can be rolled up? @json_schema_type class ScoringFnParamsType(StrEnum): + """Types of scoring function parameter configurations. + :cvar llm_as_judge: Use an LLM model to evaluate and score responses + :cvar regex_parser: Use regex patterns to extract and score specific parts of responses + :cvar basic: Basic scoring with simple aggregation functions + """ + llm_as_judge = "llm_as_judge" regex_parser = "regex_parser" basic = "basic" @@ -32,6 +38,14 @@ class ScoringFnParamsType(StrEnum): @json_schema_type class AggregationFunctionType(StrEnum): + """Types of aggregation functions for scoring results. + :cvar average: Calculate the arithmetic mean of scores + :cvar weighted_average: Calculate a weighted average of scores + :cvar median: Calculate the median value of scores + :cvar categorical_count: Count occurrences of categorical values + :cvar accuracy: Calculate accuracy as the proportion of correct answers + """ + average = "average" weighted_average = "weighted_average" median = "median" @@ -41,6 +55,14 @@ class AggregationFunctionType(StrEnum): @json_schema_type class LLMAsJudgeScoringFnParams(BaseModel): + """Parameters for LLM-as-judge scoring function configuration. + :param type: The type of scoring function parameters, always llm_as_judge + :param judge_model: Identifier of the LLM model to use as a judge for scoring + :param prompt_template: (Optional) Custom prompt template for the judge model + :param judge_score_regexes: Regexes to extract the answer from generated response + :param aggregation_functions: Aggregation functions to apply to the scores of each row + """ + type: Literal[ScoringFnParamsType.llm_as_judge] = ScoringFnParamsType.llm_as_judge judge_model: str prompt_template: str | None = None @@ -56,6 +78,12 @@ class LLMAsJudgeScoringFnParams(BaseModel): @json_schema_type class RegexParserScoringFnParams(BaseModel): + """Parameters for regex parser scoring function configuration. + :param type: The type of scoring function parameters, always regex_parser + :param parsing_regexes: Regex to extract the answer from generated response + :param aggregation_functions: Aggregation functions to apply to the scores of each row + """ + type: Literal[ScoringFnParamsType.regex_parser] = ScoringFnParamsType.regex_parser parsing_regexes: list[str] = Field( description="Regex to extract the answer from generated response", @@ -69,6 +97,11 @@ class RegexParserScoringFnParams(BaseModel): @json_schema_type class BasicScoringFnParams(BaseModel): + """Parameters for basic scoring function configuration. + :param type: The type of scoring function parameters, always basic + :param aggregation_functions: Aggregation functions to apply to the scores of each row + """ + type: Literal[ScoringFnParamsType.basic] = ScoringFnParamsType.basic aggregation_functions: list[AggregationFunctionType] = Field( description="Aggregation functions to apply to the scores of each row", @@ -100,6 +133,10 @@ class CommonScoringFnFields(BaseModel): @json_schema_type class ScoringFn(CommonScoringFnFields, Resource): + """A scoring function resource for evaluating model outputs. + :param type: The resource type, always scoring_function + """ + type: Literal[ResourceType.scoring_function] = ResourceType.scoring_function @property diff --git a/llama_stack/apis/shields/shields.py b/llama_stack/apis/shields/shields.py index ce1f73d8e..ec1b85349 100644 --- a/llama_stack/apis/shields/shields.py +++ b/llama_stack/apis/shields/shields.py @@ -19,7 +19,11 @@ class CommonShieldFields(BaseModel): @json_schema_type class Shield(CommonShieldFields, Resource): - """A safety shield resource that can be used to check content""" + """A safety shield resource that can be used to check content. + + :param params: (Optional) Configuration parameters for the shield + :param type: The resource type, always shield + """ type: Literal[ResourceType.shield] = ResourceType.shield @@ -79,3 +83,11 @@ class Shields(Protocol): :returns: A Shield. """ ... + + @webmethod(route="/shields/{identifier:path}", method="DELETE") + async def unregister_shield(self, identifier: str) -> None: + """Unregister a shield. + + :param identifier: The identifier of the shield to unregister. + """ + ... diff --git a/llama_stack/apis/synthetic_data_generation/synthetic_data_generation.py b/llama_stack/apis/synthetic_data_generation/synthetic_data_generation.py index 91e550da9..a7af44b28 100644 --- a/llama_stack/apis/synthetic_data_generation/synthetic_data_generation.py +++ b/llama_stack/apis/synthetic_data_generation/synthetic_data_generation.py @@ -14,7 +14,15 @@ from llama_stack.schema_utils import json_schema_type, webmethod class FilteringFunction(Enum): - """The type of filtering function.""" + """The type of filtering function. + + :cvar none: No filtering applied, accept all generated synthetic data + :cvar random: Random sampling of generated data points + :cvar top_k: Keep only the top-k highest scoring synthetic data samples + :cvar top_p: Nucleus-style filtering, keep samples exceeding cumulative score threshold + :cvar top_k_top_p: Combined top-k and top-p filtering strategy + :cvar sigmoid: Apply sigmoid function for probability-based filtering + """ none = "none" random = "random" @@ -26,7 +34,12 @@ class FilteringFunction(Enum): @json_schema_type class SyntheticDataGenerationRequest(BaseModel): - """Request to generate synthetic data. A small batch of prompts and a filtering function""" + """Request to generate synthetic data. A small batch of prompts and a filtering function + + :param dialogs: List of conversation messages to use as input for synthetic data generation + :param filtering_function: Type of filtering to apply to generated synthetic data samples + :param model: (Optional) The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint + """ dialogs: list[Message] filtering_function: FilteringFunction = FilteringFunction.none @@ -35,7 +48,11 @@ class SyntheticDataGenerationRequest(BaseModel): @json_schema_type class SyntheticDataGenerationResponse(BaseModel): - """Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold.""" + """Response from the synthetic data generation. Batch of (prompt, response, score) tuples that pass the threshold. + + :param synthetic_data: List of generated synthetic data samples that passed the filtering criteria + :param statistics: (Optional) Statistical information about the generation process and filtering results + """ synthetic_data: list[dict[str, Any]] statistics: dict[str, Any] | None = None @@ -48,4 +65,12 @@ class SyntheticDataGeneration(Protocol): dialogs: list[Message], filtering_function: FilteringFunction = FilteringFunction.none, model: str | None = None, - ) -> SyntheticDataGenerationResponse: ... + ) -> SyntheticDataGenerationResponse: + """Generate synthetic data based on input dialogs and apply filtering. + + :param dialogs: List of conversation messages to use as input for synthetic data generation + :param filtering_function: Type of filtering to apply to generated synthetic data samples + :param model: (Optional) The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint + :returns: Response containing filtered synthetic data samples and optional statistics + """ + ... diff --git a/llama_stack/apis/telemetry/telemetry.py b/llama_stack/apis/telemetry/telemetry.py index 96b317c29..92422ac1b 100644 --- a/llama_stack/apis/telemetry/telemetry.py +++ b/llama_stack/apis/telemetry/telemetry.py @@ -27,12 +27,27 @@ REQUIRED_SCOPE = "telemetry.read" @json_schema_type class SpanStatus(Enum): + """The status of a span indicating whether it completed successfully or with an error. + :cvar OK: Span completed successfully without errors + :cvar ERROR: Span completed with an error or failure + """ + OK = "ok" ERROR = "error" @json_schema_type class Span(BaseModel): + """A span representing a single operation within a trace. + :param span_id: Unique identifier for the span + :param trace_id: Unique identifier for the trace this span belongs to + :param parent_span_id: (Optional) Unique identifier for the parent span, if this is a child span + :param name: Human-readable name describing the operation this span represents + :param start_time: Timestamp when the operation began + :param end_time: (Optional) Timestamp when the operation finished, if completed + :param attributes: (Optional) Key-value pairs containing additional metadata about the span + """ + span_id: str trace_id: str parent_span_id: str | None = None @@ -49,6 +64,13 @@ class Span(BaseModel): @json_schema_type class Trace(BaseModel): + """A trace representing the complete execution path of a request across multiple operations. + :param trace_id: Unique identifier for the trace + :param root_span_id: Unique identifier for the root span that started this trace + :param start_time: Timestamp when the trace began + :param end_time: (Optional) Timestamp when the trace finished, if completed + """ + trace_id: str root_span_id: str start_time: datetime @@ -57,6 +79,12 @@ class Trace(BaseModel): @json_schema_type class EventType(Enum): + """The type of telemetry event being logged. + :cvar UNSTRUCTURED_LOG: A simple log message with severity level + :cvar STRUCTURED_LOG: A structured log event with typed payload data + :cvar METRIC: A metric measurement with value and unit + """ + UNSTRUCTURED_LOG = "unstructured_log" STRUCTURED_LOG = "structured_log" METRIC = "metric" @@ -64,6 +92,15 @@ class EventType(Enum): @json_schema_type class LogSeverity(Enum): + """The severity level of a log message. + :cvar VERBOSE: Detailed diagnostic information for troubleshooting + :cvar DEBUG: Debug information useful during development + :cvar INFO: General informational messages about normal operation + :cvar WARN: Warning messages about potentially problematic situations + :cvar ERROR: Error messages indicating failures that don't stop execution + :cvar CRITICAL: Critical error messages indicating severe failures + """ + VERBOSE = "verbose" DEBUG = "debug" INFO = "info" @@ -73,6 +110,13 @@ class LogSeverity(Enum): class EventCommon(BaseModel): + """Common fields shared by all telemetry events. + :param trace_id: Unique identifier for the trace this event belongs to + :param span_id: Unique identifier for the span this event belongs to + :param timestamp: Timestamp when the event occurred + :param attributes: (Optional) Key-value pairs containing additional metadata about the event + """ + trace_id: str span_id: str timestamp: datetime @@ -81,6 +125,12 @@ class EventCommon(BaseModel): @json_schema_type class UnstructuredLogEvent(EventCommon): + """An unstructured log event containing a simple text message. + :param type: Event type identifier set to UNSTRUCTURED_LOG + :param message: The log message text + :param severity: The severity level of the log message + """ + type: Literal[EventType.UNSTRUCTURED_LOG] = EventType.UNSTRUCTURED_LOG message: str severity: LogSeverity @@ -88,6 +138,13 @@ class UnstructuredLogEvent(EventCommon): @json_schema_type class MetricEvent(EventCommon): + """A metric event containing a measured value. + :param type: Event type identifier set to METRIC + :param metric: The name of the metric being measured + :param value: The numeric value of the metric measurement + :param unit: The unit of measurement for the metric value + """ + type: Literal[EventType.METRIC] = EventType.METRIC metric: str # this would be an enum value: int | float @@ -96,6 +153,12 @@ class MetricEvent(EventCommon): @json_schema_type class MetricInResponse(BaseModel): + """A metric value included in API responses. + :param metric: The name of the metric + :param value: The numeric value of the metric + :param unit: (Optional) The unit of measurement for the metric value + """ + metric: str value: int | float unit: str | None = None @@ -122,17 +185,32 @@ class MetricInResponse(BaseModel): class MetricResponseMixin(BaseModel): + """Mixin class for API responses that can include metrics. + :param metrics: (Optional) List of metrics associated with the API response + """ + metrics: list[MetricInResponse] | None = None @json_schema_type class StructuredLogType(Enum): + """The type of structured log event payload. + :cvar SPAN_START: Event indicating the start of a new span + :cvar SPAN_END: Event indicating the completion of a span + """ + SPAN_START = "span_start" SPAN_END = "span_end" @json_schema_type class SpanStartPayload(BaseModel): + """Payload for a span start event. + :param type: Payload type identifier set to SPAN_START + :param name: Human-readable name describing the operation this span represents + :param parent_span_id: (Optional) Unique identifier for the parent span, if this is a child span + """ + type: Literal[StructuredLogType.SPAN_START] = StructuredLogType.SPAN_START name: str parent_span_id: str | None = None @@ -140,6 +218,11 @@ class SpanStartPayload(BaseModel): @json_schema_type class SpanEndPayload(BaseModel): + """Payload for a span end event. + :param type: Payload type identifier set to SPAN_END + :param status: The final status of the span indicating success or failure + """ + type: Literal[StructuredLogType.SPAN_END] = StructuredLogType.SPAN_END status: SpanStatus @@ -153,6 +236,11 @@ register_schema(StructuredLogPayload, name="StructuredLogPayload") @json_schema_type class StructuredLogEvent(EventCommon): + """A structured log event containing typed payload data. + :param type: Event type identifier set to STRUCTURED_LOG + :param payload: The structured payload data for the log event + """ + type: Literal[EventType.STRUCTURED_LOG] = EventType.STRUCTURED_LOG payload: StructuredLogPayload @@ -166,6 +254,14 @@ register_schema(Event, name="Event") @json_schema_type class EvalTrace(BaseModel): + """A trace record for evaluation purposes. + :param session_id: Unique identifier for the evaluation session + :param step: The evaluation step or phase identifier + :param input: The input data for the evaluation + :param output: The actual output produced during evaluation + :param expected_output: The expected output for comparison during evaluation + """ + session_id: str step: str input: str @@ -175,11 +271,22 @@ class EvalTrace(BaseModel): @json_schema_type class SpanWithStatus(Span): + """A span that includes status information. + :param status: (Optional) The current status of the span + """ + status: SpanStatus | None = None @json_schema_type class QueryConditionOp(Enum): + """Comparison operators for query conditions. + :cvar EQ: Equal to comparison + :cvar NE: Not equal to comparison + :cvar GT: Greater than comparison + :cvar LT: Less than comparison + """ + EQ = "eq" NE = "ne" GT = "gt" @@ -188,29 +295,59 @@ class QueryConditionOp(Enum): @json_schema_type class QueryCondition(BaseModel): + """A condition for filtering query results. + :param key: The attribute key to filter on + :param op: The comparison operator to apply + :param value: The value to compare against + """ + key: str op: QueryConditionOp value: Any class QueryTracesResponse(BaseModel): + """Response containing a list of traces. + :param data: List of traces matching the query criteria + """ + data: list[Trace] class QuerySpansResponse(BaseModel): + """Response containing a list of spans. + :param data: List of spans matching the query criteria + """ + data: list[Span] class QuerySpanTreeResponse(BaseModel): + """Response containing a tree structure of spans. + :param data: Dictionary mapping span IDs to spans with status information + """ + data: dict[str, SpanWithStatus] class MetricQueryType(Enum): + """The type of metric query to perform. + :cvar RANGE: Query metrics over a time range + :cvar INSTANT: Query metrics at a specific point in time + """ + RANGE = "range" INSTANT = "instant" class MetricLabelOperator(Enum): + """Operators for matching metric labels. + :cvar EQUALS: Label value must equal the specified value + :cvar NOT_EQUALS: Label value must not equal the specified value + :cvar REGEX_MATCH: Label value must match the specified regular expression + :cvar REGEX_NOT_MATCH: Label value must not match the specified regular expression + """ + EQUALS = "=" NOT_EQUALS = "!=" REGEX_MATCH = "=~" @@ -218,6 +355,12 @@ class MetricLabelOperator(Enum): class MetricLabelMatcher(BaseModel): + """A matcher for filtering metrics by label values. + :param name: The name of the label to match + :param value: The value to match against + :param operator: The comparison operator to use for matching + """ + name: str value: str operator: MetricLabelOperator = MetricLabelOperator.EQUALS @@ -225,24 +368,44 @@ class MetricLabelMatcher(BaseModel): @json_schema_type class MetricLabel(BaseModel): + """A label associated with a metric. + :param name: The name of the label + :param value: The value of the label + """ + name: str value: str @json_schema_type class MetricDataPoint(BaseModel): + """A single data point in a metric time series. + :param timestamp: Unix timestamp when the metric value was recorded + :param value: The numeric value of the metric at this timestamp + """ + timestamp: int value: float @json_schema_type class MetricSeries(BaseModel): + """A time series of metric data points. + :param metric: The name of the metric + :param labels: List of labels associated with this metric series + :param values: List of data points in chronological order + """ + metric: str labels: list[MetricLabel] values: list[MetricDataPoint] class QueryMetricsResponse(BaseModel): + """Response containing metric time series data. + :param data: List of metric series matching the query criteria + """ + data: list[MetricSeries] diff --git a/llama_stack/apis/tools/rag_tool.py b/llama_stack/apis/tools/rag_tool.py index 1d5e7b6cb..651016bd1 100644 --- a/llama_stack/apis/tools/rag_tool.py +++ b/llama_stack/apis/tools/rag_tool.py @@ -22,7 +22,7 @@ class RRFRanker(BaseModel): :param type: The type of ranker, always "rrf" :param impact_factor: The impact factor for RRF scoring. Higher values give more weight to higher-ranked results. - Must be greater than 0. Default of 60 is from the original RRF paper (Cormack et al., 2009). + Must be greater than 0 """ type: Literal["rrf"] = "rrf" @@ -76,12 +76,25 @@ class RAGDocument(BaseModel): @json_schema_type class RAGQueryResult(BaseModel): + """Result of a RAG query containing retrieved content and metadata. + + :param content: (Optional) The retrieved content from the query + :param metadata: Additional metadata about the query result + """ + content: InterleavedContent | None = None metadata: dict[str, Any] = Field(default_factory=dict) @json_schema_type class RAGQueryGenerator(Enum): + """Types of query generators for RAG systems. + + :cvar default: Default query generator using simple text processing + :cvar llm: LLM-based query generator for enhanced query understanding + :cvar custom: Custom query generator implementation + """ + default = "default" llm = "llm" custom = "custom" @@ -103,12 +116,25 @@ class RAGSearchMode(StrEnum): @json_schema_type class DefaultRAGQueryGeneratorConfig(BaseModel): + """Configuration for the default RAG query generator. + + :param type: Type of query generator, always 'default' + :param separator: String separator used to join query terms + """ + type: Literal["default"] = "default" separator: str = " " @json_schema_type class LLMRAGQueryGeneratorConfig(BaseModel): + """Configuration for the LLM-based RAG query generator. + + :param type: Type of query generator, always 'llm' + :param model: Name of the language model to use for query generation + :param template: Template string for formatting the query generation prompt + """ + type: Literal["llm"] = "llm" model: str template: str @@ -166,7 +192,12 @@ class RAGToolRuntime(Protocol): vector_db_id: str, chunk_size_in_tokens: int = 512, ) -> None: - """Index documents so they can be used by the RAG system""" + """Index documents so they can be used by the RAG system. + + :param documents: List of documents to index in the RAG system + :param vector_db_id: ID of the vector database to store the document embeddings + :param chunk_size_in_tokens: (Optional) Size in tokens for document chunking during indexing + """ ... @webmethod(route="/tool-runtime/rag-tool/query", method="POST") @@ -176,5 +207,11 @@ class RAGToolRuntime(Protocol): vector_db_ids: list[str], query_config: RAGQueryConfig | None = None, ) -> RAGQueryResult: - """Query the RAG system for context; typically invoked by the agent""" + """Query the RAG system for context; typically invoked by the agent. + + :param content: The query content to search for in the indexed documents + :param vector_db_ids: List of vector database IDs to search within + :param query_config: (Optional) Configuration parameters for the query operation + :returns: RAGQueryResult containing the retrieved content and metadata + """ ... diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index 7d1eeeefb..52b86375a 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -20,6 +20,15 @@ from .rag_tool import RAGToolRuntime @json_schema_type class ToolParameter(BaseModel): + """Parameter definition for a tool. + + :param name: Name of the parameter + :param parameter_type: Type of the parameter (e.g., string, integer) + :param description: Human-readable description of what the parameter does + :param required: Whether this parameter is required for tool invocation + :param default: (Optional) Default value for the parameter if not provided + """ + name: str parameter_type: str description: str @@ -29,6 +38,15 @@ class ToolParameter(BaseModel): @json_schema_type class Tool(Resource): + """A tool that can be invoked by agents. + + :param type: Type of resource, always 'tool' + :param toolgroup_id: ID of the tool group this tool belongs to + :param description: Human-readable description of what the tool does + :param parameters: List of parameters this tool accepts + :param metadata: (Optional) Additional metadata about the tool + """ + type: Literal[ResourceType.tool] = ResourceType.tool toolgroup_id: str description: str @@ -38,6 +56,14 @@ class Tool(Resource): @json_schema_type class ToolDef(BaseModel): + """Tool definition used in runtime contexts. + + :param name: Name of the tool + :param description: (Optional) Human-readable description of what the tool does + :param parameters: (Optional) List of parameters this tool accepts + :param metadata: (Optional) Additional metadata about the tool + """ + name: str description: str | None = None parameters: list[ToolParameter] | None = None @@ -46,6 +72,14 @@ class ToolDef(BaseModel): @json_schema_type class ToolGroupInput(BaseModel): + """Input data for registering a tool group. + + :param toolgroup_id: Unique identifier for the tool group + :param provider_id: ID of the provider that will handle this tool group + :param args: (Optional) Additional arguments to pass to the provider + :param mcp_endpoint: (Optional) Model Context Protocol endpoint for remote tools + """ + toolgroup_id: str provider_id: str args: dict[str, Any] | None = None @@ -54,6 +88,13 @@ class ToolGroupInput(BaseModel): @json_schema_type class ToolGroup(Resource): + """A group of related tools managed together. + + :param type: Type of resource, always 'tool_group' + :param mcp_endpoint: (Optional) Model Context Protocol endpoint for remote tools + :param args: (Optional) Additional arguments for the tool group + """ + type: Literal[ResourceType.tool_group] = ResourceType.tool_group mcp_endpoint: URL | None = None args: dict[str, Any] | None = None @@ -61,6 +102,14 @@ class ToolGroup(Resource): @json_schema_type class ToolInvocationResult(BaseModel): + """Result of a tool invocation. + + :param content: (Optional) The output content from the tool execution + :param error_message: (Optional) Error message if the tool execution failed + :param error_code: (Optional) Numeric error code if the tool execution failed + :param metadata: (Optional) Additional metadata about the tool execution + """ + content: InterleavedContent | None = None error_message: str | None = None error_code: int | None = None @@ -73,14 +122,29 @@ class ToolStore(Protocol): class ListToolGroupsResponse(BaseModel): + """Response containing a list of tool groups. + + :param data: List of tool groups + """ + data: list[ToolGroup] class ListToolsResponse(BaseModel): + """Response containing a list of tools. + + :param data: List of tools + """ + data: list[Tool] class ListToolDefsResponse(BaseModel): + """Response containing a list of tool definitions. + + :param data: List of tool definitions + """ + data: list[ToolDef] @@ -158,6 +222,11 @@ class ToolGroups(Protocol): class SpecialToolGroup(Enum): + """Special tool groups with predefined functionality. + + :cvar rag_tool: Retrieval-Augmented Generation tool group for document search and retrieval + """ + rag_tool = "rag_tool" diff --git a/llama_stack/apis/vector_dbs/vector_dbs.py b/llama_stack/apis/vector_dbs/vector_dbs.py index 325e21bab..47820fa0f 100644 --- a/llama_stack/apis/vector_dbs/vector_dbs.py +++ b/llama_stack/apis/vector_dbs/vector_dbs.py @@ -15,6 +15,13 @@ from llama_stack.schema_utils import json_schema_type, webmethod @json_schema_type class VectorDB(Resource): + """Vector database resource for storing and querying vector embeddings. + + :param type: Type of resource, always 'vector_db' for vector databases + :param embedding_model: Name of the embedding model to use for vector generation + :param embedding_dimension: Dimension of the embedding vectors + """ + type: Literal[ResourceType.vector_db] = ResourceType.vector_db embedding_model: str @@ -31,6 +38,14 @@ class VectorDB(Resource): class VectorDBInput(BaseModel): + """Input parameters for creating or configuring a vector database. + + :param vector_db_id: Unique identifier for the vector database + :param embedding_model: Name of the embedding model to use for vector generation + :param embedding_dimension: Dimension of the embedding vectors + :param provider_vector_db_id: (Optional) Provider-specific identifier for the vector database + """ + vector_db_id: str embedding_model: str embedding_dimension: int @@ -39,6 +54,11 @@ class VectorDBInput(BaseModel): class ListVectorDBsResponse(BaseModel): + """Response from listing vector databases. + + :param data: List of vector databases + """ + data: list[VectorDB] diff --git a/llama_stack/apis/vector_io/vector_io.py b/llama_stack/apis/vector_io/vector_io.py index 853c4656c..3e8065cfb 100644 --- a/llama_stack/apis/vector_io/vector_io.py +++ b/llama_stack/apis/vector_io/vector_io.py @@ -16,7 +16,7 @@ from pydantic import BaseModel, Field from llama_stack.apis.inference import InterleavedContent from llama_stack.apis.vector_dbs import VectorDB from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol -from llama_stack.providers.utils.vector_io.chunk_utils import generate_chunk_id +from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id from llama_stack.schema_utils import json_schema_type, webmethod from llama_stack.strong_typing.schema import register_schema @@ -94,12 +94,27 @@ class Chunk(BaseModel): @json_schema_type class QueryChunksResponse(BaseModel): + """Response from querying chunks in a vector database. + + :param chunks: List of content chunks returned from the query + :param scores: Relevance scores corresponding to each returned chunk + """ + chunks: list[Chunk] scores: list[float] @json_schema_type class VectorStoreFileCounts(BaseModel): + """File processing status counts for a vector store. + + :param completed: Number of files that have been successfully processed + :param cancelled: Number of files that had their processing cancelled + :param failed: Number of files that failed to process + :param in_progress: Number of files currently being processed + :param total: Total number of files in the vector store + """ + completed: int cancelled: int failed: int @@ -109,7 +124,20 @@ class VectorStoreFileCounts(BaseModel): @json_schema_type class VectorStoreObject(BaseModel): - """OpenAI Vector Store object.""" + """OpenAI Vector Store object. + + :param id: Unique identifier for the vector store + :param object: Object type identifier, always "vector_store" + :param created_at: Timestamp when the vector store was created + :param name: (Optional) Name of the vector store + :param usage_bytes: Storage space used by the vector store in bytes + :param file_counts: File processing status counts for the vector store + :param status: Current status of the vector store + :param expires_after: (Optional) Expiration policy for the vector store + :param expires_at: (Optional) Timestamp when the vector store will expire + :param last_active_at: (Optional) Timestamp of last activity on the vector store + :param metadata: Set of key-value pairs that can be attached to the vector store + """ id: str object: str = "vector_store" @@ -126,7 +154,14 @@ class VectorStoreObject(BaseModel): @json_schema_type class VectorStoreCreateRequest(BaseModel): - """Request to create a vector store.""" + """Request to create a vector store. + + :param name: (Optional) Name for the vector store + :param file_ids: List of file IDs to include in the vector store + :param expires_after: (Optional) Expiration policy for the vector store + :param chunking_strategy: (Optional) Strategy for splitting files into chunks + :param metadata: Set of key-value pairs that can be attached to the vector store + """ name: str | None = None file_ids: list[str] = Field(default_factory=list) @@ -137,7 +172,12 @@ class VectorStoreCreateRequest(BaseModel): @json_schema_type class VectorStoreModifyRequest(BaseModel): - """Request to modify a vector store.""" + """Request to modify a vector store. + + :param name: (Optional) Updated name for the vector store + :param expires_after: (Optional) Updated expiration policy for the vector store + :param metadata: (Optional) Updated set of key-value pairs for the vector store + """ name: str | None = None expires_after: dict[str, Any] | None = None @@ -146,7 +186,14 @@ class VectorStoreModifyRequest(BaseModel): @json_schema_type class VectorStoreListResponse(BaseModel): - """Response from listing vector stores.""" + """Response from listing vector stores. + + :param object: Object type identifier, always "list" + :param data: List of vector store objects + :param first_id: (Optional) ID of the first vector store in the list for pagination + :param last_id: (Optional) ID of the last vector store in the list for pagination + :param has_more: Whether there are more vector stores available beyond this page + """ object: str = "list" data: list[VectorStoreObject] @@ -157,7 +204,14 @@ class VectorStoreListResponse(BaseModel): @json_schema_type class VectorStoreSearchRequest(BaseModel): - """Request to search a vector store.""" + """Request to search a vector store. + + :param query: Search query as a string or list of strings + :param filters: (Optional) Filters based on file attributes to narrow search results + :param max_num_results: Maximum number of results to return, defaults to 10 + :param ranking_options: (Optional) Options for ranking and filtering search results + :param rewrite_query: Whether to rewrite the query for better vector search performance + """ query: str | list[str] filters: dict[str, Any] | None = None @@ -168,13 +222,26 @@ class VectorStoreSearchRequest(BaseModel): @json_schema_type class VectorStoreContent(BaseModel): + """Content item from a vector store file or search result. + + :param type: Content type, currently only "text" is supported + :param text: The actual text content + """ + type: Literal["text"] text: str @json_schema_type class VectorStoreSearchResponse(BaseModel): - """Response from searching a vector store.""" + """Response from searching a vector store. + + :param file_id: Unique identifier of the file containing the result + :param filename: Name of the file containing the result + :param score: Relevance score for this search result + :param attributes: (Optional) Key-value attributes associated with the file + :param content: List of content items matching the search query + """ file_id: str filename: str @@ -185,7 +252,14 @@ class VectorStoreSearchResponse(BaseModel): @json_schema_type class VectorStoreSearchResponsePage(BaseModel): - """Response from searching a vector store.""" + """Paginated response from searching a vector store. + + :param object: Object type identifier for the search results page + :param search_query: The original search query that was executed + :param data: List of search result objects + :param has_more: Whether there are more results available beyond this page + :param next_page: (Optional) Token for retrieving the next page of results + """ object: str = "vector_store.search_results.page" search_query: str @@ -196,7 +270,12 @@ class VectorStoreSearchResponsePage(BaseModel): @json_schema_type class VectorStoreDeleteResponse(BaseModel): - """Response from deleting a vector store.""" + """Response from deleting a vector store. + + :param id: Unique identifier of the deleted vector store + :param object: Object type identifier for the deletion response + :param deleted: Whether the deletion operation was successful + """ id: str object: str = "vector_store.deleted" @@ -205,17 +284,34 @@ class VectorStoreDeleteResponse(BaseModel): @json_schema_type class VectorStoreChunkingStrategyAuto(BaseModel): + """Automatic chunking strategy for vector store files. + + :param type: Strategy type, always "auto" for automatic chunking + """ + type: Literal["auto"] = "auto" @json_schema_type class VectorStoreChunkingStrategyStaticConfig(BaseModel): + """Configuration for static chunking strategy. + + :param chunk_overlap_tokens: Number of tokens to overlap between adjacent chunks + :param max_chunk_size_tokens: Maximum number of tokens per chunk, must be between 100 and 4096 + """ + chunk_overlap_tokens: int = 400 max_chunk_size_tokens: int = Field(800, ge=100, le=4096) @json_schema_type class VectorStoreChunkingStrategyStatic(BaseModel): + """Static chunking strategy with configurable parameters. + + :param type: Strategy type, always "static" for static chunking + :param static: Configuration parameters for the static chunking strategy + """ + type: Literal["static"] = "static" static: VectorStoreChunkingStrategyStaticConfig @@ -227,6 +323,12 @@ register_schema(VectorStoreChunkingStrategy, name="VectorStoreChunkingStrategy") class SearchRankingOptions(BaseModel): + """Options for ranking and filtering search results. + + :param ranker: (Optional) Name of the ranking algorithm to use + :param score_threshold: (Optional) Minimum relevance score threshold for results + """ + ranker: str | None = None # NOTE: OpenAI File Search Tool requires threshold to be between 0 and 1, however # we don't guarantee that the score is between 0 and 1, so will leave this unconstrained @@ -236,6 +338,12 @@ class SearchRankingOptions(BaseModel): @json_schema_type class VectorStoreFileLastError(BaseModel): + """Error information for failed vector store file processing. + + :param code: Error code indicating the type of failure + :param message: Human-readable error message describing the failure + """ + code: Literal["server_error"] | Literal["rate_limit_exceeded"] message: str @@ -246,7 +354,18 @@ register_schema(VectorStoreFileStatus, name="VectorStoreFileStatus") @json_schema_type class VectorStoreFileObject(BaseModel): - """OpenAI Vector Store File object.""" + """OpenAI Vector Store File object. + + :param id: Unique identifier for the file + :param object: Object type identifier, always "vector_store.file" + :param attributes: Key-value attributes associated with the file + :param chunking_strategy: Strategy used for splitting the file into chunks + :param created_at: Timestamp when the file was added to the vector store + :param last_error: (Optional) Error information if file processing failed + :param status: Current processing status of the file + :param usage_bytes: Storage space used by this file in bytes + :param vector_store_id: ID of the vector store containing this file + """ id: str object: str = "vector_store.file" @@ -261,7 +380,14 @@ class VectorStoreFileObject(BaseModel): @json_schema_type class VectorStoreListFilesResponse(BaseModel): - """Response from listing vector stores.""" + """Response from listing files in a vector store. + + :param object: Object type identifier, always "list" + :param data: List of vector store file objects + :param first_id: (Optional) ID of the first file in the list for pagination + :param last_id: (Optional) ID of the last file in the list for pagination + :param has_more: Whether there are more files available beyond this page + """ object: str = "list" data: list[VectorStoreFileObject] @@ -272,7 +398,13 @@ class VectorStoreListFilesResponse(BaseModel): @json_schema_type class VectorStoreFileContentsResponse(BaseModel): - """Response from retrieving the contents of a vector store file.""" + """Response from retrieving the contents of a vector store file. + + :param file_id: Unique identifier for the file + :param filename: Name of the file + :param attributes: Key-value attributes associated with the file + :param content: List of content items from the file + """ file_id: str filename: str @@ -282,7 +414,12 @@ class VectorStoreFileContentsResponse(BaseModel): @json_schema_type class VectorStoreFileDeleteResponse(BaseModel): - """Response from deleting a vector store file.""" + """Response from deleting a vector store file. + + :param id: Unique identifier of the deleted file + :param object: Object type identifier for the deletion response + :param deleted: Whether the deletion operation was successful + """ id: str object: str = "vector_store.file.deleted" @@ -478,6 +615,11 @@ class VectorIO(Protocol): """List files in a vector store. :param vector_store_id: The ID of the vector store to list files from. + :param limit: (Optional) A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20. + :param order: (Optional) Sort order by the `created_at` timestamp of the objects. `asc` for ascending order and `desc` for descending order. + :param after: (Optional) A cursor for use in pagination. `after` is an object ID that defines your place in the list. + :param before: (Optional) A cursor for use in pagination. `before` is an object ID that defines your place in the list. + :param filter: (Optional) Filter by file status to only return files with the specified status. :returns: A VectorStoreListFilesResponse containing the list of files. """ ... diff --git a/llama_stack/cli/download.py b/llama_stack/cli/download.py index 30b6e11e9..70cb9f4db 100644 --- a/llama_stack/cli/download.py +++ b/llama_stack/cli/download.py @@ -323,7 +323,7 @@ def _hf_download( from huggingface_hub import snapshot_download from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError - from llama_stack.distribution.utils.model_utils import model_local_dir + from llama_stack.core.utils.model_utils import model_local_dir repo_id = model.huggingface_repo if repo_id is None: @@ -361,7 +361,7 @@ def _meta_download( info: "LlamaDownloadInfo", max_concurrent_downloads: int, ): - from llama_stack.distribution.utils.model_utils import model_local_dir + from llama_stack.core.utils.model_utils import model_local_dir output_dir = Path(model_local_dir(model.descriptor())) os.makedirs(output_dir, exist_ok=True) @@ -403,7 +403,7 @@ class Manifest(BaseModel): def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int): - from llama_stack.distribution.utils.model_utils import model_local_dir + from llama_stack.core.utils.model_utils import model_local_dir with open(manifest_file) as f: d = json.load(f) diff --git a/llama_stack/cli/model/list.py b/llama_stack/cli/model/list.py index cf84dd526..f46a8c88d 100644 --- a/llama_stack/cli/model/list.py +++ b/llama_stack/cli/model/list.py @@ -11,7 +11,7 @@ from pathlib import Path from llama_stack.cli.subcommand import Subcommand from llama_stack.cli.table import print_table -from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR +from llama_stack.core.utils.config_dirs import DEFAULT_CHECKPOINT_DIR from llama_stack.models.llama.sku_list import all_registered_models diff --git a/llama_stack/cli/model/remove.py b/llama_stack/cli/model/remove.py index 98710d82b..138e06a2a 100644 --- a/llama_stack/cli/model/remove.py +++ b/llama_stack/cli/model/remove.py @@ -9,7 +9,7 @@ import os import shutil from llama_stack.cli.subcommand import Subcommand -from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR +from llama_stack.core.utils.config_dirs import DEFAULT_CHECKPOINT_DIR from llama_stack.models.llama.sku_list import resolve_model diff --git a/llama_stack/cli/stack/_build.py b/llama_stack/cli/stack/_build.py index fbf4871c4..c6e204773 100644 --- a/llama_stack/cli/stack/_build.py +++ b/llama_stack/cli/stack/_build.py @@ -23,73 +23,80 @@ from termcolor import colored, cprint from llama_stack.cli.stack.utils import ImageType from llama_stack.cli.table import print_table -from llama_stack.distribution.build import ( +from llama_stack.core.build import ( SERVER_DEPENDENCIES, build_image, get_provider_dependencies, ) -from llama_stack.distribution.configure import parse_and_maybe_upgrade_config -from llama_stack.distribution.datatypes import ( +from llama_stack.core.configure import parse_and_maybe_upgrade_config +from llama_stack.core.datatypes import ( BuildConfig, BuildProvider, DistributionSpec, Provider, StackRunConfig, ) -from llama_stack.distribution.distribution import get_provider_registry -from llama_stack.distribution.external import load_external_apis -from llama_stack.distribution.resolver import InvalidProviderError -from llama_stack.distribution.stack import replace_env_vars -from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR, EXTERNAL_PROVIDERS_DIR -from llama_stack.distribution.utils.dynamic import instantiate_class_type -from llama_stack.distribution.utils.exec import formulate_run_args, run_command -from llama_stack.distribution.utils.image_types import LlamaStackImageType +from llama_stack.core.distribution import get_provider_registry +from llama_stack.core.external import load_external_apis +from llama_stack.core.resolver import InvalidProviderError +from llama_stack.core.stack import replace_env_vars +from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR, EXTERNAL_PROVIDERS_DIR +from llama_stack.core.utils.dynamic import instantiate_class_type +from llama_stack.core.utils.exec import formulate_run_args, run_command +from llama_stack.core.utils.image_types import LlamaStackImageType from llama_stack.providers.datatypes import Api -TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates" +DISTRIBS_PATH = Path(__file__).parent.parent.parent / "distributions" @lru_cache -def available_templates_specs() -> dict[str, BuildConfig]: +def available_distros_specs() -> dict[str, BuildConfig]: import yaml - template_specs = {} - for p in TEMPLATES_PATH.rglob("*build.yaml"): - template_name = p.parent.name + distro_specs = {} + for p in DISTRIBS_PATH.rglob("*build.yaml"): + distro_name = p.parent.name with open(p) as f: build_config = BuildConfig(**yaml.safe_load(f)) - template_specs[template_name] = build_config - return template_specs + distro_specs[distro_name] = build_config + return distro_specs def run_stack_build_command(args: argparse.Namespace) -> None: - if args.list_templates: - return _run_template_list_cmd() + if args.list_distros: + return _run_distro_list_cmd() if args.image_type == ImageType.VENV.value: current_venv = os.environ.get("VIRTUAL_ENV") image_name = args.image_name or current_venv - elif args.image_type == ImageType.CONDA.value: - current_conda_env = os.environ.get("CONDA_DEFAULT_ENV") - image_name = args.image_name or current_conda_env else: image_name = args.image_name if args.template: - available_templates = available_templates_specs() - if args.template not in available_templates: + cprint( + "The --template argument is deprecated. Please use --distro instead.", + color="red", + file=sys.stderr, + ) + distro_name = args.template + else: + distro_name = args.distribution + + if distro_name: + available_distros = available_distros_specs() + if distro_name not in available_distros: cprint( - f"Could not find template {args.template}. Please run `llama stack build --list-templates` to check out the available templates", + f"Could not find distribution {distro_name}. Please run `llama stack build --list-distros` to check out the available distributions", color="red", file=sys.stderr, ) sys.exit(1) - build_config = available_templates[args.template] + build_config = available_distros[distro_name] if args.image_type: build_config.image_type = args.image_type else: cprint( - f"Please specify a image-type ({' | '.join(e.value for e in ImageType)}) for {args.template}", + f"Please specify a image-type ({' | '.join(e.value for e in ImageType)}) for {distro_name}", color="red", file=sys.stderr, ) @@ -132,14 +139,14 @@ def run_stack_build_command(args: argparse.Namespace) -> None: ) if not args.image_type: cprint( - f"Please specify a image-type (container | conda | venv) for {args.template}", + f"Please specify a image-type (container | venv) for {args.template}", color="red", file=sys.stderr, ) sys.exit(1) build_config = BuildConfig(image_type=args.image_type, distribution_spec=distribution_spec) - elif not args.config and not args.template: + elif not args.config and not distro_name: name = prompt( "> Enter a name for your Llama Stack (e.g. my-local-stack): ", validator=Validator.from_callable( @@ -158,22 +165,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None: ), ) - if image_type == ImageType.CONDA.value: - if not image_name: - cprint( - f"No current conda environment detected or specified, will create a new conda environment with the name `llamastack-{name}`", - color="yellow", - file=sys.stderr, - ) - image_name = f"llamastack-{name}" - else: - cprint( - f"Using conda environment {image_name}", - color="green", - file=sys.stderr, - ) - else: - image_name = f"llamastack-{name}" + image_name = f"llamastack-{name}" cprint( textwrap.dedent( @@ -236,7 +228,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None: sys.exit(1) if args.print_deps_only: - print(f"# Dependencies for {args.template or args.config or image_name}") + print(f"# Dependencies for {distro_name or args.config or image_name}") normal_deps, special_deps, external_provider_dependencies = get_provider_dependencies(build_config) normal_deps += SERVER_DEPENDENCIES print(f"uv pip install {' '.join(normal_deps)}") @@ -251,7 +243,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None: build_config, image_name=image_name, config_path=args.config, - template_name=args.template, + distro_name=distro_name, ) except (Exception, RuntimeError) as exc: @@ -279,7 +271,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None: config = parse_and_maybe_upgrade_config(config_dict) if config.external_providers_dir and not config.external_providers_dir.exists(): config.external_providers_dir.mkdir(exist_ok=True) - run_args = formulate_run_args(args.image_type, args.image_name) + run_args = formulate_run_args(args.image_type, image_name or config.image_name) run_args.extend([str(os.getenv("LLAMA_STACK_PORT", 8321)), "--config", str(run_config)]) run_command(run_args) @@ -362,20 +354,17 @@ def _generate_run_config( def _run_stack_build_command_from_build_config( build_config: BuildConfig, image_name: str | None = None, - template_name: str | None = None, + distro_name: str | None = None, config_path: str | None = None, ) -> Path | Traversable: image_name = image_name or build_config.image_name if build_config.image_type == LlamaStackImageType.CONTAINER.value: - if template_name: - image_name = f"distribution-{template_name}" + if distro_name: + image_name = f"distribution-{distro_name}" else: if not image_name: raise ValueError("Please specify an image name when building a container image without a template") - elif build_config.image_type == LlamaStackImageType.CONDA.value: - if not image_name: - raise ValueError("Please specify an image name when building a conda image") - elif build_config.image_type == LlamaStackImageType.VENV.value: + else: if not image_name and os.environ.get("UV_SYSTEM_PYTHON"): image_name = "__system__" if not image_name: @@ -385,9 +374,9 @@ def _run_stack_build_command_from_build_config( if image_name is None: raise ValueError("image_name should not be None after validation") - if template_name: - build_dir = DISTRIBS_BASE_DIR / template_name - build_file_path = build_dir / f"{template_name}-build.yaml" + if distro_name: + build_dir = DISTRIBS_BASE_DIR / distro_name + build_file_path = build_dir / f"{distro_name}-build.yaml" else: if image_name is None: raise ValueError("image_name cannot be None") @@ -398,7 +387,7 @@ def _run_stack_build_command_from_build_config( run_config_file = None # Generate the run.yaml so it can be included in the container image with the proper entrypoint # Only do this if we're building a container image and we're not using a template - if build_config.image_type == LlamaStackImageType.CONTAINER.value and not template_name and config_path: + if build_config.image_type == LlamaStackImageType.CONTAINER.value and not distro_name and config_path: cprint("Generating run.yaml file", color="yellow", file=sys.stderr) run_config_file = _generate_run_config(build_config, build_dir, image_name) @@ -431,48 +420,46 @@ def _run_stack_build_command_from_build_config( return_code = build_image( build_config, - build_file_path, image_name, - template_or_config=template_name or config_path or str(build_file_path), + distro_or_config=distro_name or config_path or str(build_file_path), run_config=run_config_file.as_posix() if run_config_file else None, ) if return_code != 0: raise RuntimeError(f"Failed to build image {image_name}") - if template_name: - # copy run.yaml from template to build_dir instead of generating it again - template_path = importlib.resources.files("llama_stack") / f"templates/{template_name}/run.yaml" - run_config_file = build_dir / f"{template_name}-run.yaml" + if distro_name: + # copy run.yaml from distribution to build_dir instead of generating it again + distro_path = importlib.resources.files("llama_stack") / f"distributions/{distro_name}/run.yaml" + run_config_file = build_dir / f"{distro_name}-run.yaml" - with importlib.resources.as_file(template_path) as path: + with importlib.resources.as_file(distro_path) as path: shutil.copy(path, run_config_file) cprint("Build Successful!", color="green", file=sys.stderr) - cprint(f"You can find the newly-built template here: {run_config_file}", color="blue", file=sys.stderr) + cprint(f"You can find the newly-built distribution here: {run_config_file}", color="blue", file=sys.stderr) cprint( "You can run the new Llama Stack distro via: " + colored(f"llama stack run {run_config_file} --image-type {build_config.image_type}", "blue"), color="green", file=sys.stderr, ) - return template_path + return distro_path else: return _generate_run_config(build_config, build_dir, image_name) -def _run_template_list_cmd() -> None: - # eventually, this should query a registry at llama.meta.com/llamastack/distributions +def _run_distro_list_cmd() -> None: headers = [ - "Template Name", + "Distribution Name", # "Providers", "Description", ] rows = [] - for template_name, spec in available_templates_specs().items(): + for distro_name, spec in available_distros_specs().items(): rows.append( [ - template_name, + distro_name, # json.dumps(spec.distribution_spec.providers, indent=2), spec.distribution_spec.description, ] diff --git a/llama_stack/cli/stack/build.py b/llama_stack/cli/stack/build.py index 2c402beeb..80cf6fb38 100644 --- a/llama_stack/cli/stack/build.py +++ b/llama_stack/cli/stack/build.py @@ -27,21 +27,31 @@ class StackBuild(Subcommand): "--config", type=str, default=None, - help="Path to a config file to use for the build. You can find example configs in llama_stack/distributions/**/build.yaml. If this argument is not provided, you will be prompted to enter information interactively", + help="Path to a config file to use for the build. You can find example configs in llama_stack.cores/**/build.yaml. If this argument is not provided, you will be prompted to enter information interactively", ) self.parser.add_argument( "--template", type=str, default=None, - help="Name of the example template config to use for build. You may use `llama stack build --list-templates` to check out the available templates", + help="""(deprecated) Name of the example template config to use for build. You may use `llama stack build --list-distros` to check out the available distributions""", + ) + self.parser.add_argument( + "--distro", + "--distribution", + dest="distribution", + type=str, + default=None, + help="""Name of the distribution to use for build. You may use `llama stack build --list-distros` to check out the available distributions""", ) self.parser.add_argument( - "--list-templates", + "--list-distros", + "--list-distributions", action="store_true", + dest="list_distros", default=False, - help="Show the available templates for building a Llama Stack distribution", + help="Show the available distributions for building a Llama Stack distribution", ) self.parser.add_argument( @@ -56,7 +66,7 @@ class StackBuild(Subcommand): "--image-name", type=str, help=textwrap.dedent( - f"""[for image-type={"|".join(e.value for e in ImageType)}] Name of the conda or virtual environment to use for + f"""[for image-type={"|".join(e.value for e in ImageType)}] Name of the virtual environment to use for the build. If not specified, currently active environment will be used if found. """ ), diff --git a/llama_stack/cli/stack/list_apis.py b/llama_stack/cli/stack/list_apis.py index cac803f92..6eed5ca51 100644 --- a/llama_stack/cli/stack/list_apis.py +++ b/llama_stack/cli/stack/list_apis.py @@ -26,7 +26,7 @@ class StackListApis(Subcommand): def _run_apis_list_cmd(self, args: argparse.Namespace) -> None: from llama_stack.cli.table import print_table - from llama_stack.distribution.distribution import stack_apis + from llama_stack.core.distribution import stack_apis # eventually, this should query a registry at llama.meta.com/llamastack/distributions headers = [ diff --git a/llama_stack/cli/stack/list_providers.py b/llama_stack/cli/stack/list_providers.py index deebd937b..b78b3c31f 100644 --- a/llama_stack/cli/stack/list_providers.py +++ b/llama_stack/cli/stack/list_providers.py @@ -23,7 +23,7 @@ class StackListProviders(Subcommand): @property def providable_apis(self): - from llama_stack.distribution.distribution import providable_apis + from llama_stack.core.distribution import providable_apis return [api.value for api in providable_apis()] @@ -38,7 +38,7 @@ class StackListProviders(Subcommand): def _run_providers_list_cmd(self, args: argparse.Namespace) -> None: from llama_stack.cli.table import print_table - from llama_stack.distribution.distribution import Api, get_provider_registry + from llama_stack.core.distribution import Api, get_provider_registry all_providers = get_provider_registry() if args.api: diff --git a/llama_stack/cli/stack/run.py b/llama_stack/cli/stack/run.py index 3cb2e213c..c8ffce034 100644 --- a/llama_stack/cli/stack/run.py +++ b/llama_stack/cli/stack/run.py @@ -35,8 +35,8 @@ class StackRun(Subcommand): "config", type=str, nargs="?", # Make it optional - metavar="config | template", - help="Path to config file to use for the run or name of known template (`llama stack list` for a list).", + metavar="config | distro", + help="Path to config file to use for the run or name of known distro (`llama stack list` for a list).", ) self.parser.add_argument( "--port", @@ -47,7 +47,8 @@ class StackRun(Subcommand): self.parser.add_argument( "--image-name", type=str, - help="Name of the image to run.", + default=None, + help="Name of the image to run. Defaults to the current environment", ) self.parser.add_argument( "--env", @@ -58,7 +59,7 @@ class StackRun(Subcommand): self.parser.add_argument( "--image-type", type=str, - help="Image Type used during the build. This can be either conda or container or venv.", + help="Image Type used during the build. This can be only venv.", choices=[e.value for e in ImageType if e.value != ImageType.CONTAINER.value], ) self.parser.add_argument( @@ -67,44 +68,62 @@ class StackRun(Subcommand): help="Start the UI server", ) - # If neither image type nor image name is provided, but at the same time - # the current environment has conda breadcrumbs, then assume what the user - # wants to use conda mode and not the usual default mode (using - # pre-installed system packages). - # - # Note: yes, this is hacky. It's implemented this way to keep the existing - # conda users unaffected by the switch of the default behavior to using - # system packages. - def _get_image_type_and_name(self, args: argparse.Namespace) -> tuple[str, str]: - conda_env = os.environ.get("CONDA_DEFAULT_ENV") - if conda_env and args.image_name == conda_env: - logger.warning(f"Conda detected. Using conda environment {conda_env} for the run.") - return ImageType.CONDA.value, args.image_name - return args.image_type, args.image_name + def _resolve_config_and_distro(self, args: argparse.Namespace) -> tuple[Path | None, str | None]: + """Resolve config file path and distribution name from args.config""" + from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR + + if not args.config: + return None, None + + config_file = Path(args.config) + has_yaml_suffix = args.config.endswith(".yaml") + distro_name = None + + if not config_file.exists() and not has_yaml_suffix: + # check if this is a distribution + config_file = Path(REPO_ROOT) / "llama_stack" / "distributions" / args.config / "run.yaml" + if config_file.exists(): + distro_name = args.config + + if not config_file.exists() and not has_yaml_suffix: + # check if it's a build config saved to ~/.llama dir + config_file = Path(DISTRIBS_BASE_DIR / f"llamastack-{args.config}" / f"{args.config}-run.yaml") + + if not config_file.exists(): + self.parser.error( + f"File {str(config_file)} does not exist.\n\nPlease run `llama stack build` to generate (and optionally edit) a run.yaml file" + ) + + if not config_file.is_file(): + self.parser.error( + f"Config file must be a valid file path, '{config_file}' is not a file: type={type(config_file)}" + ) + + return config_file, distro_name def _run_stack_run_cmd(self, args: argparse.Namespace) -> None: import yaml - from llama_stack.distribution.configure import parse_and_maybe_upgrade_config - from llama_stack.distribution.utils.exec import formulate_run_args, run_command + from llama_stack.core.configure import parse_and_maybe_upgrade_config + from llama_stack.core.utils.exec import formulate_run_args, run_command if args.enable_ui: self._start_ui_development_server(args.port) - image_type, image_name = self._get_image_type_and_name(args) + image_type, image_name = args.image_type, args.image_name if args.config: try: - from llama_stack.distribution.utils.config_resolution import Mode, resolve_config_or_template + from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro - config_file = resolve_config_or_template(args.config, Mode.RUN) + config_file = resolve_config_or_distro(args.config, Mode.RUN) except ValueError as e: self.parser.error(str(e)) else: config_file = None # Check if config is required based on image type - if (image_type in [ImageType.CONDA.value, ImageType.VENV.value]) and not config_file: - self.parser.error("Config file is required for venv and conda environments") + if image_type == ImageType.VENV.value and not config_file: + self.parser.error("Config file is required for venv environment") if config_file: logger.info(f"Using run configuration: {config_file}") @@ -127,7 +146,7 @@ class StackRun(Subcommand): # using the current environment packages. if not image_type and not image_name: logger.info("No image type or image name provided. Assuming environment packages.") - from llama_stack.distribution.server.server import main as server_main + from llama_stack.core.server.server import main as server_main # Build the server args from the current args passed to the CLI server_args = argparse.Namespace() diff --git a/llama_stack/cli/stack/utils.py b/llama_stack/cli/stack/utils.py index 74a606b2b..fdf9e1761 100644 --- a/llama_stack/cli/stack/utils.py +++ b/llama_stack/cli/stack/utils.py @@ -8,7 +8,6 @@ from enum import Enum class ImageType(Enum): - CONDA = "conda" CONTAINER = "container" VENV = "venv" diff --git a/llama_stack/cli/utils.py b/llama_stack/cli/utils.py index 94cff42e8..c9c51d933 100644 --- a/llama_stack/cli/utils.py +++ b/llama_stack/cli/utils.py @@ -11,38 +11,19 @@ from llama_stack.log import get_logger logger = get_logger(name=__name__, category="cli") -def add_config_template_args(parser: argparse.ArgumentParser): - """Add unified config/template arguments with backward compatibility.""" +# TODO: this can probably just be inlined now? +def add_config_distro_args(parser: argparse.ArgumentParser): + """Add unified config/distro arguments.""" group = parser.add_mutually_exclusive_group(required=True) group.add_argument( "config", nargs="?", - help="Configuration file path or template name", - ) - - # Backward compatibility arguments (deprecated) - group.add_argument( - "--config", - dest="config_deprecated", - help="(DEPRECATED) Use positional argument [config] instead. Configuration file path", - ) - - group.add_argument( - "--template", - dest="template_deprecated", - help="(DEPRECATED) Use positional argument [config] instead. Template name", + help="Configuration file path or distribution name", ) def get_config_from_args(args: argparse.Namespace) -> str | None: - """Extract config value from parsed arguments, handling both new and deprecated forms.""" if args.config is not None: return str(args.config) - elif hasattr(args, "config_deprecated") and args.config_deprecated is not None: - logger.warning("Using deprecated --config argument. Use positional argument [config] instead.") - return str(args.config_deprecated) - elif hasattr(args, "template_deprecated") and args.template_deprecated is not None: - logger.warning("Using deprecated --template argument. Use positional argument [config] instead.") - return str(args.template_deprecated) return None diff --git a/llama_stack/cli/verify_download.py b/llama_stack/cli/verify_download.py index 3a1af3cbc..b7f4cfdb5 100644 --- a/llama_stack/cli/verify_download.py +++ b/llama_stack/cli/verify_download.py @@ -107,7 +107,7 @@ def verify_files(model_dir: Path, checksums: dict[str, str], console: Console) - def run_verify_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser): - from llama_stack.distribution.utils.model_utils import model_local_dir + from llama_stack.core.utils.model_utils import model_local_dir console = Console() model_dir = Path(model_local_dir(args.model_id)) diff --git a/llama_stack/distribution/__init__.py b/llama_stack/core/__init__.py similarity index 100% rename from llama_stack/distribution/__init__.py rename to llama_stack/core/__init__.py diff --git a/llama_stack/distribution/access_control/__init__.py b/llama_stack/core/access_control/__init__.py similarity index 100% rename from llama_stack/distribution/access_control/__init__.py rename to llama_stack/core/access_control/__init__.py diff --git a/llama_stack/distribution/access_control/access_control.py b/llama_stack/core/access_control/access_control.py similarity index 98% rename from llama_stack/distribution/access_control/access_control.py rename to llama_stack/core/access_control/access_control.py index 64c0122c1..bde5cfd76 100644 --- a/llama_stack/distribution/access_control/access_control.py +++ b/llama_stack/core/access_control/access_control.py @@ -6,7 +6,7 @@ from typing import Any -from llama_stack.distribution.datatypes import User +from llama_stack.core.datatypes import User from .conditions import ( Condition, diff --git a/llama_stack/distribution/access_control/conditions.py b/llama_stack/core/access_control/conditions.py similarity index 100% rename from llama_stack/distribution/access_control/conditions.py rename to llama_stack/core/access_control/conditions.py diff --git a/llama_stack/distribution/access_control/datatypes.py b/llama_stack/core/access_control/datatypes.py similarity index 100% rename from llama_stack/distribution/access_control/datatypes.py rename to llama_stack/core/access_control/datatypes.py diff --git a/llama_stack/distribution/build.py b/llama_stack/core/build.py similarity index 83% rename from llama_stack/distribution/build.py rename to llama_stack/core/build.py index b4eaac1c7..b3e35ecef 100644 --- a/llama_stack/distribution/build.py +++ b/llama_stack/core/build.py @@ -7,18 +7,17 @@ import importlib.resources import logging import sys -from pathlib import Path from pydantic import BaseModel from termcolor import cprint -from llama_stack.distribution.datatypes import BuildConfig -from llama_stack.distribution.distribution import get_provider_registry -from llama_stack.distribution.external import load_external_apis -from llama_stack.distribution.utils.exec import run_command -from llama_stack.distribution.utils.image_types import LlamaStackImageType +from llama_stack.core.datatypes import BuildConfig +from llama_stack.core.distribution import get_provider_registry +from llama_stack.core.external import load_external_apis +from llama_stack.core.utils.exec import run_command +from llama_stack.core.utils.image_types import LlamaStackImageType +from llama_stack.distributions.template import DistributionTemplate from llama_stack.providers.datatypes import Api -from llama_stack.templates.template import DistributionTemplate log = logging.getLogger(__name__) @@ -106,9 +105,8 @@ def print_pip_install_help(config: BuildConfig): def build_image( build_config: BuildConfig, - build_file_path: Path, image_name: str, - template_or_config: str, + distro_or_config: str, run_config: str | None = None, ): container_base = build_config.distribution_spec.container_image or "python:3.12-slim" @@ -122,11 +120,11 @@ def build_image( normal_deps.extend(api_spec.pip_packages) if build_config.image_type == LlamaStackImageType.CONTAINER.value: - script = str(importlib.resources.files("llama_stack") / "distribution/build_container.sh") + script = str(importlib.resources.files("llama_stack") / "core/build_container.sh") args = [ script, - "--template-or-config", - template_or_config, + "--distro-or-config", + distro_or_config, "--image-name", image_name, "--container-base", @@ -138,19 +136,8 @@ def build_image( # build arguments if run_config is not None: args.extend(["--run-config", run_config]) - elif build_config.image_type == LlamaStackImageType.CONDA.value: - script = str(importlib.resources.files("llama_stack") / "distribution/build_conda_env.sh") - args = [ - script, - "--env-name", - str(image_name), - "--build-file-path", - str(build_file_path), - "--normal-deps", - " ".join(normal_deps), - ] - elif build_config.image_type == LlamaStackImageType.VENV.value: - script = str(importlib.resources.files("llama_stack") / "distribution/build_venv.sh") + else: + script = str(importlib.resources.files("llama_stack") / "core/build_venv.sh") args = [ script, "--env-name", diff --git a/llama_stack/distribution/build_conda_env.sh b/llama_stack/core/build_conda_env.sh similarity index 100% rename from llama_stack/distribution/build_conda_env.sh rename to llama_stack/core/build_conda_env.sh diff --git a/llama_stack/distribution/build_container.sh b/llama_stack/core/build_container.sh similarity index 95% rename from llama_stack/distribution/build_container.sh rename to llama_stack/core/build_container.sh index 50d8e4925..424b40a9d 100755 --- a/llama_stack/distribution/build_container.sh +++ b/llama_stack/core/build_container.sh @@ -43,7 +43,7 @@ normal_deps="" external_provider_deps="" optional_deps="" run_config="" -template_or_config="" +distro_or_config="" while [[ $# -gt 0 ]]; do key="$1" @@ -96,12 +96,12 @@ while [[ $# -gt 0 ]]; do run_config="$2" shift 2 ;; - --template-or-config) + --distro-or-config) if [[ -z "$2" || "$2" == --* ]]; then - echo "Error: --template-or-config requires a string value" >&2 + echo "Error: --distro-or-config requires a string value" >&2 usage fi - template_or_config="$2" + distro_or_config="$2" shift 2 ;; *) @@ -327,12 +327,11 @@ EOF # If a run config is provided, we use the --config flag if [[ -n "$run_config" ]]; then add_to_container << EOF -ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server", "--config", "$RUN_CONFIG_PATH"] +ENTRYPOINT ["python", "-m", "llama_stack.core.server.server", "$RUN_CONFIG_PATH"] EOF -# If a template is provided (not a yaml file), we use the --template flag -elif [[ "$template_or_config" != *.yaml ]]; then +elif [[ "$distro_or_config" != *.yaml ]]; then add_to_container << EOF -ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server", "--template", "$template_or_config"] +ENTRYPOINT ["python", "-m", "llama_stack.core.server.server", "$distro_or_config"] EOF fi diff --git a/llama_stack/distribution/build_venv.sh b/llama_stack/core/build_venv.sh similarity index 97% rename from llama_stack/distribution/build_venv.sh rename to llama_stack/core/build_venv.sh index 93db9ab28..a2838803f 100755 --- a/llama_stack/distribution/build_venv.sh +++ b/llama_stack/core/build_venv.sh @@ -6,9 +6,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -# TODO: combine this with build_conda_env.sh since it is almost identical -# the only difference is that we don't do any conda-specific setup - LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-} LLAMA_STACK_CLIENT_DIR=${LLAMA_STACK_CLIENT_DIR:-} TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-} @@ -95,6 +92,8 @@ if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then echo "Using llama-stack-client-dir=$LLAMA_STACK_CLIENT_DIR" fi +ENVNAME="" + # pre-run checks to make sure we can proceed with the installation pre_run_checks() { local env_name="$1" diff --git a/llama_stack/distribution/client.py b/llama_stack/core/client.py similarity index 100% rename from llama_stack/distribution/client.py rename to llama_stack/core/client.py diff --git a/llama_stack/distribution/common.sh b/llama_stack/core/common.sh similarity index 59% rename from llama_stack/distribution/common.sh rename to llama_stack/core/common.sh index 5f764bcca..021baaddc 100755 --- a/llama_stack/distribution/common.sh +++ b/llama_stack/core/common.sh @@ -7,12 +7,10 @@ # the root directory of this source tree. cleanup() { - envname="$1" - - set +x - echo "Cleaning up..." - conda deactivate - conda env remove --name "$envname" -y + # For venv environments, no special cleanup is needed + # This function exists to avoid "function not found" errors + local env_name="$1" + echo "Cleanup called for environment: $env_name" } handle_int() { @@ -31,19 +29,7 @@ handle_exit() { fi } -setup_cleanup_handlers() { - trap handle_int INT - trap handle_exit EXIT - if is_command_available conda; then - __conda_setup="$('conda' 'shell.bash' 'hook' 2>/dev/null)" - eval "$__conda_setup" - conda deactivate - else - echo "conda is not available" - exit 1 - fi -} # check if a command is present is_command_available() { diff --git a/llama_stack/distribution/configure.py b/llama_stack/core/configure.py similarity index 93% rename from llama_stack/distribution/configure.py rename to llama_stack/core/configure.py index 20be040a0..9e18b438c 100644 --- a/llama_stack/distribution/configure.py +++ b/llama_stack/core/configure.py @@ -7,20 +7,20 @@ import logging import textwrap from typing import Any -from llama_stack.distribution.datatypes import ( +from llama_stack.core.datatypes import ( LLAMA_STACK_RUN_CONFIG_VERSION, DistributionSpec, Provider, StackRunConfig, ) -from llama_stack.distribution.distribution import ( +from llama_stack.core.distribution import ( builtin_automatically_routed_apis, get_provider_registry, ) -from llama_stack.distribution.stack import cast_image_name_to_string, replace_env_vars -from llama_stack.distribution.utils.config_dirs import EXTERNAL_PROVIDERS_DIR -from llama_stack.distribution.utils.dynamic import instantiate_class_type -from llama_stack.distribution.utils.prompt_for_config import prompt_for_config +from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars +from llama_stack.core.utils.config_dirs import EXTERNAL_PROVIDERS_DIR +from llama_stack.core.utils.dynamic import instantiate_class_type +from llama_stack.core.utils.prompt_for_config import prompt_for_config from llama_stack.providers.datatypes import Api, ProviderSpec logger = logging.getLogger(__name__) diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/core/datatypes.py similarity index 98% rename from llama_stack/distribution/datatypes.py rename to llama_stack/core/datatypes.py index 60c317337..a1b6ad32b 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/core/datatypes.py @@ -24,7 +24,7 @@ from llama_stack.apis.shields import Shield, ShieldInput from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput from llama_stack.apis.vector_io import VectorIO -from llama_stack.distribution.access_control.datatypes import AccessRule +from llama_stack.core.access_control.datatypes import AccessRule from llama_stack.providers.datatypes import Api, ProviderSpec from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig from llama_stack.providers.utils.sqlstore.sqlstore import SqlStoreConfig @@ -432,8 +432,8 @@ class BuildConfig(BaseModel): distribution_spec: DistributionSpec = Field(description="The distribution spec to build including API providers. ") image_type: str = Field( - default="conda", - description="Type of package to build (conda | container | venv)", + default="venv", + description="Type of package to build (container | venv)", ) image_name: str | None = Field( default=None, diff --git a/llama_stack/distribution/distribution.py b/llama_stack/core/distribution.py similarity index 98% rename from llama_stack/distribution/distribution.py rename to llama_stack/core/distribution.py index 6e7297e32..977eb5393 100644 --- a/llama_stack/distribution/distribution.py +++ b/llama_stack/core/distribution.py @@ -12,8 +12,8 @@ from typing import Any import yaml from pydantic import BaseModel -from llama_stack.distribution.datatypes import BuildConfig, DistributionSpec -from llama_stack.distribution.external import load_external_apis +from llama_stack.core.datatypes import BuildConfig, DistributionSpec +from llama_stack.core.external import load_external_apis from llama_stack.log import get_logger from llama_stack.providers.datatypes import ( AdapterSpec, diff --git a/llama_stack/distribution/external.py b/llama_stack/core/external.py similarity index 96% rename from llama_stack/distribution/external.py rename to llama_stack/core/external.py index 0a7da16b1..12e9824ad 100644 --- a/llama_stack/distribution/external.py +++ b/llama_stack/core/external.py @@ -8,7 +8,7 @@ import yaml from llama_stack.apis.datatypes import Api, ExternalApiSpec -from llama_stack.distribution.datatypes import BuildConfig, StackRunConfig +from llama_stack.core.datatypes import BuildConfig, StackRunConfig from llama_stack.log import get_logger logger = get_logger(name=__name__, category="core") diff --git a/llama_stack/distribution/inspect.py b/llama_stack/core/inspect.py similarity index 93% rename from llama_stack/distribution/inspect.py rename to llama_stack/core/inspect.py index f62de4f6b..37dab4199 100644 --- a/llama_stack/distribution/inspect.py +++ b/llama_stack/core/inspect.py @@ -15,9 +15,9 @@ from llama_stack.apis.inspect import ( RouteInfo, VersionInfo, ) -from llama_stack.distribution.datatypes import StackRunConfig -from llama_stack.distribution.external import load_external_apis -from llama_stack.distribution.server.routes import get_all_api_routes +from llama_stack.core.datatypes import StackRunConfig +from llama_stack.core.external import load_external_apis +from llama_stack.core.server.routes import get_all_api_routes from llama_stack.providers.datatypes import HealthStatus diff --git a/llama_stack/distribution/library_client.py b/llama_stack/core/library_client.py similarity index 89% rename from llama_stack/distribution/library_client.py rename to llama_stack/core/library_client.py index 1c28983cf..5fbbf1aff 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/core/library_client.py @@ -31,23 +31,23 @@ from pydantic import BaseModel, TypeAdapter from rich.console import Console from termcolor import cprint -from llama_stack.distribution.build import print_pip_install_help -from llama_stack.distribution.configure import parse_and_maybe_upgrade_config -from llama_stack.distribution.datatypes import Api, BuildConfig, BuildProvider, DistributionSpec -from llama_stack.distribution.request_headers import ( +from llama_stack.core.build import print_pip_install_help +from llama_stack.core.configure import parse_and_maybe_upgrade_config +from llama_stack.core.datatypes import Api, BuildConfig, BuildProvider, DistributionSpec +from llama_stack.core.request_headers import ( PROVIDER_DATA_VAR, request_provider_data_context, ) -from llama_stack.distribution.resolver import ProviderRegistry -from llama_stack.distribution.server.routes import find_matching_route, initialize_route_impls -from llama_stack.distribution.stack import ( +from llama_stack.core.resolver import ProviderRegistry +from llama_stack.core.server.routes import RouteImpls, find_matching_route, initialize_route_impls +from llama_stack.core.stack import ( construct_stack, - get_stack_run_config_from_template, + get_stack_run_config_from_distro, replace_env_vars, ) -from llama_stack.distribution.utils.config import redact_sensitive_fields -from llama_stack.distribution.utils.context import preserve_contexts_async_generator -from llama_stack.distribution.utils.exec import in_notebook +from llama_stack.core.utils.config import redact_sensitive_fields +from llama_stack.core.utils.context import preserve_contexts_async_generator +from llama_stack.core.utils.exec import in_notebook from llama_stack.providers.utils.telemetry.tracing import ( CURRENT_TRACE_CONTEXT, end_trace, @@ -138,14 +138,14 @@ class LibraryClientHttpxResponse: class LlamaStackAsLibraryClient(LlamaStackClient): def __init__( self, - config_path_or_template_name: str, + config_path_or_distro_name: str, skip_logger_removal: bool = False, custom_provider_registry: ProviderRegistry | None = None, provider_data: dict[str, Any] | None = None, ): super().__init__() self.async_client = AsyncLlamaStackAsLibraryClient( - config_path_or_template_name, custom_provider_registry, provider_data + config_path_or_distro_name, custom_provider_registry, provider_data ) self.pool_executor = ThreadPoolExecutor(max_workers=4) self.skip_logger_removal = skip_logger_removal @@ -212,7 +212,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient): class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): def __init__( self, - config_path_or_template_name: str, + config_path_or_distro_name: str, custom_provider_registry: ProviderRegistry | None = None, provider_data: dict[str, Any] | None = None, ): @@ -222,20 +222,21 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",") os.environ["TELEMETRY_SINKS"] = ",".join(sink for sink in current_sinks if sink != "console") - if config_path_or_template_name.endswith(".yaml"): - config_path = Path(config_path_or_template_name) + if config_path_or_distro_name.endswith(".yaml"): + config_path = Path(config_path_or_distro_name) if not config_path.exists(): raise ValueError(f"Config file {config_path} does not exist") config_dict = replace_env_vars(yaml.safe_load(config_path.read_text())) config = parse_and_maybe_upgrade_config(config_dict) else: - # template - config = get_stack_run_config_from_template(config_path_or_template_name) + # distribution + config = get_stack_run_config_from_distro(config_path_or_distro_name) - self.config_path_or_template_name = config_path_or_template_name + self.config_path_or_distro_name = config_path_or_distro_name self.config = config self.custom_provider_registry = custom_provider_registry self.provider_data = provider_data + self.route_impls: RouteImpls | None = None # Initialize to None to prevent AttributeError async def initialize(self) -> bool: try: @@ -244,11 +245,11 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): except ModuleNotFoundError as _e: cprint(_e.msg, color="red", file=sys.stderr) cprint( - "Using llama-stack as a library requires installing dependencies depending on the template (providers) you choose.\n", + "Using llama-stack as a library requires installing dependencies depending on the distribution (providers) you choose.\n", color="yellow", file=sys.stderr, ) - if self.config_path_or_template_name.endswith(".yaml"): + if self.config_path_or_distro_name.endswith(".yaml"): providers: dict[str, list[BuildProvider]] = {} for api, run_providers in self.config.providers.items(): for provider in run_providers: @@ -266,7 +267,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): else: prefix = "!" if in_notebook() else "" cprint( - f"Please run:\n\n{prefix}llama stack build --template {self.config_path_or_template_name} --image-type venv\n\n", + f"Please run:\n\n{prefix}llama stack build --distro {self.config_path_or_distro_name} --image-type venv\n\n", "yellow", file=sys.stderr, ) @@ -282,7 +283,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): if not os.environ.get("PYTEST_CURRENT_TEST"): console = Console() - console.print(f"Using config [blue]{self.config_path_or_template_name}[/blue]:") + console.print(f"Using config [blue]{self.config_path_or_distro_name}[/blue]:") safe_config = redact_sensitive_fields(self.config.model_dump()) console.print(yaml.dump(safe_config, indent=2)) @@ -297,8 +298,8 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): stream=False, stream_cls=None, ): - if not self.route_impls: - raise ValueError("Client not initialized") + if self.route_impls is None: + raise ValueError("Client not initialized. Please call initialize() first.") # Create headers with provider data if available headers = options.headers or {} @@ -353,9 +354,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): cast_to: Any, options: Any, ): - if self.route_impls is None: - raise ValueError("Client not initialized") - + assert self.route_impls is not None # Should be guaranteed by request() method, assertion for mypy path = options.url body = options.params or {} body |= options.json_data or {} @@ -412,9 +411,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): options: Any, stream_cls: Any, ): - if self.route_impls is None: - raise ValueError("Client not initialized") - + assert self.route_impls is not None # Should be guaranteed by request() method, assertion for mypy path = options.url body = options.params or {} body |= options.json_data or {} @@ -474,9 +471,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): if not body: return {} - if self.route_impls is None: - raise ValueError("Client not initialized") - + assert self.route_impls is not None # Should be guaranteed by request() method, assertion for mypy exclude_params = exclude_params or set() func, _, _, _ = find_matching_route(method, path, self.route_impls) diff --git a/llama_stack/distribution/providers.py b/llama_stack/core/providers.py similarity index 100% rename from llama_stack/distribution/providers.py rename to llama_stack/core/providers.py diff --git a/llama_stack/distribution/request_headers.py b/llama_stack/core/request_headers.py similarity index 98% rename from llama_stack/distribution/request_headers.py rename to llama_stack/core/request_headers.py index 509c2be44..35ac72775 100644 --- a/llama_stack/distribution/request_headers.py +++ b/llama_stack/core/request_headers.py @@ -10,7 +10,7 @@ import logging from contextlib import AbstractContextManager from typing import Any -from llama_stack.distribution.datatypes import User +from llama_stack.core.datatypes import User from .utils.dynamic import instantiate_class_type diff --git a/llama_stack/distribution/resolver.py b/llama_stack/core/resolver.py similarity index 97% rename from llama_stack/distribution/resolver.py rename to llama_stack/core/resolver.py index db6856ed2..70c78fb01 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/core/resolver.py @@ -27,18 +27,18 @@ from llama_stack.apis.telemetry import Telemetry from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.apis.vector_dbs import VectorDBs from llama_stack.apis.vector_io import VectorIO -from llama_stack.distribution.client import get_client_impl -from llama_stack.distribution.datatypes import ( +from llama_stack.core.client import get_client_impl +from llama_stack.core.datatypes import ( AccessRule, AutoRoutedProviderSpec, Provider, RoutingTableProviderSpec, StackRunConfig, ) -from llama_stack.distribution.distribution import builtin_automatically_routed_apis -from llama_stack.distribution.external import load_external_apis -from llama_stack.distribution.store import DistributionRegistry -from llama_stack.distribution.utils.dynamic import instantiate_class_type +from llama_stack.core.distribution import builtin_automatically_routed_apis +from llama_stack.core.external import load_external_apis +from llama_stack.core.store import DistributionRegistry +from llama_stack.core.utils.dynamic import instantiate_class_type from llama_stack.log import get_logger from llama_stack.providers.datatypes import ( Api, @@ -183,7 +183,7 @@ def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, spec=RoutingTableProviderSpec( api=info.routing_table_api, router_api=info.router_api, - module="llama_stack.distribution.routers", + module="llama_stack.core.routers", api_dependencies=[], deps__=[f"inner-{info.router_api.value}"], ), @@ -197,7 +197,7 @@ def specs_for_autorouted_apis(apis_to_serve: list[str] | set[str]) -> dict[str, config={}, spec=AutoRoutedProviderSpec( api=info.router_api, - module="llama_stack.distribution.routers", + module="llama_stack.core.routers", routing_table_api=info.routing_table_api, api_dependencies=[info.routing_table_api], # Add telemetry as an optional dependency to all auto-routed providers diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/core/routers/__init__.py similarity index 94% rename from llama_stack/distribution/routers/__init__.py rename to llama_stack/core/routers/__init__.py index 8671a62e1..1faace34a 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/core/routers/__init__.py @@ -6,9 +6,9 @@ from typing import Any -from llama_stack.distribution.datatypes import AccessRule, RoutedProtocol -from llama_stack.distribution.stack import StackRunConfig -from llama_stack.distribution.store import DistributionRegistry +from llama_stack.core.datatypes import AccessRule, RoutedProtocol +from llama_stack.core.stack import StackRunConfig +from llama_stack.core.store import DistributionRegistry from llama_stack.providers.datatypes import Api, RoutingTable from llama_stack.providers.utils.inference.inference_store import InferenceStore diff --git a/llama_stack/distribution/routers/datasets.py b/llama_stack/core/routers/datasets.py similarity index 100% rename from llama_stack/distribution/routers/datasets.py rename to llama_stack/core/routers/datasets.py diff --git a/llama_stack/distribution/routers/eval_scoring.py b/llama_stack/core/routers/eval_scoring.py similarity index 100% rename from llama_stack/distribution/routers/eval_scoring.py rename to llama_stack/core/routers/eval_scoring.py diff --git a/llama_stack/distribution/routers/inference.py b/llama_stack/core/routers/inference.py similarity index 56% rename from llama_stack/distribution/routers/inference.py rename to llama_stack/core/routers/inference.py index a5cc8c4b5..79ab7c34f 100644 --- a/llama_stack/distribution/routers/inference.py +++ b/llama_stack/core/routers/inference.py @@ -7,6 +7,7 @@ import asyncio import time from collections.abc import AsyncGenerator, AsyncIterator +from datetime import UTC, datetime from typing import Annotated, Any from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam @@ -17,6 +18,7 @@ from llama_stack.apis.common.content_types import ( InterleavedContent, InterleavedContentItem, ) +from llama_stack.apis.common.errors import ModelNotFoundError from llama_stack.apis.inference import ( BatchChatCompletionResponse, BatchCompletionResponse, @@ -24,14 +26,21 @@ from llama_stack.apis.inference import ( ChatCompletionResponseEventType, ChatCompletionResponseStreamChunk, CompletionMessage, + CompletionResponse, + CompletionResponseStreamChunk, EmbeddingsResponse, EmbeddingTaskType, Inference, ListOpenAIChatCompletionResponse, LogProbConfig, Message, + OpenAIAssistantMessageParam, OpenAIChatCompletion, OpenAIChatCompletionChunk, + OpenAIChatCompletionToolCall, + OpenAIChatCompletionToolCallFunction, + OpenAIChoice, + OpenAIChoiceLogprobs, OpenAICompletion, OpenAICompletionWithInputMessages, OpenAIEmbeddingsResponse, @@ -54,7 +63,6 @@ from llama_stack.models.llama.llama3.chat_format import ChatFormat from llama_stack.models.llama.llama3.tokenizer import Tokenizer from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable from llama_stack.providers.utils.inference.inference_store import InferenceStore -from llama_stack.providers.utils.inference.stream_utils import stream_and_store_openai_completion from llama_stack.providers.utils.telemetry.tracing import get_current_span logger = get_logger(name=__name__, category="core") @@ -79,11 +87,9 @@ class InferenceRouter(Inference): async def initialize(self) -> None: logger.debug("InferenceRouter.initialize") - pass async def shutdown(self) -> None: logger.debug("InferenceRouter.shutdown") - pass async def register_model( self, @@ -120,6 +126,7 @@ class InferenceRouter(Inference): if span is None: logger.warning("No span found for token usage metrics") return [] + metrics = [ ("prompt_tokens", prompt_tokens), ("completion_tokens", completion_tokens), @@ -133,7 +140,7 @@ class InferenceRouter(Inference): span_id=span.span_id, metric=metric_name, value=value, - timestamp=time.time(), + timestamp=datetime.now(UTC), unit="tokens", attributes={ "model_id": model.model_id, @@ -190,7 +197,7 @@ class InferenceRouter(Inference): sampling_params = SamplingParams() model = await self.routing_table.get_model(model_id) if model is None: - raise ValueError(f"Model '{model_id}' not found") + raise ModelNotFoundError(model_id) if model.model_type == ModelType.embedding: raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions") if tool_config: @@ -235,49 +242,26 @@ class InferenceRouter(Inference): prompt_tokens = await self._count_tokens(messages, tool_config.tool_prompt_format) if stream: - - async def stream_generator(): - completion_text = "" - async for chunk in await provider.chat_completion(**params): - if chunk.event.event_type == ChatCompletionResponseEventType.progress: - if chunk.event.delta.type == "text": - completion_text += chunk.event.delta.text - if chunk.event.event_type == ChatCompletionResponseEventType.complete: - completion_tokens = await self._count_tokens( - [ - CompletionMessage( - content=completion_text, - stop_reason=StopReason.end_of_turn, - ) - ], - tool_config.tool_prompt_format, - ) - total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) - metrics = await self._compute_and_log_token_usage( - prompt_tokens or 0, - completion_tokens or 0, - total_tokens, - model, - ) - chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics - yield chunk - - return stream_generator() - else: - response = await provider.chat_completion(**params) - completion_tokens = await self._count_tokens( - [response.completion_message], - tool_config.tool_prompt_format, + response_stream = await provider.chat_completion(**params) + return self.stream_tokens_and_compute_metrics( + response=response_stream, + prompt_tokens=prompt_tokens, + model=model, + tool_prompt_format=tool_config.tool_prompt_format, ) - total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) - metrics = await self._compute_and_log_token_usage( - prompt_tokens or 0, - completion_tokens or 0, - total_tokens, - model, - ) - response.metrics = metrics if response.metrics is None else response.metrics + metrics - return response + + response = await provider.chat_completion(**params) + metrics = await self.count_tokens_and_compute_metrics( + response=response, + prompt_tokens=prompt_tokens, + model=model, + tool_prompt_format=tool_config.tool_prompt_format, + ) + # these metrics will show up in the client response. + response.metrics = ( + metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics + ) + return response async def batch_chat_completion( self, @@ -319,7 +303,7 @@ class InferenceRouter(Inference): ) model = await self.routing_table.get_model(model_id) if model is None: - raise ValueError(f"Model '{model_id}' not found") + raise ModelNotFoundError(model_id) if model.model_type == ModelType.embedding: raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions") provider = await self.routing_table.get_provider_impl(model_id) @@ -333,39 +317,20 @@ class InferenceRouter(Inference): ) prompt_tokens = await self._count_tokens(content) - + response = await provider.completion(**params) if stream: - - async def stream_generator(): - completion_text = "" - async for chunk in await provider.completion(**params): - if hasattr(chunk, "delta"): - completion_text += chunk.delta - if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry: - completion_tokens = await self._count_tokens(completion_text) - total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) - metrics = await self._compute_and_log_token_usage( - prompt_tokens or 0, - completion_tokens or 0, - total_tokens, - model, - ) - chunk.metrics = metrics if chunk.metrics is None else chunk.metrics + metrics - yield chunk - - return stream_generator() - else: - response = await provider.completion(**params) - completion_tokens = await self._count_tokens(response.content) - total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) - metrics = await self._compute_and_log_token_usage( - prompt_tokens or 0, - completion_tokens or 0, - total_tokens, - model, + return self.stream_tokens_and_compute_metrics( + response=response, + prompt_tokens=prompt_tokens, + model=model, ) - response.metrics = metrics if response.metrics is None else response.metrics + metrics - return response + + metrics = await self.count_tokens_and_compute_metrics( + response=response, prompt_tokens=prompt_tokens, model=model + ) + response.metrics = metrics if response.metrics is None else response.metrics + metrics + + return response async def batch_completion( self, @@ -392,7 +357,7 @@ class InferenceRouter(Inference): logger.debug(f"InferenceRouter.embeddings: {model_id}") model = await self.routing_table.get_model(model_id) if model is None: - raise ValueError(f"Model '{model_id}' not found") + raise ModelNotFoundError(model_id) if model.model_type == ModelType.llm: raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings") provider = await self.routing_table.get_provider_impl(model_id) @@ -432,7 +397,7 @@ class InferenceRouter(Inference): ) model_obj = await self.routing_table.get_model(model) if model_obj is None: - raise ValueError(f"Model '{model}' not found") + raise ModelNotFoundError(model) if model_obj.model_type == ModelType.embedding: raise ValueError(f"Model '{model}' is an embedding model and does not support completions") @@ -458,9 +423,29 @@ class InferenceRouter(Inference): prompt_logprobs=prompt_logprobs, suffix=suffix, ) - provider = await self.routing_table.get_provider_impl(model_obj.identifier) - return await provider.openai_completion(**params) + if stream: + return await provider.openai_completion(**params) + # TODO: Metrics do NOT work with openai_completion stream=True due to the fact + # that we do not return an AsyncIterator, our tests expect a stream of chunks we cannot intercept currently. + # response_stream = await provider.openai_completion(**params) + + response = await provider.openai_completion(**params) + if self.telemetry: + metrics = self._construct_metrics( + prompt_tokens=response.usage.prompt_tokens, + completion_tokens=response.usage.completion_tokens, + total_tokens=response.usage.total_tokens, + model=model_obj, + ) + for metric in metrics: + await self.telemetry.log_event(metric) + + # these metrics will show up in the client response. + response.metrics = ( + metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics + ) + return response async def openai_chat_completion( self, @@ -493,7 +478,7 @@ class InferenceRouter(Inference): ) model_obj = await self.routing_table.get_model(model) if model_obj is None: - raise ValueError(f"Model '{model}' not found") + raise ModelNotFoundError(model) if model_obj.model_type == ModelType.embedding: raise ValueError(f"Model '{model}' is an embedding model and does not support chat completions") @@ -538,18 +523,38 @@ class InferenceRouter(Inference): top_p=top_p, user=user, ) - provider = await self.routing_table.get_provider_impl(model_obj.identifier) if stream: response_stream = await provider.openai_chat_completion(**params) - if self.store: - return stream_and_store_openai_completion(response_stream, model, self.store, messages) - return response_stream - else: - response = await self._nonstream_openai_chat_completion(provider, params) - if self.store: - await self.store.store_chat_completion(response, messages) - return response + + # For streaming, the provider returns AsyncIterator[OpenAIChatCompletionChunk] + # We need to add metrics to each chunk and store the final completion + return self.stream_tokens_and_compute_metrics_openai_chat( + response=response_stream, + model=model_obj, + messages=messages, + ) + + response = await self._nonstream_openai_chat_completion(provider, params) + + # Store the response with the ID that will be returned to the client + if self.store: + await self.store.store_chat_completion(response, messages) + + if self.telemetry: + metrics = self._construct_metrics( + prompt_tokens=response.usage.prompt_tokens, + completion_tokens=response.usage.completion_tokens, + total_tokens=response.usage.total_tokens, + model=model_obj, + ) + for metric in metrics: + await self.telemetry.log_event(metric) + # these metrics will show up in the client response. + response.metrics = ( + metrics if not hasattr(response, "metrics") or response.metrics is None else response.metrics + metrics + ) + return response async def openai_embeddings( self, @@ -564,7 +569,7 @@ class InferenceRouter(Inference): ) model_obj = await self.routing_table.get_model(model) if model_obj is None: - raise ValueError(f"Model '{model}' not found") + raise ModelNotFoundError(model) if model_obj.model_type != ModelType.embedding: raise ValueError(f"Model '{model}' is not an embedding model") @@ -626,3 +631,244 @@ class InferenceRouter(Inference): status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}" ) return health_statuses + + async def stream_tokens_and_compute_metrics( + self, + response, + prompt_tokens, + model, + tool_prompt_format: ToolPromptFormat | None = None, + ) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None] | AsyncGenerator[CompletionResponseStreamChunk, None]: + completion_text = "" + async for chunk in response: + complete = False + if hasattr(chunk, "event"): # only ChatCompletions have .event + if chunk.event.event_type == ChatCompletionResponseEventType.progress: + if chunk.event.delta.type == "text": + completion_text += chunk.event.delta.text + if chunk.event.event_type == ChatCompletionResponseEventType.complete: + complete = True + completion_tokens = await self._count_tokens( + [ + CompletionMessage( + content=completion_text, + stop_reason=StopReason.end_of_turn, + ) + ], + tool_prompt_format=tool_prompt_format, + ) + else: + if hasattr(chunk, "delta"): + completion_text += chunk.delta + if hasattr(chunk, "stop_reason") and chunk.stop_reason and self.telemetry: + complete = True + completion_tokens = await self._count_tokens(completion_text) + # if we are done receiving tokens + if complete: + total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) + + # Create a separate span for streaming completion metrics + if self.telemetry: + # Log metrics in the new span context + completion_metrics = self._construct_metrics( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + model=model, + ) + for metric in completion_metrics: + if metric.metric in [ + "completion_tokens", + "total_tokens", + ]: # Only log completion and total tokens + await self.telemetry.log_event(metric) + + # Return metrics in response + async_metrics = [ + MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics + ] + chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics + else: + # Fallback if no telemetry + completion_metrics = self._construct_metrics( + prompt_tokens or 0, + completion_tokens or 0, + total_tokens, + model, + ) + async_metrics = [ + MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics + ] + chunk.metrics = async_metrics if chunk.metrics is None else chunk.metrics + async_metrics + yield chunk + + async def count_tokens_and_compute_metrics( + self, + response: ChatCompletionResponse | CompletionResponse, + prompt_tokens, + model, + tool_prompt_format: ToolPromptFormat | None = None, + ): + if isinstance(response, ChatCompletionResponse): + content = [response.completion_message] + else: + content = response.content + completion_tokens = await self._count_tokens(messages=content, tool_prompt_format=tool_prompt_format) + total_tokens = (prompt_tokens or 0) + (completion_tokens or 0) + + # Create a separate span for completion metrics + if self.telemetry: + # Log metrics in the new span context + completion_metrics = self._construct_metrics( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + model=model, + ) + for metric in completion_metrics: + if metric.metric in ["completion_tokens", "total_tokens"]: # Only log completion and total tokens + await self.telemetry.log_event(metric) + + # Return metrics in response + return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in completion_metrics] + + # Fallback if no telemetry + metrics = self._construct_metrics( + prompt_tokens or 0, + completion_tokens or 0, + total_tokens, + model, + ) + return [MetricInResponse(metric=metric.metric, value=metric.value) for metric in metrics] + + async def stream_tokens_and_compute_metrics_openai_chat( + self, + response: AsyncIterator[OpenAIChatCompletionChunk], + model: Model, + messages: list[OpenAIMessageParam] | None = None, + ) -> AsyncIterator[OpenAIChatCompletionChunk]: + """Stream OpenAI chat completion chunks, compute metrics, and store the final completion.""" + id = None + created = None + choices_data: dict[int, dict[str, Any]] = {} + + try: + async for chunk in response: + # Skip None chunks + if chunk is None: + continue + + # Capture ID and created timestamp from first chunk + if id is None and chunk.id: + id = chunk.id + if created is None and chunk.created: + created = chunk.created + + # Accumulate choice data for final assembly + if chunk.choices: + for choice_delta in chunk.choices: + idx = choice_delta.index + if idx not in choices_data: + choices_data[idx] = { + "content_parts": [], + "tool_calls_builder": {}, + "finish_reason": None, + "logprobs_content_parts": [], + } + current_choice_data = choices_data[idx] + + if choice_delta.delta: + delta = choice_delta.delta + if delta.content: + current_choice_data["content_parts"].append(delta.content) + if delta.tool_calls: + for tool_call_delta in delta.tool_calls: + tc_idx = tool_call_delta.index + if tc_idx not in current_choice_data["tool_calls_builder"]: + current_choice_data["tool_calls_builder"][tc_idx] = { + "id": None, + "type": "function", + "function_name_parts": [], + "function_arguments_parts": [], + } + builder = current_choice_data["tool_calls_builder"][tc_idx] + if tool_call_delta.id: + builder["id"] = tool_call_delta.id + if tool_call_delta.type: + builder["type"] = tool_call_delta.type + if tool_call_delta.function: + if tool_call_delta.function.name: + builder["function_name_parts"].append(tool_call_delta.function.name) + if tool_call_delta.function.arguments: + builder["function_arguments_parts"].append( + tool_call_delta.function.arguments + ) + if choice_delta.finish_reason: + current_choice_data["finish_reason"] = choice_delta.finish_reason + if choice_delta.logprobs and choice_delta.logprobs.content: + current_choice_data["logprobs_content_parts"].extend(choice_delta.logprobs.content) + + # Compute metrics on final chunk + if chunk.choices and chunk.choices[0].finish_reason: + completion_text = "" + for choice_data in choices_data.values(): + completion_text += "".join(choice_data["content_parts"]) + + # Add metrics to the chunk + if self.telemetry and chunk.usage: + metrics = self._construct_metrics( + prompt_tokens=chunk.usage.prompt_tokens, + completion_tokens=chunk.usage.completion_tokens, + total_tokens=chunk.usage.total_tokens, + model=model, + ) + for metric in metrics: + await self.telemetry.log_event(metric) + + yield chunk + finally: + # Store the final assembled completion + if id and self.store and messages: + assembled_choices: list[OpenAIChoice] = [] + for choice_idx, choice_data in choices_data.items(): + content_str = "".join(choice_data["content_parts"]) + assembled_tool_calls: list[OpenAIChatCompletionToolCall] = [] + if choice_data["tool_calls_builder"]: + for tc_build_data in choice_data["tool_calls_builder"].values(): + if tc_build_data["id"]: + func_name = "".join(tc_build_data["function_name_parts"]) + func_args = "".join(tc_build_data["function_arguments_parts"]) + assembled_tool_calls.append( + OpenAIChatCompletionToolCall( + id=tc_build_data["id"], + type=tc_build_data["type"], + function=OpenAIChatCompletionToolCallFunction( + name=func_name, arguments=func_args + ), + ) + ) + message = OpenAIAssistantMessageParam( + role="assistant", + content=content_str if content_str else None, + tool_calls=assembled_tool_calls if assembled_tool_calls else None, + ) + logprobs_content = choice_data["logprobs_content_parts"] + final_logprobs = OpenAIChoiceLogprobs(content=logprobs_content) if logprobs_content else None + + assembled_choices.append( + OpenAIChoice( + finish_reason=choice_data["finish_reason"], + index=choice_idx, + message=message, + logprobs=final_logprobs, + ) + ) + + final_response = OpenAIChatCompletion( + id=id, + choices=assembled_choices, + created=created or int(time.time()), + model=model.identifier, + object="chat.completion", + ) + await self.store.store_chat_completion(final_response, messages) diff --git a/llama_stack/core/routers/safety.py b/llama_stack/core/routers/safety.py new file mode 100644 index 000000000..9bf2b1bac --- /dev/null +++ b/llama_stack/core/routers/safety.py @@ -0,0 +1,101 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from llama_stack.apis.inference import ( + Message, +) +from llama_stack.apis.safety import RunShieldResponse, Safety +from llama_stack.apis.safety.safety import ModerationObject, OpenAICategories +from llama_stack.apis.shields import Shield +from llama_stack.log import get_logger +from llama_stack.providers.datatypes import RoutingTable + +logger = get_logger(name=__name__, category="core") + + +class SafetyRouter(Safety): + def __init__( + self, + routing_table: RoutingTable, + ) -> None: + logger.debug("Initializing SafetyRouter") + self.routing_table = routing_table + + async def initialize(self) -> None: + logger.debug("SafetyRouter.initialize") + pass + + async def shutdown(self) -> None: + logger.debug("SafetyRouter.shutdown") + pass + + async def register_shield( + self, + shield_id: str, + provider_shield_id: str | None = None, + provider_id: str | None = None, + params: dict[str, Any] | None = None, + ) -> Shield: + logger.debug(f"SafetyRouter.register_shield: {shield_id}") + return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params) + + async def unregister_shield(self, identifier: str) -> None: + logger.debug(f"SafetyRouter.unregister_shield: {identifier}") + return await self.routing_table.unregister_shield(identifier) + + async def run_shield( + self, + shield_id: str, + messages: list[Message], + params: dict[str, Any] = None, + ) -> RunShieldResponse: + logger.debug(f"SafetyRouter.run_shield: {shield_id}") + provider = await self.routing_table.get_provider_impl(shield_id) + return await provider.run_shield( + shield_id=shield_id, + messages=messages, + params=params, + ) + + async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject: + async def get_shield_id(self, model: str) -> str: + """Get Shield id from model (provider_resource_id) of shield.""" + list_shields_response = await self.routing_table.list_shields() + + matches = [s.identifier for s in list_shields_response.data if model == s.provider_resource_id] + if not matches: + raise ValueError(f"No shield associated with provider_resource id {model}") + if len(matches) > 1: + raise ValueError(f"Multiple shields associated with provider_resource id {model}") + return matches[0] + + shield_id = await get_shield_id(self, model) + logger.debug(f"SafetyRouter.run_moderation: {shield_id}") + provider = await self.routing_table.get_provider_impl(shield_id) + + response = await provider.run_moderation( + input=input, + model=model, + ) + self._validate_required_categories_exist(response) + + return response + + def _validate_required_categories_exist(self, response: ModerationObject) -> None: + """Validate the ProviderImpl response contains the required Open AI moderations categories.""" + required_categories = list(map(str, OpenAICategories)) + + categories = response.results[0].categories + category_applied_input_types = response.results[0].category_applied_input_types + category_scores = response.results[0].category_scores + + for i in [categories, category_applied_input_types, category_scores]: + if not set(required_categories).issubset(set(i.keys())): + raise ValueError( + f"ProviderImpl response is missing required categories: {set(required_categories) - set(i.keys())}" + ) diff --git a/llama_stack/distribution/routers/tool_runtime.py b/llama_stack/core/routers/tool_runtime.py similarity index 100% rename from llama_stack/distribution/routers/tool_runtime.py rename to llama_stack/core/routers/tool_runtime.py diff --git a/llama_stack/distribution/routers/vector_io.py b/llama_stack/core/routers/vector_io.py similarity index 100% rename from llama_stack/distribution/routers/vector_io.py rename to llama_stack/core/routers/vector_io.py diff --git a/llama_stack/distribution/routing_tables/__init__.py b/llama_stack/core/routing_tables/__init__.py similarity index 100% rename from llama_stack/distribution/routing_tables/__init__.py rename to llama_stack/core/routing_tables/__init__.py diff --git a/llama_stack/distribution/routing_tables/benchmarks.py b/llama_stack/core/routing_tables/benchmarks.py similarity index 97% rename from llama_stack/distribution/routing_tables/benchmarks.py rename to llama_stack/core/routing_tables/benchmarks.py index 815483494..74bee8040 100644 --- a/llama_stack/distribution/routing_tables/benchmarks.py +++ b/llama_stack/core/routing_tables/benchmarks.py @@ -7,7 +7,7 @@ from typing import Any from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse -from llama_stack.distribution.datatypes import ( +from llama_stack.core.datatypes import ( BenchmarkWithOwner, ) from llama_stack.log import get_logger diff --git a/llama_stack/distribution/routing_tables/common.py b/llama_stack/core/routing_tables/common.py similarity index 95% rename from llama_stack/distribution/routing_tables/common.py rename to llama_stack/core/routing_tables/common.py index caf0780fd..339ff6da4 100644 --- a/llama_stack/distribution/routing_tables/common.py +++ b/llama_stack/core/routing_tables/common.py @@ -6,19 +6,20 @@ from typing import Any +from llama_stack.apis.common.errors import ModelNotFoundError from llama_stack.apis.models import Model from llama_stack.apis.resource import ResourceType from llama_stack.apis.scoring_functions import ScoringFn -from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed -from llama_stack.distribution.access_control.datatypes import Action -from llama_stack.distribution.datatypes import ( +from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed +from llama_stack.core.access_control.datatypes import Action +from llama_stack.core.datatypes import ( AccessRule, RoutableObject, RoutableObjectWithProvider, RoutedProtocol, ) -from llama_stack.distribution.request_headers import get_authenticated_user -from llama_stack.distribution.store import DistributionRegistry +from llama_stack.core.request_headers import get_authenticated_user +from llama_stack.core.store import DistributionRegistry from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api, RoutingTable @@ -59,6 +60,8 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None: return await p.unregister_vector_db(obj.identifier) elif api == Api.inference: return await p.unregister_model(obj.identifier) + elif api == Api.safety: + return await p.unregister_shield(obj.identifier) elif api == Api.datasetio: return await p.unregister_dataset(obj.identifier) elif api == Api.tool_runtime: @@ -257,7 +260,7 @@ async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) -> models = await routing_table.get_all_with_type("model") matching_models = [m for m in models if m.provider_resource_id == model_id] if len(matching_models) == 0: - raise ValueError(f"Model '{model_id}' not found") + raise ModelNotFoundError(model_id) if len(matching_models) > 1: raise ValueError(f"Multiple providers found for '{model_id}': {[m.provider_id for m in matching_models]}") diff --git a/llama_stack/distribution/routing_tables/datasets.py b/llama_stack/core/routing_tables/datasets.py similarity index 93% rename from llama_stack/distribution/routing_tables/datasets.py rename to llama_stack/core/routing_tables/datasets.py index 47894313a..fc6a75df4 100644 --- a/llama_stack/distribution/routing_tables/datasets.py +++ b/llama_stack/core/routing_tables/datasets.py @@ -7,6 +7,7 @@ import uuid from typing import Any +from llama_stack.apis.common.errors import DatasetNotFoundError from llama_stack.apis.datasets import ( Dataset, DatasetPurpose, @@ -18,7 +19,7 @@ from llama_stack.apis.datasets import ( URIDataSource, ) from llama_stack.apis.resource import ResourceType -from llama_stack.distribution.datatypes import ( +from llama_stack.core.datatypes import ( DatasetWithOwner, ) from llama_stack.log import get_logger @@ -35,7 +36,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): async def get_dataset(self, dataset_id: str) -> Dataset: dataset = await self.get_object_by_identifier("dataset", dataset_id) if dataset is None: - raise ValueError(f"Dataset '{dataset_id}' not found") + raise DatasetNotFoundError(dataset_id) return dataset async def register_dataset( @@ -87,6 +88,4 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): async def unregister_dataset(self, dataset_id: str) -> None: dataset = await self.get_dataset(dataset_id) - if dataset is None: - raise ValueError(f"Dataset {dataset_id} not found") await self.unregister_object(dataset) diff --git a/llama_stack/distribution/routing_tables/models.py b/llama_stack/core/routing_tables/models.py similarity index 97% rename from llama_stack/distribution/routing_tables/models.py rename to llama_stack/core/routing_tables/models.py index 3928307c6..c76619271 100644 --- a/llama_stack/distribution/routing_tables/models.py +++ b/llama_stack/core/routing_tables/models.py @@ -7,8 +7,9 @@ import time from typing import Any +from llama_stack.apis.common.errors import ModelNotFoundError from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel -from llama_stack.distribution.datatypes import ( +from llama_stack.core.datatypes import ( ModelWithOwner, RegistryEntrySource, ) @@ -111,7 +112,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): async def unregister_model(self, model_id: str) -> None: existing_model = await self.get_model(model_id) if existing_model is None: - raise ValueError(f"Model {model_id} not found") + raise ModelNotFoundError(model_id) await self.unregister_object(existing_model) async def update_registered_models( diff --git a/llama_stack/distribution/routing_tables/scoring_functions.py b/llama_stack/core/routing_tables/scoring_functions.py similarity index 97% rename from llama_stack/distribution/routing_tables/scoring_functions.py rename to llama_stack/core/routing_tables/scoring_functions.py index 742cc3ca6..5874ba941 100644 --- a/llama_stack/distribution/routing_tables/scoring_functions.py +++ b/llama_stack/core/routing_tables/scoring_functions.py @@ -12,7 +12,7 @@ from llama_stack.apis.scoring_functions import ( ScoringFnParams, ScoringFunctions, ) -from llama_stack.distribution.datatypes import ( +from llama_stack.core.datatypes import ( ScoringFnWithOwner, ) from llama_stack.log import get_logger diff --git a/llama_stack/distribution/routing_tables/shields.py b/llama_stack/core/routing_tables/shields.py similarity index 90% rename from llama_stack/distribution/routing_tables/shields.py rename to llama_stack/core/routing_tables/shields.py index 5215981b9..e08f35bfc 100644 --- a/llama_stack/distribution/routing_tables/shields.py +++ b/llama_stack/core/routing_tables/shields.py @@ -8,7 +8,7 @@ from typing import Any from llama_stack.apis.resource import ResourceType from llama_stack.apis.shields import ListShieldsResponse, Shield, Shields -from llama_stack.distribution.datatypes import ( +from llama_stack.core.datatypes import ( ShieldWithOwner, ) from llama_stack.log import get_logger @@ -55,3 +55,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): ) await self.register_object(shield) return shield + + async def unregister_shield(self, identifier: str) -> None: + existing_shield = await self.get_shield(identifier) + await self.unregister_object(existing_shield) diff --git a/llama_stack/distribution/routing_tables/toolgroups.py b/llama_stack/core/routing_tables/toolgroups.py similarity index 95% rename from llama_stack/distribution/routing_tables/toolgroups.py rename to llama_stack/core/routing_tables/toolgroups.py index 22c4e109a..e172af991 100644 --- a/llama_stack/distribution/routing_tables/toolgroups.py +++ b/llama_stack/core/routing_tables/toolgroups.py @@ -7,8 +7,9 @@ from typing import Any from llama_stack.apis.common.content_types import URL +from llama_stack.apis.common.errors import ToolGroupNotFoundError from llama_stack.apis.tools import ListToolGroupsResponse, ListToolsResponse, Tool, ToolGroup, ToolGroups -from llama_stack.distribution.datatypes import ToolGroupWithOwner +from llama_stack.core.datatypes import ToolGroupWithOwner from llama_stack.log import get_logger from .common import CommonRoutingTableImpl @@ -87,7 +88,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): async def get_tool_group(self, toolgroup_id: str) -> ToolGroup: tool_group = await self.get_object_by_identifier("tool_group", toolgroup_id) if tool_group is None: - raise ValueError(f"Tool group '{toolgroup_id}' not found") + raise ToolGroupNotFoundError(toolgroup_id) return tool_group async def get_tool(self, tool_name: str) -> Tool: @@ -125,7 +126,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): async def unregister_toolgroup(self, toolgroup_id: str) -> None: tool_group = await self.get_tool_group(toolgroup_id) if tool_group is None: - raise ValueError(f"Tool group {toolgroup_id} not found") + raise ToolGroupNotFoundError(toolgroup_id) await self.unregister_object(tool_group) async def shutdown(self) -> None: diff --git a/llama_stack/distribution/routing_tables/vector_dbs.py b/llama_stack/core/routing_tables/vector_dbs.py similarity index 96% rename from llama_stack/distribution/routing_tables/vector_dbs.py rename to llama_stack/core/routing_tables/vector_dbs.py index 58ecf24da..c81a27a3b 100644 --- a/llama_stack/distribution/routing_tables/vector_dbs.py +++ b/llama_stack/core/routing_tables/vector_dbs.py @@ -8,6 +8,7 @@ from typing import Any from pydantic import TypeAdapter +from llama_stack.apis.common.errors import ModelNotFoundError, VectorStoreNotFoundError from llama_stack.apis.models import ModelType from llama_stack.apis.resource import ResourceType from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs @@ -22,7 +23,7 @@ from llama_stack.apis.vector_io.vector_io import ( VectorStoreObject, VectorStoreSearchResponsePage, ) -from llama_stack.distribution.datatypes import ( +from llama_stack.core.datatypes import ( VectorDBWithOwner, ) from llama_stack.log import get_logger @@ -39,7 +40,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): async def get_vector_db(self, vector_db_id: str) -> VectorDB: vector_db = await self.get_object_by_identifier("vector_db", vector_db_id) if vector_db is None: - raise ValueError(f"Vector DB '{vector_db_id}' not found") + raise VectorStoreNotFoundError(vector_db_id) return vector_db async def register_vector_db( @@ -63,7 +64,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): raise ValueError("No provider available. Please configure a vector_io provider.") model = await lookup_model(self, embedding_model) if model is None: - raise ValueError(f"Model {embedding_model} not found") + raise ModelNotFoundError(embedding_model) if model.model_type != ModelType.embedding: raise ValueError(f"Model {embedding_model} is not an embedding model") if "embedding_dimension" not in model.metadata: @@ -83,8 +84,6 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): async def unregister_vector_db(self, vector_db_id: str) -> None: existing_vector_db = await self.get_vector_db(vector_db_id) - if existing_vector_db is None: - raise ValueError(f"Vector DB {vector_db_id} not found") await self.unregister_object(existing_vector_db) async def openai_retrieve_vector_store( diff --git a/llama_stack/distribution/server/__init__.py b/llama_stack/core/server/__init__.py similarity index 100% rename from llama_stack/distribution/server/__init__.py rename to llama_stack/core/server/__init__.py diff --git a/llama_stack/distribution/server/auth.py b/llama_stack/core/server/auth.py similarity index 95% rename from llama_stack/distribution/server/auth.py rename to llama_stack/core/server/auth.py index 87c1a2ab6..e4fb4ff2b 100644 --- a/llama_stack/distribution/server/auth.py +++ b/llama_stack/core/server/auth.py @@ -9,10 +9,10 @@ import json import httpx from aiohttp import hdrs -from llama_stack.distribution.datatypes import AuthenticationConfig, User -from llama_stack.distribution.request_headers import user_from_scope -from llama_stack.distribution.server.auth_providers import create_auth_provider -from llama_stack.distribution.server.routes import find_matching_route, initialize_route_impls +from llama_stack.core.datatypes import AuthenticationConfig, User +from llama_stack.core.request_headers import user_from_scope +from llama_stack.core.server.auth_providers import create_auth_provider +from llama_stack.core.server.routes import find_matching_route, initialize_route_impls from llama_stack.log import get_logger logger = get_logger(name=__name__, category="auth") diff --git a/llama_stack/distribution/server/auth_providers.py b/llama_stack/core/server/auth_providers.py similarity index 99% rename from llama_stack/distribution/server/auth_providers.py rename to llama_stack/core/server/auth_providers.py index 9b0e182f5..73d5581c2 100644 --- a/llama_stack/distribution/server/auth_providers.py +++ b/llama_stack/core/server/auth_providers.py @@ -14,7 +14,7 @@ import httpx from jose import jwt from pydantic import BaseModel, Field -from llama_stack.distribution.datatypes import ( +from llama_stack.core.datatypes import ( AuthenticationConfig, CustomAuthConfig, GitHubTokenAuthConfig, diff --git a/llama_stack/distribution/server/quota.py b/llama_stack/core/server/quota.py similarity index 100% rename from llama_stack/distribution/server/quota.py rename to llama_stack/core/server/quota.py diff --git a/llama_stack/distribution/server/routes.py b/llama_stack/core/server/routes.py similarity index 98% rename from llama_stack/distribution/server/routes.py rename to llama_stack/core/server/routes.py index ca6f629af..7baf20da5 100644 --- a/llama_stack/distribution/server/routes.py +++ b/llama_stack/core/server/routes.py @@ -15,7 +15,7 @@ from starlette.routing import Route from llama_stack.apis.datatypes import Api, ExternalApiSpec from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup from llama_stack.apis.version import LLAMA_STACK_API_VERSION -from llama_stack.distribution.resolver import api_protocol_map +from llama_stack.core.resolver import api_protocol_map from llama_stack.schema_utils import WebMethod EndpointFunc = Callable[..., Any] diff --git a/llama_stack/distribution/server/server.py b/llama_stack/core/server/server.py similarity index 95% rename from llama_stack/distribution/server/server.py rename to llama_stack/core/server/server.py index 96a0d60e7..fe5cc68d7 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/core/server/server.py @@ -32,36 +32,36 @@ from openai import BadRequestError from pydantic import BaseModel, ValidationError from llama_stack.apis.common.responses import PaginatedResponse -from llama_stack.cli.utils import add_config_template_args, get_config_from_args -from llama_stack.distribution.access_control.access_control import AccessDeniedError -from llama_stack.distribution.datatypes import ( +from llama_stack.cli.utils import add_config_distro_args, get_config_from_args +from llama_stack.core.access_control.access_control import AccessDeniedError +from llama_stack.core.datatypes import ( AuthenticationRequiredError, LoggingConfig, StackRunConfig, ) -from llama_stack.distribution.distribution import builtin_automatically_routed_apis -from llama_stack.distribution.external import ExternalApiSpec, load_external_apis -from llama_stack.distribution.request_headers import ( +from llama_stack.core.distribution import builtin_automatically_routed_apis +from llama_stack.core.external import ExternalApiSpec, load_external_apis +from llama_stack.core.request_headers import ( PROVIDER_DATA_VAR, request_provider_data_context, user_from_scope, ) -from llama_stack.distribution.resolver import InvalidProviderError -from llama_stack.distribution.server.routes import ( +from llama_stack.core.resolver import InvalidProviderError +from llama_stack.core.server.routes import ( find_matching_route, get_all_api_routes, initialize_route_impls, ) -from llama_stack.distribution.stack import ( +from llama_stack.core.stack import ( cast_image_name_to_string, construct_stack, replace_env_vars, shutdown_stack, validate_env_pair, ) -from llama_stack.distribution.utils.config import redact_sensitive_fields -from llama_stack.distribution.utils.config_resolution import Mode, resolve_config_or_template -from llama_stack.distribution.utils.context import preserve_contexts_async_generator +from llama_stack.core.utils.config import redact_sensitive_fields +from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro +from llama_stack.core.utils.context import preserve_contexts_async_generator from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig @@ -377,7 +377,7 @@ def main(args: argparse.Namespace | None = None): """Start the LlamaStack server.""" parser = argparse.ArgumentParser(description="Start the LlamaStack server.") - add_config_template_args(parser) + add_config_distro_args(parser) parser.add_argument( "--port", type=int, @@ -396,8 +396,8 @@ def main(args: argparse.Namespace | None = None): if args is None: args = parser.parse_args() - config_or_template = get_config_from_args(args) - config_file = resolve_config_or_template(config_or_template, Mode.RUN) + config_or_distro = get_config_from_args(args) + config_file = resolve_config_or_distro(config_or_distro, Mode.RUN) logger_config = None with open(config_file) as fp: diff --git a/llama_stack/distribution/stack.py b/llama_stack/core/stack.py similarity index 90% rename from llama_stack/distribution/stack.py rename to llama_stack/core/stack.py index 40e0b9b50..87a3978c1 100644 --- a/llama_stack/distribution/stack.py +++ b/llama_stack/core/stack.py @@ -34,14 +34,14 @@ from llama_stack.apis.telemetry import Telemetry from llama_stack.apis.tools import RAGToolRuntime, ToolGroups, ToolRuntime from llama_stack.apis.vector_dbs import VectorDBs from llama_stack.apis.vector_io import VectorIO -from llama_stack.distribution.datatypes import Provider, StackRunConfig -from llama_stack.distribution.distribution import get_provider_registry -from llama_stack.distribution.inspect import DistributionInspectConfig, DistributionInspectImpl -from llama_stack.distribution.providers import ProviderImpl, ProviderImplConfig -from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls -from llama_stack.distribution.routing_tables.common import CommonRoutingTableImpl -from llama_stack.distribution.store.registry import create_dist_registry -from llama_stack.distribution.utils.dynamic import instantiate_class_type +from llama_stack.core.datatypes import Provider, StackRunConfig +from llama_stack.core.distribution import get_provider_registry +from llama_stack.core.inspect import DistributionInspectConfig, DistributionInspectImpl +from llama_stack.core.providers import ProviderImpl, ProviderImplConfig +from llama_stack.core.resolver import ProviderRegistry, resolve_impls +from llama_stack.core.routing_tables.common import CommonRoutingTableImpl +from llama_stack.core.store.registry import create_dist_registry +from llama_stack.core.utils.dynamic import instantiate_class_type from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api @@ -94,6 +94,7 @@ RESOURCES = [ REGISTRY_REFRESH_INTERVAL_SECONDS = 300 REGISTRY_REFRESH_TASK = None +TEST_RECORDING_CONTEXT = None async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]): @@ -307,6 +308,15 @@ def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConf async def construct_stack( run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None ) -> dict[Api, Any]: + if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ: + from llama_stack.testing.inference_recorder import setup_inference_recording + + global TEST_RECORDING_CONTEXT + TEST_RECORDING_CONTEXT = setup_inference_recording() + if TEST_RECORDING_CONTEXT: + TEST_RECORDING_CONTEXT.__enter__() + logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}") + dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name) policy = run_config.server.auth.access_policy if run_config.server.auth else [] impls = await resolve_impls( @@ -352,6 +362,13 @@ async def shutdown_stack(impls: dict[Api, Any]): except (Exception, asyncio.CancelledError) as e: logger.exception(f"Failed to shutdown {impl_name}: {e}") + global TEST_RECORDING_CONTEXT + if TEST_RECORDING_CONTEXT: + try: + TEST_RECORDING_CONTEXT.__exit__(None, None, None) + except Exception as e: + logger.error(f"Error during inference recording cleanup: {e}") + global REGISTRY_REFRESH_TASK if REGISTRY_REFRESH_TASK: REGISTRY_REFRESH_TASK.cancel() @@ -372,12 +389,12 @@ async def refresh_registry_task(impls: dict[Api, Any]): await asyncio.sleep(REGISTRY_REFRESH_INTERVAL_SECONDS) -def get_stack_run_config_from_template(template: str) -> StackRunConfig: - template_path = importlib.resources.files("llama_stack") / f"templates/{template}/run.yaml" +def get_stack_run_config_from_distro(distro: str) -> StackRunConfig: + distro_path = importlib.resources.files("llama_stack") / f"distributions/{distro}/run.yaml" - with importlib.resources.as_file(template_path) as path: + with importlib.resources.as_file(distro_path) as path: if not path.exists(): - raise ValueError(f"Template '{template}' not found at {template_path}") + raise ValueError(f"Distribution '{distro}' not found at {distro_path}") run_config = yaml.safe_load(path.open()) return StackRunConfig(**replace_env_vars(run_config)) diff --git a/llama_stack/distribution/start_stack.sh b/llama_stack/core/start_stack.sh similarity index 80% rename from llama_stack/distribution/start_stack.sh rename to llama_stack/core/start_stack.sh index 77a7dc92e..a3fc83265 100755 --- a/llama_stack/distribution/start_stack.sh +++ b/llama_stack/core/start_stack.sh @@ -40,7 +40,6 @@ port="$1" shift SCRIPT_DIR=$(dirname "$(readlink -f "$0")") -source "$SCRIPT_DIR/common.sh" # Initialize variables yaml_config="" @@ -75,9 +74,9 @@ while [[ $# -gt 0 ]]; do esac done -# Check if yaml_config is required based on env_type -if [[ "$env_type" == "venv" || "$env_type" == "conda" ]] && [ -z "$yaml_config" ]; then - echo -e "${RED}Error: --config is required for venv and conda environments${NC}" >&2 +# Check if yaml_config is required +if [[ "$env_type" == "venv" ]] && [ -z "$yaml_config" ]; then + echo -e "${RED}Error: --config is required for venv environment${NC}" >&2 exit 1 fi @@ -101,19 +100,14 @@ case "$env_type" in source "$env_path_or_name/bin/activate" fi ;; - "conda") - if ! is_command_available conda; then - echo -e "${RED}Error: conda not found" >&2 - exit 1 - fi - eval "$(conda shell.bash hook)" - conda deactivate && conda activate "$env_path_or_name" - PYTHON_BINARY="$CONDA_PREFIX/bin/python" - ;; *) + # Handle unsupported env_types here + echo -e "${RED}Error: Unsupported environment type '$env_type'. Only 'venv' is supported.${NC}" >&2 + exit 1 + ;; esac -if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then +if [[ "$env_type" == "venv" ]]; then set -x if [ -n "$yaml_config" ]; then @@ -122,7 +116,7 @@ if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then yaml_config_arg="" fi - $PYTHON_BINARY -m llama_stack.distribution.server.server \ + $PYTHON_BINARY -m llama_stack.core.server.server \ $yaml_config_arg \ --port "$port" \ $env_vars \ diff --git a/llama_stack/distribution/store/__init__.py b/llama_stack/core/store/__init__.py similarity index 100% rename from llama_stack/distribution/store/__init__.py rename to llama_stack/core/store/__init__.py diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/core/store/registry.py similarity index 98% rename from llama_stack/distribution/store/registry.py rename to llama_stack/core/store/registry.py index cd7cd9f00..4b60e1001 100644 --- a/llama_stack/distribution/store/registry.py +++ b/llama_stack/core/store/registry.py @@ -10,8 +10,8 @@ from typing import Protocol import pydantic -from llama_stack.distribution.datatypes import RoutableObjectWithProvider -from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR +from llama_stack.core.datatypes import RoutableObjectWithProvider +from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR from llama_stack.log import get_logger from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig diff --git a/llama_stack/distribution/ui/Containerfile b/llama_stack/core/ui/Containerfile similarity index 100% rename from llama_stack/distribution/ui/Containerfile rename to llama_stack/core/ui/Containerfile diff --git a/llama_stack/distribution/ui/README.md b/llama_stack/core/ui/README.md similarity index 93% rename from llama_stack/distribution/ui/README.md rename to llama_stack/core/ui/README.md index 51c2d2bc2..05b4adc26 100644 --- a/llama_stack/distribution/ui/README.md +++ b/llama_stack/core/ui/README.md @@ -9,7 +9,7 @@ 1. Start up Llama Stack API server. More details [here](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html). ``` -llama stack build --template together --image-type conda +llama stack build --distro together --image-type venv llama stack run together ``` @@ -36,7 +36,7 @@ llama-stack-client benchmarks register \ 3. Start Streamlit UI ```bash -uv run --with ".[ui]" streamlit run llama_stack/distribution/ui/app.py +uv run --with ".[ui]" streamlit run llama_stack.core/ui/app.py ``` ## Environment Variables diff --git a/llama_stack/distribution/ui/__init__.py b/llama_stack/core/ui/__init__.py similarity index 100% rename from llama_stack/distribution/ui/__init__.py rename to llama_stack/core/ui/__init__.py diff --git a/llama_stack/distribution/ui/app.py b/llama_stack/core/ui/app.py similarity index 100% rename from llama_stack/distribution/ui/app.py rename to llama_stack/core/ui/app.py diff --git a/llama_stack/distribution/ui/modules/__init__.py b/llama_stack/core/ui/modules/__init__.py similarity index 100% rename from llama_stack/distribution/ui/modules/__init__.py rename to llama_stack/core/ui/modules/__init__.py diff --git a/llama_stack/distribution/ui/modules/api.py b/llama_stack/core/ui/modules/api.py similarity index 100% rename from llama_stack/distribution/ui/modules/api.py rename to llama_stack/core/ui/modules/api.py diff --git a/llama_stack/distribution/ui/modules/utils.py b/llama_stack/core/ui/modules/utils.py similarity index 100% rename from llama_stack/distribution/ui/modules/utils.py rename to llama_stack/core/ui/modules/utils.py diff --git a/llama_stack/distribution/ui/page/__init__.py b/llama_stack/core/ui/page/__init__.py similarity index 100% rename from llama_stack/distribution/ui/page/__init__.py rename to llama_stack/core/ui/page/__init__.py diff --git a/llama_stack/distribution/ui/page/distribution/__init__.py b/llama_stack/core/ui/page/distribution/__init__.py similarity index 100% rename from llama_stack/distribution/ui/page/distribution/__init__.py rename to llama_stack/core/ui/page/distribution/__init__.py diff --git a/llama_stack/distribution/ui/page/distribution/datasets.py b/llama_stack/core/ui/page/distribution/datasets.py similarity index 88% rename from llama_stack/distribution/ui/page/distribution/datasets.py rename to llama_stack/core/ui/page/distribution/datasets.py index 6842b29a7..aab0901ac 100644 --- a/llama_stack/distribution/ui/page/distribution/datasets.py +++ b/llama_stack/core/ui/page/distribution/datasets.py @@ -6,7 +6,7 @@ import streamlit as st -from llama_stack.distribution.ui.modules.api import llama_stack_api +from llama_stack.core.ui.modules.api import llama_stack_api def datasets(): diff --git a/llama_stack/distribution/ui/page/distribution/eval_tasks.py b/llama_stack/core/ui/page/distribution/eval_tasks.py similarity index 90% rename from llama_stack/distribution/ui/page/distribution/eval_tasks.py rename to llama_stack/core/ui/page/distribution/eval_tasks.py index 492be4700..1a0ce502b 100644 --- a/llama_stack/distribution/ui/page/distribution/eval_tasks.py +++ b/llama_stack/core/ui/page/distribution/eval_tasks.py @@ -6,7 +6,7 @@ import streamlit as st -from llama_stack.distribution.ui.modules.api import llama_stack_api +from llama_stack.core.ui.modules.api import llama_stack_api def benchmarks(): diff --git a/llama_stack/distribution/ui/page/distribution/models.py b/llama_stack/core/ui/page/distribution/models.py similarity index 87% rename from llama_stack/distribution/ui/page/distribution/models.py rename to llama_stack/core/ui/page/distribution/models.py index f29459098..f84508746 100644 --- a/llama_stack/distribution/ui/page/distribution/models.py +++ b/llama_stack/core/ui/page/distribution/models.py @@ -6,7 +6,7 @@ import streamlit as st -from llama_stack.distribution.ui.modules.api import llama_stack_api +from llama_stack.core.ui.modules.api import llama_stack_api def models(): diff --git a/llama_stack/distribution/ui/page/distribution/providers.py b/llama_stack/core/ui/page/distribution/providers.py similarity index 91% rename from llama_stack/distribution/ui/page/distribution/providers.py rename to llama_stack/core/ui/page/distribution/providers.py index c660cb986..3ec6026d1 100644 --- a/llama_stack/distribution/ui/page/distribution/providers.py +++ b/llama_stack/core/ui/page/distribution/providers.py @@ -6,7 +6,7 @@ import streamlit as st -from llama_stack.distribution.ui.modules.api import llama_stack_api +from llama_stack.core.ui.modules.api import llama_stack_api def providers(): diff --git a/llama_stack/distribution/ui/page/distribution/resources.py b/llama_stack/core/ui/page/distribution/resources.py similarity index 70% rename from llama_stack/distribution/ui/page/distribution/resources.py rename to llama_stack/core/ui/page/distribution/resources.py index 5e10e6e80..c56fcfff3 100644 --- a/llama_stack/distribution/ui/page/distribution/resources.py +++ b/llama_stack/core/ui/page/distribution/resources.py @@ -6,12 +6,12 @@ from streamlit_option_menu import option_menu -from llama_stack.distribution.ui.page.distribution.datasets import datasets -from llama_stack.distribution.ui.page.distribution.eval_tasks import benchmarks -from llama_stack.distribution.ui.page.distribution.models import models -from llama_stack.distribution.ui.page.distribution.scoring_functions import scoring_functions -from llama_stack.distribution.ui.page.distribution.shields import shields -from llama_stack.distribution.ui.page.distribution.vector_dbs import vector_dbs +from llama_stack.core.ui.page.distribution.datasets import datasets +from llama_stack.core.ui.page.distribution.eval_tasks import benchmarks +from llama_stack.core.ui.page.distribution.models import models +from llama_stack.core.ui.page.distribution.scoring_functions import scoring_functions +from llama_stack.core.ui.page.distribution.shields import shields +from llama_stack.core.ui.page.distribution.vector_dbs import vector_dbs def resources_page(): diff --git a/llama_stack/distribution/ui/page/distribution/scoring_functions.py b/llama_stack/core/ui/page/distribution/scoring_functions.py similarity index 89% rename from llama_stack/distribution/ui/page/distribution/scoring_functions.py rename to llama_stack/core/ui/page/distribution/scoring_functions.py index 193146356..2a5196fa9 100644 --- a/llama_stack/distribution/ui/page/distribution/scoring_functions.py +++ b/llama_stack/core/ui/page/distribution/scoring_functions.py @@ -6,7 +6,7 @@ import streamlit as st -from llama_stack.distribution.ui.modules.api import llama_stack_api +from llama_stack.core.ui.modules.api import llama_stack_api def scoring_functions(): diff --git a/llama_stack/distribution/ui/page/distribution/shields.py b/llama_stack/core/ui/page/distribution/shields.py similarity index 88% rename from llama_stack/distribution/ui/page/distribution/shields.py rename to llama_stack/core/ui/page/distribution/shields.py index 67d66d64f..ecce2f12b 100644 --- a/llama_stack/distribution/ui/page/distribution/shields.py +++ b/llama_stack/core/ui/page/distribution/shields.py @@ -6,7 +6,7 @@ import streamlit as st -from llama_stack.distribution.ui.modules.api import llama_stack_api +from llama_stack.core.ui.modules.api import llama_stack_api def shields(): diff --git a/llama_stack/distribution/ui/page/distribution/vector_dbs.py b/llama_stack/core/ui/page/distribution/vector_dbs.py similarity index 90% rename from llama_stack/distribution/ui/page/distribution/vector_dbs.py rename to llama_stack/core/ui/page/distribution/vector_dbs.py index 49a4f25bb..e81077d2a 100644 --- a/llama_stack/distribution/ui/page/distribution/vector_dbs.py +++ b/llama_stack/core/ui/page/distribution/vector_dbs.py @@ -6,7 +6,7 @@ import streamlit as st -from llama_stack.distribution.ui.modules.api import llama_stack_api +from llama_stack.core.ui.modules.api import llama_stack_api def vector_dbs(): diff --git a/llama_stack/distribution/ui/page/evaluations/__init__.py b/llama_stack/core/ui/page/evaluations/__init__.py similarity index 100% rename from llama_stack/distribution/ui/page/evaluations/__init__.py rename to llama_stack/core/ui/page/evaluations/__init__.py diff --git a/llama_stack/distribution/ui/page/evaluations/app_eval.py b/llama_stack/core/ui/page/evaluations/app_eval.py similarity index 97% rename from llama_stack/distribution/ui/page/evaluations/app_eval.py rename to llama_stack/core/ui/page/evaluations/app_eval.py index d7bc6388c..07e6349c9 100644 --- a/llama_stack/distribution/ui/page/evaluations/app_eval.py +++ b/llama_stack/core/ui/page/evaluations/app_eval.py @@ -9,8 +9,8 @@ import json import pandas as pd import streamlit as st -from llama_stack.distribution.ui.modules.api import llama_stack_api -from llama_stack.distribution.ui.modules.utils import process_dataset +from llama_stack.core.ui.modules.api import llama_stack_api +from llama_stack.core.ui.modules.utils import process_dataset def application_evaluation_page(): diff --git a/llama_stack/distribution/ui/page/evaluations/native_eval.py b/llama_stack/core/ui/page/evaluations/native_eval.py similarity index 99% rename from llama_stack/distribution/ui/page/evaluations/native_eval.py rename to llama_stack/core/ui/page/evaluations/native_eval.py index 97f875e17..2bef63b2f 100644 --- a/llama_stack/distribution/ui/page/evaluations/native_eval.py +++ b/llama_stack/core/ui/page/evaluations/native_eval.py @@ -9,7 +9,7 @@ import json import pandas as pd import streamlit as st -from llama_stack.distribution.ui.modules.api import llama_stack_api +from llama_stack.core.ui.modules.api import llama_stack_api def select_benchmark_1(): diff --git a/llama_stack/distribution/ui/page/playground/__init__.py b/llama_stack/core/ui/page/playground/__init__.py similarity index 100% rename from llama_stack/distribution/ui/page/playground/__init__.py rename to llama_stack/core/ui/page/playground/__init__.py diff --git a/llama_stack/distribution/ui/page/playground/chat.py b/llama_stack/core/ui/page/playground/chat.py similarity index 98% rename from llama_stack/distribution/ui/page/playground/chat.py rename to llama_stack/core/ui/page/playground/chat.py index fcaf08795..d391d0fb7 100644 --- a/llama_stack/distribution/ui/page/playground/chat.py +++ b/llama_stack/core/ui/page/playground/chat.py @@ -6,7 +6,7 @@ import streamlit as st -from llama_stack.distribution.ui.modules.api import llama_stack_api +from llama_stack.core.ui.modules.api import llama_stack_api # Sidebar configurations with st.sidebar: diff --git a/llama_stack/distribution/ui/page/playground/rag.py b/llama_stack/core/ui/page/playground/rag.py similarity index 98% rename from llama_stack/distribution/ui/page/playground/rag.py rename to llama_stack/core/ui/page/playground/rag.py index 696d89bc2..2ffae1c33 100644 --- a/llama_stack/distribution/ui/page/playground/rag.py +++ b/llama_stack/core/ui/page/playground/rag.py @@ -10,8 +10,8 @@ import streamlit as st from llama_stack_client import Agent, AgentEventLogger, RAGDocument from llama_stack.apis.common.content_types import ToolCallDelta -from llama_stack.distribution.ui.modules.api import llama_stack_api -from llama_stack.distribution.ui.modules.utils import data_url_from_file +from llama_stack.core.ui.modules.api import llama_stack_api +from llama_stack.core.ui.modules.utils import data_url_from_file def rag_chat_page(): diff --git a/llama_stack/distribution/ui/page/playground/tools.py b/llama_stack/core/ui/page/playground/tools.py similarity index 99% rename from llama_stack/distribution/ui/page/playground/tools.py rename to llama_stack/core/ui/page/playground/tools.py index 149d8cce9..602c9eea1 100644 --- a/llama_stack/distribution/ui/page/playground/tools.py +++ b/llama_stack/core/ui/page/playground/tools.py @@ -13,7 +13,7 @@ from llama_stack_client import Agent from llama_stack_client.lib.agents.react.agent import ReActAgent from llama_stack_client.lib.agents.react.tool_parser import ReActOutput -from llama_stack.distribution.ui.modules.api import llama_stack_api +from llama_stack.core.ui.modules.api import llama_stack_api class AgentType(enum.Enum): diff --git a/llama_stack/distribution/ui/requirements.txt b/llama_stack/core/ui/requirements.txt similarity index 100% rename from llama_stack/distribution/ui/requirements.txt rename to llama_stack/core/ui/requirements.txt diff --git a/llama_stack/distribution/utils/__init__.py b/llama_stack/core/utils/__init__.py similarity index 100% rename from llama_stack/distribution/utils/__init__.py rename to llama_stack/core/utils/__init__.py diff --git a/llama_stack/distribution/utils/config.py b/llama_stack/core/utils/config.py similarity index 100% rename from llama_stack/distribution/utils/config.py rename to llama_stack/core/utils/config.py diff --git a/llama_stack/distribution/utils/config_dirs.py b/llama_stack/core/utils/config_dirs.py similarity index 100% rename from llama_stack/distribution/utils/config_dirs.py rename to llama_stack/core/utils/config_dirs.py diff --git a/llama_stack/core/utils/config_resolution.py b/llama_stack/core/utils/config_resolution.py new file mode 100644 index 000000000..30cd71e15 --- /dev/null +++ b/llama_stack/core/utils/config_resolution.py @@ -0,0 +1,125 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from enum import StrEnum +from pathlib import Path + +from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR +from llama_stack.log import get_logger + +logger = get_logger(name=__name__, category="config_resolution") + + +DISTRO_DIR = Path(__file__).parent.parent.parent.parent / "llama_stack" / "distributions" + + +class Mode(StrEnum): + RUN = "run" + BUILD = "build" + + +def resolve_config_or_distro( + config_or_distro: str, + mode: Mode = Mode.RUN, +) -> Path: + """ + Resolve a config/distro argument to a concrete config file path. + + Args: + config_or_distro: User input (file path, distribution name, or built distribution) + mode: Mode resolving for ("run", "build", "server") + + Returns: + Path to the resolved config file + + Raises: + ValueError: If resolution fails + """ + + # Strategy 1: Try as file path first + config_path = Path(config_or_distro) + if config_path.exists() and config_path.is_file(): + logger.info(f"Using file path: {config_path}") + return config_path.resolve() + + # Strategy 2: Try as distribution name (if no .yaml extension) + if not config_or_distro.endswith(".yaml"): + distro_config = _get_distro_config_path(config_or_distro, mode) + if distro_config.exists(): + logger.info(f"Using distribution: {distro_config}") + return distro_config + + # Strategy 3: Try as built distribution name + distrib_config = DISTRIBS_BASE_DIR / f"llamastack-{config_or_distro}" / f"{config_or_distro}-{mode}.yaml" + if distrib_config.exists(): + logger.info(f"Using built distribution: {distrib_config}") + return distrib_config + + distrib_config = DISTRIBS_BASE_DIR / f"{config_or_distro}" / f"{config_or_distro}-{mode}.yaml" + if distrib_config.exists(): + logger.info(f"Using built distribution: {distrib_config}") + return distrib_config + + # Strategy 4: Failed - provide helpful error + raise ValueError(_format_resolution_error(config_or_distro, mode)) + + +def _get_distro_config_path(distro_name: str, mode: Mode) -> Path: + """Get the config file path for a distro.""" + return DISTRO_DIR / distro_name / f"{mode}.yaml" + + +def _format_resolution_error(config_or_distro: str, mode: Mode) -> str: + """Format a helpful error message for resolution failures.""" + from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR + + distro_path = _get_distro_config_path(config_or_distro, mode) + distrib_path = DISTRIBS_BASE_DIR / f"llamastack-{config_or_distro}" / f"{config_or_distro}-{mode}.yaml" + distrib_path2 = DISTRIBS_BASE_DIR / f"{config_or_distro}" / f"{config_or_distro}-{mode}.yaml" + + available_distros = _get_available_distros() + distros_str = ", ".join(available_distros) if available_distros else "none found" + + return f"""Could not resolve config or distribution '{config_or_distro}'. + +Tried the following locations: + 1. As file path: {Path(config_or_distro).resolve()} + 2. As distribution: {distro_path} + 3. As built distribution: ({distrib_path}, {distrib_path2}) + +Available distributions: {distros_str} + +Did you mean one of these distributions? +{_format_distro_suggestions(available_distros, config_or_distro)} +""" + + +def _get_available_distros() -> list[str]: + """Get list of available distro names.""" + if not DISTRO_DIR.exists() and not DISTRIBS_BASE_DIR.exists(): + return [] + + return list( + set( + [d.name for d in DISTRO_DIR.iterdir() if d.is_dir() and not d.name.startswith(".")] + + [d.name for d in DISTRIBS_BASE_DIR.iterdir() if d.is_dir() and not d.name.startswith(".")] + ) + ) + + +def _format_distro_suggestions(distros: list[str], user_input: str) -> str: + """Format distro suggestions for error messages, showing closest matches first.""" + if not distros: + return " (no distros found)" + + import difflib + + # Get up to 3 closest matches with similarity threshold of 0.3 (lower = more permissive) + close_matches = difflib.get_close_matches(user_input, distros, n=3, cutoff=0.3) + display_distros = close_matches if close_matches else distros[:3] + + suggestions = [f" - {d}" for d in display_distros] + return "\n".join(suggestions) diff --git a/llama_stack/distribution/utils/context.py b/llama_stack/core/utils/context.py similarity index 100% rename from llama_stack/distribution/utils/context.py rename to llama_stack/core/utils/context.py diff --git a/llama_stack/distribution/utils/dynamic.py b/llama_stack/core/utils/dynamic.py similarity index 100% rename from llama_stack/distribution/utils/dynamic.py rename to llama_stack/core/utils/dynamic.py diff --git a/llama_stack/core/utils/exec.py b/llama_stack/core/utils/exec.py new file mode 100644 index 000000000..1b2b782fe --- /dev/null +++ b/llama_stack/core/utils/exec.py @@ -0,0 +1,96 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import logging +import os +import signal +import subprocess +import sys + +from termcolor import cprint + +log = logging.getLogger(__name__) + +import importlib + + +def formulate_run_args(image_type: str, image_name: str) -> list: + # Only venv is supported now + current_venv = os.environ.get("VIRTUAL_ENV") + env_name = image_name or current_venv + if not env_name: + cprint( + "No current virtual environment detected, please specify a virtual environment name with --image-name", + color="red", + file=sys.stderr, + ) + return [] + + cprint(f"Using virtual environment: {env_name}", file=sys.stderr) + + script = importlib.resources.files("llama_stack") / "core/start_stack.sh" + run_args = [ + script, + image_type, + env_name, + ] + + return run_args + + +def in_notebook(): + try: + from IPython import get_ipython + + ipython = get_ipython() + if ipython is None or "IPKernelApp" not in ipython.config: # pragma: no cover + return False + except ImportError: + return False + except AttributeError: + return False + return True + + +def run_command(command: list[str]) -> int: + """ + Run a command with interrupt handling and output capture. + Uses subprocess.run with direct stream piping for better performance. + + Args: + command (list): The command to run. + + Returns: + int: The return code of the command. + """ + original_sigint = signal.getsignal(signal.SIGINT) + ctrl_c_pressed = False + + def sigint_handler(signum, frame): + nonlocal ctrl_c_pressed + ctrl_c_pressed = True + log.info("\nCtrl-C detected. Aborting...") + + try: + # Set up the signal handler + signal.signal(signal.SIGINT, sigint_handler) + + # Run the command with stdout/stderr piped directly to system streams + result = subprocess.run( + command, + text=True, + check=False, + ) + return result.returncode + except subprocess.SubprocessError as e: + log.error(f"Subprocess error: {e}") + return 1 + except Exception as e: + log.exception(f"Unexpected error: {e}") + return 1 + finally: + # Restore the original signal handler + signal.signal(signal.SIGINT, original_sigint) diff --git a/llama_stack/distribution/utils/image_types.py b/llama_stack/core/utils/image_types.py similarity index 93% rename from llama_stack/distribution/utils/image_types.py rename to llama_stack/core/utils/image_types.py index 403c91ca6..9e140dc5c 100644 --- a/llama_stack/distribution/utils/image_types.py +++ b/llama_stack/core/utils/image_types.py @@ -9,5 +9,4 @@ import enum class LlamaStackImageType(enum.Enum): CONTAINER = "container" - CONDA = "conda" VENV = "venv" diff --git a/llama_stack/distribution/utils/model_utils.py b/llama_stack/core/utils/model_utils.py similarity index 100% rename from llama_stack/distribution/utils/model_utils.py rename to llama_stack/core/utils/model_utils.py diff --git a/llama_stack/distribution/utils/prompt_for_config.py b/llama_stack/core/utils/prompt_for_config.py similarity index 100% rename from llama_stack/distribution/utils/prompt_for_config.py rename to llama_stack/core/utils/prompt_for_config.py diff --git a/llama_stack/distribution/utils/serialize.py b/llama_stack/core/utils/serialize.py similarity index 100% rename from llama_stack/distribution/utils/serialize.py rename to llama_stack/core/utils/serialize.py diff --git a/llama_stack/distribution/routers/safety.py b/llama_stack/distribution/routers/safety.py deleted file mode 100644 index 26ee8e722..000000000 --- a/llama_stack/distribution/routers/safety.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Any - -from llama_stack.apis.inference import ( - Message, -) -from llama_stack.apis.safety import RunShieldResponse, Safety -from llama_stack.apis.shields import Shield -from llama_stack.log import get_logger -from llama_stack.providers.datatypes import RoutingTable - -logger = get_logger(name=__name__, category="core") - - -class SafetyRouter(Safety): - def __init__( - self, - routing_table: RoutingTable, - ) -> None: - logger.debug("Initializing SafetyRouter") - self.routing_table = routing_table - - async def initialize(self) -> None: - logger.debug("SafetyRouter.initialize") - pass - - async def shutdown(self) -> None: - logger.debug("SafetyRouter.shutdown") - pass - - async def register_shield( - self, - shield_id: str, - provider_shield_id: str | None = None, - provider_id: str | None = None, - params: dict[str, Any] | None = None, - ) -> Shield: - logger.debug(f"SafetyRouter.register_shield: {shield_id}") - return await self.routing_table.register_shield(shield_id, provider_shield_id, provider_id, params) - - async def run_shield( - self, - shield_id: str, - messages: list[Message], - params: dict[str, Any] = None, - ) -> RunShieldResponse: - logger.debug(f"SafetyRouter.run_shield: {shield_id}") - provider = await self.routing_table.get_provider_impl(shield_id) - return await provider.run_shield( - shield_id=shield_id, - messages=messages, - params=params, - ) diff --git a/llama_stack/distribution/utils/config_resolution.py b/llama_stack/distribution/utils/config_resolution.py deleted file mode 100644 index 7e8de1242..000000000 --- a/llama_stack/distribution/utils/config_resolution.py +++ /dev/null @@ -1,125 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from enum import StrEnum -from pathlib import Path - -from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR -from llama_stack.log import get_logger - -logger = get_logger(name=__name__, category="config_resolution") - - -TEMPLATE_DIR = Path(__file__).parent.parent.parent.parent / "llama_stack" / "templates" - - -class Mode(StrEnum): - RUN = "run" - BUILD = "build" - - -def resolve_config_or_template( - config_or_template: str, - mode: Mode = Mode.RUN, -) -> Path: - """ - Resolve a config/template argument to a concrete config file path. - - Args: - config_or_template: User input (file path, template name, or built distribution) - mode: Mode resolving for ("run", "build", "server") - - Returns: - Path to the resolved config file - - Raises: - ValueError: If resolution fails - """ - - # Strategy 1: Try as file path first - config_path = Path(config_or_template) - if config_path.exists() and config_path.is_file(): - logger.info(f"Using file path: {config_path}") - return config_path.resolve() - - # Strategy 2: Try as template name (if no .yaml extension) - if not config_or_template.endswith(".yaml"): - template_config = _get_template_config_path(config_or_template, mode) - if template_config.exists(): - logger.info(f"Using template: {template_config}") - return template_config - - # Strategy 3: Try as built distribution name - distrib_config = DISTRIBS_BASE_DIR / f"llamastack-{config_or_template}" / f"{config_or_template}-{mode}.yaml" - if distrib_config.exists(): - logger.info(f"Using built distribution: {distrib_config}") - return distrib_config - - distrib_config = DISTRIBS_BASE_DIR / f"{config_or_template}" / f"{config_or_template}-{mode}.yaml" - if distrib_config.exists(): - logger.info(f"Using built distribution: {distrib_config}") - return distrib_config - - # Strategy 4: Failed - provide helpful error - raise ValueError(_format_resolution_error(config_or_template, mode)) - - -def _get_template_config_path(template_name: str, mode: Mode) -> Path: - """Get the config file path for a template.""" - return TEMPLATE_DIR / template_name / f"{mode}.yaml" - - -def _format_resolution_error(config_or_template: str, mode: Mode) -> str: - """Format a helpful error message for resolution failures.""" - from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR - - template_path = _get_template_config_path(config_or_template, mode) - distrib_path = DISTRIBS_BASE_DIR / f"llamastack-{config_or_template}" / f"{config_or_template}-{mode}.yaml" - distrib_path2 = DISTRIBS_BASE_DIR / f"{config_or_template}" / f"{config_or_template}-{mode}.yaml" - - available_templates = _get_available_templates() - templates_str = ", ".join(available_templates) if available_templates else "none found" - - return f"""Could not resolve config or template '{config_or_template}'. - -Tried the following locations: - 1. As file path: {Path(config_or_template).resolve()} - 2. As template: {template_path} - 3. As built distribution: ({distrib_path}, {distrib_path2}) - -Available templates: {templates_str} - -Did you mean one of these templates? -{_format_template_suggestions(available_templates, config_or_template)} -""" - - -def _get_available_templates() -> list[str]: - """Get list of available template names.""" - if not TEMPLATE_DIR.exists() and not DISTRIBS_BASE_DIR.exists(): - return [] - - return list( - set( - [d.name for d in TEMPLATE_DIR.iterdir() if d.is_dir() and not d.name.startswith(".")] - + [d.name for d in DISTRIBS_BASE_DIR.iterdir() if d.is_dir() and not d.name.startswith(".")] - ) - ) - - -def _format_template_suggestions(templates: list[str], user_input: str) -> str: - """Format template suggestions for error messages, showing closest matches first.""" - if not templates: - return " (no templates found)" - - import difflib - - # Get up to 3 closest matches with similarity threshold of 0.3 (lower = more permissive) - close_matches = difflib.get_close_matches(user_input, templates, n=3, cutoff=0.3) - display_templates = close_matches if close_matches else templates[:3] - - suggestions = [f" - {t}" for t in display_templates] - return "\n".join(suggestions) diff --git a/llama_stack/distribution/utils/exec.py b/llama_stack/distribution/utils/exec.py deleted file mode 100644 index c646ae821..000000000 --- a/llama_stack/distribution/utils/exec.py +++ /dev/null @@ -1,143 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import logging -import os -import signal -import subprocess -import sys - -from termcolor import cprint - -log = logging.getLogger(__name__) - -import importlib -import json -from pathlib import Path - -from llama_stack.distribution.utils.image_types import LlamaStackImageType - - -def formulate_run_args(image_type: str, image_name: str) -> list[str]: - env_name = "" - - if image_type == LlamaStackImageType.CONDA.value: - current_conda_env = os.environ.get("CONDA_DEFAULT_ENV") - env_name = image_name or current_conda_env - if not env_name: - cprint( - "No current conda environment detected, please specify a conda environment name with --image-name", - color="red", - file=sys.stderr, - ) - return - - def get_conda_prefix(env_name): - # Conda "base" environment does not end with "base" in the - # prefix, so should be handled separately. - if env_name == "base": - return os.environ.get("CONDA_PREFIX") - # Get conda environments info - conda_env_info = json.loads(subprocess.check_output(["conda", "info", "--envs", "--json"]).decode()) - envs = conda_env_info["envs"] - for envpath in envs: - if os.path.basename(envpath) == env_name: - return envpath - return None - - cprint(f"Using conda environment: {env_name}", color="green", file=sys.stderr) - conda_prefix = get_conda_prefix(env_name) - if not conda_prefix: - cprint( - f"Conda environment {env_name} does not exist.", - color="red", - file=sys.stderr, - ) - return - - build_file = Path(conda_prefix) / "llamastack-build.yaml" - if not build_file.exists(): - cprint( - f"Build file {build_file} does not exist.\n\nPlease run `llama stack build` or specify the correct conda environment name with --image-name", - color="red", - file=sys.stderr, - ) - return - else: - # else must be venv since that is the only valid option left. - current_venv = os.environ.get("VIRTUAL_ENV") - env_name = image_name or current_venv - if not env_name: - cprint( - "No current virtual environment detected, please specify a virtual environment name with --image-name", - color="red", - file=sys.stderr, - ) - return - cprint(f"Using virtual environment: {env_name}", file=sys.stderr) - - script = importlib.resources.files("llama_stack") / "distribution/start_stack.sh" - run_args = [ - script, - image_type, - env_name, - ] - - return run_args - - -def in_notebook(): - try: - from IPython import get_ipython - - if "IPKernelApp" not in get_ipython().config: # pragma: no cover - return False - except ImportError: - return False - except AttributeError: - return False - return True - - -def run_command(command: list[str]) -> int: - """ - Run a command with interrupt handling and output capture. - Uses subprocess.run with direct stream piping for better performance. - - Args: - command (list): The command to run. - - Returns: - int: The return code of the command. - """ - original_sigint = signal.getsignal(signal.SIGINT) - ctrl_c_pressed = False - - def sigint_handler(signum, frame): - nonlocal ctrl_c_pressed - ctrl_c_pressed = True - log.info("\nCtrl-C detected. Aborting...") - - try: - # Set up the signal handler - signal.signal(signal.SIGINT, sigint_handler) - - # Run the command with stdout/stderr piped directly to system streams - result = subprocess.run( - command, - text=True, - check=False, - ) - return result.returncode - except subprocess.SubprocessError as e: - log.error(f"Subprocess error: {e}") - return 1 - except Exception as e: - log.exception(f"Unexpected error: {e}") - return 1 - finally: - # Restore the original signal handler - signal.signal(signal.SIGINT, original_sigint) diff --git a/llama_stack/templates/__init__.py b/llama_stack/distributions/__init__.py similarity index 100% rename from llama_stack/templates/__init__.py rename to llama_stack/distributions/__init__.py diff --git a/llama_stack/templates/ci-tests/__init__.py b/llama_stack/distributions/ci-tests/__init__.py similarity index 100% rename from llama_stack/templates/ci-tests/__init__.py rename to llama_stack/distributions/ci-tests/__init__.py diff --git a/llama_stack/templates/ci-tests/build.yaml b/llama_stack/distributions/ci-tests/build.yaml similarity index 97% rename from llama_stack/templates/ci-tests/build.yaml rename to llama_stack/distributions/ci-tests/build.yaml index b5df01923..2f9ae8682 100644 --- a/llama_stack/templates/ci-tests/build.yaml +++ b/llama_stack/distributions/ci-tests/build.yaml @@ -47,8 +47,7 @@ distribution_spec: - provider_type: remote::tavily-search - provider_type: inline::rag-runtime - provider_type: remote::model-context-protocol -image_type: conda -image_name: ci-tests +image_type: venv additional_pip_packages: - aiosqlite - asyncpg diff --git a/llama_stack/templates/ci-tests/ci_tests.py b/llama_stack/distributions/ci-tests/ci_tests.py similarity index 88% rename from llama_stack/templates/ci-tests/ci_tests.py rename to llama_stack/distributions/ci-tests/ci_tests.py index 49cb36e39..8fb61faca 100644 --- a/llama_stack/templates/ci-tests/ci_tests.py +++ b/llama_stack/distributions/ci-tests/ci_tests.py @@ -5,7 +5,7 @@ # the root directory of this source tree. -from llama_stack.templates.template import DistributionTemplate +from llama_stack.distributions.template import DistributionTemplate from ..starter.starter import get_distribution_template as get_starter_distribution_template diff --git a/llama_stack/templates/ci-tests/run.yaml b/llama_stack/distributions/ci-tests/run.yaml similarity index 98% rename from llama_stack/templates/ci-tests/run.yaml rename to llama_stack/distributions/ci-tests/run.yaml index 84eacae1f..188c66275 100644 --- a/llama_stack/templates/ci-tests/run.yaml +++ b/llama_stack/distributions/ci-tests/run.yaml @@ -19,7 +19,7 @@ providers: config: base_url: https://api.cerebras.ai api_key: ${env.CEREBRAS_API_KEY:=} - - provider_id: ollama + - provider_id: ${env.OLLAMA_URL:+ollama} provider_type: remote::ollama config: url: ${env.OLLAMA_URL:=http://localhost:11434} @@ -154,6 +154,7 @@ providers: checkpoint_format: huggingface distributed_backend: null device: cpu + dpo_output_dir: ~/.llama/distributions/ci-tests/dpo_output eval: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/dell/__init__.py b/llama_stack/distributions/dell/__init__.py similarity index 100% rename from llama_stack/templates/dell/__init__.py rename to llama_stack/distributions/dell/__init__.py diff --git a/llama_stack/templates/dell/build.yaml b/llama_stack/distributions/dell/build.yaml similarity index 96% rename from llama_stack/templates/dell/build.yaml rename to llama_stack/distributions/dell/build.yaml index 4a4ecb37c..acd5d827c 100644 --- a/llama_stack/templates/dell/build.yaml +++ b/llama_stack/distributions/dell/build.yaml @@ -29,8 +29,7 @@ distribution_spec: - provider_type: remote::brave-search - provider_type: remote::tavily-search - provider_type: inline::rag-runtime -image_type: conda -image_name: dell +image_type: venv additional_pip_packages: - aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/dell/dell.py b/llama_stack/distributions/dell/dell.py similarity index 97% rename from llama_stack/templates/dell/dell.py rename to llama_stack/distributions/dell/dell.py index 64e01535c..b561ea00e 100644 --- a/llama_stack/templates/dell/dell.py +++ b/llama_stack/distributions/dell/dell.py @@ -5,17 +5,17 @@ # the root directory of this source tree. from llama_stack.apis.models import ModelType -from llama_stack.distribution.datatypes import ( +from llama_stack.core.datatypes import ( BuildProvider, ModelInput, Provider, ShieldInput, ToolGroupInput, ) +from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings from llama_stack.providers.inline.inference.sentence_transformers import ( SentenceTransformersInferenceConfig, ) -from llama_stack.templates.template import DistributionTemplate, RunConfigSettings def get_distribution_template() -> DistributionTemplate: diff --git a/llama_stack/templates/dell/doc_template.md b/llama_stack/distributions/dell/doc_template.md similarity index 97% rename from llama_stack/templates/dell/doc_template.md rename to llama_stack/distributions/dell/doc_template.md index 6bdd7f81c..34b87c907 100644 --- a/llama_stack/templates/dell/doc_template.md +++ b/llama_stack/distributions/dell/doc_template.md @@ -141,7 +141,7 @@ docker run \ --pull always \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -v $HOME/.llama:/root/.llama \ - -v ./llama_stack/templates/tgi/run-with-safety.yaml:/root/my-run.yaml \ + -v ./llama_stack/distributions/tgi/run-with-safety.yaml:/root/my-run.yaml \ llamastack/distribution-{{ name }} \ --config /root/my-run.yaml \ --port $LLAMA_STACK_PORT \ @@ -157,7 +157,7 @@ docker run \ Make sure you have done `pip install llama-stack` and have the Llama Stack CLI available. ```bash -llama stack build --template {{ name }} --image-type conda +llama stack build --distro {{ name }} --image-type conda llama stack run {{ name }} --port $LLAMA_STACK_PORT \ --env INFERENCE_MODEL=$INFERENCE_MODEL \ diff --git a/llama_stack/templates/dell/run-with-safety.yaml b/llama_stack/distributions/dell/run-with-safety.yaml similarity index 100% rename from llama_stack/templates/dell/run-with-safety.yaml rename to llama_stack/distributions/dell/run-with-safety.yaml diff --git a/llama_stack/templates/dell/run.yaml b/llama_stack/distributions/dell/run.yaml similarity index 100% rename from llama_stack/templates/dell/run.yaml rename to llama_stack/distributions/dell/run.yaml diff --git a/llama_stack/templates/meta-reference-gpu/__init__.py b/llama_stack/distributions/meta-reference-gpu/__init__.py similarity index 100% rename from llama_stack/templates/meta-reference-gpu/__init__.py rename to llama_stack/distributions/meta-reference-gpu/__init__.py diff --git a/llama_stack/templates/meta-reference-gpu/build.yaml b/llama_stack/distributions/meta-reference-gpu/build.yaml similarity index 95% rename from llama_stack/templates/meta-reference-gpu/build.yaml rename to llama_stack/distributions/meta-reference-gpu/build.yaml index e929f1683..47e782c85 100644 --- a/llama_stack/templates/meta-reference-gpu/build.yaml +++ b/llama_stack/distributions/meta-reference-gpu/build.yaml @@ -28,8 +28,7 @@ distribution_spec: - provider_type: remote::tavily-search - provider_type: inline::rag-runtime - provider_type: remote::model-context-protocol -image_type: conda -image_name: meta-reference-gpu +image_type: venv additional_pip_packages: - aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/meta-reference-gpu/doc_template.md b/llama_stack/distributions/meta-reference-gpu/doc_template.md similarity index 97% rename from llama_stack/templates/meta-reference-gpu/doc_template.md rename to llama_stack/distributions/meta-reference-gpu/doc_template.md index 2ca6793d7..ff45c3826 100644 --- a/llama_stack/templates/meta-reference-gpu/doc_template.md +++ b/llama_stack/distributions/meta-reference-gpu/doc_template.md @@ -58,7 +58,7 @@ $ llama model list --downloaded ## Running the Distribution -You can do this via Conda (build code) or Docker which has a pre-built image. +You can do this via venv or Docker which has a pre-built image. ### Via Docker @@ -92,12 +92,12 @@ docker run \ --env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B ``` -### Via Conda +### Via venv Make sure you have done `uv pip install llama-stack` and have the Llama Stack CLI available. ```bash -llama stack build --template {{ name }} --image-type conda +llama stack build --distro {{ name }} --image-type venv llama stack run distributions/{{ name }}/run.yaml \ --port 8321 \ --env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct diff --git a/llama_stack/templates/meta-reference-gpu/meta_reference.py b/llama_stack/distributions/meta-reference-gpu/meta_reference.py similarity index 97% rename from llama_stack/templates/meta-reference-gpu/meta_reference.py rename to llama_stack/distributions/meta-reference-gpu/meta_reference.py index 981c66bf5..78bebb24c 100644 --- a/llama_stack/templates/meta-reference-gpu/meta_reference.py +++ b/llama_stack/distributions/meta-reference-gpu/meta_reference.py @@ -7,13 +7,14 @@ from pathlib import Path from llama_stack.apis.models import ModelType -from llama_stack.distribution.datatypes import ( +from llama_stack.core.datatypes import ( BuildProvider, ModelInput, Provider, ShieldInput, ToolGroupInput, ) +from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings from llama_stack.providers.inline.inference.meta_reference import ( MetaReferenceInferenceConfig, ) @@ -21,7 +22,6 @@ from llama_stack.providers.inline.inference.sentence_transformers import ( SentenceTransformersInferenceConfig, ) from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig -from llama_stack.templates.template import DistributionTemplate, RunConfigSettings def get_distribution_template() -> DistributionTemplate: diff --git a/llama_stack/templates/meta-reference-gpu/run-with-safety.yaml b/llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml similarity index 100% rename from llama_stack/templates/meta-reference-gpu/run-with-safety.yaml rename to llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml diff --git a/llama_stack/templates/meta-reference-gpu/run.yaml b/llama_stack/distributions/meta-reference-gpu/run.yaml similarity index 100% rename from llama_stack/templates/meta-reference-gpu/run.yaml rename to llama_stack/distributions/meta-reference-gpu/run.yaml diff --git a/llama_stack/templates/nvidia/__init__.py b/llama_stack/distributions/nvidia/__init__.py similarity index 100% rename from llama_stack/templates/nvidia/__init__.py rename to llama_stack/distributions/nvidia/__init__.py diff --git a/llama_stack/templates/nvidia/build.yaml b/llama_stack/distributions/nvidia/build.yaml similarity index 95% rename from llama_stack/templates/nvidia/build.yaml rename to llama_stack/distributions/nvidia/build.yaml index 7d9a6140f..f3e73a2c1 100644 --- a/llama_stack/templates/nvidia/build.yaml +++ b/llama_stack/distributions/nvidia/build.yaml @@ -23,8 +23,7 @@ distribution_spec: - provider_type: inline::basic tool_runtime: - provider_type: inline::rag-runtime -image_type: conda -image_name: nvidia +image_type: venv additional_pip_packages: - aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/nvidia/doc_template.md b/llama_stack/distributions/nvidia/doc_template.md similarity index 94% rename from llama_stack/templates/nvidia/doc_template.md rename to llama_stack/distributions/nvidia/doc_template.md index 5a180d49f..3884e6b51 100644 --- a/llama_stack/templates/nvidia/doc_template.md +++ b/llama_stack/distributions/nvidia/doc_template.md @@ -105,7 +105,7 @@ curl -X DELETE "$NEMO_URL/v1/deployment/model-deployments/meta/llama-3.1-8b-inst ## Running Llama Stack with NVIDIA -You can do this via Conda or venv (build code), or Docker which has a pre-built image. +You can do this via venv (build code), or Docker which has a pre-built image. ### Via Docker @@ -124,24 +124,13 @@ docker run \ --env NVIDIA_API_KEY=$NVIDIA_API_KEY ``` -### Via Conda - -```bash -INFERENCE_MODEL=meta-llama/Llama-3.1-8b-Instruct -llama stack build --template nvidia --image-type conda -llama stack run ./run.yaml \ - --port 8321 \ - --env NVIDIA_API_KEY=$NVIDIA_API_KEY \ - --env INFERENCE_MODEL=$INFERENCE_MODEL -``` - ### Via venv If you've set up your local development environment, you can also build the image using your local virtual environment. ```bash INFERENCE_MODEL=meta-llama/Llama-3.1-8b-Instruct -llama stack build --template nvidia --image-type venv +llama stack build --distro nvidia --image-type venv llama stack run ./run.yaml \ --port 8321 \ --env NVIDIA_API_KEY=$NVIDIA_API_KEY \ diff --git a/llama_stack/templates/nvidia/nvidia.py b/llama_stack/distributions/nvidia/nvidia.py similarity index 96% rename from llama_stack/templates/nvidia/nvidia.py rename to llama_stack/distributions/nvidia/nvidia.py index df82cf7c0..aedda0ae9 100644 --- a/llama_stack/templates/nvidia/nvidia.py +++ b/llama_stack/distributions/nvidia/nvidia.py @@ -6,13 +6,13 @@ from pathlib import Path -from llama_stack.distribution.datatypes import BuildProvider, ModelInput, Provider, ShieldInput, ToolGroupInput +from llama_stack.core.datatypes import BuildProvider, ModelInput, Provider, ShieldInput, ToolGroupInput +from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings, get_model_registry from llama_stack.providers.remote.datasetio.nvidia import NvidiaDatasetIOConfig from llama_stack.providers.remote.eval.nvidia import NVIDIAEvalConfig from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig -from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry def get_distribution_template() -> DistributionTemplate: diff --git a/llama_stack/templates/nvidia/run-with-safety.yaml b/llama_stack/distributions/nvidia/run-with-safety.yaml similarity index 100% rename from llama_stack/templates/nvidia/run-with-safety.yaml rename to llama_stack/distributions/nvidia/run-with-safety.yaml diff --git a/llama_stack/templates/nvidia/run.yaml b/llama_stack/distributions/nvidia/run.yaml similarity index 100% rename from llama_stack/templates/nvidia/run.yaml rename to llama_stack/distributions/nvidia/run.yaml diff --git a/llama_stack/templates/open-benchmark/__init__.py b/llama_stack/distributions/open-benchmark/__init__.py similarity index 100% rename from llama_stack/templates/open-benchmark/__init__.py rename to llama_stack/distributions/open-benchmark/__init__.py diff --git a/llama_stack/templates/open-benchmark/build.yaml b/llama_stack/distributions/open-benchmark/build.yaml similarity index 96% rename from llama_stack/templates/open-benchmark/build.yaml rename to llama_stack/distributions/open-benchmark/build.yaml index 1c7966cdd..6ff4155dc 100644 --- a/llama_stack/templates/open-benchmark/build.yaml +++ b/llama_stack/distributions/open-benchmark/build.yaml @@ -32,8 +32,7 @@ distribution_spec: - provider_type: remote::tavily-search - provider_type: inline::rag-runtime - provider_type: remote::model-context-protocol -image_type: conda -image_name: open-benchmark +image_type: venv additional_pip_packages: - aiosqlite - sqlalchemy[asyncio] diff --git a/llama_stack/templates/open-benchmark/open_benchmark.py b/llama_stack/distributions/open-benchmark/open_benchmark.py similarity index 99% rename from llama_stack/templates/open-benchmark/open_benchmark.py rename to llama_stack/distributions/open-benchmark/open_benchmark.py index 0a0d9fb14..af08ac7ba 100644 --- a/llama_stack/templates/open-benchmark/open_benchmark.py +++ b/llama_stack/distributions/open-benchmark/open_benchmark.py @@ -7,7 +7,7 @@ from llama_stack.apis.datasets import DatasetPurpose, URIDataSource from llama_stack.apis.models import ModelType -from llama_stack.distribution.datatypes import ( +from llama_stack.core.datatypes import ( BenchmarkInput, BuildProvider, DatasetInput, @@ -16,6 +16,11 @@ from llama_stack.distribution.datatypes import ( ShieldInput, ToolGroupInput, ) +from llama_stack.distributions.template import ( + DistributionTemplate, + RunConfigSettings, + get_model_registry, +) from llama_stack.providers.inline.vector_io.sqlite_vec.config import ( SQLiteVectorIOConfig, ) @@ -29,11 +34,6 @@ from llama_stack.providers.remote.vector_io.pgvector.config import ( PGVectorVectorIOConfig, ) from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry -from llama_stack.templates.template import ( - DistributionTemplate, - RunConfigSettings, - get_model_registry, -) def get_inference_providers() -> tuple[list[Provider], dict[str, list[ProviderModelEntry]]]: diff --git a/llama_stack/templates/open-benchmark/run.yaml b/llama_stack/distributions/open-benchmark/run.yaml similarity index 100% rename from llama_stack/templates/open-benchmark/run.yaml rename to llama_stack/distributions/open-benchmark/run.yaml diff --git a/llama_stack/templates/postgres-demo/__init__.py b/llama_stack/distributions/postgres-demo/__init__.py similarity index 100% rename from llama_stack/templates/postgres-demo/__init__.py rename to llama_stack/distributions/postgres-demo/__init__.py diff --git a/llama_stack/templates/postgres-demo/build.yaml b/llama_stack/distributions/postgres-demo/build.yaml similarity index 94% rename from llama_stack/templates/postgres-demo/build.yaml rename to llama_stack/distributions/postgres-demo/build.yaml index 5a1c9129a..e5a5d3b83 100644 --- a/llama_stack/templates/postgres-demo/build.yaml +++ b/llama_stack/distributions/postgres-demo/build.yaml @@ -18,8 +18,7 @@ distribution_spec: - provider_type: remote::tavily-search - provider_type: inline::rag-runtime - provider_type: remote::model-context-protocol -image_type: conda -image_name: postgres-demo +image_type: venv additional_pip_packages: - asyncpg - psycopg2-binary diff --git a/llama_stack/templates/postgres-demo/postgres_demo.py b/llama_stack/distributions/postgres-demo/postgres_demo.py similarity index 96% rename from llama_stack/templates/postgres-demo/postgres_demo.py rename to llama_stack/distributions/postgres-demo/postgres_demo.py index d9ded9a86..c04cfedfa 100644 --- a/llama_stack/templates/postgres-demo/postgres_demo.py +++ b/llama_stack/distributions/postgres-demo/postgres_demo.py @@ -6,22 +6,22 @@ from llama_stack.apis.models import ModelType -from llama_stack.distribution.datatypes import ( +from llama_stack.core.datatypes import ( BuildProvider, ModelInput, Provider, ShieldInput, ToolGroupInput, ) +from llama_stack.distributions.template import ( + DistributionTemplate, + RunConfigSettings, +) from llama_stack.providers.inline.inference.sentence_transformers import SentenceTransformersInferenceConfig from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig from llama_stack.providers.utils.kvstore.config import PostgresKVStoreConfig from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig -from llama_stack.templates.template import ( - DistributionTemplate, - RunConfigSettings, -) def get_distribution_template() -> DistributionTemplate: @@ -123,7 +123,7 @@ def get_distribution_template() -> DistributionTemplate: config=dict( service_name="${env.OTEL_SERVICE_NAME:=\u200b}", sinks="${env.TELEMETRY_SINKS:=console,otel_trace}", - otel_trace_endpoint="${env.OTEL_TRACE_ENDPOINT:=http://localhost:4318/v1/traces}", + otel_exporter_otlp_endpoint="${env.OTEL_EXPORTER_OTLP_ENDPOINT:=http://localhost:4318/v1/traces}", ), ) ], diff --git a/llama_stack/templates/postgres-demo/run.yaml b/llama_stack/distributions/postgres-demo/run.yaml similarity index 96% rename from llama_stack/templates/postgres-demo/run.yaml rename to llama_stack/distributions/postgres-demo/run.yaml index 747b7dc53..0cf0e82e6 100644 --- a/llama_stack/templates/postgres-demo/run.yaml +++ b/llama_stack/distributions/postgres-demo/run.yaml @@ -55,7 +55,7 @@ providers: config: service_name: "${env.OTEL_SERVICE_NAME:=\u200B}" sinks: ${env.TELEMETRY_SINKS:=console,otel_trace} - otel_trace_endpoint: ${env.OTEL_TRACE_ENDPOINT:=http://localhost:4318/v1/traces} + otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=http://localhost:4318/v1/traces} tool_runtime: - provider_id: brave-search provider_type: remote::brave-search diff --git a/llama_stack/templates/starter/__init__.py b/llama_stack/distributions/starter/__init__.py similarity index 100% rename from llama_stack/templates/starter/__init__.py rename to llama_stack/distributions/starter/__init__.py diff --git a/llama_stack/templates/starter/build.yaml b/llama_stack/distributions/starter/build.yaml similarity index 97% rename from llama_stack/templates/starter/build.yaml rename to llama_stack/distributions/starter/build.yaml index 1c67c433e..f95a03a9e 100644 --- a/llama_stack/templates/starter/build.yaml +++ b/llama_stack/distributions/starter/build.yaml @@ -47,8 +47,7 @@ distribution_spec: - provider_type: remote::tavily-search - provider_type: inline::rag-runtime - provider_type: remote::model-context-protocol -image_type: conda -image_name: starter +image_type: venv additional_pip_packages: - aiosqlite - asyncpg diff --git a/llama_stack/templates/starter/run.yaml b/llama_stack/distributions/starter/run.yaml similarity index 98% rename from llama_stack/templates/starter/run.yaml rename to llama_stack/distributions/starter/run.yaml index 0b7e71a75..8bd737686 100644 --- a/llama_stack/templates/starter/run.yaml +++ b/llama_stack/distributions/starter/run.yaml @@ -19,7 +19,7 @@ providers: config: base_url: https://api.cerebras.ai api_key: ${env.CEREBRAS_API_KEY:=} - - provider_id: ollama + - provider_id: ${env.OLLAMA_URL:+ollama} provider_type: remote::ollama config: url: ${env.OLLAMA_URL:=http://localhost:11434} @@ -154,6 +154,7 @@ providers: checkpoint_format: huggingface distributed_backend: null device: cpu + dpo_output_dir: ~/.llama/distributions/starter/dpo_output eval: - provider_id: meta-reference provider_type: inline::meta-reference diff --git a/llama_stack/templates/starter/starter.py b/llama_stack/distributions/starter/starter.py similarity index 97% rename from llama_stack/templates/starter/starter.py rename to llama_stack/distributions/starter/starter.py index d0782797f..a970f2d1c 100644 --- a/llama_stack/templates/starter/starter.py +++ b/llama_stack/distributions/starter/starter.py @@ -7,14 +7,18 @@ from typing import Any -from llama_stack.distribution.datatypes import ( +from llama_stack.core.datatypes import ( BuildProvider, Provider, ProviderSpec, ShieldInput, ToolGroupInput, ) -from llama_stack.distribution.utils.dynamic import instantiate_class_type +from llama_stack.core.utils.dynamic import instantiate_class_type +from llama_stack.distributions.template import ( + DistributionTemplate, + RunConfigSettings, +) from llama_stack.providers.datatypes import RemoteProviderSpec from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig from llama_stack.providers.inline.inference.sentence_transformers import ( @@ -33,10 +37,6 @@ from llama_stack.providers.remote.vector_io.pgvector.config import ( PGVectorVectorIOConfig, ) from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig -from llama_stack.templates.template import ( - DistributionTemplate, - RunConfigSettings, -) def _get_config_for_provider(provider_spec: ProviderSpec) -> dict[str, Any]: @@ -66,6 +66,7 @@ ENABLED_INFERENCE_PROVIDERS = [ ] INFERENCE_PROVIDER_IDS = { + "ollama": "${env.OLLAMA_URL:+ollama}", "vllm": "${env.VLLM_URL:+vllm}", "tgi": "${env.TGI_URL:+tgi}", "cerebras": "${env.CEREBRAS_API_KEY:+cerebras}", diff --git a/llama_stack/templates/template.py b/llama_stack/distributions/template.py similarity index 98% rename from llama_stack/templates/template.py rename to llama_stack/distributions/template.py index 084996cd4..d564312dc 100644 --- a/llama_stack/templates/template.py +++ b/llama_stack/distributions/template.py @@ -14,7 +14,7 @@ from pydantic import BaseModel, Field from llama_stack.apis.datasets import DatasetPurpose from llama_stack.apis.models import ModelType -from llama_stack.distribution.datatypes import ( +from llama_stack.core.datatypes import ( LLAMA_STACK_RUN_CONFIG_VERSION, Api, BenchmarkInput, @@ -27,8 +27,9 @@ from llama_stack.distribution.datatypes import ( ShieldInput, ToolGroupInput, ) -from llama_stack.distribution.distribution import get_provider_registry -from llama_stack.distribution.utils.dynamic import instantiate_class_type +from llama_stack.core.distribution import get_provider_registry +from llama_stack.core.utils.dynamic import instantiate_class_type +from llama_stack.core.utils.image_types import LlamaStackImageType from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig from llama_stack.providers.utils.kvstore.config import get_pip_packages as get_kv_pip_packages @@ -314,8 +315,7 @@ class DistributionTemplate(BaseModel): container_image=self.container_image, providers=build_providers, ), - image_type="conda", - image_name=self.name, + image_type=LlamaStackImageType.VENV.value, # default to venv additional_pip_packages=sorted(set(additional_pip_packages)), ) diff --git a/llama_stack/templates/watsonx/__init__.py b/llama_stack/distributions/watsonx/__init__.py similarity index 100% rename from llama_stack/templates/watsonx/__init__.py rename to llama_stack/distributions/watsonx/__init__.py diff --git a/llama_stack/templates/watsonx/build.yaml b/llama_stack/distributions/watsonx/build.yaml similarity index 76% rename from llama_stack/templates/watsonx/build.yaml rename to llama_stack/distributions/watsonx/build.yaml index bc992f0c7..bf4be7eaf 100644 --- a/llama_stack/templates/watsonx/build.yaml +++ b/llama_stack/distributions/watsonx/build.yaml @@ -35,16 +35,11 @@ distribution_spec: - provider_id: braintrust provider_type: inline::braintrust tool_runtime: - - provider_id: brave-search - provider_type: remote::brave-search - - provider_id: tavily-search - provider_type: remote::tavily-search - - provider_id: rag-runtime - provider_type: inline::rag-runtime - - provider_id: model-context-protocol - provider_type: remote::model-context-protocol -image_type: conda -image_name: watsonx + - provider_type: remote::brave-search + - provider_type: remote::tavily-search + - provider_type: inline::rag-runtime + - provider_type: remote::model-context-protocol +image_type: venv additional_pip_packages: - sqlalchemy[asyncio] - aiosqlite diff --git a/llama_stack/templates/watsonx/run.yaml b/llama_stack/distributions/watsonx/run.yaml similarity index 100% rename from llama_stack/templates/watsonx/run.yaml rename to llama_stack/distributions/watsonx/run.yaml diff --git a/llama_stack/templates/watsonx/watsonx.py b/llama_stack/distributions/watsonx/watsonx.py similarity index 95% rename from llama_stack/templates/watsonx/watsonx.py rename to llama_stack/distributions/watsonx/watsonx.py index 5d8332c4f..1ef2ef339 100644 --- a/llama_stack/templates/watsonx/watsonx.py +++ b/llama_stack/distributions/watsonx/watsonx.py @@ -7,13 +7,13 @@ from pathlib import Path from llama_stack.apis.models import ModelType -from llama_stack.distribution.datatypes import BuildProvider, ModelInput, Provider, ToolGroupInput +from llama_stack.core.datatypes import BuildProvider, ModelInput, Provider, ToolGroupInput +from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings, get_model_registry from llama_stack.providers.inline.inference.sentence_transformers import ( SentenceTransformersInferenceConfig, ) from llama_stack.providers.remote.inference.watsonx import WatsonXConfig from llama_stack.providers.remote.inference.watsonx.models import MODEL_ENTRIES -from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry def get_distribution_template() -> DistributionTemplate: diff --git a/llama_stack/log.py b/llama_stack/log.py index fb6fa85f9..ab53e08c0 100644 --- a/llama_stack/log.py +++ b/llama_stack/log.py @@ -15,7 +15,7 @@ from rich.errors import MarkupError from rich.logging import RichHandler from termcolor import cprint -from .distribution.datatypes import LoggingConfig +from llama_stack.core.datatypes import LoggingConfig # Default log level DEFAULT_LOG_LEVEL = logging.INFO diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index f9f463bf9..5e15dd8e1 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -65,6 +65,8 @@ class ModelsProtocolPrivate(Protocol): class ShieldsProtocolPrivate(Protocol): async def register_shield(self, shield: Shield) -> None: ... + async def unregister_shield(self, identifier: str) -> None: ... + class VectorDBsProtocolPrivate(Protocol): async def register_vector_db(self, vector_db: VectorDB) -> None: ... diff --git a/llama_stack/providers/inline/agents/meta_reference/__init__.py b/llama_stack/providers/inline/agents/meta_reference/__init__.py index 4a77e65b9..334c32e15 100644 --- a/llama_stack/providers/inline/agents/meta_reference/__init__.py +++ b/llama_stack/providers/inline/agents/meta_reference/__init__.py @@ -6,7 +6,7 @@ from typing import Any -from llama_stack.distribution.datatypes import AccessRule, Api +from llama_stack.core.datatypes import AccessRule, Api from .config import MetaReferenceAgentsImplConfig diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 3c34c71fb..5f7c90879 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -44,6 +44,7 @@ from llama_stack.apis.common.content_types import ( ToolCallDelta, ToolCallParseStatus, ) +from llama_stack.apis.common.errors import SessionNotFoundError from llama_stack.apis.inference import ( ChatCompletionResponseEventType, CompletionMessage, @@ -61,7 +62,7 @@ from llama_stack.apis.inference import ( from llama_stack.apis.safety import Safety from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime from llama_stack.apis.vector_io import VectorIO -from llama_stack.distribution.datatypes import AccessRule +from llama_stack.core.datatypes import AccessRule from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import ( BuiltinTool, @@ -214,7 +215,7 @@ class ChatAgent(ShieldRunnerMixin): is_resume = isinstance(request, AgentTurnResumeRequest) session_info = await self.storage.get_session_info(request.session_id) if session_info is None: - raise ValueError(f"Session {request.session_id} not found") + raise SessionNotFoundError(request.session_id) turns = await self.storage.get_session_turns(request.session_id) if is_resume and len(turns) == 0: diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index 4d0c429bd..15695ec48 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -41,7 +41,7 @@ from llama_stack.apis.inference import ( from llama_stack.apis.safety import Safety from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.apis.vector_io import VectorIO -from llama_stack.distribution.datatypes import AccessRule +from llama_stack.core.datatypes import AccessRule from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl from llama_stack.providers.utils.pagination import paginate_records from llama_stack.providers.utils.responses.responses_store import ResponsesStore @@ -230,8 +230,6 @@ class MetaReferenceAgentsImpl(Agents): agent = await self._get_agent_impl(agent_id) session_info = await agent.storage.get_session_info(session_id) - if session_info is None: - raise ValueError(f"Session {session_id} not found") turns = await agent.storage.get_session_turns(session_id) if turn_ids: turns = [turn for turn in turns if turn.turn_id in turn_ids] @@ -244,9 +242,6 @@ class MetaReferenceAgentsImpl(Agents): async def delete_agents_session(self, agent_id: str, session_id: str) -> None: agent = await self._get_agent_impl(agent_id) - session_info = await agent.storage.get_session_info(session_id) - if session_info is None: - raise ValueError(f"Session {session_id} not found") # Delete turns first, then the session await agent.storage.delete_session_turns(session_id) diff --git a/llama_stack/providers/inline/agents/meta_reference/persistence.py b/llama_stack/providers/inline/agents/meta_reference/persistence.py index 437d617ad..7a8d99b78 100644 --- a/llama_stack/providers/inline/agents/meta_reference/persistence.py +++ b/llama_stack/providers/inline/agents/meta_reference/persistence.py @@ -10,10 +10,11 @@ import uuid from datetime import UTC, datetime from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn -from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed -from llama_stack.distribution.access_control.datatypes import AccessRule -from llama_stack.distribution.datatypes import User -from llama_stack.distribution.request_headers import get_authenticated_user +from llama_stack.apis.common.errors import SessionNotFoundError +from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed +from llama_stack.core.access_control.datatypes import AccessRule +from llama_stack.core.datatypes import User +from llama_stack.core.request_headers import get_authenticated_user from llama_stack.providers.utils.kvstore import KVStore log = logging.getLogger(__name__) @@ -61,12 +62,12 @@ class AgentPersistence: ) return session_id - async def get_session_info(self, session_id: str) -> AgentSessionInfo | None: + async def get_session_info(self, session_id: str) -> AgentSessionInfo: value = await self.kvstore.get( key=f"session:{self.agent_id}:{session_id}", ) if not value: - return None + raise SessionNotFoundError(session_id) session_info = AgentSessionInfo(**json.loads(value)) @@ -95,7 +96,7 @@ class AgentPersistence: async def add_vector_db_to_session(self, session_id: str, vector_db_id: str): session_info = await self.get_session_if_accessible(session_id) if session_info is None: - raise ValueError(f"Session {session_id} not found or access denied") + raise SessionNotFoundError(session_id) session_info.vector_db_id = vector_db_id await self.kvstore.set( @@ -105,7 +106,7 @@ class AgentPersistence: async def add_turn_to_session(self, session_id: str, turn: Turn): if not await self.get_session_if_accessible(session_id): - raise ValueError(f"Session {session_id} not found or access denied") + raise SessionNotFoundError(session_id) await self.kvstore.set( key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}", @@ -114,7 +115,7 @@ class AgentPersistence: async def get_session_turns(self, session_id: str) -> list[Turn]: if not await self.get_session_if_accessible(session_id): - raise ValueError(f"Session {session_id} not found or access denied") + raise SessionNotFoundError(session_id) values = await self.kvstore.values_in_range( start_key=f"session:{self.agent_id}:{session_id}:", @@ -137,7 +138,7 @@ class AgentPersistence: async def get_session_turn(self, session_id: str, turn_id: str) -> Turn | None: if not await self.get_session_if_accessible(session_id): - raise ValueError(f"Session {session_id} not found or access denied") + raise SessionNotFoundError(session_id) value = await self.kvstore.get( key=f"session:{self.agent_id}:{session_id}:{turn_id}", @@ -148,7 +149,7 @@ class AgentPersistence: async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep): if not await self.get_session_if_accessible(session_id): - raise ValueError(f"Session {session_id} not found or access denied") + raise SessionNotFoundError(session_id) await self.kvstore.set( key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}", @@ -166,7 +167,7 @@ class AgentPersistence: async def set_num_infer_iters_in_turn(self, session_id: str, turn_id: str, num_infer_iters: int): if not await self.get_session_if_accessible(session_id): - raise ValueError(f"Session {session_id} not found or access denied") + raise SessionNotFoundError(session_id) await self.kvstore.set( key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}", @@ -218,6 +219,6 @@ class AgentPersistence: """ session_info = await self.get_session_info(session_id) if session_info is None: - raise ValueError(f"Session {session_id} not found") + raise SessionNotFoundError(session_id) await self.kvstore.delete(key=f"session:{self.agent_id}:{session_id}") diff --git a/llama_stack/providers/inline/datasetio/localfs/datasetio.py b/llama_stack/providers/inline/datasetio/localfs/datasetio.py index da71ecb17..e8ebeb30d 100644 --- a/llama_stack/providers/inline/datasetio/localfs/datasetio.py +++ b/llama_stack/providers/inline/datasetio/localfs/datasetio.py @@ -5,8 +5,6 @@ # the root directory of this source tree. from typing import Any -import pandas - from llama_stack.apis.common.responses import PaginatedResponse from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Dataset @@ -44,6 +42,8 @@ class PandasDataframeDataset: if self.dataset_def.source.type == "uri": self.df = await get_dataframe_from_uri(self.dataset_def.source.uri) elif self.dataset_def.source.type == "rows": + import pandas + self.df = pandas.DataFrame(self.dataset_def.source.rows) else: raise ValueError(f"Unsupported dataset source type: {self.dataset_def.source.type}") @@ -103,6 +103,8 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): return paginate_records(records, start_index, limit) async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None: + import pandas + dataset_def = self.dataset_infos[dataset_id] dataset_impl = PandasDataframeDataset(dataset_def) await dataset_impl.load() diff --git a/llama_stack/providers/inline/eval/meta_reference/__init__.py b/llama_stack/providers/inline/eval/meta_reference/__init__.py index 7afe7f33b..cf2578a72 100644 --- a/llama_stack/providers/inline/eval/meta_reference/__init__.py +++ b/llama_stack/providers/inline/eval/meta_reference/__init__.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from typing import Any -from llama_stack.distribution.datatypes import Api +from llama_stack.core.datatypes import Api from .config import MetaReferenceEvalConfig diff --git a/llama_stack/providers/inline/files/localfs/__init__.py b/llama_stack/providers/inline/files/localfs/__init__.py index 71664efad..363b6f04c 100644 --- a/llama_stack/providers/inline/files/localfs/__init__.py +++ b/llama_stack/providers/inline/files/localfs/__init__.py @@ -6,7 +6,7 @@ from typing import Any -from llama_stack.distribution.datatypes import AccessRule, Api +from llama_stack.core.datatypes import AccessRule, Api from .config import LocalfsFilesImplConfig from .files import LocalfsFilesImpl diff --git a/llama_stack/providers/inline/files/localfs/files.py b/llama_stack/providers/inline/files/localfs/files.py index 433762c5a..1e9dca3b5 100644 --- a/llama_stack/providers/inline/files/localfs/files.py +++ b/llama_stack/providers/inline/files/localfs/files.py @@ -19,7 +19,7 @@ from llama_stack.apis.files import ( OpenAIFileObject, OpenAIFilePurpose, ) -from llama_stack.distribution.datatypes import AccessRule +from llama_stack.core.datatypes import AccessRule from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl diff --git a/llama_stack/providers/inline/inference/meta_reference/common.py b/llama_stack/providers/inline/inference/meta_reference/common.py index beb0d39d4..1e164430d 100644 --- a/llama_stack/providers/inline/inference/meta_reference/common.py +++ b/llama_stack/providers/inline/inference/meta_reference/common.py @@ -6,7 +6,7 @@ from pathlib import Path -from llama_stack.distribution.utils.model_utils import model_local_dir +from llama_stack.core.utils.model_utils import model_local_dir def model_checkpoint_dir(model_id) -> str: diff --git a/llama_stack/providers/inline/post_training/huggingface/__init__.py b/llama_stack/providers/inline/post_training/huggingface/__init__.py index cc1a671c1..96c45cc4f 100644 --- a/llama_stack/providers/inline/post_training/huggingface/__init__.py +++ b/llama_stack/providers/inline/post_training/huggingface/__init__.py @@ -6,7 +6,7 @@ from typing import Any -from llama_stack.distribution.datatypes import Api +from llama_stack.core.datatypes import Api from .config import HuggingFacePostTrainingConfig diff --git a/llama_stack/providers/inline/post_training/huggingface/config.py b/llama_stack/providers/inline/post_training/huggingface/config.py index 06c6d8073..04e286ff0 100644 --- a/llama_stack/providers/inline/post_training/huggingface/config.py +++ b/llama_stack/providers/inline/post_training/huggingface/config.py @@ -67,6 +67,17 @@ class HuggingFacePostTrainingConfig(BaseModel): # Can improve data transfer speed to GPU but uses more memory dataloader_pin_memory: bool = True + # DPO-specific parameters + dpo_beta: float = 0.1 + use_reference_model: bool = True + dpo_loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"] = "sigmoid" + dpo_output_dir: str + @classmethod def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]: - return {"checkpoint_format": "huggingface", "distributed_backend": None, "device": "cpu"} + return { + "checkpoint_format": "huggingface", + "distributed_backend": None, + "device": "cpu", + "dpo_output_dir": __distro_dir__ + "/dpo_output", + } diff --git a/llama_stack/providers/inline/post_training/huggingface/post_training.py b/llama_stack/providers/inline/post_training/huggingface/post_training.py index 0b2760792..22ace1ae0 100644 --- a/llama_stack/providers/inline/post_training/huggingface/post_training.py +++ b/llama_stack/providers/inline/post_training/huggingface/post_training.py @@ -22,12 +22,8 @@ from llama_stack.apis.post_training import ( from llama_stack.providers.inline.post_training.huggingface.config import ( HuggingFacePostTrainingConfig, ) -from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device import ( - HFFinetuningSingleDevice, -) from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus -from llama_stack.schema_utils import webmethod class TrainingArtifactType(Enum): @@ -36,6 +32,7 @@ class TrainingArtifactType(Enum): _JOB_TYPE_SUPERVISED_FINE_TUNE = "supervised-fine-tune" +_JOB_TYPE_DPO_TRAINING = "dpo-training" class HuggingFacePostTrainingImpl: @@ -81,6 +78,10 @@ class HuggingFacePostTrainingImpl: algorithm_config: AlgorithmConfig | None = None, ) -> PostTrainingJob: async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb): + from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device import ( + HFFinetuningSingleDevice, + ) + on_log_message_cb("Starting HF finetuning") recipe = HFFinetuningSingleDevice( @@ -119,12 +120,41 @@ class HuggingFacePostTrainingImpl: hyperparam_search_config: dict[str, Any], logger_config: dict[str, Any], ) -> PostTrainingJob: - raise NotImplementedError("DPO alignment is not implemented yet") + async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb): + from llama_stack.providers.inline.post_training.huggingface.recipes.finetune_single_device_dpo import ( + HFDPOAlignmentSingleDevice, + ) - async def get_training_jobs(self) -> ListPostTrainingJobsResponse: - return ListPostTrainingJobsResponse( - data=[PostTrainingJob(job_uuid=job.id) for job in self._scheduler.get_jobs()] - ) + on_log_message_cb("Starting HF DPO alignment") + + recipe = HFDPOAlignmentSingleDevice( + job_uuid=job_uuid, + datasetio_api=self.datasetio_api, + datasets_api=self.datasets_api, + ) + + resources_allocated, checkpoints = await recipe.train( + model=finetuned_model, + output_dir=f"{self.config.dpo_output_dir}/{job_uuid}", + job_uuid=job_uuid, + dpo_config=algorithm_config, + config=training_config, + provider_config=self.config, + ) + + on_artifact_collected_cb(self._resources_stats_to_artifact(resources_allocated)) + if checkpoints: + for checkpoint in checkpoints: + artifact = self._checkpoint_to_artifact(checkpoint) + on_artifact_collected_cb(artifact) + else: + on_log_message_cb("Warning: No checkpoints were saved during DPO training") + + on_status_change_cb(SchedulerJobStatus.completed) + on_log_message_cb("HF DPO alignment completed") + + job_uuid = self._scheduler.schedule(_JOB_TYPE_DPO_TRAINING, job_uuid, handler) + return PostTrainingJob(job_uuid=job_uuid) @staticmethod def _get_artifacts_metadata_by_type(job, artifact_type): @@ -139,7 +169,6 @@ class HuggingFacePostTrainingImpl: data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value) return data[0] if data else None - @webmethod(route="/post-training/job/status") async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse | None: job = self._scheduler.get_job(job_uuid) @@ -166,11 +195,14 @@ class HuggingFacePostTrainingImpl: resources_allocated=self._get_resources_allocated(job), ) - @webmethod(route="/post-training/job/cancel") async def cancel_training_job(self, job_uuid: str) -> None: self._scheduler.cancel(job_uuid) - @webmethod(route="/post-training/job/artifacts") async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse | None: job = self._scheduler.get_job(job_uuid) return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job)) + + async def get_training_jobs(self) -> ListPostTrainingJobsResponse: + return ListPostTrainingJobsResponse( + data=[PostTrainingJob(job_uuid=job.id) for job in self._scheduler.get_jobs()] + ) diff --git a/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py b/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py index 2a024eb25..2574b995b 100644 --- a/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py +++ b/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py @@ -8,30 +8,13 @@ import gc import json import logging import multiprocessing -import os -import signal -import sys -from datetime import UTC, datetime from pathlib import Path from typing import Any -import psutil - -from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device - -# Set tokenizer parallelism environment variable -os.environ["TOKENIZERS_PARALLELISM"] = "false" - -# Force PyTorch to use OpenBLAS instead of MKL -os.environ["MKL_THREADING_LAYER"] = "GNU" -os.environ["MKL_SERVICE_FORCE_INTEL"] = "0" -os.environ["MKL_NUM_THREADS"] = "1" - import torch from datasets import Dataset from peft import LoraConfig from transformers import ( - AutoConfig, AutoModelForCausalLM, AutoTokenizer, ) @@ -45,93 +28,25 @@ from llama_stack.apis.post_training import ( LoraFinetuningConfig, TrainingConfig, ) +from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device from ..config import HuggingFacePostTrainingConfig +from ..utils import ( + calculate_training_steps, + create_checkpoints, + get_memory_stats, + get_save_strategy, + load_model, + load_rows_from_dataset, + setup_environment, + setup_signal_handlers, + setup_torch_device, + split_dataset, +) logger = logging.getLogger(__name__) -def get_gb(to_convert: int) -> str: - """Converts memory stats to GB and formats to 2 decimal places. - Args: - to_convert: Memory value in bytes - Returns: - str: Memory value in GB formatted to 2 decimal places - """ - return f"{(to_convert / (1024**3)):.2f}" - - -def get_memory_stats(device: torch.device) -> dict[str, Any]: - """Get memory statistics for the given device.""" - stats = { - "system_memory": { - "total": get_gb(psutil.virtual_memory().total), - "available": get_gb(psutil.virtual_memory().available), - "used": get_gb(psutil.virtual_memory().used), - "percent": psutil.virtual_memory().percent, - } - } - - if device.type == "cuda": - stats["device_memory"] = { - "allocated": get_gb(torch.cuda.memory_allocated(device)), - "reserved": get_gb(torch.cuda.memory_reserved(device)), - "max_allocated": get_gb(torch.cuda.max_memory_allocated(device)), - } - elif device.type == "mps": - # MPS doesn't provide direct memory stats, but we can track system memory - stats["device_memory"] = { - "note": "MPS memory stats not directly available", - "system_memory_used": get_gb(psutil.virtual_memory().used), - } - elif device.type == "cpu": - # For CPU, we track process memory usage - process = psutil.Process() - stats["device_memory"] = { - "process_rss": get_gb(process.memory_info().rss), - "process_vms": get_gb(process.memory_info().vms), - "process_percent": process.memory_percent(), - } - - return stats - - -def setup_torch_device(device_str: str) -> torch.device: - """Initialize and validate a PyTorch device. - This function handles device initialization and validation for different device types: - - CUDA: Validates CUDA availability and handles device selection - - MPS: Validates MPS availability for Apple Silicon - - CPU: Basic validation - - HPU: Raises error as it's not supported - Args: - device_str: String specifying the device ('cuda', 'cpu', 'mps') - Returns: - torch.device: The initialized and validated device - Raises: - RuntimeError: If device initialization fails or device is not supported - """ - try: - device = torch.device(device_str) - except RuntimeError as e: - raise RuntimeError(f"Error getting Torch Device {str(e)}") from e - - # Validate device capabilities - if device.type == "cuda": - if not torch.cuda.is_available(): - raise RuntimeError( - f"{device.type}: Torch has no CUDA/ROCm support or could not detect a compatible device." - ) - if device.index is None: - device = torch.device(device.type, torch.cuda.current_device()) - elif device.type == "mps": - if not torch.backends.mps.is_available(): - raise RuntimeError(f"{device.type}: Torch has no MPS support or could not detect a compatible device.") - elif device.type == "hpu": - raise RuntimeError(f"{device.type}: training does not support Intel Gaudi.") - - return device - - class HFFinetuningSingleDevice: def __init__( self, @@ -262,19 +177,6 @@ class HFFinetuningSingleDevice: remove_columns=ds.column_names, ) - async def _setup_data(self, dataset_id: str) -> list[dict[str, Any]]: - """Load dataset from llama stack dataset provider""" - try: - all_rows = await self.datasetio_api.iterrows( - dataset_id=dataset_id, - limit=-1, - ) - if not isinstance(all_rows.data, list): - raise RuntimeError("Expected dataset data to be a list") - return all_rows.data - except Exception as e: - raise RuntimeError(f"Failed to load dataset: {str(e)}") from e - def _run_training_sync( self, model: str, @@ -327,7 +229,7 @@ class HFFinetuningSingleDevice: # Load dataset logger.info(f"Loading dataset: {config.data_config.dataset_id}") - rows = await self._setup_data(config.data_config.dataset_id) + rows = await load_rows_from_dataset(self.datasetio_api, config.data_config.dataset_id) if not self.validate_dataset_format(rows): raise ValueError("Dataset is missing required fields: input_query, expected_answer, chat_completion_input") logger.info(f"Loaded {len(rows)} rows from dataset") @@ -369,47 +271,10 @@ class HFFinetuningSingleDevice: raise ValueError(f"Failed to create dataset: {str(e)}") from e # Split dataset - logger.info("Splitting dataset into train and validation sets") - train_val_split = ds.train_test_split(test_size=0.1, seed=42) - train_dataset = train_val_split["train"] - eval_dataset = train_val_split["test"] - logger.info(f"Split dataset into {len(train_dataset)} training and {len(eval_dataset)} validation examples") + train_dataset, eval_dataset = split_dataset(ds) return train_dataset, eval_dataset, tokenizer - def load_model( - self, - model: str, - device: torch.device, - provider_config: HuggingFacePostTrainingConfig, - ) -> AutoModelForCausalLM: - """Load and initialize the model for training. - Args: - model: The model identifier to load - device: The device to load the model onto - provider_config: Provider-specific configuration - Returns: - The loaded and initialized model - Raises: - RuntimeError: If model loading fails - """ - logger.info("Loading the base model") - try: - model_config = AutoConfig.from_pretrained(model, **provider_config.model_specific_config) - model_obj = AutoModelForCausalLM.from_pretrained( - model, - torch_dtype="auto" if device.type != "cpu" else "float32", - quantization_config=None, - config=model_config, - **provider_config.model_specific_config, - ) - # Always move model to specified device - model_obj = model_obj.to(device) - logger.info(f"Model loaded and moved to device: {model_obj.device}") - return model_obj - except Exception as e: - raise RuntimeError(f"Failed to load model: {str(e)}") from e - def setup_training_args( self, config: TrainingConfig, @@ -439,27 +304,12 @@ class HFFinetuningSingleDevice: raise ValueError("DataConfig is required for training") data_config = config.data_config - # Calculate steps - total_steps = steps_per_epoch * config.n_epochs - max_steps = min(config.max_steps_per_epoch, total_steps) - logging_steps = max(1, steps_per_epoch // 50) # Log 50 times per epoch - - logger.info("Training configuration:") - logger.info(f"- Steps per epoch: {steps_per_epoch}") - logger.info(f"- Total steps: {total_steps}") - logger.info(f"- Max steps: {max_steps}") - logger.info(f"- Logging steps: {logging_steps}") - - # Configure save strategy - save_strategy = "no" - eval_strategy = "no" - if output_dir_path: - save_strategy = "epoch" - eval_strategy = "epoch" - logger.info(f"Will save checkpoints to {output_dir_path}") + # Calculate steps and get save strategy + step_info = calculate_training_steps(steps_per_epoch, config) + save_strategy, eval_strategy = get_save_strategy(output_dir_path) return SFTConfig( - max_steps=max_steps, + max_steps=step_info["max_steps"], output_dir=str(output_dir_path) if output_dir_path is not None else None, num_train_epochs=config.n_epochs, per_device_train_batch_size=data_config.batch_size, @@ -483,7 +333,7 @@ class HFFinetuningSingleDevice: load_best_model_at_end=True if output_dir_path else False, metric_for_best_model="eval_loss", greater_is_better=False, - logging_steps=logging_steps, + logging_steps=step_info["logging_steps"], ) def save_model( @@ -523,13 +373,11 @@ class HFFinetuningSingleDevice: ) -> None: """Run the training process with signal handling.""" - def signal_handler(signum, frame): - """Handle termination signals gracefully.""" - logger.info(f"Received signal {signum}, initiating graceful shutdown") - sys.exit(0) + # Setup environment variables + setup_environment() - signal.signal(signal.SIGTERM, signal_handler) - signal.signal(signal.SIGINT, signal_handler) + # Setup signal handlers + setup_signal_handlers() # Convert config dicts back to objects logger.info("Initializing configuration objects") @@ -558,7 +406,7 @@ class HFFinetuningSingleDevice: ) # Load model - model_obj = self.load_model(model, device, provider_config_obj) + model_obj = load_model(model, device, provider_config_obj) # Initialize trainer logger.info("Initializing SFTTrainer") @@ -633,7 +481,7 @@ class HFFinetuningSingleDevice: # Train in a separate process logger.info("Starting training in separate process") try: - # Set multiprocessing start method to 'spawn' for CUDA/MPS compatibility + # Setup multiprocessing for device if device.type in ["cuda", "mps"]: multiprocessing.set_start_method("spawn", force=True) @@ -663,37 +511,7 @@ class HFFinetuningSingleDevice: checkpoints = [] if output_dir_path: - # Get all checkpoint directories and sort them numerically - checkpoint_dirs = sorted( - [d for d in output_dir_path.glob("checkpoint-*") if d.is_dir()], - key=lambda x: int(x.name.split("-")[1]), - ) - - # Add all checkpoint directories - for epoch_number, checkpoint_dir in enumerate(checkpoint_dirs, start=1): - # Get the creation time of the directory - created_time = datetime.fromtimestamp(os.path.getctime(checkpoint_dir), tz=UTC) - - checkpoint = Checkpoint( - identifier=checkpoint_dir.name, - created_at=created_time, - epoch=epoch_number, - post_training_job_id=job_uuid, - path=str(checkpoint_dir), - ) - checkpoints.append(checkpoint) - - # Add the merged model as a checkpoint - merged_model_path = output_dir_path / "merged_model" - if merged_model_path.exists(): - checkpoint = Checkpoint( - identifier=f"{model}-sft-{config.n_epochs}", - created_at=datetime.now(UTC), - epoch=config.n_epochs, - post_training_job_id=job_uuid, - path=str(merged_model_path), - ) - checkpoints.append(checkpoint) + checkpoints = create_checkpoints(output_dir_path, job_uuid, model, config, "merged_model") return memory_stats, checkpoints if checkpoints else None finally: diff --git a/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device_dpo.py b/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device_dpo.py new file mode 100644 index 000000000..a7c19faac --- /dev/null +++ b/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device_dpo.py @@ -0,0 +1,485 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import gc +import logging +import multiprocessing +from pathlib import Path +from typing import Any + +import torch +from datasets import Dataset +from transformers import ( + AutoTokenizer, +) +from trl import DPOConfig, DPOTrainer + +from llama_stack.apis.datasetio import DatasetIO +from llama_stack.apis.datasets import Datasets +from llama_stack.apis.post_training import ( + Checkpoint, + DPOAlignmentConfig, + TrainingConfig, +) +from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device + +from ..config import HuggingFacePostTrainingConfig +from ..utils import ( + calculate_training_steps, + create_checkpoints, + get_memory_stats, + get_save_strategy, + load_model, + load_rows_from_dataset, + setup_environment, + setup_signal_handlers, + setup_torch_device, + split_dataset, +) + +logger = logging.getLogger(__name__) + + +class HFDPOAlignmentSingleDevice: + def __init__( + self, + job_uuid: str, + datasetio_api: DatasetIO, + datasets_api: Datasets, + ): + self.datasetio_api = datasetio_api + self.datasets_api = datasets_api + self.job_uuid = job_uuid + + def validate_dataset_format(self, rows: list[dict]) -> None: + """Validate that the dataset has the required fields for DPO training.""" + required_fields = ["prompt", "chosen", "rejected"] + + if not rows: + logger.warning("Dataset is empty") + raise ValueError("Dataset is empty") + + for i, row in enumerate(rows): + if not isinstance(row, dict): + logger.warning(f"Row {i} is not a dictionary") + raise ValueError(f"Row {i} is not a dictionary") + + for field in required_fields: + if field not in row: + logger.warning(f"Row {i} missing required DPO field: {field}") + raise ValueError(f"Row {i} missing required DPO field: {field}") + + # Handle both string and list formats + if field == "prompt": + # Prompt should be a string + if not isinstance(row[field], str): + logger.warning(f"Row {i} field '{field}' is not a string") + raise ValueError(f"Row {i} field '{field}' is not a string") + if not row[field].strip(): + logger.warning(f"Row {i} field '{field}' is empty") + raise ValueError(f"Row {i} field '{field}' is empty") + else: + # chosen/rejected can be either strings or lists of messages + if isinstance(row[field], str): + if not row[field].strip(): + logger.warning(f"Row {i} field '{field}' is empty") + raise ValueError(f"Row {i} field '{field}' is empty") + elif isinstance(row[field], list): + if not row[field]: + logger.warning(f"Row {i} field '{field}' is empty list") + raise ValueError(f"Row {i} field '{field}' is empty list") + else: + logger.warning(f"Row {i} field '{field}' is neither string nor list") + raise ValueError(f"Row {i} field '{field}' is neither string nor list") + + logger.info(f"DPO dataset validation passed: {len(rows)} preference examples") + + def _process_dpo_format(self, row: dict) -> tuple[str | None, str | None, str | None]: + """Process a row in DPO format, handling both string and conversation list formats.""" + if all(field in row for field in ["prompt", "chosen", "rejected"]): + prompt = row["prompt"] + + # Handle chosen field - convert list to string if needed + if isinstance(row["chosen"], list): + # For conversation format, concatenate messages + chosen = "\n".join( + [msg.get("content", "") if isinstance(msg, dict) else str(msg) for msg in row["chosen"]] + ) + else: + chosen = row["chosen"] + + # Handle rejected field - convert list to string if needed + if isinstance(row["rejected"], list): + # For conversation format, concatenate messages + rejected = "\n".join( + [msg.get("content", "") if isinstance(msg, dict) else str(msg) for msg in row["rejected"]] + ) + else: + rejected = row["rejected"] + + return prompt, chosen, rejected + return None, None, None + + def _format_text_for_dpo(self, prompt: str, response: str, provider_config: HuggingFacePostTrainingConfig) -> str: + """Format prompt and response text based on model requirements.""" + if hasattr(provider_config, "chat_template") and provider_config.chat_template: + # Use the chat template, supporting both {prompt}/{response} and {input}/{output} + template = provider_config.chat_template + # Try prompt/response first (DPO style) + if "{prompt}" in template and "{response}" in template: + return template.format(prompt=prompt, response=response) + # Fall back to input/output (SFT style) + elif "{input}" in template and "{output}" in template: + return template.format(input=prompt, output=response) + else: + # If template doesn't have expected placeholders, use default + return f"{prompt}\n{response}" + return f"{prompt}\n{response}" + + def _create_dataset( + self, rows: list[dict], config: TrainingConfig, provider_config: HuggingFacePostTrainingConfig + ) -> Dataset: + """Create and preprocess the dataset for DPO.""" + dpo_examples = [] + for row in rows: + prompt, chosen, rejected = self._process_dpo_format(row) + + if prompt and chosen and rejected: + # Format the texts + chosen_formatted = self._format_text_for_dpo(prompt, chosen, provider_config) + rejected_formatted = self._format_text_for_dpo(prompt, rejected, provider_config) + + dpo_examples.append( + { + "prompt": prompt, + "chosen": chosen_formatted, + "rejected": rejected_formatted, + } + ) + + if not dpo_examples: + raise ValueError("No valid preference examples found in dataset") + + logger.info(f"Created DPO dataset with {len(dpo_examples)} preference pairs") + return Dataset.from_list(dpo_examples) + + def _preprocess_dataset( + self, ds: Dataset, tokenizer: AutoTokenizer, provider_config: HuggingFacePostTrainingConfig + ) -> Dataset: + """Preprocess the dataset with tokenizer for DPO.""" + # DPOTrainer expects raw text, so we don't tokenize here + # Just return the dataset as is + return ds + + def _run_training_sync( + self, + model: str, + provider_config: dict[str, Any], + dpo_config: dict[str, Any], + config: dict[str, Any], + output_dir_path: Path | None, + ) -> None: + """Synchronous wrapper for running DPO training process.""" + import asyncio + + logger.info("Starting DPO training process with async wrapper") + asyncio.run( + self._run_training( + model=model, + provider_config=provider_config, + dpo_config=dpo_config, + config=config, + output_dir_path=output_dir_path, + ) + ) + + async def load_dataset( + self, + model: str, + config: TrainingConfig, + provider_config: HuggingFacePostTrainingConfig, + ) -> tuple[Dataset, Dataset, AutoTokenizer]: + """Load and prepare the dataset for DPO training.""" + # Validate data config + if not config.data_config: + raise ValueError("DataConfig is required for DPO training") + + # Load dataset + logger.info(f"Loading dataset: {config.data_config.dataset_id}") + rows = await load_rows_from_dataset(self.datasetio_api, config.data_config.dataset_id) + self.validate_dataset_format(rows) + logger.info(f"Loaded {len(rows)} rows from dataset") + + # Initialize tokenizer + logger.info(f"Initializing tokenizer for model: {model}") + try: + tokenizer = AutoTokenizer.from_pretrained(model, **provider_config.model_specific_config) + + # Set pad token to eos token if not present + if not tokenizer.pad_token: + tokenizer.pad_token = tokenizer.eos_token + + # Set padding side to left for DPO + tokenizer.padding_side = "left" + + # Set truncation side to right to keep the beginning of the sequence + tokenizer.truncation_side = "right" + + # Set model max length to match provider config + tokenizer.model_max_length = provider_config.max_seq_length + + logger.info("Tokenizer initialized successfully for DPO") + except Exception as e: + raise RuntimeError(f"Failed to initialize tokenizer: {str(e)}") from e + + # Create and preprocess dataset + logger.info("Creating and preprocessing dataset for DPO") + try: + ds = self._create_dataset(rows, config, provider_config) + ds = self._preprocess_dataset(ds, tokenizer, provider_config) + logger.info(f"Dataset created with {len(ds)} examples") + except Exception as e: + raise ValueError(f"Failed to create dataset: {str(e)}") from e + + # Split dataset + train_dataset, eval_dataset = split_dataset(ds) + + return train_dataset, eval_dataset, tokenizer + + def setup_training_args( + self, + config: TrainingConfig, + provider_config: HuggingFacePostTrainingConfig, + dpo_config: DPOAlignmentConfig, + device: torch.device, + output_dir_path: Path | None, + steps_per_epoch: int, + ) -> DPOConfig: + """Setup DPO training arguments.""" + logger.info("Configuring DPO training arguments") + lr = 5e-7 # Lower learning rate for DPO + if config.optimizer_config: + lr = config.optimizer_config.lr + logger.info(f"Using custom learning rate: {lr}") + + # Validate data config + if not config.data_config: + raise ValueError("DataConfig is required for training") + data_config = config.data_config + + # Calculate steps and get save strategy + step_info = calculate_training_steps(steps_per_epoch, config) + save_strategy, eval_strategy = get_save_strategy(output_dir_path) + + logger.info("DPO training configuration:") + logger.info(f"- DPO beta: {dpo_config.beta}") + logger.info(f"- DPO loss type: {provider_config.dpo_loss_type}") + + # Calculate max prompt length as half of max sequence length + max_prompt_length = provider_config.max_seq_length // 2 + + return DPOConfig( + max_steps=step_info["max_steps"], + output_dir=str(output_dir_path) if output_dir_path is not None else None, + num_train_epochs=config.n_epochs, + per_device_train_batch_size=data_config.batch_size, + fp16=device.type == "cuda", + bf16=False, # Causes CPU issues. + eval_strategy=eval_strategy, + use_cpu=True if device.type == "cpu" and not torch.backends.mps.is_available() else False, + save_strategy=save_strategy, + report_to="none", + max_length=provider_config.max_seq_length, + max_prompt_length=max_prompt_length, + gradient_accumulation_steps=config.gradient_accumulation_steps, + gradient_checkpointing=provider_config.gradient_checkpointing, + learning_rate=lr, + warmup_ratio=provider_config.warmup_ratio, + weight_decay=provider_config.weight_decay, + remove_unused_columns=False, + dataloader_pin_memory=provider_config.dataloader_pin_memory, + dataloader_num_workers=provider_config.dataloader_num_workers, + load_best_model_at_end=True if output_dir_path else False, + metric_for_best_model="eval_loss", + greater_is_better=False, + logging_steps=step_info["logging_steps"], + save_total_limit=provider_config.save_total_limit, + # DPO specific parameters + beta=dpo_config.beta, + loss_type=provider_config.dpo_loss_type, + ) + + def save_model( + self, + trainer: DPOTrainer, + output_dir_path: Path, + ) -> None: + """Save the trained DPO model.""" + logger.info("Saving final DPO model") + + save_path = output_dir_path / "dpo_model" + logger.info(f"Saving model to {save_path}") + + # Save model and tokenizer + trainer.save_model(str(save_path)) + + async def _run_training( + self, + model: str, + provider_config: dict[str, Any], + dpo_config: dict[str, Any], + config: dict[str, Any], + output_dir_path: Path | None, + ) -> None: + """Run the DPO training process with signal handling.""" + + # Setup environment variables + setup_environment() + + # Setup signal handlers + setup_signal_handlers() + + # Convert config dicts back to objects + logger.info("Initializing configuration objects") + provider_config_obj = HuggingFacePostTrainingConfig(**provider_config) + config_obj = TrainingConfig(**config) + dpo_config_obj = DPOAlignmentConfig(**dpo_config) + + # Initialize and validate device + device = setup_torch_device(provider_config_obj.device) + logger.info(f"Using device '{device}'") + + # Load dataset and tokenizer + train_dataset, eval_dataset, tokenizer = await self.load_dataset(model, config_obj, provider_config_obj) + + # Calculate steps per epoch + if not config_obj.data_config: + raise ValueError("DataConfig is required for training") + steps_per_epoch = len(train_dataset) // config_obj.data_config.batch_size + + # Setup training arguments + training_args = self.setup_training_args( + config_obj, + provider_config_obj, + dpo_config_obj, + device, + output_dir_path, + steps_per_epoch, + ) + + # Load model and reference model + model_obj = load_model(model, device, provider_config_obj) + ref_model = None + if provider_config_obj.use_reference_model: + logger.info("Loading separate reference model for DPO") + ref_model = load_model(model, device, provider_config_obj) + else: + logger.info("Using shared reference model for DPO") + + # Initialize DPO trainer + logger.info("Initializing DPOTrainer") + trainer = DPOTrainer( + model=model_obj, + ref_model=ref_model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=tokenizer, + ) + + try: + # Train + logger.info("Starting DPO training") + trainer.train() + logger.info("DPO training completed successfully") + + # Save final model if output directory is provided + if output_dir_path: + logger.info(f"Saving model to output directory: {output_dir_path}") + self.save_model(trainer, output_dir_path) + logger.info("Model save completed") + + finally: + # Clean up resources + logger.info("Cleaning up resources") + if hasattr(trainer, "model"): + evacuate_model_from_device(trainer.model, device.type) + if ref_model: + evacuate_model_from_device(ref_model, device.type) + del trainer + del ref_model + gc.collect() + logger.info("Cleanup completed") + logger.info("DPO training process finishing successfully") + + async def train( + self, + model: str, + output_dir: str | None, + job_uuid: str, + dpo_config: DPOAlignmentConfig, + config: TrainingConfig, + provider_config: HuggingFacePostTrainingConfig, + ) -> tuple[dict[str, Any], list[Checkpoint] | None]: + """Train a model using HuggingFace's DPOTrainer""" + # Initialize and validate device + device = setup_torch_device(provider_config.device) + logger.info(f"Using device '{device}'") + + output_dir_path = None + if output_dir: + output_dir_path = Path(output_dir) + + # Track memory stats + memory_stats = { + "initial": get_memory_stats(device), + "after_training": None, + "final": None, + } + + # Validate data config + if not config.data_config: + raise ValueError("DataConfig is required for training") + + # Train in a separate process + logger.info("Starting DPO training in separate process") + try: + # Setup multiprocessing for device + if device.type in ["cuda", "mps"]: + multiprocessing.set_start_method("spawn", force=True) + + process = multiprocessing.Process( + target=self._run_training_sync, + kwargs={ + "model": model, + "provider_config": provider_config.model_dump(), + "dpo_config": dpo_config.model_dump(), + "config": config.model_dump(), + "output_dir_path": output_dir_path, + }, + ) + process.start() + + # Monitor the process + while process.is_alive(): + process.join(timeout=1) # Check every second + if not process.is_alive(): + break + + # Get the return code + if process.exitcode != 0: + raise RuntimeError(f"DPO training failed with exit code {process.exitcode}") + + memory_stats["after_training"] = get_memory_stats(device) + + checkpoints = [] + if output_dir_path: + checkpoints = create_checkpoints(output_dir_path, job_uuid, model, config, "dpo_model") + + return memory_stats, checkpoints if checkpoints else None + finally: + memory_stats["final"] = get_memory_stats(device) + gc.collect() diff --git a/llama_stack/providers/inline/post_training/huggingface/utils.py b/llama_stack/providers/inline/post_training/huggingface/utils.py new file mode 100644 index 000000000..3147c19ab --- /dev/null +++ b/llama_stack/providers/inline/post_training/huggingface/utils.py @@ -0,0 +1,269 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import logging +import os +import signal +import sys +from datetime import UTC, datetime +from pathlib import Path +from typing import Any + +import psutil +import torch +from datasets import Dataset +from transformers import AutoConfig, AutoModelForCausalLM + +from llama_stack.apis.datasetio import DatasetIO +from llama_stack.apis.post_training import Checkpoint, TrainingConfig + +from .config import HuggingFacePostTrainingConfig + +logger = logging.getLogger(__name__) + + +def setup_environment(): + """Setup common environment variables for training.""" + os.environ["TOKENIZERS_PARALLELISM"] = "false" + os.environ["MKL_THREADING_LAYER"] = "GNU" + os.environ["MKL_SERVICE_FORCE_INTEL"] = "0" + os.environ["MKL_NUM_THREADS"] = "1" + + +def bytes_to_gb(to_convert: int) -> str: + """Converts memory stats to GB and formats to 2 decimal places. + Args: + to_convert: Memory value in bytes + Returns: + str: Memory value in GB formatted to 2 decimal places + """ + return f"{(to_convert / (1024**3)):.2f}" + + +def get_memory_stats(device: torch.device) -> dict[str, Any]: + """Get memory statistics for the given device.""" + stats = { + "system_memory": { + "total": bytes_to_gb(psutil.virtual_memory().total), + "available": bytes_to_gb(psutil.virtual_memory().available), + "used": bytes_to_gb(psutil.virtual_memory().used), + "percent": psutil.virtual_memory().percent, + } + } + + if device.type == "cuda": + stats["device_memory"] = { + "allocated": bytes_to_gb(torch.cuda.memory_allocated(device)), + "reserved": bytes_to_gb(torch.cuda.memory_reserved(device)), + "max_allocated": bytes_to_gb(torch.cuda.max_memory_allocated(device)), + } + elif device.type == "mps": + # MPS doesn't provide direct memory stats, but we can track system memory + stats["device_memory"] = { + "note": "MPS memory stats not directly available", + "system_memory_used": bytes_to_gb(psutil.virtual_memory().used), + } + elif device.type == "cpu": + # For CPU, we track process memory usage + process = psutil.Process() + stats["device_memory"] = { + "process_rss": bytes_to_gb(process.memory_info().rss), + "process_vms": bytes_to_gb(process.memory_info().vms), + "process_percent": process.memory_percent(), + } + + return stats + + +def setup_torch_device(device_str: str) -> torch.device: + """Initialize and validate a PyTorch device. + This function handles device initialization and validation for different device types: + - CUDA: Validates CUDA availability and handles device selection + - MPS: Validates MPS availability for Apple Silicon + - CPU: Basic validation + - HPU: Raises error as it's not supported + Args: + device_str: String specifying the device ('cuda', 'cpu', 'mps') + Returns: + torch.device: The initialized and validated device + Raises: + RuntimeError: If device initialization fails or device is not supported + """ + try: + device = torch.device(device_str) + except RuntimeError as e: + raise RuntimeError(f"Error getting Torch Device {str(e)}") from e + + # Validate device capabilities + if device.type == "cuda": + if not torch.cuda.is_available(): + raise RuntimeError( + f"{device.type}: Torch has no CUDA/ROCm support or could not detect a compatible device." + ) + if device.index is None: + device = torch.device(device.type, torch.cuda.current_device()) + elif device.type == "mps": + if not torch.backends.mps.is_available(): + raise RuntimeError(f"{device.type}: Torch has no MPS support or could not detect a compatible device.") + elif device.type == "hpu": + raise RuntimeError(f"{device.type}: training does not support Intel Gaudi.") + + return device + + +async def load_rows_from_dataset(datasetio_api: DatasetIO, dataset_id: str) -> list[dict[str, Any]]: + """Load dataset from llama stack dataset provider""" + try: + all_rows = await datasetio_api.iterrows( + dataset_id=dataset_id, + limit=-1, + ) + if not isinstance(all_rows.data, list): + raise RuntimeError("Expected dataset data to be a list") + return all_rows.data + except Exception as e: + raise RuntimeError(f"Failed to load dataset: {str(e)}") from e + + +def load_model( + model: str, + device: torch.device, + provider_config: HuggingFacePostTrainingConfig, +) -> AutoModelForCausalLM: + """Load and initialize the model for training. + Args: + model: The model identifier to load + device: The device to load the model onto + provider_config: Provider-specific configuration + Returns: + The loaded and initialized model + Raises: + RuntimeError: If model loading fails + """ + logger.info("Loading the base model") + try: + model_config = AutoConfig.from_pretrained(model, **provider_config.model_specific_config) + model_obj = AutoModelForCausalLM.from_pretrained( + model, + torch_dtype="auto" if device.type != "cpu" else "float32", + quantization_config=None, + config=model_config, + **provider_config.model_specific_config, + ) + # Always move model to specified device + model_obj = model_obj.to(device) + logger.info(f"Model loaded and moved to device: {model_obj.device}") + return model_obj + except Exception as e: + raise RuntimeError(f"Failed to load model: {str(e)}") from e + + +def split_dataset(ds: Dataset) -> tuple[Dataset, Dataset]: + """Split dataset into train and validation sets. + Args: + ds: Dataset to split + Returns: + tuple: (train_dataset, eval_dataset) + """ + logger.info("Splitting dataset into train and validation sets") + train_val_split = ds.train_test_split(test_size=0.1, seed=42) + train_dataset = train_val_split["train"] + eval_dataset = train_val_split["test"] + logger.info(f"Split dataset into {len(train_dataset)} training and {len(eval_dataset)} validation examples") + return train_dataset, eval_dataset + + +def setup_signal_handlers(): + """Setup signal handlers for graceful shutdown.""" + + def signal_handler(signum, frame): + logger.info(f"Received signal {signum}, initiating graceful shutdown") + sys.exit(0) + + signal.signal(signal.SIGTERM, signal_handler) + signal.signal(signal.SIGINT, signal_handler) + + +def calculate_training_steps(steps_per_epoch: int, config: TrainingConfig) -> dict[str, int]: + """Calculate training steps and logging configuration. + Args: + steps_per_epoch: Number of training steps per epoch + config: Training configuration + Returns: + dict: Dictionary with calculated step values + """ + total_steps = steps_per_epoch * config.n_epochs + max_steps = min(config.max_steps_per_epoch, total_steps) + logging_steps = max(1, steps_per_epoch // 50) # Log 50 times per epoch + + logger.info("Training configuration:") + logger.info(f"- Steps per epoch: {steps_per_epoch}") + logger.info(f"- Total steps: {total_steps}") + logger.info(f"- Max steps: {max_steps}") + logger.info(f"- Logging steps: {logging_steps}") + + return {"total_steps": total_steps, "max_steps": max_steps, "logging_steps": logging_steps} + + +def get_save_strategy(output_dir_path: Path | None) -> tuple[str, str]: + """Get save and evaluation strategy based on output directory. + Args: + output_dir_path: Optional path to save the model + Returns: + tuple: (save_strategy, eval_strategy) + """ + if output_dir_path: + logger.info(f"Will save checkpoints to {output_dir_path}") + return "epoch", "epoch" + return "no", "no" + + +def create_checkpoints( + output_dir_path: Path, job_uuid: str, model: str, config: TrainingConfig, final_model_name: str +) -> list[Checkpoint]: + """Create checkpoint objects from training output. + Args: + output_dir_path: Path to the training output directory + job_uuid: Unique identifier for the training job + model: Model identifier + config: Training configuration + final_model_name: Name of the final model directory ("merged_model" for SFT, "dpo_model" for DPO) + Returns: + List of Checkpoint objects + """ + checkpoints = [] + + # Add checkpoint directories + checkpoint_dirs = sorted( + [d for d in output_dir_path.glob("checkpoint-*") if d.is_dir()], + key=lambda x: int(x.name.split("-")[1]), + ) + + for epoch_number, checkpoint_dir in enumerate(checkpoint_dirs, start=1): + created_time = datetime.fromtimestamp(os.path.getctime(checkpoint_dir), tz=UTC) + checkpoint = Checkpoint( + identifier=checkpoint_dir.name, + created_at=created_time, + epoch=epoch_number, + post_training_job_id=job_uuid, + path=str(checkpoint_dir), + ) + checkpoints.append(checkpoint) + + # Add final model + final_model_path = output_dir_path / final_model_name + if final_model_path.exists(): + training_type = "sft" if final_model_name == "merged_model" else "dpo" + checkpoint = Checkpoint( + identifier=f"{model}-{training_type}-{config.n_epochs}", + created_at=datetime.now(UTC), + epoch=config.n_epochs, + post_training_job_id=job_uuid, + path=str(final_model_path), + ) + checkpoints.append(checkpoint) + + return checkpoints diff --git a/llama_stack/providers/inline/post_training/torchtune/__init__.py b/llama_stack/providers/inline/post_training/torchtune/__init__.py index 7a2f9eba2..af4ebd92a 100644 --- a/llama_stack/providers/inline/post_training/torchtune/__init__.py +++ b/llama_stack/providers/inline/post_training/torchtune/__init__.py @@ -6,7 +6,7 @@ from typing import Any -from llama_stack.distribution.datatypes import Api +from llama_stack.core.datatypes import Api from .config import TorchtunePostTrainingConfig diff --git a/llama_stack/providers/inline/post_training/torchtune/post_training.py b/llama_stack/providers/inline/post_training/torchtune/post_training.py index d20e11b11..765f6789d 100644 --- a/llama_stack/providers/inline/post_training/torchtune/post_training.py +++ b/llama_stack/providers/inline/post_training/torchtune/post_training.py @@ -23,12 +23,8 @@ from llama_stack.apis.post_training import ( from llama_stack.providers.inline.post_training.torchtune.config import ( TorchtunePostTrainingConfig, ) -from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import ( - LoraFinetuningSingleDevice, -) from llama_stack.providers.utils.scheduler import JobArtifact, Scheduler from llama_stack.providers.utils.scheduler import JobStatus as SchedulerJobStatus -from llama_stack.schema_utils import webmethod class TrainingArtifactType(Enum): @@ -84,6 +80,10 @@ class TorchtunePostTrainingImpl: if isinstance(algorithm_config, LoraFinetuningConfig): async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_cb): + from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import ( + LoraFinetuningSingleDevice, + ) + on_log_message_cb("Starting Lora finetuning") recipe = LoraFinetuningSingleDevice( @@ -144,7 +144,6 @@ class TorchtunePostTrainingImpl: data = cls._get_artifacts_metadata_by_type(job, TrainingArtifactType.RESOURCES_STATS.value) return data[0] if data else None - @webmethod(route="/post-training/job/status") async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse | None: job = self._scheduler.get_job(job_uuid) @@ -171,11 +170,9 @@ class TorchtunePostTrainingImpl: resources_allocated=self._get_resources_allocated(job), ) - @webmethod(route="/post-training/job/cancel") async def cancel_training_job(self, job_uuid: str) -> None: self._scheduler.cancel(job_uuid) - @webmethod(route="/post-training/job/artifacts") async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse | None: job = self._scheduler.get_job(job_uuid) return PostTrainingJobArtifactsResponse(job_uuid=job_uuid, checkpoints=self._get_checkpoints(job)) diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index fed19428c..49e1c95b8 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -43,8 +43,8 @@ from llama_stack.apis.post_training import ( QATFinetuningConfig, TrainingConfig, ) -from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR -from llama_stack.distribution.utils.model_utils import model_local_dir +from llama_stack.core.utils.config_dirs import DEFAULT_CHECKPOINT_DIR +from llama_stack.core.utils.model_utils import model_local_dir from llama_stack.models.llama.sku_list import resolve_model from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device from llama_stack.providers.inline.post_training.torchtune.common import utils diff --git a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py index 0d1c4ffe1..f83c39a6a 100644 --- a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py +++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py @@ -4,7 +4,9 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import logging import re +import uuid from string import Template from typing import Any @@ -20,8 +22,9 @@ from llama_stack.apis.safety import ( SafetyViolation, ViolationLevel, ) +from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults, OpenAICategories from llama_stack.apis.shields import Shield -from llama_stack.distribution.datatypes import Api +from llama_stack.core.datatypes import Api from llama_stack.models.llama.datatypes import Role from llama_stack.models.llama.sku_types import CoreModelId from llama_stack.providers.datatypes import ShieldsProtocolPrivate @@ -67,6 +70,31 @@ SAFETY_CATEGORIES_TO_CODE_MAP = { CAT_ELECTIONS: "S13", CAT_CODE_INTERPRETER_ABUSE: "S14", } +SAFETY_CODE_TO_CATEGORIES_MAP = {v: k for k, v in SAFETY_CATEGORIES_TO_CODE_MAP.items()} + +OPENAI_TO_LLAMA_CATEGORIES_MAP = { + OpenAICategories.VIOLENCE: [CAT_VIOLENT_CRIMES], + OpenAICategories.VIOLENCE_GRAPHIC: [CAT_VIOLENT_CRIMES], + OpenAICategories.HARRASMENT: [CAT_CHILD_EXPLOITATION], + OpenAICategories.HARRASMENT_THREATENING: [CAT_VIOLENT_CRIMES, CAT_CHILD_EXPLOITATION], + OpenAICategories.HATE: [CAT_HATE], + OpenAICategories.HATE_THREATENING: [CAT_HATE, CAT_VIOLENT_CRIMES], + OpenAICategories.ILLICIT: [CAT_NON_VIOLENT_CRIMES], + OpenAICategories.ILLICIT_VIOLENT: [CAT_VIOLENT_CRIMES, CAT_INDISCRIMINATE_WEAPONS], + OpenAICategories.SEXUAL: [CAT_SEX_CRIMES, CAT_SEXUAL_CONTENT], + OpenAICategories.SEXUAL_MINORS: [CAT_CHILD_EXPLOITATION], + OpenAICategories.SELF_HARM: [CAT_SELF_HARM], + OpenAICategories.SELF_HARM_INTENT: [CAT_SELF_HARM], + OpenAICategories.SELF_HARM_INSTRUCTIONS: [CAT_SELF_HARM, CAT_SPECIALIZED_ADVICE], + # These are custom categories that are not in the OpenAI moderation categories + "custom/defamation": [CAT_DEFAMATION], + "custom/specialized_advice": [CAT_SPECIALIZED_ADVICE], + "custom/privacy_violation": [CAT_PRIVACY], + "custom/intellectual_property": [CAT_INTELLECTUAL_PROPERTY], + "custom/weapons": [CAT_INDISCRIMINATE_WEAPONS], + "custom/elections": [CAT_ELECTIONS], + "custom/code_interpreter_abuse": [CAT_CODE_INTERPRETER_ABUSE], +} DEFAULT_LG_V3_SAFETY_CATEGORIES = [ @@ -150,6 +178,11 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate): if not model_id: raise ValueError("Llama Guard shield must have a model id") + async def unregister_shield(self, identifier: str) -> None: + # LlamaGuard doesn't need to do anything special for unregistration + # The routing table handles the removal from the registry + pass + async def run_shield( self, shield_id: str, @@ -189,6 +222,34 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate): return await impl.run(messages) + async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject: + if isinstance(input, list): + messages = input.copy() + else: + messages = [input] + + # convert to user messages format with role + messages = [UserMessage(content=m) for m in messages] + + # Determine safety categories based on the model type + # For known Llama Guard models, use specific categories + if model in LLAMA_GUARD_MODEL_IDS: + # Use the mapped model for categories but the original model_id for inference + mapped_model = LLAMA_GUARD_MODEL_IDS[model] + safety_categories = MODEL_TO_SAFETY_CATEGORIES_MAP.get(mapped_model, DEFAULT_LG_V3_SAFETY_CATEGORIES) + else: + # For unknown models, use default Llama Guard 3 8B categories + safety_categories = DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE] + + impl = LlamaGuardShield( + model=model, + inference_api=self.inference_api, + excluded_categories=self.config.excluded_categories, + safety_categories=safety_categories, + ) + + return await impl.run_moderation(messages) + class LlamaGuardShield: def __init__( @@ -335,3 +396,117 @@ class LlamaGuardShield: ) raise ValueError(f"Unexpected response: {response}") + + async def run_moderation(self, messages: list[Message]) -> ModerationObject: + if not messages: + return self.create_moderation_object(self.model) + + # TODO: Add Image based support for OpenAI Moderations + shield_input_message = self.build_text_shield_input(messages) + + response = await self.inference_api.openai_chat_completion( + model=self.model, + messages=[shield_input_message], + stream=False, + ) + content = response.choices[0].message.content + content = content.strip() + return self.get_moderation_object(content) + + def create_moderation_object(self, model: str, unsafe_code: str | None = None) -> ModerationObject: + """Create a ModerationObject for either safe or unsafe content. + + Args: + model: The model name + unsafe_code: Optional comma-separated list of safety codes. If None, creates safe object. + + Returns: + ModerationObject with appropriate configuration + """ + # Set default values for safe case + categories = dict.fromkeys(OPENAI_TO_LLAMA_CATEGORIES_MAP.keys(), False) + category_scores = dict.fromkeys(OPENAI_TO_LLAMA_CATEGORIES_MAP.keys(), 1.0) + category_applied_input_types = {key: [] for key in OPENAI_TO_LLAMA_CATEGORIES_MAP.keys()} + flagged = False + user_message = None + metadata = {} + + # Handle unsafe case + if unsafe_code: + unsafe_code_list = [code.strip() for code in unsafe_code.split(",")] + invalid_codes = [code for code in unsafe_code_list if code not in SAFETY_CODE_TO_CATEGORIES_MAP] + if invalid_codes: + logging.warning(f"Invalid safety codes returned: {invalid_codes}") + # just returning safe object, as we don't know what the invalid codes can map to + return ModerationObject( + id=f"modr-{uuid.uuid4()}", + model=model, + results=[ + ModerationObjectResults( + flagged=flagged, + categories=categories, + category_applied_input_types=category_applied_input_types, + category_scores=category_scores, + user_message=user_message, + metadata=metadata, + ) + ], + ) + + # Get OpenAI categories for the unsafe codes + openai_categories = [] + for code in unsafe_code_list: + llama_guard_category = SAFETY_CODE_TO_CATEGORIES_MAP[code] + openai_categories.extend( + k for k, v_l in OPENAI_TO_LLAMA_CATEGORIES_MAP.items() if llama_guard_category in v_l + ) + + # Update categories for unsafe content + categories = {k: k in openai_categories for k in OPENAI_TO_LLAMA_CATEGORIES_MAP} + category_scores = {k: 1.0 if k in openai_categories else 0.0 for k in OPENAI_TO_LLAMA_CATEGORIES_MAP} + category_applied_input_types = { + k: ["text"] if k in openai_categories else [] for k in OPENAI_TO_LLAMA_CATEGORIES_MAP + } + flagged = True + user_message = CANNED_RESPONSE_TEXT + metadata = {"violation_type": unsafe_code_list} + + return ModerationObject( + id=f"modr-{uuid.uuid4()}", + model=model, + results=[ + ModerationObjectResults( + flagged=flagged, + categories=categories, + category_applied_input_types=category_applied_input_types, + category_scores=category_scores, + user_message=user_message, + metadata=metadata, + ) + ], + ) + + def is_content_safe(self, response: str, unsafe_code: str | None = None) -> bool: + """Check if content is safe based on response and unsafe code.""" + if response.strip() == SAFE_RESPONSE: + return True + + if unsafe_code: + unsafe_code_list = unsafe_code.split(",") + if set(unsafe_code_list).issubset(set(self.excluded_categories)): + return True + + return False + + def get_moderation_object(self, response: str) -> ModerationObject: + response = response.strip() + if self.is_content_safe(response): + return self.create_moderation_object(self.model) + unsafe_code = self.check_unsafe_response(response) + if not unsafe_code: + raise ValueError(f"Unexpected response: {response}") + + if self.is_content_safe(response, unsafe_code): + return self.create_moderation_object(self.model) + else: + return self.create_moderation_object(self.model, unsafe_code) diff --git a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py index ff87889ea..796771ee1 100644 --- a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +++ b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py @@ -18,7 +18,7 @@ from llama_stack.apis.safety import ( ViolationLevel, ) from llama_stack.apis.shields import Shield -from llama_stack.distribution.utils.model_utils import model_local_dir +from llama_stack.core.utils.model_utils import model_local_dir from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, @@ -46,6 +46,9 @@ class PromptGuardSafetyImpl(Safety, ShieldsProtocolPrivate): if shield.provider_resource_id != PROMPT_GUARD_MODEL: raise ValueError(f"Only {PROMPT_GUARD_MODEL} is supported for Prompt Guard. ") + async def unregister_shield(self, identifier: str) -> None: + pass + async def run_shield( self, shield_id: str, diff --git a/llama_stack/providers/inline/scoring/basic/__init__.py b/llama_stack/providers/inline/scoring/basic/__init__.py index d9d150b1a..c996b9c2d 100644 --- a/llama_stack/providers/inline/scoring/basic/__init__.py +++ b/llama_stack/providers/inline/scoring/basic/__init__.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from typing import Any -from llama_stack.distribution.datatypes import Api +from llama_stack.core.datatypes import Api from .config import BasicScoringConfig diff --git a/llama_stack/providers/inline/scoring/basic/scoring.py b/llama_stack/providers/inline/scoring/basic/scoring.py index 09f89be5e..91b10daae 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring.py +++ b/llama_stack/providers/inline/scoring/basic/scoring.py @@ -14,7 +14,7 @@ from llama_stack.apis.scoring import ( ScoringResult, ) from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams -from llama_stack.distribution.datatypes import Api +from llama_stack.core.datatypes import Api from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate from llama_stack.providers.utils.common.data_schema_validator import ( get_valid_schemas, diff --git a/llama_stack/providers/inline/scoring/braintrust/__init__.py b/llama_stack/providers/inline/scoring/braintrust/__init__.py index 8ea6e9b96..3b492ae3f 100644 --- a/llama_stack/providers/inline/scoring/braintrust/__init__.py +++ b/llama_stack/providers/inline/scoring/braintrust/__init__.py @@ -7,7 +7,7 @@ from typing import Any from pydantic import BaseModel -from llama_stack.distribution.datatypes import Api +from llama_stack.core.datatypes import Api from .config import BraintrustScoringConfig diff --git a/llama_stack/providers/inline/scoring/braintrust/braintrust.py b/llama_stack/providers/inline/scoring/braintrust/braintrust.py index d6655d657..14810f706 100644 --- a/llama_stack/providers/inline/scoring/braintrust/braintrust.py +++ b/llama_stack/providers/inline/scoring/braintrust/braintrust.py @@ -29,8 +29,8 @@ from llama_stack.apis.scoring import ( ScoringResultRow, ) from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams -from llama_stack.distribution.datatypes import Api -from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.core.datatypes import Api +from llama_stack.core.request_headers import NeedsRequestProviderData from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate from llama_stack.providers.utils.common.data_schema_validator import ( get_valid_schemas, diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/__init__.py b/llama_stack/providers/inline/scoring/llm_as_judge/__init__.py index 88bf10737..76735fcb3 100644 --- a/llama_stack/providers/inline/scoring/llm_as_judge/__init__.py +++ b/llama_stack/providers/inline/scoring/llm_as_judge/__init__.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from typing import Any -from llama_stack.distribution.datatypes import Api +from llama_stack.core.datatypes import Api from .config import LlmAsJudgeScoringConfig diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py index 2bd113a94..fd651877c 100644 --- a/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +++ b/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py @@ -15,7 +15,7 @@ from llama_stack.apis.scoring import ( ScoringResult, ) from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams -from llama_stack.distribution.datatypes import Api +from llama_stack.core.datatypes import Api from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate from llama_stack.providers.utils.common.data_schema_validator import ( get_valid_schemas, diff --git a/llama_stack/providers/inline/telemetry/meta_reference/__init__.py b/llama_stack/providers/inline/telemetry/meta_reference/__init__.py index 09e97136a..21743b653 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/__init__.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/__init__.py @@ -6,7 +6,7 @@ from typing import Any -from llama_stack.distribution.datatypes import Api +from llama_stack.core.datatypes import Api from .config import TelemetryConfig, TelemetrySink diff --git a/llama_stack/providers/inline/telemetry/meta_reference/config.py b/llama_stack/providers/inline/telemetry/meta_reference/config.py index f2a7c2a6e..31ae80050 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/config.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/config.py @@ -9,7 +9,7 @@ from typing import Any from pydantic import BaseModel, Field, field_validator -from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR +from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR class TelemetrySink(StrEnum): diff --git a/llama_stack/providers/inline/telemetry/meta_reference/console_span_processor.py b/llama_stack/providers/inline/telemetry/meta_reference/console_span_processor.py index b4c77437d..78e49af94 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/console_span_processor.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/console_span_processor.py @@ -28,9 +28,6 @@ class ConsoleSpanProcessor(SpanProcessor): logger.info(f"[dim]{timestamp}[/dim] [bold magenta][START][/bold magenta] [dim]{span.name}[/dim]") def on_end(self, span: ReadableSpan) -> None: - if span.attributes and span.attributes.get("__autotraced__"): - return - timestamp = datetime.fromtimestamp(span.end_time / 1e9, tz=UTC).strftime("%H:%M:%S.%f")[:-3] span_context = f"[dim]{timestamp}[/dim] [bold magenta][END][/bold magenta] [dim]{span.name}[/dim]" if span.status.status_code == StatusCode.ERROR: @@ -67,7 +64,7 @@ class ConsoleSpanProcessor(SpanProcessor): for key, value in event.attributes.items(): if key.startswith("__") or key in ["message", "severity"]: continue - logger.info(f"/r[dim]{key}[/dim]: {value}") + logger.info(f"[dim]{key}[/dim]: {value}") def shutdown(self) -> None: """Shutdown the processor.""" diff --git a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py index c63fc23c2..d99255c79 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py @@ -4,10 +4,13 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import logging import threading from typing import Any from opentelemetry import metrics, trace + +logger = logging.getLogger(__name__) from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter from opentelemetry.sdk.metrics import MeterProvider @@ -36,7 +39,7 @@ from llama_stack.apis.telemetry import ( Trace, UnstructuredLogEvent, ) -from llama_stack.distribution.datatypes import Api +from llama_stack.core.datatypes import Api from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import ( ConsoleSpanProcessor, ) @@ -110,7 +113,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): if TelemetrySink.SQLITE in self.config.sinks: trace.get_tracer_provider().add_span_processor(SQLiteSpanProcessor(self.config.sqlite_db_path)) if TelemetrySink.CONSOLE in self.config.sinks: - trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor()) + trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor(print_attributes=True)) if TelemetrySink.OTEL_METRIC in self.config.sinks: self.meter = metrics.get_meter(__name__) @@ -126,9 +129,11 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): trace.get_tracer_provider().force_flush() async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None: + logger.debug(f"DEBUG: log_event called with event type: {type(event).__name__}") if isinstance(event, UnstructuredLogEvent): self._log_unstructured(event, ttl_seconds) elif isinstance(event, MetricEvent): + logger.debug("DEBUG: Routing MetricEvent to _log_metric") self._log_metric(event) elif isinstance(event, StructuredLogEvent): self._log_structured(event, ttl_seconds) @@ -188,6 +193,38 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): return _GLOBAL_STORAGE["gauges"][name] def _log_metric(self, event: MetricEvent) -> None: + # Always log to console if console sink is enabled (debug) + if TelemetrySink.CONSOLE in self.config.sinks: + logger.debug(f"METRIC: {event.metric}={event.value} {event.unit} {event.attributes}") + + # Add metric as an event to the current span + try: + with self._lock: + # Only try to add to span if we have a valid span_id + if event.span_id: + try: + span_id = int(event.span_id, 16) + span = _GLOBAL_STORAGE["active_spans"].get(span_id) + + if span: + timestamp_ns = int(event.timestamp.timestamp() * 1e9) + span.add_event( + name=f"metric.{event.metric}", + attributes={ + "value": event.value, + "unit": event.unit, + **(event.attributes or {}), + }, + timestamp=timestamp_ns, + ) + except (ValueError, KeyError): + # Invalid span_id or span not found, but we already logged to console above + pass + except Exception: + # Lock acquisition failed + logger.debug("Failed to acquire lock to add metric to span") + + # Log to OpenTelemetry meter if available if self.meter is None: return if isinstance(event.value, int): diff --git a/llama_stack/providers/inline/vector_io/faiss/faiss.py b/llama_stack/providers/inline/vector_io/faiss/faiss.py index edee4649d..7a5373726 100644 --- a/llama_stack/providers/inline/vector_io/faiss/faiss.py +++ b/llama_stack/providers/inline/vector_io/faiss/faiss.py @@ -15,6 +15,7 @@ import faiss import numpy as np from numpy.typing import NDArray +from llama_stack.apis.common.errors import VectorStoreNotFoundError from llama_stack.apis.files import Files from llama_stack.apis.inference import Inference, InterleavedContent from llama_stack.apis.vector_dbs import VectorDB @@ -159,8 +160,11 @@ class FaissIndex(EmbeddingIndex): for d, i in zip(distances[0], indices[0], strict=False): if i < 0: continue + score = 1.0 / float(d) if d != 0 else float("inf") + if score < score_threshold: + continue chunks.append(self.chunk_by_index[int(i)]) - scores.append(1.0 / float(d) if d != 0 else float("inf")) + scores.append(score) return QueryChunksResponse(chunks=chunks, scores=scores) @@ -285,7 +289,7 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr ) -> QueryChunksResponse: index = self.cache.get(vector_db_id) if index is None: - raise ValueError(f"Vector DB {vector_db_id} not found") + raise VectorStoreNotFoundError(vector_db_id) return await index.query_chunks(query, params) diff --git a/llama_stack/providers/inline/vector_io/qdrant/__init__.py b/llama_stack/providers/inline/vector_io/qdrant/__init__.py index ee33b3797..bc9014c68 100644 --- a/llama_stack/providers/inline/vector_io/qdrant/__init__.py +++ b/llama_stack/providers/inline/vector_io/qdrant/__init__.py @@ -4,14 +4,18 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.providers.datatypes import Api, ProviderSpec +from typing import Any + +from llama_stack.providers.datatypes import Api from .config import QdrantVectorIOConfig -async def get_adapter_impl(config: QdrantVectorIOConfig, deps: dict[Api, ProviderSpec]): +async def get_provider_impl(config: QdrantVectorIOConfig, deps: dict[Api, Any]): from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter - impl = QdrantVectorIOAdapter(config, deps[Api.inference]) + assert isinstance(config, QdrantVectorIOConfig), f"Unexpected config type: {type(config)}" + files_api = deps.get(Api.files) + impl = QdrantVectorIOAdapter(config, deps[Api.inference], files_api) await impl.initialize() return impl diff --git a/llama_stack/providers/inline/vector_io/qdrant/config.py b/llama_stack/providers/inline/vector_io/qdrant/config.py index 7cc91d918..e15c27ea1 100644 --- a/llama_stack/providers/inline/vector_io/qdrant/config.py +++ b/llama_stack/providers/inline/vector_io/qdrant/config.py @@ -9,15 +9,23 @@ from typing import Any from pydantic import BaseModel +from llama_stack.providers.utils.kvstore.config import ( + KVStoreConfig, + SqliteKVStoreConfig, +) from llama_stack.schema_utils import json_schema_type @json_schema_type class QdrantVectorIOConfig(BaseModel): path: str + kvstore: KVStoreConfig @classmethod def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]: return { "path": "${env.QDRANT_PATH:=~/.llama/" + __distro_dir__ + "}/" + "qdrant.db", + "kvstore": SqliteKVStoreConfig.sample_run_config( + __distro_dir__=__distro_dir__, db_name="qdrant_registry.db" + ), } diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py index cfa4e2263..1fff7b484 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py @@ -15,6 +15,7 @@ import numpy as np import sqlite_vec from numpy.typing import NDArray +from llama_stack.apis.common.errors import VectorStoreNotFoundError from llama_stack.apis.files import Files from llama_stack.apis.inference import Inference from llama_stack.apis.vector_dbs import VectorDB @@ -508,11 +509,11 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc return self.cache[vector_db_id] if self.vector_db_store is None: - raise ValueError(f"Vector DB {vector_db_id} not found") + raise VectorStoreNotFoundError(vector_db_id) vector_db = self.vector_db_store.get_vector_db(vector_db_id) if not vector_db: - raise ValueError(f"Vector DB {vector_db_id} not found") + raise VectorStoreNotFoundError(vector_db_id) index = VectorDBWithIndex( vector_db=vector_db, @@ -537,7 +538,7 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None: index = await self._get_and_cache_vector_db_index(vector_db_id) if not index: - raise ValueError(f"Vector DB {vector_db_id} not found") + raise VectorStoreNotFoundError(vector_db_id) # The VectorDBWithIndex helper is expected to compute embeddings via the inference_api # and then call our index's add_chunks. await index.insert_chunks(chunks) @@ -547,14 +548,14 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc ) -> QueryChunksResponse: index = await self._get_and_cache_vector_db_index(vector_db_id) if not index: - raise ValueError(f"Vector DB {vector_db_id} not found") + raise VectorStoreNotFoundError(vector_db_id) return await index.query_chunks(query, params) async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None: """Delete a chunk from a sqlite_vec index.""" index = await self._get_and_cache_vector_db_index(store_id) if not index: - raise ValueError(f"Vector DB {store_id} not found") + raise VectorStoreNotFoundError(store_id) for chunk_id in chunk_ids: # Use the index's delete_chunk method diff --git a/llama_stack/providers/registry/vector_io.py b/llama_stack/providers/registry/vector_io.py index 063b382df..846f7b88e 100644 --- a/llama_stack/providers/registry/vector_io.py +++ b/llama_stack/providers/registry/vector_io.py @@ -460,6 +460,7 @@ See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more module="llama_stack.providers.inline.vector_io.qdrant", config_class="llama_stack.providers.inline.vector_io.qdrant.QdrantVectorIOConfig", api_dependencies=[Api.inference], + optional_api_dependencies=[Api.files], description=r""" [Qdrant](https://qdrant.tech/documentation/) is an inline and remote vector database provider for Llama Stack. It allows you to store and query vectors directly in memory. @@ -516,6 +517,7 @@ Please refer to the inline provider documentation. """, ), api_dependencies=[Api.inference], + optional_api_dependencies=[Api.files], ), remote_provider_spec( Api.vector_io, diff --git a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py index fafd1d8ff..a34e354bf 100644 --- a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py +++ b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py @@ -6,8 +6,6 @@ from typing import Any from urllib.parse import parse_qs, urlparse -import datasets as hf_datasets - from llama_stack.apis.common.responses import PaginatedResponse from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Dataset @@ -73,6 +71,8 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): start_index: int | None = None, limit: int | None = None, ) -> PaginatedResponse: + import datasets as hf_datasets + dataset_def = self.dataset_infos[dataset_id] path, params = parse_hf_params(dataset_def) loaded_dataset = hf_datasets.load_dataset(path, **params) @@ -81,6 +81,8 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): return paginate_records(records, start_index, limit) async def append_rows(self, dataset_id: str, rows: list[dict[str, Any]]) -> None: + import datasets as hf_datasets + dataset_def = self.dataset_infos[dataset_id] path, params = parse_hf_params(dataset_def) loaded_dataset = hf_datasets.load_dataset(path, **params) diff --git a/llama_stack/providers/remote/datasetio/nvidia/README.md b/llama_stack/providers/remote/datasetio/nvidia/README.md index 8b1e2e6ee..74e0895f4 100644 --- a/llama_stack/providers/remote/datasetio/nvidia/README.md +++ b/llama_stack/providers/remote/datasetio/nvidia/README.md @@ -20,7 +20,7 @@ This provider enables dataset management using NVIDIA's NeMo Customizer service. Build the NVIDIA environment: ```bash -llama stack build --template nvidia --image-type conda +llama stack build --distro nvidia --image-type venv ``` ### Basic Usage using the LlamaStack Python Client @@ -34,7 +34,7 @@ os.environ["NVIDIA_API_KEY"] = "your-api-key" os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" os.environ["NVIDIA_DATASET_NAMESPACE"] = "default" os.environ["NVIDIA_PROJECT_ID"] = "test-project" -from llama_stack.distribution.library_client import LlamaStackAsLibraryClient +from llama_stack.core.library_client import LlamaStackAsLibraryClient client = LlamaStackAsLibraryClient("nvidia") client.initialize() diff --git a/llama_stack/providers/remote/eval/nvidia/__init__.py b/llama_stack/providers/remote/eval/nvidia/__init__.py index 55e3754f3..1314fdb83 100644 --- a/llama_stack/providers/remote/eval/nvidia/__init__.py +++ b/llama_stack/providers/remote/eval/nvidia/__init__.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from typing import Any -from llama_stack.distribution.datatypes import Api +from llama_stack.core.datatypes import Api from .config import NVIDIAEvalConfig diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index c76aa39f3..ca4c7b578 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -39,7 +39,7 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) -from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.core.request_headers import NeedsRequestProviderData from llama_stack.log import get_logger from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, diff --git a/llama_stack/providers/remote/inference/nvidia/NVIDIA.md b/llama_stack/providers/remote/inference/nvidia/NVIDIA.md index a353c67f5..2505718e0 100644 --- a/llama_stack/providers/remote/inference/nvidia/NVIDIA.md +++ b/llama_stack/providers/remote/inference/nvidia/NVIDIA.md @@ -18,7 +18,7 @@ This provider enables running inference using NVIDIA NIM. Build the NVIDIA environment: ```bash -llama stack build --template nvidia --image-type conda +llama stack build --distro nvidia --image-type venv ``` ### Basic Usage using the LlamaStack Python Client @@ -33,7 +33,7 @@ os.environ["NVIDIA_API_KEY"] = ( ) os.environ["NVIDIA_BASE_URL"] = "http://nim.test" # NIM URL -from llama_stack.distribution.library_client import LlamaStackAsLibraryClient +from llama_stack.core.library_client import LlamaStackAsLibraryClient client = LlamaStackAsLibraryClient("nvidia") client.initialize() diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 098e4d324..26b4dec76 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -112,7 +112,8 @@ class OllamaInferenceAdapter( @property def openai_client(self) -> AsyncOpenAI: if self._openai_client is None: - self._openai_client = AsyncOpenAI(base_url=f"{self.config.url}/v1", api_key="ollama") + url = self.config.url.rstrip("/") + self._openai_client = AsyncOpenAI(base_url=f"{url}/v1", api_key="ollama") return self._openai_client async def initialize(self) -> None: diff --git a/llama_stack/providers/remote/inference/passthrough/passthrough.py b/llama_stack/providers/remote/inference/passthrough/passthrough.py index d5b3a5973..2f1cd40f2 100644 --- a/llama_stack/providers/remote/inference/passthrough/passthrough.py +++ b/llama_stack/providers/remote/inference/passthrough/passthrough.py @@ -34,7 +34,7 @@ from llama_stack.apis.inference import ( ToolPromptFormat, ) from llama_stack.apis.models import Model -from llama_stack.distribution.library_client import convert_pydantic_to_json_value, convert_to_pydantic +from llama_stack.core.library_client import convert_pydantic_to_json_value, convert_to_pydantic from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index 8ba705f59..96469acac 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -4,178 +4,13 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import json -from collections.abc import Iterable - -import requests -from openai.types.chat import ( - ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage, -) -from openai.types.chat import ( - ChatCompletionContentPartImageParam as OpenAIChatCompletionContentPartImageParam, -) -from openai.types.chat import ( - ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam, -) -from openai.types.chat import ( - ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam, -) -from openai.types.chat import ( - ChatCompletionMessageParam as OpenAIChatCompletionMessage, -) -from openai.types.chat import ( - ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCall, -) -from openai.types.chat import ( - ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage, -) -from openai.types.chat import ( - ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage, -) -from openai.types.chat import ( - ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage, -) -from openai.types.chat.chat_completion_content_part_image_param import ( - ImageURL as OpenAIImageURL, -) -from openai.types.chat.chat_completion_message_tool_call_param import ( - Function as OpenAIFunction, -) - -from llama_stack.apis.common.content_types import ( - ImageContentItem, - InterleavedContent, - TextContentItem, -) -from llama_stack.apis.inference import ( - ChatCompletionRequest, - CompletionMessage, - JsonSchemaResponseFormat, - Message, - SystemMessage, - ToolChoice, - ToolResponseMessage, - UserMessage, -) -from llama_stack.apis.models import Model -from llama_stack.log import get_logger -from llama_stack.models.llama.datatypes import BuiltinTool from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin -from llama_stack.providers.utils.inference.openai_compat import ( - convert_tooldef_to_openai_tool, - get_sampling_options, -) -from llama_stack.providers.utils.inference.prompt_adapter import convert_image_content_to_url from .config import SambaNovaImplConfig from .models import MODEL_ENTRIES -logger = get_logger(name=__name__, category="inference") - - -async def convert_message_to_openai_dict_with_b64_images( - message: Message | dict, -) -> OpenAIChatCompletionMessage: - """ - Convert a Message to an OpenAI API-compatible dictionary. - """ - # users can supply a dict instead of a Message object, we'll - # convert it to a Message object and proceed with some type safety. - if isinstance(message, dict): - if "role" not in message: - raise ValueError("role is required in message") - if message["role"] == "user": - message = UserMessage(**message) - elif message["role"] == "assistant": - message = CompletionMessage(**message) - elif message["role"] == "tool": - message = ToolResponseMessage(**message) - elif message["role"] == "system": - message = SystemMessage(**message) - else: - raise ValueError(f"Unsupported message role: {message['role']}") - - # Map Llama Stack spec to OpenAI spec - - # str -> str - # {"type": "text", "text": ...} -> {"type": "text", "text": ...} - # {"type": "image", "image": {"url": {"uri": ...}}} -> {"type": "image_url", "image_url": {"url": ...}} - # {"type": "image", "image": {"data": ...}} -> {"type": "image_url", "image_url": {"url": "data:image/?;base64,..."}} - # List[...] -> List[...] - async def _convert_message_content( - content: InterleavedContent, - ) -> str | Iterable[OpenAIChatCompletionContentPartParam]: - async def impl( - content_: InterleavedContent, - ) -> str | OpenAIChatCompletionContentPartParam | list[OpenAIChatCompletionContentPartParam]: - # Llama Stack and OpenAI spec match for str and text input - if isinstance(content_, str): - return content_ - elif isinstance(content_, TextContentItem): - return OpenAIChatCompletionContentPartTextParam( - type="text", - text=content_.text, - ) - elif isinstance(content_, ImageContentItem): - return OpenAIChatCompletionContentPartImageParam( - type="image_url", - image_url=OpenAIImageURL(url=await convert_image_content_to_url(content_, download=True)), - ) - elif isinstance(content_, list): - return [await impl(item) for item in content_] - else: - raise ValueError(f"Unsupported content type: {type(content_)}") - - ret = await impl(content) - - # OpenAI*Message expects a str or list - if isinstance(ret, str) or isinstance(ret, list): - return ret - else: - return [ret] - - out: OpenAIChatCompletionMessage = None - if isinstance(message, UserMessage): - out = OpenAIChatCompletionUserMessage( - role="user", - content=await _convert_message_content(message.content), - ) - elif isinstance(message, CompletionMessage): - out = OpenAIChatCompletionAssistantMessage( - role="assistant", - content=await _convert_message_content(message.content), - tool_calls=[ - OpenAIChatCompletionMessageToolCall( - id=tool.call_id, - function=OpenAIFunction( - name=tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value, - arguments=json.dumps(tool.arguments), - ), - type="function", - ) - for tool in message.tool_calls - ] - or None, - ) - elif isinstance(message, ToolResponseMessage): - out = OpenAIChatCompletionToolMessage( - role="tool", - tool_call_id=message.call_id, - content=await _convert_message_content(message.content), - ) - elif isinstance(message, SystemMessage): - out = OpenAIChatCompletionSystemMessage( - role="system", - content=await _convert_message_content(message.content), - ) - else: - raise ValueError(f"Unsupported message type: {type(message)}") - - return out - class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin): - _config: SambaNovaImplConfig - def __init__(self, config: SambaNovaImplConfig): self.config = config self.environment_available_models = [] @@ -185,89 +20,7 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin): litellm_provider_name="sambanova", api_key_from_config=self.config.api_key.get_secret_value() if self.config.api_key else None, provider_data_api_key_field="sambanova_api_key", + openai_compat_api_base=self.config.url, + download_images=True, # SambaNova requires base64 image encoding + json_schema_strict=False, # SambaNova doesn't support strict=True yet ) - - def _get_api_key(self) -> str: - config_api_key = self.config.api_key if self.config.api_key else None - if config_api_key: - return config_api_key.get_secret_value() - else: - provider_data = self.get_request_provider_data() - if provider_data is None or not provider_data.sambanova_api_key: - raise ValueError( - 'Pass Sambanova API Key in the header X-LlamaStack-Provider-Data as { "sambanova_api_key": }' - ) - return provider_data.sambanova_api_key - - async def _get_params(self, request: ChatCompletionRequest) -> dict: - input_dict = {} - - input_dict["messages"] = [await convert_message_to_openai_dict_with_b64_images(m) for m in request.messages] - if fmt := request.response_format: - if not isinstance(fmt, JsonSchemaResponseFormat): - raise ValueError( - f"Unsupported response format: {type(fmt)}. Only JsonSchemaResponseFormat is supported." - ) - - fmt = fmt.json_schema - name = fmt["title"] - del fmt["title"] - fmt["additionalProperties"] = False - - # Apply additionalProperties: False recursively to all objects - fmt = self._add_additional_properties_recursive(fmt) - - input_dict["response_format"] = { - "type": "json_schema", - "json_schema": { - "name": name, - "schema": fmt, - "strict": False, - }, - } - if request.tools: - input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools] - if request.tool_config.tool_choice: - input_dict["tool_choice"] = ( - request.tool_config.tool_choice.value - if isinstance(request.tool_config.tool_choice, ToolChoice) - else request.tool_config.tool_choice - ) - - provider_data = self.get_request_provider_data() - key_field = self.provider_data_api_key_field - if provider_data and getattr(provider_data, key_field, None): - api_key = getattr(provider_data, key_field) - else: - api_key = self._get_api_key() - - return { - "model": request.model, - "api_key": api_key, - "api_base": self.config.url, - **input_dict, - "stream": request.stream, - **get_sampling_options(request.sampling_params), - } - - async def register_model(self, model: Model) -> Model: - model_id = self.get_provider_model_id(model.provider_resource_id) - - list_models_url = self.config.url + "/models" - if len(self.environment_available_models) == 0: - try: - response = requests.get(list_models_url) - response.raise_for_status() - except requests.exceptions.RequestException as e: - raise RuntimeError(f"Request to {list_models_url} failed") from e - self.environment_available_models = [model.get("id") for model in response.json().get("data", {})] - - if model_id.split("sambanova/")[-1] not in self.environment_available_models: - logger.warning(f"Model {model_id} not available in {list_models_url}") - return model - - async def initialize(self): - await super().initialize() - - async def shutdown(self): - await super().shutdown() diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 46094c146..a06e4173b 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -38,7 +38,7 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) -from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.core.request_headers import NeedsRequestProviderData from llama_stack.log import get_logger from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.openai_compat import ( diff --git a/llama_stack/providers/remote/post_training/nvidia/README.md b/llama_stack/providers/remote/post_training/nvidia/README.md index 3ef538d29..6647316df 100644 --- a/llama_stack/providers/remote/post_training/nvidia/README.md +++ b/llama_stack/providers/remote/post_training/nvidia/README.md @@ -22,7 +22,7 @@ This provider enables fine-tuning of LLMs using NVIDIA's NeMo Customizer service Build the NVIDIA environment: ```bash -llama stack build --template nvidia --image-type conda +llama stack build --distro nvidia --image-type venv ``` ### Basic Usage using the LlamaStack Python Client @@ -40,7 +40,7 @@ os.environ["NVIDIA_DATASET_NAMESPACE"] = "default" os.environ["NVIDIA_PROJECT_ID"] = "test-project" os.environ["NVIDIA_OUTPUT_MODEL_DIR"] = "test-example-model@v1" -from llama_stack.distribution.library_client import LlamaStackAsLibraryClient +from llama_stack.core.library_client import LlamaStackAsLibraryClient client = LlamaStackAsLibraryClient("nvidia") client.initialize() diff --git a/llama_stack/providers/remote/safety/bedrock/bedrock.py b/llama_stack/providers/remote/safety/bedrock/bedrock.py index c43b51073..1895e7507 100644 --- a/llama_stack/providers/remote/safety/bedrock/bedrock.py +++ b/llama_stack/providers/remote/safety/bedrock/bedrock.py @@ -52,6 +52,9 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate): f"Shield {shield.provider_resource_id} with version {shield.params['guardrailVersion']} not found in Bedrock" ) + async def unregister_shield(self, identifier: str) -> None: + pass + async def run_shield( self, shield_id: str, messages: list[Message], params: dict[str, Any] = None ) -> RunShieldResponse: diff --git a/llama_stack/providers/remote/safety/nvidia/README.md b/llama_stack/providers/remote/safety/nvidia/README.md index 434db32fb..784ab464f 100644 --- a/llama_stack/providers/remote/safety/nvidia/README.md +++ b/llama_stack/providers/remote/safety/nvidia/README.md @@ -19,7 +19,7 @@ This provider enables safety checks and guardrails for LLM interactions using NV Build the NVIDIA environment: ```bash -llama stack build --template nvidia --image-type conda +llama stack build --distro nvidia --image-type venv ``` ### Basic Usage using the LlamaStack Python Client @@ -32,7 +32,7 @@ import os os.environ["NVIDIA_API_KEY"] = "your-api-key" os.environ["NVIDIA_GUARDRAILS_URL"] = "http://guardrails.test" -from llama_stack.distribution.library_client import LlamaStackAsLibraryClient +from llama_stack.core.library_client import LlamaStackAsLibraryClient client = LlamaStackAsLibraryClient("nvidia") client.initialize() diff --git a/llama_stack/providers/remote/safety/nvidia/nvidia.py b/llama_stack/providers/remote/safety/nvidia/nvidia.py index 411badb1c..7f17b1cb6 100644 --- a/llama_stack/providers/remote/safety/nvidia/nvidia.py +++ b/llama_stack/providers/remote/safety/nvidia/nvidia.py @@ -40,6 +40,9 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate): if not shield.provider_resource_id: raise ValueError("Shield model not provided.") + async def unregister_shield(self, identifier: str) -> None: + pass + async def run_shield( self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None ) -> RunShieldResponse: diff --git a/llama_stack/providers/remote/safety/sambanova/sambanova.py b/llama_stack/providers/remote/safety/sambanova/sambanova.py index 1a65f6aa1..6c7190afe 100644 --- a/llama_stack/providers/remote/safety/sambanova/sambanova.py +++ b/llama_stack/providers/remote/safety/sambanova/sambanova.py @@ -19,7 +19,7 @@ from llama_stack.apis.safety import ( ViolationLevel, ) from llama_stack.apis.shields import Shield -from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.core.request_headers import NeedsRequestProviderData from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new @@ -68,6 +68,9 @@ class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProvide ): logger.warning(f"Shield {shield.provider_resource_id} not available in {list_models_url}") + async def unregister_shield(self, identifier: str) -> None: + pass + async def run_shield( self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None ) -> RunShieldResponse: diff --git a/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py b/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py index 7e82cb6d4..e40903969 100644 --- a/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py +++ b/llama_stack/providers/remote/tool_runtime/bing_search/bing_search.py @@ -18,7 +18,7 @@ from llama_stack.apis.tools import ( ToolParameter, ToolRuntime, ) -from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.core.request_headers import NeedsRequestProviderData from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate from .config import BingSearchToolConfig diff --git a/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py b/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py index b96b9e59c..ba3b910d5 100644 --- a/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py +++ b/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py @@ -17,7 +17,7 @@ from llama_stack.apis.tools import ( ToolParameter, ToolRuntime, ) -from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.core.request_headers import NeedsRequestProviderData from llama_stack.models.llama.datatypes import BuiltinTool from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate diff --git a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py index a9b252dfe..578bb6d34 100644 --- a/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py +++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py @@ -15,7 +15,7 @@ from llama_stack.apis.tools import ( ToolInvocationResult, ToolRuntime, ) -from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.core.request_headers import NeedsRequestProviderData from llama_stack.log import get_logger from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool, list_mcp_tools diff --git a/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py b/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py index 1fe91fd7f..976ec9c57 100644 --- a/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py +++ b/llama_stack/providers/remote/tool_runtime/tavily_search/tavily_search.py @@ -18,7 +18,7 @@ from llama_stack.apis.tools import ( ToolParameter, ToolRuntime, ) -from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.core.request_headers import NeedsRequestProviderData from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate from .config import TavilySearchToolConfig diff --git a/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py b/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py index 6e1d0f61d..f12a44958 100644 --- a/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py +++ b/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py @@ -18,7 +18,7 @@ from llama_stack.apis.tools import ( ToolParameter, ToolRuntime, ) -from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.core.request_headers import NeedsRequestProviderData from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate from .config import WolframAlphaToolConfig diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py index f1652a80e..b09edb65c 100644 --- a/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -7,12 +7,12 @@ import asyncio import logging import os -import re from typing import Any from numpy.typing import NDArray -from pymilvus import DataType, Function, FunctionType, MilvusClient +from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusClient, RRFRanker, WeightedRanker +from llama_stack.apis.common.errors import VectorStoreNotFoundError from llama_stack.apis.files.files import Files from llama_stack.apis.inference import Inference, InterleavedContent from llama_stack.apis.vector_dbs import VectorDB @@ -27,9 +27,11 @@ from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.vector_store import ( + RERANKER_TYPE_WEIGHTED, EmbeddingIndex, VectorDBWithIndex, ) +from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collection_name from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig @@ -43,14 +45,6 @@ OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:milvus:{VERSION OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:milvus:{VERSION}::" -def sanitize_collection_name(name: str) -> str: - """ - Sanitize collection name to ensure it only contains numbers, letters, and underscores. - Any other characters are replaced with underscores. - """ - return re.sub(r"[^a-zA-Z0-9_]", "_", name) - - class MilvusIndex(EmbeddingIndex): def __init__( self, client: MilvusClient, collection_name: str, consistency_level="Strong", kvstore: KVStore | None = None @@ -245,7 +239,53 @@ class MilvusIndex(EmbeddingIndex): reranker_type: str, reranker_params: dict[str, Any] | None = None, ) -> QueryChunksResponse: - raise NotImplementedError("Hybrid search is not supported in Milvus") + """ + Hybrid search using Milvus's native hybrid search capabilities. + + This implementation uses Milvus's hybrid_search method which combines + vector search and BM25 search with configurable reranking strategies. + """ + search_requests = [] + + # nprobe: Controls search accuracy vs performance trade-off + # 10 balances these trade-offs for RAG applications + search_requests.append( + AnnSearchRequest(data=[embedding.tolist()], anns_field="vector", param={"nprobe": 10}, limit=k) + ) + + # drop_ratio_search: Filters low-importance terms to improve search performance + # 0.2 balances noise reduction with recall + search_requests.append( + AnnSearchRequest(data=[query_string], anns_field="sparse", param={"drop_ratio_search": 0.2}, limit=k) + ) + + if reranker_type == RERANKER_TYPE_WEIGHTED: + alpha = (reranker_params or {}).get("alpha", 0.5) + rerank = WeightedRanker(alpha, 1 - alpha) + else: + impact_factor = (reranker_params or {}).get("impact_factor", 60.0) + rerank = RRFRanker(impact_factor) + + search_res = await asyncio.to_thread( + self.client.hybrid_search, + collection_name=self.collection_name, + reqs=search_requests, + ranker=rerank, + limit=k, + output_fields=["chunk_content"], + ) + + chunks = [] + scores = [] + for res in search_res[0]: + chunk = Chunk(**res["entity"]["chunk_content"]) + chunks.append(chunk) + scores.append(res["distance"]) + + filtered_chunks = [chunk for chunk, score in zip(chunks, scores, strict=False) if score >= score_threshold] + filtered_scores = [score for score in scores if score >= score_threshold] + + return QueryChunksResponse(chunks=filtered_chunks, scores=filtered_scores) async def delete_chunk(self, chunk_id: str) -> None: """Remove a chunk from the Milvus collection.""" @@ -329,11 +369,11 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP return self.cache[vector_db_id] if self.vector_db_store is None: - raise ValueError(f"Vector DB {vector_db_id} not found") + raise VectorStoreNotFoundError(vector_db_id) vector_db = await self.vector_db_store.get_vector_db(vector_db_id) if not vector_db: - raise ValueError(f"Vector DB {vector_db_id} not found") + raise VectorStoreNotFoundError(vector_db_id) index = VectorDBWithIndex( vector_db=vector_db, @@ -356,7 +396,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP ) -> None: index = await self._get_and_cache_vector_db_index(vector_db_id) if not index: - raise ValueError(f"Vector DB {vector_db_id} not found") + raise VectorStoreNotFoundError(vector_db_id) await index.insert_chunks(chunks) @@ -368,7 +408,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP ) -> QueryChunksResponse: index = await self._get_and_cache_vector_db_index(vector_db_id) if not index: - raise ValueError(f"Vector DB {vector_db_id} not found") + raise VectorStoreNotFoundError(vector_db_id) if params and params.get("mode") == "keyword": # Check if this is inline Milvus (Milvus-Lite) @@ -384,7 +424,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP """Delete a chunk from a milvus vector store.""" index = await self._get_and_cache_vector_db_index(store_id) if not index: - raise ValueError(f"Vector DB {store_id} not found") + raise VectorStoreNotFoundError(store_id) for chunk_id in chunk_ids: # Use the index's delete_chunk method diff --git a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py index 643c27328..b1645ac5a 100644 --- a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py +++ b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py @@ -13,6 +13,7 @@ from psycopg2 import sql from psycopg2.extras import Json, execute_values from pydantic import BaseModel, TypeAdapter +from llama_stack.apis.common.errors import VectorStoreNotFoundError from llama_stack.apis.files.files import Files from llama_stack.apis.inference import InterleavedContent from llama_stack.apis.vector_dbs import VectorDB @@ -131,8 +132,11 @@ class PGVectorIndex(EmbeddingIndex): chunks = [] scores = [] for doc, dist in results: + score = 1.0 / float(dist) if dist != 0 else float("inf") + if score < score_threshold: + continue chunks.append(Chunk(**doc)) - scores.append(1.0 / float(dist) if dist != 0 else float("inf")) + scores.append(score) return QueryChunksResponse(chunks=chunks, scores=scores) @@ -275,7 +279,7 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco """Delete a chunk from a PostgreSQL vector store.""" index = await self._get_and_cache_vector_db_index(store_id) if not index: - raise ValueError(f"Vector DB {store_id} not found") + raise VectorStoreNotFoundError(store_id) for chunk_id in chunk_ids: # Use the index's delete_chunk method diff --git a/llama_stack/providers/remote/vector_io/qdrant/__init__.py b/llama_stack/providers/remote/vector_io/qdrant/__init__.py index 029de285f..6ce98b17c 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/__init__.py +++ b/llama_stack/providers/remote/vector_io/qdrant/__init__.py @@ -12,6 +12,7 @@ from .config import QdrantVectorIOConfig async def get_adapter_impl(config: QdrantVectorIOConfig, deps: dict[Api, ProviderSpec]): from .qdrant import QdrantVectorIOAdapter - impl = QdrantVectorIOAdapter(config, deps[Api.inference]) + files_api = deps.get(Api.files) + impl = QdrantVectorIOAdapter(config, deps[Api.inference], files_api) await impl.initialize() return impl diff --git a/llama_stack/providers/remote/vector_io/qdrant/config.py b/llama_stack/providers/remote/vector_io/qdrant/config.py index 314d3f5f1..ff5506236 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/config.py +++ b/llama_stack/providers/remote/vector_io/qdrant/config.py @@ -8,6 +8,10 @@ from typing import Any from pydantic import BaseModel +from llama_stack.providers.utils.kvstore.config import ( + KVStoreConfig, + SqliteKVStoreConfig, +) from llama_stack.schema_utils import json_schema_type @@ -23,9 +27,14 @@ class QdrantVectorIOConfig(BaseModel): prefix: str | None = None timeout: int | None = None host: str | None = None + kvstore: KVStoreConfig @classmethod - def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]: + def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]: return { - "api_key": "${env.QDRANT_API_KEY}", + "api_key": "${env.QDRANT_API_KEY:=}", + "kvstore": SqliteKVStoreConfig.sample_run_config( + __distro_dir__=__distro_dir__, + db_name="qdrant_registry.db", + ), } diff --git a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py index 3df3da27f..144da0f4f 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py +++ b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import asyncio import logging import uuid from typing import Any @@ -12,25 +13,21 @@ from numpy.typing import NDArray from qdrant_client import AsyncQdrantClient, models from qdrant_client.models import PointStruct +from llama_stack.apis.common.errors import VectorStoreNotFoundError +from llama_stack.apis.files import Files from llama_stack.apis.inference import InterleavedContent from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import ( Chunk, QueryChunksResponse, - SearchRankingOptions, VectorIO, VectorStoreChunkingStrategy, - VectorStoreDeleteResponse, - VectorStoreFileContentsResponse, VectorStoreFileObject, - VectorStoreFileStatus, - VectorStoreListFilesResponse, - VectorStoreListResponse, - VectorStoreObject, - VectorStoreSearchResponsePage, ) from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig +from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl +from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.vector_store import ( EmbeddingIndex, VectorDBWithIndex, @@ -41,6 +38,10 @@ from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig log = logging.getLogger(__name__) CHUNK_ID_KEY = "_chunk_id" +# KV store prefixes for vector databases +VERSION = "v3" +VECTOR_DBS_PREFIX = f"vector_dbs:qdrant:{VERSION}::" + def convert_id(_id: str) -> str: """ @@ -58,6 +59,11 @@ class QdrantIndex(EmbeddingIndex): self.client = client self.collection_name = collection_name + async def initialize(self) -> None: + # Qdrant collections are created on-demand in add_chunks + # If the collection does not exist, it will be created in add_chunks. + pass + async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): assert len(chunks) == len(embeddings), ( f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" @@ -83,7 +89,15 @@ class QdrantIndex(EmbeddingIndex): await self.client.upsert(collection_name=self.collection_name, points=points) async def delete_chunk(self, chunk_id: str) -> None: - raise NotImplementedError("delete_chunk is not supported in qdrant") + """Remove a chunk from the Qdrant collection.""" + try: + await self.client.delete( + collection_name=self.collection_name, + points_selector=models.PointIdsList(points=[convert_id(chunk_id)]), + ) + except Exception as e: + log.error(f"Error deleting chunk {chunk_id} from Qdrant collection {self.collection_name}: {e}") + raise async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: results = ( @@ -135,17 +149,41 @@ class QdrantIndex(EmbeddingIndex): await self.client.delete_collection(collection_name=self.collection_name) -class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): +class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate): def __init__( - self, config: RemoteQdrantVectorIOConfig | InlineQdrantVectorIOConfig, inference_api: Api.inference + self, + config: RemoteQdrantVectorIOConfig | InlineQdrantVectorIOConfig, + inference_api: Api.inference, + files_api: Files | None = None, ) -> None: self.config = config self.client: AsyncQdrantClient = None self.cache = {} self.inference_api = inference_api + self.files_api = files_api + self.vector_db_store = None + self.kvstore: KVStore | None = None + self.openai_vector_stores: dict[str, dict[str, Any]] = {} + self._qdrant_lock = asyncio.Lock() async def initialize(self) -> None: - self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True)) + client_config = self.config.model_dump(exclude_none=True, exclude={"kvstore"}) + self.client = AsyncQdrantClient(**client_config) + self.kvstore = await kvstore_impl(self.config.kvstore) + + start_key = VECTOR_DBS_PREFIX + end_key = f"{VECTOR_DBS_PREFIX}\xff" + stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key) + + for vector_db_data in stored_vector_dbs: + vector_db = VectorDB.model_validate_json(vector_db_data) + index = VectorDBWithIndex( + vector_db, + QdrantIndex(self.client, vector_db.identifier), + self.inference_api, + ) + self.cache[vector_db.identifier] = index + self.openai_vector_stores = await self._load_openai_vector_stores() async def shutdown(self) -> None: await self.client.close() @@ -154,6 +192,10 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): self, vector_db: VectorDB, ) -> None: + assert self.kvstore is not None + key = f"{VECTOR_DBS_PREFIX}{vector_db.identifier}" + await self.kvstore.set(key=key, value=vector_db.model_dump_json()) + index = VectorDBWithIndex( vector_db=vector_db, index=QdrantIndex(self.client, vector_db.identifier), @@ -167,13 +209,19 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): await self.cache[vector_db_id].index.delete() del self.cache[vector_db_id] + assert self.kvstore is not None + await self.kvstore.delete(f"{VECTOR_DBS_PREFIX}{vector_db_id}") + async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None: if vector_db_id in self.cache: return self.cache[vector_db_id] + if self.vector_db_store is None: + raise ValueError(f"Vector DB not found {vector_db_id}") + vector_db = await self.vector_db_store.get_vector_db(vector_db_id) if not vector_db: - raise ValueError(f"Vector DB {vector_db_id} not found") + raise VectorStoreNotFoundError(vector_db_id) index = VectorDBWithIndex( vector_db=vector_db, @@ -191,7 +239,7 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): ) -> None: index = await self._get_and_cache_vector_db_index(vector_db_id) if not index: - raise ValueError(f"Vector DB {vector_db_id} not found") + raise VectorStoreNotFoundError(vector_db_id) await index.insert_chunks(chunks) @@ -203,65 +251,10 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): ) -> QueryChunksResponse: index = await self._get_and_cache_vector_db_index(vector_db_id) if not index: - raise ValueError(f"Vector DB {vector_db_id} not found") + raise VectorStoreNotFoundError(vector_db_id) return await index.query_chunks(query, params) - async def openai_create_vector_store( - self, - name: str, - file_ids: list[str] | None = None, - expires_after: dict[str, Any] | None = None, - chunking_strategy: dict[str, Any] | None = None, - metadata: dict[str, Any] | None = None, - embedding_model: str | None = None, - embedding_dimension: int | None = 384, - provider_id: str | None = None, - ) -> VectorStoreObject: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant") - - async def openai_list_vector_stores( - self, - limit: int | None = 20, - order: str | None = "desc", - after: str | None = None, - before: str | None = None, - ) -> VectorStoreListResponse: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant") - - async def openai_retrieve_vector_store( - self, - vector_store_id: str, - ) -> VectorStoreObject: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant") - - async def openai_update_vector_store( - self, - vector_store_id: str, - name: str | None = None, - expires_after: dict[str, Any] | None = None, - metadata: dict[str, Any] | None = None, - ) -> VectorStoreObject: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant") - - async def openai_delete_vector_store( - self, - vector_store_id: str, - ) -> VectorStoreDeleteResponse: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant") - - async def openai_search_vector_store( - self, - vector_store_id: str, - query: str | list[str], - filters: dict[str, Any] | None = None, - max_num_results: int | None = 10, - ranking_options: SearchRankingOptions | None = None, - rewrite_query: bool | None = False, - search_mode: str | None = "vector", - ) -> VectorStoreSearchResponsePage: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant") - async def openai_attach_file_to_vector_store( self, vector_store_id: str, @@ -269,47 +262,14 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): attributes: dict[str, Any] | None = None, chunking_strategy: VectorStoreChunkingStrategy | None = None, ) -> VectorStoreFileObject: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant") - - async def openai_list_files_in_vector_store( - self, - vector_store_id: str, - limit: int | None = 20, - order: str | None = "desc", - after: str | None = None, - before: str | None = None, - filter: VectorStoreFileStatus | None = None, - ) -> VectorStoreListFilesResponse: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant") - - async def openai_retrieve_vector_store_file( - self, - vector_store_id: str, - file_id: str, - ) -> VectorStoreFileObject: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant") - - async def openai_retrieve_vector_store_file_contents( - self, - vector_store_id: str, - file_id: str, - ) -> VectorStoreFileContentsResponse: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant") - - async def openai_update_vector_store_file( - self, - vector_store_id: str, - file_id: str, - attributes: dict[str, Any] | None = None, - ) -> VectorStoreFileObject: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant") - - async def openai_delete_vector_store_file( - self, - vector_store_id: str, - file_id: str, - ) -> VectorStoreFileObject: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant") + # Qdrant doesn't allow multiple clients to access the same storage path simultaneously. + async with self._qdrant_lock: + await super().openai_attach_file_to_vector_store(vector_store_id, file_id, attributes, chunking_strategy) async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant") + """Delete chunks from a Qdrant vector store.""" + index = await self._get_and_cache_vector_db_index(store_id) + if not index: + raise ValueError(f"Vector DB {store_id} not found") + for chunk_id in chunk_ids: + await index.index.delete_chunk(chunk_id) diff --git a/llama_stack/providers/remote/vector_io/weaviate/__init__.py b/llama_stack/providers/remote/vector_io/weaviate/__init__.py index 22e116c22..9272b21e2 100644 --- a/llama_stack/providers/remote/vector_io/weaviate/__init__.py +++ b/llama_stack/providers/remote/vector_io/weaviate/__init__.py @@ -12,6 +12,6 @@ from .config import WeaviateVectorIOConfig async def get_adapter_impl(config: WeaviateVectorIOConfig, deps: dict[Api, ProviderSpec]): from .weaviate import WeaviateVectorIOAdapter - impl = WeaviateVectorIOAdapter(config, deps[Api.inference]) + impl = WeaviateVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files, None)) await impl.initialize() return impl diff --git a/llama_stack/providers/remote/vector_io/weaviate/config.py b/llama_stack/providers/remote/vector_io/weaviate/config.py index 4283b8d3b..b693e294e 100644 --- a/llama_stack/providers/remote/vector_io/weaviate/config.py +++ b/llama_stack/providers/remote/vector_io/weaviate/config.py @@ -12,18 +12,24 @@ from llama_stack.providers.utils.kvstore.config import ( KVStoreConfig, SqliteKVStoreConfig, ) +from llama_stack.schema_utils import json_schema_type -class WeaviateRequestProviderData(BaseModel): - weaviate_api_key: str - weaviate_cluster_url: str +@json_schema_type +class WeaviateVectorIOConfig(BaseModel): + weaviate_api_key: str | None = Field(description="The API key for the Weaviate instance", default=None) + weaviate_cluster_url: str | None = Field(description="The URL of the Weaviate cluster", default="localhost:8080") kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None) - -class WeaviateVectorIOConfig(BaseModel): @classmethod - def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]: + def sample_run_config( + cls, + __distro_dir__: str, + **kwargs: Any, + ) -> dict[str, Any]: return { + "weaviate_api_key": None, + "weaviate_cluster_url": "${env.WEAVIATE_CLUSTER_URL:=localhost:8080}", "kvstore": SqliteKVStoreConfig.sample_run_config( __distro_dir__=__distro_dir__, db_name="weaviate_registry.db", diff --git a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py index 543835e20..11da8902c 100644 --- a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py +++ b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py @@ -14,19 +14,24 @@ from weaviate.classes.init import Auth from weaviate.classes.query import Filter from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.common.errors import VectorStoreNotFoundError from llama_stack.apis.files.files import Files from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO -from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.core.request_headers import NeedsRequestProviderData from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore.api import KVStore +from llama_stack.providers.utils.memory.openai_vector_store_mixin import ( + OpenAIVectorStoreMixin, +) from llama_stack.providers.utils.memory.vector_store import ( EmbeddingIndex, VectorDBWithIndex, ) +from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collection_name -from .config import WeaviateRequestProviderData, WeaviateVectorIOConfig +from .config import WeaviateVectorIOConfig log = logging.getLogger(__name__) @@ -39,11 +44,19 @@ OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_conten class WeaviateIndex(EmbeddingIndex): - def __init__(self, client: weaviate.Client, collection_name: str, kvstore: KVStore | None = None): + def __init__( + self, + client: weaviate.Client, + collection_name: str, + kvstore: KVStore | None = None, + ): self.client = client - self.collection_name = collection_name + self.collection_name = sanitize_collection_name(collection_name, weaviate_format=True) self.kvstore = kvstore + async def initialize(self): + pass + async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): assert len(chunks) == len(embeddings), ( f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" @@ -67,10 +80,13 @@ class WeaviateIndex(EmbeddingIndex): collection.data.insert_many(data_objects) async def delete_chunk(self, chunk_id: str) -> None: - raise NotImplementedError("delete_chunk is not supported in Chroma") + sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True) + collection = self.client.collections.get(sanitized_collection_name) + collection.data.delete_many(where=Filter.by_property("id").contains_any([chunk_id])) async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: - collection = self.client.collections.get(self.collection_name) + sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True) + collection = self.client.collections.get(sanitized_collection_name) results = collection.query.near_vector( near_vector=embedding.tolist(), @@ -89,13 +105,26 @@ class WeaviateIndex(EmbeddingIndex): log.exception(f"Failed to parse document: {chunk_json}") continue + score = 1.0 / doc.metadata.distance if doc.metadata.distance != 0 else float("inf") + if score < score_threshold: + continue + chunks.append(chunk) - scores.append(1.0 / doc.metadata.distance if doc.metadata.distance != 0 else float("inf")) + scores.append(score) return QueryChunksResponse(chunks=chunks, scores=scores) - async def delete(self, chunk_ids: list[str]) -> None: - collection = self.client.collections.get(self.collection_name) + async def delete(self, chunk_ids: list[str] | None = None) -> None: + """ + Delete chunks by IDs if provided, otherwise drop the entire collection. + """ + sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True) + if chunk_ids is None: + # Drop entire collection if it exists + if self.client.collections.exists(sanitized_collection_name): + self.client.collections.delete(sanitized_collection_name) + return + collection = self.client.collections.get(sanitized_collection_name) collection.data.delete_many(where=Filter.by_property("id").contains_any(chunk_ids)) async def query_keyword( @@ -119,6 +148,7 @@ class WeaviateIndex(EmbeddingIndex): class WeaviateVectorIOAdapter( + OpenAIVectorStoreMixin, VectorIO, NeedsRequestProviderData, VectorDBsProtocolPrivate, @@ -140,42 +170,56 @@ class WeaviateVectorIOAdapter( self.metadata_collection_name = "openai_vector_stores_metadata" def _get_client(self) -> weaviate.Client: - provider_data = self.get_request_provider_data() - assert provider_data is not None, "Request provider data must be set" - assert isinstance(provider_data, WeaviateRequestProviderData) - - key = f"{provider_data.weaviate_cluster_url}::{provider_data.weaviate_api_key}" - if key in self.client_cache: - return self.client_cache[key] - - client = weaviate.connect_to_weaviate_cloud( - cluster_url=provider_data.weaviate_cluster_url, - auth_credentials=Auth.api_key(provider_data.weaviate_api_key), - ) + if "localhost" in self.config.weaviate_cluster_url: + log.info("using Weaviate locally in container") + host, port = self.config.weaviate_cluster_url.split(":") + key = "local_test" + client = weaviate.connect_to_local( + host=host, + port=port, + ) + else: + log.info("Using Weaviate remote cluster with URL") + key = f"{self.config.weaviate_cluster_url}::{self.config.weaviate_api_key}" + if key in self.client_cache: + return self.client_cache[key] + client = weaviate.connect_to_weaviate_cloud( + cluster_url=self.config.weaviate_cluster_url, + auth_credentials=Auth.api_key(self.config.weaviate_api_key), + ) self.client_cache[key] = client return client async def initialize(self) -> None: """Set up KV store and load existing vector DBs and OpenAI vector stores.""" - # Initialize KV store for metadata - self.kvstore = await kvstore_impl(self.config.kvstore) + # Initialize KV store for metadata if configured + if self.config.kvstore is not None: + self.kvstore = await kvstore_impl(self.config.kvstore) + else: + self.kvstore = None + log.info("No kvstore configured, registry will not persist across restarts") # Load existing vector DB definitions - start_key = VECTOR_DBS_PREFIX - end_key = f"{VECTOR_DBS_PREFIX}\xff" - stored = await self.kvstore.values_in_range(start_key, end_key) - for raw in stored: - vector_db = VectorDB.model_validate_json(raw) - client = self._get_client() - idx = WeaviateIndex(client=client, collection_name=vector_db.identifier, kvstore=self.kvstore) - self.cache[vector_db.identifier] = VectorDBWithIndex( - vector_db=vector_db, - index=idx, - inference_api=self.inference_api, - ) + if self.kvstore is not None: + start_key = VECTOR_DBS_PREFIX + end_key = f"{VECTOR_DBS_PREFIX}\xff" + stored = await self.kvstore.values_in_range(start_key, end_key) + for raw in stored: + vector_db = VectorDB.model_validate_json(raw) + client = self._get_client() + idx = WeaviateIndex( + client=client, + collection_name=vector_db.identifier, + kvstore=self.kvstore, + ) + self.cache[vector_db.identifier] = VectorDBWithIndex( + vector_db=vector_db, + index=idx, + inference_api=self.inference_api, + ) - # Load OpenAI vector stores metadata into cache - await self.initialize_openai_vector_stores() + # Load OpenAI vector stores metadata into cache + await self.initialize_openai_vector_stores() async def shutdown(self) -> None: for client in self.client_cache.values(): @@ -186,11 +230,11 @@ class WeaviateVectorIOAdapter( vector_db: VectorDB, ) -> None: client = self._get_client() - + sanitized_collection_name = sanitize_collection_name(vector_db.identifier, weaviate_format=True) # Create collection if it doesn't exist - if not client.collections.exists(vector_db.identifier): + if not client.collections.exists(sanitized_collection_name): client.collections.create( - name=vector_db.identifier, + name=sanitized_collection_name, vectorizer_config=wvc.config.Configure.Vectorizer.none(), properties=[ wvc.config.Property( @@ -200,30 +244,41 @@ class WeaviateVectorIOAdapter( ], ) - self.cache[vector_db.identifier] = VectorDBWithIndex( + self.cache[sanitized_collection_name] = VectorDBWithIndex( vector_db, - WeaviateIndex(client=client, collection_name=vector_db.identifier), + WeaviateIndex(client=client, collection_name=sanitized_collection_name), self.inference_api, ) - async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None: - if vector_db_id in self.cache: - return self.cache[vector_db_id] + async def unregister_vector_db(self, vector_db_id: str) -> None: + client = self._get_client() + sanitized_collection_name = sanitize_collection_name(vector_db_id, weaviate_format=True) + if sanitized_collection_name not in self.cache or client.collections.exists(sanitized_collection_name) is False: + log.warning(f"Vector DB {sanitized_collection_name} not found") + return + client.collections.delete(sanitized_collection_name) + await self.cache[sanitized_collection_name].index.delete() + del self.cache[sanitized_collection_name] - vector_db = await self.vector_db_store.get_vector_db(vector_db_id) + async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None: + sanitized_collection_name = sanitize_collection_name(vector_db_id, weaviate_format=True) + if sanitized_collection_name in self.cache: + return self.cache[sanitized_collection_name] + + vector_db = await self.vector_db_store.get_vector_db(sanitized_collection_name) if not vector_db: - raise ValueError(f"Vector DB {vector_db_id} not found") + raise VectorStoreNotFoundError(vector_db_id) client = self._get_client() if not client.collections.exists(vector_db.identifier): - raise ValueError(f"Collection with name `{vector_db.identifier}` not found") + raise ValueError(f"Collection with name `{sanitized_collection_name}` not found") index = VectorDBWithIndex( vector_db=vector_db, - index=WeaviateIndex(client=client, collection_name=vector_db.identifier), + index=WeaviateIndex(client=client, collection_name=sanitized_collection_name), inference_api=self.inference_api, ) - self.cache[vector_db_id] = index + self.cache[sanitized_collection_name] = index return index async def insert_chunks( @@ -232,9 +287,10 @@ class WeaviateVectorIOAdapter( chunks: list[Chunk], ttl_seconds: int | None = None, ) -> None: - index = await self._get_and_cache_vector_db_index(vector_db_id) + sanitized_collection_name = sanitize_collection_name(vector_db_id, weaviate_format=True) + index = await self._get_and_cache_vector_db_index(sanitized_collection_name) if not index: - raise ValueError(f"Vector DB {vector_db_id} not found") + raise VectorStoreNotFoundError(vector_db_id) await index.insert_chunks(chunks) @@ -244,29 +300,17 @@ class WeaviateVectorIOAdapter( query: InterleavedContent, params: dict[str, Any] | None = None, ) -> QueryChunksResponse: - index = await self._get_and_cache_vector_db_index(vector_db_id) + sanitized_collection_name = sanitize_collection_name(vector_db_id, weaviate_format=True) + index = await self._get_and_cache_vector_db_index(sanitized_collection_name) if not index: - raise ValueError(f"Vector DB {vector_db_id} not found") + raise VectorStoreNotFoundError(vector_db_id) return await index.query_chunks(query, params) - # OpenAI Vector Stores File operations are not supported in Weaviate - async def _save_openai_vector_store_file( - self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]] - ) -> None: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate") - - async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate") - - async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate") - - async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate") - - async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate") - async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate") + sanitized_collection_name = sanitize_collection_name(store_id, weaviate_format=True) + index = await self._get_and_cache_vector_db_index(sanitized_collection_name) + if not index: + raise ValueError(f"Vector DB {sanitized_collection_name} not found") + + await index.delete(chunk_ids) diff --git a/llama_stack/providers/utils/common/data_schema_validator.py b/llama_stack/providers/utils/common/data_schema_validator.py index 28a243863..b0305104f 100644 --- a/llama_stack/providers/utils/common/data_schema_validator.py +++ b/llama_stack/providers/utils/common/data_schema_validator.py @@ -12,7 +12,7 @@ from llama_stack.apis.common.type_system import ( CompletionInputType, StringType, ) -from llama_stack.distribution.datatypes import Api +from llama_stack.core.datatypes import Api class ColumnName(Enum): diff --git a/llama_stack/providers/utils/datasetio/url_utils.py b/llama_stack/providers/utils/datasetio/url_utils.py index 386ee736d..77b047e2d 100644 --- a/llama_stack/providers/utils/datasetio/url_utils.py +++ b/llama_stack/providers/utils/datasetio/url_utils.py @@ -9,12 +9,12 @@ import base64 import io from urllib.parse import unquote -import pandas - from llama_stack.providers.utils.memory.vector_store import parse_data_url async def get_dataframe_from_uri(uri: str): + import pandas + df = None if uri.endswith(".csv"): # Moving to its own thread to avoid io from blocking the eventloop diff --git a/llama_stack/providers/utils/inference/inference_store.py b/llama_stack/providers/utils/inference/inference_store.py index 60a87494e..43006cfd5 100644 --- a/llama_stack/providers/utils/inference/inference_store.py +++ b/llama_stack/providers/utils/inference/inference_store.py @@ -10,8 +10,8 @@ from llama_stack.apis.inference import ( OpenAIMessageParam, Order, ) -from llama_stack.distribution.datatypes import AccessRule -from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR +from llama_stack.core.datatypes import AccessRule +from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR from ..sqlstore.api import ColumnDefinition, ColumnType from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index 6ccf2a729..da2e634f6 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -38,7 +38,7 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) -from llama_stack.distribution.request_headers import NeedsRequestProviderData +from llama_stack.core.request_headers import NeedsRequestProviderData from llama_stack.log import get_logger from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.openai_compat import ( @@ -72,6 +72,8 @@ class LiteLLMOpenAIMixin( api_key_from_config: str | None, provider_data_api_key_field: str, openai_compat_api_base: str | None = None, + download_images: bool = False, + json_schema_strict: bool = True, ): """ Initialize the LiteLLMOpenAIMixin. @@ -81,6 +83,8 @@ class LiteLLMOpenAIMixin( :param provider_data_api_key_field: The field in the provider data that contains the API key. :param litellm_provider_name: The name of the provider, used for model lookups. :param openai_compat_api_base: The base URL for OpenAI compatibility, or None if not using OpenAI compatibility. + :param download_images: Whether to download images and convert to base64 for message conversion. + :param json_schema_strict: Whether to use strict mode for JSON schema validation. """ ModelRegistryHelper.__init__(self, model_entries) @@ -88,6 +92,8 @@ class LiteLLMOpenAIMixin( self.api_key_from_config = api_key_from_config self.provider_data_api_key_field = provider_data_api_key_field self.api_base = openai_compat_api_base + self.download_images = download_images + self.json_schema_strict = json_schema_strict if openai_compat_api_base: self.is_openai_compat = True @@ -152,9 +158,8 @@ class LiteLLMOpenAIMixin( params["model"] = self.get_litellm_model_name(params["model"]) logger.debug(f"params to litellm (openai compat): {params}") - # unfortunately, we need to use synchronous litellm.completion here because litellm - # caches various httpx.client objects in a non-eventloop aware manner - response = litellm.completion(**params) + # see https://docs.litellm.ai/docs/completion/stream#async-completion + response = await litellm.acompletion(**params) if stream: return self._stream_chat_completion(response) else: @@ -164,7 +169,7 @@ class LiteLLMOpenAIMixin( self, response: litellm.ModelResponse ) -> AsyncIterator[ChatCompletionResponseStreamChunk]: async def _stream_generator(): - for chunk in response: + async for chunk in response: yield chunk async for chunk in convert_openai_chat_completion_stream( @@ -206,7 +211,9 @@ class LiteLLMOpenAIMixin( async def _get_params(self, request: ChatCompletionRequest) -> dict: input_dict = {} - input_dict["messages"] = [await convert_message_to_openai_dict_new(m) for m in request.messages] + input_dict["messages"] = [ + await convert_message_to_openai_dict_new(m, download_images=self.download_images) for m in request.messages + ] if fmt := request.response_format: if not isinstance(fmt, JsonSchemaResponseFormat): raise ValueError( @@ -226,7 +233,7 @@ class LiteLLMOpenAIMixin( "json_schema": { "name": name, "schema": fmt, - "strict": True, + "strict": self.json_schema_strict, }, } if request.tools: @@ -254,6 +261,12 @@ class LiteLLMOpenAIMixin( api_key = getattr(provider_data, key_field) else: api_key = self.api_key_from_config + if not api_key: + raise ValueError( + "API key is not set. Please provide a valid API key in the " + "provider data header, e.g. x-llamastack-provider-data: " + f'{{"{key_field}": ""}}, or in the provider config.' + ) return api_key async def embeddings( diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 47144ee0e..e6e5ccc8a 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -564,6 +564,7 @@ class UnparseableToolCall(BaseModel): async def convert_message_to_openai_dict_new( message: Message | dict, + download_images: bool = False, ) -> OpenAIChatCompletionMessage: """ Convert a Message to an OpenAI API-compatible dictionary. @@ -607,7 +608,9 @@ async def convert_message_to_openai_dict_new( elif isinstance(content_, ImageContentItem): return OpenAIChatCompletionContentPartImageParam( type="image_url", - image_url=OpenAIImageURL(url=await convert_image_content_to_url(content_)), + image_url=OpenAIImageURL( + url=await convert_image_content_to_url(content_, download=download_images) + ), ) elif isinstance(content_, list): return [await impl(item) for item in content_] diff --git a/llama_stack/providers/utils/inference/stream_utils.py b/llama_stack/providers/utils/inference/stream_utils.py deleted file mode 100644 index bbfac13a3..000000000 --- a/llama_stack/providers/utils/inference/stream_utils.py +++ /dev/null @@ -1,129 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from collections.abc import AsyncIterator -from datetime import UTC, datetime -from typing import Any - -from llama_stack.apis.inference import ( - OpenAIAssistantMessageParam, - OpenAIChatCompletion, - OpenAIChatCompletionChunk, - OpenAIChatCompletionToolCall, - OpenAIChatCompletionToolCallFunction, - OpenAIChoice, - OpenAIChoiceLogprobs, - OpenAIMessageParam, -) -from llama_stack.providers.utils.inference.inference_store import InferenceStore - - -async def stream_and_store_openai_completion( - provider_stream: AsyncIterator[OpenAIChatCompletionChunk], - model: str, - store: InferenceStore, - input_messages: list[OpenAIMessageParam], -) -> AsyncIterator[OpenAIChatCompletionChunk]: - """ - Wraps a provider's stream, yields chunks, and stores the full completion at the end. - """ - id = None - created = None - choices_data: dict[int, dict[str, Any]] = {} - - try: - async for chunk in provider_stream: - if id is None and chunk.id: - id = chunk.id - if created is None and chunk.created: - created = chunk.created - - if chunk.choices: - for choice_delta in chunk.choices: - idx = choice_delta.index - if idx not in choices_data: - choices_data[idx] = { - "content_parts": [], - "tool_calls_builder": {}, - "finish_reason": None, - "logprobs_content_parts": [], - } - current_choice_data = choices_data[idx] - - if choice_delta.delta: - delta = choice_delta.delta - if delta.content: - current_choice_data["content_parts"].append(delta.content) - if delta.tool_calls: - for tool_call_delta in delta.tool_calls: - tc_idx = tool_call_delta.index - if tc_idx not in current_choice_data["tool_calls_builder"]: - # Initialize with correct structure for _ToolCallBuilderData - current_choice_data["tool_calls_builder"][tc_idx] = { - "id": None, - "type": "function", - "function_name_parts": [], - "function_arguments_parts": [], - } - builder = current_choice_data["tool_calls_builder"][tc_idx] - if tool_call_delta.id: - builder["id"] = tool_call_delta.id - if tool_call_delta.type: - builder["type"] = tool_call_delta.type - if tool_call_delta.function: - if tool_call_delta.function.name: - builder["function_name_parts"].append(tool_call_delta.function.name) - if tool_call_delta.function.arguments: - builder["function_arguments_parts"].append(tool_call_delta.function.arguments) - if choice_delta.finish_reason: - current_choice_data["finish_reason"] = choice_delta.finish_reason - if choice_delta.logprobs and choice_delta.logprobs.content: - # Ensure that we are extending with the correct type - current_choice_data["logprobs_content_parts"].extend(choice_delta.logprobs.content) - yield chunk - finally: - if id: - assembled_choices: list[OpenAIChoice] = [] - for choice_idx, choice_data in choices_data.items(): - content_str = "".join(choice_data["content_parts"]) - assembled_tool_calls: list[OpenAIChatCompletionToolCall] = [] - if choice_data["tool_calls_builder"]: - for tc_build_data in choice_data["tool_calls_builder"].values(): - if tc_build_data["id"]: - func_name = "".join(tc_build_data["function_name_parts"]) - func_args = "".join(tc_build_data["function_arguments_parts"]) - assembled_tool_calls.append( - OpenAIChatCompletionToolCall( - id=tc_build_data["id"], - type=tc_build_data["type"], # No or "function" needed, already set - function=OpenAIChatCompletionToolCallFunction(name=func_name, arguments=func_args), - ) - ) - message = OpenAIAssistantMessageParam( - role="assistant", - content=content_str if content_str else None, - tool_calls=assembled_tool_calls if assembled_tool_calls else None, - ) - logprobs_content = choice_data["logprobs_content_parts"] - final_logprobs = OpenAIChoiceLogprobs(content=logprobs_content) if logprobs_content else None - - assembled_choices.append( - OpenAIChoice( - finish_reason=choice_data["finish_reason"], - index=choice_idx, - message=message, - logprobs=final_logprobs, - ) - ) - - final_response = OpenAIChatCompletion( - id=id, - choices=assembled_choices, - created=created or int(datetime.now(UTC).timestamp()), - model=model, - object="chat.completion", - ) - await store.store_chat_completion(final_response, input_messages) diff --git a/llama_stack/providers/utils/kvstore/config.py b/llama_stack/providers/utils/kvstore/config.py index 0219bbebe..f00cb1f8b 100644 --- a/llama_stack/providers/utils/kvstore/config.py +++ b/llama_stack/providers/utils/kvstore/config.py @@ -10,7 +10,7 @@ from typing import Annotated, Literal from pydantic import BaseModel, Field, field_validator -from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR +from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR class KVStoreType(Enum): diff --git a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py index ee69d7c52..7b6e69df1 100644 --- a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py +++ b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py @@ -13,6 +13,7 @@ import uuid from abc import ABC, abstractmethod from typing import Any +from llama_stack.apis.common.errors import VectorStoreNotFoundError from llama_stack.apis.files import Files, OpenAIFileObject from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import ( @@ -322,7 +323,7 @@ class OpenAIVectorStoreMixin(ABC): ) -> VectorStoreObject: """Retrieves a vector store.""" if vector_store_id not in self.openai_vector_stores: - raise ValueError(f"Vector store {vector_store_id} not found") + raise VectorStoreNotFoundError(vector_store_id) store_info = self.openai_vector_stores[vector_store_id] return VectorStoreObject(**store_info) @@ -336,7 +337,7 @@ class OpenAIVectorStoreMixin(ABC): ) -> VectorStoreObject: """Modifies a vector store.""" if vector_store_id not in self.openai_vector_stores: - raise ValueError(f"Vector store {vector_store_id} not found") + raise VectorStoreNotFoundError(vector_store_id) store_info = self.openai_vector_stores[vector_store_id].copy() @@ -365,7 +366,7 @@ class OpenAIVectorStoreMixin(ABC): ) -> VectorStoreDeleteResponse: """Delete a vector store.""" if vector_store_id not in self.openai_vector_stores: - raise ValueError(f"Vector store {vector_store_id} not found") + raise VectorStoreNotFoundError(vector_store_id) # Delete from persistent storage (provider-specific) await self._delete_openai_vector_store_from_storage(vector_store_id) @@ -403,7 +404,7 @@ class OpenAIVectorStoreMixin(ABC): raise ValueError(f"search_mode must be one of {valid_modes}, got {search_mode}") if vector_store_id not in self.openai_vector_stores: - raise ValueError(f"Vector store {vector_store_id} not found") + raise VectorStoreNotFoundError(vector_store_id) if isinstance(query, list): search_query = " ".join(query) @@ -432,10 +433,6 @@ class OpenAIVectorStoreMixin(ABC): # Convert response to OpenAI format data = [] for chunk, score in zip(response.chunks, response.scores, strict=False): - # Apply score based filtering - if score < score_threshold: - continue - # Apply filters if provided if filters: # Simple metadata filtering @@ -556,7 +553,7 @@ class OpenAIVectorStoreMixin(ABC): chunking_strategy: VectorStoreChunkingStrategy | None = None, ) -> VectorStoreFileObject: if vector_store_id not in self.openai_vector_stores: - raise ValueError(f"Vector store {vector_store_id} not found") + raise VectorStoreNotFoundError(vector_store_id) attributes = attributes or {} chunking_strategy = chunking_strategy or VectorStoreChunkingStrategyAuto() @@ -661,7 +658,7 @@ class OpenAIVectorStoreMixin(ABC): order = order or "desc" if vector_store_id not in self.openai_vector_stores: - raise ValueError(f"Vector store {vector_store_id} not found") + raise VectorStoreNotFoundError(vector_store_id) store_info = self.openai_vector_stores[vector_store_id] @@ -709,7 +706,7 @@ class OpenAIVectorStoreMixin(ABC): ) -> VectorStoreFileObject: """Retrieves a vector store file.""" if vector_store_id not in self.openai_vector_stores: - raise ValueError(f"Vector store {vector_store_id} not found") + raise VectorStoreNotFoundError(vector_store_id) store_info = self.openai_vector_stores[vector_store_id] if file_id not in store_info["file_ids"]: @@ -725,7 +722,7 @@ class OpenAIVectorStoreMixin(ABC): ) -> VectorStoreFileContentsResponse: """Retrieves the contents of a vector store file.""" if vector_store_id not in self.openai_vector_stores: - raise ValueError(f"Vector store {vector_store_id} not found") + raise VectorStoreNotFoundError(vector_store_id) file_info = await self._load_openai_vector_store_file(vector_store_id, file_id) dict_chunks = await self._load_openai_vector_store_file_contents(vector_store_id, file_id) @@ -748,7 +745,7 @@ class OpenAIVectorStoreMixin(ABC): ) -> VectorStoreFileObject: """Updates a vector store file.""" if vector_store_id not in self.openai_vector_stores: - raise ValueError(f"Vector store {vector_store_id} not found") + raise VectorStoreNotFoundError(vector_store_id) store_info = self.openai_vector_stores[vector_store_id] if file_id not in store_info["file_ids"]: @@ -766,7 +763,7 @@ class OpenAIVectorStoreMixin(ABC): ) -> VectorStoreFileDeleteResponse: """Deletes a vector store file.""" if vector_store_id not in self.openai_vector_stores: - raise ValueError(f"Vector store {vector_store_id} not found") + raise VectorStoreNotFoundError(vector_store_id) dict_chunks = await self._load_openai_vector_store_file_contents(vector_store_id, file_id) chunks = [Chunk.model_validate(c) for c in dict_chunks] diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index 4a8749cba..bb9002f30 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -30,7 +30,7 @@ from llama_stack.providers.datatypes import Api from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, ) -from llama_stack.providers.utils.vector_io.chunk_utils import generate_chunk_id +from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id log = logging.getLogger(__name__) @@ -302,23 +302,25 @@ class VectorDBWithIndex: mode = params.get("mode") score_threshold = params.get("score_threshold", 0.0) - # Get ranker configuration ranker = params.get("ranker") if ranker is None: - # Default to RRF with impact_factor=60.0 reranker_type = RERANKER_TYPE_RRF reranker_params = {"impact_factor": 60.0} else: - reranker_type = ranker.type - reranker_params = ( - {"impact_factor": ranker.impact_factor} if ranker.type == RERANKER_TYPE_RRF else {"alpha": ranker.alpha} - ) + strategy = ranker.get("strategy", "rrf") + if strategy == "weighted": + weights = ranker.get("params", {}).get("weights", [0.5, 0.5]) + reranker_type = RERANKER_TYPE_WEIGHTED + reranker_params = {"alpha": weights[0] if len(weights) > 0 else 0.5} + else: + reranker_type = RERANKER_TYPE_RRF + k_value = ranker.get("params", {}).get("k", 60.0) + reranker_params = {"impact_factor": k_value} query_string = interleaved_content_as_str(query) if mode == "keyword": return await self.index.query_keyword(query_string, k, score_threshold) - # Calculate embeddings for both vector and hybrid modes embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_string]) query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32) if mode == "hybrid": diff --git a/llama_stack/providers/utils/responses/responses_store.py b/llama_stack/providers/utils/responses/responses_store.py index ea6db7991..04778ed1c 100644 --- a/llama_stack/providers/utils/responses/responses_store.py +++ b/llama_stack/providers/utils/responses/responses_store.py @@ -14,8 +14,8 @@ from llama_stack.apis.agents.openai_responses import ( OpenAIResponseObject, OpenAIResponseObjectWithInput, ) -from llama_stack.distribution.datatypes import AccessRule -from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR +from llama_stack.core.datatypes import AccessRule +from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR from ..sqlstore.api import ColumnDefinition, ColumnType from ..sqlstore.authorized_sqlstore import AuthorizedSqlStore diff --git a/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py b/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py index 864a7dbb6..ccc835768 100644 --- a/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py +++ b/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py @@ -7,11 +7,11 @@ from collections.abc import Mapping from typing import Any, Literal -from llama_stack.distribution.access_control.access_control import default_policy, is_action_allowed -from llama_stack.distribution.access_control.conditions import ProtectedResource -from llama_stack.distribution.access_control.datatypes import AccessRule, Action, Scope -from llama_stack.distribution.datatypes import User -from llama_stack.distribution.request_headers import get_authenticated_user +from llama_stack.core.access_control.access_control import default_policy, is_action_allowed +from llama_stack.core.access_control.conditions import ProtectedResource +from llama_stack.core.access_control.datatypes import AccessRule, Action, Scope +from llama_stack.core.datatypes import User +from llama_stack.core.request_headers import get_authenticated_user from llama_stack.log import get_logger from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore diff --git a/llama_stack/providers/utils/sqlstore/sqlstore.py b/llama_stack/providers/utils/sqlstore/sqlstore.py index 9f7eefcf5..fc44402ae 100644 --- a/llama_stack/providers/utils/sqlstore/sqlstore.py +++ b/llama_stack/providers/utils/sqlstore/sqlstore.py @@ -11,7 +11,7 @@ from typing import Annotated, Literal from pydantic import BaseModel, Field -from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR +from llama_stack.core.utils.config_dirs import RUNTIME_BASE_DIR from .api import SqlStore diff --git a/llama_stack/providers/utils/telemetry/tracing.py b/llama_stack/providers/utils/telemetry/tracing.py index c85722bdc..75b29cdce 100644 --- a/llama_stack/providers/utils/telemetry/tracing.py +++ b/llama_stack/providers/utils/telemetry/tracing.py @@ -81,7 +81,7 @@ BACKGROUND_LOGGER = None class BackgroundLogger: - def __init__(self, api: Telemetry, capacity: int = 1000): + def __init__(self, api: Telemetry, capacity: int = 100000): self.api = api self.log_queue = queue.Queue(maxsize=capacity) self.worker_thread = threading.Thread(target=self._process_logs, daemon=True) diff --git a/llama_stack/providers/utils/tools/mcp.py b/llama_stack/providers/utils/tools/mcp.py index 76593a4b8..02f7aaf8a 100644 --- a/llama_stack/providers/utils/tools/mcp.py +++ b/llama_stack/providers/utils/tools/mcp.py @@ -22,7 +22,7 @@ from llama_stack.apis.tools import ( ToolInvocationResult, ToolParameter, ) -from llama_stack.distribution.datatypes import AuthenticationRequiredError +from llama_stack.core.datatypes import AuthenticationRequiredError from llama_stack.log import get_logger from llama_stack.providers.utils.tools.ttl_dict import TTLDict diff --git a/llama_stack/providers/utils/vector_io/chunk_utils.py b/llama_stack/providers/utils/vector_io/vector_utils.py similarity index 58% rename from llama_stack/providers/utils/vector_io/chunk_utils.py rename to llama_stack/providers/utils/vector_io/vector_utils.py index 01afa6ec8..f2888043e 100644 --- a/llama_stack/providers/utils/vector_io/chunk_utils.py +++ b/llama_stack/providers/utils/vector_io/vector_utils.py @@ -5,6 +5,7 @@ # the root directory of this source tree. import hashlib +import re import uuid @@ -19,3 +20,20 @@ def generate_chunk_id(document_id: str, chunk_text: str, chunk_window: str | Non if chunk_window: hash_input += f":{chunk_window}".encode() return str(uuid.UUID(hashlib.md5(hash_input, usedforsecurity=False).hexdigest())) + + +def proper_case(s: str) -> str: + """Convert a string to proper case (first letter uppercase, rest lowercase).""" + return s[0].upper() + s[1:].lower() if s else s + + +def sanitize_collection_name(name: str, weaviate_format=False) -> str: + """ + Sanitize collection name to ensure it only contains numbers, letters, and underscores. + Any other characters are replaced with underscores. + """ + if not weaviate_format: + s = re.sub(r"[^a-zA-Z0-9_]", "_", name) + else: + s = proper_case(re.sub(r"[^a-zA-Z0-9]", "", name)) + return s diff --git a/llama_stack/testing/__init__.py b/llama_stack/testing/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/testing/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/testing/inference_recorder.py b/llama_stack/testing/inference_recorder.py new file mode 100644 index 000000000..478f77773 --- /dev/null +++ b/llama_stack/testing/inference_recorder.py @@ -0,0 +1,452 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from __future__ import annotations # for forward references + +import hashlib +import json +import os +import sqlite3 +from collections.abc import Generator +from contextlib import contextmanager +from enum import StrEnum +from pathlib import Path +from typing import Any, Literal, cast + +from llama_stack.log import get_logger + +logger = get_logger(__name__, category="testing") + +# Global state for the recording system +_current_mode: str | None = None +_current_storage: ResponseStorage | None = None +_original_methods: dict[str, Any] = {} + +from openai.types.completion_choice import CompletionChoice + +# update the "finish_reason" field, since its type definition is wrong (no None is accepted) +CompletionChoice.model_fields["finish_reason"].annotation = Literal["stop", "length", "content_filter"] | None +CompletionChoice.model_rebuild() + + +class InferenceMode(StrEnum): + LIVE = "live" + RECORD = "record" + REPLAY = "replay" + + +def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict[str, Any]) -> str: + """Create a normalized hash of the request for consistent matching.""" + # Extract just the endpoint path + from urllib.parse import urlparse + + parsed = urlparse(url) + normalized = {"method": method.upper(), "endpoint": parsed.path, "body": body} + + # Create hash - sort_keys=True ensures deterministic ordering + normalized_json = json.dumps(normalized, sort_keys=True) + return hashlib.sha256(normalized_json.encode()).hexdigest() + + +def get_inference_mode() -> InferenceMode: + return InferenceMode(os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "live").lower()) + + +def setup_inference_recording(): + """ + Returns a context manager that can be used to record or replay inference requests. This is to be used in tests + to increase their reliability and reduce reliance on expensive, external services. + + Currently, this is only supported for OpenAI and Ollama clients. These should cover the vast majority of use cases. + Calls to the /models endpoint are not currently trapped. We probably need to add support for this. + + Two environment variables are required: + - LLAMA_STACK_TEST_INFERENCE_MODE: The mode to run in. Must be 'live', 'record', or 'replay'. + - LLAMA_STACK_TEST_RECORDING_DIR: The directory to store the recordings in. + + The recordings are stored in a SQLite database and a JSON file for each request. The SQLite database is used to + quickly find the correct recording for a given request. The JSON files are used to store the request and response + bodies. + """ + mode = get_inference_mode() + + if mode not in InferenceMode: + raise ValueError(f"Invalid LLAMA_STACK_TEST_INFERENCE_MODE: {mode}. Must be 'live', 'record', or 'replay'") + + if mode == InferenceMode.LIVE: + return None + + if "LLAMA_STACK_TEST_RECORDING_DIR" not in os.environ: + raise ValueError("LLAMA_STACK_TEST_RECORDING_DIR must be set for recording or replaying") + storage_dir = os.environ["LLAMA_STACK_TEST_RECORDING_DIR"] + + return inference_recording(mode=mode, storage_dir=storage_dir) + + +def _serialize_response(response: Any) -> Any: + if hasattr(response, "model_dump"): + data = response.model_dump(mode="json") + return { + "__type__": f"{response.__class__.__module__}.{response.__class__.__qualname__}", + "__data__": data, + } + elif hasattr(response, "__dict__"): + return dict(response.__dict__) + else: + return response + + +def _deserialize_response(data: dict[str, Any]) -> Any: + # Check if this is a serialized Pydantic model with type information + if isinstance(data, dict) and "__type__" in data and "__data__" in data: + try: + # Import the original class and reconstruct the object + module_path, class_name = data["__type__"].rsplit(".", 1) + module = __import__(module_path, fromlist=[class_name]) + cls = getattr(module, class_name) + + if not hasattr(cls, "model_validate"): + raise ValueError(f"Pydantic class {cls} does not support model_validate?") + + return cls.model_validate(data["__data__"]) + except (ImportError, AttributeError, TypeError, ValueError) as e: + logger.warning(f"Failed to deserialize object of type {data['__type__']}: {e}") + return data["__data__"] + + return data + + +class ResponseStorage: + """Handles SQLite index + JSON file storage/retrieval for inference recordings.""" + + def __init__(self, test_dir: Path): + self.test_dir = test_dir + self.responses_dir = self.test_dir / "responses" + self.db_path = self.test_dir / "index.sqlite" + + self._ensure_directories() + self._init_database() + + def _ensure_directories(self): + self.test_dir.mkdir(parents=True, exist_ok=True) + self.responses_dir.mkdir(exist_ok=True) + + def _init_database(self): + with sqlite3.connect(self.db_path) as conn: + conn.execute(""" + CREATE TABLE IF NOT EXISTS recordings ( + request_hash TEXT PRIMARY KEY, + response_file TEXT, + endpoint TEXT, + model TEXT, + timestamp TEXT, + is_streaming BOOLEAN + ) + """) + + def store_recording(self, request_hash: str, request: dict[str, Any], response: dict[str, Any]): + """Store a request/response pair.""" + # Generate unique response filename + response_file = f"{request_hash[:12]}.json" + response_path = self.responses_dir / response_file + + # Serialize response body if needed + serialized_response = dict(response) + if "body" in serialized_response: + if isinstance(serialized_response["body"], list): + # Handle streaming responses (list of chunks) + serialized_response["body"] = [_serialize_response(chunk) for chunk in serialized_response["body"]] + else: + # Handle single response + serialized_response["body"] = _serialize_response(serialized_response["body"]) + + # Save response to JSON file + with open(response_path, "w") as f: + json.dump({"request": request, "response": serialized_response}, f, indent=2) + f.write("\n") + f.flush() + + # Update SQLite index + with sqlite3.connect(self.db_path) as conn: + conn.execute( + """ + INSERT OR REPLACE INTO recordings + (request_hash, response_file, endpoint, model, timestamp, is_streaming) + VALUES (?, ?, ?, ?, datetime('now'), ?) + """, + ( + request_hash, + response_file, + request.get("endpoint", ""), + request.get("model", ""), + response.get("is_streaming", False), + ), + ) + + def find_recording(self, request_hash: str) -> dict[str, Any] | None: + """Find a recorded response by request hash.""" + with sqlite3.connect(self.db_path) as conn: + result = conn.execute( + "SELECT response_file FROM recordings WHERE request_hash = ?", (request_hash,) + ).fetchone() + + if not result: + return None + + response_file = result[0] + response_path = self.responses_dir / response_file + + if not response_path.exists(): + return None + + with open(response_path) as f: + data = json.load(f) + + # Deserialize response body if needed + if "response" in data and "body" in data["response"]: + if isinstance(data["response"]["body"], list): + # Handle streaming responses + data["response"]["body"] = [_deserialize_response(chunk) for chunk in data["response"]["body"]] + else: + # Handle single response + data["response"]["body"] = _deserialize_response(data["response"]["body"]) + + return cast(dict[str, Any], data) + + +async def _patched_inference_method(original_method, self, client_type, endpoint, *args, **kwargs): + global _current_mode, _current_storage + + if _current_mode == InferenceMode.LIVE or _current_storage is None: + # Normal operation + return await original_method(self, *args, **kwargs) + + # Get base URL based on client type + if client_type == "openai": + base_url = str(self._client.base_url) + elif client_type == "ollama": + # Get base URL from the client (Ollama client uses host attribute) + base_url = getattr(self, "host", "http://localhost:11434") + if not base_url.startswith("http"): + base_url = f"http://{base_url}" + else: + raise ValueError(f"Unknown client type: {client_type}") + + url = base_url.rstrip("/") + endpoint + + # Normalize request for matching + method = "POST" + headers = {} + body = kwargs + + request_hash = normalize_request(method, url, headers, body) + + if _current_mode == InferenceMode.REPLAY: + recording = _current_storage.find_recording(request_hash) + if recording: + response_body = recording["response"]["body"] + + if recording["response"].get("is_streaming", False): + + async def replay_stream(): + for chunk in response_body: + yield chunk + + return replay_stream() + else: + return response_body + else: + raise RuntimeError( + f"No recorded response found for request hash: {request_hash}\n" + f"Endpoint: {endpoint}\n" + f"Model: {body.get('model', 'unknown')}\n" + f"To record this response, run with LLAMA_STACK_INFERENCE_MODE=record" + ) + + elif _current_mode == InferenceMode.RECORD: + response = await original_method(self, *args, **kwargs) + + request_data = { + "method": method, + "url": url, + "headers": headers, + "body": body, + "endpoint": endpoint, + "model": body.get("model", ""), + } + + # Determine if this is a streaming request based on request parameters + is_streaming = body.get("stream", False) + + if is_streaming: + # For streaming responses, we need to collect all chunks immediately before yielding + # This ensures the recording is saved even if the generator isn't fully consumed + chunks = [] + async for chunk in response: + chunks.append(chunk) + + # Store the recording immediately + response_data = {"body": chunks, "is_streaming": True} + _current_storage.store_recording(request_hash, request_data, response_data) + + # Return a generator that replays the stored chunks + async def replay_recorded_stream(): + for chunk in chunks: + yield chunk + + return replay_recorded_stream() + else: + response_data = {"body": response, "is_streaming": False} + _current_storage.store_recording(request_hash, request_data, response_data) + return response + + else: + raise AssertionError(f"Invalid mode: {_current_mode}") + + +def patch_inference_clients(): + """Install monkey patches for OpenAI client methods and Ollama AsyncClient methods.""" + global _original_methods + + from ollama import AsyncClient as OllamaAsyncClient + from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions + from openai.resources.completions import AsyncCompletions + from openai.resources.embeddings import AsyncEmbeddings + + # Store original methods for both OpenAI and Ollama clients + _original_methods = { + "chat_completions_create": AsyncChatCompletions.create, + "completions_create": AsyncCompletions.create, + "embeddings_create": AsyncEmbeddings.create, + "ollama_generate": OllamaAsyncClient.generate, + "ollama_chat": OllamaAsyncClient.chat, + "ollama_embed": OllamaAsyncClient.embed, + "ollama_ps": OllamaAsyncClient.ps, + "ollama_pull": OllamaAsyncClient.pull, + "ollama_list": OllamaAsyncClient.list, + } + + # Create patched methods for OpenAI client + async def patched_chat_completions_create(self, *args, **kwargs): + return await _patched_inference_method( + _original_methods["chat_completions_create"], self, "openai", "/v1/chat/completions", *args, **kwargs + ) + + async def patched_completions_create(self, *args, **kwargs): + return await _patched_inference_method( + _original_methods["completions_create"], self, "openai", "/v1/completions", *args, **kwargs + ) + + async def patched_embeddings_create(self, *args, **kwargs): + return await _patched_inference_method( + _original_methods["embeddings_create"], self, "openai", "/v1/embeddings", *args, **kwargs + ) + + # Apply OpenAI patches + AsyncChatCompletions.create = patched_chat_completions_create + AsyncCompletions.create = patched_completions_create + AsyncEmbeddings.create = patched_embeddings_create + + # Create patched methods for Ollama client + async def patched_ollama_generate(self, *args, **kwargs): + return await _patched_inference_method( + _original_methods["ollama_generate"], self, "ollama", "/api/generate", *args, **kwargs + ) + + async def patched_ollama_chat(self, *args, **kwargs): + return await _patched_inference_method( + _original_methods["ollama_chat"], self, "ollama", "/api/chat", *args, **kwargs + ) + + async def patched_ollama_embed(self, *args, **kwargs): + return await _patched_inference_method( + _original_methods["ollama_embed"], self, "ollama", "/api/embeddings", *args, **kwargs + ) + + async def patched_ollama_ps(self, *args, **kwargs): + return await _patched_inference_method( + _original_methods["ollama_ps"], self, "ollama", "/api/ps", *args, **kwargs + ) + + async def patched_ollama_pull(self, *args, **kwargs): + return await _patched_inference_method( + _original_methods["ollama_pull"], self, "ollama", "/api/pull", *args, **kwargs + ) + + async def patched_ollama_list(self, *args, **kwargs): + return await _patched_inference_method( + _original_methods["ollama_list"], self, "ollama", "/api/tags", *args, **kwargs + ) + + # Apply Ollama patches + OllamaAsyncClient.generate = patched_ollama_generate + OllamaAsyncClient.chat = patched_ollama_chat + OllamaAsyncClient.embed = patched_ollama_embed + OllamaAsyncClient.ps = patched_ollama_ps + OllamaAsyncClient.pull = patched_ollama_pull + OllamaAsyncClient.list = patched_ollama_list + + +def unpatch_inference_clients(): + """Remove monkey patches and restore original OpenAI and Ollama client methods.""" + global _original_methods + + if not _original_methods: + return + + # Import here to avoid circular imports + from ollama import AsyncClient as OllamaAsyncClient + from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions + from openai.resources.completions import AsyncCompletions + from openai.resources.embeddings import AsyncEmbeddings + + # Restore OpenAI client methods + AsyncChatCompletions.create = _original_methods["chat_completions_create"] + AsyncCompletions.create = _original_methods["completions_create"] + AsyncEmbeddings.create = _original_methods["embeddings_create"] + + # Restore Ollama client methods if they were patched + OllamaAsyncClient.generate = _original_methods["ollama_generate"] + OllamaAsyncClient.chat = _original_methods["ollama_chat"] + OllamaAsyncClient.embed = _original_methods["ollama_embed"] + OllamaAsyncClient.ps = _original_methods["ollama_ps"] + OllamaAsyncClient.pull = _original_methods["ollama_pull"] + OllamaAsyncClient.list = _original_methods["ollama_list"] + + _original_methods.clear() + + +@contextmanager +def inference_recording(mode: str = "live", storage_dir: str | Path | None = None) -> Generator[None, None, None]: + """Context manager for inference recording/replaying.""" + global _current_mode, _current_storage + + # Set defaults + if storage_dir is None: + storage_dir_path = Path.home() / ".llama" / "recordings" + else: + storage_dir_path = Path(storage_dir) + + # Store previous state + prev_mode = _current_mode + prev_storage = _current_storage + + try: + _current_mode = mode + + if mode in ["record", "replay"]: + _current_storage = ResponseStorage(storage_dir_path) + patch_inference_clients() + + yield + + finally: + # Restore previous state + if mode in ["record", "replay"]: + unpatch_inference_clients() + + _current_mode = prev_mode + _current_storage = prev_storage diff --git a/llama_stack/ui/app/chat-playground/page.tsx b/llama_stack/ui/app/chat-playground/page.tsx new file mode 100644 index 000000000..c31248b78 --- /dev/null +++ b/llama_stack/ui/app/chat-playground/page.tsx @@ -0,0 +1,223 @@ +"use client"; + +import { useState, useEffect } from "react"; +import { flushSync } from "react-dom"; +import { Button } from "@/components/ui/button"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/ui/select"; +import { Chat } from "@/components/chat-playground/chat"; +import { type Message } from "@/components/chat-playground/chat-message"; +import { useAuthClient } from "@/hooks/use-auth-client"; +import type { CompletionCreateParams } from "llama-stack-client/resources/chat/completions"; +import type { Model } from "llama-stack-client/resources/models"; + +export default function ChatPlaygroundPage() { + const [messages, setMessages] = useState([]); + const [input, setInput] = useState(""); + const [isGenerating, setIsGenerating] = useState(false); + const [error, setError] = useState(null); + const [models, setModels] = useState([]); + const [selectedModel, setSelectedModel] = useState(""); + const [modelsLoading, setModelsLoading] = useState(true); + const [modelsError, setModelsError] = useState(null); + const client = useAuthClient(); + + const isModelsLoading = modelsLoading ?? true; + + + useEffect(() => { + const fetchModels = async () => { + try { + setModelsLoading(true); + setModelsError(null); + const modelList = await client.models.list(); + const llmModels = modelList.filter(model => model.model_type === 'llm'); + setModels(llmModels); + if (llmModels.length > 0) { + setSelectedModel(llmModels[0].identifier); + } + } catch (err) { + console.error("Error fetching models:", err); + setModelsError("Failed to fetch available models"); + } finally { + setModelsLoading(false); + } + }; + + fetchModels(); + }, [client]); + + const extractTextContent = (content: unknown): string => { + if (typeof content === 'string') { + return content; + } + if (Array.isArray(content)) { + return content + .filter(item => item && typeof item === 'object' && 'type' in item && item.type === 'text') + .map(item => (item && typeof item === 'object' && 'text' in item) ? String(item.text) : '') + .join(''); + } + if (content && typeof content === 'object' && 'type' in content && content.type === 'text' && 'text' in content) { + return String(content.text) || ''; + } + return ''; + }; + + const handleInputChange = (e: React.ChangeEvent) => { + setInput(e.target.value); + }; + +const handleSubmit = async (event?: { preventDefault?: () => void }) => { + event?.preventDefault?.(); + if (!input.trim()) return; + + // Add user message to chat + const userMessage: Message = { + id: Date.now().toString(), + role: "user", + content: input.trim(), + createdAt: new Date(), + }; + + setMessages(prev => [...prev, userMessage]); + setInput(""); + + // Use the helper function with the content + await handleSubmitWithContent(userMessage.content); +}; + +const handleSubmitWithContent = async (content: string) => { + setIsGenerating(true); + setError(null); + + try { + const messageParams: CompletionCreateParams["messages"] = [ + ...messages.map(msg => { + const msgContent = typeof msg.content === 'string' ? msg.content : extractTextContent(msg.content); + if (msg.role === "user") { + return { role: "user" as const, content: msgContent }; + } else if (msg.role === "assistant") { + return { role: "assistant" as const, content: msgContent }; + } else { + return { role: "system" as const, content: msgContent }; + } + }), + { role: "user" as const, content } + ]; + + const response = await client.chat.completions.create({ + model: selectedModel, + messages: messageParams, + stream: true, + }); + + const assistantMessage: Message = { + id: (Date.now() + 1).toString(), + role: "assistant", + content: "", + createdAt: new Date(), + }; + + setMessages(prev => [...prev, assistantMessage]); + let fullContent = ""; + for await (const chunk of response) { + if (chunk.choices && chunk.choices[0]?.delta?.content) { + const deltaContent = chunk.choices[0].delta.content; + fullContent += deltaContent; + + flushSync(() => { + setMessages(prev => { + const newMessages = [...prev]; + const lastMessage = newMessages[newMessages.length - 1]; + if (lastMessage.role === "assistant") { + lastMessage.content = fullContent; + } + return newMessages; + }); + }); + } + } + } catch (err) { + console.error("Error sending message:", err); + setError("Failed to send message. Please try again."); + setMessages(prev => prev.slice(0, -1)); + } finally { + setIsGenerating(false); + } +}; + const suggestions = [ + "Write a Python function that prints 'Hello, World!'", + "Explain step-by-step how to solve this math problem: If x² + 6x + 9 = 25, what is x?", + "Design a simple algorithm to find the longest palindrome in a string.", + ]; + + const append = (message: { role: "user"; content: string }) => { + const newMessage: Message = { + id: Date.now().toString(), + role: message.role, + content: message.content, + createdAt: new Date(), + }; + setMessages(prev => [...prev, newMessage]) + handleSubmitWithContent(newMessage.content); + }; + + const clearChat = () => { + setMessages([]); + setError(null); + }; + + return ( +
+
+

Chat Playground

+
+ + +
+
+ + {modelsError && ( +
+

{modelsError}

+
+ )} + + {error && ( +
+

{error}

+
+ )} + + +
+ ); +} diff --git a/llama_stack/ui/components.json b/llama_stack/ui/components.json index 4ee62ee10..cef815d9e 100644 --- a/llama_stack/ui/components.json +++ b/llama_stack/ui/components.json @@ -13,7 +13,7 @@ "aliases": { "components": "@/components", "utils": "@/lib/utils", - "ui": "@/components/ui", + "chat": "@/components/chat", "lib": "@/lib", "hooks": "@/hooks" }, diff --git a/llama_stack/ui/components/chat-completions/chat-messasge-item.tsx b/llama_stack/ui/components/chat-completions/chat-messasge-item.tsx index 2e8593bfb..6170e816e 100644 --- a/llama_stack/ui/components/chat-completions/chat-messasge-item.tsx +++ b/llama_stack/ui/components/chat-completions/chat-messasge-item.tsx @@ -7,7 +7,7 @@ import { extractTextFromContentPart } from "@/lib/format-message-content"; import { MessageBlock, ToolCallBlock, -} from "@/components/ui/message-components"; +} from "@/components/chat-playground/message-components"; interface ChatMessageItemProps { message: ChatMessage; diff --git a/llama_stack/ui/components/chat-playground/chat-message.tsx b/llama_stack/ui/components/chat-playground/chat-message.tsx new file mode 100644 index 000000000..e5d621c81 --- /dev/null +++ b/llama_stack/ui/components/chat-playground/chat-message.tsx @@ -0,0 +1,405 @@ +"use client" + +import React, { useMemo, useState } from "react" +import { cva, type VariantProps } from "class-variance-authority" +import { motion } from "framer-motion" +import { Ban, ChevronRight, Code2, Loader2, Terminal } from "lucide-react" + +import { cn } from "@/lib/utils" +import { + Collapsible, + CollapsibleContent, + CollapsibleTrigger, +} from "@/components/ui/collapsible" +import { FilePreview } from "@/components/ui/file-preview" +import { MarkdownRenderer } from "@/components/chat-playground/markdown-renderer" + +const chatBubbleVariants = cva( + "group/message relative break-words rounded-lg p-3 text-sm sm:max-w-[70%]", + { + variants: { + isUser: { + true: "bg-primary text-primary-foreground", + false: "bg-muted text-foreground", + }, + animation: { + none: "", + slide: "duration-300 animate-in fade-in-0", + scale: "duration-300 animate-in fade-in-0 zoom-in-75", + fade: "duration-500 animate-in fade-in-0", + }, + }, + compoundVariants: [ + { + isUser: true, + animation: "slide", + class: "slide-in-from-right", + }, + { + isUser: false, + animation: "slide", + class: "slide-in-from-left", + }, + { + isUser: true, + animation: "scale", + class: "origin-bottom-right", + }, + { + isUser: false, + animation: "scale", + class: "origin-bottom-left", + }, + ], + } +) + +type Animation = VariantProps["animation"] + +interface Attachment { + name?: string + contentType?: string + url: string +} + +interface PartialToolCall { + state: "partial-call" + toolName: string +} + +interface ToolCall { + state: "call" + toolName: string +} + +interface ToolResult { + state: "result" + toolName: string + result: { + __cancelled?: boolean + [key: string]: any + } +} + +type ToolInvocation = PartialToolCall | ToolCall | ToolResult + +interface ReasoningPart { + type: "reasoning" + reasoning: string +} + +interface ToolInvocationPart { + type: "tool-invocation" + toolInvocation: ToolInvocation +} + +interface TextPart { + type: "text" + text: string +} + +// For compatibility with AI SDK types, not used +interface SourcePart { + type: "source" + source?: any +} + +interface FilePart { + type: "file" + mimeType: string + data: string +} + +interface StepStartPart { + type: "step-start" +} + +type MessagePart = + | TextPart + | ReasoningPart + | ToolInvocationPart + | SourcePart + | FilePart + | StepStartPart + +export interface Message { + id: string + role: "user" | "assistant" | (string & {}) + content: string + createdAt?: Date + experimental_attachments?: Attachment[] + toolInvocations?: ToolInvocation[] + parts?: MessagePart[] +} + +export interface ChatMessageProps extends Message { + showTimeStamp?: boolean + animation?: Animation + actions?: React.ReactNode +} + +export const ChatMessage: React.FC = ({ + role, + content, + createdAt, + showTimeStamp = false, + animation = "scale", + actions, + experimental_attachments, + toolInvocations, + parts, +}) => { + const files = useMemo(() => { + return experimental_attachments?.map((attachment) => { + const dataArray = dataUrlToUint8Array(attachment.url) + const file = new File([dataArray], attachment.name ?? "Unknown", { + type: attachment.contentType, + }) + return file + }) + }, [experimental_attachments]) + + const isUser = role === "user" + + const formattedTime = createdAt?.toLocaleTimeString("en-US", { + hour: "2-digit", + minute: "2-digit", + }) + + if (isUser) { + return ( +
+ {files ? ( +
+ {files.map((file, index) => { + return + })} +
+ ) : null} + +
+ {content} +
+ + {showTimeStamp && createdAt ? ( + + ) : null} +
+ ) + } + + if (parts && parts.length > 0) { + return parts.map((part, index) => { + if (part.type === "text") { + return ( +
+
+ {part.text} + {actions ? ( +
+ {actions} +
+ ) : null} +
+ + {showTimeStamp && createdAt ? ( + + ) : null} +
+ ) + } else if (part.type === "reasoning") { + return + } else if (part.type === "tool-invocation") { + return ( + + ) + } + return null + }) + } + + if (toolInvocations && toolInvocations.length > 0) { + return + } + + return ( +
+
+ {content} + {actions ? ( +
+ {actions} +
+ ) : null} +
+ + {showTimeStamp && createdAt ? ( + + ) : null} +
+ ) +} + +function dataUrlToUint8Array(data: string) { + const base64 = data.split(",")[1] + const buf = Buffer.from(base64, "base64") + return new Uint8Array(buf) +} + +const ReasoningBlock = ({ part }: { part: ReasoningPart }) => { + const [isOpen, setIsOpen] = useState(false) + + return ( +
+ +
+ + + +
+ + +
+
+ {part.reasoning} +
+
+
+
+
+
+ ) +} + +function ToolCall({ + toolInvocations, +}: Pick) { + if (!toolInvocations?.length) return null + + return ( +
+ {toolInvocations.map((invocation, index) => { + const isCancelled = + invocation.state === "result" && + invocation.result.__cancelled === true + + if (isCancelled) { + return ( +
+ + + Cancelled{" "} + + {"`"} + {invocation.toolName} + {"`"} + + +
+ ) + } + + switch (invocation.state) { + case "partial-call": + case "call": + return ( +
+ + + Calling{" "} + + {"`"} + {invocation.toolName} + {"`"} + + ... + + +
+ ) + case "result": + return ( +
+
+ + + Result from{" "} + + {"`"} + {invocation.toolName} + {"`"} + + +
+
+                  {JSON.stringify(invocation.result, null, 2)}
+                
+
+ ) + default: + return null + } + })} +
+ ) +} diff --git a/llama_stack/ui/components/chat-playground/chat.tsx b/llama_stack/ui/components/chat-playground/chat.tsx new file mode 100644 index 000000000..ee83fd9bb --- /dev/null +++ b/llama_stack/ui/components/chat-playground/chat.tsx @@ -0,0 +1,349 @@ +"use client" + +import { + forwardRef, + useCallback, + useRef, + useState, + type ReactElement, +} from "react" +import { ArrowDown, ThumbsDown, ThumbsUp } from "lucide-react" + +import { cn } from "@/lib/utils" +import { useAutoScroll } from "@/hooks/use-auto-scroll" +import { Button } from "@/components/ui/button" +import { type Message } from "@/components/chat-playground/chat-message" +import { CopyButton } from "@/components/ui/copy-button" +import { MessageInput } from "@/components/chat-playground/message-input" +import { MessageList } from "@/components/chat-playground/message-list" +import { PromptSuggestions } from "@/components/chat-playground/prompt-suggestions" + +interface ChatPropsBase { + handleSubmit: ( + event?: { preventDefault?: () => void }, + options?: { experimental_attachments?: FileList } + ) => void + messages: Array + input: string + className?: string + handleInputChange: React.ChangeEventHandler + isGenerating: boolean + stop?: () => void + onRateResponse?: ( + messageId: string, + rating: "thumbs-up" | "thumbs-down" + ) => void + setMessages?: (messages: any[]) => void + transcribeAudio?: (blob: Blob) => Promise +} + +interface ChatPropsWithoutSuggestions extends ChatPropsBase { + append?: never + suggestions?: never +} + +interface ChatPropsWithSuggestions extends ChatPropsBase { + append: (message: { role: "user"; content: string }) => void + suggestions: string[] +} + +type ChatProps = ChatPropsWithoutSuggestions | ChatPropsWithSuggestions + +export function Chat({ + messages, + handleSubmit, + input, + handleInputChange, + stop, + isGenerating, + append, + suggestions, + className, + onRateResponse, + setMessages, + transcribeAudio, +}: ChatProps) { + const lastMessage = messages.at(-1) + const isEmpty = messages.length === 0 + const isTyping = lastMessage?.role === "user" + + const messagesRef = useRef(messages) + messagesRef.current = messages + + // Enhanced stop function that marks pending tool calls as cancelled + const handleStop = useCallback(() => { + stop?.() + + if (!setMessages) return + + const latestMessages = [...messagesRef.current] + const lastAssistantMessage = latestMessages.findLast( + (m) => m.role === "assistant" + ) + + if (!lastAssistantMessage) return + + let needsUpdate = false + let updatedMessage = { ...lastAssistantMessage } + + if (lastAssistantMessage.toolInvocations) { + const updatedToolInvocations = lastAssistantMessage.toolInvocations.map( + (toolInvocation) => { + if (toolInvocation.state === "call") { + needsUpdate = true + return { + ...toolInvocation, + state: "result", + result: { + content: "Tool execution was cancelled", + __cancelled: true, // Special marker to indicate cancellation + }, + } as const + } + return toolInvocation + } + ) + + if (needsUpdate) { + updatedMessage = { + ...updatedMessage, + toolInvocations: updatedToolInvocations, + } + } + } + + if (lastAssistantMessage.parts && lastAssistantMessage.parts.length > 0) { + const updatedParts = lastAssistantMessage.parts.map((part: any) => { + if ( + part.type === "tool-invocation" && + part.toolInvocation && + part.toolInvocation.state === "call" + ) { + needsUpdate = true + return { + ...part, + toolInvocation: { + ...part.toolInvocation, + state: "result", + result: { + content: "Tool execution was cancelled", + __cancelled: true, + }, + }, + } + } + return part + }) + + if (needsUpdate) { + updatedMessage = { + ...updatedMessage, + parts: updatedParts, + } + } + } + + if (needsUpdate) { + const messageIndex = latestMessages.findIndex( + (m) => m.id === lastAssistantMessage.id + ) + if (messageIndex !== -1) { + latestMessages[messageIndex] = updatedMessage + setMessages(latestMessages) + } + } + }, [stop, setMessages, messagesRef]) + + const messageOptions = useCallback( + (message: Message) => ({ + actions: onRateResponse ? ( + <> +
+ +
+ + + + ) : ( + + ), + }), + [onRateResponse] + ) + + return ( + +
+ {isEmpty && append && suggestions ? ( +
+ +
+ ) : null} + + {messages.length > 0 ? ( + + + + ) : null} +
+ +
+
+ + {({ files, setFiles }) => ( + + )} + +
+
+
+ ) +} +Chat.displayName = "Chat" + +export function ChatMessages({ + messages, + children, +}: React.PropsWithChildren<{ + messages: Message[] +}>) { + const { + containerRef, + scrollToBottom, + handleScroll, + shouldAutoScroll, + handleTouchStart, + } = useAutoScroll([messages]) + + return ( +
+
+ {children} +
+ + {!shouldAutoScroll && ( +
+
+ +
+
+ )} +
+ ) +} + +export const ChatContainer = forwardRef< + HTMLDivElement, + React.HTMLAttributes +>(({ className, ...props }, ref) => { + return ( +
+ ) +}) +ChatContainer.displayName = "ChatContainer" + +interface ChatFormProps { + className?: string + isPending: boolean + handleSubmit: ( + event?: { preventDefault?: () => void }, + options?: { experimental_attachments?: FileList } + ) => void + children: (props: { + files: File[] | null + setFiles: React.Dispatch> + }) => ReactElement +} + +export const ChatForm = forwardRef( + ({ children, handleSubmit, isPending, className }, ref) => { + const [files, setFiles] = useState(null) + + const onSubmit = (event: React.FormEvent) => { + // if (isPending) { + // event.preventDefault() + // return + // } + + if (!files) { + handleSubmit(event) + return + } + + const fileList = createFileList(files) + handleSubmit(event, { experimental_attachments: fileList }) + setFiles(null) + } + + return ( +
+ {children({ files, setFiles })} +
+ ) + } +) +ChatForm.displayName = "ChatForm" + +function createFileList(files: File[] | FileList): FileList { + const dataTransfer = new DataTransfer() + for (const file of Array.from(files)) { + dataTransfer.items.add(file) + } + return dataTransfer.files +} diff --git a/llama_stack/ui/components/chat-playground/interrupt-prompt.tsx b/llama_stack/ui/components/chat-playground/interrupt-prompt.tsx new file mode 100644 index 000000000..757863c62 --- /dev/null +++ b/llama_stack/ui/components/chat-playground/interrupt-prompt.tsx @@ -0,0 +1,41 @@ +"use client" + +import { AnimatePresence, motion } from "framer-motion" +import { X } from "lucide-react" + +interface InterruptPromptProps { + isOpen: boolean + close: () => void +} + +export function InterruptPrompt({ isOpen, close }: InterruptPromptProps) { + return ( + + {isOpen && ( + + Press Enter again to interrupt + + + )} + + ) +} diff --git a/llama_stack/ui/components/chat-playground/markdown-renderer.tsx b/llama_stack/ui/components/chat-playground/markdown-renderer.tsx new file mode 100644 index 000000000..1c2781eaf --- /dev/null +++ b/llama_stack/ui/components/chat-playground/markdown-renderer.tsx @@ -0,0 +1,233 @@ +import React, { Suspense, useEffect, useState } from "react" +import Markdown from "react-markdown" +import remarkGfm from "remark-gfm" + +import { cn } from "@/lib/utils" +import { CopyButton } from "@/components/ui/copy-button" + +interface MarkdownRendererProps { + children: string +} + +export function MarkdownRenderer({ children }: MarkdownRendererProps) { + return ( +
+ + {children} + +
+ ) +} + +interface HighlightedPre extends React.HTMLAttributes { + children: string + language: string +} + +const HighlightedPre = React.memo( + ({ children, language, ...props }: HighlightedPre) => { + const [tokens, setTokens] = useState(null) + const [isSupported, setIsSupported] = useState(false) + + useEffect(() => { + let mounted = true + + const loadAndHighlight = async () => { + try { + const { codeToTokens, bundledLanguages } = await import("shiki") + + if (!mounted) return + + if (!(language in bundledLanguages)) { + setIsSupported(false) + return + } + + setIsSupported(true) + + const { tokens: highlightedTokens } = await codeToTokens(children, { + lang: language as keyof typeof bundledLanguages, + defaultColor: false, + themes: { + light: "github-light", + dark: "github-dark", + }, + }) + + if (mounted) { + setTokens(highlightedTokens) + } + } catch (error) { + if (mounted) { + setIsSupported(false) + } + } + } + + loadAndHighlight() + + return () => { + mounted = false + } + }, [children, language]) + + if (!isSupported) { + return
{children}
+ } + + if (!tokens) { + return
{children}
+ } + + return ( +
+        
+          {tokens.map((line, lineIndex) => (
+            
+              
+                {line.map((token, tokenIndex) => {
+                  const style =
+                    typeof token.htmlStyle === "string"
+                      ? undefined
+                      : token.htmlStyle
+
+                  return (
+                    
+                      {token.content}
+                    
+                  )
+                })}
+              
+              {lineIndex !== tokens.length - 1 && "\n"}
+            
+          ))}
+        
+      
+ ) + } +) +HighlightedPre.displayName = "HighlightedCode" + +interface CodeBlockProps extends React.HTMLAttributes { + children: React.ReactNode + className?: string + language: string +} + +const CodeBlock = ({ + children, + className, + language, + ...restProps +}: CodeBlockProps) => { + const code = + typeof children === "string" + ? children + : childrenTakeAllStringContents(children) + + const preClass = cn( + "overflow-x-scroll rounded-md border bg-background/50 p-4 font-mono text-sm [scrollbar-width:none]", + className + ) + + return ( +
+ + {children} + + } + > + + {code} + + + +
+ +
+
+ ) +} + +function childrenTakeAllStringContents(element: any): string { + if (typeof element === "string") { + return element + } + + if (element?.props?.children) { + let children = element.props.children + + if (Array.isArray(children)) { + return children + .map((child) => childrenTakeAllStringContents(child)) + .join("") + } else { + return childrenTakeAllStringContents(children) + } + } + + return "" +} + +const COMPONENTS = { + h1: withClass("h1", "text-2xl font-semibold"), + h2: withClass("h2", "font-semibold text-xl"), + h3: withClass("h3", "font-semibold text-lg"), + h4: withClass("h4", "font-semibold text-base"), + h5: withClass("h5", "font-medium"), + strong: withClass("strong", "font-semibold"), + a: withClass("a", "text-primary underline underline-offset-2"), + blockquote: withClass("blockquote", "border-l-2 border-primary pl-4"), + code: ({ children, className, node, ...rest }: any) => { + const match = /language-(\w+)/.exec(className || "") + return match ? ( + + {children} + + ) : ( + &]:rounded-md [:not(pre)>&]:bg-background/50 [:not(pre)>&]:px-1 [:not(pre)>&]:py-0.5" + )} + {...rest} + > + {children} + + ) + }, + pre: ({ children }: any) => children, + ol: withClass("ol", "list-decimal space-y-2 pl-6"), + ul: withClass("ul", "list-disc space-y-2 pl-6"), + li: withClass("li", "my-1.5"), + table: withClass( + "table", + "w-full border-collapse overflow-y-auto rounded-md border border-foreground/20" + ), + th: withClass( + "th", + "border border-foreground/20 px-4 py-2 text-left font-bold [&[align=center]]:text-center [&[align=right]]:text-right" + ), + td: withClass( + "td", + "border border-foreground/20 px-4 py-2 text-left [&[align=center]]:text-center [&[align=right]]:text-right" + ), + tr: withClass("tr", "m-0 border-t p-0 even:bg-muted"), + p: withClass("p", "whitespace-pre-wrap"), + hr: withClass("hr", "border-foreground/20"), +} + +function withClass(Tag: keyof JSX.IntrinsicElements, classes: string) { + const Component = ({ node, ...props }: any) => ( + + ) + Component.displayName = Tag + return Component +} + +export default MarkdownRenderer diff --git a/llama_stack/ui/components/ui/message-components.tsx b/llama_stack/ui/components/chat-playground/message-components.tsx similarity index 100% rename from llama_stack/ui/components/ui/message-components.tsx rename to llama_stack/ui/components/chat-playground/message-components.tsx diff --git a/llama_stack/ui/components/chat-playground/message-input.tsx b/llama_stack/ui/components/chat-playground/message-input.tsx new file mode 100644 index 000000000..4a29386d9 --- /dev/null +++ b/llama_stack/ui/components/chat-playground/message-input.tsx @@ -0,0 +1,466 @@ +"use client" + +import React, { useEffect, useRef, useState } from "react" +import { AnimatePresence, motion } from "framer-motion" +import { ArrowUp, Info, Loader2, Mic, Paperclip, Square } from "lucide-react" +import { omit } from "remeda" + +import { cn } from "@/lib/utils" +import { useAudioRecording } from "@/hooks/use-audio-recording" +import { useAutosizeTextArea } from "@/hooks/use-autosize-textarea" +import { AudioVisualizer } from "@/components/ui/audio-visualizer" +import { Button } from "@/components/ui/button" +import { FilePreview } from "@/components/ui/file-preview" +import { InterruptPrompt } from "@/components/chat-playground/interrupt-prompt" + +interface MessageInputBaseProps + extends React.TextareaHTMLAttributes { + value: string + submitOnEnter?: boolean + stop?: () => void + isGenerating: boolean + enableInterrupt?: boolean + transcribeAudio?: (blob: Blob) => Promise +} + +interface MessageInputWithoutAttachmentProps extends MessageInputBaseProps { + allowAttachments?: false +} + +interface MessageInputWithAttachmentsProps extends MessageInputBaseProps { + allowAttachments: true + files: File[] | null + setFiles: React.Dispatch> +} + +type MessageInputProps = + | MessageInputWithoutAttachmentProps + | MessageInputWithAttachmentsProps + +export function MessageInput({ + placeholder = "Ask AI...", + className, + onKeyDown: onKeyDownProp, + submitOnEnter = true, + stop, + isGenerating, + enableInterrupt = true, + transcribeAudio, + ...props +}: MessageInputProps) { + const [isDragging, setIsDragging] = useState(false) + const [showInterruptPrompt, setShowInterruptPrompt] = useState(false) + + const { + isListening, + isSpeechSupported, + isRecording, + isTranscribing, + audioStream, + toggleListening, + stopRecording, + } = useAudioRecording({ + transcribeAudio, + onTranscriptionComplete: (text) => { + props.onChange?.({ target: { value: text } } as any) + }, + }) + + useEffect(() => { + if (!isGenerating) { + setShowInterruptPrompt(false) + } + }, [isGenerating]) + + const addFiles = (files: File[] | null) => { + if (props.allowAttachments) { + props.setFiles((currentFiles) => { + if (currentFiles === null) { + return files + } + + if (files === null) { + return currentFiles + } + + return [...currentFiles, ...files] + }) + } + } + + const onDragOver = (event: React.DragEvent) => { + if (props.allowAttachments !== true) return + event.preventDefault() + setIsDragging(true) + } + + const onDragLeave = (event: React.DragEvent) => { + if (props.allowAttachments !== true) return + event.preventDefault() + setIsDragging(false) + } + + const onDrop = (event: React.DragEvent) => { + setIsDragging(false) + if (props.allowAttachments !== true) return + event.preventDefault() + const dataTransfer = event.dataTransfer + if (dataTransfer.files.length) { + addFiles(Array.from(dataTransfer.files)) + } + } + + const onPaste = (event: React.ClipboardEvent) => { + const items = event.clipboardData?.items + if (!items) return + + const text = event.clipboardData.getData("text") + if (text && text.length > 500 && props.allowAttachments) { + event.preventDefault() + const blob = new Blob([text], { type: "text/plain" }) + const file = new File([blob], "Pasted text", { + type: "text/plain", + lastModified: Date.now(), + }) + addFiles([file]) + return + } + + const files = Array.from(items) + .map((item) => item.getAsFile()) + .filter((file) => file !== null) + + if (props.allowAttachments && files.length > 0) { + addFiles(files) + } + } + + const onKeyDown = (event: React.KeyboardEvent) => { + if (submitOnEnter && event.key === "Enter" && !event.shiftKey) { + event.preventDefault() + + if (isGenerating && stop && enableInterrupt) { + if (showInterruptPrompt) { + stop() + setShowInterruptPrompt(false) + event.currentTarget.form?.requestSubmit() + } else if ( + props.value || + (props.allowAttachments && props.files?.length) + ) { + setShowInterruptPrompt(true) + return + } + } + + event.currentTarget.form?.requestSubmit() + } + + onKeyDownProp?.(event) + } + + const textAreaRef = useRef(null) + const [textAreaHeight, setTextAreaHeight] = useState(0) + + useEffect(() => { + if (textAreaRef.current) { + setTextAreaHeight(textAreaRef.current.offsetHeight) + } + }, [props.value]) + + const showFileList = + props.allowAttachments && props.files && props.files.length > 0 + + + useAutosizeTextArea({ + ref: textAreaRef, + maxHeight: 240, + borderWidth: 1, + dependencies: [props.value, showFileList], + }) + + return ( +
+ {enableInterrupt && ( + setShowInterruptPrompt(false)} + /> + )} + + + +
+
+