diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 134efd93b..f88402a7a 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -9,6 +9,7 @@ updates: day: "saturday" commit-message: prefix: chore(github-deps) + - package-ecosystem: "uv" directory: "/" schedule: @@ -19,3 +20,14 @@ updates: - python commit-message: prefix: chore(python-deps) + + - package-ecosystem: npm + directory: "/llama_stack/ui" + schedule: + interval: "weekly" + day: "saturday" + labels: + - type/dependencies + - javascript + commit-message: + prefix: chore(ui-deps) diff --git a/.github/workflows/changelog.yml b/.github/workflows/changelog.yml index e406d99ee..7a75d85f6 100644 --- a/.github/workflows/changelog.yml +++ b/.github/workflows/changelog.yml @@ -17,7 +17,7 @@ jobs: pull-requests: write # for peter-evans/create-pull-request to create a PR runs-on: ubuntu-latest steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: ref: main fetch-depth: 0 diff --git a/.github/workflows/install-script-ci.yml b/.github/workflows/install-script-ci.yml index 1ecda6d51..a37919f56 100644 --- a/.github/workflows/install-script-ci.yml +++ b/.github/workflows/install-script-ci.yml @@ -16,14 +16,14 @@ jobs: lint: runs-on: ubuntu-latest steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # 4.2.2 + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # 5.0.0 - name: Run ShellCheck on install.sh run: shellcheck scripts/install.sh smoke-test-on-dev: runs-on: ubuntu-latest steps: - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Install dependencies uses: ./.github/actions/setup-runner diff --git a/.github/workflows/integration-auth-tests.yml b/.github/workflows/integration-auth-tests.yml index c328e3b6c..6787806e9 100644 --- a/.github/workflows/integration-auth-tests.yml +++ b/.github/workflows/integration-auth-tests.yml @@ -18,7 +18,7 @@ on: - '.github/workflows/integration-auth-tests.yml' # This workflow concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }} cancel-in-progress: true jobs: @@ -31,7 +31,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Install dependencies uses: ./.github/actions/setup-runner diff --git a/.github/workflows/integration-sql-store-tests.yml b/.github/workflows/integration-sql-store-tests.yml index 4e5b64963..3efd970e1 100644 --- a/.github/workflows/integration-sql-store-tests.yml +++ b/.github/workflows/integration-sql-store-tests.yml @@ -16,7 +16,7 @@ on: - '.github/workflows/integration-sql-store-tests.yml' # This workflow concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }} cancel-in-progress: true jobs: @@ -44,7 +44,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Install dependencies uses: ./.github/actions/setup-runner diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index ba18c27c8..57e582b20 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -65,7 +65,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Setup test environment uses: ./.github/actions/setup-test-environment diff --git a/.github/workflows/integration-vector-io-tests.yml b/.github/workflows/integration-vector-io-tests.yml index 61b8e004e..de5701073 100644 --- a/.github/workflows/integration-vector-io-tests.yml +++ b/.github/workflows/integration-vector-io-tests.yml @@ -33,7 +33,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Install dependencies uses: ./.github/actions/setup-runner diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 4f1c143d2..2825c3bf4 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -8,7 +8,7 @@ on: branches: [main] concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }} cancel-in-progress: true jobs: @@ -20,7 +20,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: # For dependabot PRs, we need to checkout with a token that can push changes token: ${{ github.actor == 'dependabot[bot]' && secrets.GITHUB_TOKEN || github.token }} @@ -36,6 +36,17 @@ jobs: **/requirements*.txt .pre-commit-config.yaml + - name: Set up Node.js + uses: actions/setup-node@39370e3970a6d050c480ffad4ff0ed4d3fdee5af # v4.1.0 + with: + node-version: '20' + cache: 'npm' + cache-dependency-path: 'llama_stack/ui/' + + - name: Install npm dependencies + run: npm ci + working-directory: llama_stack/ui + - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1 continue-on-error: true env: diff --git a/.github/workflows/providers-build.yml b/.github/workflows/providers-build.yml index 929d76760..391acbcf8 100644 --- a/.github/workflows/providers-build.yml +++ b/.github/workflows/providers-build.yml @@ -26,7 +26,7 @@ on: - 'pyproject.toml' concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }} cancel-in-progress: true jobs: @@ -36,7 +36,7 @@ jobs: distros: ${{ steps.set-matrix.outputs.distros }} steps: - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Generate Distribution List id: set-matrix @@ -55,7 +55,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Install dependencies uses: ./.github/actions/setup-runner @@ -79,7 +79,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Install dependencies uses: ./.github/actions/setup-runner @@ -92,7 +92,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Install dependencies uses: ./.github/actions/setup-runner @@ -106,6 +106,10 @@ jobs: - name: Inspect the container image entrypoint run: | IMAGE_ID=$(docker images --format "{{.Repository}}:{{.Tag}}" | head -n 1) + if [ -z "$IMAGE_ID" ]; then + echo "No image found" + exit 1 + fi entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID) echo "Entrypoint: $entrypoint" if [ "$entrypoint" != "[python -m llama_stack.core.server.server /app/run.yaml]" ]; then @@ -117,7 +121,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Install dependencies uses: ./.github/actions/setup-runner @@ -140,6 +144,10 @@ jobs: - name: Inspect UBI9 image run: | IMAGE_ID=$(docker images --format "{{.Repository}}:{{.Tag}}" | head -n 1) + if [ -z "$IMAGE_ID" ]; then + echo "No image found" + exit 1 + fi entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID) echo "Entrypoint: $entrypoint" if [ "$entrypoint" != "[python -m llama_stack.core.server.server /app/run.yaml]" ]; then diff --git a/.github/workflows/python-build-test.yml b/.github/workflows/python-build-test.yml index fe1dfd58a..9de53f7fb 100644 --- a/.github/workflows/python-build-test.yml +++ b/.github/workflows/python-build-test.yml @@ -21,10 +21,10 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Install uv - uses: astral-sh/setup-uv@e92bafb6253dcd438e0484186d7669ea7a8ca1cc # v6.4.3 + uses: astral-sh/setup-uv@d9e0f98d3fc6adb07d1e3d37f3043649ddad06a1 # v6.5.0 with: python-version: ${{ matrix.python-version }} activate-environment: true diff --git a/.github/workflows/record-integration-tests.yml b/.github/workflows/record-integration-tests.yml index 22636f209..d4f5586e2 100644 --- a/.github/workflows/record-integration-tests.yml +++ b/.github/workflows/record-integration-tests.yml @@ -46,7 +46,7 @@ jobs: echo "::endgroup::" - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: fetch-depth: 0 diff --git a/.github/workflows/semantic-pr.yml b/.github/workflows/semantic-pr.yml index 57a4df646..4adaca84d 100644 --- a/.github/workflows/semantic-pr.yml +++ b/.github/workflows/semantic-pr.yml @@ -22,6 +22,6 @@ jobs: runs-on: ubuntu-latest steps: - name: Check PR Title's semantic conformance - uses: amannn/action-semantic-pull-request@0723387faaf9b38adef4775cd42cfd5155ed6017 # v5.5.3 + uses: amannn/action-semantic-pull-request@7f33ba792281b034f64e96f4c0b5496782dd3b37 # v6.1.0 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/test-external-provider-module.yml b/.github/workflows/test-external-provider-module.yml index d61b0dfe9..8a757b068 100644 --- a/.github/workflows/test-external-provider-module.yml +++ b/.github/workflows/test-external-provider-module.yml @@ -27,7 +27,7 @@ jobs: # container and point 'uv pip install' to the correct path... steps: - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Install dependencies uses: ./.github/actions/setup-runner diff --git a/.github/workflows/test-external.yml b/.github/workflows/test-external.yml index b9db0ad51..7ee467451 100644 --- a/.github/workflows/test-external.yml +++ b/.github/workflows/test-external.yml @@ -27,7 +27,7 @@ jobs: # container and point 'uv pip install' to the correct path... steps: - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Install dependencies uses: ./.github/actions/setup-runner diff --git a/.github/workflows/ui-unit-tests.yml b/.github/workflows/ui-unit-tests.yml index 00c539c58..2afb92bee 100644 --- a/.github/workflows/ui-unit-tests.yml +++ b/.github/workflows/ui-unit-tests.yml @@ -13,7 +13,7 @@ on: workflow_dispatch: concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }} cancel-in-progress: true jobs: @@ -26,10 +26,10 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Setup Node.js - uses: actions/setup-node@39370e3970a6d050c480ffad4ff0ed4d3fdee5af # v4.1.0 + uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4.4.0 with: node-version: ${{ matrix.node-version }} cache: 'npm' diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index f2a6c7754..dd2097a45 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -18,7 +18,7 @@ on: workflow_dispatch: concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }} cancel-in-progress: true jobs: @@ -32,7 +32,7 @@ jobs: - "3.13" steps: - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Install dependencies uses: ./.github/actions/setup-runner diff --git a/.github/workflows/update-readthedocs.yml b/.github/workflows/update-readthedocs.yml index 1dcfdeca5..e12f0adf8 100644 --- a/.github/workflows/update-readthedocs.yml +++ b/.github/workflows/update-readthedocs.yml @@ -27,7 +27,7 @@ on: - '.github/workflows/update-readthedocs.yml' concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }} cancel-in-progress: true jobs: @@ -37,7 +37,7 @@ jobs: TOKEN: ${{ secrets.READTHEDOCS_TOKEN }} steps: - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Install dependencies uses: ./.github/actions/setup-runner diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4309f289a..514fe6d2e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -146,20 +146,32 @@ repos: pass_filenames: false require_serial: true files: ^.github/workflows/.*$ - - id: ui-prettier - name: Format UI code with Prettier - entry: bash -c 'cd llama_stack/ui && npm run format' + - id: ui-linter + name: Format & Lint UI + entry: bash ./scripts/run-ui-linter.sh language: system files: ^llama_stack/ui/.*\.(ts|tsx)$ pass_filenames: false require_serial: true - - id: ui-eslint - name: Lint UI code with ESLint - entry: bash -c 'cd llama_stack/ui && npm run lint -- --fix --quiet' + + - id: check-log-usage + name: Ensure 'llama_stack.log' usage for logging + entry: bash language: system - files: ^llama_stack/ui/.*\.(ts|tsx)$ - pass_filenames: false - require_serial: true + types: [python] + pass_filenames: true + args: + - -c + - | + matches=$(grep -EnH '^[^#]*\b(import\s+logging|from\s+logging\b)' "$@" | grep -v -e '#\s*allow-direct-logging' || true) + if [ -n "$matches" ]; then + # GitHub Actions annotation format + while IFS=: read -r file line_num rest; do + echo "::error file=$file,line=$line_num::Do not use 'import logging' or 'from logging import' in $file. Use the custom log instead: from llama_stack.log import get_logger; logger = get_logger(). If direct logging is truly needed, add: # allow-direct-logging" + done <<< "$matches" + exit 1 + fi + exit 0 ci: autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index b36626719..a1f6a6f30 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -4605,6 +4605,49 @@ } } }, + "/v1/inference/rerank": { + "post": { + "responses": { + "200": { + "description": "RerankResponse with indices sorted by relevance score (descending).", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/RerankResponse" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "Inference" + ], + "description": "Rerank a list of documents based on their relevance to a query.", + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/RerankRequest" + } + } + }, + "required": true + } + } + }, "/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume": { "post": { "responses": { @@ -16024,12 +16067,16 @@ "value": { "type": "number", "description": "The numeric value of the metric at this timestamp" + }, + "unit": { + "type": "string" } }, "additionalProperties": false, "required": [ "timestamp", - "value" + "value", + "unit" ], "title": "MetricDataPoint", "description": "A single data point in a metric time series." @@ -16587,6 +16634,95 @@ ], "title": "RegisterVectorDbRequest" }, + "RerankRequest": { + "type": "object", + "properties": { + "model": { + "type": "string", + "description": "The identifier of the reranking model to use." + }, + "query": { + "oneOf": [ + { + "type": "string" + }, + { + "$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam" + }, + { + "$ref": "#/components/schemas/OpenAIChatCompletionContentPartImageParam" + } + ], + "description": "The search query to rank items against. Can be a string, text content part, or image content part. The input must not exceed the model's max input token length." + }, + "items": { + "type": "array", + "items": { + "oneOf": [ + { + "type": "string" + }, + { + "$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam" + }, + { + "$ref": "#/components/schemas/OpenAIChatCompletionContentPartImageParam" + } + ] + }, + "description": "List of items to rerank. Each item can be a string, text content part, or image content part. Each input must not exceed the model's max input token length." + }, + "max_num_results": { + "type": "integer", + "description": "(Optional) Maximum number of results to return. Default: returns all." + } + }, + "additionalProperties": false, + "required": [ + "model", + "query", + "items" + ], + "title": "RerankRequest" + }, + "RerankData": { + "type": "object", + "properties": { + "index": { + "type": "integer", + "description": "The original index of the document in the input list" + }, + "relevance_score": { + "type": "number", + "description": "The relevance score from the model output. Values are inverted when applicable so that higher scores indicate greater relevance." + } + }, + "additionalProperties": false, + "required": [ + "index", + "relevance_score" + ], + "title": "RerankData", + "description": "A single rerank result from a reranking response." + }, + "RerankResponse": { + "type": "object", + "properties": { + "data": { + "type": "array", + "items": { + "$ref": "#/components/schemas/RerankData" + }, + "description": "List of rerank result objects, sorted by relevance score (descending)" + } + }, + "additionalProperties": false, + "required": [ + "data" + ], + "title": "RerankResponse", + "description": "Response from a reranking request." + }, "ResumeAgentTurnRequest": { "type": "object", "properties": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index e7733b3c3..33142e3ff 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -3264,6 +3264,37 @@ paths: schema: $ref: '#/components/schemas/QueryTracesRequest' required: true + /v1/inference/rerank: + post: + responses: + '200': + description: >- + RerankResponse with indices sorted by relevance score (descending). + content: + application/json: + schema: + $ref: '#/components/schemas/RerankResponse' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Inference + description: >- + Rerank a list of documents based on their relevance to a query. + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/RerankRequest' + required: true /v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume: post: responses: @@ -11923,10 +11954,13 @@ components: type: number description: >- The numeric value of the metric at this timestamp + unit: + type: string additionalProperties: false required: - timestamp - value + - unit title: MetricDataPoint description: >- A single data point in a metric time series. @@ -12337,6 +12371,76 @@ components: - vector_db_id - embedding_model title: RegisterVectorDbRequest + RerankRequest: + type: object + properties: + model: + type: string + description: >- + The identifier of the reranking model to use. + query: + oneOf: + - type: string + - $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam' + - $ref: '#/components/schemas/OpenAIChatCompletionContentPartImageParam' + description: >- + The search query to rank items against. Can be a string, text content + part, or image content part. The input must not exceed the model's max + input token length. + items: + type: array + items: + oneOf: + - type: string + - $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam' + - $ref: '#/components/schemas/OpenAIChatCompletionContentPartImageParam' + description: >- + List of items to rerank. Each item can be a string, text content part, + or image content part. Each input must not exceed the model's max input + token length. + max_num_results: + type: integer + description: >- + (Optional) Maximum number of results to return. Default: returns all. + additionalProperties: false + required: + - model + - query + - items + title: RerankRequest + RerankData: + type: object + properties: + index: + type: integer + description: >- + The original index of the document in the input list + relevance_score: + type: number + description: >- + The relevance score from the model output. Values are inverted when applicable + so that higher scores indicate greater relevance. + additionalProperties: false + required: + - index + - relevance_score + title: RerankData + description: >- + A single rerank result from a reranking response. + RerankResponse: + type: object + properties: + data: + type: array + items: + $ref: '#/components/schemas/RerankData' + description: >- + List of rerank result objects, sorted by relevance score (descending) + additionalProperties: false + required: + - data + title: RerankResponse + description: Response from a reranking request. ResumeAgentTurnRequest: type: object properties: diff --git a/docs/source/distributions/configuration.md b/docs/source/distributions/configuration.md index 335fa3a68..c9677b3b6 100644 --- a/docs/source/distributions/configuration.md +++ b/docs/source/distributions/configuration.md @@ -225,8 +225,32 @@ server: port: 8321 # Port to listen on (default: 8321) tls_certfile: "/path/to/cert.pem" # Optional: Path to TLS certificate for HTTPS tls_keyfile: "/path/to/key.pem" # Optional: Path to TLS key for HTTPS + cors: true # Optional: Enable CORS (dev mode) or full config object ``` +### CORS Configuration + +CORS (Cross-Origin Resource Sharing) can be configured in two ways: + +**Local development** (allows localhost origins only): +```yaml +server: + cors: true +``` + +**Explicit configuration** (custom origins and settings): +```yaml +server: + cors: + allow_origins: ["https://myapp.com", "https://app.example.com"] + allow_methods: ["GET", "POST", "PUT", "DELETE"] + allow_headers: ["Content-Type", "Authorization"] + allow_credentials: true + max_age: 3600 +``` + +When `cors: true`, the server enables secure localhost-only access for local development. For production, specify exact origins to maintain security. + ### Authentication Configuration > **Breaking Change (v0.2.14)**: The authentication configuration structure has changed. The previous format with `provider_type` and `config` fields has been replaced with a unified `provider_config` field that includes the `type` field. Update your configuration files accordingly. @@ -618,6 +642,54 @@ Content-Type: application/json } ``` +### CORS Configuration + +Configure CORS to allow web browsers to make requests from different domains. Disabled by default. + +#### Quick Setup + +For development, use the simple boolean flag: + +```yaml +server: + cors: true # Auto-enables localhost with any port +``` + +This automatically allows `http://localhost:*` and `https://localhost:*` with secure defaults. + +#### Custom Configuration + +For specific origins and full control: + +```yaml +server: + cors: + allow_origins: ["https://myapp.com", "https://staging.myapp.com"] + allow_credentials: true + allow_methods: ["GET", "POST", "PUT", "DELETE"] + allow_headers: ["Content-Type", "Authorization"] + allow_origin_regex: "https://.*\\.example\\.com" # Optional regex pattern + expose_headers: ["X-Total-Count"] + max_age: 86400 +``` + +#### Configuration Options + +| Field | Description | Default | +| -------------------- | ---------------------------------------------- | ------- | +| `allow_origins` | List of allowed origins. Use `["*"]` for any. | `["*"]` | +| `allow_origin_regex` | Regex pattern for allowed origins (optional). | `None` | +| `allow_methods` | Allowed HTTP methods. | `["*"]` | +| `allow_headers` | Allowed headers. | `["*"]` | +| `allow_credentials` | Allow credentials (cookies, auth headers). | `false` | +| `expose_headers` | Headers exposed to browser. | `[]` | +| `max_age` | Preflight cache time (seconds). | `600` | + +**Security Notes**: +- `allow_credentials: true` requires explicit origins (no wildcards) +- `cors: true` enables localhost access only (secure for development) +- For public APIs, always specify exact allowed origins + ## Extending to handle Safety Configuring Safety can be a little involved so it is instructive to go through an example. diff --git a/docs/source/distributions/importing_as_library.md b/docs/source/distributions/importing_as_library.md index fbc48dd95..b9b4b065a 100644 --- a/docs/source/distributions/importing_as_library.md +++ b/docs/source/distributions/importing_as_library.md @@ -17,7 +17,6 @@ client = LlamaStackAsLibraryClient( # provider_data is optional, but if you need to pass in any provider specific data, you can do so here. provider_data={"tavily_search_api_key": os.environ["TAVILY_SEARCH_API_KEY"]}, ) -client.initialize() ``` This will parse your config and set up any inline implementations and remote clients needed for your implementation. @@ -32,5 +31,4 @@ If you've created a [custom distribution](https://llama-stack.readthedocs.io/en/ ```python client = LlamaStackAsLibraryClient(config_path) -client.initialize() ``` diff --git a/docs/source/distributions/k8s-benchmark/benchmark.py b/docs/source/distributions/k8s-benchmark/benchmark.py index 0e7368431..3d0d18150 100644 --- a/docs/source/distributions/k8s-benchmark/benchmark.py +++ b/docs/source/distributions/k8s-benchmark/benchmark.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # diff --git a/docs/source/providers/batches/index.md b/docs/source/providers/batches/index.md index 2a39a626c..d6d2fa9a3 100644 --- a/docs/source/providers/batches/index.md +++ b/docs/source/providers/batches/index.md @@ -2,12 +2,15 @@ ## Overview -Protocol for batch processing API operations. - - The Batches API enables efficient processing of multiple requests in a single operation, +The Batches API enables efficient processing of multiple requests in a single operation, particularly useful for processing large datasets, batch evaluation workflows, and cost-effective inference at scale. + The API is designed to allow use of openai client libraries for seamless integration. + + This API provides the following extensions: + - idempotent batch creation + Note: This API is currently under active development and may undergo changes. This section contains documentation for all available providers for the **batches** API. diff --git a/docs/source/providers/files/index.md b/docs/source/providers/files/index.md index 692aad3ca..128953223 100644 --- a/docs/source/providers/files/index.md +++ b/docs/source/providers/files/index.md @@ -10,4 +10,5 @@ This section contains documentation for all available providers for the **files* :maxdepth: 1 inline_localfs +remote_s3 ``` diff --git a/docs/source/providers/files/remote_s3.md b/docs/source/providers/files/remote_s3.md new file mode 100644 index 000000000..2e3cebabd --- /dev/null +++ b/docs/source/providers/files/remote_s3.md @@ -0,0 +1,33 @@ +# remote::s3 + +## Description + +AWS S3-based file storage provider for scalable cloud file management with metadata persistence. + +## Configuration + +| Field | Type | Required | Default | Description | +|-------|------|----------|---------|-------------| +| `bucket_name` | `` | No | | S3 bucket name to store files | +| `region` | `` | No | us-east-1 | AWS region where the bucket is located | +| `aws_access_key_id` | `str \| None` | No | | AWS access key ID (optional if using IAM roles) | +| `aws_secret_access_key` | `str \| None` | No | | AWS secret access key (optional if using IAM roles) | +| `endpoint_url` | `str \| None` | No | | Custom S3 endpoint URL (for MinIO, LocalStack, etc.) | +| `auto_create_bucket` | `` | No | False | Automatically create the S3 bucket if it doesn't exist | +| `metadata_store` | `utils.sqlstore.sqlstore.SqliteSqlStoreConfig \| utils.sqlstore.sqlstore.PostgresSqlStoreConfig` | No | sqlite | SQL store configuration for file metadata | + +## Sample Configuration + +```yaml +bucket_name: ${env.S3_BUCKET_NAME} +region: ${env.AWS_REGION:=us-east-1} +aws_access_key_id: ${env.AWS_ACCESS_KEY_ID:=} +aws_secret_access_key: ${env.AWS_SECRET_ACCESS_KEY:=} +endpoint_url: ${env.S3_ENDPOINT_URL:=} +auto_create_bucket: ${env.S3_AUTO_CREATE_BUCKET:=false} +metadata_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/s3_files_metadata.db + +``` + diff --git a/docs/source/providers/post_training/index.md b/docs/source/providers/post_training/index.md index c6c92c40e..5ada6f9aa 100644 --- a/docs/source/providers/post_training/index.md +++ b/docs/source/providers/post_training/index.md @@ -9,7 +9,9 @@ This section contains documentation for all available providers for the **post_t ```{toctree} :maxdepth: 1 -inline_huggingface -inline_torchtune +inline_huggingface-cpu +inline_huggingface-gpu +inline_torchtune-cpu +inline_torchtune-gpu remote_nvidia ``` diff --git a/docs/source/providers/post_training/inline_huggingface-cpu.md b/docs/source/providers/post_training/inline_huggingface-cpu.md new file mode 100644 index 000000000..e663fe8f8 --- /dev/null +++ b/docs/source/providers/post_training/inline_huggingface-cpu.md @@ -0,0 +1,41 @@ +# inline::huggingface-cpu + +## Description + +HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem. + +## Configuration + +| Field | Type | Required | Default | Description | +|-------|------|----------|---------|-------------| +| `device` | `` | No | cuda | | +| `distributed_backend` | `Literal['fsdp', 'deepspeed'` | No | | | +| `checkpoint_format` | `Literal['full_state', 'huggingface'` | No | huggingface | | +| `chat_template` | `` | No | <|user|> +{input} +<|assistant|> +{output} | | +| `model_specific_config` | `` | No | {'trust_remote_code': True, 'attn_implementation': 'sdpa'} | | +| `max_seq_length` | `` | No | 2048 | | +| `gradient_checkpointing` | `` | No | False | | +| `save_total_limit` | `` | No | 3 | | +| `logging_steps` | `` | No | 10 | | +| `warmup_ratio` | `` | No | 0.1 | | +| `weight_decay` | `` | No | 0.01 | | +| `dataloader_num_workers` | `` | No | 4 | | +| `dataloader_pin_memory` | `` | No | True | | +| `dpo_beta` | `` | No | 0.1 | | +| `use_reference_model` | `` | No | True | | +| `dpo_loss_type` | `Literal['sigmoid', 'hinge', 'ipo', 'kto_pair'` | No | sigmoid | | +| `dpo_output_dir` | `` | No | | | + +## Sample Configuration + +```yaml +checkpoint_format: huggingface +distributed_backend: null +device: cpu +dpo_output_dir: ~/.llama/dummy/dpo_output + +``` + diff --git a/docs/source/providers/post_training/inline_huggingface-gpu.md b/docs/source/providers/post_training/inline_huggingface-gpu.md new file mode 100644 index 000000000..21bf965fe --- /dev/null +++ b/docs/source/providers/post_training/inline_huggingface-gpu.md @@ -0,0 +1,41 @@ +# inline::huggingface-gpu + +## Description + +HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem. + +## Configuration + +| Field | Type | Required | Default | Description | +|-------|------|----------|---------|-------------| +| `device` | `` | No | cuda | | +| `distributed_backend` | `Literal['fsdp', 'deepspeed'` | No | | | +| `checkpoint_format` | `Literal['full_state', 'huggingface'` | No | huggingface | | +| `chat_template` | `` | No | <|user|> +{input} +<|assistant|> +{output} | | +| `model_specific_config` | `` | No | {'trust_remote_code': True, 'attn_implementation': 'sdpa'} | | +| `max_seq_length` | `` | No | 2048 | | +| `gradient_checkpointing` | `` | No | False | | +| `save_total_limit` | `` | No | 3 | | +| `logging_steps` | `` | No | 10 | | +| `warmup_ratio` | `` | No | 0.1 | | +| `weight_decay` | `` | No | 0.01 | | +| `dataloader_num_workers` | `` | No | 4 | | +| `dataloader_pin_memory` | `` | No | True | | +| `dpo_beta` | `` | No | 0.1 | | +| `use_reference_model` | `` | No | True | | +| `dpo_loss_type` | `Literal['sigmoid', 'hinge', 'ipo', 'kto_pair'` | No | sigmoid | | +| `dpo_output_dir` | `` | No | | | + +## Sample Configuration + +```yaml +checkpoint_format: huggingface +distributed_backend: null +device: cpu +dpo_output_dir: ~/.llama/dummy/dpo_output + +``` + diff --git a/docs/source/providers/post_training/inline_torchtune-cpu.md b/docs/source/providers/post_training/inline_torchtune-cpu.md new file mode 100644 index 000000000..7204e56e8 --- /dev/null +++ b/docs/source/providers/post_training/inline_torchtune-cpu.md @@ -0,0 +1,20 @@ +# inline::torchtune-cpu + +## Description + +TorchTune-based post-training provider for fine-tuning and optimizing models using Meta's TorchTune framework. + +## Configuration + +| Field | Type | Required | Default | Description | +|-------|------|----------|---------|-------------| +| `torch_seed` | `int \| None` | No | | | +| `checkpoint_format` | `Literal['meta', 'huggingface'` | No | meta | | + +## Sample Configuration + +```yaml +checkpoint_format: meta + +``` + diff --git a/docs/source/providers/post_training/inline_torchtune-gpu.md b/docs/source/providers/post_training/inline_torchtune-gpu.md new file mode 100644 index 000000000..98b94f6f6 --- /dev/null +++ b/docs/source/providers/post_training/inline_torchtune-gpu.md @@ -0,0 +1,20 @@ +# inline::torchtune-gpu + +## Description + +TorchTune-based post-training provider for fine-tuning and optimizing models using Meta's TorchTune framework. + +## Configuration + +| Field | Type | Required | Default | Description | +|-------|------|----------|---------|-------------| +| `torch_seed` | `int \| None` | No | | | +| `checkpoint_format` | `Literal['meta', 'huggingface'` | No | meta | | + +## Sample Configuration + +```yaml +checkpoint_format: meta + +``` + diff --git a/llama_stack/apis/batches/batches.py b/llama_stack/apis/batches/batches.py index 9297d8597..c6bbd92eb 100644 --- a/llama_stack/apis/batches/batches.py +++ b/llama_stack/apis/batches/batches.py @@ -29,12 +29,16 @@ class ListBatchesResponse(BaseModel): @runtime_checkable class Batches(Protocol): - """Protocol for batch processing API operations. - + """ The Batches API enables efficient processing of multiple requests in a single operation, particularly useful for processing large datasets, batch evaluation workflows, and cost-effective inference at scale. + The API is designed to allow use of openai client libraries for seamless integration. + + This API provides the following extensions: + - idempotent batch creation + Note: This API is currently under active development and may undergo changes. """ @@ -45,6 +49,7 @@ class Batches(Protocol): endpoint: str, completion_window: Literal["24h"], metadata: dict[str, str] | None = None, + idempotency_key: str | None = None, ) -> BatchObject: """Create a new batch for processing multiple API requests. @@ -52,6 +57,7 @@ class Batches(Protocol): :param endpoint: The endpoint to be used for all requests in the batch. :param completion_window: The time window within which the batch should be processed. :param metadata: Optional metadata for the batch. + :param idempotency_key: Optional idempotency key. When provided, enables idempotent behavior. :returns: The created batch object. """ ... diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 7e7bd0a3d..bd4737ca7 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -473,6 +473,28 @@ class EmbeddingsResponse(BaseModel): embeddings: list[list[float]] +@json_schema_type +class RerankData(BaseModel): + """A single rerank result from a reranking response. + + :param index: The original index of the document in the input list + :param relevance_score: The relevance score from the model output. Values are inverted when applicable so that higher scores indicate greater relevance. + """ + + index: int + relevance_score: float + + +@json_schema_type +class RerankResponse(BaseModel): + """Response from a reranking request. + + :param data: List of rerank result objects, sorted by relevance score (descending) + """ + + data: list[RerankData] + + @json_schema_type class OpenAIChatCompletionContentPartTextParam(BaseModel): """Text content part for OpenAI-compatible chat completion messages. @@ -1046,6 +1068,7 @@ class InferenceProvider(Protocol): :returns: A BatchCompletionResponse with the full completions. """ raise NotImplementedError("Batch completion is not implemented") + return # this is so mypy's safe-super rule will consider the method concrete @webmethod(route="/inference/chat-completion", method="POST") async def chat_completion( @@ -1110,6 +1133,7 @@ class InferenceProvider(Protocol): :returns: A BatchChatCompletionResponse with the full completions. """ raise NotImplementedError("Batch chat completion is not implemented") + return # this is so mypy's safe-super rule will consider the method concrete @webmethod(route="/inference/embeddings", method="POST") async def embeddings( @@ -1131,6 +1155,25 @@ class InferenceProvider(Protocol): """ ... + @webmethod(route="/inference/rerank", method="POST", experimental=True) + async def rerank( + self, + model: str, + query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam, + items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam], + max_num_results: int | None = None, + ) -> RerankResponse: + """Rerank a list of documents based on their relevance to a query. + + :param model: The identifier of the reranking model to use. + :param query: The search query to rank items against. Can be a string, text content part, or image content part. The input must not exceed the model's max input token length. + :param items: List of items to rerank. Each item can be a string, text content part, or image content part. Each input must not exceed the model's max input token length. + :param max_num_results: (Optional) Maximum number of results to return. Default: returns all. + :returns: RerankResponse with indices sorted by relevance score (descending). + """ + raise NotImplementedError("Reranking is not implemented") + return # this is so mypy's safe-super rule will consider the method concrete + @webmethod(route="/openai/v1/completions", method="POST") async def openai_completion( self, diff --git a/llama_stack/apis/telemetry/telemetry.py b/llama_stack/apis/telemetry/telemetry.py index 92422ac1b..8d1b5d697 100644 --- a/llama_stack/apis/telemetry/telemetry.py +++ b/llama_stack/apis/telemetry/telemetry.py @@ -386,6 +386,7 @@ class MetricDataPoint(BaseModel): timestamp: int value: float + unit: str @json_schema_type @@ -518,7 +519,7 @@ class Telemetry(Protocol): metric_name: str, start_time: int, end_time: int | None = None, - granularity: str | None = "1d", + granularity: str | None = None, query_type: MetricQueryType = MetricQueryType.RANGE, label_matchers: list[MetricLabelMatcher] | None = None, ) -> QueryMetricsResponse: diff --git a/llama_stack/cli/stack/run.py b/llama_stack/cli/stack/run.py index c8ffce034..b32b8b3ae 100644 --- a/llama_stack/cli/stack/run.py +++ b/llama_stack/cli/stack/run.py @@ -15,7 +15,7 @@ from llama_stack.log import get_logger REPO_ROOT = Path(__file__).parent.parent.parent.parent -logger = get_logger(name=__name__, category="server") +logger = get_logger(name=__name__, category="cli") class StackRun(Subcommand): diff --git a/llama_stack/core/build.py b/llama_stack/core/build.py index 4b20588fd..fa1fe632b 100644 --- a/llama_stack/core/build.py +++ b/llama_stack/core/build.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import importlib.resources -import logging import sys from pydantic import BaseModel @@ -17,9 +16,10 @@ from llama_stack.core.external import load_external_apis from llama_stack.core.utils.exec import run_command from llama_stack.core.utils.image_types import LlamaStackImageType from llama_stack.distributions.template import DistributionTemplate +from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="core") # These are the dependencies needed by the distribution server. # `llama-stack` is automatically installed by the installation script. diff --git a/llama_stack/core/configure.py b/llama_stack/core/configure.py index 9e18b438c..64473c053 100644 --- a/llama_stack/core/configure.py +++ b/llama_stack/core/configure.py @@ -3,7 +3,6 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging import textwrap from typing import Any @@ -21,9 +20,10 @@ from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars from llama_stack.core.utils.config_dirs import EXTERNAL_PROVIDERS_DIR from llama_stack.core.utils.dynamic import instantiate_class_type from llama_stack.core.utils.prompt_for_config import prompt_for_config +from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api, ProviderSpec -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="core") def configure_single_provider(registry: dict[str, ProviderSpec], provider: Provider) -> Provider: diff --git a/llama_stack/core/datatypes.py b/llama_stack/core/datatypes.py index a1b6ad32b..c3940fcbd 100644 --- a/llama_stack/core/datatypes.py +++ b/llama_stack/core/datatypes.py @@ -318,6 +318,41 @@ class QuotaConfig(BaseModel): period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set") +class CORSConfig(BaseModel): + allow_origins: list[str] = Field(default_factory=list) + allow_origin_regex: str | None = Field(default=None) + allow_methods: list[str] = Field(default=["OPTIONS"]) + allow_headers: list[str] = Field(default_factory=list) + allow_credentials: bool = Field(default=False) + expose_headers: list[str] = Field(default_factory=list) + max_age: int = Field(default=600, ge=0) + + @model_validator(mode="after") + def validate_credentials_config(self) -> Self: + if self.allow_credentials and (self.allow_origins == ["*"] or "*" in self.allow_origins): + raise ValueError("Cannot use wildcard origins with credentials enabled") + return self + + +def process_cors_config(cors_config: bool | CORSConfig | None) -> CORSConfig | None: + if cors_config is False or cors_config is None: + return None + + if cors_config is True: + # dev mode: allow localhost on any port + return CORSConfig( + allow_origins=[], + allow_origin_regex=r"https?://localhost:\d+", + allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], + allow_headers=["Content-Type", "Authorization", "X-Requested-With"], + ) + + if isinstance(cors_config, CORSConfig): + return cors_config + + raise ValueError(f"Expected bool or CORSConfig, got {type(cors_config).__name__}") + + class ServerConfig(BaseModel): port: int = Field( default=8321, @@ -349,6 +384,12 @@ class ServerConfig(BaseModel): default=None, description="Per client quota request configuration", ) + cors: bool | CORSConfig | None = Field( + default=None, + description="CORS configuration for cross-origin requests. Can be:\n" + "- true: Enable localhost CORS for development\n" + "- {allow_origins: [...], allow_methods: [...], ...}: Full configuration", + ) class StackRunConfig(BaseModel): diff --git a/llama_stack/core/library_client.py b/llama_stack/core/library_client.py index a93fe509e..9e7a8006c 100644 --- a/llama_stack/core/library_client.py +++ b/llama_stack/core/library_client.py @@ -7,7 +7,7 @@ import asyncio import inspect import json -import logging +import logging # allow-direct-logging import os import sys from concurrent.futures import ThreadPoolExecutor @@ -48,6 +48,7 @@ from llama_stack.core.stack import ( from llama_stack.core.utils.config import redact_sensitive_fields from llama_stack.core.utils.context import preserve_contexts_async_generator from llama_stack.core.utils.exec import in_notebook +from llama_stack.log import get_logger from llama_stack.providers.utils.telemetry.tracing import ( CURRENT_TRACE_CONTEXT, end_trace, @@ -55,7 +56,7 @@ from llama_stack.providers.utils.telemetry.tracing import ( start_trace, ) -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="core") T = TypeVar("T") @@ -145,39 +146,26 @@ class LlamaStackAsLibraryClient(LlamaStackClient): ): super().__init__() self.async_client = AsyncLlamaStackAsLibraryClient( - config_path_or_distro_name, custom_provider_registry, provider_data + config_path_or_distro_name, custom_provider_registry, provider_data, skip_logger_removal ) self.pool_executor = ThreadPoolExecutor(max_workers=4) - self.skip_logger_removal = skip_logger_removal self.provider_data = provider_data self.loop = asyncio.new_event_loop() - def initialize(self): - if in_notebook(): - import nest_asyncio - - nest_asyncio.apply() - if not self.skip_logger_removal: - self._remove_root_logger_handlers() - # use a new event loop to avoid interfering with the main event loop loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: - return loop.run_until_complete(self.async_client.initialize()) + loop.run_until_complete(self.async_client.initialize()) finally: asyncio.set_event_loop(None) - def _remove_root_logger_handlers(self): + def initialize(self): """ - Remove all handlers from the root logger. Needed to avoid polluting the console with logs. + Deprecated method for backward compatibility. """ - root_logger = logging.getLogger() - - for handler in root_logger.handlers[:]: - root_logger.removeHandler(handler) - logger.info(f"Removed handler {handler.__class__.__name__} from root logger") + pass def request(self, *args, **kwargs): loop = self.loop @@ -215,6 +203,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): config_path_or_distro_name: str, custom_provider_registry: ProviderRegistry | None = None, provider_data: dict[str, Any] | None = None, + skip_logger_removal: bool = False, ): super().__init__() # when using the library client, we should not log to console since many @@ -222,6 +211,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",") os.environ["TELEMETRY_SINKS"] = ",".join(sink for sink in current_sinks if sink != "console") + if in_notebook(): + import nest_asyncio + + nest_asyncio.apply() + if not skip_logger_removal: + self._remove_root_logger_handlers() + if config_path_or_distro_name.endswith(".yaml"): config_path = Path(config_path_or_distro_name) if not config_path.exists(): @@ -238,7 +234,24 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): self.provider_data = provider_data self.route_impls: RouteImpls | None = None # Initialize to None to prevent AttributeError + def _remove_root_logger_handlers(self): + """ + Remove all handlers from the root logger. Needed to avoid polluting the console with logs. + """ + root_logger = logging.getLogger() + + for handler in root_logger.handlers[:]: + root_logger.removeHandler(handler) + logger.info(f"Removed handler {handler.__class__.__name__} from root logger") + async def initialize(self) -> bool: + """ + Initialize the async client. + + Returns: + bool: True if initialization was successful + """ + try: self.route_impls = None self.impls = await construct_stack(self.config, self.custom_provider_registry) diff --git a/llama_stack/core/request_headers.py b/llama_stack/core/request_headers.py index 35ac72775..f1ce8281f 100644 --- a/llama_stack/core/request_headers.py +++ b/llama_stack/core/request_headers.py @@ -6,15 +6,15 @@ import contextvars import json -import logging from contextlib import AbstractContextManager from typing import Any from llama_stack.core.datatypes import User +from llama_stack.log import get_logger from .utils.dynamic import instantiate_class_type -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="core") # Context variable for request provider data and auth attributes PROVIDER_DATA_VAR = contextvars.ContextVar("provider_data", default=None) diff --git a/llama_stack/core/routers/datasets.py b/llama_stack/core/routers/datasets.py index d7984f729..2f1d5f78e 100644 --- a/llama_stack/core/routers/datasets.py +++ b/llama_stack/core/routers/datasets.py @@ -12,7 +12,7 @@ from llama_stack.apis.datasets import DatasetPurpose, DataSource from llama_stack.log import get_logger from llama_stack.providers.datatypes import RoutingTable -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routers") class DatasetIORouter(DatasetIO): diff --git a/llama_stack/core/routers/eval_scoring.py b/llama_stack/core/routers/eval_scoring.py index f7a17eecf..ffca81bf0 100644 --- a/llama_stack/core/routers/eval_scoring.py +++ b/llama_stack/core/routers/eval_scoring.py @@ -16,7 +16,7 @@ from llama_stack.apis.scoring import ( from llama_stack.log import get_logger from llama_stack.providers.datatypes import RoutingTable -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routers") class ScoringRouter(Scoring): diff --git a/llama_stack/core/routers/inference.py b/llama_stack/core/routers/inference.py index 6a3f07247..4b66601bb 100644 --- a/llama_stack/core/routers/inference.py +++ b/llama_stack/core/routers/inference.py @@ -65,7 +65,7 @@ from llama_stack.providers.datatypes import HealthResponse, HealthStatus, Routin from llama_stack.providers.utils.inference.inference_store import InferenceStore from llama_stack.providers.utils.telemetry.tracing import get_current_span -logger = get_logger(name=__name__, category="inference") +logger = get_logger(name=__name__, category="core::routers") class InferenceRouter(Inference): diff --git a/llama_stack/core/routers/safety.py b/llama_stack/core/routers/safety.py index 738ecded3..9ba3327f1 100644 --- a/llama_stack/core/routers/safety.py +++ b/llama_stack/core/routers/safety.py @@ -13,7 +13,7 @@ from llama_stack.apis.shields import Shield from llama_stack.log import get_logger from llama_stack.providers.datatypes import RoutingTable -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routers") class SafetyRouter(Safety): diff --git a/llama_stack/core/routers/tool_runtime.py b/llama_stack/core/routers/tool_runtime.py index 5a40bc0c5..fd606f33b 100644 --- a/llama_stack/core/routers/tool_runtime.py +++ b/llama_stack/core/routers/tool_runtime.py @@ -22,7 +22,7 @@ from llama_stack.log import get_logger from ..routing_tables.toolgroups import ToolGroupsRoutingTable -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routers") class ToolRuntimeRouter(ToolRuntime): diff --git a/llama_stack/core/routers/vector_io.py b/llama_stack/core/routers/vector_io.py index 3d0996c49..786b0e391 100644 --- a/llama_stack/core/routers/vector_io.py +++ b/llama_stack/core/routers/vector_io.py @@ -30,7 +30,7 @@ from llama_stack.apis.vector_io import ( from llama_stack.log import get_logger from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routers") class VectorIORouter(VectorIO): diff --git a/llama_stack/core/routing_tables/benchmarks.py b/llama_stack/core/routing_tables/benchmarks.py index 74bee8040..c875dee5b 100644 --- a/llama_stack/core/routing_tables/benchmarks.py +++ b/llama_stack/core/routing_tables/benchmarks.py @@ -14,7 +14,7 @@ from llama_stack.log import get_logger from .common import CommonRoutingTableImpl -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routing_tables") class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks): diff --git a/llama_stack/core/routing_tables/common.py b/llama_stack/core/routing_tables/common.py index 339ff6da4..e523746d8 100644 --- a/llama_stack/core/routing_tables/common.py +++ b/llama_stack/core/routing_tables/common.py @@ -23,7 +23,7 @@ from llama_stack.core.store import DistributionRegistry from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api, RoutingTable -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routing_tables") def get_impl_api(p: Any) -> Api: diff --git a/llama_stack/core/routing_tables/datasets.py b/llama_stack/core/routing_tables/datasets.py index fc6a75df4..b129c9ec5 100644 --- a/llama_stack/core/routing_tables/datasets.py +++ b/llama_stack/core/routing_tables/datasets.py @@ -26,7 +26,7 @@ from llama_stack.log import get_logger from .common import CommonRoutingTableImpl -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routing_tables") class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): diff --git a/llama_stack/core/routing_tables/models.py b/llama_stack/core/routing_tables/models.py index 34c431e00..b6141efa9 100644 --- a/llama_stack/core/routing_tables/models.py +++ b/llama_stack/core/routing_tables/models.py @@ -17,7 +17,7 @@ from llama_stack.log import get_logger from .common import CommonRoutingTableImpl, lookup_model -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routing_tables") class ModelsRoutingTable(CommonRoutingTableImpl, Models): diff --git a/llama_stack/core/routing_tables/scoring_functions.py b/llama_stack/core/routing_tables/scoring_functions.py index 5874ba941..71e5bed63 100644 --- a/llama_stack/core/routing_tables/scoring_functions.py +++ b/llama_stack/core/routing_tables/scoring_functions.py @@ -19,7 +19,7 @@ from llama_stack.log import get_logger from .common import CommonRoutingTableImpl -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routing_tables") class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): diff --git a/llama_stack/core/routing_tables/shields.py b/llama_stack/core/routing_tables/shields.py index e08f35bfc..b1918d20a 100644 --- a/llama_stack/core/routing_tables/shields.py +++ b/llama_stack/core/routing_tables/shields.py @@ -15,7 +15,7 @@ from llama_stack.log import get_logger from .common import CommonRoutingTableImpl -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routing_tables") class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): diff --git a/llama_stack/core/routing_tables/toolgroups.py b/llama_stack/core/routing_tables/toolgroups.py index 6910b3906..eeea406c1 100644 --- a/llama_stack/core/routing_tables/toolgroups.py +++ b/llama_stack/core/routing_tables/toolgroups.py @@ -14,7 +14,7 @@ from llama_stack.log import get_logger from .common import CommonRoutingTableImpl -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routing_tables") def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name: str) -> str | None: diff --git a/llama_stack/core/routing_tables/vector_dbs.py b/llama_stack/core/routing_tables/vector_dbs.py index e8dc46997..00f71b4fe 100644 --- a/llama_stack/core/routing_tables/vector_dbs.py +++ b/llama_stack/core/routing_tables/vector_dbs.py @@ -30,7 +30,7 @@ from llama_stack.log import get_logger from .common import CommonRoutingTableImpl, lookup_model -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routing_tables") class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): diff --git a/llama_stack/core/server/auth.py b/llama_stack/core/server/auth.py index e4fb4ff2b..c98d3bec0 100644 --- a/llama_stack/core/server/auth.py +++ b/llama_stack/core/server/auth.py @@ -15,7 +15,7 @@ from llama_stack.core.server.auth_providers import create_auth_provider from llama_stack.core.server.routes import find_matching_route, initialize_route_impls from llama_stack.log import get_logger -logger = get_logger(name=__name__, category="auth") +logger = get_logger(name=__name__, category="core::auth") class AuthenticationMiddleware: diff --git a/llama_stack/core/server/auth_providers.py b/llama_stack/core/server/auth_providers.py index 73d5581c2..a8af6f75a 100644 --- a/llama_stack/core/server/auth_providers.py +++ b/llama_stack/core/server/auth_providers.py @@ -23,7 +23,7 @@ from llama_stack.core.datatypes import ( ) from llama_stack.log import get_logger -logger = get_logger(name=__name__, category="auth") +logger = get_logger(name=__name__, category="core::auth") class AuthResponse(BaseModel): diff --git a/llama_stack/core/server/quota.py b/llama_stack/core/server/quota.py index 1cb850cde..693f224c3 100644 --- a/llama_stack/core/server/quota.py +++ b/llama_stack/core/server/quota.py @@ -15,7 +15,7 @@ from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl -logger = get_logger(name=__name__, category="quota") +logger = get_logger(name=__name__, category="core::server") class QuotaMiddleware: diff --git a/llama_stack/core/server/server.py b/llama_stack/core/server/server.py index cbef8ef88..d6dfc3435 100644 --- a/llama_stack/core/server/server.py +++ b/llama_stack/core/server/server.py @@ -9,7 +9,7 @@ import asyncio import functools import inspect import json -import logging +import logging # allow-direct-logging import os import ssl import sys @@ -28,6 +28,7 @@ from aiohttp import hdrs from fastapi import Body, FastAPI, HTTPException, Request, Response from fastapi import Path as FastapiPath from fastapi.exceptions import RequestValidationError +from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse from openai import BadRequestError from pydantic import BaseModel, ValidationError @@ -40,6 +41,7 @@ from llama_stack.core.datatypes import ( AuthenticationRequiredError, LoggingConfig, StackRunConfig, + process_cors_config, ) from llama_stack.core.distribution import builtin_automatically_routed_apis from llama_stack.core.external import ExternalApiSpec, load_external_apis @@ -82,7 +84,7 @@ from .quota import QuotaMiddleware REPO_ROOT = Path(__file__).parent.parent.parent.parent -logger = get_logger(name=__name__, category="server") +logger = get_logger(name=__name__, category="core::server") def warn_with_traceback(message, category, filename, lineno, file=None, line=None): @@ -413,7 +415,7 @@ def main(args: argparse.Namespace | None = None): config_contents = yaml.safe_load(fp) if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")): logger_config = LoggingConfig(**cfg) - logger = get_logger(name=__name__, category="server", config=logger_config) + logger = get_logger(name=__name__, category="core::server", config=logger_config) if args.env: for env_pair in args.env: try: @@ -483,6 +485,12 @@ def main(args: argparse.Namespace | None = None): window_seconds=window_seconds, ) + if config.server.cors: + logger.info("Enabling CORS") + cors_config = process_cors_config(config.server.cors) + if cors_config: + app.add_middleware(CORSMiddleware, **cors_config.model_dump()) + if Api.telemetry in impls: setup_logger(impls[Api.telemetry]) else: diff --git a/llama_stack/core/store/registry.py b/llama_stack/core/store/registry.py index 4b60e1001..5f4abe9aa 100644 --- a/llama_stack/core/store/registry.py +++ b/llama_stack/core/store/registry.py @@ -16,7 +16,7 @@ from llama_stack.log import get_logger from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig -logger = get_logger(__name__, category="core") +logger = get_logger(__name__, category="core::registry") class DistributionRegistry(Protocol): diff --git a/llama_stack/core/utils/config_resolution.py b/llama_stack/core/utils/config_resolution.py index 30cd71e15..182a571ee 100644 --- a/llama_stack/core/utils/config_resolution.py +++ b/llama_stack/core/utils/config_resolution.py @@ -10,7 +10,7 @@ from pathlib import Path from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR from llama_stack.log import get_logger -logger = get_logger(name=__name__, category="config_resolution") +logger = get_logger(name=__name__, category="core") DISTRO_DIR = Path(__file__).parent.parent.parent.parent / "llama_stack" / "distributions" diff --git a/llama_stack/core/utils/exec.py b/llama_stack/core/utils/exec.py index 1b2b782fe..12fb82d01 100644 --- a/llama_stack/core/utils/exec.py +++ b/llama_stack/core/utils/exec.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging +import importlib import os import signal import subprocess @@ -12,9 +12,9 @@ import sys from termcolor import cprint -log = logging.getLogger(__name__) +from llama_stack.log import get_logger -import importlib +log = get_logger(name=__name__, category="core") def formulate_run_args(image_type: str, image_name: str) -> list: diff --git a/llama_stack/core/utils/prompt_for_config.py b/llama_stack/core/utils/prompt_for_config.py index 26f6920e0..bac0531ed 100644 --- a/llama_stack/core/utils/prompt_for_config.py +++ b/llama_stack/core/utils/prompt_for_config.py @@ -6,7 +6,6 @@ import inspect import json -import logging from enum import Enum from typing import Annotated, Any, Literal, Union, get_args, get_origin @@ -14,7 +13,9 @@ from pydantic import BaseModel from pydantic.fields import FieldInfo from pydantic_core import PydanticUndefinedType -log = logging.getLogger(__name__) +from llama_stack.log import get_logger + +log = get_logger(name=__name__, category="core") def is_list_of_primitives(field_type): diff --git a/llama_stack/distributions/ci-tests/build.yaml b/llama_stack/distributions/ci-tests/build.yaml index 0bf42e7ee..b4701cb81 100644 --- a/llama_stack/distributions/ci-tests/build.yaml +++ b/llama_stack/distributions/ci-tests/build.yaml @@ -34,7 +34,7 @@ distribution_spec: telemetry: - provider_type: inline::meta-reference post_training: - - provider_type: inline::huggingface + - provider_type: inline::huggingface-cpu eval: - provider_type: inline::meta-reference datasetio: diff --git a/llama_stack/distributions/ci-tests/run.yaml b/llama_stack/distributions/ci-tests/run.yaml index 02a268462..3acdd20f9 100644 --- a/llama_stack/distributions/ci-tests/run.yaml +++ b/llama_stack/distributions/ci-tests/run.yaml @@ -156,8 +156,8 @@ providers: sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/trace_store.db otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=} post_training: - - provider_id: huggingface - provider_type: inline::huggingface + - provider_id: huggingface-cpu + provider_type: inline::huggingface-cpu config: checkpoint_format: huggingface distributed_backend: null diff --git a/llama_stack/distributions/starter-gpu/__init__.py b/llama_stack/distributions/starter-gpu/__init__.py new file mode 100644 index 000000000..e762f9b6e --- /dev/null +++ b/llama_stack/distributions/starter-gpu/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .starter_gpu import get_distribution_template # noqa: F401 diff --git a/llama_stack/distributions/starter-gpu/build.yaml b/llama_stack/distributions/starter-gpu/build.yaml new file mode 100644 index 000000000..ae0680cdc --- /dev/null +++ b/llama_stack/distributions/starter-gpu/build.yaml @@ -0,0 +1,59 @@ +version: 2 +distribution_spec: + description: Quick start template for running Llama Stack with several popular providers. + This distribution is intended for GPU-enabled environments. + providers: + inference: + - provider_type: remote::cerebras + - provider_type: remote::ollama + - provider_type: remote::vllm + - provider_type: remote::tgi + - provider_type: remote::fireworks + - provider_type: remote::together + - provider_type: remote::bedrock + - provider_type: remote::nvidia + - provider_type: remote::openai + - provider_type: remote::anthropic + - provider_type: remote::gemini + - provider_type: remote::vertexai + - provider_type: remote::groq + - provider_type: remote::sambanova + - provider_type: inline::sentence-transformers + vector_io: + - provider_type: inline::faiss + - provider_type: inline::sqlite-vec + - provider_type: inline::milvus + - provider_type: remote::chromadb + - provider_type: remote::pgvector + files: + - provider_type: inline::localfs + safety: + - provider_type: inline::llama-guard + - provider_type: inline::code-scanner + agents: + - provider_type: inline::meta-reference + telemetry: + - provider_type: inline::meta-reference + post_training: + - provider_type: inline::torchtune-gpu + eval: + - provider_type: inline::meta-reference + datasetio: + - provider_type: remote::huggingface + - provider_type: inline::localfs + scoring: + - provider_type: inline::basic + - provider_type: inline::llm-as-judge + - provider_type: inline::braintrust + tool_runtime: + - provider_type: remote::brave-search + - provider_type: remote::tavily-search + - provider_type: inline::rag-runtime + - provider_type: remote::model-context-protocol + batches: + - provider_type: inline::reference +image_type: venv +additional_pip_packages: +- aiosqlite +- asyncpg +- sqlalchemy[asyncio] diff --git a/llama_stack/distributions/starter-gpu/run.yaml b/llama_stack/distributions/starter-gpu/run.yaml new file mode 100644 index 000000000..81c802317 --- /dev/null +++ b/llama_stack/distributions/starter-gpu/run.yaml @@ -0,0 +1,238 @@ +version: 2 +image_name: starter-gpu +apis: +- agents +- batches +- datasetio +- eval +- files +- inference +- post_training +- safety +- scoring +- telemetry +- tool_runtime +- vector_io +providers: + inference: + - provider_id: ${env.CEREBRAS_API_KEY:+cerebras} + provider_type: remote::cerebras + config: + base_url: https://api.cerebras.ai + api_key: ${env.CEREBRAS_API_KEY:=} + - provider_id: ${env.OLLAMA_URL:+ollama} + provider_type: remote::ollama + config: + url: ${env.OLLAMA_URL:=http://localhost:11434} + - provider_id: ${env.VLLM_URL:+vllm} + provider_type: remote::vllm + config: + url: ${env.VLLM_URL:=} + max_tokens: ${env.VLLM_MAX_TOKENS:=4096} + api_token: ${env.VLLM_API_TOKEN:=fake} + tls_verify: ${env.VLLM_TLS_VERIFY:=true} + - provider_id: ${env.TGI_URL:+tgi} + provider_type: remote::tgi + config: + url: ${env.TGI_URL:=} + - provider_id: fireworks + provider_type: remote::fireworks + config: + url: https://api.fireworks.ai/inference/v1 + api_key: ${env.FIREWORKS_API_KEY:=} + - provider_id: together + provider_type: remote::together + config: + url: https://api.together.xyz/v1 + api_key: ${env.TOGETHER_API_KEY:=} + - provider_id: bedrock + provider_type: remote::bedrock + - provider_id: ${env.NVIDIA_API_KEY:+nvidia} + provider_type: remote::nvidia + config: + url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com} + api_key: ${env.NVIDIA_API_KEY:=} + append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True} + - provider_id: openai + provider_type: remote::openai + config: + api_key: ${env.OPENAI_API_KEY:=} + base_url: ${env.OPENAI_BASE_URL:=https://api.openai.com/v1} + - provider_id: anthropic + provider_type: remote::anthropic + config: + api_key: ${env.ANTHROPIC_API_KEY:=} + - provider_id: gemini + provider_type: remote::gemini + config: + api_key: ${env.GEMINI_API_KEY:=} + - provider_id: ${env.VERTEX_AI_PROJECT:+vertexai} + provider_type: remote::vertexai + config: + project: ${env.VERTEX_AI_PROJECT:=} + location: ${env.VERTEX_AI_LOCATION:=us-central1} + - provider_id: groq + provider_type: remote::groq + config: + url: https://api.groq.com + api_key: ${env.GROQ_API_KEY:=} + - provider_id: sambanova + provider_type: remote::sambanova + config: + url: https://api.sambanova.ai/v1 + api_key: ${env.SAMBANOVA_API_KEY:=} + - provider_id: sentence-transformers + provider_type: inline::sentence-transformers + vector_io: + - provider_id: faiss + provider_type: inline::faiss + config: + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/faiss_store.db + - provider_id: sqlite-vec + provider_type: inline::sqlite-vec + config: + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec.db + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec_registry.db + - provider_id: ${env.MILVUS_URL:+milvus} + provider_type: inline::milvus + config: + db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/starter}/milvus.db + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/milvus_registry.db + - provider_id: ${env.CHROMADB_URL:+chromadb} + provider_type: remote::chromadb + config: + url: ${env.CHROMADB_URL:=} + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter/}/chroma_remote_registry.db + - provider_id: ${env.PGVECTOR_DB:+pgvector} + provider_type: remote::pgvector + config: + host: ${env.PGVECTOR_HOST:=localhost} + port: ${env.PGVECTOR_PORT:=5432} + db: ${env.PGVECTOR_DB:=} + user: ${env.PGVECTOR_USER:=} + password: ${env.PGVECTOR_PASSWORD:=} + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/pgvector_registry.db + files: + - provider_id: meta-reference-files + provider_type: inline::localfs + config: + storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files} + metadata_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db + safety: + - provider_id: llama-guard + provider_type: inline::llama-guard + config: + excluded_categories: [] + - provider_id: code-scanner + provider_type: inline::code-scanner + agents: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + persistence_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/responses_store.db + telemetry: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + service_name: "${env.OTEL_SERVICE_NAME:=\u200B}" + sinks: ${env.TELEMETRY_SINKS:=console,sqlite} + sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/trace_store.db + otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=} + post_training: + - provider_id: torchtune-gpu + provider_type: inline::torchtune-gpu + config: + checkpoint_format: meta + eval: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/meta_reference_eval.db + datasetio: + - provider_id: huggingface + provider_type: remote::huggingface + config: + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/huggingface_datasetio.db + - provider_id: localfs + provider_type: inline::localfs + config: + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/localfs_datasetio.db + scoring: + - provider_id: basic + provider_type: inline::basic + - provider_id: llm-as-judge + provider_type: inline::llm-as-judge + - provider_id: braintrust + provider_type: inline::braintrust + config: + openai_api_key: ${env.OPENAI_API_KEY:=} + tool_runtime: + - provider_id: brave-search + provider_type: remote::brave-search + config: + api_key: ${env.BRAVE_SEARCH_API_KEY:=} + max_results: 3 + - provider_id: tavily-search + provider_type: remote::tavily-search + config: + api_key: ${env.TAVILY_SEARCH_API_KEY:=} + max_results: 3 + - provider_id: rag-runtime + provider_type: inline::rag-runtime + - provider_id: model-context-protocol + provider_type: remote::model-context-protocol + batches: + - provider_id: reference + provider_type: inline::reference + config: + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/batches.db +metadata_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/inference_store.db +models: [] +shields: +- shield_id: llama-guard + provider_id: ${env.SAFETY_MODEL:+llama-guard} + provider_shield_id: ${env.SAFETY_MODEL:=} +- shield_id: code-scanner + provider_id: ${env.CODE_SCANNER_MODEL:+code-scanner} + provider_shield_id: ${env.CODE_SCANNER_MODEL:=} +vector_dbs: [] +datasets: [] +scoring_fns: [] +benchmarks: [] +tool_groups: +- toolgroup_id: builtin::websearch + provider_id: tavily-search +- toolgroup_id: builtin::rag + provider_id: rag-runtime +server: + port: 8321 diff --git a/llama_stack/distributions/starter-gpu/starter_gpu.py b/llama_stack/distributions/starter-gpu/starter_gpu.py new file mode 100644 index 000000000..893df6c17 --- /dev/null +++ b/llama_stack/distributions/starter-gpu/starter_gpu.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +from llama_stack.distributions.template import BuildProvider, DistributionTemplate + +from ..starter.starter import get_distribution_template as get_starter_distribution_template + + +def get_distribution_template() -> DistributionTemplate: + template = get_starter_distribution_template() + name = "starter-gpu" + template.name = name + template.description = "Quick start template for running Llama Stack with several popular providers. This distribution is intended for GPU-enabled environments." + + template.providers["post_training"] = [ + BuildProvider(provider_type="inline::torchtune-gpu"), + ] + return template diff --git a/llama_stack/distributions/starter/build.yaml b/llama_stack/distributions/starter/build.yaml index 2ad12a165..3df0eb129 100644 --- a/llama_stack/distributions/starter/build.yaml +++ b/llama_stack/distributions/starter/build.yaml @@ -1,6 +1,7 @@ version: 2 distribution_spec: - description: Quick start template for running Llama Stack with several popular providers + description: Quick start template for running Llama Stack with several popular providers. + This distribution is intended for CPU-only environments. providers: inference: - provider_type: remote::cerebras @@ -34,7 +35,7 @@ distribution_spec: telemetry: - provider_type: inline::meta-reference post_training: - - provider_type: inline::huggingface + - provider_type: inline::huggingface-cpu eval: - provider_type: inline::meta-reference datasetio: diff --git a/llama_stack/distributions/starter/run.yaml b/llama_stack/distributions/starter/run.yaml index 7ac4dc6b9..7e1d46a61 100644 --- a/llama_stack/distributions/starter/run.yaml +++ b/llama_stack/distributions/starter/run.yaml @@ -156,8 +156,8 @@ providers: sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/trace_store.db otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=} post_training: - - provider_id: huggingface - provider_type: inline::huggingface + - provider_id: huggingface-cpu + provider_type: inline::huggingface-cpu config: checkpoint_format: huggingface distributed_backend: null diff --git a/llama_stack/distributions/starter/starter.py b/llama_stack/distributions/starter/starter.py index cad3d72d9..f49da0bb7 100644 --- a/llama_stack/distributions/starter/starter.py +++ b/llama_stack/distributions/starter/starter.py @@ -120,7 +120,7 @@ def get_distribution_template() -> DistributionTemplate: ], "agents": [BuildProvider(provider_type="inline::meta-reference")], "telemetry": [BuildProvider(provider_type="inline::meta-reference")], - "post_training": [BuildProvider(provider_type="inline::huggingface")], + "post_training": [BuildProvider(provider_type="inline::huggingface-cpu")], "eval": [BuildProvider(provider_type="inline::meta-reference")], "datasetio": [ BuildProvider(provider_type="remote::huggingface"), @@ -178,7 +178,7 @@ def get_distribution_template() -> DistributionTemplate: return DistributionTemplate( name=name, distro_type="self_hosted", - description="Quick start template for running Llama Stack with several popular providers", + description="Quick start template for running Llama Stack with several popular providers. This distribution is intended for CPU-only environments.", container_image=None, template_path=None, providers=providers, diff --git a/llama_stack/log.py b/llama_stack/log.py index d67bd1b61..cc4c9d4cf 100644 --- a/llama_stack/log.py +++ b/llama_stack/log.py @@ -4,10 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging +import logging # allow-direct-logging import os import re -from logging.config import dictConfig +from logging.config import dictConfig # allow-direct-logging from rich.console import Console from rich.errors import MarkupError diff --git a/llama_stack/models/llama/llama3/multimodal/encoder_utils.py b/llama_stack/models/llama/llama3/multimodal/encoder_utils.py index 5b5969d89..90ced13b2 100644 --- a/llama_stack/models/llama/llama3/multimodal/encoder_utils.py +++ b/llama_stack/models/llama/llama3/multimodal/encoder_utils.py @@ -13,14 +13,15 @@ # Copyright (c) Meta Platforms, Inc. and its affiliates. import math -from logging import getLogger import torch import torch.nn.functional as F +from llama_stack.log import get_logger + from .utils import get_negative_inf_value, to_2tuple -logger = getLogger() +logger = get_logger(name=__name__, category="models::llama") def resize_local_position_embedding(orig_pos_embed, grid_size): diff --git a/llama_stack/models/llama/llama3/multimodal/image_transform.py b/llama_stack/models/llama/llama3/multimodal/image_transform.py index f2761ee47..7b20a31fa 100644 --- a/llama_stack/models/llama/llama3/multimodal/image_transform.py +++ b/llama_stack/models/llama/llama3/multimodal/image_transform.py @@ -13,7 +13,6 @@ import math from collections import defaultdict -from logging import getLogger from typing import Any import torch @@ -21,9 +20,11 @@ import torchvision.transforms as tv from PIL import Image from torchvision.transforms import functional as F +from llama_stack.log import get_logger + IMAGE_RES = 224 -logger = getLogger() +logger = get_logger(name=__name__, category="models::llama") class VariableSizeImageTransform: diff --git a/llama_stack/models/llama/llama3/multimodal/model.py b/llama_stack/models/llama/llama3/multimodal/model.py index 5f1c3605c..7b501eb0e 100644 --- a/llama_stack/models/llama/llama3/multimodal/model.py +++ b/llama_stack/models/llama/llama3/multimodal/model.py @@ -3,8 +3,6 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. - -import logging import math from collections.abc import Callable from functools import partial @@ -22,6 +20,8 @@ from PIL import Image as PIL_Image from torch import Tensor, nn from torch.distributed import _functional_collectives as funcol +from llama_stack.log import get_logger + from ..model import ModelArgs, RMSNorm, apply_rotary_emb, precompute_freqs_cis from .encoder_utils import ( build_encoder_attention_mask, @@ -34,9 +34,10 @@ from .encoder_utils import ( from .image_transform import VariableSizeImageTransform from .utils import get_negative_inf_value, to_2tuple -logger = logging.getLogger(__name__) MP_SCALE = 8 +logger = get_logger(name=__name__, category="models::llama") + def reduce_from_tensor_model_parallel_region(input_): """All-reduce the input tensor across model parallel group.""" @@ -771,7 +772,7 @@ class TilePositionEmbedding(nn.Module): if embed is not None: # reshape the weights to the correct shape nt_old, nt_old, _, w = embed.shape - logging.info(f"Resizing tile embedding from {nt_old}x{nt_old} to {self.num_tiles}x{self.num_tiles}") + logger.info(f"Resizing tile embedding from {nt_old}x{nt_old} to {self.num_tiles}x{self.num_tiles}") embed_new = TilePositionEmbedding._dynamic_resize(embed, self.num_tiles) # assign the weights to the module state_dict[prefix + "embedding"] = embed_new diff --git a/llama_stack/models/llama/llama3/tokenizer.py b/llama_stack/models/llama/llama3/tokenizer.py index e47b579e3..ad7ced1c5 100644 --- a/llama_stack/models/llama/llama3/tokenizer.py +++ b/llama_stack/models/llama/llama3/tokenizer.py @@ -4,8 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. + from collections.abc import Collection, Iterator, Sequence, Set -from logging import getLogger from pathlib import Path from typing import ( Literal, @@ -14,11 +14,9 @@ from typing import ( import tiktoken +from llama_stack.log import get_logger from llama_stack.models.llama.tokenizer_utils import load_bpe_file -logger = getLogger(__name__) - - # The tiktoken tokenizer can handle <=400k chars without # pyo3_runtime.PanicException. TIKTOKEN_MAX_ENCODE_CHARS = 400_000 @@ -31,6 +29,8 @@ MAX_NO_WHITESPACES_CHARS = 25_000 _INSTANCE = None +logger = get_logger(name=__name__, category="models::llama") + class Tokenizer: """ diff --git a/llama_stack/models/llama/llama3/tool_utils.py b/llama_stack/models/llama/llama3/tool_utils.py index 574080184..d0e3e7671 100644 --- a/llama_stack/models/llama/llama3/tool_utils.py +++ b/llama_stack/models/llama/llama3/tool_utils.py @@ -11,7 +11,7 @@ from llama_stack.log import get_logger from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat -logger = get_logger(name=__name__, category="inference") +logger = get_logger(name=__name__, category="models::llama") BUILTIN_TOOL_PATTERN = r'\b(?P\w+)\.call\(query="(?P[^"]*)"\)' CUSTOM_TOOL_CALL_PATTERN = re.compile(r"[^}]+)>(?P{.*?})") diff --git a/llama_stack/models/llama/llama4/quantization/loader.py b/llama_stack/models/llama/llama4/quantization/loader.py index 223744a5f..7557a8a64 100644 --- a/llama_stack/models/llama/llama4/quantization/loader.py +++ b/llama_stack/models/llama/llama4/quantization/loader.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging import os from collections.abc import Callable @@ -13,11 +12,13 @@ from fairscale.nn.model_parallel.initialize import get_model_parallel_rank from torch import Tensor, nn from torch.nn import functional as F +from llama_stack.log import get_logger + from ...datatypes import QuantizationMode from ..model import Transformer, TransformerBlock from ..moe import MoE -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="models::llama") def swiglu_wrapper_no_reduce( diff --git a/llama_stack/models/llama/llama4/tokenizer.py b/llama_stack/models/llama/llama4/tokenizer.py index e12b2cae0..bfbace8f9 100644 --- a/llama_stack/models/llama/llama4/tokenizer.py +++ b/llama_stack/models/llama/llama4/tokenizer.py @@ -5,7 +5,6 @@ # the root directory of this source tree. from collections.abc import Collection, Iterator, Sequence, Set -from logging import getLogger from pathlib import Path from typing import ( Literal, @@ -14,11 +13,9 @@ from typing import ( import tiktoken +from llama_stack.log import get_logger from llama_stack.models.llama.tokenizer_utils import load_bpe_file -logger = getLogger(__name__) - - # The tiktoken tokenizer can handle <=400k chars without # pyo3_runtime.PanicException. TIKTOKEN_MAX_ENCODE_CHARS = 400_000 @@ -101,6 +98,8 @@ BASIC_SPECIAL_TOKENS = [ "<|fim_suffix|>", ] +logger = get_logger(name=__name__, category="models::llama") + class Tokenizer: """ diff --git a/llama_stack/models/llama/quantize_impls.py b/llama_stack/models/llama/quantize_impls.py index a6400c5c9..0a205601f 100644 --- a/llama_stack/models/llama/quantize_impls.py +++ b/llama_stack/models/llama/quantize_impls.py @@ -6,9 +6,10 @@ # type: ignore import collections -import logging -log = logging.getLogger(__name__) +from llama_stack.log import get_logger + +log = get_logger(name=__name__, category="models::llama") try: import fbgemm_gpu.experimental.gen_ai # noqa: F401 diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 5f7c90879..fde38515b 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -84,7 +84,7 @@ MEMORY_QUERY_TOOL = "knowledge_search" WEB_SEARCH_TOOL = "web_search" RAG_TOOL_GROUP = "builtin::rag" -logger = get_logger(name=__name__, category="agents") +logger = get_logger(name=__name__, category="agents::meta_reference") class ChatAgent(ShieldRunnerMixin): diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index 30196c429..8bdde86b0 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging import uuid from collections.abc import AsyncGenerator from datetime import UTC, datetime @@ -42,6 +41,7 @@ from llama_stack.apis.safety import Safety from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.apis.vector_io import VectorIO from llama_stack.core.datatypes import AccessRule +from llama_stack.log import get_logger from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl from llama_stack.providers.utils.pagination import paginate_records from llama_stack.providers.utils.responses.responses_store import ResponsesStore @@ -51,7 +51,7 @@ from .config import MetaReferenceAgentsImplConfig from .persistence import AgentInfo from .responses.openai_responses import OpenAIResponsesImpl -logger = logging.getLogger() +logger = get_logger(name=__name__, category="agents::meta_reference") class MetaReferenceAgentsImpl(Agents): diff --git a/llama_stack/providers/inline/agents/meta_reference/persistence.py b/llama_stack/providers/inline/agents/meta_reference/persistence.py index 0b234d96c..3b7b4729c 100644 --- a/llama_stack/providers/inline/agents/meta_reference/persistence.py +++ b/llama_stack/providers/inline/agents/meta_reference/persistence.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import json -import logging import uuid from datetime import UTC, datetime @@ -15,9 +14,10 @@ from llama_stack.core.access_control.access_control import AccessDeniedError, is from llama_stack.core.access_control.datatypes import AccessRule from llama_stack.core.datatypes import User from llama_stack.core.request_headers import get_authenticated_user +from llama_stack.log import get_logger from llama_stack.providers.utils.kvstore import KVStore -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="agents::meta_reference") class AgentSessionInfo(Session): diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py b/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py index e528a4005..c632e61aa 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py @@ -41,7 +41,7 @@ from .utils import ( convert_response_text_to_chat_response_format, ) -logger = get_logger(name=__name__, category="responses") +logger = get_logger(name=__name__, category="openai::responses") class OpenAIResponsePreviousResponseWithInputItems(BaseModel): diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py index 0879e978a..3e69fa5cd 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py @@ -47,7 +47,7 @@ from llama_stack.log import get_logger from .types import ChatCompletionContext, ChatCompletionResult from .utils import convert_chat_choice_to_response_message, is_function_tool_call -logger = get_logger(name=__name__, category="responses") +logger = get_logger(name=__name__, category="agents::meta_reference") class StreamingResponseOrchestrator: diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py b/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py index 5b98b4f51..b028c018b 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py @@ -38,7 +38,7 @@ from llama_stack.log import get_logger from .types import ChatCompletionContext, ToolExecutionResult -logger = get_logger(name=__name__, category="responses") +logger = get_logger(name=__name__, category="agents::meta_reference") class ToolExecutor: diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/utils.py b/llama_stack/providers/inline/agents/meta_reference/responses/utils.py index 1507a55c8..7aaeb4cd5 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/utils.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/utils.py @@ -17,6 +17,8 @@ from llama_stack.apis.agents.openai_responses import ( OpenAIResponseOutputMessageContent, OpenAIResponseOutputMessageContentOutputText, OpenAIResponseOutputMessageFunctionToolCall, + OpenAIResponseOutputMessageMCPCall, + OpenAIResponseOutputMessageMCPListTools, OpenAIResponseText, ) from llama_stack.apis.inference import ( @@ -99,14 +101,22 @@ async def convert_response_input_to_chat_messages( """ messages: list[OpenAIMessageParam] = [] if isinstance(input, list): + # extract all OpenAIResponseInputFunctionToolCallOutput items + # so their corresponding OpenAIToolMessageParam instances can + # be added immediately following the corresponding + # OpenAIAssistantMessageParam + tool_call_results = {} for input_item in input: if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput): - messages.append( - OpenAIToolMessageParam( - content=input_item.output, - tool_call_id=input_item.call_id, - ) + tool_call_results[input_item.call_id] = OpenAIToolMessageParam( + content=input_item.output, + tool_call_id=input_item.call_id, ) + + for input_item in input: + if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput): + # skip as these have been extracted and inserted in order + pass elif isinstance(input_item, OpenAIResponseOutputMessageFunctionToolCall): tool_call = OpenAIChatCompletionToolCall( index=0, @@ -117,6 +127,28 @@ async def convert_response_input_to_chat_messages( ), ) messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call])) + if input_item.call_id in tool_call_results: + messages.append(tool_call_results[input_item.call_id]) + del tool_call_results[input_item.call_id] + elif isinstance(input_item, OpenAIResponseOutputMessageMCPCall): + tool_call = OpenAIChatCompletionToolCall( + index=0, + id=input_item.id, + function=OpenAIChatCompletionToolCallFunction( + name=input_item.name, + arguments=input_item.arguments, + ), + ) + messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call])) + messages.append( + OpenAIToolMessageParam( + content=input_item.output, + tool_call_id=input_item.id, + ) + ) + elif isinstance(input_item, OpenAIResponseOutputMessageMCPListTools): + # the tool list will be handled separately + pass else: content = await convert_response_content_to_chat_content(input_item.content) message_type = await get_message_type_by_role(input_item.role) @@ -125,6 +157,10 @@ async def convert_response_input_to_chat_messages( f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context" ) messages.append(message_type(content=content)) + if len(tool_call_results): + raise ValueError( + f"Received function_call_output(s) with call_id(s) {tool_call_results.keys()}, but no corresponding function_call" + ) else: messages.append(OpenAIUserMessageParam(content=input)) return messages diff --git a/llama_stack/providers/inline/agents/meta_reference/safety.py b/llama_stack/providers/inline/agents/meta_reference/safety.py index 605f387b7..8f3ecf5c9 100644 --- a/llama_stack/providers/inline/agents/meta_reference/safety.py +++ b/llama_stack/providers/inline/agents/meta_reference/safety.py @@ -5,13 +5,13 @@ # the root directory of this source tree. import asyncio -import logging from llama_stack.apis.inference import Message from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel +from llama_stack.log import get_logger from llama_stack.providers.utils.telemetry import tracing -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="agents::meta_reference") class SafetyException(Exception): # noqa: N818 diff --git a/llama_stack/providers/inline/batches/reference/batches.py b/llama_stack/providers/inline/batches/reference/batches.py index 1ff554e70..26f0ad15a 100644 --- a/llama_stack/providers/inline/batches/reference/batches.py +++ b/llama_stack/providers/inline/batches/reference/batches.py @@ -5,6 +5,7 @@ # the root directory of this source tree. import asyncio +import hashlib import itertools import json import time @@ -136,28 +137,45 @@ class ReferenceBatchesImpl(Batches): endpoint: str, completion_window: Literal["24h"], metadata: dict[str, str] | None = None, + idempotency_key: str | None = None, ) -> BatchObject: """ Create a new batch for processing multiple API requests. - Error handling by levels - - 0. Input param handling, results in 40x errors before processing, e.g. - - Wrong completion_window - - Invalid metadata types - - Unknown endpoint - -> no batch created - 1. Errors preventing processing, result in BatchErrors aggregated in process_batch, e.g. - - input_file_id missing - - invalid json in file - - missing custom_id, method, url, body - - invalid model - - streaming - -> batch created, validation sends to failed status - 2. Processing errors, result in error_file_id entries, e.g. - - Any error returned from inference endpoint - -> batch created, goes to completed status + This implementation provides optional idempotency: when an idempotency key + (idempotency_key) is provided, a deterministic ID is generated based on the input + parameters. If a batch with the same parameters already exists, it will be + returned instead of creating a duplicate. Without an idempotency key, + each request creates a new batch with a unique ID. + + Args: + input_file_id: The ID of an uploaded file containing requests for the batch. + endpoint: The endpoint to be used for all requests in the batch. + completion_window: The time window within which the batch should be processed. + metadata: Optional metadata for the batch. + idempotency_key: Optional idempotency key for enabling idempotent behavior. + + Returns: + The created or existing batch object. """ + # Error handling by levels - + # 0. Input param handling, results in 40x errors before processing, e.g. + # - Wrong completion_window + # - Invalid metadata types + # - Unknown endpoint + # -> no batch created + # 1. Errors preventing processing, result in BatchErrors aggregated in process_batch, e.g. + # - input_file_id missing + # - invalid json in file + # - missing custom_id, method, url, body + # - invalid model + # - streaming + # -> batch created, validation sends to failed status + # 2. Processing errors, result in error_file_id entries, e.g. + # - Any error returned from inference endpoint + # -> batch created, goes to completed status + # TODO: set expiration time for garbage collection if endpoint not in ["/v1/chat/completions"]: @@ -171,6 +189,35 @@ class ReferenceBatchesImpl(Batches): ) batch_id = f"batch_{uuid.uuid4().hex[:16]}" + + # For idempotent requests, use the idempotency key for the batch ID + # This ensures the same key always maps to the same batch ID, + # allowing us to detect parameter conflicts + if idempotency_key is not None: + hash_input = idempotency_key.encode("utf-8") + hash_digest = hashlib.sha256(hash_input).hexdigest()[:24] + batch_id = f"batch_{hash_digest}" + + try: + existing_batch = await self.retrieve_batch(batch_id) + + if ( + existing_batch.input_file_id != input_file_id + or existing_batch.endpoint != endpoint + or existing_batch.completion_window != completion_window + or existing_batch.metadata != metadata + ): + raise ConflictError( + f"Idempotency key '{idempotency_key}' was previously used with different parameters. " + "Either use a new idempotency key or ensure all parameters match the original request." + ) + + logger.info(f"Returning existing batch with ID: {batch_id}") + return existing_batch + except ResourceNotFoundError: + # Batch doesn't exist, continue with creation + pass + current_time = int(time.time()) batch = BatchObject( @@ -185,6 +232,7 @@ class ReferenceBatchesImpl(Batches): ) await self.kvstore.set(f"batch:{batch_id}", batch.to_json()) + logger.info(f"Created new batch with ID: {batch_id}") if self.process_batches: task = asyncio.create_task(self._process_batch(batch_id)) diff --git a/llama_stack/providers/inline/files/localfs/files.py b/llama_stack/providers/inline/files/localfs/files.py index 1e9dca3b5..4f6d571a4 100644 --- a/llama_stack/providers/inline/files/localfs/files.py +++ b/llama_stack/providers/inline/files/localfs/files.py @@ -11,6 +11,7 @@ from typing import Annotated from fastapi import File, Form, Response, UploadFile +from llama_stack.apis.common.errors import ResourceNotFoundError from llama_stack.apis.common.responses import Order from llama_stack.apis.files import ( Files, @@ -20,12 +21,15 @@ from llama_stack.apis.files import ( OpenAIFilePurpose, ) from llama_stack.core.datatypes import AccessRule +from llama_stack.log import get_logger from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl from .config import LocalfsFilesImplConfig +logger = get_logger(name=__name__, category="files") + class LocalfsFilesImpl(Files): def __init__(self, config: LocalfsFilesImplConfig, policy: list[AccessRule]) -> None: @@ -65,6 +69,18 @@ class LocalfsFilesImpl(Files): """Get the filesystem path for a file ID.""" return Path(self.config.storage_dir) / file_id + async def _lookup_file_id(self, file_id: str) -> tuple[OpenAIFileObject, Path]: + """Look up a OpenAIFileObject and filesystem path from its ID.""" + if not self.sql_store: + raise RuntimeError("Files provider not initialized") + + row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id}) + if not row: + raise ResourceNotFoundError(file_id, "File", "client.files.list()") + + file_path = Path(row.pop("file_path")) + return OpenAIFileObject(**row), file_path + # OpenAI Files API Implementation async def openai_upload_file( self, @@ -157,37 +173,19 @@ class LocalfsFilesImpl(Files): async def openai_retrieve_file(self, file_id: str) -> OpenAIFileObject: """Returns information about a specific file.""" - if not self.sql_store: - raise RuntimeError("Files provider not initialized") + file_obj, _ = await self._lookup_file_id(file_id) - row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id}) - if not row: - raise ValueError(f"File with id {file_id} not found") - - return OpenAIFileObject( - id=row["id"], - filename=row["filename"], - purpose=OpenAIFilePurpose(row["purpose"]), - bytes=row["bytes"], - created_at=row["created_at"], - expires_at=row["expires_at"], - ) + return file_obj async def openai_delete_file(self, file_id: str) -> OpenAIFileDeleteResponse: """Delete a file.""" - if not self.sql_store: - raise RuntimeError("Files provider not initialized") - - row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id}) - if not row: - raise ValueError(f"File with id {file_id} not found") - # Delete physical file - file_path = Path(row["file_path"]) + _, file_path = await self._lookup_file_id(file_id) if file_path.exists(): file_path.unlink() # Delete metadata from database + assert self.sql_store is not None, "Files provider not initialized" await self.sql_store.delete("openai_files", where={"id": file_id}) return OpenAIFileDeleteResponse( @@ -197,25 +195,17 @@ class LocalfsFilesImpl(Files): async def openai_retrieve_file_content(self, file_id: str) -> Response: """Returns the contents of the specified file.""" - if not self.sql_store: - raise RuntimeError("Files provider not initialized") - - # Get file metadata - row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id}) - if not row: - raise ValueError(f"File with id {file_id} not found") - # Read file content - file_path = Path(row["file_path"]) - if not file_path.exists(): - raise ValueError(f"File content not found on disk: {file_path}") + file_obj, file_path = await self._lookup_file_id(file_id) - with open(file_path, "rb") as f: - content = f.read() + if not file_path.exists(): + logger.warning(f"File '{file_id}'s underlying '{file_path}' is missing, deleting metadata.") + await self.openai_delete_file(file_id) + raise ResourceNotFoundError(file_id, "File", "client.files.list()") # Return as binary response with appropriate content type return Response( - content=content, + content=file_path.read_bytes(), media_type="application/octet-stream", - headers={"Content-Disposition": f'attachment; filename="{row["filename"]}"'}, + headers={"Content-Disposition": f'attachment; filename="{file_obj.filename}"'}, ) diff --git a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py index 7ade75032..bb6a1bd03 100644 --- a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +++ b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py @@ -12,7 +12,6 @@ import copy import json -import logging import multiprocessing import os import tempfile @@ -32,13 +31,14 @@ from fairscale.nn.model_parallel.initialize import ( from pydantic import BaseModel, Field from torch.distributed.launcher.api import LaunchConfig, elastic_launch +from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import GenerationResult from llama_stack.providers.utils.inference.prompt_adapter import ( ChatCompletionRequestWithRawContent, CompletionRequestWithRawContent, ) -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="inference") class ProcessingMessageName(str, Enum): diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py index fea8a8189..34665b63e 100644 --- a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -4,13 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging from collections.abc import AsyncGenerator from llama_stack.apis.inference import ( CompletionResponse, InferenceProvider, - InterleavedContent, LogProbConfig, Message, ResponseFormat, @@ -21,6 +19,7 @@ from llama_stack.apis.inference import ( ToolPromptFormat, ) from llama_stack.apis.models import ModelType +from llama_stack.log import get_logger from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate from llama_stack.providers.utils.inference.embedding_mixin import ( SentenceTransformerEmbeddingMixin, @@ -32,7 +31,7 @@ from llama_stack.providers.utils.inference.openai_compat import ( from .config import SentenceTransformersInferenceConfig -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="inference") class SentenceTransformersInferenceImpl( @@ -100,25 +99,3 @@ class SentenceTransformersInferenceImpl( tool_config: ToolConfig | None = None, ) -> AsyncGenerator: raise ValueError("Sentence transformers don't support chat completion") - - async def batch_completion( - self, - model_id: str, - content_batch: list[InterleavedContent], - sampling_params: SamplingParams | None = None, - response_format: ResponseFormat | None = None, - logprobs: LogProbConfig | None = None, - ): - raise NotImplementedError("Batch completion is not supported for Sentence Transformers") - - async def batch_chat_completion( - self, - model_id: str, - messages_batch: list[list[Message]], - sampling_params: SamplingParams | None = None, - tools: list[ToolDefinition] | None = None, - tool_config: ToolConfig | None = None, - response_format: ResponseFormat | None = None, - logprobs: LogProbConfig | None = None, - ): - raise NotImplementedError("Batch chat completion is not supported for Sentence Transformers") diff --git a/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py b/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py index 2574b995b..d9ee3d2a8 100644 --- a/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py +++ b/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device.py @@ -6,7 +6,6 @@ import gc import json -import logging import multiprocessing from pathlib import Path from typing import Any @@ -28,6 +27,7 @@ from llama_stack.apis.post_training import ( LoraFinetuningConfig, TrainingConfig, ) +from llama_stack.log import get_logger from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device from ..config import HuggingFacePostTrainingConfig @@ -44,7 +44,7 @@ from ..utils import ( split_dataset, ) -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="post_training") class HFFinetuningSingleDevice: diff --git a/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device_dpo.py b/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device_dpo.py index a7c19faac..b39a24c66 100644 --- a/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device_dpo.py +++ b/llama_stack/providers/inline/post_training/huggingface/recipes/finetune_single_device_dpo.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import gc -import logging import multiprocessing from pathlib import Path from typing import Any @@ -24,6 +23,7 @@ from llama_stack.apis.post_training import ( DPOAlignmentConfig, TrainingConfig, ) +from llama_stack.log import get_logger from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device from ..config import HuggingFacePostTrainingConfig @@ -40,7 +40,7 @@ from ..utils import ( split_dataset, ) -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="post_training") class HFDPOAlignmentSingleDevice: diff --git a/llama_stack/providers/inline/post_training/huggingface/utils.py b/llama_stack/providers/inline/post_training/huggingface/utils.py index 3147c19ab..f229c87dd 100644 --- a/llama_stack/providers/inline/post_training/huggingface/utils.py +++ b/llama_stack/providers/inline/post_training/huggingface/utils.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging import os import signal import sys @@ -19,10 +18,11 @@ from transformers import AutoConfig, AutoModelForCausalLM from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.post_training import Checkpoint, TrainingConfig +from llama_stack.log import get_logger from .config import HuggingFacePostTrainingConfig -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="post_training") def setup_environment(): diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index 49e1c95b8..8b1462862 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging import os import time from datetime import UTC, datetime @@ -19,6 +18,7 @@ from torch.utils.data import DataLoader, DistributedSampler from torchtune import modules, training from torchtune import utils as torchtune_utils from torchtune.data import padded_collate_sft +from torchtune.models.llama3._tokenizer import Llama3Tokenizer from torchtune.modules.loss import CEWithChunkedOutputLoss from torchtune.modules.peft import ( get_adapter_params, @@ -45,6 +45,7 @@ from llama_stack.apis.post_training import ( ) from llama_stack.core.utils.config_dirs import DEFAULT_CHECKPOINT_DIR from llama_stack.core.utils.model_utils import model_local_dir +from llama_stack.log import get_logger from llama_stack.models.llama.sku_list import resolve_model from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device from llama_stack.providers.inline.post_training.torchtune.common import utils @@ -56,9 +57,7 @@ from llama_stack.providers.inline.post_training.torchtune.config import ( ) from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset -log = logging.getLogger(__name__) - -from torchtune.models.llama3._tokenizer import Llama3Tokenizer +log = get_logger(name=__name__, category="post_training") class LoraFinetuningSingleDevice: diff --git a/llama_stack/providers/inline/safety/code_scanner/code_scanner.py b/llama_stack/providers/inline/safety/code_scanner/code_scanner.py index 6e05d5b83..5e25c559f 100644 --- a/llama_stack/providers/inline/safety/code_scanner/code_scanner.py +++ b/llama_stack/providers/inline/safety/code_scanner/code_scanner.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging import uuid from typing import TYPE_CHECKING, Any @@ -20,13 +19,14 @@ from llama_stack.apis.safety import ( ) from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults from llama_stack.apis.shields import Shield +from llama_stack.log import get_logger from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, ) from .config import CodeScannerConfig -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="safety") ALLOWED_CODE_SCANNER_MODEL_IDS = [ "code-scanner", diff --git a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py index 5d52c5d89..5c7f30aa7 100644 --- a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py +++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging import re import uuid from string import Template @@ -21,6 +20,7 @@ from llama_stack.apis.safety import ( from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults from llama_stack.apis.shields import Shield from llama_stack.core.datatypes import Api +from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import Role from llama_stack.models.llama.sku_types import CoreModelId from llama_stack.providers.datatypes import ShieldsProtocolPrivate @@ -132,6 +132,8 @@ Provide your safety assessment for ONLY THE LAST $agent_type message in the abov PROMPT_TEMPLATE = Template(f"{PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATION}{PROMPT_INSTRUCTIONS}") +logger = get_logger(name=__name__, category="safety") + class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate): def __init__(self, config: LlamaGuardConfig, deps) -> None: @@ -407,7 +409,7 @@ class LlamaGuardShield: unsafe_code_list = [code.strip() for code in unsafe_code.split(",")] invalid_codes = [code for code in unsafe_code_list if code not in SAFETY_CODE_TO_CATEGORIES_MAP] if invalid_codes: - logging.warning(f"Invalid safety codes returned: {invalid_codes}") + logger.warning(f"Invalid safety codes returned: {invalid_codes}") # just returning safe object, as we don't know what the invalid codes can map to return ModerationObject( id=f"modr-{uuid.uuid4()}", diff --git a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py index c760f0fd1..6fb6c4407 100644 --- a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +++ b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging from typing import Any import torch @@ -21,6 +20,7 @@ from llama_stack.apis.safety import ( from llama_stack.apis.safety.safety import ModerationObject from llama_stack.apis.shields import Shield from llama_stack.core.utils.model_utils import model_local_dir +from llama_stack.log import get_logger from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, @@ -28,7 +28,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import PromptGuardConfig, PromptGuardType -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="safety") PROMPT_GUARD_MODEL = "Prompt-Guard-86M" diff --git a/llama_stack/providers/inline/scoring/basic/utils/ifeval_utils.py b/llama_stack/providers/inline/scoring/basic/utils/ifeval_utils.py index b74c3826e..c9358101d 100644 --- a/llama_stack/providers/inline/scoring/basic/utils/ifeval_utils.py +++ b/llama_stack/providers/inline/scoring/basic/utils/ifeval_utils.py @@ -7,7 +7,6 @@ import collections import functools import json -import logging import random import re import string @@ -20,7 +19,9 @@ import nltk from pythainlp.tokenize import sent_tokenize as sent_tokenize_thai from pythainlp.tokenize import word_tokenize as word_tokenize_thai -logger = logging.getLogger() +from llama_stack.log import get_logger + +logger = get_logger(name=__name__, category="scoring") WORD_LIST = [ "western", diff --git a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py index d99255c79..9224c3792 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/telemetry.py @@ -4,13 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging +import datetime import threading from typing import Any from opentelemetry import metrics, trace - -logger = logging.getLogger(__name__) from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter from opentelemetry.sdk.metrics import MeterProvider @@ -40,6 +38,7 @@ from llama_stack.apis.telemetry import ( UnstructuredLogEvent, ) from llama_stack.core.datatypes import Api +from llama_stack.log import get_logger from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import ( ConsoleSpanProcessor, ) @@ -61,6 +60,8 @@ _GLOBAL_STORAGE: dict[str, dict[str | int, Any]] = { _global_lock = threading.Lock() _TRACER_PROVIDER = None +logger = get_logger(name=__name__, category="telemetry") + def is_tracing_enabled(tracer): with tracer.start_as_current_span("check_tracing") as span: @@ -145,11 +146,41 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry): metric_name: str, start_time: int, end_time: int | None = None, - granularity: str | None = "1d", + granularity: str | None = None, query_type: MetricQueryType = MetricQueryType.RANGE, label_matchers: list[MetricLabelMatcher] | None = None, ) -> QueryMetricsResponse: - raise NotImplementedError("Querying metrics is not implemented") + """Query metrics from the telemetry store. + + Args: + metric_name: The name of the metric to query (e.g., "prompt_tokens") + start_time: Start time as Unix timestamp + end_time: End time as Unix timestamp (defaults to now if None) + granularity: Time granularity for aggregation + query_type: Type of query (RANGE or INSTANT) + label_matchers: Label filters to apply + + Returns: + QueryMetricsResponse with metric time series data + """ + # Convert timestamps to datetime objects + start_dt = datetime.datetime.fromtimestamp(start_time, datetime.UTC) + end_dt = datetime.datetime.fromtimestamp(end_time, datetime.UTC) if end_time else None + + # Use SQLite trace store if available + if hasattr(self, "trace_store") and self.trace_store: + return await self.trace_store.query_metrics( + metric_name=metric_name, + start_time=start_dt, + end_time=end_dt, + granularity=granularity, + query_type=query_type, + label_matchers=label_matchers, + ) + else: + raise ValueError( + f"In order to query_metrics, you must have {TelemetrySink.SQLITE} set in your telemetry sinks" + ) def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None: with self._lock: diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index 6a7c7885c..a1543457b 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import asyncio -import logging import secrets import string from typing import Any @@ -32,6 +31,7 @@ from llama_stack.apis.tools import ( ToolRuntime, ) from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO +from llama_stack.log import get_logger from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str from llama_stack.providers.utils.memory.vector_store import ( @@ -42,7 +42,7 @@ from llama_stack.providers.utils.memory.vector_store import ( from .config import RagToolRuntimeConfig from .context_retriever import generate_rag_query -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="tool_runtime") def make_random_string(length: int = 8): diff --git a/llama_stack/providers/inline/vector_io/faiss/faiss.py b/llama_stack/providers/inline/vector_io/faiss/faiss.py index af61da59b..258c6e7aa 100644 --- a/llama_stack/providers/inline/vector_io/faiss/faiss.py +++ b/llama_stack/providers/inline/vector_io/faiss/faiss.py @@ -8,7 +8,6 @@ import asyncio import base64 import io import json -import logging from typing import Any import faiss @@ -24,6 +23,7 @@ from llama_stack.apis.vector_io import ( QueryChunksResponse, VectorIO, ) +from llama_stack.log import get_logger from llama_stack.providers.datatypes import ( HealthResponse, HealthStatus, @@ -40,7 +40,7 @@ from llama_stack.providers.utils.memory.vector_store import ( from .config import FaissVectorIOConfig -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="vector_io") VERSION = "v3" VECTOR_DBS_PREFIX = f"vector_dbs:{VERSION}::" diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py index cc1982f3b..7cf163960 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import asyncio -import logging import re import sqlite3 import struct @@ -24,6 +23,7 @@ from llama_stack.apis.vector_io import ( QueryChunksResponse, VectorIO, ) +from llama_stack.log import get_logger from llama_stack.providers.datatypes import VectorDBsProtocolPrivate from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore.api import KVStore @@ -36,7 +36,7 @@ from llama_stack.providers.utils.memory.vector_store import ( VectorDBWithIndex, ) -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="vector_io") # Specifying search mode is dependent on the VectorIO provider. VECTOR_SEARCH = "vector" diff --git a/llama_stack/providers/registry/files.py b/llama_stack/providers/registry/files.py index e894debaf..ebe90310c 100644 --- a/llama_stack/providers/registry/files.py +++ b/llama_stack/providers/registry/files.py @@ -5,9 +5,11 @@ # the root directory of this source tree. from llama_stack.providers.datatypes import ( + AdapterSpec, Api, InlineProviderSpec, ProviderSpec, + remote_provider_spec, ) from llama_stack.providers.utils.sqlstore.sqlstore import sql_store_pip_packages @@ -23,4 +25,14 @@ def available_providers() -> list[ProviderSpec]: config_class="llama_stack.providers.inline.files.localfs.config.LocalfsFilesImplConfig", description="Local filesystem-based file storage provider for managing files and documents locally.", ), + remote_provider_spec( + api=Api.files, + adapter=AdapterSpec( + adapter_type="s3", + pip_packages=["boto3"] + sql_store_pip_packages, + module="llama_stack.providers.remote.files.s3", + config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig", + description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.", + ), + ), ] diff --git a/llama_stack/providers/registry/post_training.py b/llama_stack/providers/registry/post_training.py index ffd64ef7c..4443f4df1 100644 --- a/llama_stack/providers/registry/post_training.py +++ b/llama_stack/providers/registry/post_training.py @@ -5,34 +5,74 @@ # the root directory of this source tree. +from typing import cast + from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec +# We provide two versions of these providers so that distributions can package the appropriate version of torch. +# The CPU version is used for distributions that don't have GPU support -- they result in smaller container images. +torchtune_def = dict( + api=Api.post_training, + pip_packages=["torchtune==0.5.0", "torchao==0.8.0", "numpy"], + module="llama_stack.providers.inline.post_training.torchtune", + config_class="llama_stack.providers.inline.post_training.torchtune.TorchtunePostTrainingConfig", + api_dependencies=[ + Api.datasetio, + Api.datasets, + ], + description="TorchTune-based post-training provider for fine-tuning and optimizing models using Meta's TorchTune framework.", +) + +huggingface_def = dict( + api=Api.post_training, + pip_packages=["trl", "transformers", "peft", "datasets"], + module="llama_stack.providers.inline.post_training.huggingface", + config_class="llama_stack.providers.inline.post_training.huggingface.HuggingFacePostTrainingConfig", + api_dependencies=[ + Api.datasetio, + Api.datasets, + ], + description="HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem.", +) + def available_providers() -> list[ProviderSpec]: return [ InlineProviderSpec( - api=Api.post_training, - provider_type="inline::torchtune", - pip_packages=["torch", "torchtune==0.5.0", "torchao==0.8.0", "numpy"], - module="llama_stack.providers.inline.post_training.torchtune", - config_class="llama_stack.providers.inline.post_training.torchtune.TorchtunePostTrainingConfig", - api_dependencies=[ - Api.datasetio, - Api.datasets, - ], - description="TorchTune-based post-training provider for fine-tuning and optimizing models using Meta's TorchTune framework.", + **{ + **torchtune_def, + "provider_type": "inline::torchtune-cpu", + "pip_packages": ( + cast(list[str], torchtune_def["pip_packages"]) + + ["torch torchtune==0.5.0 torchao==0.8.0 --index-url https://download.pytorch.org/whl/cpu"] + ), + }, ), InlineProviderSpec( - api=Api.post_training, - provider_type="inline::huggingface", - pip_packages=["torch", "trl", "transformers", "peft", "datasets"], - module="llama_stack.providers.inline.post_training.huggingface", - config_class="llama_stack.providers.inline.post_training.huggingface.HuggingFacePostTrainingConfig", - api_dependencies=[ - Api.datasetio, - Api.datasets, - ], - description="HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem.", + **{ + **huggingface_def, + "provider_type": "inline::huggingface-cpu", + "pip_packages": ( + cast(list[str], huggingface_def["pip_packages"]) + + ["torch --index-url https://download.pytorch.org/whl/cpu"] + ), + }, + ), + InlineProviderSpec( + **{ + **torchtune_def, + "provider_type": "inline::torchtune-gpu", + "pip_packages": ( + cast(list[str], torchtune_def["pip_packages"]) + ["torch torchtune==0.5.0 torchao==0.8.0"] + ), + }, + ), + InlineProviderSpec( + **{ + **huggingface_def, + "provider_type": "inline::huggingface-gpu", + "pip_packages": (cast(list[str], huggingface_def["pip_packages"]) + ["torch"]), + }, ), remote_provider_spec( api=Api.post_training, diff --git a/llama_stack/providers/remote/files/s3/README.md b/llama_stack/providers/remote/files/s3/README.md new file mode 100644 index 000000000..0f33122c7 --- /dev/null +++ b/llama_stack/providers/remote/files/s3/README.md @@ -0,0 +1,237 @@ +# S3 Files Provider + +A remote S3-based implementation of the Llama Stack Files API that provides scalable cloud file storage with metadata persistence. + +## Features + +- **AWS S3 Storage**: Store files in AWS S3 buckets for scalable, durable storage +- **Metadata Management**: Uses SQL database for efficient file metadata queries +- **OpenAI API Compatibility**: Full compatibility with OpenAI Files API endpoints +- **Flexible Authentication**: Support for IAM roles and access keys +- **Custom S3 Endpoints**: Support for MinIO and other S3-compatible services + +## Configuration + +### Basic Configuration + +```yaml +api: files +provider_type: remote::s3 +config: + bucket_name: my-llama-stack-files + region: us-east-1 + metadata_store: + type: sqlite + db_path: ./s3_files_metadata.db +``` + +### Advanced Configuration + +```yaml +api: files +provider_type: remote::s3 +config: + bucket_name: my-llama-stack-files + region: us-east-1 + aws_access_key_id: YOUR_ACCESS_KEY + aws_secret_access_key: YOUR_SECRET_KEY + endpoint_url: https://s3.amazonaws.com # Optional for custom endpoints + metadata_store: + type: sqlite + db_path: ./s3_files_metadata.db +``` + +### Environment Variables + +The configuration supports environment variable substitution: + +```yaml +config: + bucket_name: "${env.S3_BUCKET_NAME}" + region: "${env.AWS_REGION:=us-east-1}" + aws_access_key_id: "${env.AWS_ACCESS_KEY_ID:=}" + aws_secret_access_key: "${env.AWS_SECRET_ACCESS_KEY:=}" + endpoint_url: "${env.S3_ENDPOINT_URL:=}" +``` + +Note: `S3_BUCKET_NAME` has no default value since S3 bucket names must be globally unique. + +## Authentication + +### IAM Roles (Recommended) + +For production deployments, use IAM roles: + +```yaml +config: + bucket_name: my-bucket + region: us-east-1 + # No credentials needed - will use IAM role +``` + +### Access Keys + +For development or specific use cases: + +```yaml +config: + bucket_name: my-bucket + region: us-east-1 + aws_access_key_id: AKIAIOSFODNN7EXAMPLE + aws_secret_access_key: wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY +``` + +## S3 Bucket Setup + +### Required Permissions + +The S3 provider requires the following permissions: + +```json +{ + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": [ + "s3:GetObject", + "s3:PutObject", + "s3:DeleteObject", + "s3:ListBucket" + ], + "Resource": [ + "arn:aws:s3:::your-bucket-name", + "arn:aws:s3:::your-bucket-name/*" + ] + } + ] +} +``` + +### Automatic Bucket Creation + +By default, the S3 provider expects the bucket to already exist. If you want the provider to automatically create the bucket when it doesn't exist, set `auto_create_bucket: true` in your configuration: + +```yaml +config: + bucket_name: my-bucket + auto_create_bucket: true # Will create bucket if it doesn't exist + region: us-east-1 +``` + +**Note**: When `auto_create_bucket` is enabled, the provider will need additional permissions: + +```json +{ + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": [ + "s3:GetObject", + "s3:PutObject", + "s3:DeleteObject", + "s3:ListBucket", + "s3:CreateBucket" + ], + "Resource": [ + "arn:aws:s3:::your-bucket-name", + "arn:aws:s3:::your-bucket-name/*" + ] + } + ] +} +``` + +### Bucket Policy (Optional) + +For additional security, you can add a bucket policy: + +```json +{ + "Version": "2012-10-17", + "Statement": [ + { + "Sid": "LlamaStackAccess", + "Effect": "Allow", + "Principal": { + "AWS": "arn:aws:iam::YOUR-ACCOUNT:role/LlamaStackRole" + }, + "Action": [ + "s3:GetObject", + "s3:PutObject", + "s3:DeleteObject" + ], + "Resource": "arn:aws:s3:::your-bucket-name/*" + }, + { + "Sid": "LlamaStackBucketAccess", + "Effect": "Allow", + "Principal": { + "AWS": "arn:aws:iam::YOUR-ACCOUNT:role/LlamaStackRole" + }, + "Action": [ + "s3:ListBucket" + ], + "Resource": "arn:aws:s3:::your-bucket-name" + } + ] +} +``` + +## Features + +### Metadata Persistence + +File metadata is stored in a SQL database for fast queries and OpenAI API compatibility. The metadata includes: + +- File ID +- Original filename +- Purpose (assistants, batch, etc.) +- File size in bytes +- Created and expiration timestamps + +### TTL and Cleanup + +Files currently have a fixed long expiration time (100 years). + +## Development and Testing + +### Using MinIO + +For self-hosted S3-compatible storage: + +```yaml +config: + bucket_name: test-bucket + region: us-east-1 + endpoint_url: http://localhost:9000 + aws_access_key_id: minioadmin + aws_secret_access_key: minioadmin +``` + +## Monitoring and Logging + +The provider logs important operations and errors. For production deployments, consider: + +- CloudWatch monitoring for S3 operations +- Custom metrics for file upload/download rates +- Error rate monitoring +- Performance metrics tracking + +## Error Handling + +The provider handles various error scenarios: + +- S3 connectivity issues +- Bucket access permissions +- File not found errors +- Metadata consistency checks + +## Known Limitations + +- Fixed long TTL (100 years) instead of configurable expiration +- No server-side encryption enabled by default +- No support for AWS session tokens +- No S3 key prefix organization support +- No multipart upload support (all files uploaded as single objects) diff --git a/llama_stack/providers/remote/files/s3/__init__.py b/llama_stack/providers/remote/files/s3/__init__.py new file mode 100644 index 000000000..3f5dfc88a --- /dev/null +++ b/llama_stack/providers/remote/files/s3/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from llama_stack.core.datatypes import Api + +from .config import S3FilesImplConfig + + +async def get_adapter_impl(config: S3FilesImplConfig, deps: dict[Api, Any]): + from .files import S3FilesImpl + + # TODO: authorization policies and user separation + impl = S3FilesImpl(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/remote/files/s3/config.py b/llama_stack/providers/remote/files/s3/config.py new file mode 100644 index 000000000..da20d8668 --- /dev/null +++ b/llama_stack/providers/remote/files/s3/config.py @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from pydantic import BaseModel, Field + +from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig + + +class S3FilesImplConfig(BaseModel): + """Configuration for S3-based files provider.""" + + bucket_name: str = Field(description="S3 bucket name to store files") + region: str = Field(default="us-east-1", description="AWS region where the bucket is located") + aws_access_key_id: str | None = Field(default=None, description="AWS access key ID (optional if using IAM roles)") + aws_secret_access_key: str | None = Field( + default=None, description="AWS secret access key (optional if using IAM roles)" + ) + endpoint_url: str | None = Field(default=None, description="Custom S3 endpoint URL (for MinIO, LocalStack, etc.)") + auto_create_bucket: bool = Field( + default=False, description="Automatically create the S3 bucket if it doesn't exist" + ) + metadata_store: SqlStoreConfig = Field(description="SQL store configuration for file metadata") + + @classmethod + def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]: + return { + "bucket_name": "${env.S3_BUCKET_NAME}", # no default, buckets must be globally unique + "region": "${env.AWS_REGION:=us-east-1}", + "aws_access_key_id": "${env.AWS_ACCESS_KEY_ID:=}", + "aws_secret_access_key": "${env.AWS_SECRET_ACCESS_KEY:=}", + "endpoint_url": "${env.S3_ENDPOINT_URL:=}", + "auto_create_bucket": "${env.S3_AUTO_CREATE_BUCKET:=false}", + "metadata_store": SqliteSqlStoreConfig.sample_run_config( + __distro_dir__=__distro_dir__, + db_name="s3_files_metadata.db", + ), + } diff --git a/llama_stack/providers/remote/files/s3/files.py b/llama_stack/providers/remote/files/s3/files.py new file mode 100644 index 000000000..52e0cbbf4 --- /dev/null +++ b/llama_stack/providers/remote/files/s3/files.py @@ -0,0 +1,272 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import time +import uuid +from typing import Annotated + +import boto3 +from botocore.exceptions import BotoCoreError, ClientError, NoCredentialsError +from fastapi import File, Form, Response, UploadFile + +from llama_stack.apis.common.errors import ResourceNotFoundError +from llama_stack.apis.common.responses import Order +from llama_stack.apis.files import ( + Files, + ListOpenAIFileResponse, + OpenAIFileDeleteResponse, + OpenAIFileObject, + OpenAIFilePurpose, +) +from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType +from llama_stack.providers.utils.sqlstore.sqlstore import SqlStore, sqlstore_impl + +from .config import S3FilesImplConfig + +# TODO: provider data for S3 credentials + + +def _create_s3_client(config: S3FilesImplConfig) -> boto3.client: + try: + s3_config = { + "region_name": config.region, + } + + # endpoint URL if specified (for MinIO, LocalStack, etc.) + if config.endpoint_url: + s3_config["endpoint_url"] = config.endpoint_url + + if config.aws_access_key_id and config.aws_secret_access_key: + s3_config.update( + { + "aws_access_key_id": config.aws_access_key_id, + "aws_secret_access_key": config.aws_secret_access_key, + } + ) + + return boto3.client("s3", **s3_config) + + except (BotoCoreError, NoCredentialsError) as e: + raise RuntimeError(f"Failed to initialize S3 client: {e}") from e + + +async def _create_bucket_if_not_exists(client: boto3.client, config: S3FilesImplConfig) -> None: + try: + client.head_bucket(Bucket=config.bucket_name) + except ClientError as e: + error_code = e.response["Error"]["Code"] + if error_code == "404": + if not config.auto_create_bucket: + raise RuntimeError( + f"S3 bucket '{config.bucket_name}' does not exist. " + f"Either create the bucket manually or set 'auto_create_bucket: true' in your configuration." + ) from e + try: + # For us-east-1, we can't specify LocationConstraint + if config.region == "us-east-1": + client.create_bucket(Bucket=config.bucket_name) + else: + client.create_bucket( + Bucket=config.bucket_name, + CreateBucketConfiguration={"LocationConstraint": config.region}, + ) + except ClientError as create_error: + raise RuntimeError( + f"Failed to create S3 bucket '{config.bucket_name}': {create_error}" + ) from create_error + elif error_code == "403": + raise RuntimeError(f"Access denied to S3 bucket '{config.bucket_name}'") from e + else: + raise RuntimeError(f"Failed to access S3 bucket '{config.bucket_name}': {e}") from e + + +class S3FilesImpl(Files): + """S3-based implementation of the Files API.""" + + # TODO: implement expiration, for now a silly offset + _SILLY_EXPIRATION_OFFSET = 100 * 365 * 24 * 60 * 60 + + def __init__(self, config: S3FilesImplConfig) -> None: + self._config = config + self._client: boto3.client | None = None + self._sql_store: SqlStore | None = None + + async def initialize(self) -> None: + self._client = _create_s3_client(self._config) + await _create_bucket_if_not_exists(self._client, self._config) + + self._sql_store = sqlstore_impl(self._config.metadata_store) + await self._sql_store.create_table( + "openai_files", + { + "id": ColumnDefinition(type=ColumnType.STRING, primary_key=True), + "filename": ColumnType.STRING, + "purpose": ColumnType.STRING, + "bytes": ColumnType.INTEGER, + "created_at": ColumnType.INTEGER, + "expires_at": ColumnType.INTEGER, + # TODO: add s3_etag field for integrity checking + }, + ) + + async def shutdown(self) -> None: + pass + + @property + def client(self) -> boto3.client: + assert self._client is not None, "Provider not initialized" + return self._client + + @property + def sql_store(self) -> SqlStore: + assert self._sql_store is not None, "Provider not initialized" + return self._sql_store + + async def openai_upload_file( + self, + file: Annotated[UploadFile, File()], + purpose: Annotated[OpenAIFilePurpose, Form()], + ) -> OpenAIFileObject: + file_id = f"file-{uuid.uuid4().hex}" + + filename = getattr(file, "filename", None) or "uploaded_file" + + created_at = int(time.time()) + expires_at = created_at + self._SILLY_EXPIRATION_OFFSET + content = await file.read() + file_size = len(content) + + await self.sql_store.insert( + "openai_files", + { + "id": file_id, + "filename": filename, + "purpose": purpose.value, + "bytes": file_size, + "created_at": created_at, + "expires_at": expires_at, + }, + ) + + try: + self.client.put_object( + Bucket=self._config.bucket_name, + Key=file_id, + Body=content, + # TODO: enable server-side encryption + ) + except ClientError as e: + await self.sql_store.delete("openai_files", where={"id": file_id}) + + raise RuntimeError(f"Failed to upload file to S3: {e}") from e + + return OpenAIFileObject( + id=file_id, + filename=filename, + purpose=purpose, + bytes=file_size, + created_at=created_at, + expires_at=expires_at, + ) + + async def openai_list_files( + self, + after: str | None = None, + limit: int | None = 10000, + order: Order | None = Order.desc, + purpose: OpenAIFilePurpose | None = None, + ) -> ListOpenAIFileResponse: + # this purely defensive. it should not happen because the router also default to Order.desc. + if not order: + order = Order.desc + + where_conditions = {} + if purpose: + where_conditions["purpose"] = purpose.value + + paginated_result = await self.sql_store.fetch_all( + table="openai_files", + where=where_conditions if where_conditions else None, + order_by=[("created_at", order.value)], + cursor=("id", after) if after else None, + limit=limit, + ) + + files = [ + OpenAIFileObject( + id=row["id"], + filename=row["filename"], + purpose=OpenAIFilePurpose(row["purpose"]), + bytes=row["bytes"], + created_at=row["created_at"], + expires_at=row["expires_at"], + ) + for row in paginated_result.data + ] + + return ListOpenAIFileResponse( + data=files, + has_more=paginated_result.has_more, + # empty string or None? spec says str, ref impl returns str | None, we go with spec + first_id=files[0].id if files else "", + last_id=files[-1].id if files else "", + ) + + async def openai_retrieve_file(self, file_id: str) -> OpenAIFileObject: + row = await self.sql_store.fetch_one("openai_files", where={"id": file_id}) + if not row: + raise ResourceNotFoundError(file_id, "File", "files.list()") + + return OpenAIFileObject( + id=row["id"], + filename=row["filename"], + purpose=OpenAIFilePurpose(row["purpose"]), + bytes=row["bytes"], + created_at=row["created_at"], + expires_at=row["expires_at"], + ) + + async def openai_delete_file(self, file_id: str) -> OpenAIFileDeleteResponse: + row = await self.sql_store.fetch_one("openai_files", where={"id": file_id}) + if not row: + raise ResourceNotFoundError(file_id, "File", "files.list()") + + try: + self.client.delete_object( + Bucket=self._config.bucket_name, + Key=row["id"], + ) + except ClientError as e: + if e.response["Error"]["Code"] != "NoSuchKey": + raise RuntimeError(f"Failed to delete file from S3: {e}") from e + + await self.sql_store.delete("openai_files", where={"id": file_id}) + + return OpenAIFileDeleteResponse(id=file_id, deleted=True) + + async def openai_retrieve_file_content(self, file_id: str) -> Response: + row = await self.sql_store.fetch_one("openai_files", where={"id": file_id}) + if not row: + raise ResourceNotFoundError(file_id, "File", "files.list()") + + try: + response = self.client.get_object( + Bucket=self._config.bucket_name, + Key=row["id"], + ) + # TODO: can we stream this instead of loading it into memory + content = response["Body"].read() + except ClientError as e: + if e.response["Error"]["Code"] == "NoSuchKey": + await self.sql_store.delete("openai_files", where={"id": file_id}) + raise ResourceNotFoundError(file_id, "File", "files.list()") from e + raise RuntimeError(f"Failed to download file from S3: {e}") from e + + return Response( + content=content, + media_type="application/octet-stream", + headers={"Content-Disposition": f'attachment; filename="{row["filename"]}"'}, + ) diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index bd86f7238..e907e8ec6 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -65,7 +65,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import FireworksImplConfig from .models import MODEL_ENTRIES -logger = get_logger(name=__name__, category="inference") +logger = get_logger(name=__name__, category="inference::fireworks") class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData): diff --git a/llama_stack/providers/remote/inference/llama_openai_compat/llama.py b/llama_stack/providers/remote/inference/llama_openai_compat/llama.py index 4857c6723..f2069b5e5 100644 --- a/llama_stack/providers/remote/inference/llama_openai_compat/llama.py +++ b/llama_stack/providers/remote/inference/llama_openai_compat/llama.py @@ -3,15 +3,14 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging - +from llama_stack.log import get_logger from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from .models import MODEL_ENTRIES -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="inference::llama_openai_compat") class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin): diff --git a/llama_stack/providers/remote/inference/nvidia/NVIDIA.md b/llama_stack/providers/remote/inference/nvidia/NVIDIA.md index 4a072215c..d96b29fef 100644 --- a/llama_stack/providers/remote/inference/nvidia/NVIDIA.md +++ b/llama_stack/providers/remote/inference/nvidia/NVIDIA.md @@ -41,6 +41,11 @@ client.initialize() ### Create Completion +> Note on Completion API +> +> The hosted NVIDIA Llama NIMs (e.g., `meta-llama/Llama-3.1-8B-Instruct`) with ```NVIDIA_BASE_URL="https://integrate.api.nvidia.com"``` does not support the ```completion``` method, while the locally deployed NIM does. + + ```python response = client.inference.completion( model_id="meta-llama/Llama-3.1-8B-Instruct", @@ -76,7 +81,78 @@ response = client.inference.chat_completion( print(f"Response: {response.completion_message.content}") ``` +### Tool Calling Example ### +```python +from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition + +tool_definition = ToolDefinition( + tool_name="get_weather", + description="Get current weather information for a location", + parameters={ + "location": ToolParamDefinition( + param_type="string", + description="The city and state, e.g. San Francisco, CA", + required=True, + ), + "unit": ToolParamDefinition( + param_type="string", + description="Temperature unit (celsius or fahrenheit)", + required=False, + default="celsius", + ), + }, +) + +tool_response = client.inference.chat_completion( + model_id="meta-llama/Llama-3.1-8B-Instruct", + messages=[{"role": "user", "content": "What's the weather like in San Francisco?"}], + tools=[tool_definition], +) + +print(f"Tool Response: {tool_response.completion_message.content}") +if tool_response.completion_message.tool_calls: + for tool_call in tool_response.completion_message.tool_calls: + print(f"Tool Called: {tool_call.tool_name}") + print(f"Arguments: {tool_call.arguments}") +``` + +### Structured Output Example +```python +from llama_stack.apis.inference import JsonSchemaResponseFormat, ResponseFormatType + +person_schema = { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + "occupation": {"type": "string"}, + }, + "required": ["name", "age", "occupation"], +} + +response_format = JsonSchemaResponseFormat( + type=ResponseFormatType.json_schema, json_schema=person_schema +) + +structured_response = client.inference.chat_completion( + model_id="meta-llama/Llama-3.1-8B-Instruct", + messages=[ + { + "role": "user", + "content": "Create a profile for a fictional person named Alice who is 30 years old and is a software engineer. ", + } + ], + response_format=response_format, +) + +print(f"Structured Response: {structured_response.completion_message.content}") +``` + ### Create Embeddings +> Note on OpenAI embeddings compatibility +> +> NVIDIA asymmetric embedding models (e.g., `nvidia/llama-3.2-nv-embedqa-1b-v2`) require an `input_type` parameter not present in the standard OpenAI embeddings API. The NVIDIA Inference Adapter automatically sets `input_type="query"` when using the OpenAI-compatible embeddings endpoint for NVIDIA. For passage embeddings, use the `embeddings` API with `task_type="document"`. + ```python response = client.inference.embeddings( model_id="nvidia/llama-3.2-nv-embedqa-1b-v2", diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index 7bc3fd0c9..a5475bc92 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -4,11 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging import warnings from collections.abc import AsyncIterator -from openai import APIConnectionError, BadRequestError +from openai import NOT_GIVEN, APIConnectionError from llama_stack.apis.common.content_types import ( InterleavedContent, @@ -27,12 +26,16 @@ from llama_stack.apis.inference import ( Inference, LogProbConfig, Message, + OpenAIEmbeddingData, + OpenAIEmbeddingsResponse, + OpenAIEmbeddingUsage, ResponseFormat, SamplingParams, TextTruncation, ToolChoice, ToolConfig, ) +from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, @@ -54,7 +57,7 @@ from .openai_utils import ( ) from .utils import _is_nvidia_hosted -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="inference::nvidia") class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper): @@ -194,15 +197,11 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper): } extra_body["input_type"] = task_type_options[task_type] - try: - response = await self.client.embeddings.create( - model=provider_model_id, - input=input, - extra_body=extra_body, - ) - except BadRequestError as e: - raise ValueError(f"Failed to get embeddings: {e}") from e - + response = await self.client.embeddings.create( + model=provider_model_id, + input=input, + extra_body=extra_body, + ) # # OpenAI: CreateEmbeddingResponse(data=[Embedding(embedding=list[float], ...)], ...) # -> @@ -210,6 +209,57 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper): # return EmbeddingsResponse(embeddings=[embedding.embedding for embedding in response.data]) + async def openai_embeddings( + self, + model: str, + input: str | list[str], + encoding_format: str | None = "float", + dimensions: int | None = None, + user: str | None = None, + ) -> OpenAIEmbeddingsResponse: + """ + OpenAI-compatible embeddings for NVIDIA NIM. + + Note: NVIDIA NIM asymmetric embedding models require an "input_type" field not present in the standard OpenAI embeddings API. + We default this to "query" to ensure requests succeed when using the + OpenAI-compatible endpoint. For passage embeddings, use the embeddings API with + `task_type='document'`. + """ + extra_body: dict[str, object] = {"input_type": "query"} + logger.warning( + "NVIDIA OpenAI-compatible embeddings: defaulting to input_type='query'. " + "For passage embeddings, use the embeddings API with task_type='document'." + ) + + response = await self.client.embeddings.create( + model=await self._get_provider_model_id(model), + input=input, + encoding_format=encoding_format if encoding_format is not None else NOT_GIVEN, + dimensions=dimensions if dimensions is not None else NOT_GIVEN, + user=user if user is not None else NOT_GIVEN, + extra_body=extra_body, + ) + + data = [] + for i, embedding_data in enumerate(response.data): + data.append( + OpenAIEmbeddingData( + embedding=embedding_data.embedding, + index=i, + ) + ) + + usage = OpenAIEmbeddingUsage( + prompt_tokens=response.usage.prompt_tokens, + total_tokens=response.usage.total_tokens, + ) + + return OpenAIEmbeddingsResponse( + data=data, + model=response.model, + usage=usage, + ) + async def chat_completion( self, model_id: str, diff --git a/llama_stack/providers/remote/inference/nvidia/utils.py b/llama_stack/providers/remote/inference/nvidia/utils.py index 74019999e..b8431e859 100644 --- a/llama_stack/providers/remote/inference/nvidia/utils.py +++ b/llama_stack/providers/remote/inference/nvidia/utils.py @@ -4,13 +4,13 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging - import httpx +from llama_stack.log import get_logger + from . import NVIDIAConfig -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="inference::nvidia") def _is_nvidia_hosted(config: NVIDIAConfig) -> bool: diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index a93421536..fcaf5ee92 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -85,7 +85,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .models import MODEL_ENTRIES -logger = get_logger(name=__name__, category="inference") +logger = get_logger(name=__name__, category="inference::ollama") class OllamaInferenceAdapter( @@ -619,28 +619,6 @@ class OllamaInferenceAdapter( response.id = id return response - async def batch_completion( - self, - model_id: str, - content_batch: list[InterleavedContent], - sampling_params: SamplingParams | None = None, - response_format: ResponseFormat | None = None, - logprobs: LogProbConfig | None = None, - ): - raise NotImplementedError("Batch completion is not supported for Ollama") - - async def batch_chat_completion( - self, - model_id: str, - messages_batch: list[list[Message]], - sampling_params: SamplingParams | None = None, - tools: list[ToolDefinition] | None = None, - tool_config: ToolConfig | None = None, - response_format: ResponseFormat | None = None, - logprobs: LogProbConfig | None = None, - ): - raise NotImplementedError("Batch chat completion is not supported for Ollama") - async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]: async def _convert_content(content) -> dict: diff --git a/llama_stack/providers/remote/inference/openai/openai.py b/llama_stack/providers/remote/inference/openai/openai.py index 865258559..0f73c9321 100644 --- a/llama_stack/providers/remote/inference/openai/openai.py +++ b/llama_stack/providers/remote/inference/openai/openai.py @@ -4,15 +4,14 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging - +from llama_stack.log import get_logger from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from .config import OpenAIConfig from .models import MODEL_ENTRIES -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="inference::openai") # diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index 323831845..97c72d14c 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -5,7 +5,6 @@ # the root directory of this source tree. -import logging from collections.abc import AsyncGenerator from huggingface_hub import AsyncInferenceClient, HfApi @@ -34,6 +33,7 @@ from llama_stack.apis.inference import ( ToolPromptFormat, ) from llama_stack.apis.models import Model +from llama_stack.log import get_logger from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.model_registry import ( @@ -58,7 +58,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="inference::tgi") def build_hf_repo_model_entries(): diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index a06e4173b..54c76607f 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -61,7 +61,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import TogetherImplConfig from .models import MODEL_ENTRIES -logger = get_logger(name=__name__, category="inference") +logger = get_logger(name=__name__, category="inference::together") class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData): diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index ac626874c..9e9a80ca5 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -85,7 +85,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import VLLMInferenceAdapterConfig -log = get_logger(name=__name__, category="inference") +log = get_logger(name=__name__, category="inference::vllm") def build_hf_repo_model_entries(): @@ -711,25 +711,3 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): user=user, ) return await self.client.chat.completions.create(**params) # type: ignore - - async def batch_completion( - self, - model_id: str, - content_batch: list[InterleavedContent], - sampling_params: SamplingParams | None = None, - response_format: ResponseFormat | None = None, - logprobs: LogProbConfig | None = None, - ): - raise NotImplementedError("Batch completion is not supported for Ollama") - - async def batch_chat_completion( - self, - model_id: str, - messages_batch: list[list[Message]], - sampling_params: SamplingParams | None = None, - tools: list[ToolDefinition] | None = None, - tool_config: ToolConfig | None = None, - response_format: ResponseFormat | None = None, - logprobs: LogProbConfig | None = None, - ): - raise NotImplementedError("Batch chat completion is not supported for Ollama") diff --git a/llama_stack/providers/remote/post_training/nvidia/utils.py b/llama_stack/providers/remote/post_training/nvidia/utils.py index d6e1016b2..162951ff3 100644 --- a/llama_stack/providers/remote/post_training/nvidia/utils.py +++ b/llama_stack/providers/remote/post_training/nvidia/utils.py @@ -4,18 +4,18 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging import warnings from typing import Any from pydantic import BaseModel from llama_stack.apis.post_training import TrainingConfig +from llama_stack.log import get_logger from llama_stack.providers.remote.post_training.nvidia.config import SFTLoRADefaultConfig from .config import NvidiaPostTrainingConfig -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="post_training::nvidia") def warn_unsupported_params(config_dict: Any, supported_keys: set[str], config_name: str) -> None: diff --git a/llama_stack/providers/remote/safety/bedrock/bedrock.py b/llama_stack/providers/remote/safety/bedrock/bedrock.py index 1895e7507..8855e02a4 100644 --- a/llama_stack/providers/remote/safety/bedrock/bedrock.py +++ b/llama_stack/providers/remote/safety/bedrock/bedrock.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import json -import logging from typing import Any from llama_stack.apis.inference import Message @@ -16,12 +15,13 @@ from llama_stack.apis.safety import ( ViolationLevel, ) from llama_stack.apis.shields import Shield +from llama_stack.log import get_logger from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.utils.bedrock.client import create_bedrock_client from .config import BedrockSafetyConfig -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="safety::bedrock") class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate): diff --git a/llama_stack/providers/remote/safety/nvidia/nvidia.py b/llama_stack/providers/remote/safety/nvidia/nvidia.py index 7f17b1cb6..65f901da2 100644 --- a/llama_stack/providers/remote/safety/nvidia/nvidia.py +++ b/llama_stack/providers/remote/safety/nvidia/nvidia.py @@ -4,20 +4,20 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging from typing import Any import requests from llama_stack.apis.inference import Message -from llama_stack.apis.safety import RunShieldResponse, Safety, SafetyViolation, ViolationLevel +from llama_stack.apis.safety import ModerationObject, RunShieldResponse, Safety, SafetyViolation, ViolationLevel from llama_stack.apis.shields import Shield +from llama_stack.log import get_logger from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new from .config import NVIDIASafetyConfig -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="safety::nvidia") class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate): @@ -67,6 +67,9 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate): self.shield = NeMoGuardrails(self.config, shield.shield_id) return await self.shield.run(messages) + async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject: + raise NotImplementedError("NVIDIA safety provider currently does not implement run_moderation") + class NeMoGuardrails: """ diff --git a/llama_stack/providers/remote/safety/sambanova/sambanova.py b/llama_stack/providers/remote/safety/sambanova/sambanova.py index 6c7190afe..2beb5e0ea 100644 --- a/llama_stack/providers/remote/safety/sambanova/sambanova.py +++ b/llama_stack/providers/remote/safety/sambanova/sambanova.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import json -import logging from typing import Any import litellm @@ -20,12 +19,13 @@ from llama_stack.apis.safety import ( ) from llama_stack.apis.shields import Shield from llama_stack.core.request_headers import NeedsRequestProviderData +from llama_stack.log import get_logger from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict_new from .config import SambaNovaSafetyConfig -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="safety::sambanova") CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?" diff --git a/llama_stack/providers/remote/vector_io/chroma/chroma.py b/llama_stack/providers/remote/vector_io/chroma/chroma.py index 8f252711b..a9ec644ef 100644 --- a/llama_stack/providers/remote/vector_io/chroma/chroma.py +++ b/llama_stack/providers/remote/vector_io/chroma/chroma.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import asyncio import json -import logging from typing import Any from urllib.parse import urlparse @@ -20,6 +19,7 @@ from llama_stack.apis.vector_io import ( QueryChunksResponse, VectorIO, ) +from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig from llama_stack.providers.utils.kvstore import kvstore_impl @@ -33,7 +33,7 @@ from llama_stack.providers.utils.memory.vector_store import ( from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="vector_io::chroma") ChromaClientType = chromadb.api.AsyncClientAPI | chromadb.api.ClientAPI diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py index c659bdf6c..e07e8ff12 100644 --- a/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import asyncio -import logging import os from typing import Any @@ -21,6 +20,7 @@ from llama_stack.apis.vector_io import ( QueryChunksResponse, VectorIO, ) +from llama_stack.log import get_logger from llama_stack.providers.datatypes import VectorDBsProtocolPrivate from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig from llama_stack.providers.utils.kvstore import kvstore_impl @@ -36,7 +36,7 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="vector_io::milvus") VERSION = "v3" VECTOR_DBS_PREFIX = f"vector_dbs:milvus:{VERSION}::" diff --git a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py index d2a5d910b..1c8d361c2 100644 --- a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py +++ b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging from typing import Any import psycopg2 @@ -22,6 +21,7 @@ from llama_stack.apis.vector_io import ( QueryChunksResponse, VectorIO, ) +from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore.api import KVStore @@ -34,7 +34,7 @@ from llama_stack.providers.utils.memory.vector_store import ( from .config import PGVectorVectorIOConfig -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="vector_io::pgvector") VERSION = "v3" VECTOR_DBS_PREFIX = f"vector_dbs:pgvector:{VERSION}::" diff --git a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py index 018015780..0a0faa23a 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py +++ b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import asyncio -import logging import uuid from typing import Any @@ -24,6 +23,7 @@ from llama_stack.apis.vector_io import ( VectorStoreChunkingStrategy, VectorStoreFileObject, ) +from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl @@ -36,7 +36,7 @@ from llama_stack.providers.utils.memory.vector_store import ( from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="vector_io::qdrant") CHUNK_ID_KEY = "_chunk_id" # KV store prefixes for vector databases diff --git a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py index 966724848..59b6bf124 100644 --- a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py +++ b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import json -import logging from typing import Any import weaviate @@ -19,6 +18,7 @@ from llama_stack.apis.files.files import Files from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.core.request_headers import NeedsRequestProviderData +from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore.api import KVStore @@ -34,7 +34,7 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti from .config import WeaviateVectorIOConfig -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="vector_io::weaviate") VERSION = "v3" VECTOR_DBS_PREFIX = f"vector_dbs:weaviate:{VERSION}::" diff --git a/llama_stack/providers/utils/inference/embedding_mixin.py b/llama_stack/providers/utils/inference/embedding_mixin.py index 32e89f987..65ba2854b 100644 --- a/llama_stack/providers/utils/inference/embedding_mixin.py +++ b/llama_stack/providers/utils/inference/embedding_mixin.py @@ -5,10 +5,11 @@ # the root directory of this source tree. import base64 -import logging import struct from typing import TYPE_CHECKING +from llama_stack.log import get_logger + if TYPE_CHECKING: from sentence_transformers import SentenceTransformer @@ -27,7 +28,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import interleaved_con EMBEDDING_MODELS = {} -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="providers::utils") class SentenceTransformerEmbeddingMixin: diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index da2e634f6..9bd43e4c9 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -54,7 +54,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, ) -logger = get_logger(name=__name__, category="inference") +logger = get_logger(name=__name__, category="providers::utils") class LiteLLMOpenAIMixin( @@ -429,28 +429,6 @@ class LiteLLMOpenAIMixin( ) return await litellm.acompletion(**params) - async def batch_completion( - self, - model_id: str, - content_batch: list[InterleavedContent], - sampling_params: SamplingParams | None = None, - response_format: ResponseFormat | None = None, - logprobs: LogProbConfig | None = None, - ): - raise NotImplementedError("Batch completion is not supported for OpenAI Compat") - - async def batch_chat_completion( - self, - model_id: str, - messages_batch: list[list[Message]], - sampling_params: SamplingParams | None = None, - tools: list[ToolDefinition] | None = None, - tool_config: ToolConfig | None = None, - response_format: ResponseFormat | None = None, - logprobs: LogProbConfig | None = None, - ): - raise NotImplementedError("Batch chat completion is not supported for OpenAI Compat") - async def check_model_availability(self, model: str) -> bool: """ Check if a specific model is available via LiteLLM for the current diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index ddb3bda8c..44add8f9e 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -17,7 +17,7 @@ from llama_stack.providers.utils.inference import ( ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR, ) -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="providers::utils") class RemoteInferenceProviderConfig(BaseModel): diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 5e6c26884..55c2ac0ad 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import base64 import json -import logging import struct import time import uuid @@ -122,6 +121,7 @@ from llama_stack.apis.inference import ( from llama_stack.apis.inference import ( OpenAIChoice as OpenAIChatCompletionChoice, ) +from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import ( BuiltinTool, StopReason, @@ -134,7 +134,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( decode_assistant_message, ) -logger = logging.getLogger(__name__) +logger = get_logger(name=__name__, category="providers::utils") class OpenAICompatCompletionChoiceDelta(BaseModel): diff --git a/llama_stack/providers/utils/inference/openai_mixin.py b/llama_stack/providers/utils/inference/openai_mixin.py index 72286dffb..f60deee6e 100644 --- a/llama_stack/providers/utils/inference/openai_mixin.py +++ b/llama_stack/providers/utils/inference/openai_mixin.py @@ -25,7 +25,7 @@ from llama_stack.apis.inference import ( from llama_stack.log import get_logger from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="providers::utils") class OpenAIMixin(ABC): diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index bb9a91b97..a93326e41 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -58,7 +58,7 @@ from llama_stack.models.llama.sku_list import resolve_model from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal from llama_stack.providers.utils.inference import supported_inference_models -log = get_logger(name=__name__, category="inference") +log = get_logger(name=__name__, category="providers::utils") class ChatCompletionRequestWithRawContent(ChatCompletionRequest): diff --git a/llama_stack/providers/utils/kvstore/mongodb/mongodb.py b/llama_stack/providers/utils/kvstore/mongodb/mongodb.py index 3842773d9..bab87a4aa 100644 --- a/llama_stack/providers/utils/kvstore/mongodb/mongodb.py +++ b/llama_stack/providers/utils/kvstore/mongodb/mongodb.py @@ -4,16 +4,16 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging from datetime import datetime from pymongo import AsyncMongoClient +from llama_stack.log import get_logger from llama_stack.providers.utils.kvstore import KVStore from ..config import MongoDBKVStoreConfig -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="providers::utils") class MongoDBKVStoreImpl(KVStore): diff --git a/llama_stack/providers/utils/kvstore/postgres/postgres.py b/llama_stack/providers/utils/kvstore/postgres/postgres.py index cabb4c512..56d6dbb48 100644 --- a/llama_stack/providers/utils/kvstore/postgres/postgres.py +++ b/llama_stack/providers/utils/kvstore/postgres/postgres.py @@ -4,16 +4,17 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import logging from datetime import datetime import psycopg2 from psycopg2.extras import DictCursor +from llama_stack.log import get_logger + from ..api import KVStore from ..config import PostgresKVStoreConfig -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="providers::utils") class PostgresKVStoreImpl(KVStore): diff --git a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py index 120d0d4fc..3acdcf293 100644 --- a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py +++ b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py @@ -44,7 +44,7 @@ from llama_stack.providers.utils.memory.vector_store import ( make_overlapped_chunks, ) -logger = get_logger(__name__, category="vector_io") +logger = get_logger(name=__name__, category="providers::utils") # Constants for OpenAI vector stores CHUNK_MULTIPLIER = 5 diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index 6ae5bb521..b74080384 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import base64 import io -import logging import re import time from abc import ABC, abstractmethod @@ -26,6 +25,7 @@ from llama_stack.apis.common.content_types import ( from llama_stack.apis.tools import RAGDocument from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse +from llama_stack.log import get_logger from llama_stack.models.llama.llama3.tokenizer import Tokenizer from llama_stack.providers.datatypes import Api from llama_stack.providers.utils.inference.prompt_adapter import ( @@ -33,7 +33,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( ) from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="providers::utils") class ChunkForDeletion(BaseModel): diff --git a/llama_stack/providers/utils/scheduler.py b/llama_stack/providers/utils/scheduler.py index 65c3d2898..146591b2f 100644 --- a/llama_stack/providers/utils/scheduler.py +++ b/llama_stack/providers/utils/scheduler.py @@ -17,7 +17,7 @@ from pydantic import BaseModel from llama_stack.log import get_logger -logger = get_logger(name=__name__, category="scheduler") +logger = get_logger(name=__name__, category="providers::utils") # TODO: revisit the list of possible statuses when defining a more coherent diff --git a/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py b/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py index ccc835768..867ba2f55 100644 --- a/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py +++ b/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py @@ -17,7 +17,7 @@ from llama_stack.log import get_logger from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore from .sqlstore import SqlStoreType -logger = get_logger(name=__name__, category="authorized_sqlstore") +logger = get_logger(name=__name__, category="providers::utils") # Hardcoded copy of the default policy that our SQL filtering implements # WARNING: If default_policy() changes, this constant must be updated accordingly diff --git a/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py b/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py index 6414929db..f75c35314 100644 --- a/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py +++ b/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py @@ -22,6 +22,7 @@ from sqlalchemy import ( text, ) from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine +from sqlalchemy.ext.asyncio.engine import AsyncEngine from llama_stack.apis.common.responses import PaginatedResponse from llama_stack.log import get_logger @@ -29,7 +30,7 @@ from llama_stack.log import get_logger from .api import ColumnDefinition, ColumnType, SqlStore from .sqlstore import SqlAlchemySqlStoreConfig -logger = get_logger(name=__name__, category="sqlstore") +logger = get_logger(name=__name__, category="providers::utils") TYPE_MAPPING: dict[ColumnType, Any] = { ColumnType.INTEGER: Integer, @@ -45,9 +46,12 @@ TYPE_MAPPING: dict[ColumnType, Any] = { class SqlAlchemySqlStoreImpl(SqlStore): def __init__(self, config: SqlAlchemySqlStoreConfig): self.config = config - self.async_session = async_sessionmaker(create_async_engine(config.engine_str)) + self.async_session = async_sessionmaker(self.create_engine()) self.metadata = MetaData() + def create_engine(self) -> AsyncEngine: + return create_async_engine(self.config.engine_str, pool_pre_ping=True) + async def create_table( self, table: str, @@ -83,7 +87,7 @@ class SqlAlchemySqlStoreImpl(SqlStore): else: sqlalchemy_table = self.metadata.tables[table] - engine = create_async_engine(self.config.engine_str) + engine = self.create_engine() async with engine.begin() as conn: await conn.run_sync(self.metadata.create_all, tables=[sqlalchemy_table], checkfirst=True) @@ -241,7 +245,7 @@ class SqlAlchemySqlStoreImpl(SqlStore): nullable: bool = True, ) -> None: """Add a column to an existing table if the column doesn't already exist.""" - engine = create_async_engine(self.config.engine_str) + engine = self.create_engine() try: async with engine.begin() as conn: diff --git a/llama_stack/providers/utils/telemetry/sqlite_trace_store.py b/llama_stack/providers/utils/telemetry/sqlite_trace_store.py index 8dd6061a6..71480364c 100644 --- a/llama_stack/providers/utils/telemetry/sqlite_trace_store.py +++ b/llama_stack/providers/utils/telemetry/sqlite_trace_store.py @@ -5,12 +5,23 @@ # the root directory of this source tree. import json -from datetime import datetime +from datetime import UTC, datetime from typing import Protocol import aiosqlite -from llama_stack.apis.telemetry import QueryCondition, Span, SpanWithStatus, Trace +from llama_stack.apis.telemetry import ( + MetricDataPoint, + MetricLabel, + MetricLabelMatcher, + MetricQueryType, + MetricSeries, + QueryCondition, + QueryMetricsResponse, + Span, + SpanWithStatus, + Trace, +) class TraceStore(Protocol): @@ -29,11 +40,192 @@ class TraceStore(Protocol): max_depth: int | None = None, ) -> dict[str, SpanWithStatus]: ... + async def query_metrics( + self, + metric_name: str, + start_time: datetime, + end_time: datetime | None = None, + granularity: str | None = "1d", + query_type: MetricQueryType = MetricQueryType.RANGE, + label_matchers: list[MetricLabelMatcher] | None = None, + ) -> QueryMetricsResponse: ... + class SQLiteTraceStore(TraceStore): def __init__(self, conn_string: str): self.conn_string = conn_string + async def query_metrics( + self, + metric_name: str, + start_time: datetime, + end_time: datetime | None = None, + granularity: str | None = None, + query_type: MetricQueryType = MetricQueryType.RANGE, + label_matchers: list[MetricLabelMatcher] | None = None, + ) -> QueryMetricsResponse: + if end_time is None: + end_time = datetime.now(UTC) + + # Build base query + if query_type == MetricQueryType.INSTANT: + query = """ + SELECT + se.name, + SUM(CAST(json_extract(se.attributes, '$.value') AS REAL)) as value, + json_extract(se.attributes, '$.unit') as unit, + se.attributes + FROM span_events se + WHERE se.name = ? + AND se.timestamp BETWEEN ? AND ? + """ + else: + if granularity: + time_format = self._get_time_format_for_granularity(granularity) + query = f""" + SELECT + se.name, + SUM(CAST(json_extract(se.attributes, '$.value') AS REAL)) as value, + json_extract(se.attributes, '$.unit') as unit, + se.attributes, + strftime('{time_format}', se.timestamp) as bucket_start + FROM span_events se + WHERE se.name = ? + AND se.timestamp BETWEEN ? AND ? + """ + else: + query = """ + SELECT + se.name, + json_extract(se.attributes, '$.value') as value, + json_extract(se.attributes, '$.unit') as unit, + se.attributes, + se.timestamp + FROM span_events se + WHERE se.name = ? + AND se.timestamp BETWEEN ? AND ? + """ + + params = [f"metric.{metric_name}", start_time.isoformat(), end_time.isoformat()] + + # Labels that will be attached to the MetricSeries (preserve matcher labels) + all_labels: list[MetricLabel] = [] + matcher_label_names = set() + if label_matchers: + for matcher in label_matchers: + json_path = f"$.{matcher.name}" + if matcher.operator == "=": + query += f" AND json_extract(se.attributes, '{json_path}') = ?" + params.append(matcher.value) + elif matcher.operator == "!=": + query += f" AND json_extract(se.attributes, '{json_path}') != ?" + params.append(matcher.value) + elif matcher.operator == "=~": + query += f" AND json_extract(se.attributes, '{json_path}') LIKE ?" + params.append(f"%{matcher.value}%") + elif matcher.operator == "!~": + query += f" AND json_extract(se.attributes, '{json_path}') NOT LIKE ?" + params.append(f"%{matcher.value}%") + # Preserve filter context in output + all_labels.append(MetricLabel(name=matcher.name, value=str(matcher.value))) + matcher_label_names.add(matcher.name) + + # GROUP BY / ORDER BY logic + if query_type == MetricQueryType.RANGE and granularity: + group_time_format = self._get_time_format_for_granularity(granularity) + query += f" GROUP BY strftime('{group_time_format}', se.timestamp), json_extract(se.attributes, '$.unit')" + query += " ORDER BY bucket_start" + elif query_type == MetricQueryType.INSTANT: + query += " GROUP BY json_extract(se.attributes, '$.unit')" + else: + query += " ORDER BY se.timestamp" + + # Execute query + async with aiosqlite.connect(self.conn_string) as conn: + conn.row_factory = aiosqlite.Row + async with conn.execute(query, params) as cursor: + rows = await cursor.fetchall() + + if not rows: + return QueryMetricsResponse(data=[]) + + data_points = [] + # We want to add attribute labels, but only those not already present as matcher labels. + attr_label_names = set() + for row in rows: + # Parse JSON attributes safely, if there are no attributes (weird), just don't add the labels to the result. + try: + attributes = json.loads(row["attributes"] or "{}") + except (TypeError, json.JSONDecodeError): + attributes = {} + + value = row["value"] + unit = row["unit"] or "" + + # Add labels from attributes without duplicating matcher labels, if we don't do this, there will be a lot of duplicate label in the result. + for k, v in attributes.items(): + if k not in ["value", "unit"] and k not in matcher_label_names and k not in attr_label_names: + all_labels.append(MetricLabel(name=k, value=str(v))) + attr_label_names.add(k) + + # Determine timestamp + if query_type == MetricQueryType.RANGE and granularity: + try: + bucket_start_raw = row["bucket_start"] + except KeyError as e: + raise ValueError( + "DB did not have a bucket_start time in row when using granularity, this indicates improper formatting" + ) from e + # this value could also be there, but be NULL, I think. + if bucket_start_raw is None: + raise ValueError("bucket_start is None check time format and data") + bucket_start = datetime.fromisoformat(bucket_start_raw) + timestamp = int(bucket_start.timestamp()) + elif query_type == MetricQueryType.INSTANT: + timestamp = int(datetime.now(UTC).timestamp()) + else: + try: + timestamp_raw = row["timestamp"] + except KeyError as e: + raise ValueError( + "DB did not have a timestamp in row, this indicates improper formatting" + ) from e + # this value could also be there, but be NULL, I think. + if timestamp_raw is None: + raise ValueError("timestamp is None check time format and data") + timestamp_iso = datetime.fromisoformat(timestamp_raw) + timestamp = int(timestamp_iso.timestamp()) + + data_points.append( + MetricDataPoint( + timestamp=timestamp, + value=value, + unit=unit, + ) + ) + + metric_series = [MetricSeries(metric=metric_name, labels=all_labels, values=data_points)] + return QueryMetricsResponse(data=metric_series) + + def _get_time_format_for_granularity(self, granularity: str | None) -> str: + """Get the SQLite strftime format string for a given granularity. + Args: + granularity: Granularity string (e.g., "1m", "5m", "1h", "1d") + Returns: + SQLite strftime format string for the granularity + """ + if granularity is None: + raise ValueError("granularity cannot be None for this method - use separate logic for no aggregation") + + if granularity.endswith("d"): + return "%Y-%m-%d 00:00:00" + elif granularity.endswith("h"): + return "%Y-%m-%d %H:00:00" + elif granularity.endswith("m"): + return "%Y-%m-%d %H:%M:00" + else: + return "%Y-%m-%d %H:%M:00" # Default to most granular which will give us the most timestamps. + async def query_traces( self, attribute_filters: list[QueryCondition] | None = None, diff --git a/llama_stack/providers/utils/telemetry/tracing.py b/llama_stack/providers/utils/telemetry/tracing.py index 7080e774a..7694003b5 100644 --- a/llama_stack/providers/utils/telemetry/tracing.py +++ b/llama_stack/providers/utils/telemetry/tracing.py @@ -6,7 +6,7 @@ import asyncio import contextvars -import logging +import logging # allow-direct-logging import queue import random import sys diff --git a/llama_stack/ui/app/chat-playground/page.test.tsx b/llama_stack/ui/app/chat-playground/page.test.tsx new file mode 100644 index 000000000..54c15f95a --- /dev/null +++ b/llama_stack/ui/app/chat-playground/page.test.tsx @@ -0,0 +1,587 @@ +import React from "react"; +import { + render, + screen, + fireEvent, + waitFor, + act, +} from "@testing-library/react"; +import "@testing-library/jest-dom"; +import ChatPlaygroundPage from "./page"; + +const mockClient = { + agents: { + list: jest.fn(), + create: jest.fn(), + retrieve: jest.fn(), + delete: jest.fn(), + session: { + list: jest.fn(), + create: jest.fn(), + delete: jest.fn(), + retrieve: jest.fn(), + }, + turn: { + create: jest.fn(), + }, + }, + models: { + list: jest.fn(), + }, + toolgroups: { + list: jest.fn(), + }, +}; + +jest.mock("@/hooks/use-auth-client", () => ({ + useAuthClient: jest.fn(() => mockClient), +})); + +jest.mock("@/components/chat-playground/chat", () => ({ + Chat: jest.fn( + ({ + className, + messages, + handleSubmit, + input, + handleInputChange, + isGenerating, + append, + suggestions, + }) => ( +
+
{messages.length}
+ + + {suggestions?.map((suggestion: string, index: number) => ( + + ))} +
+ ) + ), +})); + +jest.mock("@/components/chat-playground/conversations", () => ({ + SessionManager: jest.fn(({ selectedAgentId, onNewSession }) => ( +
+ {selectedAgentId && ( + <> +
{selectedAgentId}
+ + + )} +
+ )), + SessionUtils: { + saveCurrentSessionId: jest.fn(), + loadCurrentSessionId: jest.fn(), + loadCurrentAgentId: jest.fn(), + saveCurrentAgentId: jest.fn(), + clearCurrentSession: jest.fn(), + saveSessionData: jest.fn(), + loadSessionData: jest.fn(), + saveAgentConfig: jest.fn(), + loadAgentConfig: jest.fn(), + clearAgentCache: jest.fn(), + createDefaultSession: jest.fn(() => ({ + id: "test-session-123", + name: "Default Session", + messages: [], + selectedModel: "", + systemMessage: "You are a helpful assistant.", + agentId: "test-agent-123", + createdAt: Date.now(), + updatedAt: Date.now(), + })), + }, +})); + +const mockAgents = [ + { + agent_id: "agent_123", + agent_config: { + name: "Test Agent", + instructions: "You are a test assistant.", + }, + }, + { + agent_id: "agent_456", + agent_config: { + agent_name: "Another Agent", + instructions: "You are another assistant.", + }, + }, +]; + +const mockModels = [ + { + identifier: "test-model-1", + model_type: "llm", + }, + { + identifier: "test-model-2", + model_type: "llm", + }, +]; + +const mockToolgroups = [ + { + identifier: "builtin::rag", + provider_id: "test-provider", + type: "tool_group", + provider_resource_id: "test-resource", + }, +]; + +describe("ChatPlaygroundPage", () => { + beforeEach(() => { + jest.clearAllMocks(); + Element.prototype.scrollIntoView = jest.fn(); + mockClient.agents.list.mockResolvedValue({ data: mockAgents }); + mockClient.models.list.mockResolvedValue(mockModels); + mockClient.toolgroups.list.mockResolvedValue(mockToolgroups); + mockClient.agents.session.create.mockResolvedValue({ + session_id: "new-session-123", + }); + mockClient.agents.session.list.mockResolvedValue({ data: [] }); + mockClient.agents.session.retrieve.mockResolvedValue({ + session_id: "test-session", + session_name: "Test Session", + started_at: new Date().toISOString(), + turns: [], + }); // No turns by default + mockClient.agents.retrieve.mockResolvedValue({ + agent_id: "test-agent", + agent_config: { + toolgroups: ["builtin::rag"], + instructions: "Test instructions", + model: "test-model", + }, + }); + mockClient.agents.delete.mockResolvedValue(undefined); + }); + + describe("Agent Selector Rendering", () => { + test("shows agent selector when agents are available", async () => { + await act(async () => { + render(); + }); + + await waitFor(() => { + expect(screen.getByText("Agent Session:")).toBeInTheDocument(); + expect(screen.getAllByRole("combobox")).toHaveLength(2); + expect(screen.getByText("+ New Agent")).toBeInTheDocument(); + expect(screen.getByText("Clear Chat")).toBeInTheDocument(); + }); + }); + + test("does not show agent selector when no agents are available", async () => { + mockClient.agents.list.mockResolvedValue({ data: [] }); + + await act(async () => { + render(); + }); + + await waitFor(() => { + expect(screen.queryByText("Agent Session:")).not.toBeInTheDocument(); + expect(screen.getAllByRole("combobox")).toHaveLength(1); + expect(screen.getByText("+ New Agent")).toBeInTheDocument(); + expect(screen.queryByText("Clear Chat")).not.toBeInTheDocument(); + }); + }); + + test("does not show agent selector while loading", async () => { + mockClient.agents.list.mockImplementation(() => new Promise(() => {})); + + await act(async () => { + render(); + }); + + expect(screen.queryByText("Agent Session:")).not.toBeInTheDocument(); + expect(screen.getAllByRole("combobox")).toHaveLength(1); + expect(screen.getByText("+ New Agent")).toBeInTheDocument(); + expect(screen.queryByText("Clear Chat")).not.toBeInTheDocument(); + }); + + test("shows agent options in selector", async () => { + await act(async () => { + render(); + }); + + await waitFor(() => { + const agentCombobox = screen.getAllByRole("combobox").find(element => { + return ( + element.textContent?.includes("Test Agent") || + element.textContent?.includes("Select Agent") + ); + }); + expect(agentCombobox).toBeDefined(); + fireEvent.click(agentCombobox!); + }); + + await waitFor(() => { + expect(screen.getAllByText("Test Agent")).toHaveLength(2); + expect(screen.getByText("Another Agent")).toBeInTheDocument(); + }); + }); + + test("displays agent ID when no name is available", async () => { + const agentWithoutName = { + agent_id: "agent_789", + agent_config: { + instructions: "You are an agent without a name.", + }, + }; + + mockClient.agents.list.mockResolvedValue({ data: [agentWithoutName] }); + + await act(async () => { + render(); + }); + + await waitFor(() => { + const agentCombobox = screen.getAllByRole("combobox").find(element => { + return ( + element.textContent?.includes("Agent agent_78") || + element.textContent?.includes("Select Agent") + ); + }); + expect(agentCombobox).toBeDefined(); + fireEvent.click(agentCombobox!); + }); + + await waitFor(() => { + expect(screen.getAllByText("Agent agent_78...")).toHaveLength(2); + }); + }); + }); + + describe("Agent Creation Modal", () => { + test("opens agent creation modal when + New Agent is clicked", async () => { + await act(async () => { + render(); + }); + + const newAgentButton = screen.getByText("+ New Agent"); + fireEvent.click(newAgentButton); + + expect(screen.getByText("Create New Agent")).toBeInTheDocument(); + expect(screen.getByText("Agent Name (optional)")).toBeInTheDocument(); + expect(screen.getAllByText("Model")).toHaveLength(2); + expect(screen.getByText("System Instructions")).toBeInTheDocument(); + expect(screen.getByText("Tools (optional)")).toBeInTheDocument(); + }); + + test("closes modal when Cancel is clicked", async () => { + await act(async () => { + render(); + }); + + const newAgentButton = screen.getByText("+ New Agent"); + fireEvent.click(newAgentButton); + + const cancelButton = screen.getByText("Cancel"); + fireEvent.click(cancelButton); + + expect(screen.queryByText("Create New Agent")).not.toBeInTheDocument(); + }); + + test("creates agent when Create Agent is clicked", async () => { + mockClient.agents.create.mockResolvedValue({ agent_id: "new-agent-123" }); + mockClient.agents.list + .mockResolvedValueOnce({ data: mockAgents }) + .mockResolvedValueOnce({ + data: [ + ...mockAgents, + { agent_id: "new-agent-123", agent_config: { name: "New Agent" } }, + ], + }); + + await act(async () => { + render(); + }); + + const newAgentButton = screen.getByText("+ New Agent"); + await act(async () => { + fireEvent.click(newAgentButton); + }); + + await waitFor(() => { + expect(screen.getByText("Create New Agent")).toBeInTheDocument(); + }); + + const nameInput = screen.getByPlaceholderText("My Custom Agent"); + await act(async () => { + fireEvent.change(nameInput, { target: { value: "Test Agent Name" } }); + }); + + const instructionsTextarea = screen.getByDisplayValue( + "You are a helpful assistant." + ); + await act(async () => { + fireEvent.change(instructionsTextarea, { + target: { value: "Custom instructions" }, + }); + }); + + await waitFor(() => { + const modalModelSelectors = screen + .getAllByRole("combobox") + .filter(el => { + return ( + el.textContent?.includes("Select Model") || + el.closest('[class*="modal"]') || + el.closest('[class*="card"]') + ); + }); + expect(modalModelSelectors.length).toBeGreaterThan(0); + }); + + const modalModelSelectors = screen.getAllByRole("combobox").filter(el => { + return ( + el.textContent?.includes("Select Model") || + el.closest('[class*="modal"]') || + el.closest('[class*="card"]') + ); + }); + + await act(async () => { + fireEvent.click(modalModelSelectors[0]); + }); + + await waitFor(() => { + const modelOptions = screen.getAllByText("test-model-1"); + expect(modelOptions.length).toBeGreaterThan(0); + }); + + const modelOptions = screen.getAllByText("test-model-1"); + const dropdownOption = modelOptions.find( + option => + option.closest('[role="option"]') || + option.id?.includes("radix") || + option.getAttribute("aria-selected") !== null + ); + + await act(async () => { + fireEvent.click( + dropdownOption || modelOptions[modelOptions.length - 1] + ); + }); + + await waitFor(() => { + const createButton = screen.getByText("Create Agent"); + expect(createButton).not.toBeDisabled(); + }); + + const createButton = screen.getByText("Create Agent"); + await act(async () => { + fireEvent.click(createButton); + }); + + await waitFor(() => { + expect(mockClient.agents.create).toHaveBeenCalledWith({ + agent_config: { + model: expect.any(String), + instructions: "Custom instructions", + name: "Test Agent Name", + enable_session_persistence: true, + }, + }); + }); + + await waitFor(() => { + expect(screen.queryByText("Create New Agent")).not.toBeInTheDocument(); + }); + }); + }); + + describe("Agent Selection", () => { + test("creates default session when agent is selected", async () => { + await act(async () => { + render(); + }); + + await waitFor(() => { + // first agent should be auto-selected + expect(mockClient.agents.session.create).toHaveBeenCalledWith( + "agent_123", + { session_name: "Default Session" } + ); + }); + }); + + test("switches agent when different agent is selected", async () => { + await act(async () => { + render(); + }); + + await waitFor(() => { + const agentCombobox = screen.getAllByRole("combobox").find(element => { + return ( + element.textContent?.includes("Test Agent") || + element.textContent?.includes("Select Agent") + ); + }); + expect(agentCombobox).toBeDefined(); + fireEvent.click(agentCombobox!); + }); + + await waitFor(() => { + const anotherAgentOption = screen.getByText("Another Agent"); + fireEvent.click(anotherAgentOption); + }); + + expect(mockClient.agents.session.create).toHaveBeenCalledWith( + "agent_456", + { session_name: "Default Session" } + ); + }); + }); + + describe("Agent Deletion", () => { + test("shows delete button when multiple agents exist", async () => { + await act(async () => { + render(); + }); + + await waitFor(() => { + expect(screen.getByTitle("Delete current agent")).toBeInTheDocument(); + }); + }); + + test("hides delete button when only one agent exists", async () => { + mockClient.agents.list.mockResolvedValue({ + data: [mockAgents[0]], + }); + + await act(async () => { + render(); + }); + + await waitFor(() => { + expect( + screen.queryByTitle("Delete current agent") + ).not.toBeInTheDocument(); + }); + }); + + test("deletes agent and switches to another when confirmed", async () => { + global.confirm = jest.fn(() => true); + + await act(async () => { + render(); + }); + + await waitFor(() => { + expect(screen.getByTitle("Delete current agent")).toBeInTheDocument(); + }); + + mockClient.agents.delete.mockResolvedValue(undefined); + mockClient.agents.list.mockResolvedValueOnce({ data: mockAgents }); + mockClient.agents.list.mockResolvedValueOnce({ + data: [mockAgents[1]], + }); + + const deleteButton = screen.getByTitle("Delete current agent"); + await act(async () => { + deleteButton.click(); + }); + + await waitFor(() => { + expect(mockClient.agents.delete).toHaveBeenCalledWith("agent_123"); + expect(global.confirm).toHaveBeenCalledWith( + "Are you sure you want to delete this agent? This action cannot be undone and will delete all associated sessions." + ); + }); + + (global.confirm as jest.Mock).mockRestore(); + }); + + test("does not delete agent when cancelled", async () => { + global.confirm = jest.fn(() => false); + + await act(async () => { + render(); + }); + + await waitFor(() => { + expect(screen.getByTitle("Delete current agent")).toBeInTheDocument(); + }); + + const deleteButton = screen.getByTitle("Delete current agent"); + await act(async () => { + deleteButton.click(); + }); + + await waitFor(() => { + expect(global.confirm).toHaveBeenCalled(); + expect(mockClient.agents.delete).not.toHaveBeenCalled(); + }); + + (global.confirm as jest.Mock).mockRestore(); + }); + }); + + describe("Error Handling", () => { + test("handles agent loading errors gracefully", async () => { + mockClient.agents.list.mockRejectedValue( + new Error("Failed to load agents") + ); + const consoleSpy = jest + .spyOn(console, "error") + .mockImplementation(() => {}); + + await act(async () => { + render(); + }); + + await waitFor(() => { + expect(consoleSpy).toHaveBeenCalledWith( + "Error fetching agents:", + expect.any(Error) + ); + }); + + expect(screen.getByText("+ New Agent")).toBeInTheDocument(); + + consoleSpy.mockRestore(); + }); + + test("handles model loading errors gracefully", async () => { + mockClient.models.list.mockRejectedValue( + new Error("Failed to load models") + ); + const consoleSpy = jest + .spyOn(console, "error") + .mockImplementation(() => {}); + + await act(async () => { + render(); + }); + + await waitFor(() => { + expect(consoleSpy).toHaveBeenCalledWith( + "Error fetching models:", + expect.any(Error) + ); + }); + + consoleSpy.mockRestore(); + }); + }); +}); diff --git a/llama_stack/ui/app/chat-playground/page.tsx b/llama_stack/ui/app/chat-playground/page.tsx index b8651aca0..f26791a41 100644 --- a/llama_stack/ui/app/chat-playground/page.tsx +++ b/llama_stack/ui/app/chat-playground/page.tsx @@ -1,6 +1,6 @@ "use client"; -import { useState, useEffect } from "react"; +import { useState, useEffect, useCallback, useRef } from "react"; import { flushSync } from "react-dom"; import { Button } from "@/components/ui/button"; import { @@ -10,14 +10,22 @@ import { SelectTrigger, SelectValue, } from "@/components/ui/select"; +import { Card } from "@/components/ui/card"; +import { Input } from "@/components/ui/input"; +import { Trash2 } from "lucide-react"; import { Chat } from "@/components/chat-playground/chat"; import { type Message } from "@/components/chat-playground/chat-message"; import { useAuthClient } from "@/hooks/use-auth-client"; -import type { CompletionCreateParams } from "llama-stack-client/resources/chat/completions"; import type { Model } from "llama-stack-client/resources/models"; - +import type { TurnCreateParams } from "llama-stack-client/resources/agents/turn"; +import { + SessionUtils, + type ChatSession, +} from "@/components/chat-playground/conversations"; export default function ChatPlaygroundPage() { - const [messages, setMessages] = useState([]); + const [currentSession, setCurrentSession] = useState( + null + ); const [input, setInput] = useState(""); const [isGenerating, setIsGenerating] = useState(false); const [error, setError] = useState(null); @@ -25,10 +33,523 @@ export default function ChatPlaygroundPage() { const [selectedModel, setSelectedModel] = useState(""); const [modelsLoading, setModelsLoading] = useState(true); const [modelsError, setModelsError] = useState(null); + const [agents, setAgents] = useState< + Array<{ + agent_id: string; + agent_config?: { + agent_name?: string; + name?: string; + instructions?: string; + }; + [key: string]: unknown; + }> + >([]); + const [selectedAgentConfig, setSelectedAgentConfig] = useState<{ + toolgroups?: Array< + string | { name: string; args: Record } + >; + } | null>(null); + const [selectedAgentId, setSelectedAgentId] = useState(""); + const [agentsLoading, setAgentsLoading] = useState(true); + const [showCreateAgent, setShowCreateAgent] = useState(false); + const [newAgentName, setNewAgentName] = useState(""); + const [newAgentInstructions, setNewAgentInstructions] = useState( + "You are a helpful assistant." + ); + const [selectedToolgroups, setSelectedToolgroups] = useState([]); + const [availableToolgroups, setAvailableToolgroups] = useState< + Array<{ + identifier: string; + provider_id: string; + type: string; + provider_resource_id?: string; + }> + >([]); const client = useAuthClient(); + const abortControllerRef = useRef(null); const isModelsLoading = modelsLoading ?? true; + const loadAgentConfig = useCallback( + async (agentId: string) => { + try { + console.log("Loading agent config for:", agentId); + + // try to load from cache first + const cachedConfig = SessionUtils.loadAgentConfig(agentId); + if (cachedConfig) { + console.log("✅ Loaded agent config from cache:", cachedConfig); + setSelectedAgentConfig({ + toolgroups: cachedConfig.toolgroups, + }); + return; + } + + console.log("📡 Fetching agent config from API..."); + const agentDetails = await client.agents.retrieve(agentId); + console.log("Agent details retrieved:", agentDetails); + console.log("Agent config:", agentDetails.agent_config); + console.log("Agent toolgroups:", agentDetails.agent_config?.toolgroups); + + // cache the config + SessionUtils.saveAgentConfig(agentId, agentDetails.agent_config); + + setSelectedAgentConfig({ + toolgroups: agentDetails.agent_config?.toolgroups, + }); + } catch (error) { + console.error("Error loading agent config:", error); + setSelectedAgentConfig(null); + } + }, + [client] + ); + + const createDefaultSession = useCallback( + async (agentId: string) => { + try { + const response = await client.agents.session.create(agentId, { + session_name: "Default Session", + }); + + const defaultSession: ChatSession = { + id: response.session_id, + name: "Default Session", + messages: [], + selectedModel: selectedModel, // Use current selected model + systemMessage: "You are a helpful assistant.", + agentId, + createdAt: Date.now(), + updatedAt: Date.now(), + }; + + setCurrentSession(defaultSession); + console.log( + `💾 Saving default session ID for agent ${agentId}:`, + defaultSession.id + ); + SessionUtils.saveCurrentSessionId(defaultSession.id, agentId); + // cache entire session data + SessionUtils.saveSessionData(agentId, defaultSession); + } catch (error) { + console.error("Error creating default session:", error); + } + }, + [client, selectedModel] + ); + + const loadSessionMessages = useCallback( + async (agentId: string, sessionId: string): Promise => { + try { + const session = await client.agents.session.retrieve( + agentId, + sessionId + ); + + if (!session || !session.turns || !Array.isArray(session.turns)) { + return []; + } + + const messages: Message[] = []; + for (const turn of session.turns) { + // add user messages + if (turn.input_messages && Array.isArray(turn.input_messages)) { + for (const input of turn.input_messages) { + if (input.role === "user" && input.content) { + messages.push({ + id: `${turn.turn_id}-user-${messages.length}`, + role: "user", + content: + typeof input.content === "string" + ? input.content + : JSON.stringify(input.content), + createdAt: new Date(turn.started_at || Date.now()), + }); + } + } + } + + // add assistant message from output_message + if (turn.output_message && turn.output_message.content) { + messages.push({ + id: `${turn.turn_id}-assistant-${messages.length}`, + role: "assistant", + content: + typeof turn.output_message.content === "string" + ? turn.output_message.content + : JSON.stringify(turn.output_message.content), + createdAt: new Date( + turn.completed_at || turn.started_at || Date.now() + ), + }); + } + } + + return messages; + } catch (error) { + console.error("Error loading session messages:", error); + return []; + } + }, + [client] + ); + + const loadAgentSessions = useCallback( + async (agentId: string) => { + try { + console.log("Loading sessions for agent:", agentId); + const response = await client.agents.session.list(agentId); + console.log("Available sessions:", response.data); + + if ( + response.data && + Array.isArray(response.data) && + response.data.length > 0 + ) { + // check for a previously saved session ID for this specific agent + const savedSessionId = SessionUtils.loadCurrentSessionId(agentId); + console.log(`Saved session ID for agent ${agentId}:`, savedSessionId); + + // try to load cached session data first + if (savedSessionId) { + const cachedSession = SessionUtils.loadSessionData( + agentId, + savedSessionId + ); + if (cachedSession) { + console.log("✅ Loaded session from cache:", cachedSession.id); + setCurrentSession(cachedSession); + SessionUtils.saveCurrentSessionId(cachedSession.id, agentId); + return; + } + console.log("📡 Cache miss, fetching session from API..."); + } + + let sessionToLoad = response.data[0] as { + session_id: string; + session_name?: string; + started_at?: string; + }; + console.log( + "Default session to load (first in list):", + sessionToLoad.session_id + ); + + // try to find saved session id in available sessions + if (savedSessionId) { + const foundSession = response.data.find( + (s: { session_id: string }) => s.session_id === savedSessionId + ); + console.log("Found saved session in list:", foundSession); + if (foundSession) { + sessionToLoad = foundSession as { + session_id: string; + session_name?: string; + started_at?: string; + }; + console.log( + "✅ Restored previously selected session:", + savedSessionId + ); + } else { + console.log( + "❌ Previously selected session not found, using latest session" + ); + } + } else { + console.log("❌ No saved session ID found, using latest session"); + } + + const messages = await loadSessionMessages( + agentId, + sessionToLoad.session_id + ); + + const session: ChatSession = { + id: sessionToLoad.session_id, + name: sessionToLoad.session_name || "Session", + messages, + selectedModel: selectedModel || "", // Preserve current model or use empty + systemMessage: "You are a helpful assistant.", + agentId, + createdAt: sessionToLoad.started_at + ? new Date(sessionToLoad.started_at).getTime() + : Date.now(), + updatedAt: Date.now(), + }; + + setCurrentSession(session); + console.log(`💾 Saving session ID for agent ${agentId}:`, session.id); + SessionUtils.saveCurrentSessionId(session.id, agentId); + // cache session data + SessionUtils.saveSessionData(agentId, session); + } else { + // no sessions, create a new one + await createDefaultSession(agentId); + } + } catch (error) { + console.error("Error loading agent sessions:", error); + // fallback to creating a new session + await createDefaultSession(agentId); + } + }, + [client, loadSessionMessages, createDefaultSession, selectedModel] + ); + + useEffect(() => { + const fetchAgents = async () => { + try { + setAgentsLoading(true); + const agentList = await client.agents.list(); + setAgents( + (agentList.data as Array<{ + agent_id: string; + agent_config?: { + agent_name?: string; + name?: string; + instructions?: string; + }; + [key: string]: unknown; + }>) || [] + ); + + if (agentList.data && agentList.data.length > 0) { + // check if there's a previously selected agent + const savedAgentId = SessionUtils.loadCurrentAgentId(); + + let agentToSelect = agentList.data[0] as { + agent_id: string; + agent_config?: { + agent_name?: string; + name?: string; + instructions?: string; + }; + [key: string]: unknown; + }; + + // if we have a saved agent ID, find it in the available agents + if (savedAgentId) { + const foundAgent = agentList.data.find( + (a: { agent_id: string }) => a.agent_id === savedAgentId + ); + if (foundAgent) { + agentToSelect = foundAgent as typeof agentToSelect; + } else { + console.log("Previously slelected agent not found:"); + } + } + setSelectedAgentId(agentToSelect.agent_id); + SessionUtils.saveCurrentAgentId(agentToSelect.agent_id); + // load agent config immediately + await loadAgentConfig(agentToSelect.agent_id); + // Note: loadAgentSessions will be called after models are loaded + } + } catch (error) { + console.error("Error fetching agents:", error); + } finally { + setAgentsLoading(false); + } + }; + + fetchAgents(); + + // fetch available toolgroups + const fetchToolgroups = async () => { + try { + console.log("Fetching toolgroups..."); + const toolgroups = await client.toolgroups.list(); + console.log("Toolgroups response:", toolgroups); + + // The client returns data directly, not wrapped in .data + const toolGroupsArray = Array.isArray(toolgroups) + ? toolgroups + : toolgroups && + typeof toolgroups === "object" && + "data" in toolgroups && + Array.isArray((toolgroups as { data: unknown }).data) + ? ( + toolgroups as { + data: Array<{ + identifier: string; + provider_id: string; + type: string; + provider_resource_id?: string; + }>; + } + ).data + : []; + + if (toolGroupsArray && Array.isArray(toolGroupsArray)) { + setAvailableToolgroups(toolGroupsArray); + console.log("Set toolgroups:", toolGroupsArray); + } else { + console.error("Invalid toolgroups data format:", toolgroups); + } + } catch (error) { + console.error("Error fetching toolgroups:", error); + if (error instanceof Error) { + console.error("Error details:", { + name: error.name, + message: error.message, + stack: error.stack, + }); + } + } + }; + + fetchToolgroups(); + }, [client, loadAgentSessions, loadAgentConfig]); + + const createNewAgent = useCallback( + async ( + name: string, + instructions: string, + model: string, + toolgroups: string[] = [] + ) => { + try { + console.log("Creating agent with toolgroups:", toolgroups); + const agentConfig = { + model, + instructions, + name: name || undefined, + enable_session_persistence: true, + toolgroups: toolgroups.length > 0 ? toolgroups : undefined, + }; + console.log("Agent config being sent:", agentConfig); + + const response = await client.agents.create({ + agent_config: agentConfig, + }); + + // refresh agents list + const agentList = await client.agents.list(); + setAgents( + (agentList.data as Array<{ + agent_id: string; + agent_config?: { + agent_name?: string; + name?: string; + instructions?: string; + }; + [key: string]: unknown; + }>) || [] + ); + + // set the new agent as selected + setSelectedAgentId(response.agent_id); + await loadAgentConfig(response.agent_id); + await loadAgentSessions(response.agent_id); + + return response.agent_id; + } catch (error) { + console.error("Error creating agent:", error); + throw error; + } + }, + [client, loadAgentSessions, loadAgentConfig] + ); + + const deleteAgent = useCallback( + async (agentId: string) => { + if (agents.length <= 1) { + return; + } + + if ( + confirm( + "Are you sure you want to delete this agent? This action cannot be undone and will delete all associated sessions." + ) + ) { + try { + await client.agents.delete(agentId); + + // clear cached data for agent + SessionUtils.clearAgentCache(agentId); + + // Refresh agents list + const agentList = await client.agents.list(); + setAgents( + (agentList.data as Array<{ + agent_id: string; + agent_config?: { + agent_name?: string; + name?: string; + instructions?: string; + }; + [key: string]: unknown; + }>) || [] + ); + + // if we deleted the current agent, switch to another one + if (selectedAgentId === agentId) { + const remainingAgents = agentList.data?.filter( + (a: { agent_id: string }) => a.agent_id !== agentId + ); + if (remainingAgents && remainingAgents.length > 0) { + const newAgent = remainingAgents[0] as { + agent_id: string; + agent_config?: { + agent_name?: string; + name?: string; + instructions?: string; + }; + [key: string]: unknown; + }; + setSelectedAgentId(newAgent.agent_id); + SessionUtils.saveCurrentAgentId(newAgent.agent_id); + await loadAgentConfig(newAgent.agent_id); + await loadAgentSessions(newAgent.agent_id); + } else { + // No agents left + setSelectedAgentId(""); + setCurrentSession(null); + setSelectedAgentConfig(null); + } + } + } catch (error) { + console.error("Error deleting agent:", error); + } + } + }, + [agents.length, client, selectedAgentId, loadAgentConfig, loadAgentSessions] + ); + + const handleModelChange = useCallback((newModel: string) => { + setSelectedModel(newModel); + setCurrentSession(prev => + prev + ? { + ...prev, + selectedModel: newModel, + updatedAt: Date.now(), + } + : prev + ); + }, []); + + useEffect(() => { + if (currentSession) { + console.log( + `💾 Auto-saving session ID for agent ${currentSession.agentId}:`, + currentSession.id + ); + SessionUtils.saveCurrentSessionId( + currentSession.id, + currentSession.agentId + ); + // cache session data + SessionUtils.saveSessionData(currentSession.agentId, currentSession); + // only update selectedModel if the session has a valid model and it's different from current + if ( + currentSession.selectedModel && + currentSession.selectedModel !== selectedModel + ) { + setSelectedModel(currentSession.selectedModel); + } + } + }, [currentSession, selectedModel]); + useEffect(() => { const fetchModels = async () => { try { @@ -38,7 +559,7 @@ export default function ChatPlaygroundPage() { const llmModels = modelList.filter(model => model.model_type === "llm"); setModels(llmModels); if (llmModels.length > 0) { - setSelectedModel(llmModels[0].identifier); + handleModelChange(llmModels[0].identifier); } } catch (err) { console.error("Error fetching models:", err); @@ -49,39 +570,27 @@ export default function ChatPlaygroundPage() { }; fetchModels(); - }, [client]); + }, [client, handleModelChange]); - const extractTextContent = (content: unknown): string => { - if (typeof content === "string") { - return content; - } - if (Array.isArray(content)) { - return content - .filter( - item => - item && - typeof item === "object" && - "type" in item && - item.type === "text" - ) - .map(item => - item && typeof item === "object" && "text" in item - ? String(item.text) - : "" - ) - .join(""); - } + // load agent sessions after both agents and models are ready + useEffect(() => { if ( - content && - typeof content === "object" && - "type" in content && - content.type === "text" && - "text" in content + selectedAgentId && + !agentsLoading && + !modelsLoading && + selectedModel && + !currentSession ) { - return String(content.text) || ""; + loadAgentSessions(selectedAgentId); } - return ""; - }; + }, [ + selectedAgentId, + agentsLoading, + modelsLoading, + selectedModel, + currentSession, + loadAgentSessions, + ]); const handleInputChange = (e: React.ChangeEvent) => { setInput(e.target.value); @@ -91,7 +600,6 @@ export default function ChatPlaygroundPage() { event?.preventDefault?.(); if (!input.trim()) return; - // Add user message to chat const userMessage: Message = { id: Date.now().toString(), role: "user", @@ -99,40 +607,54 @@ export default function ChatPlaygroundPage() { createdAt: new Date(), }; - setMessages(prev => [...prev, userMessage]); + setCurrentSession(prev => { + if (!prev) return prev; + const updatedSession = { + ...prev, + messages: [...prev.messages, userMessage], + updatedAt: Date.now(), + }; + // Update cache with new message + SessionUtils.saveSessionData(prev.agentId, updatedSession); + return updatedSession; + }); setInput(""); - // Use the helper function with the content await handleSubmitWithContent(userMessage.content); }; const handleSubmitWithContent = async (content: string) => { + if (!currentSession || !selectedAgentId) return; + setIsGenerating(true); setError(null); - try { - const messageParams: CompletionCreateParams["messages"] = [ - ...messages.map(msg => { - const msgContent = - typeof msg.content === "string" - ? msg.content - : extractTextContent(msg.content); - if (msg.role === "user") { - return { role: "user" as const, content: msgContent }; - } else if (msg.role === "assistant") { - return { role: "assistant" as const, content: msgContent }; - } else { - return { role: "system" as const, content: msgContent }; - } - }), - { role: "user" as const, content }, - ]; + if (abortControllerRef.current) { + abortControllerRef.current.abort(); + } - const response = await client.chat.completions.create({ - model: selectedModel, - messages: messageParams, + const abortController = new AbortController(); + abortControllerRef.current = abortController; + + try { + const userMessage = { + role: "user" as const, + content, + }; + + const turnParams: TurnCreateParams = { + messages: [userMessage], stream: true, - }); + }; + + const response = await client.agents.turn.create( + selectedAgentId, + currentSession.id, + turnParams, + { + signal: abortController.signal, + } as { signal: AbortSignal } + ); const assistantMessage: Message = { id: (Date.now() + 1).toString(), @@ -141,31 +663,112 @@ export default function ChatPlaygroundPage() { createdAt: new Date(), }; - setMessages(prev => [...prev, assistantMessage]); + const extractDeltaText = (chunk: unknown): string | null => { + // this is an awful way to handle different chunk formats, but i'm not sure if there's much of a better way + if (chunk?.delta?.text && typeof chunk.delta.text === "string") { + return chunk.delta.text; + } + + if ( + chunk?.event?.delta?.text && + typeof chunk.event.delta.text === "string" + ) { + return chunk.event.delta.text; + } + + if ( + chunk?.choices?.[0]?.delta?.content && + typeof chunk.choices[0].delta.content === "string" + ) { + return chunk.choices[0].delta.content; + } + + if (typeof chunk === "string") { + return chunk; + } + + if ( + chunk?.event?.payload?.delta?.text && + typeof chunk.event.payload.delta.text === "string" + ) { + return chunk.event.payload.delta.text; + } + + if (process.env.NODE_ENV !== "production") { + console.debug("Unrecognized chunk format:", chunk); + } + + return null; + }; + setCurrentSession(prev => { + if (!prev) return null; + const updatedSession = { + ...prev, + messages: [...prev.messages, assistantMessage], + updatedAt: Date.now(), + }; + // update cache with assistant message + SessionUtils.saveSessionData(prev.agentId, updatedSession); + return updatedSession; + }); + let fullContent = ""; for await (const chunk of response) { - if (chunk.choices && chunk.choices[0]?.delta?.content) { - const deltaContent = chunk.choices[0].delta.content; - fullContent += deltaContent; + const deltaText = extractDeltaText(chunk); + + if (deltaText) { + fullContent += deltaText; flushSync(() => { - setMessages(prev => { - const newMessages = [...prev]; - const lastMessage = newMessages[newMessages.length - 1]; - if (lastMessage.role === "assistant") { - lastMessage.content = fullContent; + setCurrentSession(prev => { + if (!prev) return null; + const newMessages = [...prev.messages]; + const last = newMessages[newMessages.length - 1]; + if (last.role === "assistant") { + last.content = fullContent; } - return newMessages; + const updatedSession = { + ...prev, + messages: newMessages, + updatedAt: Date.now(), + }; + // update cache with streaming content (throttled) + if (fullContent.length % 100 === 0) { + // Only cache every 100 characters to avoid spam + SessionUtils.saveSessionData(prev.agentId, updatedSession); + } + return updatedSession; }); }); } } } catch (err) { + if (err instanceof Error && err.name === "AbortError") { + console.log("Request aborted"); + return; + } + console.error("Error sending message:", err); setError("Failed to send message. Please try again."); - setMessages(prev => prev.slice(0, -1)); + setCurrentSession(prev => + prev + ? { + ...prev, + messages: prev.messages.slice(0, -1), + updatedAt: Date.now(), + } + : prev + ); } finally { setIsGenerating(false); + abortControllerRef.current = null; + // cache final session state after streaming completes + setCurrentSession(prev => { + if (prev) { + SessionUtils.saveSessionData(prev.agentId, prev); + } + return prev; + }); } }; const suggestions = [ @@ -181,69 +784,457 @@ export default function ChatPlaygroundPage() { content: message.content, createdAt: new Date(), }; - setMessages(prev => [...prev, newMessage]); + setCurrentSession(prev => + prev + ? { + ...prev, + messages: [...prev.messages, newMessage], + updatedAt: Date.now(), + } + : prev + ); handleSubmitWithContent(newMessage.content); }; const clearChat = () => { - setMessages([]); + if (abortControllerRef.current) { + abortControllerRef.current.abort(); + abortControllerRef.current = null; + setIsGenerating(false); + } + + setCurrentSession(prev => + prev ? { ...prev, messages: [], updatedAt: Date.now() } : prev + ); setError(null); }; return ( -
-
-

Chat Playground (Completions)

-
- - +
+ {/* Header */} +
+
+

Agent Session

+
+ {!agentsLoading && agents.length > 0 && ( +
+ + + {selectedAgentId && agents.length > 1 && ( + + )} +
+ )} + + {!agentsLoading && agents.length > 0 && ( + + )} +
+
+
+ {/* Main Two-Column Layout */} +
+ {/* Left Column - Configuration Panel */} +
+

+ Settings +

+ + {/* Model Configuration */} +
+

+ Model Configuration +

+
+
+ + + {modelsError && ( +

{modelsError}

+ )} +
+ +
+ +
+ {(selectedAgentId && + agents.find(a => a.agent_id === selectedAgentId) + ?.agent_config?.instructions) || + "No agent selected"} +
+

+ Instructions are set when creating an agent and cannot be + changed. +

+
+
+
+ + {/* Agent Tools */} +
+

+ Agent Tools +

+
+
+ +
+ {selectedAgentConfig?.toolgroups && + selectedAgentConfig.toolgroups.length > 0 ? ( + selectedAgentConfig.toolgroups.map( + ( + toolgroup: + | string + | { name: string; args: Record }, + index: number + ) => { + const toolName = + typeof toolgroup === "string" + ? toolgroup + : toolgroup.name; + const toolArgs = + typeof toolgroup === "object" ? toolgroup.args : null; + + return ( +
+
+ + {toolName} + + + {toolName.includes("rag") + ? "🔍 RAG" + : toolName.includes("search") + ? "🌐 Search" + : "🔧 Tool"} + +
+ {toolArgs && Object.keys(toolArgs).length > 0 && ( +
+ Args:{" "} + {Object.entries(toolArgs) + .map( + ([key, value]) => + `${key}: ${JSON.stringify(value)}` + ) + .join(", ")} +
+ )} +
+ ); + } + ) + ) : ( +
+

+ No tools configured +

+

+ This agent only has text generation capabilities +

+
+ )} +
+

+ Tools are configured when creating an agent and provide + additional capabilities like web search, math calculations, or + RAG document retrieval. +

+
+
+
+
+ + {/* Right Column - Chat Interface */} +
+ {error && ( +
+

{error}

+
+ )} + + + setCurrentSession(prev => + prev ? { ...prev, messages, updatedAt: Date.now() } : prev + ) + } + />
- {modelsError && ( -
-

{modelsError}

+ {/* Create Agent Modal */} + {showCreateAgent && ( +
+ +

Create New Agent

+ +
+
+ + setNewAgentName(e.target.value)} + placeholder="My Custom Agent" + /> +
+ +
+ + +
+ +
+ +