diff --git a/.github/TRIAGERS.md b/.github/TRIAGERS.md index ed4f4a6c6..f5bd11531 100644 --- a/.github/TRIAGERS.md +++ b/.github/TRIAGERS.md @@ -1,2 +1,2 @@ # This file documents Triage members in the Llama Stack community - @bbrowning @franciscojavierarceo @leseb + @franciscojavierarceo diff --git a/.github/actions/run-and-record-tests/action.yml b/.github/actions/run-and-record-tests/action.yml index 573148e46..60550cfdc 100644 --- a/.github/actions/run-and-record-tests/action.yml +++ b/.github/actions/run-and-record-tests/action.yml @@ -2,9 +2,13 @@ name: 'Run and Record Tests' description: 'Run integration tests and handle recording/artifact upload' inputs: - test-types: - description: 'JSON array of test types to run' + test-subdirs: + description: 'Comma-separated list of test subdirectories to run' required: true + test-pattern: + description: 'Regex pattern to pass to pytest -k' + required: false + default: '' stack-config: description: 'Stack configuration to use' required: true @@ -32,12 +36,14 @@ runs: - name: Run Integration Tests shell: bash run: | - ./scripts/integration-tests.sh \ + uv run --no-sync ./scripts/integration-tests.sh \ --stack-config '${{ inputs.stack-config }}' \ --provider '${{ inputs.provider }}' \ - --test-types '${{ inputs.test-types }}' \ + --test-subdirs '${{ inputs.test-subdirs }}' \ + --test-pattern '${{ inputs.test-pattern }}' \ --inference-mode '${{ inputs.inference-mode }}' \ - ${{ inputs.run-vision-tests == 'true' && '--run-vision-tests' || '' }} + ${{ inputs.run-vision-tests == 'true' && '--run-vision-tests' || '' }} \ + | tee pytest-${{ inputs.inference-mode }}.log - name: Commit and push recordings @@ -57,10 +63,10 @@ runs: git commit -m "Recordings update from CI" fi - git fetch origin ${{ github.event.pull_request.head.ref }} - git rebase origin/${{ github.event.pull_request.head.ref }} + git fetch origin ${{ github.ref_name }} + git rebase origin/${{ github.ref_name }} echo "Rebased successfully" - git push origin HEAD:${{ github.event.pull_request.head.ref }} + git push origin HEAD:${{ github.ref_name }} echo "Pushed successfully" else echo "No recording changes" diff --git a/.github/actions/setup-runner/action.yml b/.github/actions/setup-runner/action.yml index 1ca02bbff..905d6b73a 100644 --- a/.github/actions/setup-runner/action.yml +++ b/.github/actions/setup-runner/action.yml @@ -16,14 +16,16 @@ runs: uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1 with: python-version: ${{ inputs.python-version }} - activate-environment: true version: 0.7.6 - name: Install dependencies shell: bash run: | + echo "Updating project dependencies via uv sync" uv sync --all-groups - uv pip install ollama faiss-cpu + + echo "Installing ad-hoc dependencies" + uv pip install faiss-cpu # Install llama-stack-client-python based on the client-version input if [ "${{ inputs.client-version }}" = "latest" ]; then @@ -37,4 +39,5 @@ runs: exit 1 fi - uv pip install -e . + echo "Installed llama packages" + uv pip list | grep llama diff --git a/.github/actions/setup-test-environment/action.yml b/.github/actions/setup-test-environment/action.yml index 4465fe159..d8005866c 100644 --- a/.github/actions/setup-test-environment/action.yml +++ b/.github/actions/setup-test-environment/action.yml @@ -53,7 +53,22 @@ runs: - name: Build Llama Stack shell: bash run: | - uv run llama stack build --template ci-tests --image-type venv + # Install llama-stack-client-python based on the client-version input + if [ "${{ inputs.client-version }}" = "latest" ]; then + echo "Installing latest llama-stack-client-python from main branch" + export LLAMA_STACK_CLIENT_DIR=git+https://github.com/llamastack/llama-stack-client-python.git@main + elif [ "${{ inputs.client-version }}" = "published" ]; then + echo "Installing published llama-stack-client-python from PyPI" + unset LLAMA_STACK_CLIENT_DIR + else + echo "Invalid client-version: ${{ inputs.client-version }}" + exit 1 + fi + + echo "Building Llama Stack" + + LLAMA_STACK_DIR=. \ + uv run --no-sync llama stack build --template ci-tests --image-type venv - name: Configure git for commits shell: bash diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 134efd93b..f88402a7a 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -9,6 +9,7 @@ updates: day: "saturday" commit-message: prefix: chore(github-deps) + - package-ecosystem: "uv" directory: "/" schedule: @@ -19,3 +20,14 @@ updates: - python commit-message: prefix: chore(python-deps) + + - package-ecosystem: npm + directory: "/llama_stack/ui" + schedule: + interval: "weekly" + day: "saturday" + labels: + - type/dependencies + - javascript + commit-message: + prefix: chore(ui-deps) diff --git a/.github/workflows/README.md b/.github/workflows/README.md index 3c3d93dc2..8344d12a4 100644 --- a/.github/workflows/README.md +++ b/.github/workflows/README.md @@ -18,5 +18,6 @@ Llama Stack uses GitHub Actions for Continuous Integration (CI). Below is a tabl | Close stale issues and PRs | [stale_bot.yml](stale_bot.yml) | Run the Stale Bot action | | Test External Providers Installed via Module | [test-external-provider-module.yml](test-external-provider-module.yml) | Test External Provider installation via Python module | | Test External API and Providers | [test-external.yml](test-external.yml) | Test the External API and Provider mechanisms | +| UI Tests | [ui-unit-tests.yml](ui-unit-tests.yml) | Run the UI test suite | | Unit Tests | [unit-tests.yml](unit-tests.yml) | Run the unit test suite | | Update ReadTheDocs | [update-readthedocs.yml](update-readthedocs.yml) | Update the Llama Stack ReadTheDocs site | diff --git a/.github/workflows/changelog.yml b/.github/workflows/changelog.yml index e406d99ee..7a75d85f6 100644 --- a/.github/workflows/changelog.yml +++ b/.github/workflows/changelog.yml @@ -17,7 +17,7 @@ jobs: pull-requests: write # for peter-evans/create-pull-request to create a PR runs-on: ubuntu-latest steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: ref: main fetch-depth: 0 diff --git a/.github/workflows/install-script-ci.yml b/.github/workflows/install-script-ci.yml index 5dc2b4412..a37919f56 100644 --- a/.github/workflows/install-script-ci.yml +++ b/.github/workflows/install-script-ci.yml @@ -16,21 +16,22 @@ jobs: lint: runs-on: ubuntu-latest steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # 4.2.2 + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # 5.0.0 - name: Run ShellCheck on install.sh run: shellcheck scripts/install.sh smoke-test-on-dev: runs-on: ubuntu-latest steps: - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Install dependencies uses: ./.github/actions/setup-runner - name: Build a single provider run: | - USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --template starter --image-type container --image-name test + USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run --no-sync \ + llama stack build --template starter --image-type container --image-name test - name: Run installer end-to-end run: | diff --git a/.github/workflows/integration-auth-tests.yml b/.github/workflows/integration-auth-tests.yml index ef2066497..6787806e9 100644 --- a/.github/workflows/integration-auth-tests.yml +++ b/.github/workflows/integration-auth-tests.yml @@ -10,6 +10,7 @@ on: paths: - 'distributions/**' - 'llama_stack/**' + - '!llama_stack/ui/**' - 'tests/integration/**' - 'uv.lock' - 'pyproject.toml' @@ -17,7 +18,7 @@ on: - '.github/workflows/integration-auth-tests.yml' # This workflow concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }} cancel-in-progress: true jobs: @@ -30,7 +31,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Install dependencies uses: ./.github/actions/setup-runner diff --git a/.github/workflows/integration-sql-store-tests.yml b/.github/workflows/integration-sql-store-tests.yml index 4e5b64963..3efd970e1 100644 --- a/.github/workflows/integration-sql-store-tests.yml +++ b/.github/workflows/integration-sql-store-tests.yml @@ -16,7 +16,7 @@ on: - '.github/workflows/integration-sql-store-tests.yml' # This workflow concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }} cancel-in-progress: true jobs: @@ -44,7 +44,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Install dependencies uses: ./.github/actions/setup-runner diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index f330d2c45..57e582b20 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -5,11 +5,12 @@ run-name: Run the integration test suite from tests/integration in replay mode on: push: branches: [ main ] - pull_request_target: + pull_request: branches: [ main ] types: [opened, synchronize, reopened] paths: - 'llama_stack/**' + - '!llama_stack/ui/**' - 'tests/**' - 'uv.lock' - 'pyproject.toml' @@ -31,35 +32,23 @@ on: description: 'Test against a specific provider' type: string default: 'ollama' + test-subdirs: + description: 'Comma-separated list of test subdirectories to run' + type: string + default: '' + test-pattern: + description: 'Regex pattern to pass to pytest -k' + type: string + default: '' concurrency: # Skip concurrency for pushes to main - each commit should be tested independently - group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.event.pull_request.number }} + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }} cancel-in-progress: true jobs: - discover-tests: - runs-on: ubuntu-latest - outputs: - test-types: ${{ steps.generate-test-types.outputs.test-types }} - - steps: - - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - - name: Generate test types - id: generate-test-types - run: | - # Get test directories dynamically, excluding non-test directories - # NOTE: we are excluding post_training since the tests take too long - TEST_TYPES=$(find tests/integration -maxdepth 1 -mindepth 1 -type d | - sed 's|tests/integration/||' | - grep -Ev "^(__pycache__|fixtures|test_cases|recordings|non_ci|post_training)$" | - sort | jq -R -s -c 'split("\n")[:-1]') - echo "test-types=$TEST_TYPES" >> $GITHUB_OUTPUT run-replay-mode-tests: - needs: discover-tests runs-on: ubuntu-latest name: ${{ format('Integration Tests ({0}, {1}, {2}, client={3}, vision={4})', matrix.client-type, matrix.provider, matrix.python-version, matrix.client-version, matrix.run-vision-tests) }} @@ -76,7 +65,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Setup test environment uses: ./.github/actions/setup-test-environment @@ -90,7 +79,8 @@ jobs: - name: Run tests uses: ./.github/actions/run-and-record-tests with: - test-types: ${{ needs.discover-tests.outputs.test-types }} + test-subdirs: ${{ inputs.test-subdirs }} + test-pattern: ${{ inputs.test-pattern }} stack-config: ${{ matrix.client-type == 'library' && 'ci-tests' || 'server:ci-tests' }} provider: ${{ matrix.provider }} inference-mode: 'replay' diff --git a/.github/workflows/integration-vector-io-tests.yml b/.github/workflows/integration-vector-io-tests.yml index f4d28e407..de5701073 100644 --- a/.github/workflows/integration-vector-io-tests.yml +++ b/.github/workflows/integration-vector-io-tests.yml @@ -9,14 +9,17 @@ on: branches: [ main ] paths: - 'llama_stack/**' + - '!llama_stack/ui/**' - 'tests/integration/vector_io/**' - 'uv.lock' - 'pyproject.toml' - 'requirements.txt' - '.github/workflows/integration-vector-io-tests.yml' # This workflow + schedule: + - cron: '0 0 * * *' # (test on python 3.13) Daily at 12 AM UTC concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }} cancel-in-progress: true jobs: @@ -25,12 +28,12 @@ jobs: strategy: matrix: vector-io-provider: ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "remote::chromadb", "remote::pgvector", "remote::weaviate", "remote::qdrant"] - python-version: ["3.12", "3.13"] + python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }} fail-fast: false # we want to run all tests regardless of failure steps: - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Install dependencies uses: ./.github/actions/setup-runner @@ -141,7 +144,7 @@ jobs: - name: Build Llama Stack run: | - uv run llama stack build --template ci-tests --image-type venv + uv run --no-sync llama stack build --template ci-tests --image-type venv - name: Check Storage and Memory Available Before Tests if: ${{ always() }} @@ -164,7 +167,8 @@ jobs: ENABLE_WEAVIATE: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'true' || '' }} WEAVIATE_CLUSTER_URL: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'localhost:8080' || '' }} run: | - uv run pytest -sv --stack-config="files=inline::localfs,inference=inline::sentence-transformers,vector_io=${{ matrix.vector-io-provider }}" \ + uv run --no-sync \ + pytest -sv --stack-config="files=inline::localfs,inference=inline::sentence-transformers,vector_io=${{ matrix.vector-io-provider }}" \ tests/integration/vector_io \ --embedding-model inline::sentence-transformers/all-MiniLM-L6-v2 diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 4f1c143d2..5f13620f7 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -8,7 +8,7 @@ on: branches: [main] concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }} cancel-in-progress: true jobs: @@ -20,7 +20,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: # For dependabot PRs, we need to checkout with a token that can push changes token: ${{ github.actor == 'dependabot[bot]' && secrets.GITHUB_TOKEN || github.token }} @@ -36,6 +36,17 @@ jobs: **/requirements*.txt .pre-commit-config.yaml + - name: Set up Node.js + uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4.4.0 + with: + node-version: '20' + cache: 'npm' + cache-dependency-path: 'llama_stack/ui/' + + - name: Install npm dependencies + run: npm ci + working-directory: llama_stack/ui + - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1 continue-on-error: true env: diff --git a/.github/workflows/providers-build.yml b/.github/workflows/providers-build.yml index 929d76760..391acbcf8 100644 --- a/.github/workflows/providers-build.yml +++ b/.github/workflows/providers-build.yml @@ -26,7 +26,7 @@ on: - 'pyproject.toml' concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }} cancel-in-progress: true jobs: @@ -36,7 +36,7 @@ jobs: distros: ${{ steps.set-matrix.outputs.distros }} steps: - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Generate Distribution List id: set-matrix @@ -55,7 +55,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Install dependencies uses: ./.github/actions/setup-runner @@ -79,7 +79,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Install dependencies uses: ./.github/actions/setup-runner @@ -92,7 +92,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Install dependencies uses: ./.github/actions/setup-runner @@ -106,6 +106,10 @@ jobs: - name: Inspect the container image entrypoint run: | IMAGE_ID=$(docker images --format "{{.Repository}}:{{.Tag}}" | head -n 1) + if [ -z "$IMAGE_ID" ]; then + echo "No image found" + exit 1 + fi entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID) echo "Entrypoint: $entrypoint" if [ "$entrypoint" != "[python -m llama_stack.core.server.server /app/run.yaml]" ]; then @@ -117,7 +121,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Install dependencies uses: ./.github/actions/setup-runner @@ -140,6 +144,10 @@ jobs: - name: Inspect UBI9 image run: | IMAGE_ID=$(docker images --format "{{.Repository}}:{{.Tag}}" | head -n 1) + if [ -z "$IMAGE_ID" ]; then + echo "No image found" + exit 1 + fi entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID) echo "Entrypoint: $entrypoint" if [ "$entrypoint" != "[python -m llama_stack.core.server.server /app/run.yaml]" ]; then diff --git a/.github/workflows/python-build-test.yml b/.github/workflows/python-build-test.yml index 67dc49cce..bf9a3e057 100644 --- a/.github/workflows/python-build-test.yml +++ b/.github/workflows/python-build-test.yml @@ -9,6 +9,8 @@ on: pull_request: branches: - main + paths-ignore: + - 'llama_stack/ui/**' jobs: build: @@ -19,10 +21,10 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Install uv - uses: astral-sh/setup-uv@e92bafb6253dcd438e0484186d7669ea7a8ca1cc # v6.4.3 + uses: astral-sh/setup-uv@4959332f0f014c5280e7eac8b70c90cb574c9f9b # v6.6.0 with: python-version: ${{ matrix.python-version }} activate-environment: true diff --git a/.github/workflows/record-integration-tests.yml b/.github/workflows/record-integration-tests.yml index 12957db27..d4f5586e2 100644 --- a/.github/workflows/record-integration-tests.yml +++ b/.github/workflows/record-integration-tests.yml @@ -1,93 +1,53 @@ +# This workflow should be run manually when needing to re-record tests. This happens when you have +# - added a new test +# - or changed an existing test such that a new inference call is made +# You should make a PR and then run this workflow on that PR branch. The workflow will re-record the +# tests and commit the recordings to the PR branch. name: Integration Tests (Record) run-name: Run the integration test suite from tests/integration on: - pull_request: - branches: [ main ] - types: [opened, synchronize, labeled] - paths: - - 'llama_stack/**' - - 'tests/**' - - 'uv.lock' - - 'pyproject.toml' - - '.github/workflows/record-integration-tests.yml' # This workflow - - '.github/actions/setup-ollama/action.yml' - - '.github/actions/setup-test-environment/action.yml' - - '.github/actions/run-and-record-tests/action.yml' workflow_dispatch: inputs: + test-subdirs: + description: 'Comma-separated list of test subdirectories to run' + type: string + default: '' test-provider: description: 'Test against a specific provider' type: string default: 'ollama' - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true + run-vision-tests: + description: 'Whether to run vision tests' + type: boolean + default: false + test-pattern: + description: 'Regex pattern to pass to pytest -k' + type: string + default: '' jobs: - discover-tests: - if: contains(github.event.pull_request.labels.*.name, 're-record-tests') || - contains(github.event.pull_request.labels.*.name, 're-record-vision-tests') - runs-on: ubuntu-latest - outputs: - test-types: ${{ steps.generate-test-types.outputs.test-types }} - matrix-modes: ${{ steps.generate-test-types.outputs.matrix-modes }} - - steps: - - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - - name: Generate test types - id: generate-test-types - run: | - # Get test directories dynamically, excluding non-test directories - TEST_TYPES=$(find tests/integration -maxdepth 1 -mindepth 1 -type d -printf "%f\n" | - grep -Ev "^(__pycache__|fixtures|test_cases|recordings|post_training)$" | - sort | jq -R -s -c 'split("\n")[:-1]') - echo "test-types=$TEST_TYPES" >> $GITHUB_OUTPUT - - labels=$(gh pr view ${{ github.event.pull_request.number }} --json labels --jq '.labels[].name') - echo "labels=$labels" - - modes_array=() - if [[ $labels == *"re-record-vision-tests"* ]]; then - modes_array+=("vision") - fi - if [[ $labels == *"re-record-tests"* ]]; then - modes_array+=("non-vision") - fi - - # Convert to JSON array - if [ ${#modes_array[@]} -eq 0 ]; then - matrix_modes="[]" - else - matrix_modes=$(printf '%s\n' "${modes_array[@]}" | jq -R -s -c 'split("\n")[:-1]') - fi - echo "matrix_modes=$matrix_modes" - echo "matrix-modes=$matrix_modes" >> $GITHUB_OUTPUT - - env: - GH_TOKEN: ${{ github.token }} - record-tests: - needs: discover-tests runs-on: ubuntu-latest permissions: contents: write - strategy: - fail-fast: false - matrix: - mode: ${{ fromJSON(needs.discover-tests.outputs.matrix-modes) }} - steps: + - name: Echo workflow inputs + run: | + echo "::group::Workflow Inputs" + echo "test-subdirs: ${{ inputs.test-subdirs }}" + echo "test-provider: ${{ inputs.test-provider }}" + echo "run-vision-tests: ${{ inputs.run-vision-tests }}" + echo "test-pattern: ${{ inputs.test-pattern }}" + echo "branch: ${{ github.ref_name }}" + echo "::endgroup::" + - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: - ref: ${{ github.event.pull_request.head.ref }} fetch-depth: 0 - name: Setup test environment @@ -96,14 +56,15 @@ jobs: python-version: "3.12" # Use single Python version for recording client-version: "latest" provider: ${{ inputs.test-provider || 'ollama' }} - run-vision-tests: ${{ matrix.mode == 'vision' && 'true' || 'false' }} + run-vision-tests: ${{ inputs.run-vision-tests }} inference-mode: 'record' - name: Run and record tests uses: ./.github/actions/run-and-record-tests with: - test-types: ${{ needs.discover-tests.outputs.test-types }} + test-pattern: ${{ inputs.test-pattern }} + test-subdirs: ${{ inputs.test-subdirs }} stack-config: 'server:ci-tests' # recording must be done with server since more tests are run provider: ${{ inputs.test-provider || 'ollama' }} inference-mode: 'record' - run-vision-tests: ${{ matrix.mode == 'vision' && 'true' || 'false' }} + run-vision-tests: ${{ inputs.run-vision-tests }} diff --git a/.github/workflows/semantic-pr.yml b/.github/workflows/semantic-pr.yml index 57a4df646..4a078fa00 100644 --- a/.github/workflows/semantic-pr.yml +++ b/.github/workflows/semantic-pr.yml @@ -22,6 +22,6 @@ jobs: runs-on: ubuntu-latest steps: - name: Check PR Title's semantic conformance - uses: amannn/action-semantic-pull-request@0723387faaf9b38adef4775cd42cfd5155ed6017 # v5.5.3 + uses: amannn/action-semantic-pull-request@48f256284bd46cdaab1048c3721360e808335d50 # v6.1.1 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/test-external-provider-module.yml b/.github/workflows/test-external-provider-module.yml index d61b0dfe9..8a757b068 100644 --- a/.github/workflows/test-external-provider-module.yml +++ b/.github/workflows/test-external-provider-module.yml @@ -27,7 +27,7 @@ jobs: # container and point 'uv pip install' to the correct path... steps: - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Install dependencies uses: ./.github/actions/setup-runner diff --git a/.github/workflows/test-external.yml b/.github/workflows/test-external.yml index 27181a236..7ee467451 100644 --- a/.github/workflows/test-external.yml +++ b/.github/workflows/test-external.yml @@ -9,6 +9,7 @@ on: branches: [ main ] paths: - 'llama_stack/**' + - '!llama_stack/ui/**' - 'tests/integration/**' - 'uv.lock' - 'pyproject.toml' @@ -26,7 +27,7 @@ jobs: # container and point 'uv pip install' to the correct path... steps: - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Install dependencies uses: ./.github/actions/setup-runner @@ -43,11 +44,11 @@ jobs: - name: Print distro dependencies run: | - USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --config tests/external/build.yaml --print-deps-only + USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run --no-sync llama stack build --config tests/external/build.yaml --print-deps-only - name: Build distro from config file run: | - USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --config tests/external/build.yaml + USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run --no-sync llama stack build --config tests/external/build.yaml - name: Start Llama Stack server in background if: ${{ matrix.image-type }} == 'venv' diff --git a/.github/workflows/ui-unit-tests.yml b/.github/workflows/ui-unit-tests.yml new file mode 100644 index 000000000..2afb92bee --- /dev/null +++ b/.github/workflows/ui-unit-tests.yml @@ -0,0 +1,55 @@ +name: UI Tests + +run-name: Run the UI test suite + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + paths: + - 'llama_stack/ui/**' + - '.github/workflows/ui-unit-tests.yml' # This workflow + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }} + cancel-in-progress: true + +jobs: + ui-tests: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + node-version: [22] + + steps: + - name: Checkout repository + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + + - name: Setup Node.js + uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4.4.0 + with: + node-version: ${{ matrix.node-version }} + cache: 'npm' + cache-dependency-path: 'llama_stack/ui/package-lock.json' + + - name: Install dependencies + working-directory: llama_stack/ui + run: npm ci + + - name: Run linting + working-directory: llama_stack/ui + run: npm run lint + + - name: Run format check + working-directory: llama_stack/ui + run: npm run format:check + + - name: Run unit tests + working-directory: llama_stack/ui + env: + CI: true + + run: npm test -- --coverage --watchAll=false --passWithNoTests diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index b133511d1..dd2097a45 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -9,6 +9,7 @@ on: branches: [ main ] paths: - 'llama_stack/**' + - '!llama_stack/ui/**' - 'tests/unit/**' - 'uv.lock' - 'pyproject.toml' @@ -17,7 +18,7 @@ on: workflow_dispatch: concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }} cancel-in-progress: true jobs: @@ -31,7 +32,7 @@ jobs: - "3.13" steps: - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Install dependencies uses: ./.github/actions/setup-runner diff --git a/.github/workflows/update-readthedocs.yml b/.github/workflows/update-readthedocs.yml index 1dcfdeca5..e12f0adf8 100644 --- a/.github/workflows/update-readthedocs.yml +++ b/.github/workflows/update-readthedocs.yml @@ -27,7 +27,7 @@ on: - '.github/workflows/update-readthedocs.yml' concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }} cancel-in-progress: true jobs: @@ -37,7 +37,7 @@ jobs: TOKEN: ${{ secrets.READTHEDOCS_TOKEN }} steps: - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Install dependencies uses: ./.github/actions/setup-runner diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 30843173c..514fe6d2e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,6 +2,7 @@ exclude: 'build/' default_language_version: python: python3.12 + node: "22" repos: - repo: https://github.com/pre-commit/pre-commit-hooks @@ -145,6 +146,32 @@ repos: pass_filenames: false require_serial: true files: ^.github/workflows/.*$ + - id: ui-linter + name: Format & Lint UI + entry: bash ./scripts/run-ui-linter.sh + language: system + files: ^llama_stack/ui/.*\.(ts|tsx)$ + pass_filenames: false + require_serial: true + + - id: check-log-usage + name: Ensure 'llama_stack.log' usage for logging + entry: bash + language: system + types: [python] + pass_filenames: true + args: + - -c + - | + matches=$(grep -EnH '^[^#]*\b(import\s+logging|from\s+logging\b)' "$@" | grep -v -e '#\s*allow-direct-logging' || true) + if [ -n "$matches" ]; then + # GitHub Actions annotation format + while IFS=: read -r file line_num rest; do + echo "::error file=$file,line=$line_num::Do not use 'import logging' or 'from logging import' in $file. Use the custom log instead: from llama_stack.log import get_logger; logger = get_logger(). If direct logging is truly needed, add: # allow-direct-logging" + done <<< "$matches" + exit 1 + fi + exit 0 ci: autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks diff --git a/README.md b/README.md index 8db4580a2..4df4a5372 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,5 @@ # Llama Stack -meta-llama%2Fllama-stack | Trendshift - ------ [![PyPI version](https://img.shields.io/pypi/v/llama_stack.svg)](https://pypi.org/project/llama_stack/) [![PyPI - Downloads](https://img.shields.io/pypi/dm/llama-stack)](https://pypi.org/project/llama-stack/) [![License](https://img.shields.io/pypi/l/llama_stack.svg)](https://github.com/meta-llama/llama-stack/blob/main/LICENSE) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 28c02829b..9578703c0 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -4605,6 +4605,49 @@ } } }, + "/v1/inference/rerank": { + "post": { + "responses": { + "200": { + "description": "RerankResponse with indices sorted by relevance score (descending).", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/RerankResponse" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "Inference" + ], + "description": "Rerank a list of documents based on their relevance to a query.", + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/RerankRequest" + } + } + }, + "required": true + } + } + }, "/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume": { "post": { "responses": { @@ -8821,6 +8864,61 @@ "title": "OpenAIResponseOutputMessageMCPListTools", "description": "MCP list tools output message containing available tools from an MCP server." }, + "OpenAIResponseContentPart": { + "oneOf": [ + { + "$ref": "#/components/schemas/OpenAIResponseContentPartOutputText" + }, + { + "$ref": "#/components/schemas/OpenAIResponseContentPartRefusal" + } + ], + "discriminator": { + "propertyName": "type", + "mapping": { + "output_text": "#/components/schemas/OpenAIResponseContentPartOutputText", + "refusal": "#/components/schemas/OpenAIResponseContentPartRefusal" + } + } + }, + "OpenAIResponseContentPartOutputText": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "output_text", + "default": "output_text" + }, + "text": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "type", + "text" + ], + "title": "OpenAIResponseContentPartOutputText" + }, + "OpenAIResponseContentPartRefusal": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "refusal", + "default": "refusal" + }, + "refusal": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "type", + "refusal" + ], + "title": "OpenAIResponseContentPartRefusal" + }, "OpenAIResponseObjectStream": { "oneOf": [ { @@ -8877,6 +8975,12 @@ { "$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallCompleted" }, + { + "$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseContentPartAdded" + }, + { + "$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseContentPartDone" + }, { "$ref": "#/components/schemas/OpenAIResponseObjectStreamResponseCompleted" } @@ -8902,6 +9006,8 @@ "response.mcp_call.in_progress": "#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallInProgress", "response.mcp_call.failed": "#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallFailed", "response.mcp_call.completed": "#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallCompleted", + "response.content_part.added": "#/components/schemas/OpenAIResponseObjectStreamResponseContentPartAdded", + "response.content_part.done": "#/components/schemas/OpenAIResponseObjectStreamResponseContentPartDone", "response.completed": "#/components/schemas/OpenAIResponseObjectStreamResponseCompleted" } } @@ -8928,6 +9034,80 @@ "title": "OpenAIResponseObjectStreamResponseCompleted", "description": "Streaming event indicating a response has been completed." }, + "OpenAIResponseObjectStreamResponseContentPartAdded": { + "type": "object", + "properties": { + "response_id": { + "type": "string", + "description": "Unique identifier of the response containing this content" + }, + "item_id": { + "type": "string", + "description": "Unique identifier of the output item containing this content part" + }, + "part": { + "$ref": "#/components/schemas/OpenAIResponseContentPart", + "description": "The content part that was added" + }, + "sequence_number": { + "type": "integer", + "description": "Sequential number for ordering streaming events" + }, + "type": { + "type": "string", + "const": "response.content_part.added", + "default": "response.content_part.added", + "description": "Event type identifier, always \"response.content_part.added\"" + } + }, + "additionalProperties": false, + "required": [ + "response_id", + "item_id", + "part", + "sequence_number", + "type" + ], + "title": "OpenAIResponseObjectStreamResponseContentPartAdded", + "description": "Streaming event for when a new content part is added to a response item." + }, + "OpenAIResponseObjectStreamResponseContentPartDone": { + "type": "object", + "properties": { + "response_id": { + "type": "string", + "description": "Unique identifier of the response containing this content" + }, + "item_id": { + "type": "string", + "description": "Unique identifier of the output item containing this content part" + }, + "part": { + "$ref": "#/components/schemas/OpenAIResponseContentPart", + "description": "The completed content part" + }, + "sequence_number": { + "type": "integer", + "description": "Sequential number for ordering streaming events" + }, + "type": { + "type": "string", + "const": "response.content_part.done", + "default": "response.content_part.done", + "description": "Event type identifier, always \"response.content_part.done\"" + } + }, + "additionalProperties": false, + "required": [ + "response_id", + "item_id", + "part", + "sequence_number", + "type" + ], + "title": "OpenAIResponseObjectStreamResponseContentPartDone", + "description": "Streaming event for when a content part is completed." + }, "OpenAIResponseObjectStreamResponseCreated": { "type": "object", "properties": { @@ -14630,7 +14810,8 @@ "OpenAIFilePurpose": { "type": "string", "enum": [ - "assistants" + "assistants", + "batch" ], "title": "OpenAIFilePurpose", "description": "Valid purpose values for OpenAI Files API." @@ -14707,7 +14888,8 @@ "purpose": { "type": "string", "enum": [ - "assistants" + "assistants", + "batch" ], "description": "The intended purpose of the file" } @@ -15926,12 +16108,16 @@ "value": { "type": "number", "description": "The numeric value of the metric at this timestamp" + }, + "unit": { + "type": "string" } }, "additionalProperties": false, "required": [ "timestamp", - "value" + "value", + "unit" ], "title": "MetricDataPoint", "description": "A single data point in a metric time series." @@ -16489,6 +16675,95 @@ ], "title": "RegisterVectorDbRequest" }, + "RerankRequest": { + "type": "object", + "properties": { + "model": { + "type": "string", + "description": "The identifier of the reranking model to use." + }, + "query": { + "oneOf": [ + { + "type": "string" + }, + { + "$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam" + }, + { + "$ref": "#/components/schemas/OpenAIChatCompletionContentPartImageParam" + } + ], + "description": "The search query to rank items against. Can be a string, text content part, or image content part. The input must not exceed the model's max input token length." + }, + "items": { + "type": "array", + "items": { + "oneOf": [ + { + "type": "string" + }, + { + "$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam" + }, + { + "$ref": "#/components/schemas/OpenAIChatCompletionContentPartImageParam" + } + ] + }, + "description": "List of items to rerank. Each item can be a string, text content part, or image content part. Each input must not exceed the model's max input token length." + }, + "max_num_results": { + "type": "integer", + "description": "(Optional) Maximum number of results to return. Default: returns all." + } + }, + "additionalProperties": false, + "required": [ + "model", + "query", + "items" + ], + "title": "RerankRequest" + }, + "RerankData": { + "type": "object", + "properties": { + "index": { + "type": "integer", + "description": "The original index of the document in the input list" + }, + "relevance_score": { + "type": "number", + "description": "The relevance score from the model output. Values are inverted when applicable so that higher scores indicate greater relevance." + } + }, + "additionalProperties": false, + "required": [ + "index", + "relevance_score" + ], + "title": "RerankData", + "description": "A single rerank result from a reranking response." + }, + "RerankResponse": { + "type": "object", + "properties": { + "data": { + "type": "array", + "items": { + "$ref": "#/components/schemas/RerankData" + }, + "description": "List of rerank result objects, sorted by relevance score (descending)" + } + }, + "additionalProperties": false, + "required": [ + "data" + ], + "title": "RerankResponse", + "description": "Response from a reranking request." + }, "ResumeAgentTurnRequest": { "type": "object", "properties": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index e79513652..fc590da5b 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -3264,6 +3264,37 @@ paths: schema: $ref: '#/components/schemas/QueryTracesRequest' required: true + /v1/inference/rerank: + post: + responses: + '200': + description: >- + RerankResponse with indices sorted by relevance score (descending). + content: + application/json: + schema: + $ref: '#/components/schemas/RerankResponse' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Inference + description: >- + Rerank a list of documents based on their relevance to a query. + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/RerankRequest' + required: true /v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume: post: responses: @@ -6441,6 +6472,43 @@ components: title: OpenAIResponseOutputMessageMCPListTools description: >- MCP list tools output message containing available tools from an MCP server. + OpenAIResponseContentPart: + oneOf: + - $ref: '#/components/schemas/OpenAIResponseContentPartOutputText' + - $ref: '#/components/schemas/OpenAIResponseContentPartRefusal' + discriminator: + propertyName: type + mapping: + output_text: '#/components/schemas/OpenAIResponseContentPartOutputText' + refusal: '#/components/schemas/OpenAIResponseContentPartRefusal' + OpenAIResponseContentPartOutputText: + type: object + properties: + type: + type: string + const: output_text + default: output_text + text: + type: string + additionalProperties: false + required: + - type + - text + title: OpenAIResponseContentPartOutputText + OpenAIResponseContentPartRefusal: + type: object + properties: + type: + type: string + const: refusal + default: refusal + refusal: + type: string + additionalProperties: false + required: + - type + - refusal + title: OpenAIResponseContentPartRefusal OpenAIResponseObjectStream: oneOf: - $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCreated' @@ -6461,6 +6529,8 @@ components: - $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallInProgress' - $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallFailed' - $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallCompleted' + - $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseContentPartAdded' + - $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseContentPartDone' - $ref: '#/components/schemas/OpenAIResponseObjectStreamResponseCompleted' discriminator: propertyName: type @@ -6483,6 +6553,8 @@ components: response.mcp_call.in_progress: '#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallInProgress' response.mcp_call.failed: '#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallFailed' response.mcp_call.completed: '#/components/schemas/OpenAIResponseObjectStreamResponseMcpCallCompleted' + response.content_part.added: '#/components/schemas/OpenAIResponseObjectStreamResponseContentPartAdded' + response.content_part.done: '#/components/schemas/OpenAIResponseObjectStreamResponseContentPartDone' response.completed: '#/components/schemas/OpenAIResponseObjectStreamResponseCompleted' "OpenAIResponseObjectStreamResponseCompleted": type: object @@ -6504,6 +6576,76 @@ components: OpenAIResponseObjectStreamResponseCompleted description: >- Streaming event indicating a response has been completed. + "OpenAIResponseObjectStreamResponseContentPartAdded": + type: object + properties: + response_id: + type: string + description: >- + Unique identifier of the response containing this content + item_id: + type: string + description: >- + Unique identifier of the output item containing this content part + part: + $ref: '#/components/schemas/OpenAIResponseContentPart' + description: The content part that was added + sequence_number: + type: integer + description: >- + Sequential number for ordering streaming events + type: + type: string + const: response.content_part.added + default: response.content_part.added + description: >- + Event type identifier, always "response.content_part.added" + additionalProperties: false + required: + - response_id + - item_id + - part + - sequence_number + - type + title: >- + OpenAIResponseObjectStreamResponseContentPartAdded + description: >- + Streaming event for when a new content part is added to a response item. + "OpenAIResponseObjectStreamResponseContentPartDone": + type: object + properties: + response_id: + type: string + description: >- + Unique identifier of the response containing this content + item_id: + type: string + description: >- + Unique identifier of the output item containing this content part + part: + $ref: '#/components/schemas/OpenAIResponseContentPart' + description: The completed content part + sequence_number: + type: integer + description: >- + Sequential number for ordering streaming events + type: + type: string + const: response.content_part.done + default: response.content_part.done + description: >- + Event type identifier, always "response.content_part.done" + additionalProperties: false + required: + - response_id + - item_id + - part + - sequence_number + - type + title: >- + OpenAIResponseObjectStreamResponseContentPartDone + description: >- + Streaming event for when a content part is completed. "OpenAIResponseObjectStreamResponseCreated": type: object properties: @@ -10840,6 +10982,7 @@ components: type: string enum: - assistants + - batch title: OpenAIFilePurpose description: >- Valid purpose values for OpenAI Files API. @@ -10908,6 +11051,7 @@ components: type: string enum: - assistants + - batch description: The intended purpose of the file additionalProperties: false required: @@ -11838,10 +11982,13 @@ components: type: number description: >- The numeric value of the metric at this timestamp + unit: + type: string additionalProperties: false required: - timestamp - value + - unit title: MetricDataPoint description: >- A single data point in a metric time series. @@ -12252,6 +12399,76 @@ components: - vector_db_id - embedding_model title: RegisterVectorDbRequest + RerankRequest: + type: object + properties: + model: + type: string + description: >- + The identifier of the reranking model to use. + query: + oneOf: + - type: string + - $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam' + - $ref: '#/components/schemas/OpenAIChatCompletionContentPartImageParam' + description: >- + The search query to rank items against. Can be a string, text content + part, or image content part. The input must not exceed the model's max + input token length. + items: + type: array + items: + oneOf: + - type: string + - $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam' + - $ref: '#/components/schemas/OpenAIChatCompletionContentPartImageParam' + description: >- + List of items to rerank. Each item can be a string, text content part, + or image content part. Each input must not exceed the model's max input + token length. + max_num_results: + type: integer + description: >- + (Optional) Maximum number of results to return. Default: returns all. + additionalProperties: false + required: + - model + - query + - items + title: RerankRequest + RerankData: + type: object + properties: + index: + type: integer + description: >- + The original index of the document in the input list + relevance_score: + type: number + description: >- + The relevance score from the model output. Values are inverted when applicable + so that higher scores indicate greater relevance. + additionalProperties: false + required: + - index + - relevance_score + title: RerankData + description: >- + A single rerank result from a reranking response. + RerankResponse: + type: object + properties: + data: + type: array + items: + $ref: '#/components/schemas/RerankData' + description: >- + List of rerank result objects, sorted by relevance score (descending) + additionalProperties: false + required: + - data + title: RerankResponse + description: Response from a reranking request. ResumeAgentTurnRequest: type: object properties: diff --git a/docs/source/concepts/apis.md b/docs/source/concepts/apis.md index 5a10d6498..f8f73a928 100644 --- a/docs/source/concepts/apis.md +++ b/docs/source/concepts/apis.md @@ -18,3 +18,4 @@ We are working on adding a few more APIs to complete the application lifecycle. - **Batch Inference**: run inference on a dataset of inputs - **Batch Agents**: run agents on a dataset of inputs - **Synthetic Data Generation**: generate synthetic data for model development +- **Batches**: OpenAI-compatible batch management for inference diff --git a/docs/source/contributing/index.md b/docs/source/contributing/index.md index 7a3a1c2e2..1846f4d97 100644 --- a/docs/source/contributing/index.md +++ b/docs/source/contributing/index.md @@ -4,11 +4,11 @@ ## Adding a New Provider -See the [Adding a New API Provider Page](new_api_provider.md) which describes how to add new API providers to the Stack. +See: +- [Adding a New API Provider Page](new_api_provider.md) which describes how to add new API providers to the Stack. +- [Vector Database Page](new_vector_database.md) which describes how to add a new vector databases with Llama Stack. +- [External Provider Page](../providers/external/index.md) which describes how to add external providers to the Stack. -See the [Vector Database Page](new_vector_database.md) which describes how to add a new vector databases with Llama Stack. - -See the [External Provider Page](../providers/external/index.md) which describes how to add external providers to the Stack. ```{toctree} :maxdepth: 1 :hidden: @@ -19,11 +19,21 @@ new_vector_database ## Testing -See the [Test Page](testing.md) which describes how to test your changes. + +```{include} ../../../tests/README.md +``` + +## Advanced Topics + +For developers who need deeper understanding of the testing system internals: + ```{toctree} :maxdepth: 1 -:hidden: -:caption: Testing -testing -``` \ No newline at end of file +testing/record-replay +``` + +### Benchmarking + +```{include} ../../../docs/source/distributions/k8s-benchmark/README.md +``` diff --git a/docs/source/contributing/testing.md b/docs/source/contributing/testing.md deleted file mode 100644 index 454ded266..000000000 --- a/docs/source/contributing/testing.md +++ /dev/null @@ -1,8 +0,0 @@ -```{include} ../../../tests/README.md -``` - -```{include} ../../../tests/unit/README.md -``` - -```{include} ../../../tests/integration/README.md -``` diff --git a/docs/source/contributing/testing/record-replay.md b/docs/source/contributing/testing/record-replay.md new file mode 100644 index 000000000..3049d333c --- /dev/null +++ b/docs/source/contributing/testing/record-replay.md @@ -0,0 +1,234 @@ +# Record-Replay System + +Understanding how Llama Stack captures and replays API interactions for testing. + +## Overview + +The record-replay system solves a fundamental challenge in AI testing: how do you test against expensive, non-deterministic APIs without breaking the bank or dealing with flaky tests? + +The solution: intercept API calls, store real responses, and replay them later. This gives you real API behavior without the cost or variability. + +## How It Works + +### Request Hashing + +Every API request gets converted to a deterministic hash for lookup: + +```python +def normalize_request(method: str, url: str, headers: dict, body: dict) -> str: + normalized = { + "method": method.upper(), + "endpoint": urlparse(url).path, # Just the path, not full URL + "body": body, # Request parameters + } + return hashlib.sha256(json.dumps(normalized, sort_keys=True).encode()).hexdigest() +``` + +**Key insight:** The hashing is intentionally precise. Different whitespace, float precision, or parameter order produces different hashes. This prevents subtle bugs from false cache hits. + +```python +# These produce DIFFERENT hashes: +{"content": "Hello world"} +{"content": "Hello world\n"} +{"temperature": 0.7} +{"temperature": 0.7000001} +``` + +### Client Interception + +The system patches OpenAI and Ollama client methods to intercept calls before they leave your application. This happens transparently - your test code doesn't change. + +### Storage Architecture + +Recordings use a two-tier storage system optimized for both speed and debuggability: + +``` +recordings/ +├── index.sqlite # Fast lookup by request hash +└── responses/ + ├── abc123def456.json # Individual response files + └── def789ghi012.json +``` + +**SQLite index** enables O(log n) hash lookups and metadata queries without loading response bodies. + +**JSON files** store complete request/response pairs in human-readable format for debugging. + +## Recording Modes + +### LIVE Mode + +Direct API calls with no recording or replay: + +```python +with inference_recording(mode=InferenceMode.LIVE): + response = await client.chat.completions.create(...) +``` + +Use for initial development and debugging against real APIs. + +### RECORD Mode + +Captures API interactions while passing through real responses: + +```python +with inference_recording(mode=InferenceMode.RECORD, storage_dir="./recordings"): + response = await client.chat.completions.create(...) + # Real API call made, response captured AND returned +``` + +The recording process: +1. Request intercepted and hashed +2. Real API call executed +3. Response captured and serialized +4. Recording stored to disk +5. Original response returned to caller + +### REPLAY Mode + +Returns stored responses instead of making API calls: + +```python +with inference_recording(mode=InferenceMode.REPLAY, storage_dir="./recordings"): + response = await client.chat.completions.create(...) + # No API call made, cached response returned instantly +``` + +The replay process: +1. Request intercepted and hashed +2. Hash looked up in SQLite index +3. Response loaded from JSON file +4. Response deserialized and returned +5. Error if no recording found + +## Streaming Support + +Streaming APIs present a unique challenge: how do you capture an async generator? + +### The Problem + +```python +# How do you record this? +async for chunk in client.chat.completions.create(stream=True): + process(chunk) +``` + +### The Solution + +The system captures all chunks immediately before yielding any: + +```python +async def handle_streaming_record(response): + # Capture complete stream first + chunks = [] + async for chunk in response: + chunks.append(chunk) + + # Store complete recording + storage.store_recording( + request_hash, request_data, {"body": chunks, "is_streaming": True} + ) + + # Return generator that replays captured chunks + async def replay_stream(): + for chunk in chunks: + yield chunk + + return replay_stream() +``` + +This ensures: +- **Complete capture** - The entire stream is saved atomically +- **Interface preservation** - The returned object behaves like the original API +- **Deterministic replay** - Same chunks in the same order every time + +## Serialization + +API responses contain complex Pydantic objects that need careful serialization: + +```python +def _serialize_response(response): + if hasattr(response, "model_dump"): + # Preserve type information for proper deserialization + return { + "__type__": f"{response.__class__.__module__}.{response.__class__.__qualname__}", + "__data__": response.model_dump(mode="json"), + } + return response +``` + +This preserves type safety - when replayed, you get the same Pydantic objects with all their validation and methods. + +## Environment Integration + +### Environment Variables + +Control recording behavior globally: + +```bash +export LLAMA_STACK_TEST_INFERENCE_MODE=replay +export LLAMA_STACK_TEST_RECORDING_DIR=/path/to/recordings +pytest tests/integration/ +``` + +### Pytest Integration + +The system integrates automatically based on environment variables, requiring no changes to test code. + +## Debugging Recordings + +### Inspecting Storage + +```bash +# See what's recorded +sqlite3 recordings/index.sqlite "SELECT endpoint, model, timestamp FROM recordings LIMIT 10;" + +# View specific response +cat recordings/responses/abc123def456.json | jq '.response.body' + +# Find recordings by endpoint +sqlite3 recordings/index.sqlite "SELECT * FROM recordings WHERE endpoint='/v1/chat/completions';" +``` + +### Common Issues + +**Hash mismatches:** Request parameters changed slightly between record and replay +```bash +# Compare request details +cat recordings/responses/abc123.json | jq '.request' +``` + +**Serialization errors:** Response types changed between versions +```bash +# Re-record with updated types +rm recordings/responses/failing_hash.json +LLAMA_STACK_TEST_INFERENCE_MODE=record pytest test_failing.py +``` + +**Missing recordings:** New test or changed parameters +```bash +# Record the missing interaction +LLAMA_STACK_TEST_INFERENCE_MODE=record pytest test_new.py +``` + +## Design Decisions + +### Why Not Mocks? + +Traditional mocking breaks down with AI APIs because: +- Response structures are complex and evolve frequently +- Streaming behavior is hard to mock correctly +- Edge cases in real APIs get missed +- Mocks become brittle maintenance burdens + +### Why Precise Hashing? + +Loose hashing (normalizing whitespace, rounding floats) seems convenient but hides bugs. If a test changes slightly, you want to know about it rather than accidentally getting the wrong cached response. + +### Why JSON + SQLite? + +- **JSON** - Human readable, diff-friendly, easy to inspect and modify +- **SQLite** - Fast indexed lookups without loading response bodies +- **Hybrid** - Best of both worlds for different use cases + +This system provides reliable, fast testing against real AI APIs while maintaining the ability to debug issues when they arise. \ No newline at end of file diff --git a/docs/source/distributions/configuration.md b/docs/source/distributions/configuration.md index 335fa3a68..c9677b3b6 100644 --- a/docs/source/distributions/configuration.md +++ b/docs/source/distributions/configuration.md @@ -225,8 +225,32 @@ server: port: 8321 # Port to listen on (default: 8321) tls_certfile: "/path/to/cert.pem" # Optional: Path to TLS certificate for HTTPS tls_keyfile: "/path/to/key.pem" # Optional: Path to TLS key for HTTPS + cors: true # Optional: Enable CORS (dev mode) or full config object ``` +### CORS Configuration + +CORS (Cross-Origin Resource Sharing) can be configured in two ways: + +**Local development** (allows localhost origins only): +```yaml +server: + cors: true +``` + +**Explicit configuration** (custom origins and settings): +```yaml +server: + cors: + allow_origins: ["https://myapp.com", "https://app.example.com"] + allow_methods: ["GET", "POST", "PUT", "DELETE"] + allow_headers: ["Content-Type", "Authorization"] + allow_credentials: true + max_age: 3600 +``` + +When `cors: true`, the server enables secure localhost-only access for local development. For production, specify exact origins to maintain security. + ### Authentication Configuration > **Breaking Change (v0.2.14)**: The authentication configuration structure has changed. The previous format with `provider_type` and `config` fields has been replaced with a unified `provider_config` field that includes the `type` field. Update your configuration files accordingly. @@ -618,6 +642,54 @@ Content-Type: application/json } ``` +### CORS Configuration + +Configure CORS to allow web browsers to make requests from different domains. Disabled by default. + +#### Quick Setup + +For development, use the simple boolean flag: + +```yaml +server: + cors: true # Auto-enables localhost with any port +``` + +This automatically allows `http://localhost:*` and `https://localhost:*` with secure defaults. + +#### Custom Configuration + +For specific origins and full control: + +```yaml +server: + cors: + allow_origins: ["https://myapp.com", "https://staging.myapp.com"] + allow_credentials: true + allow_methods: ["GET", "POST", "PUT", "DELETE"] + allow_headers: ["Content-Type", "Authorization"] + allow_origin_regex: "https://.*\\.example\\.com" # Optional regex pattern + expose_headers: ["X-Total-Count"] + max_age: 86400 +``` + +#### Configuration Options + +| Field | Description | Default | +| -------------------- | ---------------------------------------------- | ------- | +| `allow_origins` | List of allowed origins. Use `["*"]` for any. | `["*"]` | +| `allow_origin_regex` | Regex pattern for allowed origins (optional). | `None` | +| `allow_methods` | Allowed HTTP methods. | `["*"]` | +| `allow_headers` | Allowed headers. | `["*"]` | +| `allow_credentials` | Allow credentials (cookies, auth headers). | `false` | +| `expose_headers` | Headers exposed to browser. | `[]` | +| `max_age` | Preflight cache time (seconds). | `600` | + +**Security Notes**: +- `allow_credentials: true` requires explicit origins (no wildcards) +- `cors: true` enables localhost access only (secure for development) +- For public APIs, always specify exact allowed origins + ## Extending to handle Safety Configuring Safety can be a little involved so it is instructive to go through an example. diff --git a/docs/source/distributions/importing_as_library.md b/docs/source/distributions/importing_as_library.md index fbc48dd95..b9b4b065a 100644 --- a/docs/source/distributions/importing_as_library.md +++ b/docs/source/distributions/importing_as_library.md @@ -17,7 +17,6 @@ client = LlamaStackAsLibraryClient( # provider_data is optional, but if you need to pass in any provider specific data, you can do so here. provider_data={"tavily_search_api_key": os.environ["TAVILY_SEARCH_API_KEY"]}, ) -client.initialize() ``` This will parse your config and set up any inline implementations and remote clients needed for your implementation. @@ -32,5 +31,4 @@ If you've created a [custom distribution](https://llama-stack.readthedocs.io/en/ ```python client = LlamaStackAsLibraryClient(config_path) -client.initialize() ``` diff --git a/docs/source/distributions/k8s-benchmark/README.md b/docs/source/distributions/k8s-benchmark/README.md new file mode 100644 index 000000000..42da4d466 --- /dev/null +++ b/docs/source/distributions/k8s-benchmark/README.md @@ -0,0 +1,156 @@ +# Llama Stack Benchmark Suite on Kubernetes + +## Motivation + +Performance benchmarking is critical for understanding the overhead and characteristics of the Llama Stack abstraction layer compared to direct inference engines like vLLM. + +### Why This Benchmark Suite Exists + +**Performance Validation**: The Llama Stack provides a unified API layer across multiple inference providers, but this abstraction introduces potential overhead. This benchmark suite quantifies the performance impact by comparing: +- Llama Stack inference (with vLLM backend) +- Direct vLLM inference calls +- Both under identical Kubernetes deployment conditions + +**Production Readiness Assessment**: Real-world deployments require understanding performance characteristics under load. This suite simulates concurrent user scenarios with configurable parameters (duration, concurrency, request patterns) to validate production readiness. + +**Regression Detection (TODO)**: As the Llama Stack evolves, this benchmark provides automated regression detection for performance changes. CI/CD pipelines can leverage these benchmarks to catch performance degradations before production deployments. + +**Resource Planning**: By measuring throughput, latency percentiles, and resource utilization patterns, teams can make informed decisions about: +- Kubernetes resource allocation (CPU, memory, GPU) +- Auto-scaling configurations +- Cost optimization strategies + +### Key Metrics Captured + +The benchmark suite measures critical performance indicators: +- **Throughput**: Requests per second under sustained load +- **Latency Distribution**: P50, P95, P99 response times +- **Time to First Token (TTFT)**: Critical for streaming applications +- **Error Rates**: Request failures and timeout analysis + +This data enables data-driven architectural decisions and performance optimization efforts. + +## Setup + +**1. Deploy base k8s infrastructure:** +```bash +cd ../k8s +./apply.sh +``` + +**2. Deploy benchmark components:** +```bash +cd ../k8s-benchmark +./apply.sh +``` + +**3. Verify deployment:** +```bash +kubectl get pods +# Should see: llama-stack-benchmark-server, vllm-server, etc. +``` + +## Quick Start + +### Basic Benchmarks + +**Benchmark Llama Stack (default):** +```bash +cd docs/source/distributions/k8s-benchmark/ +./run-benchmark.sh +``` + +**Benchmark vLLM direct:** +```bash +./run-benchmark.sh --target vllm +``` + +### Custom Configuration + +**Extended benchmark with high concurrency:** +```bash +./run-benchmark.sh --target vllm --duration 120 --concurrent 20 +``` + +**Short test run:** +```bash +./run-benchmark.sh --target stack --duration 30 --concurrent 5 +``` + +## Command Reference + +### run-benchmark.sh Options + +```bash +./run-benchmark.sh [options] + +Options: + -t, --target Target to benchmark (default: stack) + -d, --duration Duration in seconds (default: 60) + -c, --concurrent Number of concurrent users (default: 10) + -h, --help Show help message + +Examples: + ./run-benchmark.sh --target vllm # Benchmark vLLM direct + ./run-benchmark.sh --target stack # Benchmark Llama Stack + ./run-benchmark.sh -t vllm -d 120 -c 20 # vLLM with 120s, 20 users +``` + +## Local Testing + +### Running Benchmark Locally + +For local development without Kubernetes: + +**1. Start OpenAI mock server:** +```bash +uv run python openai-mock-server.py --port 8080 +``` + +**2. Run benchmark against mock server:** +```bash +uv run python benchmark.py \ + --base-url http://localhost:8080/v1 \ + --model mock-inference \ + --duration 30 \ + --concurrent 5 +``` + +**3. Test against local vLLM server:** +```bash +# If you have vLLM running locally on port 8000 +uv run python benchmark.py \ + --base-url http://localhost:8000/v1 \ + --model meta-llama/Llama-3.2-3B-Instruct \ + --duration 30 \ + --concurrent 5 +``` + +**4. Profile the running server:** +```bash +./profile_running_server.sh +``` + + + +### OpenAI Mock Server + +The `openai-mock-server.py` provides: +- **OpenAI-compatible API** for testing without real models +- **Configurable streaming delay** via `STREAM_DELAY_SECONDS` env var +- **Consistent responses** for reproducible benchmarks +- **Lightweight testing** without GPU requirements + +**Mock server usage:** +```bash +uv run python openai-mock-server.py --port 8080 +``` + +The mock server is also deployed in k8s as `openai-mock-service:8080` and can be used by changing the Llama Stack configuration to use the `mock-vllm-inference` provider. + +## Files in this Directory + +- `benchmark.py` - Core benchmark script with async streaming support +- `run-benchmark.sh` - Main script with target selection and configuration +- `openai-mock-server.py` - Mock OpenAI API server for local testing +- `README.md` - This documentation file diff --git a/docs/source/distributions/k8s-benchmark/apply.sh b/docs/source/distributions/k8s-benchmark/apply.sh index 119a1c849..4f2270da8 100755 --- a/docs/source/distributions/k8s-benchmark/apply.sh +++ b/docs/source/distributions/k8s-benchmark/apply.sh @@ -8,7 +8,6 @@ # Deploys the benchmark-specific components on top of the base k8s deployment (../k8s/apply.sh). -export MOCK_INFERENCE_PORT=8080 export STREAM_DELAY_SECONDS=0.005 export POSTGRES_USER=llamastack @@ -20,14 +19,7 @@ export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B export MOCK_INFERENCE_MODEL=mock-inference -# Use llama-stack-benchmark-service as the benchmark server -export LOCUST_HOST=http://llama-stack-benchmark-service:8323 -export LOCUST_BASE_PATH=/v1/openai/v1 - -# Use vllm-service as the benchmark server -# export LOCUST_HOST=http://vllm-server:8000 -# export LOCUST_BASE_PATH=/v1 - +export MOCK_INFERENCE_URL=openai-mock-service:8080 export BENCHMARK_INFERENCE_MODEL=$INFERENCE_MODEL @@ -35,13 +27,6 @@ set -euo pipefail set -x # Deploy benchmark-specific components -# Deploy OpenAI mock server -kubectl create configmap openai-mock --from-file=openai-mock-server.py \ - --dry-run=client -o yaml | kubectl apply --validate=false -f - - -envsubst < openai-mock-deployment.yaml | kubectl apply --validate=false -f - - -# Create configmap with our custom stack config kubectl create configmap llama-stack-config --from-file=stack_run_config.yaml \ --dry-run=client -o yaml > stack-configmap.yaml @@ -49,9 +34,3 @@ kubectl apply --validate=false -f stack-configmap.yaml # Deploy our custom llama stack server (overriding the base one) envsubst < stack-k8s.yaml.template | kubectl apply --validate=false -f - - -# Deploy Locust load testing -kubectl create configmap locust-script --from-file=locustfile.py \ - --dry-run=client -o yaml | kubectl apply --validate=false -f - - -envsubst < locust-k8s.yaml | kubectl apply --validate=false -f - diff --git a/docs/source/distributions/k8s-benchmark/benchmark.py b/docs/source/distributions/k8s-benchmark/benchmark.py new file mode 100644 index 000000000..3d0d18150 --- /dev/null +++ b/docs/source/distributions/k8s-benchmark/benchmark.py @@ -0,0 +1,267 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +""" +Simple benchmark script for Llama Stack with OpenAI API compatibility. +""" + +import argparse +import asyncio +import os +import random +import statistics +import time +from typing import Tuple +import aiohttp + + +class BenchmarkStats: + def __init__(self): + self.response_times = [] + self.ttft_times = [] + self.chunks_received = [] + self.errors = [] + self.success_count = 0 + self.total_requests = 0 + self.concurrent_users = 0 + self.start_time = None + self.end_time = None + self._lock = asyncio.Lock() + + async def add_result(self, response_time: float, chunks: int, ttft: float = None, error: str = None): + async with self._lock: + self.total_requests += 1 + if error: + self.errors.append(error) + else: + self.success_count += 1 + self.response_times.append(response_time) + self.chunks_received.append(chunks) + if ttft is not None: + self.ttft_times.append(ttft) + + def print_summary(self): + if not self.response_times: + print("No successful requests to report") + if self.errors: + print(f"Total errors: {len(self.errors)}") + print("First 5 errors:") + for error in self.errors[:5]: + print(f" {error}") + return + + total_time = self.end_time - self.start_time + success_rate = (self.success_count / self.total_requests) * 100 + + print(f"\n{'='*60}") + print(f"BENCHMARK RESULTS") + print(f"{'='*60}") + print(f"Total time: {total_time:.2f}s") + print(f"Concurrent users: {self.concurrent_users}") + print(f"Total requests: {self.total_requests}") + print(f"Successful requests: {self.success_count}") + print(f"Failed requests: {len(self.errors)}") + print(f"Success rate: {success_rate:.1f}%") + print(f"Requests per second: {self.success_count / total_time:.2f}") + + print(f"\nResponse Time Statistics:") + print(f" Mean: {statistics.mean(self.response_times):.3f}s") + print(f" Median: {statistics.median(self.response_times):.3f}s") + print(f" Min: {min(self.response_times):.3f}s") + print(f" Max: {max(self.response_times):.3f}s") + + if len(self.response_times) > 1: + print(f" Std Dev: {statistics.stdev(self.response_times):.3f}s") + + percentiles = [50, 90, 95, 99] + sorted_times = sorted(self.response_times) + print(f"\nPercentiles:") + for p in percentiles: + idx = int(len(sorted_times) * p / 100) - 1 + idx = max(0, min(idx, len(sorted_times) - 1)) + print(f" P{p}: {sorted_times[idx]:.3f}s") + + if self.ttft_times: + print(f"\nTime to First Token (TTFT) Statistics:") + print(f" Mean: {statistics.mean(self.ttft_times):.3f}s") + print(f" Median: {statistics.median(self.ttft_times):.3f}s") + print(f" Min: {min(self.ttft_times):.3f}s") + print(f" Max: {max(self.ttft_times):.3f}s") + + if len(self.ttft_times) > 1: + print(f" Std Dev: {statistics.stdev(self.ttft_times):.3f}s") + + sorted_ttft = sorted(self.ttft_times) + print(f"\nTTFT Percentiles:") + for p in percentiles: + idx = int(len(sorted_ttft) * p / 100) - 1 + idx = max(0, min(idx, len(sorted_ttft) - 1)) + print(f" P{p}: {sorted_ttft[idx]:.3f}s") + + if self.chunks_received: + print(f"\nStreaming Statistics:") + print(f" Mean chunks per response: {statistics.mean(self.chunks_received):.1f}") + print(f" Total chunks received: {sum(self.chunks_received)}") + + if self.errors: + print(f"\nErrors (showing first 5):") + for error in self.errors[:5]: + print(f" {error}") + + +class LlamaStackBenchmark: + def __init__(self, base_url: str, model_id: str): + self.base_url = base_url.rstrip('/') + self.model_id = model_id + self.headers = {"Content-Type": "application/json"} + self.test_messages = [ + [{"role": "user", "content": "Hi"}], + [{"role": "user", "content": "What is the capital of France?"}], + [{"role": "user", "content": "Explain quantum physics in simple terms."}], + [{"role": "user", "content": "Write a short story about a robot learning to paint."}], + [ + {"role": "user", "content": "What is machine learning?"}, + {"role": "assistant", "content": "Machine learning is a subset of AI..."}, + {"role": "user", "content": "Can you give me a practical example?"} + ] + ] + + + async def make_async_streaming_request(self) -> Tuple[float, int, float | None, str | None]: + """Make a single async streaming chat completion request.""" + messages = random.choice(self.test_messages) + payload = { + "model": self.model_id, + "messages": messages, + "stream": True, + "max_tokens": 100 + } + + start_time = time.time() + chunks_received = 0 + ttft = None + error = None + + session = aiohttp.ClientSession() + + try: + async with session.post( + f"{self.base_url}/chat/completions", + headers=self.headers, + json=payload, + timeout=aiohttp.ClientTimeout(total=30) + ) as response: + if response.status == 200: + async for line in response.content: + if line: + line_str = line.decode('utf-8').strip() + if line_str.startswith('data: '): + chunks_received += 1 + if ttft is None: + ttft = time.time() - start_time + if line_str == 'data: [DONE]': + break + + if chunks_received == 0: + error = "No streaming chunks received" + else: + text = await response.text() + error = f"HTTP {response.status}: {text[:100]}" + + except Exception as e: + error = f"Request error: {str(e)}" + finally: + await session.close() + + response_time = time.time() - start_time + return response_time, chunks_received, ttft, error + + + async def run_benchmark(self, duration: int, concurrent_users: int) -> BenchmarkStats: + """Run benchmark using async requests for specified duration.""" + stats = BenchmarkStats() + stats.concurrent_users = concurrent_users + stats.start_time = time.time() + + print(f"Starting benchmark: {duration}s duration, {concurrent_users} concurrent users") + print(f"Target URL: {self.base_url}/chat/completions") + print(f"Model: {self.model_id}") + + connector = aiohttp.TCPConnector(limit=concurrent_users) + async with aiohttp.ClientSession(connector=connector) as session: + + async def worker(worker_id: int): + """Worker that sends requests sequentially until canceled.""" + request_count = 0 + while True: + try: + response_time, chunks, ttft, error = await self.make_async_streaming_request() + await stats.add_result(response_time, chunks, ttft, error) + request_count += 1 + + except asyncio.CancelledError: + break + except Exception as e: + await stats.add_result(0, 0, None, f"Worker {worker_id} error: {str(e)}") + + # Progress reporting task + async def progress_reporter(): + last_report_time = time.time() + while True: + try: + await asyncio.sleep(1) # Report every second + if time.time() >= last_report_time + 10: # Report every 10 seconds + elapsed = time.time() - stats.start_time + print(f"Completed: {stats.total_requests} requests in {elapsed:.1f}s") + last_report_time = time.time() + except asyncio.CancelledError: + break + + # Spawn concurrent workers + tasks = [asyncio.create_task(worker(i)) for i in range(concurrent_users)] + progress_task = asyncio.create_task(progress_reporter()) + tasks.append(progress_task) + + # Wait for duration then cancel all tasks + await asyncio.sleep(duration) + + for task in tasks: + task.cancel() + + # Wait for all tasks to complete + await asyncio.gather(*tasks, return_exceptions=True) + + stats.end_time = time.time() + return stats + + +def main(): + parser = argparse.ArgumentParser(description="Llama Stack Benchmark Tool") + parser.add_argument("--base-url", default=os.getenv("BENCHMARK_BASE_URL", "http://localhost:8000/v1/openai/v1"), + help="Base URL for the API (default: http://localhost:8000/v1/openai/v1)") + parser.add_argument("--model", default=os.getenv("INFERENCE_MODEL", "test-model"), + help="Model ID to use for requests") + parser.add_argument("--duration", type=int, default=60, + help="Duration in seconds to run benchmark (default: 60)") + parser.add_argument("--concurrent", type=int, default=10, + help="Number of concurrent users (default: 10)") + + args = parser.parse_args() + + benchmark = LlamaStackBenchmark(args.base_url, args.model) + + try: + stats = asyncio.run(benchmark.run_benchmark(args.duration, args.concurrent)) + stats.print_summary() + + except KeyboardInterrupt: + print("\nBenchmark interrupted by user") + except Exception as e: + print(f"Benchmark failed: {e}") + + +if __name__ == "__main__": + main() diff --git a/docs/source/distributions/k8s-benchmark/locust-k8s.yaml b/docs/source/distributions/k8s-benchmark/locust-k8s.yaml deleted file mode 100644 index f20a01b2d..000000000 --- a/docs/source/distributions/k8s-benchmark/locust-k8s.yaml +++ /dev/null @@ -1,131 +0,0 @@ -apiVersion: apps/v1 -kind: Deployment -metadata: - name: locust-master - labels: - app: locust - role: master -spec: - replicas: 1 - selector: - matchLabels: - app: locust - role: master - template: - metadata: - labels: - app: locust - role: master - spec: - containers: - - name: locust-master - image: locustio/locust:2.31.8 - ports: - - containerPort: 8089 # Web UI - - containerPort: 5557 # Master communication - env: - - name: LOCUST_HOST - value: "${LOCUST_HOST}" - - name: LOCUST_LOCUSTFILE - value: "/locust/locustfile.py" - - name: LOCUST_WEB_HOST - value: "0.0.0.0" - - name: LOCUST_MASTER - value: "true" - - name: LOCUST_BASE_PATH - value: "${LOCUST_BASE_PATH}" - - name: INFERENCE_MODEL - value: "${BENCHMARK_INFERENCE_MODEL}" - volumeMounts: - - name: locust-script - mountPath: /locust - command: ["locust"] - args: - - "--master" - - "--web-host=0.0.0.0" - - "--web-port=8089" - - "--host=${LOCUST_HOST}" - - "--locustfile=/locust/locustfile.py" - volumes: - - name: locust-script - configMap: - name: locust-script ---- -apiVersion: apps/v1 -kind: Deployment -metadata: - name: locust-worker - labels: - app: locust - role: worker -spec: - replicas: 2 # Start with 2 workers, can be scaled up - selector: - matchLabels: - app: locust - role: worker - template: - metadata: - labels: - app: locust - role: worker - spec: - containers: - - name: locust-worker - image: locustio/locust:2.31.8 - env: - - name: LOCUST_HOST - value: "${LOCUST_HOST}" - - name: LOCUST_LOCUSTFILE - value: "/locust/locustfile.py" - - name: LOCUST_MASTER_HOST - value: "locust-master-service" - - name: LOCUST_MASTER_PORT - value: "5557" - - name: INFERENCE_MODEL - value: "${BENCHMARK_INFERENCE_MODEL}" - - name: LOCUST_BASE_PATH - value: "${LOCUST_BASE_PATH}" - volumeMounts: - - name: locust-script - mountPath: /locust - command: ["locust"] - args: - - "--worker" - - "--master-host=locust-master-service" - - "--master-port=5557" - - "--locustfile=/locust/locustfile.py" - volumes: - - name: locust-script - configMap: - name: locust-script ---- -apiVersion: v1 -kind: Service -metadata: - name: locust-master-service -spec: - selector: - app: locust - role: master - ports: - - name: web-ui - port: 8089 - targetPort: 8089 - - name: master-comm - port: 5557 - targetPort: 5557 - type: ClusterIP ---- -apiVersion: v1 -kind: Service -metadata: - name: locust-web-ui -spec: - selector: - app: locust - role: master - ports: - - port: 8089 - targetPort: 8089 - type: ClusterIP # Keep internal, use port-forward to access diff --git a/docs/source/distributions/k8s-benchmark/locustfile.py b/docs/source/distributions/k8s-benchmark/locustfile.py deleted file mode 100644 index 8e511fa95..000000000 --- a/docs/source/distributions/k8s-benchmark/locustfile.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -""" -Locust load testing script for Llama Stack with Prism mock OpenAI provider. -""" - -import random -from locust import HttpUser, task, between -import os - -base_path = os.getenv("LOCUST_BASE_PATH", "/v1/openai/v1") - -MODEL_ID = os.getenv("INFERENCE_MODEL") - -class LlamaStackUser(HttpUser): - wait_time = between(0.0, 0.0001) - - def on_start(self): - """Setup authentication and test data.""" - # No auth required for benchmark server - self.headers = { - "Content-Type": "application/json" - } - - # Test messages of varying lengths - self.test_messages = [ - [{"role": "user", "content": "Hi"}], - [{"role": "user", "content": "What is the capital of France?"}], - [{"role": "user", "content": "Explain quantum physics in simple terms."}], - [{"role": "user", "content": "Write a short story about a robot learning to paint."}], - [ - {"role": "user", "content": "What is machine learning?"}, - {"role": "assistant", "content": "Machine learning is a subset of AI..."}, - {"role": "user", "content": "Can you give me a practical example?"} - ] - ] - - @task(weight=100) - def chat_completion_streaming(self): - """Test streaming chat completion (20% of requests).""" - messages = random.choice(self.test_messages) - payload = { - "model": MODEL_ID, - "messages": messages, - "stream": True, - "max_tokens": 100 - } - - with self.client.post( - f"{base_path}/chat/completions", - headers=self.headers, - json=payload, - stream=True, - catch_response=True - ) as response: - if response.status_code == 200: - chunks_received = 0 - try: - for line in response.iter_lines(): - if line: - line_str = line.decode('utf-8') - if line_str.startswith('data: '): - chunks_received += 1 - if line_str.strip() == 'data: [DONE]': - break - - if chunks_received > 0: - response.success() - else: - response.failure("No streaming chunks received") - except Exception as e: - response.failure(f"Streaming error: {e}") - else: - response.failure(f"HTTP {response.status_code}: {response.text}") diff --git a/docs/source/distributions/k8s-benchmark/openai-mock-deployment.yaml b/docs/source/distributions/k8s-benchmark/openai-mock-deployment.yaml deleted file mode 100644 index c72921281..000000000 --- a/docs/source/distributions/k8s-benchmark/openai-mock-deployment.yaml +++ /dev/null @@ -1,52 +0,0 @@ -apiVersion: apps/v1 -kind: Deployment -metadata: - name: openai-mock - labels: - app: openai-mock -spec: - replicas: 1 - selector: - matchLabels: - app: openai-mock - template: - metadata: - labels: - app: openai-mock - spec: - containers: - - name: openai-mock - image: python:3.12-slim - ports: - - containerPort: ${MOCK_INFERENCE_PORT} - env: - - name: PORT - value: "${MOCK_INFERENCE_PORT}" - - name: MOCK_MODELS - value: "${MOCK_INFERENCE_MODEL}" - - name: STREAM_DELAY_SECONDS - value: "${STREAM_DELAY_SECONDS}" - command: ["sh", "-c"] - args: - - | - pip install flask && - python /app/openai-mock-server.py --port ${MOCK_INFERENCE_PORT} - volumeMounts: - - name: openai-mock-script - mountPath: /app - volumes: - - name: openai-mock-script - configMap: - name: openai-mock ---- -apiVersion: v1 -kind: Service -metadata: - name: openai-mock-service -spec: - selector: - app: openai-mock - ports: - - port: 8080 - targetPort: 8080 - type: ClusterIP diff --git a/docs/source/distributions/k8s-benchmark/openai-mock-server.py b/docs/source/distributions/k8s-benchmark/openai-mock-server.py old mode 100644 new mode 100755 index 46c923b60..de0680842 --- a/docs/source/distributions/k8s-benchmark/openai-mock-server.py +++ b/docs/source/distributions/k8s-benchmark/openai-mock-server.py @@ -23,7 +23,7 @@ app = Flask(__name__) # Models from environment variables def get_models(): - models_str = os.getenv("MOCK_MODELS", "mock-inference") + models_str = os.getenv("MOCK_MODELS", "meta-llama/Llama-3.2-3B-Instruct") model_ids = [m.strip() for m in models_str.split(",") if m.strip()] return { @@ -49,13 +49,13 @@ def generate_random_text(length=50): ] return " ".join(random.choices(words, k=length)) -@app.route('/models', methods=['GET']) +@app.route('/v1/models', methods=['GET']) def list_models(): models = get_models() print(f"[MOCK] Returning models: {[m['id'] for m in models['data']]}") return jsonify(models) -@app.route('/chat/completions', methods=['POST']) +@app.route('/v1/chat/completions', methods=['POST']) def chat_completions(): """Return OpenAI-formatted chat completion responses.""" data = request.get_json() diff --git a/docs/source/distributions/k8s-benchmark/profile_running_server.sh b/docs/source/distributions/k8s-benchmark/profile_running_server.sh new file mode 100755 index 000000000..65d620583 --- /dev/null +++ b/docs/source/distributions/k8s-benchmark/profile_running_server.sh @@ -0,0 +1,52 @@ +#!/bin/bash + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Script to profile an already running Llama Stack server +# Usage: ./profile_running_server.sh [duration_seconds] [output_file] + +DURATION=${1:-60} # Default 60 seconds +OUTPUT_FILE=${2:-"llama_stack_profile"} # Default output file + +echo "Looking for running Llama Stack server..." + +# Find the server PID +SERVER_PID=$(ps aux | grep "llama_stack.core.server.server" | grep -v grep | awk '{print $2}' | head -1) + + +if [ -z "$SERVER_PID" ]; then + echo "Error: No running Llama Stack server found" + echo "Please start your server first with:" + echo "LLAMA_STACK_LOGGING=\"all=ERROR\" MOCK_INFERENCE_URL=http://localhost:8080 SAFETY_MODEL=llama-guard3:1b uv run --with llama-stack python -m llama_stack.core.server.server docs/source/distributions/k8s-benchmark/stack_run_config.yaml" + exit 1 +fi + +echo "Found Llama Stack server with PID: $SERVER_PID" + +# Start py-spy profiling +echo "Starting py-spy profiling for ${DURATION} seconds..." +echo "Output will be saved to: ${OUTPUT_FILE}.svg" +echo "" +echo "You can now run your load test..." +echo "" + +# Get the full path to py-spy +PYSPY_PATH=$(which py-spy) + +# Check if running as root, if not, use sudo +if [ "$EUID" -ne 0 ]; then + echo "py-spy requires root permissions on macOS. Running with sudo..." + sudo "$PYSPY_PATH" record -o "${OUTPUT_FILE}.svg" -d ${DURATION} -p $SERVER_PID +else + "$PYSPY_PATH" record -o "${OUTPUT_FILE}.svg" -d ${DURATION} -p $SERVER_PID +fi + +echo "" +echo "Profiling completed! Results saved to: ${OUTPUT_FILE}.svg" +echo "" +echo "To view the flame graph:" +echo "open ${OUTPUT_FILE}.svg" diff --git a/docs/source/distributions/k8s-benchmark/run-benchmark.sh b/docs/source/distributions/k8s-benchmark/run-benchmark.sh new file mode 100755 index 000000000..e1c826143 --- /dev/null +++ b/docs/source/distributions/k8s-benchmark/run-benchmark.sh @@ -0,0 +1,148 @@ +#!/usr/bin/env bash + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +set -euo pipefail + +# Default values +TARGET="stack" +DURATION=60 +CONCURRENT=10 + +# Parse command line arguments +usage() { + echo "Usage: $0 [options]" + echo "Options:" + echo " -t, --target Target to benchmark (default: stack)" + echo " -d, --duration Duration in seconds (default: 60)" + echo " -c, --concurrent Number of concurrent users (default: 10)" + echo " -h, --help Show this help message" + echo "" + echo "Examples:" + echo " $0 --target vllm # Benchmark vLLM direct" + echo " $0 --target stack # Benchmark Llama Stack (default)" + echo " $0 -t vllm -d 120 -c 20 # vLLM with 120s duration, 20 users" +} + +while [[ $# -gt 0 ]]; do + case $1 in + -t|--target) + TARGET="$2" + shift 2 + ;; + -d|--duration) + DURATION="$2" + shift 2 + ;; + -c|--concurrent) + CONCURRENT="$2" + shift 2 + ;; + -h|--help) + usage + exit 0 + ;; + *) + echo "Unknown option: $1" + usage + exit 1 + ;; + esac +done + +# Validate target +if [[ "$TARGET" != "stack" && "$TARGET" != "vllm" ]]; then + echo "Error: Target must be 'stack' or 'vllm'" + usage + exit 1 +fi + +# Set configuration based on target +if [[ "$TARGET" == "vllm" ]]; then + BASE_URL="http://vllm-server:8000/v1" + JOB_NAME="vllm-benchmark-job" + echo "Benchmarking vLLM direct..." +else + BASE_URL="http://llama-stack-benchmark-service:8323/v1/openai/v1" + JOB_NAME="stack-benchmark-job" + echo "Benchmarking Llama Stack..." +fi + +echo "Configuration:" +echo " Target: $TARGET" +echo " Base URL: $BASE_URL" +echo " Duration: ${DURATION}s" +echo " Concurrent users: $CONCURRENT" +echo "" + +# Create temporary job yaml +TEMP_YAML="/tmp/benchmark-job-temp-$(date +%s).yaml" +cat > "$TEMP_YAML" << EOF +apiVersion: batch/v1 +kind: Job +metadata: + name: $JOB_NAME + namespace: default +spec: + template: + spec: + containers: + - name: benchmark + image: python:3.11-slim + command: ["/bin/bash"] + args: + - "-c" + - | + pip install aiohttp && + python3 /benchmark/benchmark.py \\ + --base-url $BASE_URL \\ + --model \${INFERENCE_MODEL} \\ + --duration $DURATION \\ + --concurrent $CONCURRENT + env: + - name: INFERENCE_MODEL + value: "meta-llama/Llama-3.2-3B-Instruct" + volumeMounts: + - name: benchmark-script + mountPath: /benchmark + resources: + requests: + memory: "256Mi" + cpu: "250m" + limits: + memory: "512Mi" + cpu: "500m" + volumes: + - name: benchmark-script + configMap: + name: benchmark-script + restartPolicy: Never + backoffLimit: 3 +EOF + +echo "Creating benchmark ConfigMap..." +kubectl create configmap benchmark-script \ + --from-file=benchmark.py=benchmark.py \ + --dry-run=client -o yaml | kubectl apply -f - + +echo "Cleaning up any existing benchmark job..." +kubectl delete job $JOB_NAME 2>/dev/null || true + +echo "Deploying benchmark Job..." +kubectl apply -f "$TEMP_YAML" + +echo "Waiting for job to start..." +kubectl wait --for=condition=Ready pod -l job-name=$JOB_NAME --timeout=60s + +echo "Following benchmark logs..." +kubectl logs -f job/$JOB_NAME + +echo "Job completed. Checking final status..." +kubectl get job $JOB_NAME + +# Clean up temporary file +rm -f "$TEMP_YAML" diff --git a/docs/source/distributions/k8s-benchmark/stack-configmap.yaml b/docs/source/distributions/k8s-benchmark/stack-configmap.yaml index 653e66756..edf4ebd75 100644 --- a/docs/source/distributions/k8s-benchmark/stack-configmap.yaml +++ b/docs/source/distributions/k8s-benchmark/stack-configmap.yaml @@ -26,13 +26,6 @@ data: max_tokens: ${env.VLLM_MAX_TOKENS:=4096} api_token: ${env.VLLM_API_TOKEN:=fake} tls_verify: ${env.VLLM_TLS_VERIFY:=true} - - provider_id: mock-vllm-inference - provider_type: remote::vllm - config: - url: http://openai-mock-service:${env.MOCK_INFERENCE_PORT} - max_tokens: 4096 - api_token: fake - tls_verify: false - provider_id: sentence-transformers provider_type: inline::sentence-transformers config: {} @@ -121,9 +114,6 @@ data: - model_id: ${env.SAFETY_MODEL} provider_id: vllm-safety model_type: llm - - model_id: ${env.MOCK_INFERENCE_MODEL} - provider_id: mock-vllm-inference - model_type: llm shields: - shield_id: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B} vector_dbs: [] diff --git a/docs/source/distributions/k8s-benchmark/stack-k8s.yaml.template b/docs/source/distributions/k8s-benchmark/stack-k8s.yaml.template index bc14d5124..9cb1e5be3 100644 --- a/docs/source/distributions/k8s-benchmark/stack-k8s.yaml.template +++ b/docs/source/distributions/k8s-benchmark/stack-k8s.yaml.template @@ -44,8 +44,6 @@ spec: value: "${SAFETY_MODEL}" - name: TAVILY_SEARCH_API_KEY value: "${TAVILY_SEARCH_API_KEY}" - - name: MOCK_INFERENCE_PORT - value: "${MOCK_INFERENCE_PORT}" - name: VLLM_URL value: http://vllm-server.default.svc.cluster.local:8000/v1 - name: VLLM_MAX_TOKENS @@ -54,8 +52,6 @@ spec: value: http://vllm-server-safety.default.svc.cluster.local:8001/v1 - name: VLLM_TLS_VERIFY value: "false" - - name: MOCK_INFERENCE_MODEL - value: "${MOCK_INFERENCE_MODEL}" command: ["python", "-m", "llama_stack.core.server.server", "/etc/config/stack_run_config.yaml", "--port", "8323"] ports: - containerPort: 8323 diff --git a/docs/source/distributions/k8s-benchmark/stack_run_config.yaml b/docs/source/distributions/k8s-benchmark/stack_run_config.yaml index ad56be047..ceb1ba2d9 100644 --- a/docs/source/distributions/k8s-benchmark/stack_run_config.yaml +++ b/docs/source/distributions/k8s-benchmark/stack_run_config.yaml @@ -3,7 +3,6 @@ image_name: kubernetes-benchmark-demo apis: - agents - inference -- safety - telemetry - tool_runtime - vector_io @@ -16,20 +15,6 @@ providers: max_tokens: ${env.VLLM_MAX_TOKENS:=4096} api_token: ${env.VLLM_API_TOKEN:=fake} tls_verify: ${env.VLLM_TLS_VERIFY:=true} - - provider_id: vllm-safety - provider_type: remote::vllm - config: - url: ${env.VLLM_SAFETY_URL:=http://localhost:8000/v1} - max_tokens: ${env.VLLM_MAX_TOKENS:=4096} - api_token: ${env.VLLM_API_TOKEN:=fake} - tls_verify: ${env.VLLM_TLS_VERIFY:=true} - - provider_id: mock-vllm-inference - provider_type: remote::vllm - config: - url: http://openai-mock-service:${env.MOCK_INFERENCE_PORT} - max_tokens: 4096 - api_token: fake - tls_verify: false - provider_id: sentence-transformers provider_type: inline::sentence-transformers config: {} @@ -45,11 +30,6 @@ providers: db: ${env.POSTGRES_DB:=llamastack} user: ${env.POSTGRES_USER:=llamastack} password: ${env.POSTGRES_PASSWORD:=llamastack} - safety: - - provider_id: llama-guard - provider_type: inline::llama-guard - config: - excluded_categories: [] agents: - provider_id: meta-reference provider_type: inline::meta-reference @@ -115,14 +95,6 @@ models: - model_id: ${env.INFERENCE_MODEL} provider_id: vllm-inference model_type: llm -- model_id: ${env.SAFETY_MODEL} - provider_id: vllm-safety - model_type: llm -- model_id: ${env.MOCK_INFERENCE_MODEL} - provider_id: mock-vllm-inference - model_type: llm -shields: -- shield_id: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B} vector_dbs: [] datasets: [] scoring_fns: [] diff --git a/docs/source/providers/agents/index.md b/docs/source/providers/agents/index.md index 92bf9edc0..a2c48d4b9 100644 --- a/docs/source/providers/agents/index.md +++ b/docs/source/providers/agents/index.md @@ -2,6 +2,15 @@ ## Overview +Agents API for creating and interacting with agentic systems. + + Main functionalities provided by this API: + - Create agents with specific instructions and ability to use tools. + - Interactions with agents are grouped into sessions ("threads"), and each interaction is called a "turn". + - Agents can be provided with various tools (see the ToolGroups and ToolRuntime APIs for more details). + - Agents can be provided with various shields (see the Safety API for more details). + - Agents can also use Memory to retrieve information from knowledge bases. See the RAG Tool and Vector IO APIs for more details. + This section contains documentation for all available providers for the **agents** API. ## Providers diff --git a/docs/source/providers/batches/index.md b/docs/source/providers/batches/index.md new file mode 100644 index 000000000..d6d2fa9a3 --- /dev/null +++ b/docs/source/providers/batches/index.md @@ -0,0 +1,24 @@ +# Batches + +## Overview + +The Batches API enables efficient processing of multiple requests in a single operation, + particularly useful for processing large datasets, batch evaluation workflows, and + cost-effective inference at scale. + + The API is designed to allow use of openai client libraries for seamless integration. + + This API provides the following extensions: + - idempotent batch creation + + Note: This API is currently under active development and may undergo changes. + +This section contains documentation for all available providers for the **batches** API. + +## Providers + +```{toctree} +:maxdepth: 1 + +inline_reference +``` diff --git a/docs/source/providers/batches/inline_reference.md b/docs/source/providers/batches/inline_reference.md new file mode 100644 index 000000000..a58e5124d --- /dev/null +++ b/docs/source/providers/batches/inline_reference.md @@ -0,0 +1,23 @@ +# inline::reference + +## Description + +Reference implementation of batches API with KVStore persistence. + +## Configuration + +| Field | Type | Required | Default | Description | +|-------|------|----------|---------|-------------| +| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Configuration for the key-value store backend. | +| `max_concurrent_batches` | `` | No | 1 | Maximum number of concurrent batches to process simultaneously. | +| `max_concurrent_requests_per_batch` | `` | No | 10 | Maximum number of concurrent requests to process per batch. | + +## Sample Configuration + +```yaml +kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/batches.db + +``` + diff --git a/docs/source/providers/eval/index.md b/docs/source/providers/eval/index.md index d180d256c..a14fada1d 100644 --- a/docs/source/providers/eval/index.md +++ b/docs/source/providers/eval/index.md @@ -2,6 +2,8 @@ ## Overview +Llama Stack Evaluation API for running evaluations on model and agent candidates. + This section contains documentation for all available providers for the **eval** API. ## Providers diff --git a/docs/source/providers/files/index.md b/docs/source/providers/files/index.md index 692aad3ca..128953223 100644 --- a/docs/source/providers/files/index.md +++ b/docs/source/providers/files/index.md @@ -10,4 +10,5 @@ This section contains documentation for all available providers for the **files* :maxdepth: 1 inline_localfs +remote_s3 ``` diff --git a/docs/source/providers/files/remote_s3.md b/docs/source/providers/files/remote_s3.md new file mode 100644 index 000000000..2e3cebabd --- /dev/null +++ b/docs/source/providers/files/remote_s3.md @@ -0,0 +1,33 @@ +# remote::s3 + +## Description + +AWS S3-based file storage provider for scalable cloud file management with metadata persistence. + +## Configuration + +| Field | Type | Required | Default | Description | +|-------|------|----------|---------|-------------| +| `bucket_name` | `` | No | | S3 bucket name to store files | +| `region` | `` | No | us-east-1 | AWS region where the bucket is located | +| `aws_access_key_id` | `str \| None` | No | | AWS access key ID (optional if using IAM roles) | +| `aws_secret_access_key` | `str \| None` | No | | AWS secret access key (optional if using IAM roles) | +| `endpoint_url` | `str \| None` | No | | Custom S3 endpoint URL (for MinIO, LocalStack, etc.) | +| `auto_create_bucket` | `` | No | False | Automatically create the S3 bucket if it doesn't exist | +| `metadata_store` | `utils.sqlstore.sqlstore.SqliteSqlStoreConfig \| utils.sqlstore.sqlstore.PostgresSqlStoreConfig` | No | sqlite | SQL store configuration for file metadata | + +## Sample Configuration + +```yaml +bucket_name: ${env.S3_BUCKET_NAME} +region: ${env.AWS_REGION:=us-east-1} +aws_access_key_id: ${env.AWS_ACCESS_KEY_ID:=} +aws_secret_access_key: ${env.AWS_SECRET_ACCESS_KEY:=} +endpoint_url: ${env.S3_ENDPOINT_URL:=} +auto_create_bucket: ${env.S3_AUTO_CREATE_BUCKET:=false} +metadata_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/s3_files_metadata.db + +``` + diff --git a/docs/source/providers/inference/index.md b/docs/source/providers/inference/index.md index 38781e5eb..b6d215474 100644 --- a/docs/source/providers/inference/index.md +++ b/docs/source/providers/inference/index.md @@ -2,6 +2,12 @@ ## Overview +Llama Stack Inference API for generating completions, chat completions, and embeddings. + + This API provides the raw interface to the underlying models. Two kinds of models are supported: + - LLM models: these models generate "raw" and "chat" (conversational) completions. + - Embedding models: these models generate embeddings to be used for semantic search. + This section contains documentation for all available providers for the **inference** API. ## Providers diff --git a/docs/source/providers/post_training/index.md b/docs/source/providers/post_training/index.md index c6c92c40e..5ada6f9aa 100644 --- a/docs/source/providers/post_training/index.md +++ b/docs/source/providers/post_training/index.md @@ -9,7 +9,9 @@ This section contains documentation for all available providers for the **post_t ```{toctree} :maxdepth: 1 -inline_huggingface -inline_torchtune +inline_huggingface-cpu +inline_huggingface-gpu +inline_torchtune-cpu +inline_torchtune-gpu remote_nvidia ``` diff --git a/docs/source/providers/post_training/inline_huggingface-cpu.md b/docs/source/providers/post_training/inline_huggingface-cpu.md new file mode 100644 index 000000000..e663fe8f8 --- /dev/null +++ b/docs/source/providers/post_training/inline_huggingface-cpu.md @@ -0,0 +1,41 @@ +# inline::huggingface-cpu + +## Description + +HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem. + +## Configuration + +| Field | Type | Required | Default | Description | +|-------|------|----------|---------|-------------| +| `device` | `` | No | cuda | | +| `distributed_backend` | `Literal['fsdp', 'deepspeed'` | No | | | +| `checkpoint_format` | `Literal['full_state', 'huggingface'` | No | huggingface | | +| `chat_template` | `` | No | <|user|> +{input} +<|assistant|> +{output} | | +| `model_specific_config` | `` | No | {'trust_remote_code': True, 'attn_implementation': 'sdpa'} | | +| `max_seq_length` | `` | No | 2048 | | +| `gradient_checkpointing` | `` | No | False | | +| `save_total_limit` | `` | No | 3 | | +| `logging_steps` | `` | No | 10 | | +| `warmup_ratio` | `` | No | 0.1 | | +| `weight_decay` | `` | No | 0.01 | | +| `dataloader_num_workers` | `` | No | 4 | | +| `dataloader_pin_memory` | `` | No | True | | +| `dpo_beta` | `` | No | 0.1 | | +| `use_reference_model` | `` | No | True | | +| `dpo_loss_type` | `Literal['sigmoid', 'hinge', 'ipo', 'kto_pair'` | No | sigmoid | | +| `dpo_output_dir` | `` | No | | | + +## Sample Configuration + +```yaml +checkpoint_format: huggingface +distributed_backend: null +device: cpu +dpo_output_dir: ~/.llama/dummy/dpo_output + +``` + diff --git a/docs/source/providers/post_training/inline_huggingface-gpu.md b/docs/source/providers/post_training/inline_huggingface-gpu.md new file mode 100644 index 000000000..21bf965fe --- /dev/null +++ b/docs/source/providers/post_training/inline_huggingface-gpu.md @@ -0,0 +1,41 @@ +# inline::huggingface-gpu + +## Description + +HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem. + +## Configuration + +| Field | Type | Required | Default | Description | +|-------|------|----------|---------|-------------| +| `device` | `` | No | cuda | | +| `distributed_backend` | `Literal['fsdp', 'deepspeed'` | No | | | +| `checkpoint_format` | `Literal['full_state', 'huggingface'` | No | huggingface | | +| `chat_template` | `` | No | <|user|> +{input} +<|assistant|> +{output} | | +| `model_specific_config` | `` | No | {'trust_remote_code': True, 'attn_implementation': 'sdpa'} | | +| `max_seq_length` | `` | No | 2048 | | +| `gradient_checkpointing` | `` | No | False | | +| `save_total_limit` | `` | No | 3 | | +| `logging_steps` | `` | No | 10 | | +| `warmup_ratio` | `` | No | 0.1 | | +| `weight_decay` | `` | No | 0.01 | | +| `dataloader_num_workers` | `` | No | 4 | | +| `dataloader_pin_memory` | `` | No | True | | +| `dpo_beta` | `` | No | 0.1 | | +| `use_reference_model` | `` | No | True | | +| `dpo_loss_type` | `Literal['sigmoid', 'hinge', 'ipo', 'kto_pair'` | No | sigmoid | | +| `dpo_output_dir` | `` | No | | | + +## Sample Configuration + +```yaml +checkpoint_format: huggingface +distributed_backend: null +device: cpu +dpo_output_dir: ~/.llama/dummy/dpo_output + +``` + diff --git a/docs/source/providers/post_training/inline_torchtune-cpu.md b/docs/source/providers/post_training/inline_torchtune-cpu.md new file mode 100644 index 000000000..7204e56e8 --- /dev/null +++ b/docs/source/providers/post_training/inline_torchtune-cpu.md @@ -0,0 +1,20 @@ +# inline::torchtune-cpu + +## Description + +TorchTune-based post-training provider for fine-tuning and optimizing models using Meta's TorchTune framework. + +## Configuration + +| Field | Type | Required | Default | Description | +|-------|------|----------|---------|-------------| +| `torch_seed` | `int \| None` | No | | | +| `checkpoint_format` | `Literal['meta', 'huggingface'` | No | meta | | + +## Sample Configuration + +```yaml +checkpoint_format: meta + +``` + diff --git a/docs/source/providers/post_training/inline_torchtune-gpu.md b/docs/source/providers/post_training/inline_torchtune-gpu.md new file mode 100644 index 000000000..98b94f6f6 --- /dev/null +++ b/docs/source/providers/post_training/inline_torchtune-gpu.md @@ -0,0 +1,20 @@ +# inline::torchtune-gpu + +## Description + +TorchTune-based post-training provider for fine-tuning and optimizing models using Meta's TorchTune framework. + +## Configuration + +| Field | Type | Required | Default | Description | +|-------|------|----------|---------|-------------| +| `torch_seed` | `int \| None` | No | | | +| `checkpoint_format` | `Literal['meta', 'huggingface'` | No | meta | | + +## Sample Configuration + +```yaml +checkpoint_format: meta + +``` + diff --git a/llama_stack/apis/agents/openai_responses.py b/llama_stack/apis/agents/openai_responses.py index 8574104dc..591992479 100644 --- a/llama_stack/apis/agents/openai_responses.py +++ b/llama_stack/apis/agents/openai_responses.py @@ -623,6 +623,62 @@ class OpenAIResponseObjectStreamResponseMcpCallCompleted(BaseModel): type: Literal["response.mcp_call.completed"] = "response.mcp_call.completed" +@json_schema_type +class OpenAIResponseContentPartOutputText(BaseModel): + type: Literal["output_text"] = "output_text" + text: str + # TODO: add annotations, logprobs, etc. + + +@json_schema_type +class OpenAIResponseContentPartRefusal(BaseModel): + type: Literal["refusal"] = "refusal" + refusal: str + + +OpenAIResponseContentPart = Annotated[ + OpenAIResponseContentPartOutputText | OpenAIResponseContentPartRefusal, + Field(discriminator="type"), +] +register_schema(OpenAIResponseContentPart, name="OpenAIResponseContentPart") + + +@json_schema_type +class OpenAIResponseObjectStreamResponseContentPartAdded(BaseModel): + """Streaming event for when a new content part is added to a response item. + + :param response_id: Unique identifier of the response containing this content + :param item_id: Unique identifier of the output item containing this content part + :param part: The content part that was added + :param sequence_number: Sequential number for ordering streaming events + :param type: Event type identifier, always "response.content_part.added" + """ + + response_id: str + item_id: str + part: OpenAIResponseContentPart + sequence_number: int + type: Literal["response.content_part.added"] = "response.content_part.added" + + +@json_schema_type +class OpenAIResponseObjectStreamResponseContentPartDone(BaseModel): + """Streaming event for when a content part is completed. + + :param response_id: Unique identifier of the response containing this content + :param item_id: Unique identifier of the output item containing this content part + :param part: The completed content part + :param sequence_number: Sequential number for ordering streaming events + :param type: Event type identifier, always "response.content_part.done" + """ + + response_id: str + item_id: str + part: OpenAIResponseContentPart + sequence_number: int + type: Literal["response.content_part.done"] = "response.content_part.done" + + OpenAIResponseObjectStream = Annotated[ OpenAIResponseObjectStreamResponseCreated | OpenAIResponseObjectStreamResponseOutputItemAdded @@ -642,6 +698,8 @@ OpenAIResponseObjectStream = Annotated[ | OpenAIResponseObjectStreamResponseMcpCallInProgress | OpenAIResponseObjectStreamResponseMcpCallFailed | OpenAIResponseObjectStreamResponseMcpCallCompleted + | OpenAIResponseObjectStreamResponseContentPartAdded + | OpenAIResponseObjectStreamResponseContentPartDone | OpenAIResponseObjectStreamResponseCompleted, Field(discriminator="type"), ] diff --git a/llama_stack/apis/batches/__init__.py b/llama_stack/apis/batches/__init__.py new file mode 100644 index 000000000..9ce7d3d75 --- /dev/null +++ b/llama_stack/apis/batches/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .batches import Batches, BatchObject, ListBatchesResponse + +__all__ = ["Batches", "BatchObject", "ListBatchesResponse"] diff --git a/llama_stack/apis/batches/batches.py b/llama_stack/apis/batches/batches.py new file mode 100644 index 000000000..c6bbd92eb --- /dev/null +++ b/llama_stack/apis/batches/batches.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Literal, Protocol, runtime_checkable + +from pydantic import BaseModel, Field + +from llama_stack.schema_utils import json_schema_type, webmethod + +try: + from openai.types import Batch as BatchObject +except ImportError as e: + raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e + + +@json_schema_type +class ListBatchesResponse(BaseModel): + """Response containing a list of batch objects.""" + + object: Literal["list"] = "list" + data: list[BatchObject] = Field(..., description="List of batch objects") + first_id: str | None = Field(default=None, description="ID of the first batch in the list") + last_id: str | None = Field(default=None, description="ID of the last batch in the list") + has_more: bool = Field(default=False, description="Whether there are more batches available") + + +@runtime_checkable +class Batches(Protocol): + """ + The Batches API enables efficient processing of multiple requests in a single operation, + particularly useful for processing large datasets, batch evaluation workflows, and + cost-effective inference at scale. + + The API is designed to allow use of openai client libraries for seamless integration. + + This API provides the following extensions: + - idempotent batch creation + + Note: This API is currently under active development and may undergo changes. + """ + + @webmethod(route="/openai/v1/batches", method="POST") + async def create_batch( + self, + input_file_id: str, + endpoint: str, + completion_window: Literal["24h"], + metadata: dict[str, str] | None = None, + idempotency_key: str | None = None, + ) -> BatchObject: + """Create a new batch for processing multiple API requests. + + :param input_file_id: The ID of an uploaded file containing requests for the batch. + :param endpoint: The endpoint to be used for all requests in the batch. + :param completion_window: The time window within which the batch should be processed. + :param metadata: Optional metadata for the batch. + :param idempotency_key: Optional idempotency key. When provided, enables idempotent behavior. + :returns: The created batch object. + """ + ... + + @webmethod(route="/openai/v1/batches/{batch_id}", method="GET") + async def retrieve_batch(self, batch_id: str) -> BatchObject: + """Retrieve information about a specific batch. + + :param batch_id: The ID of the batch to retrieve. + :returns: The batch object. + """ + ... + + @webmethod(route="/openai/v1/batches/{batch_id}/cancel", method="POST") + async def cancel_batch(self, batch_id: str) -> BatchObject: + """Cancel a batch that is in progress. + + :param batch_id: The ID of the batch to cancel. + :returns: The updated batch object. + """ + ... + + @webmethod(route="/openai/v1/batches", method="GET") + async def list_batches( + self, + after: str | None = None, + limit: int = 20, + ) -> ListBatchesResponse: + """List all batches for the current user. + + :param after: A cursor for pagination; returns batches after this batch ID. + :param limit: Number of batches to return (default 20, max 100). + :returns: A list of batch objects. + """ + ... diff --git a/llama_stack/apis/common/errors.py b/llama_stack/apis/common/errors.py index 6e0fa0b3c..ec3d2b1ce 100644 --- a/llama_stack/apis/common/errors.py +++ b/llama_stack/apis/common/errors.py @@ -72,3 +72,10 @@ class ModelTypeError(TypeError): f"Model '{model_name}' is of type '{model_type}' rather than the expected type '{expected_model_type}'" ) super().__init__(message) + + +class ConflictError(ValueError): + """raised when an operation cannot be performed due to a conflict with the current state""" + + def __init__(self, message: str) -> None: + super().__init__(message) diff --git a/llama_stack/apis/datatypes.py b/llama_stack/apis/datatypes.py index cabe46a2f..87fc95917 100644 --- a/llama_stack/apis/datatypes.py +++ b/llama_stack/apis/datatypes.py @@ -86,6 +86,7 @@ class Api(Enum, metaclass=DynamicApiMeta): :cvar inference: Text generation, chat completions, and embeddings :cvar safety: Content moderation and safety shields :cvar agents: Agent orchestration and execution + :cvar batches: Batch processing for asynchronous API requests :cvar vector_io: Vector database operations and queries :cvar datasetio: Dataset input/output operations :cvar scoring: Model output evaluation and scoring @@ -108,6 +109,7 @@ class Api(Enum, metaclass=DynamicApiMeta): inference = "inference" safety = "safety" agents = "agents" + batches = "batches" vector_io = "vector_io" datasetio = "datasetio" scoring = "scoring" diff --git a/llama_stack/apis/files/files.py b/llama_stack/apis/files/files.py index ba8701e23..a1b9dd4dc 100644 --- a/llama_stack/apis/files/files.py +++ b/llama_stack/apis/files/files.py @@ -22,6 +22,7 @@ class OpenAIFilePurpose(StrEnum): """ ASSISTANTS = "assistants" + BATCH = "batch" # TODO: Add other purposes as needed diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 7e7bd0a3d..bd4737ca7 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -473,6 +473,28 @@ class EmbeddingsResponse(BaseModel): embeddings: list[list[float]] +@json_schema_type +class RerankData(BaseModel): + """A single rerank result from a reranking response. + + :param index: The original index of the document in the input list + :param relevance_score: The relevance score from the model output. Values are inverted when applicable so that higher scores indicate greater relevance. + """ + + index: int + relevance_score: float + + +@json_schema_type +class RerankResponse(BaseModel): + """Response from a reranking request. + + :param data: List of rerank result objects, sorted by relevance score (descending) + """ + + data: list[RerankData] + + @json_schema_type class OpenAIChatCompletionContentPartTextParam(BaseModel): """Text content part for OpenAI-compatible chat completion messages. @@ -1046,6 +1068,7 @@ class InferenceProvider(Protocol): :returns: A BatchCompletionResponse with the full completions. """ raise NotImplementedError("Batch completion is not implemented") + return # this is so mypy's safe-super rule will consider the method concrete @webmethod(route="/inference/chat-completion", method="POST") async def chat_completion( @@ -1110,6 +1133,7 @@ class InferenceProvider(Protocol): :returns: A BatchChatCompletionResponse with the full completions. """ raise NotImplementedError("Batch chat completion is not implemented") + return # this is so mypy's safe-super rule will consider the method concrete @webmethod(route="/inference/embeddings", method="POST") async def embeddings( @@ -1131,6 +1155,25 @@ class InferenceProvider(Protocol): """ ... + @webmethod(route="/inference/rerank", method="POST", experimental=True) + async def rerank( + self, + model: str, + query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam, + items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam], + max_num_results: int | None = None, + ) -> RerankResponse: + """Rerank a list of documents based on their relevance to a query. + + :param model: The identifier of the reranking model to use. + :param query: The search query to rank items against. Can be a string, text content part, or image content part. The input must not exceed the model's max input token length. + :param items: List of items to rerank. Each item can be a string, text content part, or image content part. Each input must not exceed the model's max input token length. + :param max_num_results: (Optional) Maximum number of results to return. Default: returns all. + :returns: RerankResponse with indices sorted by relevance score (descending). + """ + raise NotImplementedError("Reranking is not implemented") + return # this is so mypy's safe-super rule will consider the method concrete + @webmethod(route="/openai/v1/completions", method="POST") async def openai_completion( self, diff --git a/llama_stack/apis/telemetry/telemetry.py b/llama_stack/apis/telemetry/telemetry.py index 92422ac1b..8d1b5d697 100644 --- a/llama_stack/apis/telemetry/telemetry.py +++ b/llama_stack/apis/telemetry/telemetry.py @@ -386,6 +386,7 @@ class MetricDataPoint(BaseModel): timestamp: int value: float + unit: str @json_schema_type @@ -518,7 +519,7 @@ class Telemetry(Protocol): metric_name: str, start_time: int, end_time: int | None = None, - granularity: str | None = "1d", + granularity: str | None = None, query_type: MetricQueryType = MetricQueryType.RANGE, label_matchers: list[MetricLabelMatcher] | None = None, ) -> QueryMetricsResponse: diff --git a/llama_stack/cli/stack/run.py b/llama_stack/cli/stack/run.py index c8ffce034..b32b8b3ae 100644 --- a/llama_stack/cli/stack/run.py +++ b/llama_stack/cli/stack/run.py @@ -15,7 +15,7 @@ from llama_stack.log import get_logger REPO_ROOT = Path(__file__).parent.parent.parent.parent -logger = get_logger(name=__name__, category="server") +logger = get_logger(name=__name__, category="cli") class StackRun(Subcommand): diff --git a/llama_stack/core/build.py b/llama_stack/core/build.py index 4b20588fd..fa1fe632b 100644 --- a/llama_stack/core/build.py +++ b/llama_stack/core/build.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import importlib.resources -import logging import sys from pydantic import BaseModel @@ -17,9 +16,10 @@ from llama_stack.core.external import load_external_apis from llama_stack.core.utils.exec import run_command from llama_stack.core.utils.image_types import LlamaStackImageType from llama_stack.distributions.template import DistributionTemplate +from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="core") # These are the dependencies needed by the distribution server. # `llama-stack` is automatically installed by the installation script. diff --git a/llama_stack/core/build_conda_env.sh b/llama_stack/core/build_conda_env.sh deleted file mode 100755 index 48ac3a1ab..000000000 --- a/llama_stack/core/build_conda_env.sh +++ /dev/null @@ -1,207 +0,0 @@ -#!/bin/bash - -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-} -LLAMA_STACK_CLIENT_DIR=${LLAMA_STACK_CLIENT_DIR:-} -TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-} -PYPI_VERSION=${PYPI_VERSION:-} -# This timeout (in seconds) is necessary when installing PyTorch via uv since it's likely to time out -# Reference: https://github.com/astral-sh/uv/pull/1694 -UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500} - -set -euo pipefail - -# Define color codes -RED='\033[0;31m' -GREEN='\033[0;32m' -NC='\033[0m' # No Color - -SCRIPT_DIR=$(dirname "$(readlink -f "$0")") -source "$SCRIPT_DIR/common.sh" - -# Usage function -usage() { - echo "Usage: $0 --env-name --build-file-path --normal-deps [--external-provider-deps ] [--optional-deps ]" - echo "Example: $0 --env-name my-conda-env --build-file-path ./my-stack-build.yaml --normal-deps 'numpy pandas scipy' --external-provider-deps 'foo' --optional-deps 'bar'" - exit 1 -} - -# Parse arguments -env_name="" -build_file_path="" -normal_deps="" -external_provider_deps="" -optional_deps="" - -while [[ $# -gt 0 ]]; do - key="$1" - case "$key" in - --env-name) - if [[ -z "$2" || "$2" == --* ]]; then - echo "Error: --env-name requires a string value" >&2 - usage - fi - env_name="$2" - shift 2 - ;; - --build-file-path) - if [[ -z "$2" || "$2" == --* ]]; then - echo "Error: --build-file-path requires a string value" >&2 - usage - fi - build_file_path="$2" - shift 2 - ;; - --normal-deps) - if [[ -z "$2" || "$2" == --* ]]; then - echo "Error: --normal-deps requires a string value" >&2 - usage - fi - normal_deps="$2" - shift 2 - ;; - --external-provider-deps) - if [[ -z "$2" || "$2" == --* ]]; then - echo "Error: --external-provider-deps requires a string value" >&2 - usage - fi - external_provider_deps="$2" - shift 2 - ;; - --optional-deps) - if [[ -z "$2" || "$2" == --* ]]; then - echo "Error: --optional-deps requires a string value" >&2 - usage - fi - optional_deps="$2" - shift 2 - ;; - *) - echo "Unknown option: $1" >&2 - usage - ;; - esac -done - -# Check required arguments -if [[ -z "$env_name" || -z "$build_file_path" || -z "$normal_deps" ]]; then - echo "Error: --env-name, --build-file-path, and --normal-deps are required." >&2 - usage -fi - -if [ -n "$LLAMA_STACK_DIR" ]; then - echo "Using llama-stack-dir=$LLAMA_STACK_DIR" -fi -if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then - echo "Using llama-stack-client-dir=$LLAMA_STACK_CLIENT_DIR" -fi - -ensure_conda_env_python310() { - # Use only global variables set by flag parser - local python_version="3.12" - - if ! is_command_available conda; then - printf "${RED}Error: conda command not found. Is Conda installed and in your PATH?${NC}" >&2 - exit 1 - fi - - if conda env list | grep -q "^${env_name} "; then - printf "Conda environment '${env_name}' exists. Checking Python version...\n" - current_version=$(conda run -n "${env_name}" python --version 2>&1 | cut -d' ' -f2 | cut -d'.' -f1,2) - if [ "$current_version" = "$python_version" ]; then - printf "Environment '${env_name}' already has Python ${python_version}. No action needed.\n" - else - printf "Updating environment '${env_name}' to Python ${python_version}...\n" - conda install -n "${env_name}" python="${python_version}" -y - fi - else - printf "Conda environment '${env_name}' does not exist. Creating with Python ${python_version}...\n" - conda create -n "${env_name}" python="${python_version}" -y - fi - - eval "$(conda shell.bash hook)" - conda deactivate && conda activate "${env_name}" - "$CONDA_PREFIX"/bin/pip install uv - - if [ -n "$TEST_PYPI_VERSION" ]; then - uv pip install fastapi libcst - uv pip install --extra-index-url https://test.pypi.org/simple/ \ - llama-stack=="$TEST_PYPI_VERSION" \ - "$normal_deps" - if [ -n "$optional_deps" ]; then - IFS='#' read -ra parts <<<"$optional_deps" - for part in "${parts[@]}"; do - echo "$part" - uv pip install $part - done - fi - if [ -n "$external_provider_deps" ]; then - IFS='#' read -ra parts <<<"$external_provider_deps" - for part in "${parts[@]}"; do - echo "$part" - uv pip install "$part" - done - fi - else - if [ -n "$LLAMA_STACK_DIR" ]; then - if [ ! -d "$LLAMA_STACK_DIR" ]; then - printf "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: $LLAMA_STACK_DIR${NC}\n" >&2 - exit 1 - fi - printf "Installing from LLAMA_STACK_DIR: $LLAMA_STACK_DIR\n" - uv pip install --no-cache-dir -e "$LLAMA_STACK_DIR" - else - PYPI_VERSION="${PYPI_VERSION:-}" - if [ -n "$PYPI_VERSION" ]; then - SPEC_VERSION="llama-stack==${PYPI_VERSION}" - else - SPEC_VERSION="llama-stack" - fi - uv pip install --no-cache-dir "$SPEC_VERSION" - fi - if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then - if [ ! -d "$LLAMA_STACK_CLIENT_DIR" ]; then - printf "${RED}Warning: LLAMA_STACK_CLIENT_DIR is set but directory does not exist: $LLAMA_STACK_CLIENT_DIR${NC}\n" >&2 - exit 1 - fi - printf "Installing from LLAMA_STACK_CLIENT_DIR: $LLAMA_STACK_CLIENT_DIR\n" - uv pip install --no-cache-dir -e "$LLAMA_STACK_CLIENT_DIR" - fi - printf "Installing pip dependencies\n" - uv pip install $normal_deps - if [ -n "$optional_deps" ]; then - IFS='#' read -ra parts <<<"$optional_deps" - for part in "${parts[@]}"; do - echo "$part" - uv pip install $part - done - fi - if [ -n "$external_provider_deps" ]; then - IFS='#' read -ra parts <<<"$external_provider_deps" - for part in "${parts[@]}"; do - echo "Getting provider spec for module: $part and installing dependencies" - package_name=$(echo "$part" | sed 's/[<>=!].*//') - python3 -c " -import importlib -import sys -try: - module = importlib.import_module(f'$package_name.provider') - spec = module.get_provider_spec() - if hasattr(spec, 'pip_packages') and spec.pip_packages: - print('\\n'.join(spec.pip_packages)) -except Exception as e: - print(f'Error getting provider spec for $package_name: {e}', file=sys.stderr) -" | uv pip install -r - - done - fi - fi - mv "$build_file_path" "$CONDA_PREFIX"/llamastack-build.yaml - echo "Build spec configuration saved at $CONDA_PREFIX/llamastack-build.yaml" -} - -ensure_conda_env_python310 "$env_name" "$build_file_path" "$normal_deps" "$optional_deps" "$external_provider_deps" diff --git a/llama_stack/core/build_venv.sh b/llama_stack/core/build_venv.sh index a2838803f..04927d71e 100755 --- a/llama_stack/core/build_venv.sh +++ b/llama_stack/core/build_venv.sh @@ -151,23 +151,37 @@ run() { fi else if [ -n "$LLAMA_STACK_DIR" ]; then - if [ ! -d "$LLAMA_STACK_DIR" ]; then + # only warn if DIR does not start with "git+" + if [ ! -d "$LLAMA_STACK_DIR" ] && [[ "$LLAMA_STACK_DIR" != git+* ]]; then printf "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: %s${NC}\n" "$LLAMA_STACK_DIR" >&2 exit 1 fi printf "Installing from LLAMA_STACK_DIR: %s\n" "$LLAMA_STACK_DIR" - uv pip install --no-cache-dir -e "$LLAMA_STACK_DIR" + # editable only if LLAMA_STACK_DIR does not start with "git+" + if [[ "$LLAMA_STACK_DIR" != git+* ]]; then + EDITABLE="-e" + else + EDITABLE="" + fi + uv pip install --no-cache-dir $EDITABLE "$LLAMA_STACK_DIR" else uv pip install --no-cache-dir llama-stack fi if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then - if [ ! -d "$LLAMA_STACK_CLIENT_DIR" ]; then + # only warn if DIR does not start with "git+" + if [ ! -d "$LLAMA_STACK_CLIENT_DIR" ] && [[ "$LLAMA_STACK_CLIENT_DIR" != git+* ]]; then printf "${RED}Warning: LLAMA_STACK_CLIENT_DIR is set but directory does not exist: %s${NC}\n" "$LLAMA_STACK_CLIENT_DIR" >&2 exit 1 fi printf "Installing from LLAMA_STACK_CLIENT_DIR: %s\n" "$LLAMA_STACK_CLIENT_DIR" - uv pip install --no-cache-dir -e "$LLAMA_STACK_CLIENT_DIR" + # editable only if LLAMA_STACK_CLIENT_DIR does not start with "git+" + if [[ "$LLAMA_STACK_CLIENT_DIR" != git+* ]]; then + EDITABLE="-e" + else + EDITABLE="" + fi + uv pip install --no-cache-dir $EDITABLE "$LLAMA_STACK_CLIENT_DIR" fi printf "Installing pip dependencies\n" diff --git a/llama_stack/core/configure.py b/llama_stack/core/configure.py index 9e18b438c..64473c053 100644 --- a/llama_stack/core/configure.py +++ b/llama_stack/core/configure.py @@ -3,7 +3,6 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging import textwrap from typing import Any @@ -21,9 +20,10 @@ from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars from llama_stack.core.utils.config_dirs import EXTERNAL_PROVIDERS_DIR from llama_stack.core.utils.dynamic import instantiate_class_type from llama_stack.core.utils.prompt_for_config import prompt_for_config +from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api, ProviderSpec -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="core") def configure_single_provider(registry: dict[str, ProviderSpec], provider: Provider) -> Provider: diff --git a/llama_stack/core/datatypes.py b/llama_stack/core/datatypes.py index a1b6ad32b..c3940fcbd 100644 --- a/llama_stack/core/datatypes.py +++ b/llama_stack/core/datatypes.py @@ -318,6 +318,41 @@ class QuotaConfig(BaseModel): period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set") +class CORSConfig(BaseModel): + allow_origins: list[str] = Field(default_factory=list) + allow_origin_regex: str | None = Field(default=None) + allow_methods: list[str] = Field(default=["OPTIONS"]) + allow_headers: list[str] = Field(default_factory=list) + allow_credentials: bool = Field(default=False) + expose_headers: list[str] = Field(default_factory=list) + max_age: int = Field(default=600, ge=0) + + @model_validator(mode="after") + def validate_credentials_config(self) -> Self: + if self.allow_credentials and (self.allow_origins == ["*"] or "*" in self.allow_origins): + raise ValueError("Cannot use wildcard origins with credentials enabled") + return self + + +def process_cors_config(cors_config: bool | CORSConfig | None) -> CORSConfig | None: + if cors_config is False or cors_config is None: + return None + + if cors_config is True: + # dev mode: allow localhost on any port + return CORSConfig( + allow_origins=[], + allow_origin_regex=r"https?://localhost:\d+", + allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], + allow_headers=["Content-Type", "Authorization", "X-Requested-With"], + ) + + if isinstance(cors_config, CORSConfig): + return cors_config + + raise ValueError(f"Expected bool or CORSConfig, got {type(cors_config).__name__}") + + class ServerConfig(BaseModel): port: int = Field( default=8321, @@ -349,6 +384,12 @@ class ServerConfig(BaseModel): default=None, description="Per client quota request configuration", ) + cors: bool | CORSConfig | None = Field( + default=None, + description="CORS configuration for cross-origin requests. Can be:\n" + "- true: Enable localhost CORS for development\n" + "- {allow_origins: [...], allow_methods: [...], ...}: Full configuration", + ) class StackRunConfig(BaseModel): diff --git a/llama_stack/core/library_client.py b/llama_stack/core/library_client.py index a93fe509e..9e7a8006c 100644 --- a/llama_stack/core/library_client.py +++ b/llama_stack/core/library_client.py @@ -7,7 +7,7 @@ import asyncio import inspect import json -import logging +import logging # allow-direct-logging import os import sys from concurrent.futures import ThreadPoolExecutor @@ -48,6 +48,7 @@ from llama_stack.core.stack import ( from llama_stack.core.utils.config import redact_sensitive_fields from llama_stack.core.utils.context import preserve_contexts_async_generator from llama_stack.core.utils.exec import in_notebook +from llama_stack.log import get_logger from llama_stack.providers.utils.telemetry.tracing import ( CURRENT_TRACE_CONTEXT, end_trace, @@ -55,7 +56,7 @@ from llama_stack.providers.utils.telemetry.tracing import ( start_trace, ) -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="core") T = TypeVar("T") @@ -145,39 +146,26 @@ class LlamaStackAsLibraryClient(LlamaStackClient): ): super().__init__() self.async_client = AsyncLlamaStackAsLibraryClient( - config_path_or_distro_name, custom_provider_registry, provider_data + config_path_or_distro_name, custom_provider_registry, provider_data, skip_logger_removal ) self.pool_executor = ThreadPoolExecutor(max_workers=4) - self.skip_logger_removal = skip_logger_removal self.provider_data = provider_data self.loop = asyncio.new_event_loop() - def initialize(self): - if in_notebook(): - import nest_asyncio - - nest_asyncio.apply() - if not self.skip_logger_removal: - self._remove_root_logger_handlers() - # use a new event loop to avoid interfering with the main event loop loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: - return loop.run_until_complete(self.async_client.initialize()) + loop.run_until_complete(self.async_client.initialize()) finally: asyncio.set_event_loop(None) - def _remove_root_logger_handlers(self): + def initialize(self): """ - Remove all handlers from the root logger. Needed to avoid polluting the console with logs. + Deprecated method for backward compatibility. """ - root_logger = logging.getLogger() - - for handler in root_logger.handlers[:]: - root_logger.removeHandler(handler) - logger.info(f"Removed handler {handler.__class__.__name__} from root logger") + pass def request(self, *args, **kwargs): loop = self.loop @@ -215,6 +203,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): config_path_or_distro_name: str, custom_provider_registry: ProviderRegistry | None = None, provider_data: dict[str, Any] | None = None, + skip_logger_removal: bool = False, ): super().__init__() # when using the library client, we should not log to console since many @@ -222,6 +211,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",") os.environ["TELEMETRY_SINKS"] = ",".join(sink for sink in current_sinks if sink != "console") + if in_notebook(): + import nest_asyncio + + nest_asyncio.apply() + if not skip_logger_removal: + self._remove_root_logger_handlers() + if config_path_or_distro_name.endswith(".yaml"): config_path = Path(config_path_or_distro_name) if not config_path.exists(): @@ -238,7 +234,24 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): self.provider_data = provider_data self.route_impls: RouteImpls | None = None # Initialize to None to prevent AttributeError + def _remove_root_logger_handlers(self): + """ + Remove all handlers from the root logger. Needed to avoid polluting the console with logs. + """ + root_logger = logging.getLogger() + + for handler in root_logger.handlers[:]: + root_logger.removeHandler(handler) + logger.info(f"Removed handler {handler.__class__.__name__} from root logger") + async def initialize(self) -> bool: + """ + Initialize the async client. + + Returns: + bool: True if initialization was successful + """ + try: self.route_impls = None self.impls = await construct_stack(self.config, self.custom_provider_registry) diff --git a/llama_stack/core/request_headers.py b/llama_stack/core/request_headers.py index 35ac72775..f1ce8281f 100644 --- a/llama_stack/core/request_headers.py +++ b/llama_stack/core/request_headers.py @@ -6,15 +6,15 @@ import contextvars import json -import logging from contextlib import AbstractContextManager from typing import Any from llama_stack.core.datatypes import User +from llama_stack.log import get_logger from .utils.dynamic import instantiate_class_type -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="core") # Context variable for request provider data and auth attributes PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None) diff --git a/llama_stack/core/resolver.py b/llama_stack/core/resolver.py index 70c78fb01..7ac98dac8 100644 --- a/llama_stack/core/resolver.py +++ b/llama_stack/core/resolver.py @@ -8,6 +8,7 @@ import inspect from typing import Any from llama_stack.apis.agents import Agents +from llama_stack.apis.batches import Batches from llama_stack.apis.benchmarks import Benchmarks from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets @@ -75,6 +76,7 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) -> Api.agents: Agents, Api.inference: Inference, Api.inspect: Inspect, + Api.batches: Batches, Api.vector_io: VectorIO, Api.vector_dbs: VectorDBs, Api.models: Models, diff --git a/llama_stack/core/routers/datasets.py b/llama_stack/core/routers/datasets.py index d7984f729..2f1d5f78e 100644 --- a/llama_stack/core/routers/datasets.py +++ b/llama_stack/core/routers/datasets.py @@ -12,7 +12,7 @@ from llama_stack.apis.datasets import DatasetPurpose, DataSource from llama_stack.log import get_logger from llama_stack.providers.datatypes import RoutingTable -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routers") class DatasetIORouter(DatasetIO): diff --git a/llama_stack/core/routers/eval_scoring.py b/llama_stack/core/routers/eval_scoring.py index f7a17eecf..ffca81bf0 100644 --- a/llama_stack/core/routers/eval_scoring.py +++ b/llama_stack/core/routers/eval_scoring.py @@ -16,7 +16,7 @@ from llama_stack.apis.scoring import ( from llama_stack.log import get_logger from llama_stack.providers.datatypes import RoutingTable -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routers") class ScoringRouter(Scoring): diff --git a/llama_stack/core/routers/inference.py b/llama_stack/core/routers/inference.py index 6a3f07247..4b66601bb 100644 --- a/llama_stack/core/routers/inference.py +++ b/llama_stack/core/routers/inference.py @@ -65,7 +65,7 @@ from llama_stack.providers.datatypes import HealthResponse, HealthStatus, Routin from llama_stack.providers.utils.inference.inference_store import InferenceStore from llama_stack.providers.utils.telemetry.tracing import get_current_span -logger = get_logger(name=__name__, category="inference") +logger = get_logger(name=__name__, category="core::routers") class InferenceRouter(Inference): diff --git a/llama_stack/core/routers/safety.py b/llama_stack/core/routers/safety.py index c76673d2a..9ba3327f1 100644 --- a/llama_stack/core/routers/safety.py +++ b/llama_stack/core/routers/safety.py @@ -6,16 +6,14 @@ from typing import Any -from llama_stack.apis.inference import ( - Message, -) +from llama_stack.apis.inference import Message from llama_stack.apis.safety import RunShieldResponse, Safety from llama_stack.apis.safety.safety import ModerationObject from llama_stack.apis.shields import Shield from llama_stack.log import get_logger from llama_stack.providers.datatypes import RoutingTable -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routers") class SafetyRouter(Safety): @@ -68,6 +66,7 @@ class SafetyRouter(Safety): list_shields_response = await self.routing_table.list_shields() matches = [s.identifier for s in list_shields_response.data if model == s.provider_resource_id] + if not matches: raise ValueError(f"No shield associated with provider_resource id {model}") if len(matches) > 1: diff --git a/llama_stack/core/routers/tool_runtime.py b/llama_stack/core/routers/tool_runtime.py index 5a40bc0c5..fd606f33b 100644 --- a/llama_stack/core/routers/tool_runtime.py +++ b/llama_stack/core/routers/tool_runtime.py @@ -22,7 +22,7 @@ from llama_stack.log import get_logger from ..routing_tables.toolgroups import ToolGroupsRoutingTable -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routers") class ToolRuntimeRouter(ToolRuntime): diff --git a/llama_stack/core/routers/vector_io.py b/llama_stack/core/routers/vector_io.py index 3d0996c49..786b0e391 100644 --- a/llama_stack/core/routers/vector_io.py +++ b/llama_stack/core/routers/vector_io.py @@ -30,7 +30,7 @@ from llama_stack.apis.vector_io import ( from llama_stack.log import get_logger from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routers") class VectorIORouter(VectorIO): diff --git a/llama_stack/core/routing_tables/benchmarks.py b/llama_stack/core/routing_tables/benchmarks.py index 74bee8040..c875dee5b 100644 --- a/llama_stack/core/routing_tables/benchmarks.py +++ b/llama_stack/core/routing_tables/benchmarks.py @@ -14,7 +14,7 @@ from llama_stack.log import get_logger from .common import CommonRoutingTableImpl -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routing_tables") class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks): diff --git a/llama_stack/core/routing_tables/common.py b/llama_stack/core/routing_tables/common.py index 339ff6da4..e523746d8 100644 --- a/llama_stack/core/routing_tables/common.py +++ b/llama_stack/core/routing_tables/common.py @@ -23,7 +23,7 @@ from llama_stack.core.store import DistributionRegistry from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api, RoutingTable -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routing_tables") def get_impl_api(p: Any) -> Api: diff --git a/llama_stack/core/routing_tables/datasets.py b/llama_stack/core/routing_tables/datasets.py index fc6a75df4..b129c9ec5 100644 --- a/llama_stack/core/routing_tables/datasets.py +++ b/llama_stack/core/routing_tables/datasets.py @@ -26,7 +26,7 @@ from llama_stack.log import get_logger from .common import CommonRoutingTableImpl -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routing_tables") class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): diff --git a/llama_stack/core/routing_tables/models.py b/llama_stack/core/routing_tables/models.py index 34c431e00..b6141efa9 100644 --- a/llama_stack/core/routing_tables/models.py +++ b/llama_stack/core/routing_tables/models.py @@ -17,7 +17,7 @@ from llama_stack.log import get_logger from .common import CommonRoutingTableImpl, lookup_model -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routing_tables") class ModelsRoutingTable(CommonRoutingTableImpl, Models): diff --git a/llama_stack/core/routing_tables/scoring_functions.py b/llama_stack/core/routing_tables/scoring_functions.py index 5874ba941..71e5bed63 100644 --- a/llama_stack/core/routing_tables/scoring_functions.py +++ b/llama_stack/core/routing_tables/scoring_functions.py @@ -19,7 +19,7 @@ from llama_stack.log import get_logger from .common import CommonRoutingTableImpl -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routing_tables") class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): diff --git a/llama_stack/core/routing_tables/shields.py b/llama_stack/core/routing_tables/shields.py index e08f35bfc..b1918d20a 100644 --- a/llama_stack/core/routing_tables/shields.py +++ b/llama_stack/core/routing_tables/shields.py @@ -15,7 +15,7 @@ from llama_stack.log import get_logger from .common import CommonRoutingTableImpl -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routing_tables") class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): diff --git a/llama_stack/core/routing_tables/toolgroups.py b/llama_stack/core/routing_tables/toolgroups.py index 6910b3906..eeea406c1 100644 --- a/llama_stack/core/routing_tables/toolgroups.py +++ b/llama_stack/core/routing_tables/toolgroups.py @@ -14,7 +14,7 @@ from llama_stack.log import get_logger from .common import CommonRoutingTableImpl -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routing_tables") def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name: str) -> str | None: diff --git a/llama_stack/core/routing_tables/vector_dbs.py b/llama_stack/core/routing_tables/vector_dbs.py index e8dc46997..00f71b4fe 100644 --- a/llama_stack/core/routing_tables/vector_dbs.py +++ b/llama_stack/core/routing_tables/vector_dbs.py @@ -30,7 +30,7 @@ from llama_stack.log import get_logger from .common import CommonRoutingTableImpl, lookup_model -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routing_tables") class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): diff --git a/llama_stack/core/server/auth.py b/llama_stack/core/server/auth.py index e4fb4ff2b..c98d3bec0 100644 --- a/llama_stack/core/server/auth.py +++ b/llama_stack/core/server/auth.py @@ -15,7 +15,7 @@ from llama_stack.core.server.auth_providers import create_auth_provider from llama_stack.core.server.routes import find_matching_route, initialize_route_impls from llama_stack.log import get_logger -logger = get_logger(name=__name__, category="auth") +logger = get_logger(name=__name__, category="core::auth") class AuthenticationMiddleware: diff --git a/llama_stack/core/server/auth_providers.py b/llama_stack/core/server/auth_providers.py index 73d5581c2..a8af6f75a 100644 --- a/llama_stack/core/server/auth_providers.py +++ b/llama_stack/core/server/auth_providers.py @@ -23,7 +23,7 @@ from llama_stack.core.datatypes import ( ) from llama_stack.log import get_logger -logger = get_logger(name=__name__, category="auth") +logger = get_logger(name=__name__, category="core::auth") class AuthResponse(BaseModel): diff --git a/llama_stack/core/server/quota.py b/llama_stack/core/server/quota.py index 1cb850cde..693f224c3 100644 --- a/llama_stack/core/server/quota.py +++ b/llama_stack/core/server/quota.py @@ -15,7 +15,7 @@ from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl -logger = get_logger(name=__name__, category="quota") +logger = get_logger(name=__name__, category="core::server") class QuotaMiddleware: diff --git a/llama_stack/core/server/server.py b/llama_stack/core/server/server.py index e9d70fc8d..d6dfc3435 100644 --- a/llama_stack/core/server/server.py +++ b/llama_stack/core/server/server.py @@ -9,7 +9,7 @@ import asyncio import functools import inspect import json -import logging +import logging # allow-direct-logging import os import ssl import sys @@ -28,10 +28,12 @@ from aiohttp import hdrs from fastapi import Body, FastAPI, HTTPException, Request, Response from fastapi import Path as FastapiPath from fastapi.exceptions import RequestValidationError +from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse from openai import BadRequestError from pydantic import BaseModel, ValidationError +from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError from llama_stack.apis.common.responses import PaginatedResponse from llama_stack.cli.utils import add_config_distro_args, get_config_from_args from llama_stack.core.access_control.access_control import AccessDeniedError @@ -39,6 +41,7 @@ from llama_stack.core.datatypes import ( AuthenticationRequiredError, LoggingConfig, StackRunConfig, + process_cors_config, ) from llama_stack.core.distribution import builtin_automatically_routed_apis from llama_stack.core.external import ExternalApiSpec, load_external_apis @@ -81,7 +84,7 @@ from .quota import QuotaMiddleware REPO_ROOT = Path(__file__).parent.parent.parent.parent -logger = get_logger(name=__name__, category="server") +logger = get_logger(name=__name__, category="core::server") def warn_with_traceback(message, category, filename, lineno, file=None, line=None): @@ -128,6 +131,10 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro ] }, ) + elif isinstance(exc, ConflictError): + return HTTPException(status_code=409, detail=str(exc)) + elif isinstance(exc, ResourceNotFoundError): + return HTTPException(status_code=404, detail=str(exc)) elif isinstance(exc, ValueError): return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=f"Invalid value: {str(exc)}") elif isinstance(exc, BadRequestError): @@ -408,7 +415,7 @@ def main(args: argparse.Namespace | None = None): config_contents = yaml.safe_load(fp) if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")): logger_config = LoggingConfig(**cfg) - logger = get_logger(name=__name__, category="server", config=logger_config) + logger = get_logger(name=__name__, category="core::server", config=logger_config) if args.env: for env_pair in args.env: try: @@ -478,6 +485,12 @@ def main(args: argparse.Namespace | None = None): window_seconds=window_seconds, ) + if config.server.cors: + logger.info("Enabling CORS") + cors_config = process_cors_config(config.server.cors) + if cors_config: + app.add_middleware(CORSMiddleware, **cors_config.model_dump()) + if Api.telemetry in impls: setup_logger(impls[Api.telemetry]) else: diff --git a/llama_stack/core/store/registry.py b/llama_stack/core/store/registry.py index 4b60e1001..5f4abe9aa 100644 --- a/llama_stack/core/store/registry.py +++ b/llama_stack/core/store/registry.py @@ -16,7 +16,7 @@ from llama_stack.log import get_logger from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig -logger = get_logger(__name__, category="core") +logger = get_logger(__name__, category="core::registry") class DistributionRegistry(Protocol): diff --git a/llama_stack/core/utils/config_resolution.py b/llama_stack/core/utils/config_resolution.py index 30cd71e15..182a571ee 100644 --- a/llama_stack/core/utils/config_resolution.py +++ b/llama_stack/core/utils/config_resolution.py @@ -10,7 +10,7 @@ from pathlib import Path from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR from llama_stack.log import get_logger -logger = get_logger(name=__name__, category="config_resolution") +logger = get_logger(name=__name__, category="core") DISTRO_DIR = Path(__file__).parent.parent.parent.parent / "llama_stack" / "distributions" diff --git a/llama_stack/core/utils/exec.py b/llama_stack/core/utils/exec.py index 1b2b782fe..12fb82d01 100644 --- a/llama_stack/core/utils/exec.py +++ b/llama_stack/core/utils/exec.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging +import importlib import os import signal import subprocess @@ -12,9 +12,9 @@ import sys from termcolor import cprint -log = logging.getLogger(__name__) +from llama_stack.log import get_logger -import importlib +log = get_logger(name=__name__, category="core") def formulate_run_args(image_type: str, image_name: str) -> list: diff --git a/llama_stack/core/utils/prompt_for_config.py b/llama_stack/core/utils/prompt_for_config.py index 26f6920e0..bac0531ed 100644 --- a/llama_stack/core/utils/prompt_for_config.py +++ b/llama_stack/core/utils/prompt_for_config.py @@ -6,7 +6,6 @@ import inspect import json -import logging from enum import Enum from typing import Annotated, Any, Literal, Union, get_args, get_origin @@ -14,7 +13,9 @@ from pydantic import BaseModel from pydantic.fields import FieldInfo from pydantic_core import PydanticUndefinedType -log = logging.getLogger(__name__) +from llama_stack.log import get_logger + +log = get_logger(name=__name__, category="core") def is_list_of_primitives(field_type): diff --git a/llama_stack/distributions/ci-tests/build.yaml b/llama_stack/distributions/ci-tests/build.yaml index e6e699b62..b4701cb81 100644 --- a/llama_stack/distributions/ci-tests/build.yaml +++ b/llama_stack/distributions/ci-tests/build.yaml @@ -28,12 +28,13 @@ distribution_spec: - provider_type: inline::localfs safety: - provider_type: inline::llama-guard + - provider_type: inline::code-scanner agents: - provider_type: inline::meta-reference telemetry: - provider_type: inline::meta-reference post_training: - - provider_type: inline::huggingface + - provider_type: inline::huggingface-cpu eval: - provider_type: inline::meta-reference datasetio: @@ -48,6 +49,8 @@ distribution_spec: - provider_type: remote::tavily-search - provider_type: inline::rag-runtime - provider_type: remote::model-context-protocol + batches: + - provider_type: inline::reference image_type: venv additional_pip_packages: - aiosqlite diff --git a/llama_stack/distributions/ci-tests/run.yaml b/llama_stack/distributions/ci-tests/run.yaml index 05e1b4576..3acdd20f9 100644 --- a/llama_stack/distributions/ci-tests/run.yaml +++ b/llama_stack/distributions/ci-tests/run.yaml @@ -2,6 +2,7 @@ version: 2 image_name: ci-tests apis: - agents +- batches - datasetio - eval - files @@ -134,6 +135,8 @@ providers: provider_type: inline::llama-guard config: excluded_categories: [] + - provider_id: code-scanner + provider_type: inline::code-scanner agents: - provider_id: meta-reference provider_type: inline::meta-reference @@ -153,8 +156,8 @@ providers: sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/trace_store.db otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=} post_training: - - provider_id: huggingface - provider_type: inline::huggingface + - provider_id: huggingface-cpu + provider_type: inline::huggingface-cpu config: checkpoint_format: huggingface distributed_backend: null @@ -204,6 +207,13 @@ providers: provider_type: inline::rag-runtime - provider_id: model-context-protocol provider_type: remote::model-context-protocol + batches: + - provider_id: reference + provider_type: inline::reference + config: + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/batches.db metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/registry.db @@ -215,6 +225,9 @@ shields: - shield_id: llama-guard provider_id: ${env.SAFETY_MODEL:+llama-guard} provider_shield_id: ${env.SAFETY_MODEL:=} +- shield_id: code-scanner + provider_id: ${env.CODE_SCANNER_MODEL:+code-scanner} + provider_shield_id: ${env.CODE_SCANNER_MODEL:=} vector_dbs: [] datasets: [] scoring_fns: [] diff --git a/llama_stack/distributions/starter-gpu/__init__.py b/llama_stack/distributions/starter-gpu/__init__.py new file mode 100644 index 000000000..e762f9b6e --- /dev/null +++ b/llama_stack/distributions/starter-gpu/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .starter_gpu import get_distribution_template # noqa: F401 diff --git a/llama_stack/distributions/starter-gpu/build.yaml b/llama_stack/distributions/starter-gpu/build.yaml new file mode 100644 index 000000000..ae0680cdc --- /dev/null +++ b/llama_stack/distributions/starter-gpu/build.yaml @@ -0,0 +1,59 @@ +version: 2 +distribution_spec: + description: Quick start template for running Llama Stack with several popular providers. + This distribution is intended for GPU-enabled environments. + providers: + inference: + - provider_type: remote::cerebras + - provider_type: remote::ollama + - provider_type: remote::vllm + - provider_type: remote::tgi + - provider_type: remote::fireworks + - provider_type: remote::together + - provider_type: remote::bedrock + - provider_type: remote::nvidia + - provider_type: remote::openai + - provider_type: remote::anthropic + - provider_type: remote::gemini + - provider_type: remote::vertexai + - provider_type: remote::groq + - provider_type: remote::sambanova + - provider_type: inline::sentence-transformers + vector_io: + - provider_type: inline::faiss + - provider_type: inline::sqlite-vec + - provider_type: inline::milvus + - provider_type: remote::chromadb + - provider_type: remote::pgvector + files: + - provider_type: inline::localfs + safety: + - provider_type: inline::llama-guard + - provider_type: inline::code-scanner + agents: + - provider_type: inline::meta-reference + telemetry: + - provider_type: inline::meta-reference + post_training: + - provider_type: inline::torchtune-gpu + eval: + - provider_type: inline::meta-reference + datasetio: + - provider_type: remote::huggingface + - provider_type: inline::localfs + scoring: + - provider_type: inline::basic + - provider_type: inline::llm-as-judge + - provider_type: inline::braintrust + tool_runtime: + - provider_type: remote::brave-search + - provider_type: remote::tavily-search + - provider_type: inline::rag-runtime + - provider_type: remote::model-context-protocol + batches: + - provider_type: inline::reference +image_type: venv +additional_pip_packages: +- aiosqlite +- asyncpg +- sqlalchemy[asyncio] diff --git a/llama_stack/distributions/starter-gpu/run.yaml b/llama_stack/distributions/starter-gpu/run.yaml new file mode 100644 index 000000000..81c802317 --- /dev/null +++ b/llama_stack/distributions/starter-gpu/run.yaml @@ -0,0 +1,238 @@ +version: 2 +image_name: starter-gpu +apis: +- agents +- batches +- datasetio +- eval +- files +- inference +- post_training +- safety +- scoring +- telemetry +- tool_runtime +- vector_io +providers: + inference: + - provider_id: ${env.CEREBRAS_API_KEY:+cerebras} + provider_type: remote::cerebras + config: + base_url: https://api.cerebras.ai + api_key: ${env.CEREBRAS_API_KEY:=} + - provider_id: ${env.OLLAMA_URL:+ollama} + provider_type: remote::ollama + config: + url: ${env.OLLAMA_URL:=http://localhost:11434} + - provider_id: ${env.VLLM_URL:+vllm} + provider_type: remote::vllm + config: + url: ${env.VLLM_URL:=} + max_tokens: ${env.VLLM_MAX_TOKENS:=4096} + api_token: ${env.VLLM_API_TOKEN:=fake} + tls_verify: ${env.VLLM_TLS_VERIFY:=true} + - provider_id: ${env.TGI_URL:+tgi} + provider_type: remote::tgi + config: + url: ${env.TGI_URL:=} + - provider_id: fireworks + provider_type: remote::fireworks + config: + url: https://api.fireworks.ai/inference/v1 + api_key: ${env.FIREWORKS_API_KEY:=} + - provider_id: together + provider_type: remote::together + config: + url: https://api.together.xyz/v1 + api_key: ${env.TOGETHER_API_KEY:=} + - provider_id: bedrock + provider_type: remote::bedrock + - provider_id: ${env.NVIDIA_API_KEY:+nvidia} + provider_type: remote::nvidia + config: + url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com} + api_key: ${env.NVIDIA_API_KEY:=} + append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True} + - provider_id: openai + provider_type: remote::openai + config: + api_key: ${env.OPENAI_API_KEY:=} + base_url: ${env.OPENAI_BASE_URL:=https://api.openai.com/v1} + - provider_id: anthropic + provider_type: remote::anthropic + config: + api_key: ${env.ANTHROPIC_API_KEY:=} + - provider_id: gemini + provider_type: remote::gemini + config: + api_key: ${env.GEMINI_API_KEY:=} + - provider_id: ${env.VERTEX_AI_PROJECT:+vertexai} + provider_type: remote::vertexai + config: + project: ${env.VERTEX_AI_PROJECT:=} + location: ${env.VERTEX_AI_LOCATION:=us-central1} + - provider_id: groq + provider_type: remote::groq + config: + url: https://api.groq.com + api_key: ${env.GROQ_API_KEY:=} + - provider_id: sambanova + provider_type: remote::sambanova + config: + url: https://api.sambanova.ai/v1 + api_key: ${env.SAMBANOVA_API_KEY:=} + - provider_id: sentence-transformers + provider_type: inline::sentence-transformers + vector_io: + - provider_id: faiss + provider_type: inline::faiss + config: + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/faiss_store.db + - provider_id: sqlite-vec + provider_type: inline::sqlite-vec + config: + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec.db + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec_registry.db + - provider_id: ${env.MILVUS_URL:+milvus} + provider_type: inline::milvus + config: + db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/starter}/milvus.db + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/milvus_registry.db + - provider_id: ${env.CHROMADB_URL:+chromadb} + provider_type: remote::chromadb + config: + url: ${env.CHROMADB_URL:=} + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter/}/chroma_remote_registry.db + - provider_id: ${env.PGVECTOR_DB:+pgvector} + provider_type: remote::pgvector + config: + host: ${env.PGVECTOR_HOST:=localhost} + port: ${env.PGVECTOR_PORT:=5432} + db: ${env.PGVECTOR_DB:=} + user: ${env.PGVECTOR_USER:=} + password: ${env.PGVECTOR_PASSWORD:=} + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/pgvector_registry.db + files: + - provider_id: meta-reference-files + provider_type: inline::localfs + config: + storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files} + metadata_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db + safety: + - provider_id: llama-guard + provider_type: inline::llama-guard + config: + excluded_categories: [] + - provider_id: code-scanner + provider_type: inline::code-scanner + agents: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + persistence_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/responses_store.db + telemetry: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + service_name: "${env.OTEL_SERVICE_NAME:=\u200B}" + sinks: ${env.TELEMETRY_SINKS:=console,sqlite} + sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/trace_store.db + otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=} + post_training: + - provider_id: torchtune-gpu + provider_type: inline::torchtune-gpu + config: + checkpoint_format: meta + eval: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/meta_reference_eval.db + datasetio: + - provider_id: huggingface + provider_type: remote::huggingface + config: + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/huggingface_datasetio.db + - provider_id: localfs + provider_type: inline::localfs + config: + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/localfs_datasetio.db + scoring: + - provider_id: basic + provider_type: inline::basic + - provider_id: llm-as-judge + provider_type: inline::llm-as-judge + - provider_id: braintrust + provider_type: inline::braintrust + config: + openai_api_key: ${env.OPENAI_API_KEY:=} + tool_runtime: + - provider_id: brave-search + provider_type: remote::brave-search + config: + api_key: ${env.BRAVE_SEARCH_API_KEY:=} + max_results: 3 + - provider_id: tavily-search + provider_type: remote::tavily-search + config: + api_key: ${env.TAVILY_SEARCH_API_KEY:=} + max_results: 3 + - provider_id: rag-runtime + provider_type: inline::rag-runtime + - provider_id: model-context-protocol + provider_type: remote::model-context-protocol + batches: + - provider_id: reference + provider_type: inline::reference + config: + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/batches.db +metadata_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/inference_store.db +models: [] +shields: +- shield_id: llama-guard + provider_id: ${env.SAFETY_MODEL:+llama-guard} + provider_shield_id: ${env.SAFETY_MODEL:=} +- shield_id: code-scanner + provider_id: ${env.CODE_SCANNER_MODEL:+code-scanner} + provider_shield_id: ${env.CODE_SCANNER_MODEL:=} +vector_dbs: [] +datasets: [] +scoring_fns: [] +benchmarks: [] +tool_groups: +- toolgroup_id: builtin::websearch + provider_id: tavily-search +- toolgroup_id: builtin::rag + provider_id: rag-runtime +server: + port: 8321 diff --git a/llama_stack/distributions/starter-gpu/starter_gpu.py b/llama_stack/distributions/starter-gpu/starter_gpu.py new file mode 100644 index 000000000..893df6c17 --- /dev/null +++ b/llama_stack/distributions/starter-gpu/starter_gpu.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +from llama_stack.distributions.template import BuildProvider, DistributionTemplate + +from ..starter.starter import get_distribution_template as get_starter_distribution_template + + +def get_distribution_template() -> DistributionTemplate: + template = get_starter_distribution_template() + name = "starter-gpu" + template.name = name + template.description = "Quick start template for running Llama Stack with several popular providers. This distribution is intended for GPU-enabled environments." + + template.providers["post_training"] = [ + BuildProvider(provider_type="inline::torchtune-gpu"), + ] + return template diff --git a/llama_stack/distributions/starter/build.yaml b/llama_stack/distributions/starter/build.yaml index 1a4f81d49..3df0eb129 100644 --- a/llama_stack/distributions/starter/build.yaml +++ b/llama_stack/distributions/starter/build.yaml @@ -1,6 +1,7 @@ version: 2 distribution_spec: - description: Quick start template for running Llama Stack with several popular providers + description: Quick start template for running Llama Stack with several popular providers. + This distribution is intended for CPU-only environments. providers: inference: - provider_type: remote::cerebras @@ -28,12 +29,13 @@ distribution_spec: - provider_type: inline::localfs safety: - provider_type: inline::llama-guard + - provider_type: inline::code-scanner agents: - provider_type: inline::meta-reference telemetry: - provider_type: inline::meta-reference post_training: - - provider_type: inline::huggingface + - provider_type: inline::huggingface-cpu eval: - provider_type: inline::meta-reference datasetio: @@ -48,6 +50,8 @@ distribution_spec: - provider_type: remote::tavily-search - provider_type: inline::rag-runtime - provider_type: remote::model-context-protocol + batches: + - provider_type: inline::reference image_type: venv additional_pip_packages: - aiosqlite diff --git a/llama_stack/distributions/starter/run.yaml b/llama_stack/distributions/starter/run.yaml index 46bd12956..7e1d46a61 100644 --- a/llama_stack/distributions/starter/run.yaml +++ b/llama_stack/distributions/starter/run.yaml @@ -2,6 +2,7 @@ version: 2 image_name: starter apis: - agents +- batches - datasetio - eval - files @@ -134,6 +135,8 @@ providers: provider_type: inline::llama-guard config: excluded_categories: [] + - provider_id: code-scanner + provider_type: inline::code-scanner agents: - provider_id: meta-reference provider_type: inline::meta-reference @@ -153,8 +156,8 @@ providers: sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/trace_store.db otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=} post_training: - - provider_id: huggingface - provider_type: inline::huggingface + - provider_id: huggingface-cpu + provider_type: inline::huggingface-cpu config: checkpoint_format: huggingface distributed_backend: null @@ -204,6 +207,13 @@ providers: provider_type: inline::rag-runtime - provider_id: model-context-protocol provider_type: remote::model-context-protocol + batches: + - provider_id: reference + provider_type: inline::reference + config: + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/batches.db metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/registry.db @@ -215,6 +225,9 @@ shields: - shield_id: llama-guard provider_id: ${env.SAFETY_MODEL:+llama-guard} provider_shield_id: ${env.SAFETY_MODEL:=} +- shield_id: code-scanner + provider_id: ${env.CODE_SCANNER_MODEL:+code-scanner} + provider_shield_id: ${env.CODE_SCANNER_MODEL:=} vector_dbs: [] datasets: [] scoring_fns: [] diff --git a/llama_stack/distributions/starter/starter.py b/llama_stack/distributions/starter/starter.py index 0270b68ad..f49da0bb7 100644 --- a/llama_stack/distributions/starter/starter.py +++ b/llama_stack/distributions/starter/starter.py @@ -15,19 +15,14 @@ from llama_stack.core.datatypes import ( ToolGroupInput, ) from llama_stack.core.utils.dynamic import instantiate_class_type -from llama_stack.distributions.template import ( - DistributionTemplate, - RunConfigSettings, -) +from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings from llama_stack.providers.datatypes import RemoteProviderSpec from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig from llama_stack.providers.inline.inference.sentence_transformers import ( SentenceTransformersInferenceConfig, ) from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig -from llama_stack.providers.inline.vector_io.milvus.config import ( - MilvusVectorIOConfig, -) +from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig from llama_stack.providers.inline.vector_io.sqlite_vec.config import ( SQLiteVectorIOConfig, ) @@ -119,10 +114,13 @@ def get_distribution_template() -> DistributionTemplate: BuildProvider(provider_type="remote::pgvector"), ], "files": [BuildProvider(provider_type="inline::localfs")], - "safety": [BuildProvider(provider_type="inline::llama-guard")], + "safety": [ + BuildProvider(provider_type="inline::llama-guard"), + BuildProvider(provider_type="inline::code-scanner"), + ], "agents": [BuildProvider(provider_type="inline::meta-reference")], "telemetry": [BuildProvider(provider_type="inline::meta-reference")], - "post_training": [BuildProvider(provider_type="inline::huggingface")], + "post_training": [BuildProvider(provider_type="inline::huggingface-cpu")], "eval": [BuildProvider(provider_type="inline::meta-reference")], "datasetio": [ BuildProvider(provider_type="remote::huggingface"), @@ -139,6 +137,9 @@ def get_distribution_template() -> DistributionTemplate: BuildProvider(provider_type="inline::rag-runtime"), BuildProvider(provider_type="remote::model-context-protocol"), ], + "batches": [ + BuildProvider(provider_type="inline::reference"), + ], } files_provider = Provider( provider_id="meta-reference-files", @@ -167,12 +168,17 @@ def get_distribution_template() -> DistributionTemplate: provider_id="${env.SAFETY_MODEL:+llama-guard}", provider_shield_id="${env.SAFETY_MODEL:=}", ), + ShieldInput( + shield_id="code-scanner", + provider_id="${env.CODE_SCANNER_MODEL:+code-scanner}", + provider_shield_id="${env.CODE_SCANNER_MODEL:=}", + ), ] return DistributionTemplate( name=name, distro_type="self_hosted", - description="Quick start template for running Llama Stack with several popular providers", + description="Quick start template for running Llama Stack with several popular providers. This distribution is intended for CPU-only environments.", container_image=None, template_path=None, providers=providers, diff --git a/llama_stack/log.py b/llama_stack/log.py index 7507aface..cc4c9d4cf 100644 --- a/llama_stack/log.py +++ b/llama_stack/log.py @@ -4,16 +4,14 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging +import logging # allow-direct-logging import os import re -import sys -from logging.config import dictConfig +from logging.config import dictConfig # allow-direct-logging from rich.console import Console from rich.errors import MarkupError from rich.logging import RichHandler -from termcolor import cprint from llama_stack.core.datatypes import LoggingConfig @@ -66,7 +64,6 @@ def config_to_category_levels(category: str, level: str): category_levels["root"] = level_value elif category in CATEGORIES: category_levels[category] = level_value - logging.info(f"Setting '{category}' category to level '{level}'.") else: logging.warning(f"Unknown logging category: {category}. No changes made.") return category_levels @@ -256,7 +253,6 @@ def get_logger( env_config = os.environ.get("LLAMA_STACK_LOGGING", "") if env_config: - cprint(f"Environment variable LLAMA_STACK_LOGGING found: {env_config}", color="yellow", file=sys.stderr) _category_levels.update(parse_environment_config(env_config)) log_file = os.environ.get("LLAMA_STACK_LOG_FILE") diff --git a/llama_stack/models/llama/llama3/multimodal/encoder_utils.py b/llama_stack/models/llama/llama3/multimodal/encoder_utils.py index 5b5969d89..90ced13b2 100644 --- a/llama_stack/models/llama/llama3/multimodal/encoder_utils.py +++ b/llama_stack/models/llama/llama3/multimodal/encoder_utils.py @@ -13,14 +13,15 @@ # Copyright (c) Meta Platforms, Inc. and its affiliates. import math -from logging import getLogger import torch import torch.nn.functional as F +from llama_stack.log import get_logger + from .utils import get_negative_inf_value, to_2tuple -logger = getLogger() +logger = get_logger(name=__name__, category="models::llama") def resize_local_position_embedding(orig_pos_embed, grid_size): diff --git a/llama_stack/models/llama/llama3/multimodal/image_transform.py b/llama_stack/models/llama/llama3/multimodal/image_transform.py index f2761ee47..7b20a31fa 100644 --- a/llama_stack/models/llama/llama3/multimodal/image_transform.py +++ b/llama_stack/models/llama/llama3/multimodal/image_transform.py @@ -13,7 +13,6 @@ import math from collections import defaultdict -from logging import getLogger from typing import Any import torch @@ -21,9 +20,11 @@ import torchvision.transforms as tv from PIL import Image from torchvision.transforms import functional as F +from llama_stack.log import get_logger + IMAGE_RES = 224 -logger = getLogger() +logger = get_logger(name=__name__, category="models::llama") class VariableSizeImageTransform: diff --git a/llama_stack/models/llama/llama3/multimodal/model.py b/llama_stack/models/llama/llama3/multimodal/model.py index 5f1c3605c..7b501eb0e 100644 --- a/llama_stack/models/llama/llama3/multimodal/model.py +++ b/llama_stack/models/llama/llama3/multimodal/model.py @@ -3,8 +3,6 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. - -import logging import math from collections.abc import Callable from functools import partial @@ -22,6 +20,8 @@ from PIL import Image as PIL_Image from torch import Tensor, nn from torch.distributed import _functional_collectives as funcol +from llama_stack.log import get_logger + from ..model import ModelArgs, RMSNorm, apply_rotary_emb, precompute_freqs_cis from .encoder_utils import ( build_encoder_attention_mask, @@ -34,9 +34,10 @@ from .encoder_utils import ( from .image_transform import VariableSizeImageTransform from .utils import get_negative_inf_value, to_2tuple -logger = logging.getLogger(__name__) MP_SCALE = 8 +logger = get_logger(name=__name__, category="models::llama") + def reduce_from_tensor_model_parallel_region(input_): """All-reduce the input tensor across model parallel group.""" @@ -771,7 +772,7 @@ class TilePositionEmbedding(nn.Module): if embed is not None: # reshape the weights to the correct shape nt_old, nt_old, _, w = embed.shape - logging.info(f"Resizing tile embedding from {nt_old}x{nt_old} to {self.num_tiles}x{self.num_tiles}") + logger.info(f"Resizing tile embedding from {nt_old}x{nt_old} to {self.num_tiles}x{self.num_tiles}") embed_new = TilePositionEmbedding._dynamic_resize(embed, self.num_tiles) # assign the weights to the module state_dict[prefix + "embedding"] = embed_new diff --git a/llama_stack/models/llama/llama3/tokenizer.py b/llama_stack/models/llama/llama3/tokenizer.py index e47b579e3..ad7ced1c5 100644 --- a/llama_stack/models/llama/llama3/tokenizer.py +++ b/llama_stack/models/llama/llama3/tokenizer.py @@ -4,8 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. + from collections.abc import Collection, Iterator, Sequence, Set -from logging import getLogger from pathlib import Path from typing import ( Literal, @@ -14,11 +14,9 @@ from typing import ( import tiktoken +from llama_stack.log import get_logger from llama_stack.models.llama.tokenizer_utils import load_bpe_file -logger = getLogger(__name__) - - # The tiktoken tokenizer can handle <=400k chars without # pyo3_runtime.PanicException. TIKTOKEN_MAX_ENCODE_CHARS = 400_000 @@ -31,6 +29,8 @@ MAX_NO_WHITESPACES_CHARS = 25_000 _INSTANCE = None +logger = get_logger(name=__name__, category="models::llama") + class Tokenizer: """ diff --git a/llama_stack/models/llama/llama3/tool_utils.py b/llama_stack/models/llama/llama3/tool_utils.py index 574080184..d0e3e7671 100644 --- a/llama_stack/models/llama/llama3/tool_utils.py +++ b/llama_stack/models/llama/llama3/tool_utils.py @@ -11,7 +11,7 @@ from llama_stack.log import get_logger from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat -logger = get_logger(name=__name__, category="inference") +logger = get_logger(name=__name__, category="models::llama") BUILTIN_TOOL_PATTERN = r'\b(?P\w+)\.call\(query="(?P[^"]*)"\)' CUSTOM_TOOL_CALL_PATTERN = re.compile(r"[^}]+)>(?P{.*?})") diff --git a/llama_stack/models/llama/llama4/quantization/loader.py b/llama_stack/models/llama/llama4/quantization/loader.py index 223744a5f..7557a8a64 100644 --- a/llama_stack/models/llama/llama4/quantization/loader.py +++ b/llama_stack/models/llama/llama4/quantization/loader.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging import os from collections.abc import Callable @@ -13,11 +12,13 @@ from fairscale.nn.model_parallel.initialize import get_model_parallel_rank from torch import Tensor, nn from torch.nn import functional as F +from llama_stack.log import get_logger + from ...datatypes import QuantizationMode from ..model import Transformer, TransformerBlock from ..moe import MoE -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="models::llama") def swiglu_wrapper_no_reduce( diff --git a/llama_stack/models/llama/llama4/tokenizer.py b/llama_stack/models/llama/llama4/tokenizer.py index e12b2cae0..bfbace8f9 100644 --- a/llama_stack/models/llama/llama4/tokenizer.py +++ b/llama_stack/models/llama/llama4/tokenizer.py @@ -5,7 +5,6 @@ # the root directory of this source tree. from collections.abc import Collection, Iterator, Sequence, Set -from logging import getLogger from pathlib import Path from typing import ( Literal, @@ -14,11 +13,9 @@ from typing import ( import tiktoken +from llama_stack.log import get_logger from llama_stack.models.llama.tokenizer_utils import load_bpe_file -logger = getLogger(__name__) - - # The tiktoken tokenizer can handle <=400k chars without # pyo3_runtime.PanicException. TIKTOKEN_MAX_ENCODE_CHARS = 400_000 @@ -101,6 +98,8 @@ BASIC_SPECIAL_TOKENS = [ "<|fim_suffix|>", ] +logger = get_logger(name=__name__, category="models::llama") + class Tokenizer: """ diff --git a/llama_stack/models/llama/quantize_impls.py b/llama_stack/models/llama/quantize_impls.py index a6400c5c9..0a205601f 100644 --- a/llama_stack/models/llama/quantize_impls.py +++ b/llama_stack/models/llama/quantize_impls.py @@ -6,9 +6,10 @@ # type: ignore import collections -import logging -log = logging.getLogger(__name__) +from llama_stack.log import get_logger + +log = get_logger(name=__name__, category="models::llama") try: import fbgemm_gpu.experimental.gen_ai # noqa: F401 diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 5f7c90879..fde38515b 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -84,7 +84,7 @@ MEMORY_QUERY_TOOL = "knowledge_search" WEB_SEARCH_TOOL = "web_search" RAG_TOOL_GROUP = "builtin::rag" -logger = get_logger(name=__name__, category="agents") +logger = get_logger(name=__name__, category="agents::meta_reference") class ChatAgent(ShieldRunnerMixin): diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index 0f12a0865..8bdde86b0 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging import uuid from collections.abc import AsyncGenerator from datetime import UTC, datetime @@ -42,16 +41,17 @@ from llama_stack.apis.safety import Safety from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.apis.vector_io import VectorIO from llama_stack.core.datatypes import AccessRule +from llama_stack.log import get_logger from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl from llama_stack.providers.utils.pagination import paginate_records from llama_stack.providers.utils.responses.responses_store import ResponsesStore from .agent_instance import ChatAgent from .config import MetaReferenceAgentsImplConfig -from .openai_responses import OpenAIResponsesImpl from .persistence import AgentInfo +from .responses.openai_responses import OpenAIResponsesImpl -logger = logging.getLogger() +logger = get_logger(name=__name__, category="agents::meta_reference") class MetaReferenceAgentsImpl(Agents): diff --git a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py b/llama_stack/providers/inline/agents/meta_reference/openai_responses.py deleted file mode 100644 index 104f15010..000000000 --- a/llama_stack/providers/inline/agents/meta_reference/openai_responses.py +++ /dev/null @@ -1,989 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import asyncio -import json -import time -import uuid -from collections.abc import AsyncIterator -from typing import Any - -from openai.types.chat import ChatCompletionToolParam -from pydantic import BaseModel - -from llama_stack.apis.agents import Order -from llama_stack.apis.agents.openai_responses import ( - AllowedToolsFilter, - ListOpenAIResponseInputItem, - ListOpenAIResponseObject, - OpenAIDeleteResponseObject, - OpenAIResponseInput, - OpenAIResponseInputFunctionToolCallOutput, - OpenAIResponseInputMessageContent, - OpenAIResponseInputMessageContentImage, - OpenAIResponseInputMessageContentText, - OpenAIResponseInputTool, - OpenAIResponseInputToolFileSearch, - OpenAIResponseInputToolMCP, - OpenAIResponseMessage, - OpenAIResponseObject, - OpenAIResponseObjectStream, - OpenAIResponseObjectStreamResponseCompleted, - OpenAIResponseObjectStreamResponseCreated, - OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta, - OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone, - OpenAIResponseObjectStreamResponseOutputItemAdded, - OpenAIResponseObjectStreamResponseOutputItemDone, - OpenAIResponseObjectStreamResponseOutputTextDelta, - OpenAIResponseOutput, - OpenAIResponseOutputMessageContent, - OpenAIResponseOutputMessageContentOutputText, - OpenAIResponseOutputMessageFileSearchToolCall, - OpenAIResponseOutputMessageFileSearchToolCallResults, - OpenAIResponseOutputMessageFunctionToolCall, - OpenAIResponseOutputMessageMCPListTools, - OpenAIResponseOutputMessageWebSearchToolCall, - OpenAIResponseText, - OpenAIResponseTextFormat, - WebSearchToolTypes, -) -from llama_stack.apis.common.content_types import TextContentItem -from llama_stack.apis.inference import ( - Inference, - OpenAIAssistantMessageParam, - OpenAIChatCompletion, - OpenAIChatCompletionContentPartImageParam, - OpenAIChatCompletionContentPartParam, - OpenAIChatCompletionContentPartTextParam, - OpenAIChatCompletionToolCall, - OpenAIChatCompletionToolCallFunction, - OpenAIChoice, - OpenAIDeveloperMessageParam, - OpenAIImageURL, - OpenAIJSONSchema, - OpenAIMessageParam, - OpenAIResponseFormatJSONObject, - OpenAIResponseFormatJSONSchema, - OpenAIResponseFormatParam, - OpenAIResponseFormatText, - OpenAISystemMessageParam, - OpenAIToolMessageParam, - OpenAIUserMessageParam, -) -from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime -from llama_stack.apis.vector_io import VectorIO -from llama_stack.log import get_logger -from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition -from llama_stack.providers.utils.inference.openai_compat import ( - convert_tooldef_to_openai_tool, -) -from llama_stack.providers.utils.responses.responses_store import ResponsesStore - -logger = get_logger(name=__name__, category="openai_responses") - -OPENAI_RESPONSES_PREFIX = "openai_responses:" - - -async def _convert_response_content_to_chat_content( - content: (str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]), -) -> str | list[OpenAIChatCompletionContentPartParam]: - """ - Convert the content parts from an OpenAI Response API request into OpenAI Chat Completion content parts. - - The content schemas of each API look similar, but are not exactly the same. - """ - if isinstance(content, str): - return content - - converted_parts = [] - for content_part in content: - if isinstance(content_part, OpenAIResponseInputMessageContentText): - converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text)) - elif isinstance(content_part, OpenAIResponseOutputMessageContentOutputText): - converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text)) - elif isinstance(content_part, OpenAIResponseInputMessageContentImage): - if content_part.image_url: - image_url = OpenAIImageURL(url=content_part.image_url, detail=content_part.detail) - converted_parts.append(OpenAIChatCompletionContentPartImageParam(image_url=image_url)) - elif isinstance(content_part, str): - converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part)) - else: - raise ValueError( - f"Llama Stack OpenAI Responses does not yet support content type '{type(content_part)}' in this context" - ) - return converted_parts - - -async def _convert_response_input_to_chat_messages( - input: str | list[OpenAIResponseInput], -) -> list[OpenAIMessageParam]: - """ - Convert the input from an OpenAI Response API request into OpenAI Chat Completion messages. - """ - messages: list[OpenAIMessageParam] = [] - if isinstance(input, list): - for input_item in input: - if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput): - messages.append( - OpenAIToolMessageParam( - content=input_item.output, - tool_call_id=input_item.call_id, - ) - ) - elif isinstance(input_item, OpenAIResponseOutputMessageFunctionToolCall): - tool_call = OpenAIChatCompletionToolCall( - index=0, - id=input_item.call_id, - function=OpenAIChatCompletionToolCallFunction( - name=input_item.name, - arguments=input_item.arguments, - ), - ) - messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call])) - else: - content = await _convert_response_content_to_chat_content(input_item.content) - message_type = await _get_message_type_by_role(input_item.role) - if message_type is None: - raise ValueError( - f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context" - ) - messages.append(message_type(content=content)) - else: - messages.append(OpenAIUserMessageParam(content=input)) - return messages - - -async def _convert_chat_choice_to_response_message( - choice: OpenAIChoice, -) -> OpenAIResponseMessage: - """ - Convert an OpenAI Chat Completion choice into an OpenAI Response output message. - """ - output_content = "" - if isinstance(choice.message.content, str): - output_content = choice.message.content - elif isinstance(choice.message.content, OpenAIChatCompletionContentPartTextParam): - output_content = choice.message.content.text - else: - raise ValueError( - f"Llama Stack OpenAI Responses does not yet support output content type: {type(choice.message.content)}" - ) - - return OpenAIResponseMessage( - id=f"msg_{uuid.uuid4()}", - content=[OpenAIResponseOutputMessageContentOutputText(text=output_content)], - status="completed", - role="assistant", - ) - - -async def _convert_response_text_to_chat_response_format( - text: OpenAIResponseText, -) -> OpenAIResponseFormatParam: - """ - Convert an OpenAI Response text parameter into an OpenAI Chat Completion response format. - """ - if not text.format or text.format["type"] == "text": - return OpenAIResponseFormatText(type="text") - if text.format["type"] == "json_object": - return OpenAIResponseFormatJSONObject() - if text.format["type"] == "json_schema": - return OpenAIResponseFormatJSONSchema( - json_schema=OpenAIJSONSchema(name=text.format["name"], schema=text.format["schema"]) - ) - raise ValueError(f"Unsupported text format: {text.format}") - - -async def _get_message_type_by_role(role: str): - role_to_type = { - "user": OpenAIUserMessageParam, - "system": OpenAISystemMessageParam, - "assistant": OpenAIAssistantMessageParam, - "developer": OpenAIDeveloperMessageParam, - } - return role_to_type.get(role) - - -class OpenAIResponsePreviousResponseWithInputItems(BaseModel): - input_items: ListOpenAIResponseInputItem - response: OpenAIResponseObject - - -class ChatCompletionContext(BaseModel): - model: str - messages: list[OpenAIMessageParam] - response_tools: list[OpenAIResponseInputTool] | None = None - chat_tools: list[ChatCompletionToolParam] | None = None - mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] - temperature: float | None - response_format: OpenAIResponseFormatParam - - -class OpenAIResponsesImpl: - def __init__( - self, - inference_api: Inference, - tool_groups_api: ToolGroups, - tool_runtime_api: ToolRuntime, - responses_store: ResponsesStore, - vector_io_api: VectorIO, # VectorIO - ): - self.inference_api = inference_api - self.tool_groups_api = tool_groups_api - self.tool_runtime_api = tool_runtime_api - self.responses_store = responses_store - self.vector_io_api = vector_io_api - - async def _prepend_previous_response( - self, - input: str | list[OpenAIResponseInput], - previous_response_id: str | None = None, - ): - if previous_response_id: - previous_response_with_input = await self.responses_store.get_response_object(previous_response_id) - - # previous response input items - new_input_items = previous_response_with_input.input - - # previous response output items - new_input_items.extend(previous_response_with_input.output) - - # new input items from the current request - if isinstance(input, str): - new_input_items.append(OpenAIResponseMessage(content=input, role="user")) - else: - new_input_items.extend(input) - - input = new_input_items - - return input - - async def _prepend_instructions(self, messages, instructions): - if instructions: - messages.insert(0, OpenAISystemMessageParam(content=instructions)) - - async def get_openai_response( - self, - response_id: str, - ) -> OpenAIResponseObject: - response_with_input = await self.responses_store.get_response_object(response_id) - return OpenAIResponseObject(**{k: v for k, v in response_with_input.model_dump().items() if k != "input"}) - - async def list_openai_responses( - self, - after: str | None = None, - limit: int | None = 50, - model: str | None = None, - order: Order | None = Order.desc, - ) -> ListOpenAIResponseObject: - return await self.responses_store.list_responses(after, limit, model, order) - - async def list_openai_response_input_items( - self, - response_id: str, - after: str | None = None, - before: str | None = None, - include: list[str] | None = None, - limit: int | None = 20, - order: Order | None = Order.desc, - ) -> ListOpenAIResponseInputItem: - """List input items for a given OpenAI response. - - :param response_id: The ID of the response to retrieve input items for. - :param after: An item ID to list items after, used for pagination. - :param before: An item ID to list items before, used for pagination. - :param include: Additional fields to include in the response. - :param limit: A limit on the number of objects to be returned. - :param order: The order to return the input items in. - :returns: An ListOpenAIResponseInputItem. - """ - return await self.responses_store.list_response_input_items(response_id, after, before, include, limit, order) - - async def _store_response( - self, - response: OpenAIResponseObject, - input: str | list[OpenAIResponseInput], - ) -> None: - new_input_id = f"msg_{uuid.uuid4()}" - if isinstance(input, str): - # synthesize a message from the input string - input_content = OpenAIResponseInputMessageContentText(text=input) - input_content_item = OpenAIResponseMessage( - role="user", - content=[input_content], - id=new_input_id, - ) - input_items_data = [input_content_item] - else: - # we already have a list of messages - input_items_data = [] - for input_item in input: - if isinstance(input_item, OpenAIResponseMessage): - # These may or may not already have an id, so dump to dict, check for id, and add if missing - input_item_dict = input_item.model_dump() - if "id" not in input_item_dict: - input_item_dict["id"] = new_input_id - input_items_data.append(OpenAIResponseMessage(**input_item_dict)) - else: - input_items_data.append(input_item) - - await self.responses_store.store_response_object( - response_object=response, - input=input_items_data, - ) - - async def create_openai_response( - self, - input: str | list[OpenAIResponseInput], - model: str, - instructions: str | None = None, - previous_response_id: str | None = None, - store: bool | None = True, - stream: bool | None = False, - temperature: float | None = None, - text: OpenAIResponseText | None = None, - tools: list[OpenAIResponseInputTool] | None = None, - include: list[str] | None = None, - max_infer_iters: int | None = 10, - ): - stream = bool(stream) - text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text - - stream_gen = self._create_streaming_response( - input=input, - model=model, - instructions=instructions, - previous_response_id=previous_response_id, - store=store, - temperature=temperature, - text=text, - tools=tools, - max_infer_iters=max_infer_iters, - ) - - if stream: - return stream_gen - else: - response = None - async for stream_chunk in stream_gen: - if stream_chunk.type == "response.completed": - if response is not None: - raise ValueError("The response stream completed multiple times! Earlier response: {response}") - response = stream_chunk.response - # don't leave the generator half complete! - - if response is None: - raise ValueError("The response stream never completed") - return response - - async def _create_streaming_response( - self, - input: str | list[OpenAIResponseInput], - model: str, - instructions: str | None = None, - previous_response_id: str | None = None, - store: bool | None = True, - temperature: float | None = None, - text: OpenAIResponseText | None = None, - tools: list[OpenAIResponseInputTool] | None = None, - max_infer_iters: int | None = 10, - ) -> AsyncIterator[OpenAIResponseObjectStream]: - output_messages: list[OpenAIResponseOutput] = [] - - # Input preprocessing - input = await self._prepend_previous_response(input, previous_response_id) - messages = await _convert_response_input_to_chat_messages(input) - await self._prepend_instructions(messages, instructions) - - # Structured outputs - response_format = await _convert_response_text_to_chat_response_format(text) - - # Tool setup, TODO: refactor this slightly since this can also yield events - chat_tools, mcp_tool_to_server, mcp_list_message = ( - await self._convert_response_tools_to_chat_tools(tools) if tools else (None, {}, None) - ) - if mcp_list_message: - output_messages.append(mcp_list_message) - - ctx = ChatCompletionContext( - model=model, - messages=messages, - response_tools=tools, - chat_tools=chat_tools, - mcp_tool_to_server=mcp_tool_to_server, - temperature=temperature, - response_format=response_format, - ) - - # Create initial response and emit response.created immediately - response_id = f"resp-{uuid.uuid4()}" - created_at = int(time.time()) - - initial_response = OpenAIResponseObject( - created_at=created_at, - id=response_id, - model=model, - object="response", - status="in_progress", - output=output_messages.copy(), - text=text, - ) - - yield OpenAIResponseObjectStreamResponseCreated(response=initial_response) - - n_iter = 0 - messages = ctx.messages.copy() - - while True: - completion_result = await self.inference_api.openai_chat_completion( - model=ctx.model, - messages=messages, - tools=ctx.chat_tools, - stream=True, - temperature=ctx.temperature, - response_format=ctx.response_format, - ) - - # Process streaming chunks and build complete response - chat_response_id = "" - chat_response_content = [] - chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {} - chunk_created = 0 - chunk_model = "" - chunk_finish_reason = "" - sequence_number = 0 - - # Create a placeholder message item for delta events - message_item_id = f"msg_{uuid.uuid4()}" - # Track tool call items for streaming events - tool_call_item_ids: dict[int, str] = {} - - async for chunk in completion_result: - chat_response_id = chunk.id - chunk_created = chunk.created - chunk_model = chunk.model - for chunk_choice in chunk.choices: - # Emit incremental text content as delta events - if chunk_choice.delta.content: - sequence_number += 1 - yield OpenAIResponseObjectStreamResponseOutputTextDelta( - content_index=0, - delta=chunk_choice.delta.content, - item_id=message_item_id, - output_index=0, - sequence_number=sequence_number, - ) - - # Collect content for final response - chat_response_content.append(chunk_choice.delta.content or "") - if chunk_choice.finish_reason: - chunk_finish_reason = chunk_choice.finish_reason - - # Aggregate tool call arguments across chunks - if chunk_choice.delta.tool_calls: - for tool_call in chunk_choice.delta.tool_calls: - response_tool_call = chat_response_tool_calls.get(tool_call.index, None) - # Create new tool call entry if this is the first chunk for this index - is_new_tool_call = response_tool_call is None - if is_new_tool_call: - tool_call_dict: dict[str, Any] = tool_call.model_dump() - tool_call_dict.pop("type", None) - response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict) - chat_response_tool_calls[tool_call.index] = response_tool_call - - # Create item ID for this tool call for streaming events - tool_call_item_id = f"fc_{uuid.uuid4()}" - tool_call_item_ids[tool_call.index] = tool_call_item_id - - # Emit output_item.added event for the new function call - sequence_number += 1 - function_call_item = OpenAIResponseOutputMessageFunctionToolCall( - arguments="", # Will be filled incrementally via delta events - call_id=tool_call.id or "", - name=tool_call.function.name if tool_call.function else "", - id=tool_call_item_id, - status="in_progress", - ) - yield OpenAIResponseObjectStreamResponseOutputItemAdded( - response_id=response_id, - item=function_call_item, - output_index=len(output_messages), - sequence_number=sequence_number, - ) - - # Stream function call arguments as they arrive - if tool_call.function and tool_call.function.arguments: - tool_call_item_id = tool_call_item_ids[tool_call.index] - sequence_number += 1 - yield OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta( - delta=tool_call.function.arguments, - item_id=tool_call_item_id, - output_index=len(output_messages), - sequence_number=sequence_number, - ) - - # Accumulate arguments for final response (only for subsequent chunks) - if not is_new_tool_call: - response_tool_call.function.arguments = ( - response_tool_call.function.arguments or "" - ) + tool_call.function.arguments - - # Emit function_call_arguments.done events for completed tool calls - for tool_call_index in sorted(chat_response_tool_calls.keys()): - tool_call_item_id = tool_call_item_ids[tool_call_index] - final_arguments = chat_response_tool_calls[tool_call_index].function.arguments or "" - sequence_number += 1 - yield OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone( - arguments=final_arguments, - item_id=tool_call_item_id, - output_index=len(output_messages), - sequence_number=sequence_number, - ) - - # Convert collected chunks to complete response - if chat_response_tool_calls: - tool_calls = [chat_response_tool_calls[i] for i in sorted(chat_response_tool_calls.keys())] - - # when there are tool calls, we need to clear the content - chat_response_content = [] - else: - tool_calls = None - - assistant_message = OpenAIAssistantMessageParam( - content="".join(chat_response_content), - tool_calls=tool_calls, - ) - current_response = OpenAIChatCompletion( - id=chat_response_id, - choices=[ - OpenAIChoice( - message=assistant_message, - finish_reason=chunk_finish_reason, - index=0, - ) - ], - created=chunk_created, - model=chunk_model, - ) - - function_tool_calls = [] - non_function_tool_calls = [] - - next_turn_messages = messages.copy() - for choice in current_response.choices: - next_turn_messages.append(choice.message) - - if choice.message.tool_calls and tools: - for tool_call in choice.message.tool_calls: - if _is_function_tool_call(tool_call, tools): - function_tool_calls.append(tool_call) - else: - non_function_tool_calls.append(tool_call) - else: - output_messages.append(await _convert_chat_choice_to_response_message(choice)) - - # execute non-function tool calls - for tool_call in non_function_tool_calls: - tool_call_log, tool_response_message = await self._execute_tool_call(tool_call, ctx) - if tool_call_log: - output_messages.append(tool_call_log) - - # Emit output_item.done event for completed non-function tool call - # Find the item_id for this tool call - matching_item_id = None - for index, item_id in tool_call_item_ids.items(): - response_tool_call = chat_response_tool_calls.get(index) - if response_tool_call and response_tool_call.id == tool_call.id: - matching_item_id = item_id - break - - if matching_item_id: - sequence_number += 1 - yield OpenAIResponseObjectStreamResponseOutputItemDone( - response_id=response_id, - item=tool_call_log, - output_index=len(output_messages) - 1, - sequence_number=sequence_number, - ) - - if tool_response_message: - next_turn_messages.append(tool_response_message) - - for tool_call in function_tool_calls: - # Find the item_id for this tool call from our tracking dictionary - matching_item_id = None - for index, item_id in tool_call_item_ids.items(): - response_tool_call = chat_response_tool_calls.get(index) - if response_tool_call and response_tool_call.id == tool_call.id: - matching_item_id = item_id - break - - # Use existing item_id or create new one if not found - final_item_id = matching_item_id or f"fc_{uuid.uuid4()}" - - function_call_item = OpenAIResponseOutputMessageFunctionToolCall( - arguments=tool_call.function.arguments or "", - call_id=tool_call.id, - name=tool_call.function.name or "", - id=final_item_id, - status="completed", - ) - output_messages.append(function_call_item) - - # Emit output_item.done event for completed function call - sequence_number += 1 - yield OpenAIResponseObjectStreamResponseOutputItemDone( - response_id=response_id, - item=function_call_item, - output_index=len(output_messages) - 1, - sequence_number=sequence_number, - ) - - if not function_tool_calls and not non_function_tool_calls: - break - - if function_tool_calls: - logger.info("Exiting inference loop since there is a function (client-side) tool call") - break - - n_iter += 1 - if n_iter >= max_infer_iters: - logger.info(f"Exiting inference loop since iteration count({n_iter}) exceeds {max_infer_iters=}") - break - - messages = next_turn_messages - - # Create final response - final_response = OpenAIResponseObject( - created_at=created_at, - id=response_id, - model=model, - object="response", - status="completed", - text=text, - output=output_messages, - ) - - # Emit response.completed - yield OpenAIResponseObjectStreamResponseCompleted(response=final_response) - - if store: - await self._store_response( - response=final_response, - input=input, - ) - - async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject: - return await self.responses_store.delete_response_object(response_id) - - async def _convert_response_tools_to_chat_tools( - self, tools: list[OpenAIResponseInputTool] - ) -> tuple[ - list[ChatCompletionToolParam], - dict[str, OpenAIResponseInputToolMCP], - OpenAIResponseOutput | None, - ]: - from llama_stack.apis.agents.openai_responses import ( - MCPListToolsTool, - ) - from llama_stack.apis.tools import Tool - - mcp_tool_to_server = {} - - def make_openai_tool(tool_name: str, tool: Tool) -> ChatCompletionToolParam: - tool_def = ToolDefinition( - tool_name=tool_name, - description=tool.description, - parameters={ - param.name: ToolParamDefinition( - param_type=param.parameter_type, - description=param.description, - required=param.required, - default=param.default, - ) - for param in tool.parameters - }, - ) - return convert_tooldef_to_openai_tool(tool_def) - - mcp_list_message = None - chat_tools: list[ChatCompletionToolParam] = [] - for input_tool in tools: - # TODO: Handle other tool types - if input_tool.type == "function": - chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump())) - elif input_tool.type in WebSearchToolTypes: - tool_name = "web_search" - tool = await self.tool_groups_api.get_tool(tool_name) - if not tool: - raise ValueError(f"Tool {tool_name} not found") - chat_tools.append(make_openai_tool(tool_name, tool)) - elif input_tool.type == "file_search": - tool_name = "knowledge_search" - tool = await self.tool_groups_api.get_tool(tool_name) - if not tool: - raise ValueError(f"Tool {tool_name} not found") - chat_tools.append(make_openai_tool(tool_name, tool)) - elif input_tool.type == "mcp": - from llama_stack.providers.utils.tools.mcp import list_mcp_tools - - always_allowed = None - never_allowed = None - if input_tool.allowed_tools: - if isinstance(input_tool.allowed_tools, list): - always_allowed = input_tool.allowed_tools - elif isinstance(input_tool.allowed_tools, AllowedToolsFilter): - always_allowed = input_tool.allowed_tools.always - never_allowed = input_tool.allowed_tools.never - - tool_defs = await list_mcp_tools( - endpoint=input_tool.server_url, - headers=input_tool.headers or {}, - ) - - mcp_list_message = OpenAIResponseOutputMessageMCPListTools( - id=f"mcp_list_{uuid.uuid4()}", - status="completed", - server_label=input_tool.server_label, - tools=[], - ) - for t in tool_defs.data: - if never_allowed and t.name in never_allowed: - continue - if not always_allowed or t.name in always_allowed: - chat_tools.append(make_openai_tool(t.name, t)) - if t.name in mcp_tool_to_server: - raise ValueError(f"Duplicate tool name {t.name} found for server {input_tool.server_label}") - mcp_tool_to_server[t.name] = input_tool - mcp_list_message.tools.append( - MCPListToolsTool( - name=t.name, - description=t.description, - input_schema={ - "type": "object", - "properties": { - p.name: { - "type": p.parameter_type, - "description": p.description, - } - for p in t.parameters - }, - "required": [p.name for p in t.parameters if p.required], - }, - ) - ) - else: - raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}") - return chat_tools, mcp_tool_to_server, mcp_list_message - - async def _execute_knowledge_search_via_vector_store( - self, - query: str, - response_file_search_tool: OpenAIResponseInputToolFileSearch, - ) -> ToolInvocationResult: - """Execute knowledge search using vector_stores.search API with filters support.""" - search_results = [] - - # Create search tasks for all vector stores - async def search_single_store(vector_store_id): - try: - search_response = await self.vector_io_api.openai_search_vector_store( - vector_store_id=vector_store_id, - query=query, - filters=response_file_search_tool.filters, - max_num_results=response_file_search_tool.max_num_results, - ranking_options=response_file_search_tool.ranking_options, - rewrite_query=False, - ) - return search_response.data - except Exception as e: - logger.warning(f"Failed to search vector store {vector_store_id}: {e}") - return [] - - # Run all searches in parallel using gather - search_tasks = [search_single_store(vid) for vid in response_file_search_tool.vector_store_ids] - all_results = await asyncio.gather(*search_tasks) - - # Flatten results - for results in all_results: - search_results.extend(results) - - # Convert search results to tool result format matching memory.py - # Format the results as interleaved content similar to memory.py - content_items = [] - content_items.append( - TextContentItem( - text=f"knowledge_search tool found {len(search_results)} chunks:\nBEGIN of knowledge_search tool results.\n" - ) - ) - - for i, result_item in enumerate(search_results): - chunk_text = result_item.content[0].text if result_item.content else "" - metadata_text = f"document_id: {result_item.file_id}, score: {result_item.score}" - if result_item.attributes: - metadata_text += f", attributes: {result_item.attributes}" - text_content = f"[{i + 1}] {metadata_text}\n{chunk_text}\n" - content_items.append(TextContentItem(text=text_content)) - - content_items.append(TextContentItem(text="END of knowledge_search tool results.\n")) - content_items.append( - TextContentItem( - text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.\n', - ) - ) - - return ToolInvocationResult( - content=content_items, - metadata={ - "document_ids": [r.file_id for r in search_results], - "chunks": [r.content[0].text if r.content else "" for r in search_results], - "scores": [r.score for r in search_results], - }, - ) - - async def _execute_tool_call( - self, - tool_call: OpenAIChatCompletionToolCall, - ctx: ChatCompletionContext, - ) -> tuple[OpenAIResponseOutput | None, OpenAIMessageParam | None]: - from llama_stack.providers.utils.inference.prompt_adapter import ( - interleaved_content_as_str, - ) - - tool_call_id = tool_call.id - function = tool_call.function - tool_kwargs = json.loads(function.arguments) if function.arguments else {} - - if not function or not tool_call_id or not function.name: - return None, None - - error_exc = None - result = None - try: - if ctx.mcp_tool_to_server and function.name in ctx.mcp_tool_to_server: - from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool - - mcp_tool = ctx.mcp_tool_to_server[function.name] - result = await invoke_mcp_tool( - endpoint=mcp_tool.server_url, - headers=mcp_tool.headers or {}, - tool_name=function.name, - kwargs=tool_kwargs, - ) - elif function.name == "knowledge_search": - response_file_search_tool = next( - (t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)), - None, - ) - if response_file_search_tool: - # Use vector_stores.search API instead of knowledge_search tool - # to support filters and ranking_options - query = tool_kwargs.get("query", "") - result = await self._execute_knowledge_search_via_vector_store( - query=query, - response_file_search_tool=response_file_search_tool, - ) - else: - result = await self.tool_runtime_api.invoke_tool( - tool_name=function.name, - kwargs=tool_kwargs, - ) - except Exception as e: - error_exc = e - - if function.name in ctx.mcp_tool_to_server: - from llama_stack.apis.agents.openai_responses import ( - OpenAIResponseOutputMessageMCPCall, - ) - - message = OpenAIResponseOutputMessageMCPCall( - id=tool_call_id, - arguments=function.arguments, - name=function.name, - server_label=ctx.mcp_tool_to_server[function.name].server_label, - ) - if error_exc: - message.error = str(error_exc) - elif (result.error_code and result.error_code > 0) or result.error_message: - message.error = f"Error (code {result.error_code}): {result.error_message}" - elif result.content: - message.output = interleaved_content_as_str(result.content) - else: - if function.name == "web_search": - message = OpenAIResponseOutputMessageWebSearchToolCall( - id=tool_call_id, - status="completed", - ) - if error_exc or (result.error_code and result.error_code > 0) or result.error_message: - message.status = "failed" - elif function.name == "knowledge_search": - message = OpenAIResponseOutputMessageFileSearchToolCall( - id=tool_call_id, - queries=[tool_kwargs.get("query", "")], - status="completed", - ) - if "document_ids" in result.metadata: - message.results = [] - for i, doc_id in enumerate(result.metadata["document_ids"]): - text = result.metadata["chunks"][i] if "chunks" in result.metadata else None - score = result.metadata["scores"][i] if "scores" in result.metadata else None - message.results.append( - OpenAIResponseOutputMessageFileSearchToolCallResults( - file_id=doc_id, - filename=doc_id, - text=text, - score=score, - attributes={}, - ) - ) - if error_exc or (result.error_code and result.error_code > 0) or result.error_message: - message.status = "failed" - else: - raise ValueError(f"Unknown tool {function.name} called") - - input_message = None - if result and result.content: - if isinstance(result.content, str): - content = result.content - elif isinstance(result.content, list): - from llama_stack.apis.common.content_types import ( - ImageContentItem, - TextContentItem, - ) - - content = [] - for item in result.content: - if isinstance(item, TextContentItem): - part = OpenAIChatCompletionContentPartTextParam(text=item.text) - elif isinstance(item, ImageContentItem): - if item.image.data: - url = f"data:image;base64,{item.image.data}" - else: - url = item.image.url - part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url)) - else: - raise ValueError(f"Unknown result content type: {type(item)}") - content.append(part) - else: - raise ValueError(f"Unknown result content type: {type(result.content)}") - input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id) - else: - text = str(error_exc) - input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id) - - return message, input_message - - -def _is_function_tool_call( - tool_call: OpenAIChatCompletionToolCall, - tools: list[OpenAIResponseInputTool], -) -> bool: - if not tool_call.function: - return False - for t in tools: - if t.type == "function" and t.name == tool_call.function.name: - return True - return False diff --git a/llama_stack/providers/inline/agents/meta_reference/persistence.py b/llama_stack/providers/inline/agents/meta_reference/persistence.py index 0b234d96c..3b7b4729c 100644 --- a/llama_stack/providers/inline/agents/meta_reference/persistence.py +++ b/llama_stack/providers/inline/agents/meta_reference/persistence.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import json -import logging import uuid from datetime import UTC, datetime @@ -15,9 +14,10 @@ from llama_stack.core.access_control.access_control import AccessDeniedError, is from llama_stack.core.access_control.datatypes import AccessRule from llama_stack.core.datatypes import User from llama_stack.core.request_headers import get_authenticated_user +from llama_stack.log import get_logger from llama_stack.providers.utils.kvstore import KVStore -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="agents::meta_reference") class AgentSessionInfo(Session): diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/__init__.py b/llama_stack/providers/inline/agents/meta_reference/responses/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/inline/agents/meta_reference/responses/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py b/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py new file mode 100644 index 000000000..c632e61aa --- /dev/null +++ b/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py @@ -0,0 +1,271 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import time +import uuid +from collections.abc import AsyncIterator + +from pydantic import BaseModel + +from llama_stack.apis.agents import Order +from llama_stack.apis.agents.openai_responses import ( + ListOpenAIResponseInputItem, + ListOpenAIResponseObject, + OpenAIDeleteResponseObject, + OpenAIResponseInput, + OpenAIResponseInputMessageContentText, + OpenAIResponseInputTool, + OpenAIResponseMessage, + OpenAIResponseObject, + OpenAIResponseObjectStream, + OpenAIResponseText, + OpenAIResponseTextFormat, +) +from llama_stack.apis.inference import ( + Inference, + OpenAISystemMessageParam, +) +from llama_stack.apis.tools import ToolGroups, ToolRuntime +from llama_stack.apis.vector_io import VectorIO +from llama_stack.log import get_logger +from llama_stack.providers.utils.responses.responses_store import ResponsesStore + +from .streaming import StreamingResponseOrchestrator +from .tool_executor import ToolExecutor +from .types import ChatCompletionContext +from .utils import ( + convert_response_input_to_chat_messages, + convert_response_text_to_chat_response_format, +) + +logger = get_logger(name=__name__, category="openai::responses") + + +class OpenAIResponsePreviousResponseWithInputItems(BaseModel): + input_items: ListOpenAIResponseInputItem + response: OpenAIResponseObject + + +class OpenAIResponsesImpl: + def __init__( + self, + inference_api: Inference, + tool_groups_api: ToolGroups, + tool_runtime_api: ToolRuntime, + responses_store: ResponsesStore, + vector_io_api: VectorIO, # VectorIO + ): + self.inference_api = inference_api + self.tool_groups_api = tool_groups_api + self.tool_runtime_api = tool_runtime_api + self.responses_store = responses_store + self.vector_io_api = vector_io_api + self.tool_executor = ToolExecutor( + tool_groups_api=tool_groups_api, + tool_runtime_api=tool_runtime_api, + vector_io_api=vector_io_api, + ) + + async def _prepend_previous_response( + self, + input: str | list[OpenAIResponseInput], + previous_response_id: str | None = None, + ): + if previous_response_id: + previous_response_with_input = await self.responses_store.get_response_object(previous_response_id) + + # previous response input items + new_input_items = previous_response_with_input.input + + # previous response output items + new_input_items.extend(previous_response_with_input.output) + + # new input items from the current request + if isinstance(input, str): + new_input_items.append(OpenAIResponseMessage(content=input, role="user")) + else: + new_input_items.extend(input) + + input = new_input_items + + return input + + async def _prepend_instructions(self, messages, instructions): + if instructions: + messages.insert(0, OpenAISystemMessageParam(content=instructions)) + + async def get_openai_response( + self, + response_id: str, + ) -> OpenAIResponseObject: + response_with_input = await self.responses_store.get_response_object(response_id) + return OpenAIResponseObject(**{k: v for k, v in response_with_input.model_dump().items() if k != "input"}) + + async def list_openai_responses( + self, + after: str | None = None, + limit: int | None = 50, + model: str | None = None, + order: Order | None = Order.desc, + ) -> ListOpenAIResponseObject: + return await self.responses_store.list_responses(after, limit, model, order) + + async def list_openai_response_input_items( + self, + response_id: str, + after: str | None = None, + before: str | None = None, + include: list[str] | None = None, + limit: int | None = 20, + order: Order | None = Order.desc, + ) -> ListOpenAIResponseInputItem: + """List input items for a given OpenAI response. + + :param response_id: The ID of the response to retrieve input items for. + :param after: An item ID to list items after, used for pagination. + :param before: An item ID to list items before, used for pagination. + :param include: Additional fields to include in the response. + :param limit: A limit on the number of objects to be returned. + :param order: The order to return the input items in. + :returns: An ListOpenAIResponseInputItem. + """ + return await self.responses_store.list_response_input_items(response_id, after, before, include, limit, order) + + async def _store_response( + self, + response: OpenAIResponseObject, + input: str | list[OpenAIResponseInput], + ) -> None: + new_input_id = f"msg_{uuid.uuid4()}" + if isinstance(input, str): + # synthesize a message from the input string + input_content = OpenAIResponseInputMessageContentText(text=input) + input_content_item = OpenAIResponseMessage( + role="user", + content=[input_content], + id=new_input_id, + ) + input_items_data = [input_content_item] + else: + # we already have a list of messages + input_items_data = [] + for input_item in input: + if isinstance(input_item, OpenAIResponseMessage): + # These may or may not already have an id, so dump to dict, check for id, and add if missing + input_item_dict = input_item.model_dump() + if "id" not in input_item_dict: + input_item_dict["id"] = new_input_id + input_items_data.append(OpenAIResponseMessage(**input_item_dict)) + else: + input_items_data.append(input_item) + + await self.responses_store.store_response_object( + response_object=response, + input=input_items_data, + ) + + async def create_openai_response( + self, + input: str | list[OpenAIResponseInput], + model: str, + instructions: str | None = None, + previous_response_id: str | None = None, + store: bool | None = True, + stream: bool | None = False, + temperature: float | None = None, + text: OpenAIResponseText | None = None, + tools: list[OpenAIResponseInputTool] | None = None, + include: list[str] | None = None, + max_infer_iters: int | None = 10, + ): + stream = bool(stream) + text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text + + stream_gen = self._create_streaming_response( + input=input, + model=model, + instructions=instructions, + previous_response_id=previous_response_id, + store=store, + temperature=temperature, + text=text, + tools=tools, + max_infer_iters=max_infer_iters, + ) + + if stream: + return stream_gen + else: + response = None + async for stream_chunk in stream_gen: + if stream_chunk.type == "response.completed": + if response is not None: + raise ValueError("The response stream completed multiple times! Earlier response: {response}") + response = stream_chunk.response + # don't leave the generator half complete! + + if response is None: + raise ValueError("The response stream never completed") + return response + + async def _create_streaming_response( + self, + input: str | list[OpenAIResponseInput], + model: str, + instructions: str | None = None, + previous_response_id: str | None = None, + store: bool | None = True, + temperature: float | None = None, + text: OpenAIResponseText | None = None, + tools: list[OpenAIResponseInputTool] | None = None, + max_infer_iters: int | None = 10, + ) -> AsyncIterator[OpenAIResponseObjectStream]: + # Input preprocessing + input = await self._prepend_previous_response(input, previous_response_id) + messages = await convert_response_input_to_chat_messages(input) + await self._prepend_instructions(messages, instructions) + + # Structured outputs + response_format = await convert_response_text_to_chat_response_format(text) + + ctx = ChatCompletionContext( + model=model, + messages=messages, + response_tools=tools, + temperature=temperature, + response_format=response_format, + ) + + # Create orchestrator and delegate streaming logic + response_id = f"resp-{uuid.uuid4()}" + created_at = int(time.time()) + + orchestrator = StreamingResponseOrchestrator( + inference_api=self.inference_api, + ctx=ctx, + response_id=response_id, + created_at=created_at, + text=text, + max_infer_iters=max_infer_iters, + tool_executor=self.tool_executor, + ) + + # Stream the response + final_response = None + async for stream_chunk in orchestrator.create_response(): + if stream_chunk.type == "response.completed": + final_response = stream_chunk.response + yield stream_chunk + + # Store the response if requested + if store and final_response: + await self._store_response( + response=final_response, + input=input, + ) + + async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject: + return await self.responses_store.delete_response_object(response_id) diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py new file mode 100644 index 000000000..3e69fa5cd --- /dev/null +++ b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py @@ -0,0 +1,634 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import uuid +from collections.abc import AsyncIterator +from typing import Any + +from llama_stack.apis.agents.openai_responses import ( + AllowedToolsFilter, + MCPListToolsTool, + OpenAIResponseContentPartOutputText, + OpenAIResponseInputTool, + OpenAIResponseInputToolMCP, + OpenAIResponseObject, + OpenAIResponseObjectStream, + OpenAIResponseObjectStreamResponseCompleted, + OpenAIResponseObjectStreamResponseContentPartAdded, + OpenAIResponseObjectStreamResponseContentPartDone, + OpenAIResponseObjectStreamResponseCreated, + OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta, + OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone, + OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta, + OpenAIResponseObjectStreamResponseMcpCallArgumentsDone, + OpenAIResponseObjectStreamResponseMcpListToolsCompleted, + OpenAIResponseObjectStreamResponseMcpListToolsInProgress, + OpenAIResponseObjectStreamResponseOutputItemAdded, + OpenAIResponseObjectStreamResponseOutputItemDone, + OpenAIResponseObjectStreamResponseOutputTextDelta, + OpenAIResponseOutput, + OpenAIResponseOutputMessageFunctionToolCall, + OpenAIResponseOutputMessageMCPListTools, + OpenAIResponseText, + WebSearchToolTypes, +) +from llama_stack.apis.inference import ( + Inference, + OpenAIAssistantMessageParam, + OpenAIChatCompletion, + OpenAIChatCompletionToolCall, + OpenAIChoice, +) +from llama_stack.log import get_logger + +from .types import ChatCompletionContext, ChatCompletionResult +from .utils import convert_chat_choice_to_response_message, is_function_tool_call + +logger = get_logger(name=__name__, category="agents::meta_reference") + + +class StreamingResponseOrchestrator: + def __init__( + self, + inference_api: Inference, + ctx: ChatCompletionContext, + response_id: str, + created_at: int, + text: OpenAIResponseText, + max_infer_iters: int, + tool_executor, # Will be the tool execution logic from the main class + ): + self.inference_api = inference_api + self.ctx = ctx + self.response_id = response_id + self.created_at = created_at + self.text = text + self.max_infer_iters = max_infer_iters + self.tool_executor = tool_executor + self.sequence_number = 0 + # Store MCP tool mapping that gets built during tool processing + self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = {} + + async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]: + # Initialize output messages + output_messages: list[OpenAIResponseOutput] = [] + # Create initial response and emit response.created immediately + initial_response = OpenAIResponseObject( + created_at=self.created_at, + id=self.response_id, + model=self.ctx.model, + object="response", + status="in_progress", + output=output_messages.copy(), + text=self.text, + ) + + yield OpenAIResponseObjectStreamResponseCreated(response=initial_response) + + # Process all tools (including MCP tools) and emit streaming events + if self.ctx.response_tools: + async for stream_event in self._process_tools(self.ctx.response_tools, output_messages): + yield stream_event + + n_iter = 0 + messages = self.ctx.messages.copy() + + while True: + completion_result = await self.inference_api.openai_chat_completion( + model=self.ctx.model, + messages=messages, + tools=self.ctx.chat_tools, + stream=True, + temperature=self.ctx.temperature, + response_format=self.ctx.response_format, + ) + + # Process streaming chunks and build complete response + completion_result_data = None + async for stream_event_or_result in self._process_streaming_chunks(completion_result, output_messages): + if isinstance(stream_event_or_result, ChatCompletionResult): + completion_result_data = stream_event_or_result + else: + yield stream_event_or_result + if not completion_result_data: + raise ValueError("Streaming chunk processor failed to return completion data") + current_response = self._build_chat_completion(completion_result_data) + + function_tool_calls, non_function_tool_calls, next_turn_messages = self._separate_tool_calls( + current_response, messages + ) + + # Handle choices with no tool calls + for choice in current_response.choices: + if not (choice.message.tool_calls and self.ctx.response_tools): + output_messages.append(await convert_chat_choice_to_response_message(choice)) + + # Execute tool calls and coordinate results + async for stream_event in self._coordinate_tool_execution( + function_tool_calls, + non_function_tool_calls, + completion_result_data, + output_messages, + next_turn_messages, + ): + yield stream_event + + if not function_tool_calls and not non_function_tool_calls: + break + + if function_tool_calls: + logger.info("Exiting inference loop since there is a function (client-side) tool call") + break + + n_iter += 1 + if n_iter >= self.max_infer_iters: + logger.info(f"Exiting inference loop since iteration count({n_iter}) exceeds {self.max_infer_iters=}") + break + + messages = next_turn_messages + + # Create final response + final_response = OpenAIResponseObject( + created_at=self.created_at, + id=self.response_id, + model=self.ctx.model, + object="response", + status="completed", + text=self.text, + output=output_messages, + ) + + # Emit response.completed + yield OpenAIResponseObjectStreamResponseCompleted(response=final_response) + + def _separate_tool_calls(self, current_response, messages) -> tuple[list, list, list]: + """Separate tool calls into function and non-function categories.""" + function_tool_calls = [] + non_function_tool_calls = [] + next_turn_messages = messages.copy() + + for choice in current_response.choices: + next_turn_messages.append(choice.message) + + if choice.message.tool_calls and self.ctx.response_tools: + for tool_call in choice.message.tool_calls: + if is_function_tool_call(tool_call, self.ctx.response_tools): + function_tool_calls.append(tool_call) + else: + non_function_tool_calls.append(tool_call) + + return function_tool_calls, non_function_tool_calls, next_turn_messages + + async def _process_streaming_chunks( + self, completion_result, output_messages: list[OpenAIResponseOutput] + ) -> AsyncIterator[OpenAIResponseObjectStream | ChatCompletionResult]: + """Process streaming chunks and emit events, returning completion data.""" + # Initialize result tracking + chat_response_id = "" + chat_response_content = [] + chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {} + chunk_created = 0 + chunk_model = "" + chunk_finish_reason = "" + + # Create a placeholder message item for delta events + message_item_id = f"msg_{uuid.uuid4()}" + # Track tool call items for streaming events + tool_call_item_ids: dict[int, str] = {} + # Track content parts for streaming events + content_part_emitted = False + + async for chunk in completion_result: + chat_response_id = chunk.id + chunk_created = chunk.created + chunk_model = chunk.model + for chunk_choice in chunk.choices: + # Emit incremental text content as delta events + if chunk_choice.delta.content: + # Emit content_part.added event for first text chunk + if not content_part_emitted: + content_part_emitted = True + self.sequence_number += 1 + yield OpenAIResponseObjectStreamResponseContentPartAdded( + response_id=self.response_id, + item_id=message_item_id, + part=OpenAIResponseContentPartOutputText( + text="", # Will be filled incrementally via text deltas + ), + sequence_number=self.sequence_number, + ) + self.sequence_number += 1 + yield OpenAIResponseObjectStreamResponseOutputTextDelta( + content_index=0, + delta=chunk_choice.delta.content, + item_id=message_item_id, + output_index=0, + sequence_number=self.sequence_number, + ) + + # Collect content for final response + chat_response_content.append(chunk_choice.delta.content or "") + if chunk_choice.finish_reason: + chunk_finish_reason = chunk_choice.finish_reason + + # Aggregate tool call arguments across chunks + if chunk_choice.delta.tool_calls: + for tool_call in chunk_choice.delta.tool_calls: + response_tool_call = chat_response_tool_calls.get(tool_call.index, None) + # Create new tool call entry if this is the first chunk for this index + is_new_tool_call = response_tool_call is None + if is_new_tool_call: + tool_call_dict: dict[str, Any] = tool_call.model_dump() + tool_call_dict.pop("type", None) + response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict) + chat_response_tool_calls[tool_call.index] = response_tool_call + + # Create item ID for this tool call for streaming events + tool_call_item_id = f"fc_{uuid.uuid4()}" + tool_call_item_ids[tool_call.index] = tool_call_item_id + + # Emit output_item.added event for the new function call + self.sequence_number += 1 + function_call_item = OpenAIResponseOutputMessageFunctionToolCall( + arguments="", # Will be filled incrementally via delta events + call_id=tool_call.id or "", + name=tool_call.function.name if tool_call.function else "", + id=tool_call_item_id, + status="in_progress", + ) + yield OpenAIResponseObjectStreamResponseOutputItemAdded( + response_id=self.response_id, + item=function_call_item, + output_index=len(output_messages), + sequence_number=self.sequence_number, + ) + + # Stream tool call arguments as they arrive (differentiate between MCP and function calls) + if tool_call.function and tool_call.function.arguments: + tool_call_item_id = tool_call_item_ids[tool_call.index] + self.sequence_number += 1 + + # Check if this is an MCP tool call + is_mcp_tool = tool_call.function.name and tool_call.function.name in self.mcp_tool_to_server + if is_mcp_tool: + # Emit MCP-specific argument delta event + yield OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta( + delta=tool_call.function.arguments, + item_id=tool_call_item_id, + output_index=len(output_messages), + sequence_number=self.sequence_number, + ) + else: + # Emit function call argument delta event + yield OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta( + delta=tool_call.function.arguments, + item_id=tool_call_item_id, + output_index=len(output_messages), + sequence_number=self.sequence_number, + ) + + # Accumulate arguments for final response (only for subsequent chunks) + if not is_new_tool_call: + response_tool_call.function.arguments = ( + response_tool_call.function.arguments or "" + ) + tool_call.function.arguments + + # Emit arguments.done events for completed tool calls (differentiate between MCP and function calls) + for tool_call_index in sorted(chat_response_tool_calls.keys()): + tool_call_item_id = tool_call_item_ids[tool_call_index] + final_arguments = chat_response_tool_calls[tool_call_index].function.arguments or "" + tool_call_name = chat_response_tool_calls[tool_call_index].function.name + + # Check if this is an MCP tool call + is_mcp_tool = tool_call_name and tool_call_name in self.mcp_tool_to_server + self.sequence_number += 1 + done_event_cls = ( + OpenAIResponseObjectStreamResponseMcpCallArgumentsDone + if is_mcp_tool + else OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone + ) + yield done_event_cls( + arguments=final_arguments, + item_id=tool_call_item_id, + output_index=len(output_messages), + sequence_number=self.sequence_number, + ) + + # Emit content_part.done event if text content was streamed (before content gets cleared) + if content_part_emitted: + final_text = "".join(chat_response_content) + self.sequence_number += 1 + yield OpenAIResponseObjectStreamResponseContentPartDone( + response_id=self.response_id, + item_id=message_item_id, + part=OpenAIResponseContentPartOutputText( + text=final_text, + ), + sequence_number=self.sequence_number, + ) + + # Clear content when there are tool calls (OpenAI spec behavior) + if chat_response_tool_calls: + chat_response_content = [] + + yield ChatCompletionResult( + response_id=chat_response_id, + content=chat_response_content, + tool_calls=chat_response_tool_calls, + created=chunk_created, + model=chunk_model, + finish_reason=chunk_finish_reason, + message_item_id=message_item_id, + tool_call_item_ids=tool_call_item_ids, + content_part_emitted=content_part_emitted, + ) + + def _build_chat_completion(self, result: ChatCompletionResult) -> OpenAIChatCompletion: + """Build OpenAIChatCompletion from ChatCompletionResult.""" + # Convert collected chunks to complete response + if result.tool_calls: + tool_calls = [result.tool_calls[i] for i in sorted(result.tool_calls.keys())] + else: + tool_calls = None + + assistant_message = OpenAIAssistantMessageParam( + content=result.content_text, + tool_calls=tool_calls, + ) + return OpenAIChatCompletion( + id=result.response_id, + choices=[ + OpenAIChoice( + message=assistant_message, + finish_reason=result.finish_reason, + index=0, + ) + ], + created=result.created, + model=result.model, + ) + + async def _coordinate_tool_execution( + self, + function_tool_calls: list, + non_function_tool_calls: list, + completion_result_data: ChatCompletionResult, + output_messages: list[OpenAIResponseOutput], + next_turn_messages: list, + ) -> AsyncIterator[OpenAIResponseObjectStream]: + """Coordinate execution of both function and non-function tool calls.""" + # Execute non-function tool calls + for tool_call in non_function_tool_calls: + # Find the item_id for this tool call + matching_item_id = None + for index, item_id in completion_result_data.tool_call_item_ids.items(): + response_tool_call = completion_result_data.tool_calls.get(index) + if response_tool_call and response_tool_call.id == tool_call.id: + matching_item_id = item_id + break + + # Use a fallback item_id if not found + if not matching_item_id: + matching_item_id = f"tc_{uuid.uuid4()}" + + # Execute tool call with streaming + tool_call_log = None + tool_response_message = None + async for result in self.tool_executor.execute_tool_call( + tool_call, + self.ctx, + self.sequence_number, + len(output_messages), + matching_item_id, + self.mcp_tool_to_server, + ): + if result.stream_event: + # Forward streaming events + self.sequence_number = result.sequence_number + yield result.stream_event + + if result.final_output_message is not None: + tool_call_log = result.final_output_message + tool_response_message = result.final_input_message + self.sequence_number = result.sequence_number + + if tool_call_log: + output_messages.append(tool_call_log) + + # Emit output_item.done event for completed non-function tool call + if matching_item_id: + self.sequence_number += 1 + yield OpenAIResponseObjectStreamResponseOutputItemDone( + response_id=self.response_id, + item=tool_call_log, + output_index=len(output_messages) - 1, + sequence_number=self.sequence_number, + ) + + if tool_response_message: + next_turn_messages.append(tool_response_message) + + # Execute function tool calls (client-side) + for tool_call in function_tool_calls: + # Find the item_id for this tool call from our tracking dictionary + matching_item_id = None + for index, item_id in completion_result_data.tool_call_item_ids.items(): + response_tool_call = completion_result_data.tool_calls.get(index) + if response_tool_call and response_tool_call.id == tool_call.id: + matching_item_id = item_id + break + + # Use existing item_id or create new one if not found + final_item_id = matching_item_id or f"fc_{uuid.uuid4()}" + + function_call_item = OpenAIResponseOutputMessageFunctionToolCall( + arguments=tool_call.function.arguments or "", + call_id=tool_call.id, + name=tool_call.function.name or "", + id=final_item_id, + status="completed", + ) + output_messages.append(function_call_item) + + # Emit output_item.done event for completed function call + self.sequence_number += 1 + yield OpenAIResponseObjectStreamResponseOutputItemDone( + response_id=self.response_id, + item=function_call_item, + output_index=len(output_messages) - 1, + sequence_number=self.sequence_number, + ) + + async def _process_tools( + self, tools: list[OpenAIResponseInputTool], output_messages: list[OpenAIResponseOutput] + ) -> AsyncIterator[OpenAIResponseObjectStream]: + """Process all tools and emit appropriate streaming events.""" + from openai.types.chat import ChatCompletionToolParam + + from llama_stack.apis.tools import Tool + from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition + from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool + + def make_openai_tool(tool_name: str, tool: Tool) -> ChatCompletionToolParam: + tool_def = ToolDefinition( + tool_name=tool_name, + description=tool.description, + parameters={ + param.name: ToolParamDefinition( + param_type=param.parameter_type, + description=param.description, + required=param.required, + default=param.default, + ) + for param in tool.parameters + }, + ) + return convert_tooldef_to_openai_tool(tool_def) + + # Initialize chat_tools if not already set + if self.ctx.chat_tools is None: + self.ctx.chat_tools = [] + + for input_tool in tools: + if input_tool.type == "function": + self.ctx.chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump())) + elif input_tool.type in WebSearchToolTypes: + tool_name = "web_search" + # Need to access tool_groups_api from tool_executor + tool = await self.tool_executor.tool_groups_api.get_tool(tool_name) + if not tool: + raise ValueError(f"Tool {tool_name} not found") + self.ctx.chat_tools.append(make_openai_tool(tool_name, tool)) + elif input_tool.type == "file_search": + tool_name = "knowledge_search" + tool = await self.tool_executor.tool_groups_api.get_tool(tool_name) + if not tool: + raise ValueError(f"Tool {tool_name} not found") + self.ctx.chat_tools.append(make_openai_tool(tool_name, tool)) + elif input_tool.type == "mcp": + async for stream_event in self._process_mcp_tool(input_tool, output_messages): + yield stream_event + else: + raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}") + + async def _process_mcp_tool( + self, mcp_tool: OpenAIResponseInputToolMCP, output_messages: list[OpenAIResponseOutput] + ) -> AsyncIterator[OpenAIResponseObjectStream]: + """Process an MCP tool configuration and emit appropriate streaming events.""" + from llama_stack.providers.utils.tools.mcp import list_mcp_tools + + # Emit mcp_list_tools.in_progress + self.sequence_number += 1 + yield OpenAIResponseObjectStreamResponseMcpListToolsInProgress( + sequence_number=self.sequence_number, + ) + + try: + # Parse allowed/never allowed tools + always_allowed = None + never_allowed = None + if mcp_tool.allowed_tools: + if isinstance(mcp_tool.allowed_tools, list): + always_allowed = mcp_tool.allowed_tools + elif isinstance(mcp_tool.allowed_tools, AllowedToolsFilter): + always_allowed = mcp_tool.allowed_tools.always + never_allowed = mcp_tool.allowed_tools.never + + # Call list_mcp_tools + tool_defs = await list_mcp_tools( + endpoint=mcp_tool.server_url, + headers=mcp_tool.headers or {}, + ) + + # Create the MCP list tools message + mcp_list_message = OpenAIResponseOutputMessageMCPListTools( + id=f"mcp_list_{uuid.uuid4()}", + server_label=mcp_tool.server_label, + tools=[], + ) + + # Process tools and update context + for t in tool_defs.data: + if never_allowed and t.name in never_allowed: + continue + if not always_allowed or t.name in always_allowed: + # Add to chat tools for inference + from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition + from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool + + tool_def = ToolDefinition( + tool_name=t.name, + description=t.description, + parameters={ + param.name: ToolParamDefinition( + param_type=param.parameter_type, + description=param.description, + required=param.required, + default=param.default, + ) + for param in t.parameters + }, + ) + openai_tool = convert_tooldef_to_openai_tool(tool_def) + if self.ctx.chat_tools is None: + self.ctx.chat_tools = [] + self.ctx.chat_tools.append(openai_tool) + + # Add to MCP tool mapping + if t.name in self.mcp_tool_to_server: + raise ValueError(f"Duplicate tool name {t.name} found for server {mcp_tool.server_label}") + self.mcp_tool_to_server[t.name] = mcp_tool + + # Add to MCP list message + mcp_list_message.tools.append( + MCPListToolsTool( + name=t.name, + description=t.description, + input_schema={ + "type": "object", + "properties": { + p.name: { + "type": p.parameter_type, + "description": p.description, + } + for p in t.parameters + }, + "required": [p.name for p in t.parameters if p.required], + }, + ) + ) + + # Add the MCP list message to output + output_messages.append(mcp_list_message) + + # Emit output_item.added for the MCP list tools message + self.sequence_number += 1 + yield OpenAIResponseObjectStreamResponseOutputItemAdded( + response_id=self.response_id, + item=mcp_list_message, + output_index=len(output_messages) - 1, + sequence_number=self.sequence_number, + ) + + # Emit mcp_list_tools.completed + self.sequence_number += 1 + yield OpenAIResponseObjectStreamResponseMcpListToolsCompleted( + sequence_number=self.sequence_number, + ) + + # Emit output_item.done for the MCP list tools message + self.sequence_number += 1 + yield OpenAIResponseObjectStreamResponseOutputItemDone( + response_id=self.response_id, + item=mcp_list_message, + output_index=len(output_messages) - 1, + sequence_number=self.sequence_number, + ) + + except Exception as e: + # TODO: Emit mcp_list_tools.failed event if needed + logger.exception(f"Failed to list MCP tools from {mcp_tool.server_url}: {e}") + raise diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py b/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py new file mode 100644 index 000000000..b028c018b --- /dev/null +++ b/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py @@ -0,0 +1,379 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +import json +from collections.abc import AsyncIterator + +from llama_stack.apis.agents.openai_responses import ( + OpenAIResponseInputToolFileSearch, + OpenAIResponseInputToolMCP, + OpenAIResponseObjectStreamResponseMcpCallCompleted, + OpenAIResponseObjectStreamResponseMcpCallFailed, + OpenAIResponseObjectStreamResponseMcpCallInProgress, + OpenAIResponseObjectStreamResponseWebSearchCallCompleted, + OpenAIResponseObjectStreamResponseWebSearchCallInProgress, + OpenAIResponseObjectStreamResponseWebSearchCallSearching, + OpenAIResponseOutputMessageFileSearchToolCall, + OpenAIResponseOutputMessageFileSearchToolCallResults, + OpenAIResponseOutputMessageWebSearchToolCall, +) +from llama_stack.apis.common.content_types import ( + ImageContentItem, + TextContentItem, +) +from llama_stack.apis.inference import ( + OpenAIChatCompletionContentPartImageParam, + OpenAIChatCompletionContentPartTextParam, + OpenAIChatCompletionToolCall, + OpenAIImageURL, + OpenAIToolMessageParam, +) +from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime +from llama_stack.apis.vector_io import VectorIO +from llama_stack.log import get_logger + +from .types import ChatCompletionContext, ToolExecutionResult + +logger = get_logger(name=__name__, category="agents::meta_reference") + + +class ToolExecutor: + def __init__( + self, + tool_groups_api: ToolGroups, + tool_runtime_api: ToolRuntime, + vector_io_api: VectorIO, + ): + self.tool_groups_api = tool_groups_api + self.tool_runtime_api = tool_runtime_api + self.vector_io_api = vector_io_api + + async def execute_tool_call( + self, + tool_call: OpenAIChatCompletionToolCall, + ctx: ChatCompletionContext, + sequence_number: int, + output_index: int, + item_id: str, + mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None, + ) -> AsyncIterator[ToolExecutionResult]: + tool_call_id = tool_call.id + function = tool_call.function + tool_kwargs = json.loads(function.arguments) if function.arguments else {} + + if not function or not tool_call_id or not function.name: + yield ToolExecutionResult(sequence_number=sequence_number) + return + + # Emit progress events for tool execution start + async for event_result in self._emit_progress_events( + function.name, ctx, sequence_number, output_index, item_id, mcp_tool_to_server + ): + sequence_number = event_result.sequence_number + yield event_result + + # Execute the actual tool call + error_exc, result = await self._execute_tool(function.name, tool_kwargs, ctx, mcp_tool_to_server) + + # Emit completion events for tool execution + has_error = error_exc or (result and ((result.error_code and result.error_code > 0) or result.error_message)) + async for event_result in self._emit_completion_events( + function.name, ctx, sequence_number, output_index, item_id, has_error, mcp_tool_to_server + ): + sequence_number = event_result.sequence_number + yield event_result + + # Build result messages from tool execution + output_message, input_message = await self._build_result_messages( + function, tool_call_id, tool_kwargs, ctx, error_exc, result, has_error, mcp_tool_to_server + ) + + # Yield the final result + yield ToolExecutionResult( + sequence_number=sequence_number, final_output_message=output_message, final_input_message=input_message + ) + + async def _execute_knowledge_search_via_vector_store( + self, + query: str, + response_file_search_tool: OpenAIResponseInputToolFileSearch, + ) -> ToolInvocationResult: + """Execute knowledge search using vector_stores.search API with filters support.""" + search_results = [] + + # Create search tasks for all vector stores + async def search_single_store(vector_store_id): + try: + search_response = await self.vector_io_api.openai_search_vector_store( + vector_store_id=vector_store_id, + query=query, + filters=response_file_search_tool.filters, + max_num_results=response_file_search_tool.max_num_results, + ranking_options=response_file_search_tool.ranking_options, + rewrite_query=False, + ) + return search_response.data + except Exception as e: + logger.warning(f"Failed to search vector store {vector_store_id}: {e}") + return [] + + # Run all searches in parallel using gather + search_tasks = [search_single_store(vid) for vid in response_file_search_tool.vector_store_ids] + all_results = await asyncio.gather(*search_tasks) + + # Flatten results + for results in all_results: + search_results.extend(results) + + # Convert search results to tool result format matching memory.py + # Format the results as interleaved content similar to memory.py + content_items = [] + content_items.append( + TextContentItem( + text=f"knowledge_search tool found {len(search_results)} chunks:\nBEGIN of knowledge_search tool results.\n" + ) + ) + + for i, result_item in enumerate(search_results): + chunk_text = result_item.content[0].text if result_item.content else "" + metadata_text = f"document_id: {result_item.file_id}, score: {result_item.score}" + if result_item.attributes: + metadata_text += f", attributes: {result_item.attributes}" + text_content = f"[{i + 1}] {metadata_text}\n{chunk_text}\n" + content_items.append(TextContentItem(text=text_content)) + + content_items.append(TextContentItem(text="END of knowledge_search tool results.\n")) + content_items.append( + TextContentItem( + text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.\n', + ) + ) + + return ToolInvocationResult( + content=content_items, + metadata={ + "document_ids": [r.file_id for r in search_results], + "chunks": [r.content[0].text if r.content else "" for r in search_results], + "scores": [r.score for r in search_results], + }, + ) + + async def _emit_progress_events( + self, + function_name: str, + ctx: ChatCompletionContext, + sequence_number: int, + output_index: int, + item_id: str, + mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None, + ) -> AsyncIterator[ToolExecutionResult]: + """Emit progress events for tool execution start.""" + # Emit in_progress event based on tool type (only for tools with specific streaming events) + progress_event = None + if mcp_tool_to_server and function_name in mcp_tool_to_server: + sequence_number += 1 + progress_event = OpenAIResponseObjectStreamResponseMcpCallInProgress( + item_id=item_id, + output_index=output_index, + sequence_number=sequence_number, + ) + elif function_name == "web_search": + sequence_number += 1 + progress_event = OpenAIResponseObjectStreamResponseWebSearchCallInProgress( + item_id=item_id, + output_index=output_index, + sequence_number=sequence_number, + ) + # Note: knowledge_search and other custom tools don't have specific streaming events in OpenAI spec + + if progress_event: + yield ToolExecutionResult(stream_event=progress_event, sequence_number=sequence_number) + + # For web search, emit searching event + if function_name == "web_search": + sequence_number += 1 + searching_event = OpenAIResponseObjectStreamResponseWebSearchCallSearching( + item_id=item_id, + output_index=output_index, + sequence_number=sequence_number, + ) + yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number) + + async def _execute_tool( + self, + function_name: str, + tool_kwargs: dict, + ctx: ChatCompletionContext, + mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None, + ) -> tuple[Exception | None, any]: + """Execute the tool and return error exception and result.""" + error_exc = None + result = None + + try: + if mcp_tool_to_server and function_name in mcp_tool_to_server: + from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool + + mcp_tool = mcp_tool_to_server[function_name] + result = await invoke_mcp_tool( + endpoint=mcp_tool.server_url, + headers=mcp_tool.headers or {}, + tool_name=function_name, + kwargs=tool_kwargs, + ) + elif function_name == "knowledge_search": + response_file_search_tool = next( + (t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)), + None, + ) + if response_file_search_tool: + # Use vector_stores.search API instead of knowledge_search tool + # to support filters and ranking_options + query = tool_kwargs.get("query", "") + result = await self._execute_knowledge_search_via_vector_store( + query=query, + response_file_search_tool=response_file_search_tool, + ) + else: + result = await self.tool_runtime_api.invoke_tool( + tool_name=function_name, + kwargs=tool_kwargs, + ) + except Exception as e: + error_exc = e + + return error_exc, result + + async def _emit_completion_events( + self, + function_name: str, + ctx: ChatCompletionContext, + sequence_number: int, + output_index: int, + item_id: str, + has_error: bool, + mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None, + ) -> AsyncIterator[ToolExecutionResult]: + """Emit completion or failure events for tool execution.""" + completion_event = None + + if mcp_tool_to_server and function_name in mcp_tool_to_server: + sequence_number += 1 + if has_error: + completion_event = OpenAIResponseObjectStreamResponseMcpCallFailed( + sequence_number=sequence_number, + ) + else: + completion_event = OpenAIResponseObjectStreamResponseMcpCallCompleted( + sequence_number=sequence_number, + ) + elif function_name == "web_search": + sequence_number += 1 + completion_event = OpenAIResponseObjectStreamResponseWebSearchCallCompleted( + item_id=item_id, + output_index=output_index, + sequence_number=sequence_number, + ) + # Note: knowledge_search and other custom tools don't have specific completion events in OpenAI spec + + if completion_event: + yield ToolExecutionResult(stream_event=completion_event, sequence_number=sequence_number) + + async def _build_result_messages( + self, + function, + tool_call_id: str, + tool_kwargs: dict, + ctx: ChatCompletionContext, + error_exc: Exception | None, + result: any, + has_error: bool, + mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None, + ) -> tuple[any, any]: + """Build output and input messages from tool execution results.""" + from llama_stack.providers.utils.inference.prompt_adapter import ( + interleaved_content_as_str, + ) + + # Build output message + if mcp_tool_to_server and function.name in mcp_tool_to_server: + from llama_stack.apis.agents.openai_responses import ( + OpenAIResponseOutputMessageMCPCall, + ) + + message = OpenAIResponseOutputMessageMCPCall( + id=tool_call_id, + arguments=function.arguments, + name=function.name, + server_label=mcp_tool_to_server[function.name].server_label, + ) + if error_exc: + message.error = str(error_exc) + elif (result and result.error_code and result.error_code > 0) or (result and result.error_message): + message.error = f"Error (code {result.error_code}): {result.error_message}" + elif result and result.content: + message.output = interleaved_content_as_str(result.content) + else: + if function.name == "web_search": + message = OpenAIResponseOutputMessageWebSearchToolCall( + id=tool_call_id, + status="completed", + ) + if has_error: + message.status = "failed" + elif function.name == "knowledge_search": + message = OpenAIResponseOutputMessageFileSearchToolCall( + id=tool_call_id, + queries=[tool_kwargs.get("query", "")], + status="completed", + ) + if result and "document_ids" in result.metadata: + message.results = [] + for i, doc_id in enumerate(result.metadata["document_ids"]): + text = result.metadata["chunks"][i] if "chunks" in result.metadata else None + score = result.metadata["scores"][i] if "scores" in result.metadata else None + message.results.append( + OpenAIResponseOutputMessageFileSearchToolCallResults( + file_id=doc_id, + filename=doc_id, + text=text, + score=score, + attributes={}, + ) + ) + if has_error: + message.status = "failed" + else: + raise ValueError(f"Unknown tool {function.name} called") + + # Build input message + input_message = None + if result and result.content: + if isinstance(result.content, str): + content = result.content + elif isinstance(result.content, list): + content = [] + for item in result.content: + if isinstance(item, TextContentItem): + part = OpenAIChatCompletionContentPartTextParam(text=item.text) + elif isinstance(item, ImageContentItem): + if item.image.data: + url = f"data:image;base64,{item.image.data}" + else: + url = item.image.url + part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url)) + else: + raise ValueError(f"Unknown result content type: {type(item)}") + content.append(part) + else: + raise ValueError(f"Unknown result content type: {type(result.content)}") + input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id) + else: + text = str(error_exc) if error_exc else "Tool execution failed" + input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id) + + return message, input_message diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/types.py b/llama_stack/providers/inline/agents/meta_reference/responses/types.py new file mode 100644 index 000000000..89086c262 --- /dev/null +++ b/llama_stack/providers/inline/agents/meta_reference/responses/types.py @@ -0,0 +1,60 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from dataclasses import dataclass + +from openai.types.chat import ChatCompletionToolParam +from pydantic import BaseModel + +from llama_stack.apis.agents.openai_responses import ( + OpenAIResponseInputTool, + OpenAIResponseObjectStream, + OpenAIResponseOutput, +) +from llama_stack.apis.inference import OpenAIChatCompletionToolCall, OpenAIMessageParam, OpenAIResponseFormatParam + + +class ToolExecutionResult(BaseModel): + """Result of streaming tool execution.""" + + stream_event: OpenAIResponseObjectStream | None = None + sequence_number: int + final_output_message: OpenAIResponseOutput | None = None + final_input_message: OpenAIMessageParam | None = None + + +@dataclass +class ChatCompletionResult: + """Result of processing streaming chat completion chunks.""" + + response_id: str + content: list[str] + tool_calls: dict[int, OpenAIChatCompletionToolCall] + created: int + model: str + finish_reason: str + message_item_id: str # For streaming events + tool_call_item_ids: dict[int, str] # For streaming events + content_part_emitted: bool # Tracking state + + @property + def content_text(self) -> str: + """Get joined content as string.""" + return "".join(self.content) + + @property + def has_tool_calls(self) -> bool: + """Check if there are any tool calls.""" + return bool(self.tool_calls) + + +class ChatCompletionContext(BaseModel): + model: str + messages: list[OpenAIMessageParam] + response_tools: list[OpenAIResponseInputTool] | None = None + chat_tools: list[ChatCompletionToolParam] | None = None + temperature: float | None + response_format: OpenAIResponseFormatParam diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/utils.py b/llama_stack/providers/inline/agents/meta_reference/responses/utils.py new file mode 100644 index 000000000..7aaeb4cd5 --- /dev/null +++ b/llama_stack/providers/inline/agents/meta_reference/responses/utils.py @@ -0,0 +1,205 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import uuid + +from llama_stack.apis.agents.openai_responses import ( + OpenAIResponseInput, + OpenAIResponseInputFunctionToolCallOutput, + OpenAIResponseInputMessageContent, + OpenAIResponseInputMessageContentImage, + OpenAIResponseInputMessageContentText, + OpenAIResponseInputTool, + OpenAIResponseMessage, + OpenAIResponseOutputMessageContent, + OpenAIResponseOutputMessageContentOutputText, + OpenAIResponseOutputMessageFunctionToolCall, + OpenAIResponseOutputMessageMCPCall, + OpenAIResponseOutputMessageMCPListTools, + OpenAIResponseText, +) +from llama_stack.apis.inference import ( + OpenAIAssistantMessageParam, + OpenAIChatCompletionContentPartImageParam, + OpenAIChatCompletionContentPartParam, + OpenAIChatCompletionContentPartTextParam, + OpenAIChatCompletionToolCall, + OpenAIChatCompletionToolCallFunction, + OpenAIChoice, + OpenAIDeveloperMessageParam, + OpenAIImageURL, + OpenAIJSONSchema, + OpenAIMessageParam, + OpenAIResponseFormatJSONObject, + OpenAIResponseFormatJSONSchema, + OpenAIResponseFormatParam, + OpenAIResponseFormatText, + OpenAISystemMessageParam, + OpenAIToolMessageParam, + OpenAIUserMessageParam, +) + + +async def convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenAIResponseMessage: + """Convert an OpenAI Chat Completion choice into an OpenAI Response output message.""" + output_content = "" + if isinstance(choice.message.content, str): + output_content = choice.message.content + elif isinstance(choice.message.content, OpenAIChatCompletionContentPartTextParam): + output_content = choice.message.content.text + else: + raise ValueError( + f"Llama Stack OpenAI Responses does not yet support output content type: {type(choice.message.content)}" + ) + + return OpenAIResponseMessage( + id=f"msg_{uuid.uuid4()}", + content=[OpenAIResponseOutputMessageContentOutputText(text=output_content)], + status="completed", + role="assistant", + ) + + +async def convert_response_content_to_chat_content( + content: (str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]), +) -> str | list[OpenAIChatCompletionContentPartParam]: + """ + Convert the content parts from an OpenAI Response API request into OpenAI Chat Completion content parts. + + The content schemas of each API look similar, but are not exactly the same. + """ + if isinstance(content, str): + return content + + converted_parts = [] + for content_part in content: + if isinstance(content_part, OpenAIResponseInputMessageContentText): + converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text)) + elif isinstance(content_part, OpenAIResponseOutputMessageContentOutputText): + converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text)) + elif isinstance(content_part, OpenAIResponseInputMessageContentImage): + if content_part.image_url: + image_url = OpenAIImageURL(url=content_part.image_url, detail=content_part.detail) + converted_parts.append(OpenAIChatCompletionContentPartImageParam(image_url=image_url)) + elif isinstance(content_part, str): + converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part)) + else: + raise ValueError( + f"Llama Stack OpenAI Responses does not yet support content type '{type(content_part)}' in this context" + ) + return converted_parts + + +async def convert_response_input_to_chat_messages( + input: str | list[OpenAIResponseInput], +) -> list[OpenAIMessageParam]: + """ + Convert the input from an OpenAI Response API request into OpenAI Chat Completion messages. + """ + messages: list[OpenAIMessageParam] = [] + if isinstance(input, list): + # extract all OpenAIResponseInputFunctionToolCallOutput items + # so their corresponding OpenAIToolMessageParam instances can + # be added immediately following the corresponding + # OpenAIAssistantMessageParam + tool_call_results = {} + for input_item in input: + if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput): + tool_call_results[input_item.call_id] = OpenAIToolMessageParam( + content=input_item.output, + tool_call_id=input_item.call_id, + ) + + for input_item in input: + if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput): + # skip as these have been extracted and inserted in order + pass + elif isinstance(input_item, OpenAIResponseOutputMessageFunctionToolCall): + tool_call = OpenAIChatCompletionToolCall( + index=0, + id=input_item.call_id, + function=OpenAIChatCompletionToolCallFunction( + name=input_item.name, + arguments=input_item.arguments, + ), + ) + messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call])) + if input_item.call_id in tool_call_results: + messages.append(tool_call_results[input_item.call_id]) + del tool_call_results[input_item.call_id] + elif isinstance(input_item, OpenAIResponseOutputMessageMCPCall): + tool_call = OpenAIChatCompletionToolCall( + index=0, + id=input_item.id, + function=OpenAIChatCompletionToolCallFunction( + name=input_item.name, + arguments=input_item.arguments, + ), + ) + messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call])) + messages.append( + OpenAIToolMessageParam( + content=input_item.output, + tool_call_id=input_item.id, + ) + ) + elif isinstance(input_item, OpenAIResponseOutputMessageMCPListTools): + # the tool list will be handled separately + pass + else: + content = await convert_response_content_to_chat_content(input_item.content) + message_type = await get_message_type_by_role(input_item.role) + if message_type is None: + raise ValueError( + f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context" + ) + messages.append(message_type(content=content)) + if len(tool_call_results): + raise ValueError( + f"Received function_call_output(s) with call_id(s) {tool_call_results.keys()}, but no corresponding function_call" + ) + else: + messages.append(OpenAIUserMessageParam(content=input)) + return messages + + +async def convert_response_text_to_chat_response_format( + text: OpenAIResponseText, +) -> OpenAIResponseFormatParam: + """ + Convert an OpenAI Response text parameter into an OpenAI Chat Completion response format. + """ + if not text.format or text.format["type"] == "text": + return OpenAIResponseFormatText(type="text") + if text.format["type"] == "json_object": + return OpenAIResponseFormatJSONObject() + if text.format["type"] == "json_schema": + return OpenAIResponseFormatJSONSchema( + json_schema=OpenAIJSONSchema(name=text.format["name"], schema=text.format["schema"]) + ) + raise ValueError(f"Unsupported text format: {text.format}") + + +async def get_message_type_by_role(role: str): + role_to_type = { + "user": OpenAIUserMessageParam, + "system": OpenAISystemMessageParam, + "assistant": OpenAIAssistantMessageParam, + "developer": OpenAIDeveloperMessageParam, + } + return role_to_type.get(role) + + +def is_function_tool_call( + tool_call: OpenAIChatCompletionToolCall, + tools: list[OpenAIResponseInputTool], +) -> bool: + if not tool_call.function: + return False + for t in tools: + if t.type == "function" and t.name == tool_call.function.name: + return True + return False diff --git a/llama_stack/providers/inline/agents/meta_reference/safety.py b/llama_stack/providers/inline/agents/meta_reference/safety.py index 605f387b7..8f3ecf5c9 100644 --- a/llama_stack/providers/inline/agents/meta_reference/safety.py +++ b/llama_stack/providers/inline/agents/meta_reference/safety.py @@ -5,13 +5,13 @@ # the root directory of this source tree. import asyncio -import logging from llama_stack.apis.inference import Message from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel +from llama_stack.log import get_logger from llama_stack.providers.utils.telemetry import tracing -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="agents::meta_reference") class SafetyException(Exception): # noqa: N818 diff --git a/llama_stack/providers/inline/batches/__init__.py b/llama_stack/providers/inline/batches/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/inline/batches/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/inline/batches/reference/__init__.py b/llama_stack/providers/inline/batches/reference/__init__.py new file mode 100644 index 000000000..a8ae92eb2 --- /dev/null +++ b/llama_stack/providers/inline/batches/reference/__init__.py @@ -0,0 +1,36 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from llama_stack.apis.files import Files +from llama_stack.apis.inference import Inference +from llama_stack.apis.models import Models +from llama_stack.core.datatypes import AccessRule, Api +from llama_stack.providers.utils.kvstore import kvstore_impl + +from .batches import ReferenceBatchesImpl +from .config import ReferenceBatchesImplConfig + +__all__ = ["ReferenceBatchesImpl", "ReferenceBatchesImplConfig"] + + +async def get_provider_impl(config: ReferenceBatchesImplConfig, deps: dict[Api, Any], policy: list[AccessRule]): + kvstore = await kvstore_impl(config.kvstore) + inference_api: Inference | None = deps.get(Api.inference) + files_api: Files | None = deps.get(Api.files) + models_api: Models | None = deps.get(Api.models) + + if inference_api is None: + raise ValueError("Inference API is required but not provided in dependencies") + if files_api is None: + raise ValueError("Files API is required but not provided in dependencies") + if models_api is None: + raise ValueError("Models API is required but not provided in dependencies") + + impl = ReferenceBatchesImpl(config, inference_api, files_api, models_api, kvstore) + await impl.initialize() + return impl diff --git a/llama_stack/providers/inline/batches/reference/batches.py b/llama_stack/providers/inline/batches/reference/batches.py new file mode 100644 index 000000000..26f0ad15a --- /dev/null +++ b/llama_stack/providers/inline/batches/reference/batches.py @@ -0,0 +1,628 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +import hashlib +import itertools +import json +import time +import uuid +from io import BytesIO +from typing import Any, Literal + +from openai.types.batch import BatchError, Errors +from pydantic import BaseModel + +from llama_stack.apis.batches import Batches, BatchObject, ListBatchesResponse +from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError +from llama_stack.apis.files import Files, OpenAIFilePurpose +from llama_stack.apis.inference import ( + Inference, + OpenAIAssistantMessageParam, + OpenAIDeveloperMessageParam, + OpenAIMessageParam, + OpenAISystemMessageParam, + OpenAIToolMessageParam, + OpenAIUserMessageParam, +) +from llama_stack.apis.models import Models +from llama_stack.log import get_logger +from llama_stack.providers.utils.kvstore import KVStore + +from .config import ReferenceBatchesImplConfig + +BATCH_PREFIX = "batch:" + +logger = get_logger(__name__) + + +class AsyncBytesIO: + """ + Async-compatible BytesIO wrapper to allow async file-like operations. + + We use this when uploading files to the Files API, as it expects an + async file-like object. + """ + + def __init__(self, data: bytes): + self._buffer = BytesIO(data) + + async def read(self, n=-1): + return self._buffer.read(n) + + async def seek(self, pos, whence=0): + return self._buffer.seek(pos, whence) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self._buffer.close() + + def __getattr__(self, name): + return getattr(self._buffer, name) + + +class BatchRequest(BaseModel): + line_num: int + custom_id: str + method: str + url: str + body: dict[str, Any] + + +def convert_to_openai_message_param(msg: dict[str, Any]) -> OpenAIMessageParam: + """Convert a message dictionary to OpenAIMessageParam based on role.""" + role = msg.get("role") + + if role == "user": + return OpenAIUserMessageParam(**msg) + elif role == "system": + return OpenAISystemMessageParam(**msg) + elif role == "assistant": + return OpenAIAssistantMessageParam(**msg) + elif role == "tool": + return OpenAIToolMessageParam(**msg) + elif role == "developer": + return OpenAIDeveloperMessageParam(**msg) + else: + raise ValueError(f"Unknown message role: {role}") + + +class ReferenceBatchesImpl(Batches): + """Reference implementation of the Batches API. + + This implementation processes batch files by making individual requests + to the inference API and generates output files with results. + """ + + def __init__( + self, + config: ReferenceBatchesImplConfig, + inference_api: Inference, + files_api: Files, + models_api: Models, + kvstore: KVStore, + ) -> None: + self.config = config + self.kvstore = kvstore + self.inference_api = inference_api + self.files_api = files_api + self.models_api = models_api + self._processing_tasks: dict[str, asyncio.Task] = {} + self._batch_semaphore = asyncio.Semaphore(config.max_concurrent_batches) + self._update_batch_lock = asyncio.Lock() + + # this is to allow tests to disable background processing + self.process_batches = True + + async def initialize(self) -> None: + # TODO: start background processing of existing tasks + pass + + async def shutdown(self) -> None: + """Shutdown the batches provider.""" + if self._processing_tasks: + # don't cancel tasks - just let them stop naturally on shutdown + # cancelling would mark batches as "cancelled" in the database + logger.info(f"Shutdown initiated with {len(self._processing_tasks)} active batch processing tasks") + + # TODO (SECURITY): this currently works w/ configured api keys, not with x-llamastack-provider-data or with user policy restrictions + async def create_batch( + self, + input_file_id: str, + endpoint: str, + completion_window: Literal["24h"], + metadata: dict[str, str] | None = None, + idempotency_key: str | None = None, + ) -> BatchObject: + """ + Create a new batch for processing multiple API requests. + + This implementation provides optional idempotency: when an idempotency key + (idempotency_key) is provided, a deterministic ID is generated based on the input + parameters. If a batch with the same parameters already exists, it will be + returned instead of creating a duplicate. Without an idempotency key, + each request creates a new batch with a unique ID. + + Args: + input_file_id: The ID of an uploaded file containing requests for the batch. + endpoint: The endpoint to be used for all requests in the batch. + completion_window: The time window within which the batch should be processed. + metadata: Optional metadata for the batch. + idempotency_key: Optional idempotency key for enabling idempotent behavior. + + Returns: + The created or existing batch object. + """ + + # Error handling by levels - + # 0. Input param handling, results in 40x errors before processing, e.g. + # - Wrong completion_window + # - Invalid metadata types + # - Unknown endpoint + # -> no batch created + # 1. Errors preventing processing, result in BatchErrors aggregated in process_batch, e.g. + # - input_file_id missing + # - invalid json in file + # - missing custom_id, method, url, body + # - invalid model + # - streaming + # -> batch created, validation sends to failed status + # 2. Processing errors, result in error_file_id entries, e.g. + # - Any error returned from inference endpoint + # -> batch created, goes to completed status + + # TODO: set expiration time for garbage collection + + if endpoint not in ["/v1/chat/completions"]: + raise ValueError( + f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions. Code: invalid_value. Param: endpoint", + ) + + if completion_window != "24h": + raise ValueError( + f"Invalid completion_window: {completion_window}. Supported values are: 24h. Code: invalid_value. Param: completion_window", + ) + + batch_id = f"batch_{uuid.uuid4().hex[:16]}" + + # For idempotent requests, use the idempotency key for the batch ID + # This ensures the same key always maps to the same batch ID, + # allowing us to detect parameter conflicts + if idempotency_key is not None: + hash_input = idempotency_key.encode("utf-8") + hash_digest = hashlib.sha256(hash_input).hexdigest()[:24] + batch_id = f"batch_{hash_digest}" + + try: + existing_batch = await self.retrieve_batch(batch_id) + + if ( + existing_batch.input_file_id != input_file_id + or existing_batch.endpoint != endpoint + or existing_batch.completion_window != completion_window + or existing_batch.metadata != metadata + ): + raise ConflictError( + f"Idempotency key '{idempotency_key}' was previously used with different parameters. " + "Either use a new idempotency key or ensure all parameters match the original request." + ) + + logger.info(f"Returning existing batch with ID: {batch_id}") + return existing_batch + except ResourceNotFoundError: + # Batch doesn't exist, continue with creation + pass + + current_time = int(time.time()) + + batch = BatchObject( + id=batch_id, + object="batch", + endpoint=endpoint, + input_file_id=input_file_id, + completion_window=completion_window, + status="validating", + created_at=current_time, + metadata=metadata, + ) + + await self.kvstore.set(f"batch:{batch_id}", batch.to_json()) + logger.info(f"Created new batch with ID: {batch_id}") + + if self.process_batches: + task = asyncio.create_task(self._process_batch(batch_id)) + self._processing_tasks[batch_id] = task + + return batch + + async def cancel_batch(self, batch_id: str) -> BatchObject: + """Cancel a batch that is in progress.""" + batch = await self.retrieve_batch(batch_id) + + if batch.status in ["cancelled", "cancelling"]: + return batch + + if batch.status in ["completed", "failed", "expired"]: + raise ConflictError(f"Cannot cancel batch '{batch_id}' with status '{batch.status}'") + + await self._update_batch(batch_id, status="cancelling", cancelling_at=int(time.time())) + + if batch_id in self._processing_tasks: + self._processing_tasks[batch_id].cancel() + # note: task removal and status="cancelled" handled in finally block of _process_batch + + return await self.retrieve_batch(batch_id) + + async def list_batches( + self, + after: str | None = None, + limit: int = 20, + ) -> ListBatchesResponse: + """ + List all batches, eventually only for the current user. + + With no notion of user, we return all batches. + """ + batch_values = await self.kvstore.values_in_range("batch:", "batch:\xff") + + batches = [] + for batch_data in batch_values: + if batch_data: + batches.append(BatchObject.model_validate_json(batch_data)) + + batches.sort(key=lambda b: b.created_at, reverse=True) + + start_idx = 0 + if after: + for i, batch in enumerate(batches): + if batch.id == after: + start_idx = i + 1 + break + + page_batches = batches[start_idx : start_idx + limit] + has_more = (start_idx + limit) < len(batches) + + first_id = page_batches[0].id if page_batches else None + last_id = page_batches[-1].id if page_batches else None + + return ListBatchesResponse( + data=page_batches, + first_id=first_id, + last_id=last_id, + has_more=has_more, + ) + + async def retrieve_batch(self, batch_id: str) -> BatchObject: + """Retrieve information about a specific batch.""" + batch_data = await self.kvstore.get(f"batch:{batch_id}") + if not batch_data: + raise ResourceNotFoundError(batch_id, "Batch", "batches.list()") + + return BatchObject.model_validate_json(batch_data) + + async def _update_batch(self, batch_id: str, **updates) -> None: + """Update batch fields in kvstore.""" + async with self._update_batch_lock: + try: + batch = await self.retrieve_batch(batch_id) + + # batch processing is async. once cancelling, only allow "cancelled" status updates + if batch.status == "cancelling" and updates.get("status") != "cancelled": + logger.info( + f"Skipping status update for cancelled batch {batch_id}: attempted {updates.get('status')}" + ) + return + + if "errors" in updates: + updates["errors"] = updates["errors"].model_dump() + + batch_dict = batch.model_dump() + batch_dict.update(updates) + + await self.kvstore.set(f"batch:{batch_id}", json.dumps(batch_dict)) + except Exception as e: + logger.error(f"Failed to update batch {batch_id}: {e}") + + async def _validate_input(self, batch: BatchObject) -> tuple[list[BatchError], list[BatchRequest]]: + """ + Read & validate input, return errors and valid input. + + Validation of + - input_file_id existance + - valid json + - custom_id, method, url, body presence and valid + - no streaming + """ + requests: list[BatchRequest] = [] + errors: list[BatchError] = [] + try: + await self.files_api.openai_retrieve_file(batch.input_file_id) + except Exception: + errors.append( + BatchError( + code="invalid_request", + line=None, + message=f"Cannot find file {batch.input_file_id}.", + param="input_file_id", + ) + ) + return errors, requests + + # TODO(SECURITY): do something about large files + file_content_response = await self.files_api.openai_retrieve_file_content(batch.input_file_id) + file_content = file_content_response.body.decode("utf-8") + for line_num, line in enumerate(file_content.strip().split("\n"), 1): + if line.strip(): # skip empty lines + try: + request = json.loads(line) + + if not isinstance(request, dict): + errors.append( + BatchError( + code="invalid_request", + line=line_num, + message="Each line must be a JSON dictionary object", + ) + ) + continue + + valid = True + + for param, expected_type, type_string in [ + ("custom_id", str, "string"), + ("method", str, "string"), + ("url", str, "string"), + ("body", dict, "JSON dictionary object"), + ]: + if param not in request: + errors.append( + BatchError( + code="missing_required_parameter", + line=line_num, + message=f"Missing required parameter: {param}", + param=param, + ) + ) + valid = False + elif not isinstance(request[param], expected_type): + param_name = "URL" if param == "url" else param.capitalize() + errors.append( + BatchError( + code="invalid_request", + line=line_num, + message=f"{param_name} must be a {type_string}", + param=param, + ) + ) + valid = False + + if (url := request.get("url")) and isinstance(url, str) and url != batch.endpoint: + errors.append( + BatchError( + code="invalid_url", + line=line_num, + message="URL provided for this request does not match the batch endpoint", + param="url", + ) + ) + valid = False + + if (body := request.get("body")) and isinstance(body, dict): + if body.get("stream", False): + errors.append( + BatchError( + code="streaming_unsupported", + line=line_num, + message="Streaming is not supported in batch processing", + param="body.stream", + ) + ) + valid = False + + for param, expected_type, type_string in [ + ("model", str, "a string"), + # messages is specific to /v1/chat/completions + # we could skip validating messages here and let inference fail. however, + # that would be a very expensive way to find out messages is wrong. + ("messages", list, "an array"), # TODO: allow messages to be a string? + ]: + if param not in body: + errors.append( + BatchError( + code="invalid_request", + line=line_num, + message=f"{param.capitalize()} parameter is required", + param=f"body.{param}", + ) + ) + valid = False + elif not isinstance(body[param], expected_type): + errors.append( + BatchError( + code="invalid_request", + line=line_num, + message=f"{param.capitalize()} must be {type_string}", + param=f"body.{param}", + ) + ) + valid = False + + if "model" in body and isinstance(body["model"], str): + try: + await self.models_api.get_model(body["model"]) + except Exception: + errors.append( + BatchError( + code="model_not_found", + line=line_num, + message=f"Model '{body['model']}' does not exist or is not supported", + param="body.model", + ) + ) + valid = False + + if valid: + assert isinstance(url, str), "URL must be a string" # for mypy + assert isinstance(body, dict), "Body must be a dictionary" # for mypy + requests.append( + BatchRequest( + line_num=line_num, + url=url, + method=request["method"], + custom_id=request["custom_id"], + body=body, + ), + ) + except json.JSONDecodeError: + errors.append( + BatchError( + code="invalid_json_line", + line=line_num, + message="This line is not parseable as valid JSON.", + ) + ) + + return errors, requests + + async def _process_batch(self, batch_id: str) -> None: + """Background task to process a batch of requests.""" + try: + logger.info(f"Starting batch processing for {batch_id}") + async with self._batch_semaphore: # semaphore to limit concurrency + logger.info(f"Acquired semaphore for batch {batch_id}") + await self._process_batch_impl(batch_id) + except asyncio.CancelledError: + logger.info(f"Batch processing cancelled for {batch_id}") + await self._update_batch(batch_id, status="cancelled", cancelled_at=int(time.time())) + except Exception as e: + logger.error(f"Batch processing failed for {batch_id}: {e}") + await self._update_batch( + batch_id, + status="failed", + failed_at=int(time.time()), + errors=Errors(data=[BatchError(code="internal_error", message=str(e))]), + ) + finally: + self._processing_tasks.pop(batch_id, None) + + async def _process_batch_impl(self, batch_id: str) -> None: + """Implementation of batch processing logic.""" + errors: list[BatchError] = [] + batch = await self.retrieve_batch(batch_id) + + errors, requests = await self._validate_input(batch) + if errors: + await self._update_batch(batch_id, status="failed", failed_at=int(time.time()), errors=Errors(data=errors)) + logger.info(f"Batch validation failed for {batch_id} with {len(errors)} errors") + return + + logger.info(f"Processing {len(requests)} requests for batch {batch_id}") + + total_requests = len(requests) + await self._update_batch( + batch_id, + status="in_progress", + request_counts={"total": total_requests, "completed": 0, "failed": 0}, + ) + + error_results = [] + success_results = [] + completed_count = 0 + failed_count = 0 + + for chunk in itertools.batched(requests, self.config.max_concurrent_requests_per_batch): + # we use a TaskGroup to ensure all process-single-request tasks are canceled when process-batch is cancelled + async with asyncio.TaskGroup() as tg: + chunk_tasks = [tg.create_task(self._process_single_request(batch_id, request)) for request in chunk] + + chunk_results = await asyncio.gather(*chunk_tasks, return_exceptions=True) + + for result in chunk_results: + if isinstance(result, dict) and result.get("error") is not None: # error response from inference + failed_count += 1 + error_results.append(result) + elif isinstance(result, dict) and result.get("response") is not None: # successful inference + completed_count += 1 + success_results.append(result) + else: # unexpected result + failed_count += 1 + errors.append(BatchError(code="internal_error", message=f"Unexpected result: {result}")) + + await self._update_batch( + batch_id, + request_counts={"total": total_requests, "completed": completed_count, "failed": failed_count}, + ) + + if errors: + await self._update_batch( + batch_id, status="failed", failed_at=int(time.time()), errors=Errors(data=errors) + ) + return + + try: + output_file_id = await self._create_output_file(batch_id, success_results, "success") + await self._update_batch(batch_id, output_file_id=output_file_id) + + error_file_id = await self._create_output_file(batch_id, error_results, "error") + await self._update_batch(batch_id, error_file_id=error_file_id) + + await self._update_batch(batch_id, status="completed", completed_at=int(time.time())) + + logger.info( + f"Batch processing completed for {batch_id}: {completed_count} completed, {failed_count} failed" + ) + except Exception as e: + # note: errors is empty at this point, so we don't lose anything by ignoring it + await self._update_batch( + batch_id, + status="failed", + failed_at=int(time.time()), + errors=Errors(data=[BatchError(code="output_failed", message=str(e))]), + ) + + async def _process_single_request(self, batch_id: str, request: BatchRequest) -> dict: + """Process a single request from the batch.""" + request_id = f"batch_req_{batch_id}_{request.line_num}" + + try: + # TODO(SECURITY): review body for security issues + request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]] + chat_response = await self.inference_api.openai_chat_completion(**request.body) + + # this is for mypy, we don't allow streaming so we'll get the right type + assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method" + return { + "id": request_id, + "custom_id": request.custom_id, + "response": { + "status_code": 200, + "request_id": request_id, # TODO: should this be different? + "body": chat_response.model_dump_json(), + }, + } + except Exception as e: + logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}") + return { + "id": request_id, + "custom_id": request.custom_id, + "error": {"type": "request_failed", "message": str(e)}, + } + + async def _create_output_file(self, batch_id: str, results: list[dict], file_type: str) -> str: + """ + Create an output file with batch results. + + This function filters results based on the specified file_type + and uploads the file to the Files API. + """ + output_lines = [json.dumps(result) for result in results] + + with AsyncBytesIO("\n".join(output_lines).encode("utf-8")) as file_buffer: + file_buffer.filename = f"{batch_id}_{file_type}.jsonl" + uploaded_file = await self.files_api.openai_upload_file(file=file_buffer, purpose=OpenAIFilePurpose.BATCH) + return uploaded_file.id diff --git a/llama_stack/providers/inline/batches/reference/config.py b/llama_stack/providers/inline/batches/reference/config.py new file mode 100644 index 000000000..d8d06868b --- /dev/null +++ b/llama_stack/providers/inline/batches/reference/config.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel, Field + +from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig + + +class ReferenceBatchesImplConfig(BaseModel): + """Configuration for the Reference Batches implementation.""" + + kvstore: KVStoreConfig = Field( + description="Configuration for the key-value store backend.", + ) + + max_concurrent_batches: int = Field( + default=1, + description="Maximum number of concurrent batches to process simultaneously.", + ge=1, + ) + + max_concurrent_requests_per_batch: int = Field( + default=10, + description="Maximum number of concurrent requests to process per batch.", + ge=1, + ) + + # TODO: add a max requests per second rate limiter + + @classmethod + def sample_run_config(cls, __distro_dir__: str) -> dict: + return { + "kvstore": SqliteKVStoreConfig.sample_run_config( + __distro_dir__=__distro_dir__, + db_name="batches.db", + ), + } diff --git a/llama_stack/providers/inline/files/localfs/files.py b/llama_stack/providers/inline/files/localfs/files.py index 1e9dca3b5..4f6d571a4 100644 --- a/llama_stack/providers/inline/files/localfs/files.py +++ b/llama_stack/providers/inline/files/localfs/files.py @@ -11,6 +11,7 @@ from typing import Annotated from fastapi import File, Form, Response, UploadFile +from llama_stack.apis.common.errors import ResourceNotFoundError from llama_stack.apis.common.responses import Order from llama_stack.apis.files import ( Files, @@ -20,12 +21,15 @@ from llama_stack.apis.files import ( OpenAIFilePurpose, ) from llama_stack.core.datatypes import AccessRule +from llama_stack.log import get_logger from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl from .config import LocalfsFilesImplConfig +logger = get_logger(name=__name__, category="files") + class LocalfsFilesImpl(Files): def __init__(self, config: LocalfsFilesImplConfig, policy: list[AccessRule]) -> None: @@ -65,6 +69,18 @@ class LocalfsFilesImpl(Files): """Get the filesystem path for a file ID.""" return Path(self.config.storage_dir) / file_id + async def _lookup_file_id(self, file_id: str) -> tuple[OpenAIFileObject, Path]: + """Look up a OpenAIFileObject and filesystem path from its ID.""" + if not self.sql_store: + raise RuntimeError("Files provider not initialized") + + row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id}) + if not row: + raise ResourceNotFoundError(file_id, "File", "client.files.list()") + + file_path = Path(row.pop("file_path")) + return OpenAIFileObject(**row), file_path + # OpenAI Files API Implementation async def openai_upload_file( self, @@ -157,37 +173,19 @@ class LocalfsFilesImpl(Files): async def openai_retrieve_file(self, file_id: str) -> OpenAIFileObject: """Returns information about a specific file.""" - if not self.sql_store: - raise RuntimeError("Files provider not initialized") + file_obj, _ = await self._lookup_file_id(file_id) - row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id}) - if not row: - raise ValueError(f"File with id {file_id} not found") - - return OpenAIFileObject( - id=row["id"], - filename=row["filename"], - purpose=OpenAIFilePurpose(row["purpose"]), - bytes=row["bytes"], - created_at=row["created_at"], - expires_at=row["expires_at"], - ) + return file_obj async def openai_delete_file(self, file_id: str) -> OpenAIFileDeleteResponse: """Delete a file.""" - if not self.sql_store: - raise RuntimeError("Files provider not initialized") - - row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id}) - if not row: - raise ValueError(f"File with id {file_id} not found") - # Delete physical file - file_path = Path(row["file_path"]) + _, file_path = await self._lookup_file_id(file_id) if file_path.exists(): file_path.unlink() # Delete metadata from database + assert self.sql_store is not None, "Files provider not initialized" await self.sql_store.delete("openai_files", where={"id": file_id}) return OpenAIFileDeleteResponse( @@ -197,25 +195,17 @@ class LocalfsFilesImpl(Files): async def openai_retrieve_file_content(self, file_id: str) -> Response: """Returns the contents of the specified file.""" - if not self.sql_store: - raise RuntimeError("Files provider not initialized") - - # Get file metadata - row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id}) - if not row: - raise ValueError(f"File with id {file_id} not found") - # Read file content - file_path = Path(row["file_path"]) - if not file_path.exists(): - raise ValueError(f"File content not found on disk: {file_path}") + file_obj, file_path = await self._lookup_file_id(file_id) - with open(file_path, "rb") as f: - content = f.read() + if not file_path.exists(): + logger.warning(f"File '{file_id}'s underlying '{file_path}' is missing, deleting metadata.") + await self.openai_delete_file(file_id) + raise ResourceNotFoundError(file_id, "File", "client.files.list()") # Return as binary response with appropriate content type return Response( - content=content, + content=file_path.read_bytes(), media_type="application/octet-stream", - headers={"Content-Disposition": f'attachment; filename="{row["filename"]}"'}, + headers={"Content-Disposition": f'attachment; filename="{file_obj.filename}"'}, ) diff --git a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py index 7ade75032..bb6a1bd03 100644 --- a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +++ b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py @@ -12,7 +12,6 @@ import copy import json -import logging import multiprocessing import os import tempfile @@ -32,13 +31,14 @@ from fairscale.nn.model_parallel.initialize import ( from pydantic import BaseModel, Field from torch.distributed.launcher.api import LaunchConfig, elastic_launch +from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import GenerationResult from llama_stack.providers.utils.inference.prompt_adapter import ( ChatCompletionRequestWithRawContent, CompletionRequestWithRawContent, ) -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="inference") class ProcessingMessageName(str, Enum): diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py index fea8a8189..34665b63e 100644 --- a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -4,13 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging from collections.abc import AsyncGenerator from llama_stack.apis.inference import ( CompletionResponse, InferenceProvider, - InterleavedContent, LogProbConfig, Message, ResponseFormat, @@ -21,6 +19,7 @@ from llama_stack.apis.inference import ( ToolPromptFormat, ) from llama_stack.apis.models import ModelType +from llama_stack.log import get_logger from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate from llama_stack.providers.utils.inference.embedding_mixin import ( SentenceTransformerEmbeddingMixin, @@ -32,7 +31,7 @@ from llama_stack.providers.utils.inference.openai_compat import ( from .config import SentenceTransformersInferenceConfig -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="inference") class SentenceTransformersInferenceImpl( @@ -100,25 +99,3 @@ class SentenceTransformersInferenceImpl( tool_config: ToolConfig | None = None, ) -> AsyncGenerator: raise ValueError("Sentence transformers don't support chat completion") - - async def batch_completion( - self, - model_id: str, - content_batch: list[InterleavedContent], - sampling_params: SamplingParams | None = None, - response_format: ResponseFormat | None = None, - logprobs: LogProbConfig | None = None, - ): - raise NotImplementedError("Batch completion is not supported for Sentence Transformers") - - async def batch_chat_completion( - self, - model_id: str, - messages_batch: list[list[Message]], - sampling_params: SamplingParams | None = None, - tools: list[ToolDefinition] | None = None, - tool_config: ToolConfig | None = None, - response_format: ResponseFormat | None = None, - logprobs: LogProbConfig | None = None, - ): - raise NotImplementedError("Batch chat completion is not supported for Sentence Transformers") diff --git a/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py b/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py index 2574b995b..d9ee3d2a8 100644 --- a/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py +++ b/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py @@ -6,7 +6,6 @@ import gc import json -import logging import multiprocessing from pathlib import Path from typing import Any @@ -28,6 +27,7 @@ from llama_stack.apis.post_training import ( LoraFinetuningConfig, TrainingConfig, ) +from llama_stack.log import get_logger from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device from ..config import HuggingFacePostTrainingConfig @@ -44,7 +44,7 @@ from ..utils import ( split_dataset, ) -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="post_training") class HFFinetuningSingleDevice: diff --git a/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device_dpo.py b/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device_dpo.py index a7c19faac..b39a24c66 100644 --- a/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device_dpo.py +++ b/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device_dpo.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import gc -import logging import multiprocessing from pathlib import Path from typing import Any @@ -24,6 +23,7 @@ from llama_stack.apis.post_training import ( DPOAlignmentConfig, TrainingConfig, ) +from llama_stack.log import get_logger from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device from ..config import HuggingFacePostTrainingConfig @@ -40,7 +40,7 @@ from ..utils import ( split_dataset, ) -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="post_training") class HFDPOAlignmentSingleDevice: diff --git a/llama_stack/providers/inline/post_training/huggingface/utils.py b/llama_stack/providers/inline/post_training/huggingface/utils.py index 3147c19ab..f229c87dd 100644 --- a/llama_stack/providers/inline/post_training/huggingface/utils.py +++ b/llama_stack/providers/inline/post_training/huggingface/utils.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging import os import signal import sys @@ -19,10 +18,11 @@ from transformers import AutoConfig, AutoModelForCausalLM from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.post_training import Checkpoint, TrainingConfig +from llama_stack.log import get_logger from .config import HuggingFacePostTrainingConfig -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="post_training") def setup_environment(): diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index 49e1c95b8..8b1462862 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging import os import time from datetime import UTC, datetime @@ -19,6 +18,7 @@ from torch.utils.data import DataLoader, DistributedSampler from torchtune import modules, training from torchtune import utils as torchtune_utils from torchtune.data import padded_collate_sft +from torchtune.models.llama3._tokenizer import Llama3Tokenizer from torchtune.modules.loss import CEWithChunkedOutputLoss from torchtune.modules.peft import ( get_adapter_params, @@ -45,6 +45,7 @@ from llama_stack.apis.post_training import ( ) from llama_stack.core.utils.config_dirs import DEFAULT_CHECKPOINT_DIR from llama_stack.core.utils.model_utils import model_local_dir +from llama_stack.log import get_logger from llama_stack.models.llama.sku_list import resolve_model from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device from llama_stack.providers.inline.post_training.torchtune.common import utils @@ -56,9 +57,7 @@ from llama_stack.providers.inline.post_training.torchtune.config import ( ) from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset -log = logging.getLogger(__name__) - -from torchtune.models.llama3._tokenizer import Llama3Tokenizer +log = get_logger(name=__name__, category="post_training") class LoraFinetuningSingleDevice: diff --git a/llama_stack/providers/inline/safety/code_scanner/code_scanner.py b/llama_stack/providers/inline/safety/code_scanner/code_scanner.py index be05ee436..5e25c559f 100644 --- a/llama_stack/providers/inline/safety/code_scanner/code_scanner.py +++ b/llama_stack/providers/inline/safety/code_scanner/code_scanner.py @@ -4,8 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging -from typing import Any +import uuid +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from codeshield.cs import CodeShieldScanResult from llama_stack.apis.inference import Message from llama_stack.apis.safety import ( @@ -14,18 +17,20 @@ from llama_stack.apis.safety import ( SafetyViolation, ViolationLevel, ) +from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults from llama_stack.apis.shields import Shield +from llama_stack.log import get_logger from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, ) from .config import CodeScannerConfig -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="safety") ALLOWED_CODE_SCANNER_MODEL_IDS = [ - "CodeScanner", - "CodeShield", + "code-scanner", + "code-shield", ] @@ -69,3 +74,55 @@ class MetaReferenceCodeScannerSafetyImpl(Safety): metadata={"violation_type": ",".join([issue.pattern_id for issue in result.issues_found])}, ) return RunShieldResponse(violation=violation) + + def get_moderation_object_results(self, scan_result: "CodeShieldScanResult") -> ModerationObjectResults: + categories = {} + category_scores = {} + category_applied_input_types = {} + + flagged = scan_result.is_insecure + user_message = None + metadata = {} + + if scan_result.is_insecure: + pattern_ids = [issue.pattern_id for issue in scan_result.issues_found] + categories = dict.fromkeys(pattern_ids, True) + category_scores = dict.fromkeys(pattern_ids, 1.0) + category_applied_input_types = {key: ["text"] for key in pattern_ids} + user_message = f"Security concerns detected in the code. {scan_result.recommended_treatment.name}: {', '.join([issue.description for issue in scan_result.issues_found])}" + metadata = {"violation_type": ",".join([issue.pattern_id for issue in scan_result.issues_found])} + + return ModerationObjectResults( + flagged=flagged, + categories=categories, + category_scores=category_scores, + category_applied_input_types=category_applied_input_types, + user_message=user_message, + metadata=metadata, + ) + + async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject: + inputs = input if isinstance(input, list) else [input] + results = [] + + from codeshield.cs import CodeShield + + for text_input in inputs: + log.info(f"Running CodeScannerShield moderation on input: {text_input[:100]}...") + try: + scan_result = await CodeShield.scan_code(text_input) + moderation_result = self.get_moderation_object_results(scan_result) + except Exception as e: + log.error(f"CodeShield.scan_code failed: {e}") + # create safe fallback response on scanner failure to avoid blocking legitimate requests + moderation_result = ModerationObjectResults( + flagged=False, + categories={}, + category_scores={}, + category_applied_input_types={}, + user_message=None, + metadata={"scanner_error": str(e)}, + ) + results.append(moderation_result) + + return ModerationObject(id=str(uuid.uuid4()), model=model, results=results) diff --git a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py index bae744010..5c7f30aa7 100644 --- a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py +++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py @@ -4,18 +4,13 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging import re import uuid from string import Template from typing import Any from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem -from llama_stack.apis.inference import ( - Inference, - Message, - UserMessage, -) +from llama_stack.apis.inference import Inference, Message, UserMessage from llama_stack.apis.safety import ( RunShieldResponse, Safety, @@ -25,6 +20,7 @@ from llama_stack.apis.safety import ( from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults from llama_stack.apis.shields import Shield from llama_stack.core.datatypes import Api +from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import Role from llama_stack.models.llama.sku_types import CoreModelId from llama_stack.providers.datatypes import ShieldsProtocolPrivate @@ -72,7 +68,6 @@ SAFETY_CATEGORIES_TO_CODE_MAP = { } SAFETY_CODE_TO_CATEGORIES_MAP = {v: k for k, v in SAFETY_CATEGORIES_TO_CODE_MAP.items()} - DEFAULT_LG_V3_SAFETY_CATEGORIES = [ CAT_VIOLENT_CRIMES, CAT_NON_VIOLENT_CRIMES, @@ -137,6 +132,8 @@ Provide your safety assessment for ONLY THE LAST $agent_type message in the abov PROMPT_TEMPLATE = Template(f"{PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATION}{PROMPT_INSTRUCTIONS}") +logger = get_logger(name=__name__, category="safety") + class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate): def __init__(self, config: LlamaGuardConfig, deps) -> None: @@ -412,7 +409,7 @@ class LlamaGuardShield: unsafe_code_list = [code.strip() for code in unsafe_code.split(",")] invalid_codes = [code for code in unsafe_code_list if code not in SAFETY_CODE_TO_CATEGORIES_MAP] if invalid_codes: - logging.warning(f"Invalid safety codes returned: {invalid_codes}") + logger.warning(f"Invalid safety codes returned: {invalid_codes}") # just returning safe object, as we don't know what the invalid codes can map to return ModerationObject( id=f"modr-{uuid.uuid4()}", @@ -460,7 +457,7 @@ class LlamaGuardShield: def is_content_safe(self, response: str, unsafe_code: str | None = None) -> bool: """Check if content is safe based on response and unsafe code.""" - if response.strip() == SAFE_RESPONSE: + if response.strip().lower().startswith(SAFE_RESPONSE): return True if unsafe_code: diff --git a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py index c760f0fd1..6fb6c4407 100644 --- a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +++ b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging from typing import Any import torch @@ -21,6 +20,7 @@ from llama_stack.apis.safety import ( from llama_stack.apis.safety.safety import ModerationObject from llama_stack.apis.shields import Shield from llama_stack.core.utils.model_utils import model_local_dir +from llama_stack.log import get_logger from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, @@ -28,7 +28,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import PromptGuardConfig, PromptGuardType -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="safety") PROMPT_GUARD_MODEL = "Prompt-Guard-86M" diff --git a/llama_stack/providers/inline/scoring/basic/utils/ifeval_utils.py b/llama_stack/providers/inline/scoring/basic/utils/ifeval_utils.py index b74c3826e..c9358101d 100644 --- a/llama_stack/providers/inline/scoring/basic/utils/ifeval_utils.py +++ b/llama_stack/providers/inline/scoring/basic/utils/ifeval_utils.py @@ -7,7 +7,6 @@ import collections import functools import json -import logging import random import re import string @@ -20,7 +19,9 @@ import nltk from pythainlp.tokenize import sent_tokenize as sent_tokenize_thai from pythainlp.tokenize import word_tokenize as word_tokenize_thai -logger = logging.getLogger() +from llama_stack.log import get_logger + +logger = get_logger(name=__name__, category="scoring") WORD_LIST = [ "western", diff --git a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py index d99255c79..9224c3792 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py @@ -4,13 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging +import datetime import threading from typing import Any from opentelemetry import metrics, trace - -logger = logging.getLogger(__name__) from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter from opentelemetry.sdk.metrics import MeterProvider @@ -40,6 +38,7 @@ from llama_stack.apis.telemetry import ( UnstructuredLogEvent, ) from llama_stack.core.datatypes import Api +from llama_stack.log import get_logger from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import ( ConsoleSpanProcessor, ) @@ -61,6 +60,8 @@ _GLOBAL_STORAGE: dict[str, dict[str | int, Any]] = { _global_lock = threading.Lock() _TRACER_PROVIDER = None +logger = get_logger(name=__name__, category="telemetry") + def is_tracing_enabled(tracer): with tracer.start_as_current_span("check_tracing") as span: @@ -145,11 +146,41 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): metric_name: str, start_time: int, end_time: int | None = None, - granularity: str | None = "1d", + granularity: str | None = None, query_type: MetricQueryType = MetricQueryType.RANGE, label_matchers: list[MetricLabelMatcher] | None = None, ) -> QueryMetricsResponse: - raise NotImplementedError("Querying metrics is not implemented") + """Query metrics from the telemetry store. + + Args: + metric_name: The name of the metric to query (e.g., "prompt_tokens") + start_time: Start time as Unix timestamp + end_time: End time as Unix timestamp (defaults to now if None) + granularity: Time granularity for aggregation + query_type: Type of query (RANGE or INSTANT) + label_matchers: Label filters to apply + + Returns: + QueryMetricsResponse with metric time series data + """ + # Convert timestamps to datetime objects + start_dt = datetime.datetime.fromtimestamp(start_time, datetime.UTC) + end_dt = datetime.datetime.fromtimestamp(end_time, datetime.UTC) if end_time else None + + # Use SQLite trace store if available + if hasattr(self, "trace_store") and self.trace_store: + return await self.trace_store.query_metrics( + metric_name=metric_name, + start_time=start_dt, + end_time=end_dt, + granularity=granularity, + query_type=query_type, + label_matchers=label_matchers, + ) + else: + raise ValueError( + f"In order to query_metrics, you must have {TelemetrySink.SQLITE} set in your telemetry sinks" + ) def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None: with self._lock: diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index 6a7c7885c..a1543457b 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import asyncio -import logging import secrets import string from typing import Any @@ -32,6 +31,7 @@ from llama_stack.apis.tools import ( ToolRuntime, ) from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO +from llama_stack.log import get_logger from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str from llama_stack.providers.utils.memory.vector_store import ( @@ -42,7 +42,7 @@ from llama_stack.providers.utils.memory.vector_store import ( from .config import RagToolRuntimeConfig from .context_retriever import generate_rag_query -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="tool_runtime") def make_random_string(length: int = 8): diff --git a/llama_stack/providers/inline/vector_io/faiss/faiss.py b/llama_stack/providers/inline/vector_io/faiss/faiss.py index af61da59b..258c6e7aa 100644 --- a/llama_stack/providers/inline/vector_io/faiss/faiss.py +++ b/llama_stack/providers/inline/vector_io/faiss/faiss.py @@ -8,7 +8,6 @@ import asyncio import base64 import io import json -import logging from typing import Any import faiss @@ -24,6 +23,7 @@ from llama_stack.apis.vector_io import ( QueryChunksResponse, VectorIO, ) +from llama_stack.log import get_logger from llama_stack.providers.datatypes import ( HealthResponse, HealthStatus, @@ -40,7 +40,7 @@ from llama_stack.providers.utils.memory.vector_store import ( from .config import FaissVectorIOConfig -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="vector_io") VERSION = "v3" VECTOR_DBS_PREFIX = f"vector_dbs:{VERSION}::" diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py index cc1982f3b..7cf163960 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import asyncio -import logging import re import sqlite3 import struct @@ -24,6 +23,7 @@ from llama_stack.apis.vector_io import ( QueryChunksResponse, VectorIO, ) +from llama_stack.log import get_logger from llama_stack.providers.datatypes import VectorDBsProtocolPrivate from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore.api import KVStore @@ -36,7 +36,7 @@ from llama_stack.providers.utils.memory.vector_store import ( VectorDBWithIndex, ) -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="vector_io") # Specifying search mode is dependent on the VectorIO provider. VECTOR_SEARCH = "vector" diff --git a/llama_stack/providers/registry/batches.py b/llama_stack/providers/registry/batches.py new file mode 100644 index 000000000..de7886efb --- /dev/null +++ b/llama_stack/providers/registry/batches.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec + + +def available_providers() -> list[ProviderSpec]: + return [ + InlineProviderSpec( + api=Api.batches, + provider_type="inline::reference", + pip_packages=["openai"], + module="llama_stack.providers.inline.batches.reference", + config_class="llama_stack.providers.inline.batches.reference.config.ReferenceBatchesImplConfig", + api_dependencies=[ + Api.inference, + Api.files, + Api.models, + ], + description="Reference implementation of batches API with KVStore persistence.", + ), + ] diff --git a/llama_stack/providers/registry/files.py b/llama_stack/providers/registry/files.py index e894debaf..ebe90310c 100644 --- a/llama_stack/providers/registry/files.py +++ b/llama_stack/providers/registry/files.py @@ -5,9 +5,11 @@ # the root directory of this source tree. from llama_stack.providers.datatypes import ( + AdapterSpec, Api, InlineProviderSpec, ProviderSpec, + remote_provider_spec, ) from llama_stack.providers.utils.sqlstore.sqlstore import sql_store_pip_packages @@ -23,4 +25,14 @@ def available_providers() -> list[ProviderSpec]: config_class="llama_stack.providers.inline.files.localfs.config.LocalfsFilesImplConfig", description="Local filesystem-based file storage provider for managing files and documents locally.", ), + remote_provider_spec( + api=Api.files, + adapter=AdapterSpec( + adapter_type="s3", + pip_packages=["boto3"] + sql_store_pip_packages, + module="llama_stack.providers.remote.files.s3", + config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig", + description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.", + ), + ), ] diff --git a/llama_stack/providers/registry/post_training.py b/llama_stack/providers/registry/post_training.py index ffd64ef7c..4443f4df1 100644 --- a/llama_stack/providers/registry/post_training.py +++ b/llama_stack/providers/registry/post_training.py @@ -5,34 +5,74 @@ # the root directory of this source tree. +from typing import cast + from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec +# We provide two versions of these providers so that distributions can package the appropriate version of torch. +# The CPU version is used for distributions that don't have GPU support -- they result in smaller container images. +torchtune_def = dict( + api=Api.post_training, + pip_packages=["torchtune==0.5.0", "torchao==0.8.0", "numpy"], + module="llama_stack.providers.inline.post_training.torchtune", + config_class="llama_stack.providers.inline.post_training.torchtune.TorchtunePostTrainingConfig", + api_dependencies=[ + Api.datasetio, + Api.datasets, + ], + description="TorchTune-based post-training provider for fine-tuning and optimizing models using Meta's TorchTune framework.", +) + +huggingface_def = dict( + api=Api.post_training, + pip_packages=["trl", "transformers", "peft", "datasets"], + module="llama_stack.providers.inline.post_training.huggingface", + config_class="llama_stack.providers.inline.post_training.huggingface.HuggingFacePostTrainingConfig", + api_dependencies=[ + Api.datasetio, + Api.datasets, + ], + description="HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem.", +) + def available_providers() -> list[ProviderSpec]: return [ InlineProviderSpec( - api=Api.post_training, - provider_type="inline::torchtune", - pip_packages=["torch", "torchtune==0.5.0", "torchao==0.8.0", "numpy"], - module="llama_stack.providers.inline.post_training.torchtune", - config_class="llama_stack.providers.inline.post_training.torchtune.TorchtunePostTrainingConfig", - api_dependencies=[ - Api.datasetio, - Api.datasets, - ], - description="TorchTune-based post-training provider for fine-tuning and optimizing models using Meta's TorchTune framework.", + **{ + **torchtune_def, + "provider_type": "inline::torchtune-cpu", + "pip_packages": ( + cast(list[str], torchtune_def["pip_packages"]) + + ["torch torchtune==0.5.0 torchao==0.8.0 --index-url https://download.pytorch.org/whl/cpu"] + ), + }, ), InlineProviderSpec( - api=Api.post_training, - provider_type="inline::huggingface", - pip_packages=["torch", "trl", "transformers", "peft", "datasets"], - module="llama_stack.providers.inline.post_training.huggingface", - config_class="llama_stack.providers.inline.post_training.huggingface.HuggingFacePostTrainingConfig", - api_dependencies=[ - Api.datasetio, - Api.datasets, - ], - description="HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem.", + **{ + **huggingface_def, + "provider_type": "inline::huggingface-cpu", + "pip_packages": ( + cast(list[str], huggingface_def["pip_packages"]) + + ["torch --index-url https://download.pytorch.org/whl/cpu"] + ), + }, + ), + InlineProviderSpec( + **{ + **torchtune_def, + "provider_type": "inline::torchtune-gpu", + "pip_packages": ( + cast(list[str], torchtune_def["pip_packages"]) + ["torch torchtune==0.5.0 torchao==0.8.0"] + ), + }, + ), + InlineProviderSpec( + **{ + **huggingface_def, + "provider_type": "inline::huggingface-gpu", + "pip_packages": (cast(list[str], huggingface_def["pip_packages"]) + ["torch"]), + }, ), remote_provider_spec( api=Api.post_training, diff --git a/llama_stack/providers/remote/files/s3/README.md b/llama_stack/providers/remote/files/s3/README.md new file mode 100644 index 000000000..0f33122c7 --- /dev/null +++ b/llama_stack/providers/remote/files/s3/README.md @@ -0,0 +1,237 @@ +# S3 Files Provider + +A remote S3-based implementation of the Llama Stack Files API that provides scalable cloud file storage with metadata persistence. + +## Features + +- **AWS S3 Storage**: Store files in AWS S3 buckets for scalable, durable storage +- **Metadata Management**: Uses SQL database for efficient file metadata queries +- **OpenAI API Compatibility**: Full compatibility with OpenAI Files API endpoints +- **Flexible Authentication**: Support for IAM roles and access keys +- **Custom S3 Endpoints**: Support for MinIO and other S3-compatible services + +## Configuration + +### Basic Configuration + +```yaml +api: files +provider_type: remote::s3 +config: + bucket_name: my-llama-stack-files + region: us-east-1 + metadata_store: + type: sqlite + db_path: ./s3_files_metadata.db +``` + +### Advanced Configuration + +```yaml +api: files +provider_type: remote::s3 +config: + bucket_name: my-llama-stack-files + region: us-east-1 + aws_access_key_id: YOUR_ACCESS_KEY + aws_secret_access_key: YOUR_SECRET_KEY + endpoint_url: https://s3.amazonaws.com # Optional for custom endpoints + metadata_store: + type: sqlite + db_path: ./s3_files_metadata.db +``` + +### Environment Variables + +The configuration supports environment variable substitution: + +```yaml +config: + bucket_name: "${env.S3_BUCKET_NAME}" + region: "${env.AWS_REGION:=us-east-1}" + aws_access_key_id: "${env.AWS_ACCESS_KEY_ID:=}" + aws_secret_access_key: "${env.AWS_SECRET_ACCESS_KEY:=}" + endpoint_url: "${env.S3_ENDPOINT_URL:=}" +``` + +Note: `S3_BUCKET_NAME` has no default value since S3 bucket names must be globally unique. + +## Authentication + +### IAM Roles (Recommended) + +For production deployments, use IAM roles: + +```yaml +config: + bucket_name: my-bucket + region: us-east-1 + # No credentials needed - will use IAM role +``` + +### Access Keys + +For development or specific use cases: + +```yaml +config: + bucket_name: my-bucket + region: us-east-1 + aws_access_key_id: AKIAIOSFODNN7EXAMPLE + aws_secret_access_key: wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY +``` + +## S3 Bucket Setup + +### Required Permissions + +The S3 provider requires the following permissions: + +```json +{ + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": [ + "s3:GetObject", + "s3:PutObject", + "s3:DeleteObject", + "s3:ListBucket" + ], + "Resource": [ + "arn:aws:s3:::your-bucket-name", + "arn:aws:s3:::your-bucket-name/*" + ] + } + ] +} +``` + +### Automatic Bucket Creation + +By default, the S3 provider expects the bucket to already exist. If you want the provider to automatically create the bucket when it doesn't exist, set `auto_create_bucket: true` in your configuration: + +```yaml +config: + bucket_name: my-bucket + auto_create_bucket: true # Will create bucket if it doesn't exist + region: us-east-1 +``` + +**Note**: When `auto_create_bucket` is enabled, the provider will need additional permissions: + +```json +{ + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": [ + "s3:GetObject", + "s3:PutObject", + "s3:DeleteObject", + "s3:ListBucket", + "s3:CreateBucket" + ], + "Resource": [ + "arn:aws:s3:::your-bucket-name", + "arn:aws:s3:::your-bucket-name/*" + ] + } + ] +} +``` + +### Bucket Policy (Optional) + +For additional security, you can add a bucket policy: + +```json +{ + "Version": "2012-10-17", + "Statement": [ + { + "Sid": "LlamaStackAccess", + "Effect": "Allow", + "Principal": { + "AWS": "arn:aws:iam::YOUR-ACCOUNT:role/LlamaStackRole" + }, + "Action": [ + "s3:GetObject", + "s3:PutObject", + "s3:DeleteObject" + ], + "Resource": "arn:aws:s3:::your-bucket-name/*" + }, + { + "Sid": "LlamaStackBucketAccess", + "Effect": "Allow", + "Principal": { + "AWS": "arn:aws:iam::YOUR-ACCOUNT:role/LlamaStackRole" + }, + "Action": [ + "s3:ListBucket" + ], + "Resource": "arn:aws:s3:::your-bucket-name" + } + ] +} +``` + +## Features + +### Metadata Persistence + +File metadata is stored in a SQL database for fast queries and OpenAI API compatibility. The metadata includes: + +- File ID +- Original filename +- Purpose (assistants, batch, etc.) +- File size in bytes +- Created and expiration timestamps + +### TTL and Cleanup + +Files currently have a fixed long expiration time (100 years). + +## Development and Testing + +### Using MinIO + +For self-hosted S3-compatible storage: + +```yaml +config: + bucket_name: test-bucket + region: us-east-1 + endpoint_url: http://localhost:9000 + aws_access_key_id: minioadmin + aws_secret_access_key: minioadmin +``` + +## Monitoring and Logging + +The provider logs important operations and errors. For production deployments, consider: + +- CloudWatch monitoring for S3 operations +- Custom metrics for file upload/download rates +- Error rate monitoring +- Performance metrics tracking + +## Error Handling + +The provider handles various error scenarios: + +- S3 connectivity issues +- Bucket access permissions +- File not found errors +- Metadata consistency checks + +## Known Limitations + +- Fixed long TTL (100 years) instead of configurable expiration +- No server-side encryption enabled by default +- No support for AWS session tokens +- No S3 key prefix organization support +- No multipart upload support (all files uploaded as single objects) diff --git a/llama_stack/providers/remote/files/s3/__init__.py b/llama_stack/providers/remote/files/s3/__init__.py new file mode 100644 index 000000000..3f5dfc88a --- /dev/null +++ b/llama_stack/providers/remote/files/s3/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from llama_stack.core.datatypes import Api + +from .config import S3FilesImplConfig + + +async def get_adapter_impl(config: S3FilesImplConfig, deps: dict[Api, Any]): + from .files import S3FilesImpl + + # TODO: authorization policies and user separation + impl = S3FilesImpl(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/remote/files/s3/config.py b/llama_stack/providers/remote/files/s3/config.py new file mode 100644 index 000000000..da20d8668 --- /dev/null +++ b/llama_stack/providers/remote/files/s3/config.py @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from pydantic import BaseModel, Field + +from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig + + +class S3FilesImplConfig(BaseModel): + """Configuration for S3-based files provider.""" + + bucket_name: str = Field(description="S3 bucket name to store files") + region: str = Field(default="us-east-1", description="AWS region where the bucket is located") + aws_access_key_id: str | None = Field(default=None, description="AWS access key ID (optional if using IAM roles)") + aws_secret_access_key: str | None = Field( + default=None, description="AWS secret access key (optional if using IAM roles)" + ) + endpoint_url: str | None = Field(default=None, description="Custom S3 endpoint URL (for MinIO, LocalStack, etc.)") + auto_create_bucket: bool = Field( + default=False, description="Automatically create the S3 bucket if it doesn't exist" + ) + metadata_store: SqlStoreConfig = Field(description="SQL store configuration for file metadata") + + @classmethod + def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]: + return { + "bucket_name": "${env.S3_BUCKET_NAME}", # no default, buckets must be globally unique + "region": "${env.AWS_REGION:=us-east-1}", + "aws_access_key_id": "${env.AWS_ACCESS_KEY_ID:=}", + "aws_secret_access_key": "${env.AWS_SECRET_ACCESS_KEY:=}", + "endpoint_url": "${env.S3_ENDPOINT_URL:=}", + "auto_create_bucket": "${env.S3_AUTO_CREATE_BUCKET:=false}", + "metadata_store": SqliteSqlStoreConfig.sample_run_config( + __distro_dir__=__distro_dir__, + db_name="s3_files_metadata.db", + ), + } diff --git a/llama_stack/providers/remote/files/s3/files.py b/llama_stack/providers/remote/files/s3/files.py new file mode 100644 index 000000000..52e0cbbf4 --- /dev/null +++ b/llama_stack/providers/remote/files/s3/files.py @@ -0,0 +1,272 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import time +import uuid +from typing import Annotated + +import boto3 +from botocore.exceptions import BotoCoreError, ClientError, NoCredentialsError +from fastapi import File, Form, Response, UploadFile + +from llama_stack.apis.common.errors import ResourceNotFoundError +from llama_stack.apis.common.responses import Order +from llama_stack.apis.files import ( + Files, + ListOpenAIFileResponse, + OpenAIFileDeleteResponse, + OpenAIFileObject, + OpenAIFilePurpose, +) +from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType +from llama_stack.providers.utils.sqlstore.sqlstore import SqlStore, sqlstore_impl + +from .config import S3FilesImplConfig + +# TODO: provider data for S3 credentials + + +def _create_s3_client(config: S3FilesImplConfig) -> boto3.client: + try: + s3_config = { + "region_name": config.region, + } + + # endpoint URL if specified (for MinIO, LocalStack, etc.) + if config.endpoint_url: + s3_config["endpoint_url"] = config.endpoint_url + + if config.aws_access_key_id and config.aws_secret_access_key: + s3_config.update( + { + "aws_access_key_id": config.aws_access_key_id, + "aws_secret_access_key": config.aws_secret_access_key, + } + ) + + return boto3.client("s3", **s3_config) + + except (BotoCoreError, NoCredentialsError) as e: + raise RuntimeError(f"Failed to initialize S3 client: {e}") from e + + +async def _create_bucket_if_not_exists(client: boto3.client, config: S3FilesImplConfig) -> None: + try: + client.head_bucket(Bucket=config.bucket_name) + except ClientError as e: + error_code = e.response["Error"]["Code"] + if error_code == "404": + if not config.auto_create_bucket: + raise RuntimeError( + f"S3 bucket '{config.bucket_name}' does not exist. " + f"Either create the bucket manually or set 'auto_create_bucket: true' in your configuration." + ) from e + try: + # For us-east-1, we can't specify LocationConstraint + if config.region == "us-east-1": + client.create_bucket(Bucket=config.bucket_name) + else: + client.create_bucket( + Bucket=config.bucket_name, + CreateBucketConfiguration={"LocationConstraint": config.region}, + ) + except ClientError as create_error: + raise RuntimeError( + f"Failed to create S3 bucket '{config.bucket_name}': {create_error}" + ) from create_error + elif error_code == "403": + raise RuntimeError(f"Access denied to S3 bucket '{config.bucket_name}'") from e + else: + raise RuntimeError(f"Failed to access S3 bucket '{config.bucket_name}': {e}") from e + + +class S3FilesImpl(Files): + """S3-based implementation of the Files API.""" + + # TODO: implement expiration, for now a silly offset + _SILLY_EXPIRATION_OFFSET = 100 * 365 * 24 * 60 * 60 + + def __init__(self, config: S3FilesImplConfig) -> None: + self._config = config + self._client: boto3.client | None = None + self._sql_store: SqlStore | None = None + + async def initialize(self) -> None: + self._client = _create_s3_client(self._config) + await _create_bucket_if_not_exists(self._client, self._config) + + self._sql_store = sqlstore_impl(self._config.metadata_store) + await self._sql_store.create_table( + "openai_files", + { + "id": ColumnDefinition(type=ColumnType.STRING, primary_key=True), + "filename": ColumnType.STRING, + "purpose": ColumnType.STRING, + "bytes": ColumnType.INTEGER, + "created_at": ColumnType.INTEGER, + "expires_at": ColumnType.INTEGER, + # TODO: add s3_etag field for integrity checking + }, + ) + + async def shutdown(self) -> None: + pass + + @property + def client(self) -> boto3.client: + assert self._client is not None, "Provider not initialized" + return self._client + + @property + def sql_store(self) -> SqlStore: + assert self._sql_store is not None, "Provider not initialized" + return self._sql_store + + async def openai_upload_file( + self, + file: Annotated[UploadFile, File()], + purpose: Annotated[OpenAIFilePurpose, Form()], + ) -> OpenAIFileObject: + file_id = f"file-{uuid.uuid4().hex}" + + filename = getattr(file, "filename", None) or "uploaded_file" + + created_at = int(time.time()) + expires_at = created_at + self._SILLY_EXPIRATION_OFFSET + content = await file.read() + file_size = len(content) + + await self.sql_store.insert( + "openai_files", + { + "id": file_id, + "filename": filename, + "purpose": purpose.value, + "bytes": file_size, + "created_at": created_at, + "expires_at": expires_at, + }, + ) + + try: + self.client.put_object( + Bucket=self._config.bucket_name, + Key=file_id, + Body=content, + # TODO: enable server-side encryption + ) + except ClientError as e: + await self.sql_store.delete("openai_files", where={"id": file_id}) + + raise RuntimeError(f"Failed to upload file to S3: {e}") from e + + return OpenAIFileObject( + id=file_id, + filename=filename, + purpose=purpose, + bytes=file_size, + created_at=created_at, + expires_at=expires_at, + ) + + async def openai_list_files( + self, + after: str | None = None, + limit: int | None = 10000, + order: Order | None = Order.desc, + purpose: OpenAIFilePurpose | None = None, + ) -> ListOpenAIFileResponse: + # this purely defensive. it should not happen because the router also default to Order.desc. + if not order: + order = Order.desc + + where_conditions = {} + if purpose: + where_conditions["purpose"] = purpose.value + + paginated_result = await self.sql_store.fetch_all( + table="openai_files", + where=where_conditions if where_conditions else None, + order_by=[("created_at", order.value)], + cursor=("id", after) if after else None, + limit=limit, + ) + + files = [ + OpenAIFileObject( + id=row["id"], + filename=row["filename"], + purpose=OpenAIFilePurpose(row["purpose"]), + bytes=row["bytes"], + created_at=row["created_at"], + expires_at=row["expires_at"], + ) + for row in paginated_result.data + ] + + return ListOpenAIFileResponse( + data=files, + has_more=paginated_result.has_more, + # empty string or None? spec says str, ref impl returns str | None, we go with spec + first_id=files[0].id if files else "", + last_id=files[-1].id if files else "", + ) + + async def openai_retrieve_file(self, file_id: str) -> OpenAIFileObject: + row = await self.sql_store.fetch_one("openai_files", where={"id": file_id}) + if not row: + raise ResourceNotFoundError(file_id, "File", "files.list()") + + return OpenAIFileObject( + id=row["id"], + filename=row["filename"], + purpose=OpenAIFilePurpose(row["purpose"]), + bytes=row["bytes"], + created_at=row["created_at"], + expires_at=row["expires_at"], + ) + + async def openai_delete_file(self, file_id: str) -> OpenAIFileDeleteResponse: + row = await self.sql_store.fetch_one("openai_files", where={"id": file_id}) + if not row: + raise ResourceNotFoundError(file_id, "File", "files.list()") + + try: + self.client.delete_object( + Bucket=self._config.bucket_name, + Key=row["id"], + ) + except ClientError as e: + if e.response["Error"]["Code"] != "NoSuchKey": + raise RuntimeError(f"Failed to delete file from S3: {e}") from e + + await self.sql_store.delete("openai_files", where={"id": file_id}) + + return OpenAIFileDeleteResponse(id=file_id, deleted=True) + + async def openai_retrieve_file_content(self, file_id: str) -> Response: + row = await self.sql_store.fetch_one("openai_files", where={"id": file_id}) + if not row: + raise ResourceNotFoundError(file_id, "File", "files.list()") + + try: + response = self.client.get_object( + Bucket=self._config.bucket_name, + Key=row["id"], + ) + # TODO: can we stream this instead of loading it into memory + content = response["Body"].read() + except ClientError as e: + if e.response["Error"]["Code"] == "NoSuchKey": + await self.sql_store.delete("openai_files", where={"id": file_id}) + raise ResourceNotFoundError(file_id, "File", "files.list()") from e + raise RuntimeError(f"Failed to download file from S3: {e}") from e + + return Response( + content=content, + media_type="application/octet-stream", + headers={"Content-Disposition": f'attachment; filename="{row["filename"]}"'}, + ) diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index bd86f7238..e907e8ec6 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -65,7 +65,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import FireworksImplConfig from .models import MODEL_ENTRIES -logger = get_logger(name=__name__, category="inference") +logger = get_logger(name=__name__, category="inference::fireworks") class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData): diff --git a/llama_stack/providers/remote/inference/llama_openai_compat/llama.py b/llama_stack/providers/remote/inference/llama_openai_compat/llama.py index 4857c6723..f2069b5e5 100644 --- a/llama_stack/providers/remote/inference/llama_openai_compat/llama.py +++ b/llama_stack/providers/remote/inference/llama_openai_compat/llama.py @@ -3,15 +3,14 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging - +from llama_stack.log import get_logger from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from .models import MODEL_ENTRIES -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="inference::llama_openai_compat") class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin): diff --git a/llama_stack/providers/remote/inference/nvidia/NVIDIA.md b/llama_stack/providers/remote/inference/nvidia/NVIDIA.md index 4a072215c..d96b29fef 100644 --- a/llama_stack/providers/remote/inference/nvidia/NVIDIA.md +++ b/llama_stack/providers/remote/inference/nvidia/NVIDIA.md @@ -41,6 +41,11 @@ client.initialize() ### Create Completion +> Note on Completion API +> +> The hosted NVIDIA Llama NIMs (e.g., `meta-llama/Llama-3.1-8B-Instruct`) with ```NVIDIA_BASE_URL="https://integrate.api.nvidia.com"``` does not support the ```completion``` method, while the locally deployed NIM does. + + ```python response = client.inference.completion( model_id="meta-llama/Llama-3.1-8B-Instruct", @@ -76,7 +81,78 @@ response = client.inference.chat_completion( print(f"Response: {response.completion_message.content}") ``` +### Tool Calling Example ### +```python +from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition + +tool_definition = ToolDefinition( + tool_name="get_weather", + description="Get current weather information for a location", + parameters={ + "location": ToolParamDefinition( + param_type="string", + description="The city and state, e.g. San Francisco, CA", + required=True, + ), + "unit": ToolParamDefinition( + param_type="string", + description="Temperature unit (celsius or fahrenheit)", + required=False, + default="celsius", + ), + }, +) + +tool_response = client.inference.chat_completion( + model_id="meta-llama/Llama-3.1-8B-Instruct", + messages=[{"role": "user", "content": "What's the weather like in San Francisco?"}], + tools=[tool_definition], +) + +print(f"Tool Response: {tool_response.completion_message.content}") +if tool_response.completion_message.tool_calls: + for tool_call in tool_response.completion_message.tool_calls: + print(f"Tool Called: {tool_call.tool_name}") + print(f"Arguments: {tool_call.arguments}") +``` + +### Structured Output Example +```python +from llama_stack.apis.inference import JsonSchemaResponseFormat, ResponseFormatType + +person_schema = { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + "occupation": {"type": "string"}, + }, + "required": ["name", "age", "occupation"], +} + +response_format = JsonSchemaResponseFormat( + type=ResponseFormatType.json_schema, json_schema=person_schema +) + +structured_response = client.inference.chat_completion( + model_id="meta-llama/Llama-3.1-8B-Instruct", + messages=[ + { + "role": "user", + "content": "Create a profile for a fictional person named Alice who is 30 years old and is a software engineer. ", + } + ], + response_format=response_format, +) + +print(f"Structured Response: {structured_response.completion_message.content}") +``` + ### Create Embeddings +> Note on OpenAI embeddings compatibility +> +> NVIDIA asymmetric embedding models (e.g., `nvidia/llama-3.2-nv-embedqa-1b-v2`) require an `input_type` parameter not present in the standard OpenAI embeddings API. The NVIDIA Inference Adapter automatically sets `input_type="query"` when using the OpenAI-compatible embeddings endpoint for NVIDIA. For passage embeddings, use the `embeddings` API with `task_type="document"`. + ```python response = client.inference.embeddings( model_id="nvidia/llama-3.2-nv-embedqa-1b-v2", diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index 7bc3fd0c9..a5475bc92 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -4,11 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging import warnings from collections.abc import AsyncIterator -from openai import APIConnectionError, BadRequestError +from openai import NOT_GIVEN, APIConnectionError from llama_stack.apis.common.content_types import ( InterleavedContent, @@ -27,12 +26,16 @@ from llama_stack.apis.inference import ( Inference, LogProbConfig, Message, + OpenAIEmbeddingData, + OpenAIEmbeddingsResponse, + OpenAIEmbeddingUsage, ResponseFormat, SamplingParams, TextTruncation, ToolChoice, ToolConfig, ) +from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, @@ -54,7 +57,7 @@ from .openai_utils import ( ) from .utils import _is_nvidia_hosted -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="inference::nvidia") class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper): @@ -194,15 +197,11 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper): } extra_body["input_type"] = task_type_options[task_type] - try: - response = await self.client.embeddings.create( - model=provider_model_id, - input=input, - extra_body=extra_body, - ) - except BadRequestError as e: - raise ValueError(f"Failed to get embeddings: {e}") from e - + response = await self.client.embeddings.create( + model=provider_model_id, + input=input, + extra_body=extra_body, + ) # # OpenAI: CreateEmbeddingResponse(data=[Embedding(embedding=list[float], ...)], ...) # -> @@ -210,6 +209,57 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper): # return EmbeddingsResponse(embeddings=[embedding.embedding for embedding in response.data]) + async def openai_embeddings( + self, + model: str, + input: str | list[str], + encoding_format: str | None = "float", + dimensions: int | None = None, + user: str | None = None, + ) -> OpenAIEmbeddingsResponse: + """ + OpenAI-compatible embeddings for NVIDIA NIM. + + Note: NVIDIA NIM asymmetric embedding models require an "input_type" field not present in the standard OpenAI embeddings API. + We default this to "query" to ensure requests succeed when using the + OpenAI-compatible endpoint. For passage embeddings, use the embeddings API with + `task_type='document'`. + """ + extra_body: dict[str, object] = {"input_type": "query"} + logger.warning( + "NVIDIA OpenAI-compatible embeddings: defaulting to input_type='query'. " + "For passage embeddings, use the embeddings API with task_type='document'." + ) + + response = await self.client.embeddings.create( + model=await self._get_provider_model_id(model), + input=input, + encoding_format=encoding_format if encoding_format is not None else NOT_GIVEN, + dimensions=dimensions if dimensions is not None else NOT_GIVEN, + user=user if user is not None else NOT_GIVEN, + extra_body=extra_body, + ) + + data = [] + for i, embedding_data in enumerate(response.data): + data.append( + OpenAIEmbeddingData( + embedding=embedding_data.embedding, + index=i, + ) + ) + + usage = OpenAIEmbeddingUsage( + prompt_tokens=response.usage.prompt_tokens, + total_tokens=response.usage.total_tokens, + ) + + return OpenAIEmbeddingsResponse( + data=data, + model=response.model, + usage=usage, + ) + async def chat_completion( self, model_id: str, diff --git a/llama_stack/providers/remote/inference/nvidia/utils.py b/llama_stack/providers/remote/inference/nvidia/utils.py index 74019999e..b8431e859 100644 --- a/llama_stack/providers/remote/inference/nvidia/utils.py +++ b/llama_stack/providers/remote/inference/nvidia/utils.py @@ -4,13 +4,13 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging - import httpx +from llama_stack.log import get_logger + from . import NVIDIAConfig -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="inference::nvidia") def _is_nvidia_hosted(config: NVIDIAConfig) -> bool: diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index a93421536..fcaf5ee92 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -85,7 +85,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .models import MODEL_ENTRIES -logger = get_logger(name=__name__, category="inference") +logger = get_logger(name=__name__, category="inference::ollama") class OllamaInferenceAdapter( @@ -619,28 +619,6 @@ class OllamaInferenceAdapter( response.id = id return response - async def batch_completion( - self, - model_id: str, - content_batch: list[InterleavedContent], - sampling_params: SamplingParams | None = None, - response_format: ResponseFormat | None = None, - logprobs: LogProbConfig | None = None, - ): - raise NotImplementedError("Batch completion is not supported for Ollama") - - async def batch_chat_completion( - self, - model_id: str, - messages_batch: list[list[Message]], - sampling_params: SamplingParams | None = None, - tools: list[ToolDefinition] | None = None, - tool_config: ToolConfig | None = None, - response_format: ResponseFormat | None = None, - logprobs: LogProbConfig | None = None, - ): - raise NotImplementedError("Batch chat completion is not supported for Ollama") - async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]: async def _convert_content(content) -> dict: diff --git a/llama_stack/providers/remote/inference/openai/openai.py b/llama_stack/providers/remote/inference/openai/openai.py index 865258559..0f73c9321 100644 --- a/llama_stack/providers/remote/inference/openai/openai.py +++ b/llama_stack/providers/remote/inference/openai/openai.py @@ -4,15 +4,14 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging - +from llama_stack.log import get_logger from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from .config import OpenAIConfig from .models import MODEL_ENTRIES -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="inference::openai") # diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index 323831845..97c72d14c 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -5,7 +5,6 @@ # the root directory of this source tree. -import logging from collections.abc import AsyncGenerator from huggingface_hub import AsyncInferenceClient, HfApi @@ -34,6 +33,7 @@ from llama_stack.apis.inference import ( ToolPromptFormat, ) from llama_stack.apis.models import Model +from llama_stack.log import get_logger from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.model_registry import ( @@ -58,7 +58,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="inference::tgi") def build_hf_repo_model_entries(): diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index a06e4173b..54c76607f 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -61,7 +61,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import TogetherImplConfig from .models import MODEL_ENTRIES -logger = get_logger(name=__name__, category="inference") +logger = get_logger(name=__name__, category="inference::together") class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData): diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index ac626874c..9e9a80ca5 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -85,7 +85,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import VLLMInferenceAdapterConfig -log = get_logger(name=__name__, category="inference") +log = get_logger(name=__name__, category="inference::vllm") def build_hf_repo_model_entries(): @@ -711,25 +711,3 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): user=user, ) return await self.client.chat.completions.create(**params) # type: ignore - - async def batch_completion( - self, - model_id: str, - content_batch: list[InterleavedContent], - sampling_params: SamplingParams | None = None, - response_format: ResponseFormat | None = None, - logprobs: LogProbConfig | None = None, - ): - raise NotImplementedError("Batch completion is not supported for Ollama") - - async def batch_chat_completion( - self, - model_id: str, - messages_batch: list[list[Message]], - sampling_params: SamplingParams | None = None, - tools: list[ToolDefinition] | None = None, - tool_config: ToolConfig | None = None, - response_format: ResponseFormat | None = None, - logprobs: LogProbConfig | None = None, - ): - raise NotImplementedError("Batch chat completion is not supported for Ollama") diff --git a/llama_stack/providers/remote/post_training/nvidia/utils.py b/llama_stack/providers/remote/post_training/nvidia/utils.py index d6e1016b2..162951ff3 100644 --- a/llama_stack/providers/remote/post_training/nvidia/utils.py +++ b/llama_stack/providers/remote/post_training/nvidia/utils.py @@ -4,18 +4,18 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging import warnings from typing import Any from pydantic import BaseModel from llama_stack.apis.post_training import TrainingConfig +from llama_stack.log import get_logger from llama_stack.providers.remote.post_training.nvidia.config import SFTLoRADefaultConfig from .config import NvidiaPostTrainingConfig -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="post_training::nvidia") def warn_unsupported_params(config_dict: Any, supported_keys: set[str], config_name: str) -> None: diff --git a/llama_stack/providers/remote/safety/bedrock/bedrock.py b/llama_stack/providers/remote/safety/bedrock/bedrock.py index 1895e7507..8855e02a4 100644 --- a/llama_stack/providers/remote/safety/bedrock/bedrock.py +++ b/llama_stack/providers/remote/safety/bedrock/bedrock.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import json -import logging from typing import Any from llama_stack.apis.inference import Message @@ -16,12 +15,13 @@ from llama_stack.apis.safety import ( ViolationLevel, ) from llama_stack.apis.shields import Shield +from llama_stack.log import get_logger from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.utils.bedrock.client import create_bedrock_client from .config import BedrockSafetyConfig -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="safety::bedrock") class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate): diff --git a/llama_stack/providers/remote/safety/nvidia/nvidia.py b/llama_stack/providers/remote/safety/nvidia/nvidia.py index 7f17b1cb6..65f901da2 100644 --- a/llama_stack/providers/remote/safety/nvidia/nvidia.py +++ b/llama_stack/providers/remote/safety/nvidia/nvidia.py @@ -4,20 +4,20 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging from typing import Any import requests from llama_stack.apis.inference import Message -from llama_stack.apis.safety import RunShieldResponse, Safety, SafetyViolation, ViolationLevel +from llama_stack.apis.safety import ModerationObject, RunShieldResponse, Safety, SafetyViolation, ViolationLevel from llama_stack.apis.shields import Shield +from llama_stack.log import get_logger from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new from .config import NVIDIASafetyConfig -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="safety::nvidia") class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate): @@ -67,6 +67,9 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate): self.shield = NeMoGuardrails(self.config, shield.shield_id) return await self.shield.run(messages) + async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject: + raise NotImplementedError("NVIDIA safety provider currently does not implement run_moderation") + class NeMoGuardrails: """ diff --git a/llama_stack/providers/remote/safety/sambanova/sambanova.py b/llama_stack/providers/remote/safety/sambanova/sambanova.py index 6c7190afe..2beb5e0ea 100644 --- a/llama_stack/providers/remote/safety/sambanova/sambanova.py +++ b/llama_stack/providers/remote/safety/sambanova/sambanova.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import json -import logging from typing import Any import litellm @@ -20,12 +19,13 @@ from llama_stack.apis.safety import ( ) from llama_stack.apis.shields import Shield from llama_stack.core.request_headers import NeedsRequestProviderData +from llama_stack.log import get_logger from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new from .config import SambaNovaSafetyConfig -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="safety::sambanova") CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?" diff --git a/llama_stack/providers/remote/vector_io/chroma/chroma.py b/llama_stack/providers/remote/vector_io/chroma/chroma.py index 8f252711b..a9ec644ef 100644 --- a/llama_stack/providers/remote/vector_io/chroma/chroma.py +++ b/llama_stack/providers/remote/vector_io/chroma/chroma.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import asyncio import json -import logging from typing import Any from urllib.parse import urlparse @@ -20,6 +19,7 @@ from llama_stack.apis.vector_io import ( QueryChunksResponse, VectorIO, ) +from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig from llama_stack.providers.utils.kvstore import kvstore_impl @@ -33,7 +33,7 @@ from llama_stack.providers.utils.memory.vector_store import ( from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="vector_io::chroma") ChromaClientType = chromadb.api.AsyncClientAPI | chromadb.api.ClientAPI diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py index 0eaae81b3..e07e8ff12 100644 --- a/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import asyncio -import logging import os from typing import Any @@ -21,6 +20,7 @@ from llama_stack.apis.vector_io import ( QueryChunksResponse, VectorIO, ) +from llama_stack.log import get_logger from llama_stack.providers.datatypes import VectorDBsProtocolPrivate from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig from llama_stack.providers.utils.kvstore import kvstore_impl @@ -36,7 +36,7 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="vector_io::milvus") VERSION = "v3" VECTOR_DBS_PREFIX = f"vector_dbs:milvus:{VERSION}::" @@ -413,15 +413,6 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP index = await self._get_and_cache_vector_db_index(vector_db_id) if not index: raise VectorStoreNotFoundError(vector_db_id) - - if params and params.get("mode") == "keyword": - # Check if this is inline Milvus (Milvus-Lite) - if hasattr(self.config, "db_path"): - raise NotImplementedError( - "Keyword search is not supported in Milvus-Lite. " - "Please use a remote Milvus server for keyword search functionality." - ) - return await index.query_chunks(query, params) async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None: diff --git a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py index d2a5d910b..1c8d361c2 100644 --- a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py +++ b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging from typing import Any import psycopg2 @@ -22,6 +21,7 @@ from llama_stack.apis.vector_io import ( QueryChunksResponse, VectorIO, ) +from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore.api import KVStore @@ -34,7 +34,7 @@ from llama_stack.providers.utils.memory.vector_store import ( from .config import PGVectorVectorIOConfig -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="vector_io::pgvector") VERSION = "v3" VECTOR_DBS_PREFIX = f"vector_dbs:pgvector:{VERSION}::" diff --git a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py index 018015780..0a0faa23a 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py +++ b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import asyncio -import logging import uuid from typing import Any @@ -24,6 +23,7 @@ from llama_stack.apis.vector_io import ( VectorStoreChunkingStrategy, VectorStoreFileObject, ) +from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl @@ -36,7 +36,7 @@ from llama_stack.providers.utils.memory.vector_store import ( from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="vector_io::qdrant") CHUNK_ID_KEY = "_chunk_id" # KV store prefixes for vector databases diff --git a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py index 966724848..59b6bf124 100644 --- a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py +++ b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import json -import logging from typing import Any import weaviate @@ -19,6 +18,7 @@ from llama_stack.apis.files.files import Files from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.core.request_headers import NeedsRequestProviderData +from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore.api import KVStore @@ -34,7 +34,7 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti from .config import WeaviateVectorIOConfig -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="vector_io::weaviate") VERSION = "v3" VECTOR_DBS_PREFIX = f"vector_dbs:weaviate:{VERSION}::" diff --git a/llama_stack/providers/utils/inference/embedding_mixin.py b/llama_stack/providers/utils/inference/embedding_mixin.py index 32e89f987..65ba2854b 100644 --- a/llama_stack/providers/utils/inference/embedding_mixin.py +++ b/llama_stack/providers/utils/inference/embedding_mixin.py @@ -5,10 +5,11 @@ # the root directory of this source tree. import base64 -import logging import struct from typing import TYPE_CHECKING +from llama_stack.log import get_logger + if TYPE_CHECKING: from sentence_transformers import SentenceTransformer @@ -27,7 +28,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import interleaved_con EMBEDDING_MODELS = {} -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="providers::utils") class SentenceTransformerEmbeddingMixin: diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index da2e634f6..9bd43e4c9 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -54,7 +54,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, ) -logger = get_logger(name=__name__, category="inference") +logger = get_logger(name=__name__, category="providers::utils") class LiteLLMOpenAIMixin( @@ -429,28 +429,6 @@ class LiteLLMOpenAIMixin( ) return await litellm.acompletion(**params) - async def batch_completion( - self, - model_id: str, - content_batch: list[InterleavedContent], - sampling_params: SamplingParams | None = None, - response_format: ResponseFormat | None = None, - logprobs: LogProbConfig | None = None, - ): - raise NotImplementedError("Batch completion is not supported for OpenAI Compat") - - async def batch_chat_completion( - self, - model_id: str, - messages_batch: list[list[Message]], - sampling_params: SamplingParams | None = None, - tools: list[ToolDefinition] | None = None, - tool_config: ToolConfig | None = None, - response_format: ResponseFormat | None = None, - logprobs: LogProbConfig | None = None, - ): - raise NotImplementedError("Batch chat completion is not supported for OpenAI Compat") - async def check_model_availability(self, model: str) -> bool: """ Check if a specific model is available via LiteLLM for the current diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index ddb3bda8c..44add8f9e 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -17,7 +17,7 @@ from llama_stack.providers.utils.inference import ( ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR, ) -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="providers::utils") class RemoteInferenceProviderConfig(BaseModel): diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 9a77c8cc4..55c2ac0ad 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import base64 import json -import logging import struct import time import uuid @@ -31,15 +30,21 @@ from openai.types.chat import ( from openai.types.chat import ( ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam, ) + +try: + from openai.types.chat import ( + ChatCompletionMessageFunctionToolCall as OpenAIChatCompletionMessageFunctionToolCall, + ) +except ImportError: + from openai.types.chat.chat_completion_message_tool_call import ( + ChatCompletionMessageToolCall as OpenAIChatCompletionMessageFunctionToolCall, + ) from openai.types.chat import ( ChatCompletionMessageParam as OpenAIChatCompletionMessage, ) from openai.types.chat import ( ChatCompletionMessageToolCall, ) -from openai.types.chat import ( - ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCall, -) from openai.types.chat import ( ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage, ) @@ -116,6 +121,7 @@ from llama_stack.apis.inference import ( from llama_stack.apis.inference import ( OpenAIChoice as OpenAIChatCompletionChoice, ) +from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import ( BuiltinTool, StopReason, @@ -128,7 +134,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( decode_assistant_message, ) -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="providers::utils") class OpenAICompatCompletionChoiceDelta(BaseModel): @@ -633,7 +639,7 @@ async def convert_message_to_openai_dict_new( ) elif isinstance(message, CompletionMessage): tool_calls = [ - OpenAIChatCompletionMessageToolCall( + OpenAIChatCompletionMessageFunctionToolCall( id=tool.call_id, function=OpenAIFunction( name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value), @@ -903,7 +909,7 @@ def _convert_openai_request_response_format( def _convert_openai_tool_calls( - tool_calls: list[OpenAIChatCompletionMessageToolCall], + tool_calls: list[OpenAIChatCompletionMessageFunctionToolCall], ) -> list[ToolCall]: """ Convert an OpenAI ChatCompletionMessageToolCall list into a list of ToolCall. diff --git a/llama_stack/providers/utils/inference/openai_mixin.py b/llama_stack/providers/utils/inference/openai_mixin.py index 72286dffb..f60deee6e 100644 --- a/llama_stack/providers/utils/inference/openai_mixin.py +++ b/llama_stack/providers/utils/inference/openai_mixin.py @@ -25,7 +25,7 @@ from llama_stack.apis.inference import ( from llama_stack.log import get_logger from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="providers::utils") class OpenAIMixin(ABC): diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index bb9a91b97..a93326e41 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -58,7 +58,7 @@ from llama_stack.models.llama.sku_list import resolve_model from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal from llama_stack.providers.utils.inference import supported_inference_models -log = get_logger(name=__name__, category="inference") +log = get_logger(name=__name__, category="providers::utils") class ChatCompletionRequestWithRawContent(ChatCompletionRequest): diff --git a/llama_stack/providers/utils/kvstore/config.py b/llama_stack/providers/utils/kvstore/config.py index f00cb1f8b..d1747d65b 100644 --- a/llama_stack/providers/utils/kvstore/config.py +++ b/llama_stack/providers/utils/kvstore/config.py @@ -75,6 +75,8 @@ class PostgresKVStoreConfig(CommonConfig): db: str = "llamastack" user: str password: str | None = None + ssl_mode: str | None = None + ca_cert_path: str | None = None table_name: str = "llamastack_kvstore" @classmethod diff --git a/llama_stack/providers/utils/kvstore/mongodb/mongodb.py b/llama_stack/providers/utils/kvstore/mongodb/mongodb.py index 3842773d9..bab87a4aa 100644 --- a/llama_stack/providers/utils/kvstore/mongodb/mongodb.py +++ b/llama_stack/providers/utils/kvstore/mongodb/mongodb.py @@ -4,16 +4,16 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging from datetime import datetime from pymongo import AsyncMongoClient +from llama_stack.log import get_logger from llama_stack.providers.utils.kvstore import KVStore from ..config import MongoDBKVStoreConfig -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="providers::utils") class MongoDBKVStoreImpl(KVStore): diff --git a/llama_stack/providers/utils/kvstore/postgres/postgres.py b/llama_stack/providers/utils/kvstore/postgres/postgres.py index bd35decfc..56d6dbb48 100644 --- a/llama_stack/providers/utils/kvstore/postgres/postgres.py +++ b/llama_stack/providers/utils/kvstore/postgres/postgres.py @@ -4,16 +4,17 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging from datetime import datetime import psycopg2 from psycopg2.extras import DictCursor +from llama_stack.log import get_logger + from ..api import KVStore from ..config import PostgresKVStoreConfig -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="providers::utils") class PostgresKVStoreImpl(KVStore): @@ -30,6 +31,8 @@ class PostgresKVStoreImpl(KVStore): database=self.config.db, user=self.config.user, password=self.config.password, + sslmode=self.config.ssl_mode, + sslrootcert=self.config.ca_cert_path, ) self.conn.autocommit = True self.cursor = self.conn.cursor(cursor_factory=DictCursor) diff --git a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py index 7c7775691..cc6ee2f71 100644 --- a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py +++ b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py @@ -45,7 +45,7 @@ from llama_stack.providers.utils.memory.vector_store import ( make_overlapped_chunks, ) -logger = get_logger(__name__, category="vector_io") +logger = get_logger(name=__name__, category="providers::utils") # Constants for OpenAI vector stores CHUNK_MULTIPLIER = 5 diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index 6ae5bb521..b74080384 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import base64 import io -import logging import re import time from abc import ABC, abstractmethod @@ -26,6 +25,7 @@ from llama_stack.apis.common.content_types import ( from llama_stack.apis.tools import RAGDocument from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse +from llama_stack.log import get_logger from llama_stack.models.llama.llama3.tokenizer import Tokenizer from llama_stack.providers.datatypes import Api from llama_stack.providers.utils.inference.prompt_adapter import ( @@ -33,7 +33,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( ) from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="providers::utils") class ChunkForDeletion(BaseModel): diff --git a/llama_stack/providers/utils/scheduler.py b/llama_stack/providers/utils/scheduler.py index 65c3d2898..146591b2f 100644 --- a/llama_stack/providers/utils/scheduler.py +++ b/llama_stack/providers/utils/scheduler.py @@ -17,7 +17,7 @@ from pydantic import BaseModel from llama_stack.log import get_logger -logger = get_logger(name=__name__, category="scheduler") +logger = get_logger(name=__name__, category="providers::utils") # TODO: revisit the list of possible statuses when defining a more coherent diff --git a/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py b/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py index ccc835768..867ba2f55 100644 --- a/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py +++ b/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py @@ -17,7 +17,7 @@ from llama_stack.log import get_logger from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore from .sqlstore import SqlStoreType -logger = get_logger(name=__name__, category="authorized_sqlstore") +logger = get_logger(name=__name__, category="providers::utils") # Hardcoded copy of the default policy that our SQL filtering implements # WARNING: If default_policy() changes, this constant must be updated accordingly diff --git a/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py b/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py index 6414929db..f75c35314 100644 --- a/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py +++ b/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py @@ -22,6 +22,7 @@ from sqlalchemy import ( text, ) from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine +from sqlalchemy.ext.asyncio.engine import AsyncEngine from llama_stack.apis.common.responses import PaginatedResponse from llama_stack.log import get_logger @@ -29,7 +30,7 @@ from llama_stack.log import get_logger from .api import ColumnDefinition, ColumnType, SqlStore from .sqlstore import SqlAlchemySqlStoreConfig -logger = get_logger(name=__name__, category="sqlstore") +logger = get_logger(name=__name__, category="providers::utils") TYPE_MAPPING: dict[ColumnType, Any] = { ColumnType.INTEGER: Integer, @@ -45,9 +46,12 @@ TYPE_MAPPING: dict[ColumnType, Any] = { class SqlAlchemySqlStoreImpl(SqlStore): def __init__(self, config: SqlAlchemySqlStoreConfig): self.config = config - self.async_session = async_sessionmaker(create_async_engine(config.engine_str)) + self.async_session = async_sessionmaker(self.create_engine()) self.metadata = MetaData() + def create_engine(self) -> AsyncEngine: + return create_async_engine(self.config.engine_str, pool_pre_ping=True) + async def create_table( self, table: str, @@ -83,7 +87,7 @@ class SqlAlchemySqlStoreImpl(SqlStore): else: sqlalchemy_table = self.metadata.tables[table] - engine = create_async_engine(self.config.engine_str) + engine = self.create_engine() async with engine.begin() as conn: await conn.run_sync(self.metadata.create_all, tables=[sqlalchemy_table], checkfirst=True) @@ -241,7 +245,7 @@ class SqlAlchemySqlStoreImpl(SqlStore): nullable: bool = True, ) -> None: """Add a column to an existing table if the column doesn't already exist.""" - engine = create_async_engine(self.config.engine_str) + engine = self.create_engine() try: async with engine.begin() as conn: diff --git a/llama_stack/providers/utils/telemetry/sqlite_trace_store.py b/llama_stack/providers/utils/telemetry/sqlite_trace_store.py index 8dd6061a6..71480364c 100644 --- a/llama_stack/providers/utils/telemetry/sqlite_trace_store.py +++ b/llama_stack/providers/utils/telemetry/sqlite_trace_store.py @@ -5,12 +5,23 @@ # the root directory of this source tree. import json -from datetime import datetime +from datetime import UTC, datetime from typing import Protocol import aiosqlite -from llama_stack.apis.telemetry import QueryCondition, Span, SpanWithStatus, Trace +from llama_stack.apis.telemetry import ( + MetricDataPoint, + MetricLabel, + MetricLabelMatcher, + MetricQueryType, + MetricSeries, + QueryCondition, + QueryMetricsResponse, + Span, + SpanWithStatus, + Trace, +) class TraceStore(Protocol): @@ -29,11 +40,192 @@ class TraceStore(Protocol): max_depth: int | None = None, ) -> dict[str, SpanWithStatus]: ... + async def query_metrics( + self, + metric_name: str, + start_time: datetime, + end_time: datetime | None = None, + granularity: str | None = "1d", + query_type: MetricQueryType = MetricQueryType.RANGE, + label_matchers: list[MetricLabelMatcher] | None = None, + ) -> QueryMetricsResponse: ... + class SQLiteTraceStore(TraceStore): def __init__(self, conn_string: str): self.conn_string = conn_string + async def query_metrics( + self, + metric_name: str, + start_time: datetime, + end_time: datetime | None = None, + granularity: str | None = None, + query_type: MetricQueryType = MetricQueryType.RANGE, + label_matchers: list[MetricLabelMatcher] | None = None, + ) -> QueryMetricsResponse: + if end_time is None: + end_time = datetime.now(UTC) + + # Build base query + if query_type == MetricQueryType.INSTANT: + query = """ + SELECT + se.name, + SUM(CAST(json_extract(se.attributes, '$.value') AS REAL)) as value, + json_extract(se.attributes, '$.unit') as unit, + se.attributes + FROM span_events se + WHERE se.name = ? + AND se.timestamp BETWEEN ? AND ? + """ + else: + if granularity: + time_format = self._get_time_format_for_granularity(granularity) + query = f""" + SELECT + se.name, + SUM(CAST(json_extract(se.attributes, '$.value') AS REAL)) as value, + json_extract(se.attributes, '$.unit') as unit, + se.attributes, + strftime('{time_format}', se.timestamp) as bucket_start + FROM span_events se + WHERE se.name = ? + AND se.timestamp BETWEEN ? AND ? + """ + else: + query = """ + SELECT + se.name, + json_extract(se.attributes, '$.value') as value, + json_extract(se.attributes, '$.unit') as unit, + se.attributes, + se.timestamp + FROM span_events se + WHERE se.name = ? + AND se.timestamp BETWEEN ? AND ? + """ + + params = [f"metric.{metric_name}", start_time.isoformat(), end_time.isoformat()] + + # Labels that will be attached to the MetricSeries (preserve matcher labels) + all_labels: list[MetricLabel] = [] + matcher_label_names = set() + if label_matchers: + for matcher in label_matchers: + json_path = f"$.{matcher.name}" + if matcher.operator == "=": + query += f" AND json_extract(se.attributes, '{json_path}') = ?" + params.append(matcher.value) + elif matcher.operator == "!=": + query += f" AND json_extract(se.attributes, '{json_path}') != ?" + params.append(matcher.value) + elif matcher.operator == "=~": + query += f" AND json_extract(se.attributes, '{json_path}') LIKE ?" + params.append(f"%{matcher.value}%") + elif matcher.operator == "!~": + query += f" AND json_extract(se.attributes, '{json_path}') NOT LIKE ?" + params.append(f"%{matcher.value}%") + # Preserve filter context in output + all_labels.append(MetricLabel(name=matcher.name, value=str(matcher.value))) + matcher_label_names.add(matcher.name) + + # GROUP BY / ORDER BY logic + if query_type == MetricQueryType.RANGE and granularity: + group_time_format = self._get_time_format_for_granularity(granularity) + query += f" GROUP BY strftime('{group_time_format}', se.timestamp), json_extract(se.attributes, '$.unit')" + query += " ORDER BY bucket_start" + elif query_type == MetricQueryType.INSTANT: + query += " GROUP BY json_extract(se.attributes, '$.unit')" + else: + query += " ORDER BY se.timestamp" + + # Execute query + async with aiosqlite.connect(self.conn_string) as conn: + conn.row_factory = aiosqlite.Row + async with conn.execute(query, params) as cursor: + rows = await cursor.fetchall() + + if not rows: + return QueryMetricsResponse(data=[]) + + data_points = [] + # We want to add attribute labels, but only those not already present as matcher labels. + attr_label_names = set() + for row in rows: + # Parse JSON attributes safely, if there are no attributes (weird), just don't add the labels to the result. + try: + attributes = json.loads(row["attributes"] or "{}") + except (TypeError, json.JSONDecodeError): + attributes = {} + + value = row["value"] + unit = row["unit"] or "" + + # Add labels from attributes without duplicating matcher labels, if we don't do this, there will be a lot of duplicate label in the result. + for k, v in attributes.items(): + if k not in ["value", "unit"] and k not in matcher_label_names and k not in attr_label_names: + all_labels.append(MetricLabel(name=k, value=str(v))) + attr_label_names.add(k) + + # Determine timestamp + if query_type == MetricQueryType.RANGE and granularity: + try: + bucket_start_raw = row["bucket_start"] + except KeyError as e: + raise ValueError( + "DB did not have a bucket_start time in row when using granularity, this indicates improper formatting" + ) from e + # this value could also be there, but be NULL, I think. + if bucket_start_raw is None: + raise ValueError("bucket_start is None check time format and data") + bucket_start = datetime.fromisoformat(bucket_start_raw) + timestamp = int(bucket_start.timestamp()) + elif query_type == MetricQueryType.INSTANT: + timestamp = int(datetime.now(UTC).timestamp()) + else: + try: + timestamp_raw = row["timestamp"] + except KeyError as e: + raise ValueError( + "DB did not have a timestamp in row, this indicates improper formatting" + ) from e + # this value could also be there, but be NULL, I think. + if timestamp_raw is None: + raise ValueError("timestamp is None check time format and data") + timestamp_iso = datetime.fromisoformat(timestamp_raw) + timestamp = int(timestamp_iso.timestamp()) + + data_points.append( + MetricDataPoint( + timestamp=timestamp, + value=value, + unit=unit, + ) + ) + + metric_series = [MetricSeries(metric=metric_name, labels=all_labels, values=data_points)] + return QueryMetricsResponse(data=metric_series) + + def _get_time_format_for_granularity(self, granularity: str | None) -> str: + """Get the SQLite strftime format string for a given granularity. + Args: + granularity: Granularity string (e.g., "1m", "5m", "1h", "1d") + Returns: + SQLite strftime format string for the granularity + """ + if granularity is None: + raise ValueError("granularity cannot be None for this method - use separate logic for no aggregation") + + if granularity.endswith("d"): + return "%Y-%m-%d 00:00:00" + elif granularity.endswith("h"): + return "%Y-%m-%d %H:00:00" + elif granularity.endswith("m"): + return "%Y-%m-%d %H:%M:00" + else: + return "%Y-%m-%d %H:%M:00" # Default to most granular which will give us the most timestamps. + async def query_traces( self, attribute_filters: list[QueryCondition] | None = None, diff --git a/llama_stack/providers/utils/telemetry/tracing.py b/llama_stack/providers/utils/telemetry/tracing.py index 7080e774a..7694003b5 100644 --- a/llama_stack/providers/utils/telemetry/tracing.py +++ b/llama_stack/providers/utils/telemetry/tracing.py @@ -6,7 +6,7 @@ import asyncio import contextvars -import logging +import logging # allow-direct-logging import queue import random import sys diff --git a/llama_stack/testing/inference_recorder.py b/llama_stack/testing/inference_recorder.py index 478f77773..4a6958399 100644 --- a/llama_stack/testing/inference_recorder.py +++ b/llama_stack/testing/inference_recorder.py @@ -261,7 +261,7 @@ async def _patched_inference_method(original_method, self, client_type, endpoint else: raise RuntimeError( f"No recorded response found for request hash: {request_hash}\n" - f"Endpoint: {endpoint}\n" + f"Request: {method} {url} {body}\n" f"Model: {body.get('model', 'unknown')}\n" f"To record this response, run with LLAMA_STACK_INFERENCE_MODE=record" ) diff --git a/llama_stack/ui/.nvmrc b/llama_stack/ui/.nvmrc new file mode 100644 index 000000000..1384ff6a1 --- /dev/null +++ b/llama_stack/ui/.nvmrc @@ -0,0 +1 @@ +22.5.1 diff --git a/llama_stack/ui/.prettierignore b/llama_stack/ui/.prettierignore index 1b8ac8894..b737ae6ed 100644 --- a/llama_stack/ui/.prettierignore +++ b/llama_stack/ui/.prettierignore @@ -1,3 +1,12 @@ # Ignore artifacts: build coverage +.next +node_modules +dist +*.lock +*.log + +# Generated files +*.min.js +*.min.css diff --git a/llama_stack/ui/.prettierrc b/llama_stack/ui/.prettierrc index 0967ef424..059475a24 100644 --- a/llama_stack/ui/.prettierrc +++ b/llama_stack/ui/.prettierrc @@ -1 +1,10 @@ -{} +{ + "semi": true, + "trailingComma": "es5", + "singleQuote": false, + "printWidth": 80, + "tabWidth": 2, + "useTabs": false, + "bracketSpacing": true, + "arrowParens": "avoid" +} diff --git a/llama_stack/ui/app/api/v1/[...path]/route.ts b/llama_stack/ui/app/api/v1/[...path]/route.ts index 1959f9099..51c1f8004 100644 --- a/llama_stack/ui/app/api/v1/[...path]/route.ts +++ b/llama_stack/ui/app/api/v1/[...path]/route.ts @@ -47,7 +47,7 @@ async function proxyRequest(request: NextRequest, method: string) { const responseText = await response.text(); console.log( - `Response from FastAPI: ${response.status} ${response.statusText}`, + `Response from FastAPI: ${response.status} ${response.statusText}` ); // Create response with same status and headers @@ -74,7 +74,7 @@ async function proxyRequest(request: NextRequest, method: string) { backend_url: BACKEND_URL, timestamp: new Date().toISOString(), }, - { status: 500 }, + { status: 500 } ); } } diff --git a/llama_stack/ui/app/auth/signin/page.tsx b/llama_stack/ui/app/auth/signin/page.tsx index c9510fd6b..0ccb4a397 100644 --- a/llama_stack/ui/app/auth/signin/page.tsx +++ b/llama_stack/ui/app/auth/signin/page.tsx @@ -51,9 +51,9 @@ export default function SignInPage() { onClick={() => { console.log("Signing in with GitHub..."); signIn("github", { callbackUrl: "/auth/signin" }).catch( - (error) => { + error => { console.error("Sign in error:", error); - }, + } ); }} className="w-full" diff --git a/llama_stack/ui/app/chat-playground/page.test.tsx b/llama_stack/ui/app/chat-playground/page.test.tsx new file mode 100644 index 000000000..54c15f95a --- /dev/null +++ b/llama_stack/ui/app/chat-playground/page.test.tsx @@ -0,0 +1,587 @@ +import React from "react"; +import { + render, + screen, + fireEvent, + waitFor, + act, +} from "@testing-library/react"; +import "@testing-library/jest-dom"; +import ChatPlaygroundPage from "./page"; + +const mockClient = { + agents: { + list: jest.fn(), + create: jest.fn(), + retrieve: jest.fn(), + delete: jest.fn(), + session: { + list: jest.fn(), + create: jest.fn(), + delete: jest.fn(), + retrieve: jest.fn(), + }, + turn: { + create: jest.fn(), + }, + }, + models: { + list: jest.fn(), + }, + toolgroups: { + list: jest.fn(), + }, +}; + +jest.mock("@/hooks/use-auth-client", () => ({ + useAuthClient: jest.fn(() => mockClient), +})); + +jest.mock("@/components/chat-playground/chat", () => ({ + Chat: jest.fn( + ({ + className, + messages, + handleSubmit, + input, + handleInputChange, + isGenerating, + append, + suggestions, + }) => ( +
+
{messages.length}
+ + + {suggestions?.map((suggestion: string, index: number) => ( + + ))} +
+ ) + ), +})); + +jest.mock("@/components/chat-playground/conversations", () => ({ + SessionManager: jest.fn(({ selectedAgentId, onNewSession }) => ( +
+ {selectedAgentId && ( + <> +
{selectedAgentId}
+ + + )} +
+ )), + SessionUtils: { + saveCurrentSessionId: jest.fn(), + loadCurrentSessionId: jest.fn(), + loadCurrentAgentId: jest.fn(), + saveCurrentAgentId: jest.fn(), + clearCurrentSession: jest.fn(), + saveSessionData: jest.fn(), + loadSessionData: jest.fn(), + saveAgentConfig: jest.fn(), + loadAgentConfig: jest.fn(), + clearAgentCache: jest.fn(), + createDefaultSession: jest.fn(() => ({ + id: "test-session-123", + name: "Default Session", + messages: [], + selectedModel: "", + systemMessage: "You are a helpful assistant.", + agentId: "test-agent-123", + createdAt: Date.now(), + updatedAt: Date.now(), + })), + }, +})); + +const mockAgents = [ + { + agent_id: "agent_123", + agent_config: { + name: "Test Agent", + instructions: "You are a test assistant.", + }, + }, + { + agent_id: "agent_456", + agent_config: { + agent_name: "Another Agent", + instructions: "You are another assistant.", + }, + }, +]; + +const mockModels = [ + { + identifier: "test-model-1", + model_type: "llm", + }, + { + identifier: "test-model-2", + model_type: "llm", + }, +]; + +const mockToolgroups = [ + { + identifier: "builtin::rag", + provider_id: "test-provider", + type: "tool_group", + provider_resource_id: "test-resource", + }, +]; + +describe("ChatPlaygroundPage", () => { + beforeEach(() => { + jest.clearAllMocks(); + Element.prototype.scrollIntoView = jest.fn(); + mockClient.agents.list.mockResolvedValue({ data: mockAgents }); + mockClient.models.list.mockResolvedValue(mockModels); + mockClient.toolgroups.list.mockResolvedValue(mockToolgroups); + mockClient.agents.session.create.mockResolvedValue({ + session_id: "new-session-123", + }); + mockClient.agents.session.list.mockResolvedValue({ data: [] }); + mockClient.agents.session.retrieve.mockResolvedValue({ + session_id: "test-session", + session_name: "Test Session", + started_at: new Date().toISOString(), + turns: [], + }); // No turns by default + mockClient.agents.retrieve.mockResolvedValue({ + agent_id: "test-agent", + agent_config: { + toolgroups: ["builtin::rag"], + instructions: "Test instructions", + model: "test-model", + }, + }); + mockClient.agents.delete.mockResolvedValue(undefined); + }); + + describe("Agent Selector Rendering", () => { + test("shows agent selector when agents are available", async () => { + await act(async () => { + render(); + }); + + await waitFor(() => { + expect(screen.getByText("Agent Session:")).toBeInTheDocument(); + expect(screen.getAllByRole("combobox")).toHaveLength(2); + expect(screen.getByText("+ New Agent")).toBeInTheDocument(); + expect(screen.getByText("Clear Chat")).toBeInTheDocument(); + }); + }); + + test("does not show agent selector when no agents are available", async () => { + mockClient.agents.list.mockResolvedValue({ data: [] }); + + await act(async () => { + render(); + }); + + await waitFor(() => { + expect(screen.queryByText("Agent Session:")).not.toBeInTheDocument(); + expect(screen.getAllByRole("combobox")).toHaveLength(1); + expect(screen.getByText("+ New Agent")).toBeInTheDocument(); + expect(screen.queryByText("Clear Chat")).not.toBeInTheDocument(); + }); + }); + + test("does not show agent selector while loading", async () => { + mockClient.agents.list.mockImplementation(() => new Promise(() => {})); + + await act(async () => { + render(); + }); + + expect(screen.queryByText("Agent Session:")).not.toBeInTheDocument(); + expect(screen.getAllByRole("combobox")).toHaveLength(1); + expect(screen.getByText("+ New Agent")).toBeInTheDocument(); + expect(screen.queryByText("Clear Chat")).not.toBeInTheDocument(); + }); + + test("shows agent options in selector", async () => { + await act(async () => { + render(); + }); + + await waitFor(() => { + const agentCombobox = screen.getAllByRole("combobox").find(element => { + return ( + element.textContent?.includes("Test Agent") || + element.textContent?.includes("Select Agent") + ); + }); + expect(agentCombobox).toBeDefined(); + fireEvent.click(agentCombobox!); + }); + + await waitFor(() => { + expect(screen.getAllByText("Test Agent")).toHaveLength(2); + expect(screen.getByText("Another Agent")).toBeInTheDocument(); + }); + }); + + test("displays agent ID when no name is available", async () => { + const agentWithoutName = { + agent_id: "agent_789", + agent_config: { + instructions: "You are an agent without a name.", + }, + }; + + mockClient.agents.list.mockResolvedValue({ data: [agentWithoutName] }); + + await act(async () => { + render(); + }); + + await waitFor(() => { + const agentCombobox = screen.getAllByRole("combobox").find(element => { + return ( + element.textContent?.includes("Agent agent_78") || + element.textContent?.includes("Select Agent") + ); + }); + expect(agentCombobox).toBeDefined(); + fireEvent.click(agentCombobox!); + }); + + await waitFor(() => { + expect(screen.getAllByText("Agent agent_78...")).toHaveLength(2); + }); + }); + }); + + describe("Agent Creation Modal", () => { + test("opens agent creation modal when + New Agent is clicked", async () => { + await act(async () => { + render(); + }); + + const newAgentButton = screen.getByText("+ New Agent"); + fireEvent.click(newAgentButton); + + expect(screen.getByText("Create New Agent")).toBeInTheDocument(); + expect(screen.getByText("Agent Name (optional)")).toBeInTheDocument(); + expect(screen.getAllByText("Model")).toHaveLength(2); + expect(screen.getByText("System Instructions")).toBeInTheDocument(); + expect(screen.getByText("Tools (optional)")).toBeInTheDocument(); + }); + + test("closes modal when Cancel is clicked", async () => { + await act(async () => { + render(); + }); + + const newAgentButton = screen.getByText("+ New Agent"); + fireEvent.click(newAgentButton); + + const cancelButton = screen.getByText("Cancel"); + fireEvent.click(cancelButton); + + expect(screen.queryByText("Create New Agent")).not.toBeInTheDocument(); + }); + + test("creates agent when Create Agent is clicked", async () => { + mockClient.agents.create.mockResolvedValue({ agent_id: "new-agent-123" }); + mockClient.agents.list + .mockResolvedValueOnce({ data: mockAgents }) + .mockResolvedValueOnce({ + data: [ + ...mockAgents, + { agent_id: "new-agent-123", agent_config: { name: "New Agent" } }, + ], + }); + + await act(async () => { + render(); + }); + + const newAgentButton = screen.getByText("+ New Agent"); + await act(async () => { + fireEvent.click(newAgentButton); + }); + + await waitFor(() => { + expect(screen.getByText("Create New Agent")).toBeInTheDocument(); + }); + + const nameInput = screen.getByPlaceholderText("My Custom Agent"); + await act(async () => { + fireEvent.change(nameInput, { target: { value: "Test Agent Name" } }); + }); + + const instructionsTextarea = screen.getByDisplayValue( + "You are a helpful assistant." + ); + await act(async () => { + fireEvent.change(instructionsTextarea, { + target: { value: "Custom instructions" }, + }); + }); + + await waitFor(() => { + const modalModelSelectors = screen + .getAllByRole("combobox") + .filter(el => { + return ( + el.textContent?.includes("Select Model") || + el.closest('[class*="modal"]') || + el.closest('[class*="card"]') + ); + }); + expect(modalModelSelectors.length).toBeGreaterThan(0); + }); + + const modalModelSelectors = screen.getAllByRole("combobox").filter(el => { + return ( + el.textContent?.includes("Select Model") || + el.closest('[class*="modal"]') || + el.closest('[class*="card"]') + ); + }); + + await act(async () => { + fireEvent.click(modalModelSelectors[0]); + }); + + await waitFor(() => { + const modelOptions = screen.getAllByText("test-model-1"); + expect(modelOptions.length).toBeGreaterThan(0); + }); + + const modelOptions = screen.getAllByText("test-model-1"); + const dropdownOption = modelOptions.find( + option => + option.closest('[role="option"]') || + option.id?.includes("radix") || + option.getAttribute("aria-selected") !== null + ); + + await act(async () => { + fireEvent.click( + dropdownOption || modelOptions[modelOptions.length - 1] + ); + }); + + await waitFor(() => { + const createButton = screen.getByText("Create Agent"); + expect(createButton).not.toBeDisabled(); + }); + + const createButton = screen.getByText("Create Agent"); + await act(async () => { + fireEvent.click(createButton); + }); + + await waitFor(() => { + expect(mockClient.agents.create).toHaveBeenCalledWith({ + agent_config: { + model: expect.any(String), + instructions: "Custom instructions", + name: "Test Agent Name", + enable_session_persistence: true, + }, + }); + }); + + await waitFor(() => { + expect(screen.queryByText("Create New Agent")).not.toBeInTheDocument(); + }); + }); + }); + + describe("Agent Selection", () => { + test("creates default session when agent is selected", async () => { + await act(async () => { + render(); + }); + + await waitFor(() => { + // first agent should be auto-selected + expect(mockClient.agents.session.create).toHaveBeenCalledWith( + "agent_123", + { session_name: "Default Session" } + ); + }); + }); + + test("switches agent when different agent is selected", async () => { + await act(async () => { + render(); + }); + + await waitFor(() => { + const agentCombobox = screen.getAllByRole("combobox").find(element => { + return ( + element.textContent?.includes("Test Agent") || + element.textContent?.includes("Select Agent") + ); + }); + expect(agentCombobox).toBeDefined(); + fireEvent.click(agentCombobox!); + }); + + await waitFor(() => { + const anotherAgentOption = screen.getByText("Another Agent"); + fireEvent.click(anotherAgentOption); + }); + + expect(mockClient.agents.session.create).toHaveBeenCalledWith( + "agent_456", + { session_name: "Default Session" } + ); + }); + }); + + describe("Agent Deletion", () => { + test("shows delete button when multiple agents exist", async () => { + await act(async () => { + render(); + }); + + await waitFor(() => { + expect(screen.getByTitle("Delete current agent")).toBeInTheDocument(); + }); + }); + + test("hides delete button when only one agent exists", async () => { + mockClient.agents.list.mockResolvedValue({ + data: [mockAgents[0]], + }); + + await act(async () => { + render(); + }); + + await waitFor(() => { + expect( + screen.queryByTitle("Delete current agent") + ).not.toBeInTheDocument(); + }); + }); + + test("deletes agent and switches to another when confirmed", async () => { + global.confirm = jest.fn(() => true); + + await act(async () => { + render(); + }); + + await waitFor(() => { + expect(screen.getByTitle("Delete current agent")).toBeInTheDocument(); + }); + + mockClient.agents.delete.mockResolvedValue(undefined); + mockClient.agents.list.mockResolvedValueOnce({ data: mockAgents }); + mockClient.agents.list.mockResolvedValueOnce({ + data: [mockAgents[1]], + }); + + const deleteButton = screen.getByTitle("Delete current agent"); + await act(async () => { + deleteButton.click(); + }); + + await waitFor(() => { + expect(mockClient.agents.delete).toHaveBeenCalledWith("agent_123"); + expect(global.confirm).toHaveBeenCalledWith( + "Are you sure you want to delete this agent? This action cannot be undone and will delete all associated sessions." + ); + }); + + (global.confirm as jest.Mock).mockRestore(); + }); + + test("does not delete agent when cancelled", async () => { + global.confirm = jest.fn(() => false); + + await act(async () => { + render(); + }); + + await waitFor(() => { + expect(screen.getByTitle("Delete current agent")).toBeInTheDocument(); + }); + + const deleteButton = screen.getByTitle("Delete current agent"); + await act(async () => { + deleteButton.click(); + }); + + await waitFor(() => { + expect(global.confirm).toHaveBeenCalled(); + expect(mockClient.agents.delete).not.toHaveBeenCalled(); + }); + + (global.confirm as jest.Mock).mockRestore(); + }); + }); + + describe("Error Handling", () => { + test("handles agent loading errors gracefully", async () => { + mockClient.agents.list.mockRejectedValue( + new Error("Failed to load agents") + ); + const consoleSpy = jest + .spyOn(console, "error") + .mockImplementation(() => {}); + + await act(async () => { + render(); + }); + + await waitFor(() => { + expect(consoleSpy).toHaveBeenCalledWith( + "Error fetching agents:", + expect.any(Error) + ); + }); + + expect(screen.getByText("+ New Agent")).toBeInTheDocument(); + + consoleSpy.mockRestore(); + }); + + test("handles model loading errors gracefully", async () => { + mockClient.models.list.mockRejectedValue( + new Error("Failed to load models") + ); + const consoleSpy = jest + .spyOn(console, "error") + .mockImplementation(() => {}); + + await act(async () => { + render(); + }); + + await waitFor(() => { + expect(consoleSpy).toHaveBeenCalledWith( + "Error fetching models:", + expect.any(Error) + ); + }); + + consoleSpy.mockRestore(); + }); + }); +}); diff --git a/llama_stack/ui/app/chat-playground/page.tsx b/llama_stack/ui/app/chat-playground/page.tsx index d8094af85..f26791a41 100644 --- a/llama_stack/ui/app/chat-playground/page.tsx +++ b/llama_stack/ui/app/chat-playground/page.tsx @@ -1,6 +1,6 @@ "use client"; -import { useState, useEffect } from "react"; +import { useState, useEffect, useCallback, useRef } from "react"; import { flushSync } from "react-dom"; import { Button } from "@/components/ui/button"; import { @@ -10,14 +10,22 @@ import { SelectTrigger, SelectValue, } from "@/components/ui/select"; +import { Card } from "@/components/ui/card"; +import { Input } from "@/components/ui/input"; +import { Trash2 } from "lucide-react"; import { Chat } from "@/components/chat-playground/chat"; import { type Message } from "@/components/chat-playground/chat-message"; import { useAuthClient } from "@/hooks/use-auth-client"; -import type { CompletionCreateParams } from "llama-stack-client/resources/chat/completions"; import type { Model } from "llama-stack-client/resources/models"; - +import type { TurnCreateParams } from "llama-stack-client/resources/agents/turn"; +import { + SessionUtils, + type ChatSession, +} from "@/components/chat-playground/conversations"; export default function ChatPlaygroundPage() { - const [messages, setMessages] = useState([]); + const [currentSession, setCurrentSession] = useState( + null + ); const [input, setInput] = useState(""); const [isGenerating, setIsGenerating] = useState(false); const [error, setError] = useState(null); @@ -25,10 +33,522 @@ export default function ChatPlaygroundPage() { const [selectedModel, setSelectedModel] = useState(""); const [modelsLoading, setModelsLoading] = useState(true); const [modelsError, setModelsError] = useState(null); + const [agents, setAgents] = useState< + Array<{ + agent_id: string; + agent_config?: { + agent_name?: string; + name?: string; + instructions?: string; + }; + [key: string]: unknown; + }> + >([]); + const [selectedAgentConfig, setSelectedAgentConfig] = useState<{ + toolgroups?: Array< + string | { name: string; args: Record } + >; + } | null>(null); + const [selectedAgentId, setSelectedAgentId] = useState(""); + const [agentsLoading, setAgentsLoading] = useState(true); + const [showCreateAgent, setShowCreateAgent] = useState(false); + const [newAgentName, setNewAgentName] = useState(""); + const [newAgentInstructions, setNewAgentInstructions] = useState( + "You are a helpful assistant." + ); + const [selectedToolgroups, setSelectedToolgroups] = useState([]); + const [availableToolgroups, setAvailableToolgroups] = useState< + Array<{ + identifier: string; + provider_id: string; + type: string; + provider_resource_id?: string; + }> + >([]); const client = useAuthClient(); + const abortControllerRef = useRef(null); const isModelsLoading = modelsLoading ?? true; + const loadAgentConfig = useCallback( + async (agentId: string) => { + try { + console.log("Loading agent config for:", agentId); + + // try to load from cache first + const cachedConfig = SessionUtils.loadAgentConfig(agentId); + if (cachedConfig) { + console.log("✅ Loaded agent config from cache:", cachedConfig); + setSelectedAgentConfig({ + toolgroups: cachedConfig.toolgroups, + }); + return; + } + + console.log("📡 Fetching agent config from API..."); + const agentDetails = await client.agents.retrieve(agentId); + console.log("Agent details retrieved:", agentDetails); + console.log("Agent config:", agentDetails.agent_config); + console.log("Agent toolgroups:", agentDetails.agent_config?.toolgroups); + + // cache the config + SessionUtils.saveAgentConfig(agentId, agentDetails.agent_config); + + setSelectedAgentConfig({ + toolgroups: agentDetails.agent_config?.toolgroups, + }); + } catch (error) { + console.error("Error loading agent config:", error); + setSelectedAgentConfig(null); + } + }, + [client] + ); + + const createDefaultSession = useCallback( + async (agentId: string) => { + try { + const response = await client.agents.session.create(agentId, { + session_name: "Default Session", + }); + + const defaultSession: ChatSession = { + id: response.session_id, + name: "Default Session", + messages: [], + selectedModel: selectedModel, // Use current selected model + systemMessage: "You are a helpful assistant.", + agentId, + createdAt: Date.now(), + updatedAt: Date.now(), + }; + + setCurrentSession(defaultSession); + console.log( + `💾 Saving default session ID for agent ${agentId}:`, + defaultSession.id + ); + SessionUtils.saveCurrentSessionId(defaultSession.id, agentId); + // cache entire session data + SessionUtils.saveSessionData(agentId, defaultSession); + } catch (error) { + console.error("Error creating default session:", error); + } + }, + [client, selectedModel] + ); + + const loadSessionMessages = useCallback( + async (agentId: string, sessionId: string): Promise => { + try { + const session = await client.agents.session.retrieve( + agentId, + sessionId + ); + + if (!session || !session.turns || !Array.isArray(session.turns)) { + return []; + } + + const messages: Message[] = []; + for (const turn of session.turns) { + // add user messages + if (turn.input_messages && Array.isArray(turn.input_messages)) { + for (const input of turn.input_messages) { + if (input.role === "user" && input.content) { + messages.push({ + id: `${turn.turn_id}-user-${messages.length}`, + role: "user", + content: + typeof input.content === "string" + ? input.content + : JSON.stringify(input.content), + createdAt: new Date(turn.started_at || Date.now()), + }); + } + } + } + + // add assistant message from output_message + if (turn.output_message && turn.output_message.content) { + messages.push({ + id: `${turn.turn_id}-assistant-${messages.length}`, + role: "assistant", + content: + typeof turn.output_message.content === "string" + ? turn.output_message.content + : JSON.stringify(turn.output_message.content), + createdAt: new Date( + turn.completed_at || turn.started_at || Date.now() + ), + }); + } + } + + return messages; + } catch (error) { + console.error("Error loading session messages:", error); + return []; + } + }, + [client] + ); + + const loadAgentSessions = useCallback( + async (agentId: string) => { + try { + console.log("Loading sessions for agent:", agentId); + const response = await client.agents.session.list(agentId); + console.log("Available sessions:", response.data); + + if ( + response.data && + Array.isArray(response.data) && + response.data.length > 0 + ) { + // check for a previously saved session ID for this specific agent + const savedSessionId = SessionUtils.loadCurrentSessionId(agentId); + console.log(`Saved session ID for agent ${agentId}:`, savedSessionId); + + // try to load cached session data first + if (savedSessionId) { + const cachedSession = SessionUtils.loadSessionData( + agentId, + savedSessionId + ); + if (cachedSession) { + console.log("✅ Loaded session from cache:", cachedSession.id); + setCurrentSession(cachedSession); + SessionUtils.saveCurrentSessionId(cachedSession.id, agentId); + return; + } + console.log("📡 Cache miss, fetching session from API..."); + } + + let sessionToLoad = response.data[0] as { + session_id: string; + session_name?: string; + started_at?: string; + }; + console.log( + "Default session to load (first in list):", + sessionToLoad.session_id + ); + + // try to find saved session id in available sessions + if (savedSessionId) { + const foundSession = response.data.find( + (s: { session_id: string }) => s.session_id === savedSessionId + ); + console.log("Found saved session in list:", foundSession); + if (foundSession) { + sessionToLoad = foundSession as { + session_id: string; + session_name?: string; + started_at?: string; + }; + console.log( + "✅ Restored previously selected session:", + savedSessionId + ); + } else { + console.log( + "❌ Previously selected session not found, using latest session" + ); + } + } else { + console.log("❌ No saved session ID found, using latest session"); + } + + const messages = await loadSessionMessages( + agentId, + sessionToLoad.session_id + ); + + const session: ChatSession = { + id: sessionToLoad.session_id, + name: sessionToLoad.session_name || "Session", + messages, + selectedModel: selectedModel || "", // Preserve current model or use empty + systemMessage: "You are a helpful assistant.", + agentId, + createdAt: sessionToLoad.started_at + ? new Date(sessionToLoad.started_at).getTime() + : Date.now(), + updatedAt: Date.now(), + }; + + setCurrentSession(session); + console.log(`💾 Saving session ID for agent ${agentId}:`, session.id); + SessionUtils.saveCurrentSessionId(session.id, agentId); + // cache session data + SessionUtils.saveSessionData(agentId, session); + } else { + // no sessions, create a new one + await createDefaultSession(agentId); + } + } catch (error) { + console.error("Error loading agent sessions:", error); + // fallback to creating a new session + await createDefaultSession(agentId); + } + }, + [client, loadSessionMessages, createDefaultSession, selectedModel] + ); + + useEffect(() => { + const fetchAgents = async () => { + try { + setAgentsLoading(true); + const agentList = await client.agents.list(); + setAgents( + (agentList.data as Array<{ + agent_id: string; + agent_config?: { + agent_name?: string; + name?: string; + instructions?: string; + }; + [key: string]: unknown; + }>) || [] + ); + + if (agentList.data && agentList.data.length > 0) { + // check if there's a previously selected agent + const savedAgentId = SessionUtils.loadCurrentAgentId(); + + let agentToSelect = agentList.data[0] as { + agent_id: string; + agent_config?: { + agent_name?: string; + name?: string; + instructions?: string; + }; + [key: string]: unknown; + }; + + // if we have a saved agent ID, find it in the available agents + if (savedAgentId) { + const foundAgent = agentList.data.find( + (a: { agent_id: string }) => a.agent_id === savedAgentId + ); + if (foundAgent) { + agentToSelect = foundAgent as typeof agentToSelect; + } else { + console.log("Previously slelected agent not found:"); + } + } + setSelectedAgentId(agentToSelect.agent_id); + SessionUtils.saveCurrentAgentId(agentToSelect.agent_id); + // load agent config immediately + await loadAgentConfig(agentToSelect.agent_id); + // Note: loadAgentSessions will be called after models are loaded + } + } catch (error) { + console.error("Error fetching agents:", error); + } finally { + setAgentsLoading(false); + } + }; + + fetchAgents(); + + // fetch available toolgroups + const fetchToolgroups = async () => { + try { + console.log("Fetching toolgroups..."); + const toolgroups = await client.toolgroups.list(); + console.log("Toolgroups response:", toolgroups); + + // The client returns data directly, not wrapped in .data + const toolGroupsArray = Array.isArray(toolgroups) + ? toolgroups + : toolgroups && + typeof toolgroups === "object" && + "data" in toolgroups && + Array.isArray((toolgroups as { data: unknown }).data) + ? ( + toolgroups as { + data: Array<{ + identifier: string; + provider_id: string; + type: string; + provider_resource_id?: string; + }>; + } + ).data + : []; + + if (toolGroupsArray && Array.isArray(toolGroupsArray)) { + setAvailableToolgroups(toolGroupsArray); + console.log("Set toolgroups:", toolGroupsArray); + } else { + console.error("Invalid toolgroups data format:", toolgroups); + } + } catch (error) { + console.error("Error fetching toolgroups:", error); + if (error instanceof Error) { + console.error("Error details:", { + name: error.name, + message: error.message, + stack: error.stack, + }); + } + } + }; + + fetchToolgroups(); + }, [client, loadAgentSessions, loadAgentConfig]); + + const createNewAgent = useCallback( + async ( + name: string, + instructions: string, + model: string, + toolgroups: string[] = [] + ) => { + try { + console.log("Creating agent with toolgroups:", toolgroups); + const agentConfig = { + model, + instructions, + name: name || undefined, + enable_session_persistence: true, + toolgroups: toolgroups.length > 0 ? toolgroups : undefined, + }; + console.log("Agent config being sent:", agentConfig); + + const response = await client.agents.create({ + agent_config: agentConfig, + }); + + // refresh agents list + const agentList = await client.agents.list(); + setAgents( + (agentList.data as Array<{ + agent_id: string; + agent_config?: { + agent_name?: string; + name?: string; + instructions?: string; + }; + [key: string]: unknown; + }>) || [] + ); + + // set the new agent as selected + setSelectedAgentId(response.agent_id); + await loadAgentConfig(response.agent_id); + await loadAgentSessions(response.agent_id); + + return response.agent_id; + } catch (error) { + console.error("Error creating agent:", error); + throw error; + } + }, + [client, loadAgentSessions, loadAgentConfig] + ); + + const deleteAgent = useCallback( + async (agentId: string) => { + if (agents.length <= 1) { + return; + } + + if ( + confirm( + "Are you sure you want to delete this agent? This action cannot be undone and will delete all associated sessions." + ) + ) { + try { + await client.agents.delete(agentId); + + // clear cached data for agent + SessionUtils.clearAgentCache(agentId); + + // Refresh agents list + const agentList = await client.agents.list(); + setAgents( + (agentList.data as Array<{ + agent_id: string; + agent_config?: { + agent_name?: string; + name?: string; + instructions?: string; + }; + [key: string]: unknown; + }>) || [] + ); + + // if we deleted the current agent, switch to another one + if (selectedAgentId === agentId) { + const remainingAgents = agentList.data?.filter( + (a: { agent_id: string }) => a.agent_id !== agentId + ); + if (remainingAgents && remainingAgents.length > 0) { + const newAgent = remainingAgents[0] as { + agent_id: string; + agent_config?: { + agent_name?: string; + name?: string; + instructions?: string; + }; + [key: string]: unknown; + }; + setSelectedAgentId(newAgent.agent_id); + SessionUtils.saveCurrentAgentId(newAgent.agent_id); + await loadAgentConfig(newAgent.agent_id); + await loadAgentSessions(newAgent.agent_id); + } else { + // No agents left + setSelectedAgentId(""); + setCurrentSession(null); + setSelectedAgentConfig(null); + } + } + } catch (error) { + console.error("Error deleting agent:", error); + } + } + }, + [agents.length, client, selectedAgentId, loadAgentConfig, loadAgentSessions] + ); + + const handleModelChange = useCallback((newModel: string) => { + setSelectedModel(newModel); + setCurrentSession(prev => + prev + ? { + ...prev, + selectedModel: newModel, + updatedAt: Date.now(), + } + : prev + ); + }, []); + + useEffect(() => { + if (currentSession) { + console.log( + `💾 Auto-saving session ID for agent ${currentSession.agentId}:`, + currentSession.id + ); + SessionUtils.saveCurrentSessionId( + currentSession.id, + currentSession.agentId + ); + // cache session data + SessionUtils.saveSessionData(currentSession.agentId, currentSession); + // only update selectedModel if the session has a valid model and it's different from current + if ( + currentSession.selectedModel && + currentSession.selectedModel !== selectedModel + ) { + setSelectedModel(currentSession.selectedModel); + } + } + }, [currentSession, selectedModel]); useEffect(() => { const fetchModels = async () => { @@ -36,10 +556,10 @@ export default function ChatPlaygroundPage() { setModelsLoading(true); setModelsError(null); const modelList = await client.models.list(); - const llmModels = modelList.filter(model => model.model_type === 'llm'); + const llmModels = modelList.filter(model => model.model_type === "llm"); setModels(llmModels); if (llmModels.length > 0) { - setSelectedModel(llmModels[0].identifier); + handleModelChange(llmModels[0].identifier); } } catch (err) { console.error("Error fetching models:", err); @@ -50,106 +570,207 @@ export default function ChatPlaygroundPage() { }; fetchModels(); - }, [client]); + }, [client, handleModelChange]); - const extractTextContent = (content: unknown): string => { - if (typeof content === 'string') { - return content; + // load agent sessions after both agents and models are ready + useEffect(() => { + if ( + selectedAgentId && + !agentsLoading && + !modelsLoading && + selectedModel && + !currentSession + ) { + loadAgentSessions(selectedAgentId); } - if (Array.isArray(content)) { - return content - .filter(item => item && typeof item === 'object' && 'type' in item && item.type === 'text') - .map(item => (item && typeof item === 'object' && 'text' in item) ? String(item.text) : '') - .join(''); - } - if (content && typeof content === 'object' && 'type' in content && content.type === 'text' && 'text' in content) { - return String(content.text) || ''; - } - return ''; - }; + }, [ + selectedAgentId, + agentsLoading, + modelsLoading, + selectedModel, + currentSession, + loadAgentSessions, + ]); const handleInputChange = (e: React.ChangeEvent) => { setInput(e.target.value); }; -const handleSubmit = async (event?: { preventDefault?: () => void }) => { - event?.preventDefault?.(); - if (!input.trim()) return; + const handleSubmit = async (event?: { preventDefault?: () => void }) => { + event?.preventDefault?.(); + if (!input.trim()) return; - // Add user message to chat - const userMessage: Message = { - id: Date.now().toString(), - role: "user", - content: input.trim(), - createdAt: new Date(), - }; - - setMessages(prev => [...prev, userMessage]); - setInput(""); - - // Use the helper function with the content - await handleSubmitWithContent(userMessage.content); -}; - -const handleSubmitWithContent = async (content: string) => { - setIsGenerating(true); - setError(null); - - try { - const messageParams: CompletionCreateParams["messages"] = [ - ...messages.map(msg => { - const msgContent = typeof msg.content === 'string' ? msg.content : extractTextContent(msg.content); - if (msg.role === "user") { - return { role: "user" as const, content: msgContent }; - } else if (msg.role === "assistant") { - return { role: "assistant" as const, content: msgContent }; - } else { - return { role: "system" as const, content: msgContent }; - } - }), - { role: "user" as const, content } - ]; - - const response = await client.chat.completions.create({ - model: selectedModel, - messages: messageParams, - stream: true, - }); - - const assistantMessage: Message = { - id: (Date.now() + 1).toString(), - role: "assistant", - content: "", + const userMessage: Message = { + id: Date.now().toString(), + role: "user", + content: input.trim(), createdAt: new Date(), }; - setMessages(prev => [...prev, assistantMessage]); - let fullContent = ""; - for await (const chunk of response) { - if (chunk.choices && chunk.choices[0]?.delta?.content) { - const deltaContent = chunk.choices[0].delta.content; - fullContent += deltaContent; + setCurrentSession(prev => { + if (!prev) return prev; + const updatedSession = { + ...prev, + messages: [...prev.messages, userMessage], + updatedAt: Date.now(), + }; + // Update cache with new message + SessionUtils.saveSessionData(prev.agentId, updatedSession); + return updatedSession; + }); + setInput(""); - flushSync(() => { - setMessages(prev => { - const newMessages = [...prev]; - const lastMessage = newMessages[newMessages.length - 1]; - if (lastMessage.role === "assistant") { - lastMessage.content = fullContent; - } - return newMessages; - }); - }); - } + await handleSubmitWithContent(userMessage.content); + }; + + const handleSubmitWithContent = async (content: string) => { + if (!currentSession || !selectedAgentId) return; + + setIsGenerating(true); + setError(null); + + if (abortControllerRef.current) { + abortControllerRef.current.abort(); } - } catch (err) { - console.error("Error sending message:", err); - setError("Failed to send message. Please try again."); - setMessages(prev => prev.slice(0, -1)); - } finally { - setIsGenerating(false); - } -}; + + const abortController = new AbortController(); + abortControllerRef.current = abortController; + + try { + const userMessage = { + role: "user" as const, + content, + }; + + const turnParams: TurnCreateParams = { + messages: [userMessage], + stream: true, + }; + + const response = await client.agents.turn.create( + selectedAgentId, + currentSession.id, + turnParams, + { + signal: abortController.signal, + } as { signal: AbortSignal } + ); + + const assistantMessage: Message = { + id: (Date.now() + 1).toString(), + role: "assistant", + content: "", + createdAt: new Date(), + }; + + const extractDeltaText = (chunk: unknown): string | null => { + // this is an awful way to handle different chunk formats, but i'm not sure if there's much of a better way + if (chunk?.delta?.text && typeof chunk.delta.text === "string") { + return chunk.delta.text; + } + + if ( + chunk?.event?.delta?.text && + typeof chunk.event.delta.text === "string" + ) { + return chunk.event.delta.text; + } + + if ( + chunk?.choices?.[0]?.delta?.content && + typeof chunk.choices[0].delta.content === "string" + ) { + return chunk.choices[0].delta.content; + } + + if (typeof chunk === "string") { + return chunk; + } + + if ( + chunk?.event?.payload?.delta?.text && + typeof chunk.event.payload.delta.text === "string" + ) { + return chunk.event.payload.delta.text; + } + + if (process.env.NODE_ENV !== "production") { + console.debug("Unrecognized chunk format:", chunk); + } + + return null; + }; + setCurrentSession(prev => { + if (!prev) return null; + const updatedSession = { + ...prev, + messages: [...prev.messages, assistantMessage], + updatedAt: Date.now(), + }; + // update cache with assistant message + SessionUtils.saveSessionData(prev.agentId, updatedSession); + return updatedSession; + }); + + let fullContent = ""; + for await (const chunk of response) { + const deltaText = extractDeltaText(chunk); + + if (deltaText) { + fullContent += deltaText; + + flushSync(() => { + setCurrentSession(prev => { + if (!prev) return null; + const newMessages = [...prev.messages]; + const last = newMessages[newMessages.length - 1]; + if (last.role === "assistant") { + last.content = fullContent; + } + const updatedSession = { + ...prev, + messages: newMessages, + updatedAt: Date.now(), + }; + // update cache with streaming content (throttled) + if (fullContent.length % 100 === 0) { + // Only cache every 100 characters to avoid spam + SessionUtils.saveSessionData(prev.agentId, updatedSession); + } + return updatedSession; + }); + }); + } + } + } catch (err) { + if (err instanceof Error && err.name === "AbortError") { + console.log("Request aborted"); + return; + } + + console.error("Error sending message:", err); + setError("Failed to send message. Please try again."); + setCurrentSession(prev => + prev + ? { + ...prev, + messages: prev.messages.slice(0, -1), + updatedAt: Date.now(), + } + : prev + ); + } finally { + setIsGenerating(false); + abortControllerRef.current = null; + // cache final session state after streaming completes + setCurrentSession(prev => { + if (prev) { + SessionUtils.saveSessionData(prev.agentId, prev); + } + return prev; + }); + } + }; const suggestions = [ "Write a Python function that prints 'Hello, World!'", "Explain step-by-step how to solve this math problem: If x² + 6x + 9 = 25, what is x?", @@ -163,61 +784,457 @@ const handleSubmitWithContent = async (content: string) => { content: message.content, createdAt: new Date(), }; - setMessages(prev => [...prev, newMessage]) + setCurrentSession(prev => + prev + ? { + ...prev, + messages: [...prev.messages, newMessage], + updatedAt: Date.now(), + } + : prev + ); handleSubmitWithContent(newMessage.content); }; const clearChat = () => { - setMessages([]); + if (abortControllerRef.current) { + abortControllerRef.current.abort(); + abortControllerRef.current = null; + setIsGenerating(false); + } + + setCurrentSession(prev => + prev ? { ...prev, messages: [], updatedAt: Date.now() } : prev + ); setError(null); }; return ( -
-
-

Chat Playground (Completions)

-
- - +
+ {/* Header */} +
+
+

Agent Session

+
+ {!agentsLoading && agents.length > 0 && ( +
+ + + {selectedAgentId && agents.length > 1 && ( + + )} +
+ )} + + {!agentsLoading && agents.length > 0 && ( + + )} +
+
+
+ {/* Main Two-Column Layout */} +
+ {/* Left Column - Configuration Panel */} +
+

+ Settings +

+ + {/* Model Configuration */} +
+

+ Model Configuration +

+
+
+ + + {modelsError && ( +

{modelsError}

+ )} +
+ +
+ +
+ {(selectedAgentId && + agents.find(a => a.agent_id === selectedAgentId) + ?.agent_config?.instructions) || + "No agent selected"} +
+

+ Instructions are set when creating an agent and cannot be + changed. +

+
+
+
+ + {/* Agent Tools */} +
+

+ Agent Tools +

+
+
+ +
+ {selectedAgentConfig?.toolgroups && + selectedAgentConfig.toolgroups.length > 0 ? ( + selectedAgentConfig.toolgroups.map( + ( + toolgroup: + | string + | { name: string; args: Record }, + index: number + ) => { + const toolName = + typeof toolgroup === "string" + ? toolgroup + : toolgroup.name; + const toolArgs = + typeof toolgroup === "object" ? toolgroup.args : null; + + return ( +
+
+ + {toolName} + + + {toolName.includes("rag") + ? "🔍 RAG" + : toolName.includes("search") + ? "🌐 Search" + : "🔧 Tool"} + +
+ {toolArgs && Object.keys(toolArgs).length > 0 && ( +
+ Args:{" "} + {Object.entries(toolArgs) + .map( + ([key, value]) => + `${key}: ${JSON.stringify(value)}` + ) + .join(", ")} +
+ )} +
+ ); + } + ) + ) : ( +
+

+ No tools configured +

+

+ This agent only has text generation capabilities +

+
+ )} +
+

+ Tools are configured when creating an agent and provide + additional capabilities like web search, math calculations, or + RAG document retrieval. +

+
+
+
+
+ + {/* Right Column - Chat Interface */} +
+ {error && ( +
+

{error}

+
+ )} + + + setCurrentSession(prev => + prev ? { ...prev, messages, updatedAt: Date.now() } : prev + ) + } + />
- {modelsError && ( -
-

{modelsError}

+ {/* Create Agent Modal */} + {showCreateAgent && ( +
+ +

Create New Agent

+ +
+
+ + setNewAgentName(e.target.value)} + placeholder="My Custom Agent" + /> +
+ +
+ + +
+ +
+ +