diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 134efd93b..f88402a7a 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -9,6 +9,7 @@ updates: day: "saturday" commit-message: prefix: chore(github-deps) + - package-ecosystem: "uv" directory: "/" schedule: @@ -19,3 +20,14 @@ updates: - python commit-message: prefix: chore(python-deps) + + - package-ecosystem: npm + directory: "/llama_stack/ui" + schedule: + interval: "weekly" + day: "saturday" + labels: + - type/dependencies + - javascript + commit-message: + prefix: chore(ui-deps) diff --git a/.github/workflows/changelog.yml b/.github/workflows/changelog.yml index e406d99ee..7a75d85f6 100644 --- a/.github/workflows/changelog.yml +++ b/.github/workflows/changelog.yml @@ -17,7 +17,7 @@ jobs: pull-requests: write # for peter-evans/create-pull-request to create a PR runs-on: ubuntu-latest steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: ref: main fetch-depth: 0 diff --git a/.github/workflows/install-script-ci.yml b/.github/workflows/install-script-ci.yml index 1ecda6d51..a37919f56 100644 --- a/.github/workflows/install-script-ci.yml +++ b/.github/workflows/install-script-ci.yml @@ -16,14 +16,14 @@ jobs: lint: runs-on: ubuntu-latest steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # 4.2.2 + - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # 5.0.0 - name: Run ShellCheck on install.sh run: shellcheck scripts/install.sh smoke-test-on-dev: runs-on: ubuntu-latest steps: - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Install dependencies uses: ./.github/actions/setup-runner diff --git a/.github/workflows/integration-auth-tests.yml b/.github/workflows/integration-auth-tests.yml index c328e3b6c..6787806e9 100644 --- a/.github/workflows/integration-auth-tests.yml +++ b/.github/workflows/integration-auth-tests.yml @@ -18,7 +18,7 @@ on: - '.github/workflows/integration-auth-tests.yml' # This workflow concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }} cancel-in-progress: true jobs: @@ -31,7 +31,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Install dependencies uses: ./.github/actions/setup-runner diff --git a/.github/workflows/integration-sql-store-tests.yml b/.github/workflows/integration-sql-store-tests.yml index 4e5b64963..3efd970e1 100644 --- a/.github/workflows/integration-sql-store-tests.yml +++ b/.github/workflows/integration-sql-store-tests.yml @@ -16,7 +16,7 @@ on: - '.github/workflows/integration-sql-store-tests.yml' # This workflow concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }} cancel-in-progress: true jobs: @@ -44,7 +44,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Install dependencies uses: ./.github/actions/setup-runner diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index ba18c27c8..57e582b20 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -65,7 +65,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Setup test environment uses: ./.github/actions/setup-test-environment diff --git a/.github/workflows/integration-vector-io-tests.yml b/.github/workflows/integration-vector-io-tests.yml index 61b8e004e..de5701073 100644 --- a/.github/workflows/integration-vector-io-tests.yml +++ b/.github/workflows/integration-vector-io-tests.yml @@ -33,7 +33,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Install dependencies uses: ./.github/actions/setup-runner diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 99e0d0043..2825c3bf4 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -8,7 +8,7 @@ on: branches: [main] concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }} cancel-in-progress: true jobs: @@ -20,7 +20,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: # For dependabot PRs, we need to checkout with a token that can push changes token: ${{ github.actor == 'dependabot[bot]' && secrets.GITHUB_TOKEN || github.token }} @@ -36,20 +36,16 @@ jobs: **/requirements*.txt .pre-commit-config.yaml - # npm ci may fail - - # npm error `npm ci` can only install packages when your package.json and package-lock.json or npm-shrinkwrap.json are in sync. Please update your lock file with `npm install` before continuing. - # npm error Invalid: lock file's llama-stack-client@0.2.17 does not satisfy llama-stack-client@0.2.18 + - name: Set up Node.js + uses: actions/setup-node@39370e3970a6d050c480ffad4ff0ed4d3fdee5af # v4.1.0 + with: + node-version: '20' + cache: 'npm' + cache-dependency-path: 'llama_stack/ui/' - # - name: Set up Node.js - # uses: actions/setup-node@39370e3970a6d050c480ffad4ff0ed4d3fdee5af # v4.1.0 - # with: - # node-version: '20' - # cache: 'npm' - # cache-dependency-path: 'llama_stack/ui/' - - # - name: Install npm dependencies - # run: npm ci - # working-directory: llama_stack/ui + - name: Install npm dependencies + run: npm ci + working-directory: llama_stack/ui - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1 continue-on-error: true diff --git a/.github/workflows/providers-build.yml b/.github/workflows/providers-build.yml index 929d76760..391acbcf8 100644 --- a/.github/workflows/providers-build.yml +++ b/.github/workflows/providers-build.yml @@ -26,7 +26,7 @@ on: - 'pyproject.toml' concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }} cancel-in-progress: true jobs: @@ -36,7 +36,7 @@ jobs: distros: ${{ steps.set-matrix.outputs.distros }} steps: - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Generate Distribution List id: set-matrix @@ -55,7 +55,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Install dependencies uses: ./.github/actions/setup-runner @@ -79,7 +79,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Install dependencies uses: ./.github/actions/setup-runner @@ -92,7 +92,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Install dependencies uses: ./.github/actions/setup-runner @@ -106,6 +106,10 @@ jobs: - name: Inspect the container image entrypoint run: | IMAGE_ID=$(docker images --format "{{.Repository}}:{{.Tag}}" | head -n 1) + if [ -z "$IMAGE_ID" ]; then + echo "No image found" + exit 1 + fi entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID) echo "Entrypoint: $entrypoint" if [ "$entrypoint" != "[python -m llama_stack.core.server.server /app/run.yaml]" ]; then @@ -117,7 +121,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Install dependencies uses: ./.github/actions/setup-runner @@ -140,6 +144,10 @@ jobs: - name: Inspect UBI9 image run: | IMAGE_ID=$(docker images --format "{{.Repository}}:{{.Tag}}" | head -n 1) + if [ -z "$IMAGE_ID" ]; then + echo "No image found" + exit 1 + fi entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID) echo "Entrypoint: $entrypoint" if [ "$entrypoint" != "[python -m llama_stack.core.server.server /app/run.yaml]" ]; then diff --git a/.github/workflows/python-build-test.yml b/.github/workflows/python-build-test.yml index fe1dfd58a..9de53f7fb 100644 --- a/.github/workflows/python-build-test.yml +++ b/.github/workflows/python-build-test.yml @@ -21,10 +21,10 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Install uv - uses: astral-sh/setup-uv@e92bafb6253dcd438e0484186d7669ea7a8ca1cc # v6.4.3 + uses: astral-sh/setup-uv@d9e0f98d3fc6adb07d1e3d37f3043649ddad06a1 # v6.5.0 with: python-version: ${{ matrix.python-version }} activate-environment: true diff --git a/.github/workflows/record-integration-tests.yml b/.github/workflows/record-integration-tests.yml index 22636f209..d4f5586e2 100644 --- a/.github/workflows/record-integration-tests.yml +++ b/.github/workflows/record-integration-tests.yml @@ -46,7 +46,7 @@ jobs: echo "::endgroup::" - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 with: fetch-depth: 0 diff --git a/.github/workflows/semantic-pr.yml b/.github/workflows/semantic-pr.yml index 57a4df646..4adaca84d 100644 --- a/.github/workflows/semantic-pr.yml +++ b/.github/workflows/semantic-pr.yml @@ -22,6 +22,6 @@ jobs: runs-on: ubuntu-latest steps: - name: Check PR Title's semantic conformance - uses: amannn/action-semantic-pull-request@0723387faaf9b38adef4775cd42cfd5155ed6017 # v5.5.3 + uses: amannn/action-semantic-pull-request@7f33ba792281b034f64e96f4c0b5496782dd3b37 # v6.1.0 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/test-external-provider-module.yml b/.github/workflows/test-external-provider-module.yml index d61b0dfe9..8a757b068 100644 --- a/.github/workflows/test-external-provider-module.yml +++ b/.github/workflows/test-external-provider-module.yml @@ -27,7 +27,7 @@ jobs: # container and point 'uv pip install' to the correct path... steps: - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Install dependencies uses: ./.github/actions/setup-runner diff --git a/.github/workflows/test-external.yml b/.github/workflows/test-external.yml index b9db0ad51..7ee467451 100644 --- a/.github/workflows/test-external.yml +++ b/.github/workflows/test-external.yml @@ -27,7 +27,7 @@ jobs: # container and point 'uv pip install' to the correct path... steps: - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Install dependencies uses: ./.github/actions/setup-runner diff --git a/.github/workflows/ui-unit-tests.yml b/.github/workflows/ui-unit-tests.yml index 00c539c58..2afb92bee 100644 --- a/.github/workflows/ui-unit-tests.yml +++ b/.github/workflows/ui-unit-tests.yml @@ -13,7 +13,7 @@ on: workflow_dispatch: concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }} cancel-in-progress: true jobs: @@ -26,10 +26,10 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Setup Node.js - uses: actions/setup-node@39370e3970a6d050c480ffad4ff0ed4d3fdee5af # v4.1.0 + uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4.4.0 with: node-version: ${{ matrix.node-version }} cache: 'npm' diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index f2a6c7754..dd2097a45 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -18,7 +18,7 @@ on: workflow_dispatch: concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }} cancel-in-progress: true jobs: @@ -32,7 +32,7 @@ jobs: - "3.13" steps: - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Install dependencies uses: ./.github/actions/setup-runner diff --git a/.github/workflows/update-readthedocs.yml b/.github/workflows/update-readthedocs.yml index 1dcfdeca5..e12f0adf8 100644 --- a/.github/workflows/update-readthedocs.yml +++ b/.github/workflows/update-readthedocs.yml @@ -27,7 +27,7 @@ on: - '.github/workflows/update-readthedocs.yml' concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }} cancel-in-progress: true jobs: @@ -37,7 +37,7 @@ jobs: TOKEN: ${{ secrets.READTHEDOCS_TOKEN }} steps: - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 - name: Install dependencies uses: ./.github/actions/setup-runner diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d25455cf0..514fe6d2e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -146,31 +146,13 @@ repos: pass_filenames: false require_serial: true files: ^.github/workflows/.*$ - # ui-prettier and ui-eslint are disabled until we can avoid `npm ci`, which is slow and may fail - - # npm error `npm ci` can only install packages when your package.json and package-lock.json or npm-shrinkwrap.json are in sync. Please update your lock file with `npm install` before continuing. - # npm error Invalid: lock file's llama-stack-client@0.2.17 does not satisfy llama-stack-client@0.2.18 - # and until we have infra for installing prettier and next via npm - - # Lint UI code with ESLint.....................................................Failed - # - hook id: ui-eslint - # - exit code: 127 - # > ui@0.1.0 lint - # > next lint --fix --quiet - # sh: line 1: next: command not found - # - # - id: ui-prettier - # name: Format UI code with Prettier - # entry: bash -c 'cd llama_stack/ui && npm ci && npm run format' - # language: system - # files: ^llama_stack/ui/.*\.(ts|tsx)$ - # pass_filenames: false - # require_serial: true - # - id: ui-eslint - # name: Lint UI code with ESLint - # entry: bash -c 'cd llama_stack/ui && npm run lint -- --fix --quiet' - # language: system - # files: ^llama_stack/ui/.*\.(ts|tsx)$ - # pass_filenames: false - # require_serial: true + - id: ui-linter + name: Format & Lint UI + entry: bash ./scripts/run-ui-linter.sh + language: system + files: ^llama_stack/ui/.*\.(ts|tsx)$ + pass_filenames: false + require_serial: true - id: check-log-usage name: Ensure 'llama_stack.log' usage for logging diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index b36626719..923d19299 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -4605,6 +4605,49 @@ } } }, + "/v1/inference/rerank": { + "post": { + "responses": { + "200": { + "description": "RerankResponse with indices sorted by relevance score (descending).", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/RerankResponse" + } + } + } + }, + "400": { + "$ref": "#/components/responses/BadRequest400" + }, + "429": { + "$ref": "#/components/responses/TooManyRequests429" + }, + "500": { + "$ref": "#/components/responses/InternalServerError500" + }, + "default": { + "$ref": "#/components/responses/DefaultError" + } + }, + "tags": [ + "Inference" + ], + "description": "Rerank a list of documents based on their relevance to a query.", + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/RerankRequest" + } + } + }, + "required": true + } + } + }, "/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume": { "post": { "responses": { @@ -16587,6 +16630,95 @@ ], "title": "RegisterVectorDbRequest" }, + "RerankRequest": { + "type": "object", + "properties": { + "model": { + "type": "string", + "description": "The identifier of the reranking model to use." + }, + "query": { + "oneOf": [ + { + "type": "string" + }, + { + "$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam" + }, + { + "$ref": "#/components/schemas/OpenAIChatCompletionContentPartImageParam" + } + ], + "description": "The search query to rank items against. Can be a string, text content part, or image content part. The input must not exceed the model's max input token length." + }, + "items": { + "type": "array", + "items": { + "oneOf": [ + { + "type": "string" + }, + { + "$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam" + }, + { + "$ref": "#/components/schemas/OpenAIChatCompletionContentPartImageParam" + } + ] + }, + "description": "List of items to rerank. Each item can be a string, text content part, or image content part. Each input must not exceed the model's max input token length." + }, + "max_num_results": { + "type": "integer", + "description": "(Optional) Maximum number of results to return. Default: returns all." + } + }, + "additionalProperties": false, + "required": [ + "model", + "query", + "items" + ], + "title": "RerankRequest" + }, + "RerankData": { + "type": "object", + "properties": { + "index": { + "type": "integer", + "description": "The original index of the document in the input list" + }, + "relevance_score": { + "type": "number", + "description": "The relevance score from the model output. Values are inverted when applicable so that higher scores indicate greater relevance." + } + }, + "additionalProperties": false, + "required": [ + "index", + "relevance_score" + ], + "title": "RerankData", + "description": "A single rerank result from a reranking response." + }, + "RerankResponse": { + "type": "object", + "properties": { + "data": { + "type": "array", + "items": { + "$ref": "#/components/schemas/RerankData" + }, + "description": "List of rerank result objects, sorted by relevance score (descending)" + } + }, + "additionalProperties": false, + "required": [ + "data" + ], + "title": "RerankResponse", + "description": "Response from a reranking request." + }, "ResumeAgentTurnRequest": { "type": "object", "properties": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index e7733b3c3..3d8bd33e5 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -3264,6 +3264,37 @@ paths: schema: $ref: '#/components/schemas/QueryTracesRequest' required: true + /v1/inference/rerank: + post: + responses: + '200': + description: >- + RerankResponse with indices sorted by relevance score (descending). + content: + application/json: + schema: + $ref: '#/components/schemas/RerankResponse' + '400': + $ref: '#/components/responses/BadRequest400' + '429': + $ref: >- + #/components/responses/TooManyRequests429 + '500': + $ref: >- + #/components/responses/InternalServerError500 + default: + $ref: '#/components/responses/DefaultError' + tags: + - Inference + description: >- + Rerank a list of documents based on their relevance to a query. + parameters: [] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/RerankRequest' + required: true /v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume: post: responses: @@ -12337,6 +12368,76 @@ components: - vector_db_id - embedding_model title: RegisterVectorDbRequest + RerankRequest: + type: object + properties: + model: + type: string + description: >- + The identifier of the reranking model to use. + query: + oneOf: + - type: string + - $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam' + - $ref: '#/components/schemas/OpenAIChatCompletionContentPartImageParam' + description: >- + The search query to rank items against. Can be a string, text content + part, or image content part. The input must not exceed the model's max + input token length. + items: + type: array + items: + oneOf: + - type: string + - $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam' + - $ref: '#/components/schemas/OpenAIChatCompletionContentPartImageParam' + description: >- + List of items to rerank. Each item can be a string, text content part, + or image content part. Each input must not exceed the model's max input + token length. + max_num_results: + type: integer + description: >- + (Optional) Maximum number of results to return. Default: returns all. + additionalProperties: false + required: + - model + - query + - items + title: RerankRequest + RerankData: + type: object + properties: + index: + type: integer + description: >- + The original index of the document in the input list + relevance_score: + type: number + description: >- + The relevance score from the model output. Values are inverted when applicable + so that higher scores indicate greater relevance. + additionalProperties: false + required: + - index + - relevance_score + title: RerankData + description: >- + A single rerank result from a reranking response. + RerankResponse: + type: object + properties: + data: + type: array + items: + $ref: '#/components/schemas/RerankData' + description: >- + List of rerank result objects, sorted by relevance score (descending) + additionalProperties: false + required: + - data + title: RerankResponse + description: Response from a reranking request. ResumeAgentTurnRequest: type: object properties: diff --git a/docs/source/distributions/configuration.md b/docs/source/distributions/configuration.md index 335fa3a68..c9677b3b6 100644 --- a/docs/source/distributions/configuration.md +++ b/docs/source/distributions/configuration.md @@ -225,8 +225,32 @@ server: port: 8321 # Port to listen on (default: 8321) tls_certfile: "/path/to/cert.pem" # Optional: Path to TLS certificate for HTTPS tls_keyfile: "/path/to/key.pem" # Optional: Path to TLS key for HTTPS + cors: true # Optional: Enable CORS (dev mode) or full config object ``` +### CORS Configuration + +CORS (Cross-Origin Resource Sharing) can be configured in two ways: + +**Local development** (allows localhost origins only): +```yaml +server: + cors: true +``` + +**Explicit configuration** (custom origins and settings): +```yaml +server: + cors: + allow_origins: ["https://myapp.com", "https://app.example.com"] + allow_methods: ["GET", "POST", "PUT", "DELETE"] + allow_headers: ["Content-Type", "Authorization"] + allow_credentials: true + max_age: 3600 +``` + +When `cors: true`, the server enables secure localhost-only access for local development. For production, specify exact origins to maintain security. + ### Authentication Configuration > **Breaking Change (v0.2.14)**: The authentication configuration structure has changed. The previous format with `provider_type` and `config` fields has been replaced with a unified `provider_config` field that includes the `type` field. Update your configuration files accordingly. @@ -618,6 +642,54 @@ Content-Type: application/json } ``` +### CORS Configuration + +Configure CORS to allow web browsers to make requests from different domains. Disabled by default. + +#### Quick Setup + +For development, use the simple boolean flag: + +```yaml +server: + cors: true # Auto-enables localhost with any port +``` + +This automatically allows `http://localhost:*` and `https://localhost:*` with secure defaults. + +#### Custom Configuration + +For specific origins and full control: + +```yaml +server: + cors: + allow_origins: ["https://myapp.com", "https://staging.myapp.com"] + allow_credentials: true + allow_methods: ["GET", "POST", "PUT", "DELETE"] + allow_headers: ["Content-Type", "Authorization"] + allow_origin_regex: "https://.*\\.example\\.com" # Optional regex pattern + expose_headers: ["X-Total-Count"] + max_age: 86400 +``` + +#### Configuration Options + +| Field | Description | Default | +| -------------------- | ---------------------------------------------- | ------- | +| `allow_origins` | List of allowed origins. Use `["*"]` for any. | `["*"]` | +| `allow_origin_regex` | Regex pattern for allowed origins (optional). | `None` | +| `allow_methods` | Allowed HTTP methods. | `["*"]` | +| `allow_headers` | Allowed headers. | `["*"]` | +| `allow_credentials` | Allow credentials (cookies, auth headers). | `false` | +| `expose_headers` | Headers exposed to browser. | `[]` | +| `max_age` | Preflight cache time (seconds). | `600` | + +**Security Notes**: +- `allow_credentials: true` requires explicit origins (no wildcards) +- `cors: true` enables localhost access only (secure for development) +- For public APIs, always specify exact allowed origins + ## Extending to handle Safety Configuring Safety can be a little involved so it is instructive to go through an example. diff --git a/docs/source/distributions/importing_as_library.md b/docs/source/distributions/importing_as_library.md index fbc48dd95..b9b4b065a 100644 --- a/docs/source/distributions/importing_as_library.md +++ b/docs/source/distributions/importing_as_library.md @@ -17,7 +17,6 @@ client = LlamaStackAsLibraryClient( # provider_data is optional, but if you need to pass in any provider specific data, you can do so here. provider_data={"tavily_search_api_key": os.environ["TAVILY_SEARCH_API_KEY"]}, ) -client.initialize() ``` This will parse your config and set up any inline implementations and remote clients needed for your implementation. @@ -32,5 +31,4 @@ If you've created a [custom distribution](https://llama-stack.readthedocs.io/en/ ```python client = LlamaStackAsLibraryClient(config_path) -client.initialize() ``` diff --git a/docs/source/providers/files/index.md b/docs/source/providers/files/index.md index 692aad3ca..128953223 100644 --- a/docs/source/providers/files/index.md +++ b/docs/source/providers/files/index.md @@ -10,4 +10,5 @@ This section contains documentation for all available providers for the **files* :maxdepth: 1 inline_localfs +remote_s3 ``` diff --git a/docs/source/providers/files/remote_s3.md b/docs/source/providers/files/remote_s3.md new file mode 100644 index 000000000..2e3cebabd --- /dev/null +++ b/docs/source/providers/files/remote_s3.md @@ -0,0 +1,33 @@ +# remote::s3 + +## Description + +AWS S3-based file storage provider for scalable cloud file management with metadata persistence. + +## Configuration + +| Field | Type | Required | Default | Description | +|-------|------|----------|---------|-------------| +| `bucket_name` | `` | No | | S3 bucket name to store files | +| `region` | `` | No | us-east-1 | AWS region where the bucket is located | +| `aws_access_key_id` | `str \| None` | No | | AWS access key ID (optional if using IAM roles) | +| `aws_secret_access_key` | `str \| None` | No | | AWS secret access key (optional if using IAM roles) | +| `endpoint_url` | `str \| None` | No | | Custom S3 endpoint URL (for MinIO, LocalStack, etc.) | +| `auto_create_bucket` | `` | No | False | Automatically create the S3 bucket if it doesn't exist | +| `metadata_store` | `utils.sqlstore.sqlstore.SqliteSqlStoreConfig \| utils.sqlstore.sqlstore.PostgresSqlStoreConfig` | No | sqlite | SQL store configuration for file metadata | + +## Sample Configuration + +```yaml +bucket_name: ${env.S3_BUCKET_NAME} +region: ${env.AWS_REGION:=us-east-1} +aws_access_key_id: ${env.AWS_ACCESS_KEY_ID:=} +aws_secret_access_key: ${env.AWS_SECRET_ACCESS_KEY:=} +endpoint_url: ${env.S3_ENDPOINT_URL:=} +auto_create_bucket: ${env.S3_AUTO_CREATE_BUCKET:=false} +metadata_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/s3_files_metadata.db + +``` + diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 7e7bd0a3d..19630bfb8 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -473,6 +473,28 @@ class EmbeddingsResponse(BaseModel): embeddings: list[list[float]] +@json_schema_type +class RerankData(BaseModel): + """A single rerank result from a reranking response. + + :param index: The original index of the document in the input list + :param relevance_score: The relevance score from the model output. Values are inverted when applicable so that higher scores indicate greater relevance. + """ + + index: int + relevance_score: float + + +@json_schema_type +class RerankResponse(BaseModel): + """Response from a reranking request. + + :param data: List of rerank result objects, sorted by relevance score (descending) + """ + + data: list[RerankData] + + @json_schema_type class OpenAIChatCompletionContentPartTextParam(BaseModel): """Text content part for OpenAI-compatible chat completion messages. @@ -1131,6 +1153,24 @@ class InferenceProvider(Protocol): """ ... + @webmethod(route="/inference/rerank", method="POST", experimental=True) + async def rerank( + self, + model: str, + query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam, + items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam], + max_num_results: int | None = None, + ) -> RerankResponse: + """Rerank a list of documents based on their relevance to a query. + + :param model: The identifier of the reranking model to use. + :param query: The search query to rank items against. Can be a string, text content part, or image content part. The input must not exceed the model's max input token length. + :param items: List of items to rerank. Each item can be a string, text content part, or image content part. Each input must not exceed the model's max input token length. + :param max_num_results: (Optional) Maximum number of results to return. Default: returns all. + :returns: RerankResponse with indices sorted by relevance score (descending). + """ + raise NotImplementedError("Reranking is not implemented") + @webmethod(route="/openai/v1/completions", method="POST") async def openai_completion( self, diff --git a/llama_stack/cli/stack/run.py b/llama_stack/cli/stack/run.py index c8ffce034..b32b8b3ae 100644 --- a/llama_stack/cli/stack/run.py +++ b/llama_stack/cli/stack/run.py @@ -15,7 +15,7 @@ from llama_stack.log import get_logger REPO_ROOT = Path(__file__).parent.parent.parent.parent -logger = get_logger(name=__name__, category="server") +logger = get_logger(name=__name__, category="cli") class StackRun(Subcommand): diff --git a/llama_stack/core/datatypes.py b/llama_stack/core/datatypes.py index a1b6ad32b..c3940fcbd 100644 --- a/llama_stack/core/datatypes.py +++ b/llama_stack/core/datatypes.py @@ -318,6 +318,41 @@ class QuotaConfig(BaseModel): period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set") +class CORSConfig(BaseModel): + allow_origins: list[str] = Field(default_factory=list) + allow_origin_regex: str | None = Field(default=None) + allow_methods: list[str] = Field(default=["OPTIONS"]) + allow_headers: list[str] = Field(default_factory=list) + allow_credentials: bool = Field(default=False) + expose_headers: list[str] = Field(default_factory=list) + max_age: int = Field(default=600, ge=0) + + @model_validator(mode="after") + def validate_credentials_config(self) -> Self: + if self.allow_credentials and (self.allow_origins == ["*"] or "*" in self.allow_origins): + raise ValueError("Cannot use wildcard origins with credentials enabled") + return self + + +def process_cors_config(cors_config: bool | CORSConfig | None) -> CORSConfig | None: + if cors_config is False or cors_config is None: + return None + + if cors_config is True: + # dev mode: allow localhost on any port + return CORSConfig( + allow_origins=[], + allow_origin_regex=r"https?://localhost:\d+", + allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], + allow_headers=["Content-Type", "Authorization", "X-Requested-With"], + ) + + if isinstance(cors_config, CORSConfig): + return cors_config + + raise ValueError(f"Expected bool or CORSConfig, got {type(cors_config).__name__}") + + class ServerConfig(BaseModel): port: int = Field( default=8321, @@ -349,6 +384,12 @@ class ServerConfig(BaseModel): default=None, description="Per client quota request configuration", ) + cors: bool | CORSConfig | None = Field( + default=None, + description="CORS configuration for cross-origin requests. Can be:\n" + "- true: Enable localhost CORS for development\n" + "- {allow_origins: [...], allow_methods: [...], ...}: Full configuration", + ) class StackRunConfig(BaseModel): diff --git a/llama_stack/core/library_client.py b/llama_stack/core/library_client.py index dd1fc8a50..9e7a8006c 100644 --- a/llama_stack/core/library_client.py +++ b/llama_stack/core/library_client.py @@ -146,39 +146,26 @@ class LlamaStackAsLibraryClient(LlamaStackClient): ): super().__init__() self.async_client = AsyncLlamaStackAsLibraryClient( - config_path_or_distro_name, custom_provider_registry, provider_data + config_path_or_distro_name, custom_provider_registry, provider_data, skip_logger_removal ) self.pool_executor = ThreadPoolExecutor(max_workers=4) - self.skip_logger_removal = skip_logger_removal self.provider_data = provider_data self.loop = asyncio.new_event_loop() - def initialize(self): - if in_notebook(): - import nest_asyncio - - nest_asyncio.apply() - if not self.skip_logger_removal: - self._remove_root_logger_handlers() - # use a new event loop to avoid interfering with the main event loop loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: - return loop.run_until_complete(self.async_client.initialize()) + loop.run_until_complete(self.async_client.initialize()) finally: asyncio.set_event_loop(None) - def _remove_root_logger_handlers(self): + def initialize(self): """ - Remove all handlers from the root logger. Needed to avoid polluting the console with logs. + Deprecated method for backward compatibility. """ - root_logger = logging.getLogger() - - for handler in root_logger.handlers[:]: - root_logger.removeHandler(handler) - logger.info(f"Removed handler {handler.__class__.__name__} from root logger") + pass def request(self, *args, **kwargs): loop = self.loop @@ -216,6 +203,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): config_path_or_distro_name: str, custom_provider_registry: ProviderRegistry | None = None, provider_data: dict[str, Any] | None = None, + skip_logger_removal: bool = False, ): super().__init__() # when using the library client, we should not log to console since many @@ -223,6 +211,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",") os.environ["TELEMETRY_SINKS"] = ",".join(sink for sink in current_sinks if sink != "console") + if in_notebook(): + import nest_asyncio + + nest_asyncio.apply() + if not skip_logger_removal: + self._remove_root_logger_handlers() + if config_path_or_distro_name.endswith(".yaml"): config_path = Path(config_path_or_distro_name) if not config_path.exists(): @@ -239,7 +234,24 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): self.provider_data = provider_data self.route_impls: RouteImpls | None = None # Initialize to None to prevent AttributeError + def _remove_root_logger_handlers(self): + """ + Remove all handlers from the root logger. Needed to avoid polluting the console with logs. + """ + root_logger = logging.getLogger() + + for handler in root_logger.handlers[:]: + root_logger.removeHandler(handler) + logger.info(f"Removed handler {handler.__class__.__name__} from root logger") + async def initialize(self) -> bool: + """ + Initialize the async client. + + Returns: + bool: True if initialization was successful + """ + try: self.route_impls = None self.impls = await construct_stack(self.config, self.custom_provider_registry) diff --git a/llama_stack/core/routers/datasets.py b/llama_stack/core/routers/datasets.py index d7984f729..2f1d5f78e 100644 --- a/llama_stack/core/routers/datasets.py +++ b/llama_stack/core/routers/datasets.py @@ -12,7 +12,7 @@ from llama_stack.apis.datasets import DatasetPurpose, DataSource from llama_stack.log import get_logger from llama_stack.providers.datatypes import RoutingTable -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routers") class DatasetIORouter(DatasetIO): diff --git a/llama_stack/core/routers/eval_scoring.py b/llama_stack/core/routers/eval_scoring.py index f7a17eecf..ffca81bf0 100644 --- a/llama_stack/core/routers/eval_scoring.py +++ b/llama_stack/core/routers/eval_scoring.py @@ -16,7 +16,7 @@ from llama_stack.apis.scoring import ( from llama_stack.log import get_logger from llama_stack.providers.datatypes import RoutingTable -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routers") class ScoringRouter(Scoring): diff --git a/llama_stack/core/routers/inference.py b/llama_stack/core/routers/inference.py index 6a3f07247..4b66601bb 100644 --- a/llama_stack/core/routers/inference.py +++ b/llama_stack/core/routers/inference.py @@ -65,7 +65,7 @@ from llama_stack.providers.datatypes import HealthResponse, HealthStatus, Routin from llama_stack.providers.utils.inference.inference_store import InferenceStore from llama_stack.providers.utils.telemetry.tracing import get_current_span -logger = get_logger(name=__name__, category="inference") +logger = get_logger(name=__name__, category="core::routers") class InferenceRouter(Inference): diff --git a/llama_stack/core/routers/safety.py b/llama_stack/core/routers/safety.py index 738ecded3..9ba3327f1 100644 --- a/llama_stack/core/routers/safety.py +++ b/llama_stack/core/routers/safety.py @@ -13,7 +13,7 @@ from llama_stack.apis.shields import Shield from llama_stack.log import get_logger from llama_stack.providers.datatypes import RoutingTable -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routers") class SafetyRouter(Safety): diff --git a/llama_stack/core/routers/tool_runtime.py b/llama_stack/core/routers/tool_runtime.py index 5a40bc0c5..fd606f33b 100644 --- a/llama_stack/core/routers/tool_runtime.py +++ b/llama_stack/core/routers/tool_runtime.py @@ -22,7 +22,7 @@ from llama_stack.log import get_logger from ..routing_tables.toolgroups import ToolGroupsRoutingTable -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routers") class ToolRuntimeRouter(ToolRuntime): diff --git a/llama_stack/core/routers/vector_io.py b/llama_stack/core/routers/vector_io.py index 3d0996c49..786b0e391 100644 --- a/llama_stack/core/routers/vector_io.py +++ b/llama_stack/core/routers/vector_io.py @@ -30,7 +30,7 @@ from llama_stack.apis.vector_io import ( from llama_stack.log import get_logger from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routers") class VectorIORouter(VectorIO): diff --git a/llama_stack/core/routing_tables/benchmarks.py b/llama_stack/core/routing_tables/benchmarks.py index 74bee8040..c875dee5b 100644 --- a/llama_stack/core/routing_tables/benchmarks.py +++ b/llama_stack/core/routing_tables/benchmarks.py @@ -14,7 +14,7 @@ from llama_stack.log import get_logger from .common import CommonRoutingTableImpl -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routing_tables") class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks): diff --git a/llama_stack/core/routing_tables/common.py b/llama_stack/core/routing_tables/common.py index 339ff6da4..e523746d8 100644 --- a/llama_stack/core/routing_tables/common.py +++ b/llama_stack/core/routing_tables/common.py @@ -23,7 +23,7 @@ from llama_stack.core.store import DistributionRegistry from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api, RoutingTable -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routing_tables") def get_impl_api(p: Any) -> Api: diff --git a/llama_stack/core/routing_tables/datasets.py b/llama_stack/core/routing_tables/datasets.py index fc6a75df4..b129c9ec5 100644 --- a/llama_stack/core/routing_tables/datasets.py +++ b/llama_stack/core/routing_tables/datasets.py @@ -26,7 +26,7 @@ from llama_stack.log import get_logger from .common import CommonRoutingTableImpl -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routing_tables") class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): diff --git a/llama_stack/core/routing_tables/models.py b/llama_stack/core/routing_tables/models.py index 34c431e00..b6141efa9 100644 --- a/llama_stack/core/routing_tables/models.py +++ b/llama_stack/core/routing_tables/models.py @@ -17,7 +17,7 @@ from llama_stack.log import get_logger from .common import CommonRoutingTableImpl, lookup_model -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routing_tables") class ModelsRoutingTable(CommonRoutingTableImpl, Models): diff --git a/llama_stack/core/routing_tables/scoring_functions.py b/llama_stack/core/routing_tables/scoring_functions.py index 5874ba941..71e5bed63 100644 --- a/llama_stack/core/routing_tables/scoring_functions.py +++ b/llama_stack/core/routing_tables/scoring_functions.py @@ -19,7 +19,7 @@ from llama_stack.log import get_logger from .common import CommonRoutingTableImpl -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routing_tables") class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): diff --git a/llama_stack/core/routing_tables/shields.py b/llama_stack/core/routing_tables/shields.py index e08f35bfc..b1918d20a 100644 --- a/llama_stack/core/routing_tables/shields.py +++ b/llama_stack/core/routing_tables/shields.py @@ -15,7 +15,7 @@ from llama_stack.log import get_logger from .common import CommonRoutingTableImpl -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routing_tables") class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): diff --git a/llama_stack/core/routing_tables/toolgroups.py b/llama_stack/core/routing_tables/toolgroups.py index 6910b3906..eeea406c1 100644 --- a/llama_stack/core/routing_tables/toolgroups.py +++ b/llama_stack/core/routing_tables/toolgroups.py @@ -14,7 +14,7 @@ from llama_stack.log import get_logger from .common import CommonRoutingTableImpl -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routing_tables") def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name: str) -> str | None: diff --git a/llama_stack/core/routing_tables/vector_dbs.py b/llama_stack/core/routing_tables/vector_dbs.py index e8dc46997..00f71b4fe 100644 --- a/llama_stack/core/routing_tables/vector_dbs.py +++ b/llama_stack/core/routing_tables/vector_dbs.py @@ -30,7 +30,7 @@ from llama_stack.log import get_logger from .common import CommonRoutingTableImpl, lookup_model -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="core::routing_tables") class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): diff --git a/llama_stack/core/server/auth.py b/llama_stack/core/server/auth.py index e4fb4ff2b..c98d3bec0 100644 --- a/llama_stack/core/server/auth.py +++ b/llama_stack/core/server/auth.py @@ -15,7 +15,7 @@ from llama_stack.core.server.auth_providers import create_auth_provider from llama_stack.core.server.routes import find_matching_route, initialize_route_impls from llama_stack.log import get_logger -logger = get_logger(name=__name__, category="auth") +logger = get_logger(name=__name__, category="core::auth") class AuthenticationMiddleware: diff --git a/llama_stack/core/server/auth_providers.py b/llama_stack/core/server/auth_providers.py index 73d5581c2..a8af6f75a 100644 --- a/llama_stack/core/server/auth_providers.py +++ b/llama_stack/core/server/auth_providers.py @@ -23,7 +23,7 @@ from llama_stack.core.datatypes import ( ) from llama_stack.log import get_logger -logger = get_logger(name=__name__, category="auth") +logger = get_logger(name=__name__, category="core::auth") class AuthResponse(BaseModel): diff --git a/llama_stack/core/server/quota.py b/llama_stack/core/server/quota.py index 1cb850cde..693f224c3 100644 --- a/llama_stack/core/server/quota.py +++ b/llama_stack/core/server/quota.py @@ -15,7 +15,7 @@ from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl -logger = get_logger(name=__name__, category="quota") +logger = get_logger(name=__name__, category="core::server") class QuotaMiddleware: diff --git a/llama_stack/core/server/server.py b/llama_stack/core/server/server.py index 3d94b6e81..d6dfc3435 100644 --- a/llama_stack/core/server/server.py +++ b/llama_stack/core/server/server.py @@ -28,6 +28,7 @@ from aiohttp import hdrs from fastapi import Body, FastAPI, HTTPException, Request, Response from fastapi import Path as FastapiPath from fastapi.exceptions import RequestValidationError +from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse from openai import BadRequestError from pydantic import BaseModel, ValidationError @@ -40,6 +41,7 @@ from llama_stack.core.datatypes import ( AuthenticationRequiredError, LoggingConfig, StackRunConfig, + process_cors_config, ) from llama_stack.core.distribution import builtin_automatically_routed_apis from llama_stack.core.external import ExternalApiSpec, load_external_apis @@ -82,7 +84,7 @@ from .quota import QuotaMiddleware REPO_ROOT = Path(__file__).parent.parent.parent.parent -logger = get_logger(name=__name__, category="server") +logger = get_logger(name=__name__, category="core::server") def warn_with_traceback(message, category, filename, lineno, file=None, line=None): @@ -413,7 +415,7 @@ def main(args: argparse.Namespace | None = None): config_contents = yaml.safe_load(fp) if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")): logger_config = LoggingConfig(**cfg) - logger = get_logger(name=__name__, category="server", config=logger_config) + logger = get_logger(name=__name__, category="core::server", config=logger_config) if args.env: for env_pair in args.env: try: @@ -483,6 +485,12 @@ def main(args: argparse.Namespace | None = None): window_seconds=window_seconds, ) + if config.server.cors: + logger.info("Enabling CORS") + cors_config = process_cors_config(config.server.cors) + if cors_config: + app.add_middleware(CORSMiddleware, **cors_config.model_dump()) + if Api.telemetry in impls: setup_logger(impls[Api.telemetry]) else: diff --git a/llama_stack/core/store/registry.py b/llama_stack/core/store/registry.py index 4b60e1001..5f4abe9aa 100644 --- a/llama_stack/core/store/registry.py +++ b/llama_stack/core/store/registry.py @@ -16,7 +16,7 @@ from llama_stack.log import get_logger from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig -logger = get_logger(__name__, category="core") +logger = get_logger(__name__, category="core::registry") class DistributionRegistry(Protocol): diff --git a/llama_stack/core/utils/config_resolution.py b/llama_stack/core/utils/config_resolution.py index 30cd71e15..182a571ee 100644 --- a/llama_stack/core/utils/config_resolution.py +++ b/llama_stack/core/utils/config_resolution.py @@ -10,7 +10,7 @@ from pathlib import Path from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR from llama_stack.log import get_logger -logger = get_logger(name=__name__, category="config_resolution") +logger = get_logger(name=__name__, category="core") DISTRO_DIR = Path(__file__).parent.parent.parent.parent / "llama_stack" / "distributions" diff --git a/llama_stack/models/llama/llama3/multimodal/model.py b/llama_stack/models/llama/llama3/multimodal/model.py index 096156a5f..7b501eb0e 100644 --- a/llama_stack/models/llama/llama3/multimodal/model.py +++ b/llama_stack/models/llama/llama3/multimodal/model.py @@ -36,7 +36,7 @@ from .utils import get_negative_inf_value, to_2tuple MP_SCALE = 8 -logger = get_logger(name=__name__, category="models") +logger = get_logger(name=__name__, category="models::llama") def reduce_from_tensor_model_parallel_region(input_): diff --git a/llama_stack/models/llama/llama3/tool_utils.py b/llama_stack/models/llama/llama3/tool_utils.py index 574080184..d0e3e7671 100644 --- a/llama_stack/models/llama/llama3/tool_utils.py +++ b/llama_stack/models/llama/llama3/tool_utils.py @@ -11,7 +11,7 @@ from llama_stack.log import get_logger from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat -logger = get_logger(name=__name__, category="inference") +logger = get_logger(name=__name__, category="models::llama") BUILTIN_TOOL_PATTERN = r'\b(?P\w+)\.call\(query="(?P[^"]*)"\)' CUSTOM_TOOL_CALL_PATTERN = re.compile(r"[^}]+)>(?P{.*?})") diff --git a/llama_stack/models/llama/llama4/quantization/loader.py b/llama_stack/models/llama/llama4/quantization/loader.py index 8220a9040..7557a8a64 100644 --- a/llama_stack/models/llama/llama4/quantization/loader.py +++ b/llama_stack/models/llama/llama4/quantization/loader.py @@ -18,7 +18,7 @@ from ...datatypes import QuantizationMode from ..model import Transformer, TransformerBlock from ..moe import MoE -log = get_logger(name=__name__, category="models") +log = get_logger(name=__name__, category="models::llama") def swiglu_wrapper_no_reduce( diff --git a/llama_stack/models/llama/quantize_impls.py b/llama_stack/models/llama/quantize_impls.py index 7fab2d3a6..0a205601f 100644 --- a/llama_stack/models/llama/quantize_impls.py +++ b/llama_stack/models/llama/quantize_impls.py @@ -9,7 +9,7 @@ import collections from llama_stack.log import get_logger -log = get_logger(name=__name__, category="llama") +log = get_logger(name=__name__, category="models::llama") try: import fbgemm_gpu.experimental.gen_ai # noqa: F401 diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 5f7c90879..fde38515b 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -84,7 +84,7 @@ MEMORY_QUERY_TOOL = "knowledge_search" WEB_SEARCH_TOOL = "web_search" RAG_TOOL_GROUP = "builtin::rag" -logger = get_logger(name=__name__, category="agents") +logger = get_logger(name=__name__, category="agents::meta_reference") class ChatAgent(ShieldRunnerMixin): diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index 5794ad2c0..8bdde86b0 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -51,7 +51,7 @@ from .config import MetaReferenceAgentsImplConfig from .persistence import AgentInfo from .responses.openai_responses import OpenAIResponsesImpl -logger = get_logger(name=__name__, category="agents") +logger = get_logger(name=__name__, category="agents::meta_reference") class MetaReferenceAgentsImpl(Agents): diff --git a/llama_stack/providers/inline/agents/meta_reference/persistence.py b/llama_stack/providers/inline/agents/meta_reference/persistence.py index c19051f86..3b7b4729c 100644 --- a/llama_stack/providers/inline/agents/meta_reference/persistence.py +++ b/llama_stack/providers/inline/agents/meta_reference/persistence.py @@ -17,7 +17,7 @@ from llama_stack.core.request_headers import get_authenticated_user from llama_stack.log import get_logger from llama_stack.providers.utils.kvstore import KVStore -log = get_logger(name=__name__, category="agents") +log = get_logger(name=__name__, category="agents::meta_reference") class AgentSessionInfo(Session): diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py b/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py index e528a4005..c632e61aa 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py @@ -41,7 +41,7 @@ from .utils import ( convert_response_text_to_chat_response_format, ) -logger = get_logger(name=__name__, category="responses") +logger = get_logger(name=__name__, category="openai::responses") class OpenAIResponsePreviousResponseWithInputItems(BaseModel): diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py index 0879e978a..3e69fa5cd 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py @@ -47,7 +47,7 @@ from llama_stack.log import get_logger from .types import ChatCompletionContext, ChatCompletionResult from .utils import convert_chat_choice_to_response_message, is_function_tool_call -logger = get_logger(name=__name__, category="responses") +logger = get_logger(name=__name__, category="agents::meta_reference") class StreamingResponseOrchestrator: diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py b/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py index 5b98b4f51..b028c018b 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py @@ -38,7 +38,7 @@ from llama_stack.log import get_logger from .types import ChatCompletionContext, ToolExecutionResult -logger = get_logger(name=__name__, category="responses") +logger = get_logger(name=__name__, category="agents::meta_reference") class ToolExecutor: diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/utils.py b/llama_stack/providers/inline/agents/meta_reference/responses/utils.py index 1507a55c8..7aaeb4cd5 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/utils.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/utils.py @@ -17,6 +17,8 @@ from llama_stack.apis.agents.openai_responses import ( OpenAIResponseOutputMessageContent, OpenAIResponseOutputMessageContentOutputText, OpenAIResponseOutputMessageFunctionToolCall, + OpenAIResponseOutputMessageMCPCall, + OpenAIResponseOutputMessageMCPListTools, OpenAIResponseText, ) from llama_stack.apis.inference import ( @@ -99,14 +101,22 @@ async def convert_response_input_to_chat_messages( """ messages: list[OpenAIMessageParam] = [] if isinstance(input, list): + # extract all OpenAIResponseInputFunctionToolCallOutput items + # so their corresponding OpenAIToolMessageParam instances can + # be added immediately following the corresponding + # OpenAIAssistantMessageParam + tool_call_results = {} for input_item in input: if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput): - messages.append( - OpenAIToolMessageParam( - content=input_item.output, - tool_call_id=input_item.call_id, - ) + tool_call_results[input_item.call_id] = OpenAIToolMessageParam( + content=input_item.output, + tool_call_id=input_item.call_id, ) + + for input_item in input: + if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput): + # skip as these have been extracted and inserted in order + pass elif isinstance(input_item, OpenAIResponseOutputMessageFunctionToolCall): tool_call = OpenAIChatCompletionToolCall( index=0, @@ -117,6 +127,28 @@ async def convert_response_input_to_chat_messages( ), ) messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call])) + if input_item.call_id in tool_call_results: + messages.append(tool_call_results[input_item.call_id]) + del tool_call_results[input_item.call_id] + elif isinstance(input_item, OpenAIResponseOutputMessageMCPCall): + tool_call = OpenAIChatCompletionToolCall( + index=0, + id=input_item.id, + function=OpenAIChatCompletionToolCallFunction( + name=input_item.name, + arguments=input_item.arguments, + ), + ) + messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call])) + messages.append( + OpenAIToolMessageParam( + content=input_item.output, + tool_call_id=input_item.id, + ) + ) + elif isinstance(input_item, OpenAIResponseOutputMessageMCPListTools): + # the tool list will be handled separately + pass else: content = await convert_response_content_to_chat_content(input_item.content) message_type = await get_message_type_by_role(input_item.role) @@ -125,6 +157,10 @@ async def convert_response_input_to_chat_messages( f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context" ) messages.append(message_type(content=content)) + if len(tool_call_results): + raise ValueError( + f"Received function_call_output(s) with call_id(s) {tool_call_results.keys()}, but no corresponding function_call" + ) else: messages.append(OpenAIUserMessageParam(content=input)) return messages diff --git a/llama_stack/providers/inline/agents/meta_reference/safety.py b/llama_stack/providers/inline/agents/meta_reference/safety.py index b8a5d8a95..8f3ecf5c9 100644 --- a/llama_stack/providers/inline/agents/meta_reference/safety.py +++ b/llama_stack/providers/inline/agents/meta_reference/safety.py @@ -11,7 +11,7 @@ from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel from llama_stack.log import get_logger from llama_stack.providers.utils.telemetry import tracing -log = get_logger(name=__name__, category="agents") +log = get_logger(name=__name__, category="agents::meta_reference") class SafetyException(Exception): # noqa: N818 diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index 88d7a98ec..904a343d5 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -33,6 +33,9 @@ from llama_stack.apis.inference import ( InterleavedContent, LogProbConfig, Message, + OpenAIChatCompletionContentPartImageParam, + OpenAIChatCompletionContentPartTextParam, + RerankResponse, ResponseFormat, SamplingParams, StopReason, @@ -442,6 +445,15 @@ class MetaReferenceInferenceImpl( results = await self._nonstream_chat_completion(request_batch) return BatchChatCompletionResponse(batch=results) + async def rerank( + self, + model: str, + query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam, + items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam], + max_num_results: int | None = None, + ) -> RerankResponse: + raise NotImplementedError("Reranking is not supported for Meta Reference") + async def _nonstream_chat_completion( self, request_batch: list[ChatCompletionRequest] ) -> list[ChatCompletionResponse]: diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py index 600a5bd37..4b68cc926 100644 --- a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -12,6 +12,9 @@ from llama_stack.apis.inference import ( InterleavedContent, LogProbConfig, Message, + OpenAIChatCompletionContentPartImageParam, + OpenAIChatCompletionContentPartTextParam, + RerankResponse, ResponseFormat, SamplingParams, ToolChoice, @@ -122,3 +125,12 @@ class SentenceTransformersInferenceImpl( logprobs: LogProbConfig | None = None, ): raise NotImplementedError("Batch chat completion is not supported for Sentence Transformers") + + async def rerank( + self, + model: str, + query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam, + items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam], + max_num_results: int | None = None, + ) -> RerankResponse: + raise NotImplementedError("Reranking is not supported for Sentence Transformers") diff --git a/llama_stack/providers/registry/files.py b/llama_stack/providers/registry/files.py index e894debaf..ebe90310c 100644 --- a/llama_stack/providers/registry/files.py +++ b/llama_stack/providers/registry/files.py @@ -5,9 +5,11 @@ # the root directory of this source tree. from llama_stack.providers.datatypes import ( + AdapterSpec, Api, InlineProviderSpec, ProviderSpec, + remote_provider_spec, ) from llama_stack.providers.utils.sqlstore.sqlstore import sql_store_pip_packages @@ -23,4 +25,14 @@ def available_providers() -> list[ProviderSpec]: config_class="llama_stack.providers.inline.files.localfs.config.LocalfsFilesImplConfig", description="Local filesystem-based file storage provider for managing files and documents locally.", ), + remote_provider_spec( + api=Api.files, + adapter=AdapterSpec( + adapter_type="s3", + pip_packages=["boto3"] + sql_store_pip_packages, + module="llama_stack.providers.remote.files.s3", + config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig", + description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.", + ), + ), ] diff --git a/llama_stack/providers/remote/files/s3/README.md b/llama_stack/providers/remote/files/s3/README.md new file mode 100644 index 000000000..0f33122c7 --- /dev/null +++ b/llama_stack/providers/remote/files/s3/README.md @@ -0,0 +1,237 @@ +# S3 Files Provider + +A remote S3-based implementation of the Llama Stack Files API that provides scalable cloud file storage with metadata persistence. + +## Features + +- **AWS S3 Storage**: Store files in AWS S3 buckets for scalable, durable storage +- **Metadata Management**: Uses SQL database for efficient file metadata queries +- **OpenAI API Compatibility**: Full compatibility with OpenAI Files API endpoints +- **Flexible Authentication**: Support for IAM roles and access keys +- **Custom S3 Endpoints**: Support for MinIO and other S3-compatible services + +## Configuration + +### Basic Configuration + +```yaml +api: files +provider_type: remote::s3 +config: + bucket_name: my-llama-stack-files + region: us-east-1 + metadata_store: + type: sqlite + db_path: ./s3_files_metadata.db +``` + +### Advanced Configuration + +```yaml +api: files +provider_type: remote::s3 +config: + bucket_name: my-llama-stack-files + region: us-east-1 + aws_access_key_id: YOUR_ACCESS_KEY + aws_secret_access_key: YOUR_SECRET_KEY + endpoint_url: https://s3.amazonaws.com # Optional for custom endpoints + metadata_store: + type: sqlite + db_path: ./s3_files_metadata.db +``` + +### Environment Variables + +The configuration supports environment variable substitution: + +```yaml +config: + bucket_name: "${env.S3_BUCKET_NAME}" + region: "${env.AWS_REGION:=us-east-1}" + aws_access_key_id: "${env.AWS_ACCESS_KEY_ID:=}" + aws_secret_access_key: "${env.AWS_SECRET_ACCESS_KEY:=}" + endpoint_url: "${env.S3_ENDPOINT_URL:=}" +``` + +Note: `S3_BUCKET_NAME` has no default value since S3 bucket names must be globally unique. + +## Authentication + +### IAM Roles (Recommended) + +For production deployments, use IAM roles: + +```yaml +config: + bucket_name: my-bucket + region: us-east-1 + # No credentials needed - will use IAM role +``` + +### Access Keys + +For development or specific use cases: + +```yaml +config: + bucket_name: my-bucket + region: us-east-1 + aws_access_key_id: AKIAIOSFODNN7EXAMPLE + aws_secret_access_key: wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY +``` + +## S3 Bucket Setup + +### Required Permissions + +The S3 provider requires the following permissions: + +```json +{ + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": [ + "s3:GetObject", + "s3:PutObject", + "s3:DeleteObject", + "s3:ListBucket" + ], + "Resource": [ + "arn:aws:s3:::your-bucket-name", + "arn:aws:s3:::your-bucket-name/*" + ] + } + ] +} +``` + +### Automatic Bucket Creation + +By default, the S3 provider expects the bucket to already exist. If you want the provider to automatically create the bucket when it doesn't exist, set `auto_create_bucket: true` in your configuration: + +```yaml +config: + bucket_name: my-bucket + auto_create_bucket: true # Will create bucket if it doesn't exist + region: us-east-1 +``` + +**Note**: When `auto_create_bucket` is enabled, the provider will need additional permissions: + +```json +{ + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": [ + "s3:GetObject", + "s3:PutObject", + "s3:DeleteObject", + "s3:ListBucket", + "s3:CreateBucket" + ], + "Resource": [ + "arn:aws:s3:::your-bucket-name", + "arn:aws:s3:::your-bucket-name/*" + ] + } + ] +} +``` + +### Bucket Policy (Optional) + +For additional security, you can add a bucket policy: + +```json +{ + "Version": "2012-10-17", + "Statement": [ + { + "Sid": "LlamaStackAccess", + "Effect": "Allow", + "Principal": { + "AWS": "arn:aws:iam::YOUR-ACCOUNT:role/LlamaStackRole" + }, + "Action": [ + "s3:GetObject", + "s3:PutObject", + "s3:DeleteObject" + ], + "Resource": "arn:aws:s3:::your-bucket-name/*" + }, + { + "Sid": "LlamaStackBucketAccess", + "Effect": "Allow", + "Principal": { + "AWS": "arn:aws:iam::YOUR-ACCOUNT:role/LlamaStackRole" + }, + "Action": [ + "s3:ListBucket" + ], + "Resource": "arn:aws:s3:::your-bucket-name" + } + ] +} +``` + +## Features + +### Metadata Persistence + +File metadata is stored in a SQL database for fast queries and OpenAI API compatibility. The metadata includes: + +- File ID +- Original filename +- Purpose (assistants, batch, etc.) +- File size in bytes +- Created and expiration timestamps + +### TTL and Cleanup + +Files currently have a fixed long expiration time (100 years). + +## Development and Testing + +### Using MinIO + +For self-hosted S3-compatible storage: + +```yaml +config: + bucket_name: test-bucket + region: us-east-1 + endpoint_url: http://localhost:9000 + aws_access_key_id: minioadmin + aws_secret_access_key: minioadmin +``` + +## Monitoring and Logging + +The provider logs important operations and errors. For production deployments, consider: + +- CloudWatch monitoring for S3 operations +- Custom metrics for file upload/download rates +- Error rate monitoring +- Performance metrics tracking + +## Error Handling + +The provider handles various error scenarios: + +- S3 connectivity issues +- Bucket access permissions +- File not found errors +- Metadata consistency checks + +## Known Limitations + +- Fixed long TTL (100 years) instead of configurable expiration +- No server-side encryption enabled by default +- No support for AWS session tokens +- No S3 key prefix organization support +- No multipart upload support (all files uploaded as single objects) diff --git a/llama_stack/providers/remote/files/s3/__init__.py b/llama_stack/providers/remote/files/s3/__init__.py new file mode 100644 index 000000000..3f5dfc88a --- /dev/null +++ b/llama_stack/providers/remote/files/s3/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from llama_stack.core.datatypes import Api + +from .config import S3FilesImplConfig + + +async def get_adapter_impl(config: S3FilesImplConfig, deps: dict[Api, Any]): + from .files import S3FilesImpl + + # TODO: authorization policies and user separation + impl = S3FilesImpl(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/remote/files/s3/config.py b/llama_stack/providers/remote/files/s3/config.py new file mode 100644 index 000000000..da20d8668 --- /dev/null +++ b/llama_stack/providers/remote/files/s3/config.py @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from pydantic import BaseModel, Field + +from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig, SqlStoreConfig + + +class S3FilesImplConfig(BaseModel): + """Configuration for S3-based files provider.""" + + bucket_name: str = Field(description="S3 bucket name to store files") + region: str = Field(default="us-east-1", description="AWS region where the bucket is located") + aws_access_key_id: str | None = Field(default=None, description="AWS access key ID (optional if using IAM roles)") + aws_secret_access_key: str | None = Field( + default=None, description="AWS secret access key (optional if using IAM roles)" + ) + endpoint_url: str | None = Field(default=None, description="Custom S3 endpoint URL (for MinIO, LocalStack, etc.)") + auto_create_bucket: bool = Field( + default=False, description="Automatically create the S3 bucket if it doesn't exist" + ) + metadata_store: SqlStoreConfig = Field(description="SQL store configuration for file metadata") + + @classmethod + def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]: + return { + "bucket_name": "${env.S3_BUCKET_NAME}", # no default, buckets must be globally unique + "region": "${env.AWS_REGION:=us-east-1}", + "aws_access_key_id": "${env.AWS_ACCESS_KEY_ID:=}", + "aws_secret_access_key": "${env.AWS_SECRET_ACCESS_KEY:=}", + "endpoint_url": "${env.S3_ENDPOINT_URL:=}", + "auto_create_bucket": "${env.S3_AUTO_CREATE_BUCKET:=false}", + "metadata_store": SqliteSqlStoreConfig.sample_run_config( + __distro_dir__=__distro_dir__, + db_name="s3_files_metadata.db", + ), + } diff --git a/llama_stack/providers/remote/files/s3/files.py b/llama_stack/providers/remote/files/s3/files.py new file mode 100644 index 000000000..52e0cbbf4 --- /dev/null +++ b/llama_stack/providers/remote/files/s3/files.py @@ -0,0 +1,272 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import time +import uuid +from typing import Annotated + +import boto3 +from botocore.exceptions import BotoCoreError, ClientError, NoCredentialsError +from fastapi import File, Form, Response, UploadFile + +from llama_stack.apis.common.errors import ResourceNotFoundError +from llama_stack.apis.common.responses import Order +from llama_stack.apis.files import ( + Files, + ListOpenAIFileResponse, + OpenAIFileDeleteResponse, + OpenAIFileObject, + OpenAIFilePurpose, +) +from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType +from llama_stack.providers.utils.sqlstore.sqlstore import SqlStore, sqlstore_impl + +from .config import S3FilesImplConfig + +# TODO: provider data for S3 credentials + + +def _create_s3_client(config: S3FilesImplConfig) -> boto3.client: + try: + s3_config = { + "region_name": config.region, + } + + # endpoint URL if specified (for MinIO, LocalStack, etc.) + if config.endpoint_url: + s3_config["endpoint_url"] = config.endpoint_url + + if config.aws_access_key_id and config.aws_secret_access_key: + s3_config.update( + { + "aws_access_key_id": config.aws_access_key_id, + "aws_secret_access_key": config.aws_secret_access_key, + } + ) + + return boto3.client("s3", **s3_config) + + except (BotoCoreError, NoCredentialsError) as e: + raise RuntimeError(f"Failed to initialize S3 client: {e}") from e + + +async def _create_bucket_if_not_exists(client: boto3.client, config: S3FilesImplConfig) -> None: + try: + client.head_bucket(Bucket=config.bucket_name) + except ClientError as e: + error_code = e.response["Error"]["Code"] + if error_code == "404": + if not config.auto_create_bucket: + raise RuntimeError( + f"S3 bucket '{config.bucket_name}' does not exist. " + f"Either create the bucket manually or set 'auto_create_bucket: true' in your configuration." + ) from e + try: + # For us-east-1, we can't specify LocationConstraint + if config.region == "us-east-1": + client.create_bucket(Bucket=config.bucket_name) + else: + client.create_bucket( + Bucket=config.bucket_name, + CreateBucketConfiguration={"LocationConstraint": config.region}, + ) + except ClientError as create_error: + raise RuntimeError( + f"Failed to create S3 bucket '{config.bucket_name}': {create_error}" + ) from create_error + elif error_code == "403": + raise RuntimeError(f"Access denied to S3 bucket '{config.bucket_name}'") from e + else: + raise RuntimeError(f"Failed to access S3 bucket '{config.bucket_name}': {e}") from e + + +class S3FilesImpl(Files): + """S3-based implementation of the Files API.""" + + # TODO: implement expiration, for now a silly offset + _SILLY_EXPIRATION_OFFSET = 100 * 365 * 24 * 60 * 60 + + def __init__(self, config: S3FilesImplConfig) -> None: + self._config = config + self._client: boto3.client | None = None + self._sql_store: SqlStore | None = None + + async def initialize(self) -> None: + self._client = _create_s3_client(self._config) + await _create_bucket_if_not_exists(self._client, self._config) + + self._sql_store = sqlstore_impl(self._config.metadata_store) + await self._sql_store.create_table( + "openai_files", + { + "id": ColumnDefinition(type=ColumnType.STRING, primary_key=True), + "filename": ColumnType.STRING, + "purpose": ColumnType.STRING, + "bytes": ColumnType.INTEGER, + "created_at": ColumnType.INTEGER, + "expires_at": ColumnType.INTEGER, + # TODO: add s3_etag field for integrity checking + }, + ) + + async def shutdown(self) -> None: + pass + + @property + def client(self) -> boto3.client: + assert self._client is not None, "Provider not initialized" + return self._client + + @property + def sql_store(self) -> SqlStore: + assert self._sql_store is not None, "Provider not initialized" + return self._sql_store + + async def openai_upload_file( + self, + file: Annotated[UploadFile, File()], + purpose: Annotated[OpenAIFilePurpose, Form()], + ) -> OpenAIFileObject: + file_id = f"file-{uuid.uuid4().hex}" + + filename = getattr(file, "filename", None) or "uploaded_file" + + created_at = int(time.time()) + expires_at = created_at + self._SILLY_EXPIRATION_OFFSET + content = await file.read() + file_size = len(content) + + await self.sql_store.insert( + "openai_files", + { + "id": file_id, + "filename": filename, + "purpose": purpose.value, + "bytes": file_size, + "created_at": created_at, + "expires_at": expires_at, + }, + ) + + try: + self.client.put_object( + Bucket=self._config.bucket_name, + Key=file_id, + Body=content, + # TODO: enable server-side encryption + ) + except ClientError as e: + await self.sql_store.delete("openai_files", where={"id": file_id}) + + raise RuntimeError(f"Failed to upload file to S3: {e}") from e + + return OpenAIFileObject( + id=file_id, + filename=filename, + purpose=purpose, + bytes=file_size, + created_at=created_at, + expires_at=expires_at, + ) + + async def openai_list_files( + self, + after: str | None = None, + limit: int | None = 10000, + order: Order | None = Order.desc, + purpose: OpenAIFilePurpose | None = None, + ) -> ListOpenAIFileResponse: + # this purely defensive. it should not happen because the router also default to Order.desc. + if not order: + order = Order.desc + + where_conditions = {} + if purpose: + where_conditions["purpose"] = purpose.value + + paginated_result = await self.sql_store.fetch_all( + table="openai_files", + where=where_conditions if where_conditions else None, + order_by=[("created_at", order.value)], + cursor=("id", after) if after else None, + limit=limit, + ) + + files = [ + OpenAIFileObject( + id=row["id"], + filename=row["filename"], + purpose=OpenAIFilePurpose(row["purpose"]), + bytes=row["bytes"], + created_at=row["created_at"], + expires_at=row["expires_at"], + ) + for row in paginated_result.data + ] + + return ListOpenAIFileResponse( + data=files, + has_more=paginated_result.has_more, + # empty string or None? spec says str, ref impl returns str | None, we go with spec + first_id=files[0].id if files else "", + last_id=files[-1].id if files else "", + ) + + async def openai_retrieve_file(self, file_id: str) -> OpenAIFileObject: + row = await self.sql_store.fetch_one("openai_files", where={"id": file_id}) + if not row: + raise ResourceNotFoundError(file_id, "File", "files.list()") + + return OpenAIFileObject( + id=row["id"], + filename=row["filename"], + purpose=OpenAIFilePurpose(row["purpose"]), + bytes=row["bytes"], + created_at=row["created_at"], + expires_at=row["expires_at"], + ) + + async def openai_delete_file(self, file_id: str) -> OpenAIFileDeleteResponse: + row = await self.sql_store.fetch_one("openai_files", where={"id": file_id}) + if not row: + raise ResourceNotFoundError(file_id, "File", "files.list()") + + try: + self.client.delete_object( + Bucket=self._config.bucket_name, + Key=row["id"], + ) + except ClientError as e: + if e.response["Error"]["Code"] != "NoSuchKey": + raise RuntimeError(f"Failed to delete file from S3: {e}") from e + + await self.sql_store.delete("openai_files", where={"id": file_id}) + + return OpenAIFileDeleteResponse(id=file_id, deleted=True) + + async def openai_retrieve_file_content(self, file_id: str) -> Response: + row = await self.sql_store.fetch_one("openai_files", where={"id": file_id}) + if not row: + raise ResourceNotFoundError(file_id, "File", "files.list()") + + try: + response = self.client.get_object( + Bucket=self._config.bucket_name, + Key=row["id"], + ) + # TODO: can we stream this instead of loading it into memory + content = response["Body"].read() + except ClientError as e: + if e.response["Error"]["Code"] == "NoSuchKey": + await self.sql_store.delete("openai_files", where={"id": file_id}) + raise ResourceNotFoundError(file_id, "File", "files.list()") from e + raise RuntimeError(f"Failed to download file from S3: {e}") from e + + return Response( + content=content, + media_type="application/octet-stream", + headers={"Content-Disposition": f'attachment; filename="{row["filename"]}"'}, + ) diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index bd86f7238..e907e8ec6 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -65,7 +65,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import FireworksImplConfig from .models import MODEL_ENTRIES -logger = get_logger(name=__name__, category="inference") +logger = get_logger(name=__name__, category="inference::fireworks") class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData): diff --git a/llama_stack/providers/remote/inference/llama_openai_compat/llama.py b/llama_stack/providers/remote/inference/llama_openai_compat/llama.py index cfcfcbf90..0edff882f 100644 --- a/llama_stack/providers/remote/inference/llama_openai_compat/llama.py +++ b/llama_stack/providers/remote/inference/llama_openai_compat/llama.py @@ -3,6 +3,11 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from llama_stack.apis.inference import ( + OpenAIChatCompletionContentPartImageParam, + OpenAIChatCompletionContentPartTextParam, + RerankResponse, +) from llama_stack.log import get_logger from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin @@ -10,7 +15,7 @@ from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from .models import MODEL_ENTRIES -logger = get_logger(name=__name__, category="inference") +logger = get_logger(name=__name__, category="inference::llama_openai_compat") class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin): @@ -54,3 +59,12 @@ class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin): async def shutdown(self): await super().shutdown() + + async def rerank( + self, + model: str, + query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam, + items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam], + max_num_results: int | None = None, + ) -> RerankResponse: + raise NotImplementedError("Reranking is not supported for Llama OpenAI Compat") diff --git a/llama_stack/providers/remote/inference/nvidia/NVIDIA.md b/llama_stack/providers/remote/inference/nvidia/NVIDIA.md index 35d26fd0b..d96b29fef 100644 --- a/llama_stack/providers/remote/inference/nvidia/NVIDIA.md +++ b/llama_stack/providers/remote/inference/nvidia/NVIDIA.md @@ -41,6 +41,11 @@ client.initialize() ### Create Completion +> Note on Completion API +> +> The hosted NVIDIA Llama NIMs (e.g., `meta-llama/Llama-3.1-8B-Instruct`) with ```NVIDIA_BASE_URL="https://integrate.api.nvidia.com"``` does not support the ```completion``` method, while the locally deployed NIM does. + + ```python response = client.inference.completion( model_id="meta-llama/Llama-3.1-8B-Instruct", @@ -76,6 +81,73 @@ response = client.inference.chat_completion( print(f"Response: {response.completion_message.content}") ``` +### Tool Calling Example ### +```python +from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition + +tool_definition = ToolDefinition( + tool_name="get_weather", + description="Get current weather information for a location", + parameters={ + "location": ToolParamDefinition( + param_type="string", + description="The city and state, e.g. San Francisco, CA", + required=True, + ), + "unit": ToolParamDefinition( + param_type="string", + description="Temperature unit (celsius or fahrenheit)", + required=False, + default="celsius", + ), + }, +) + +tool_response = client.inference.chat_completion( + model_id="meta-llama/Llama-3.1-8B-Instruct", + messages=[{"role": "user", "content": "What's the weather like in San Francisco?"}], + tools=[tool_definition], +) + +print(f"Tool Response: {tool_response.completion_message.content}") +if tool_response.completion_message.tool_calls: + for tool_call in tool_response.completion_message.tool_calls: + print(f"Tool Called: {tool_call.tool_name}") + print(f"Arguments: {tool_call.arguments}") +``` + +### Structured Output Example +```python +from llama_stack.apis.inference import JsonSchemaResponseFormat, ResponseFormatType + +person_schema = { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + "occupation": {"type": "string"}, + }, + "required": ["name", "age", "occupation"], +} + +response_format = JsonSchemaResponseFormat( + type=ResponseFormatType.json_schema, json_schema=person_schema +) + +structured_response = client.inference.chat_completion( + model_id="meta-llama/Llama-3.1-8B-Instruct", + messages=[ + { + "role": "user", + "content": "Create a profile for a fictional person named Alice who is 30 years old and is a software engineer. ", + } + ], + response_format=response_format, +) + +print(f"Structured Response: {structured_response.completion_message.content}") +``` + ### Create Embeddings > Note on OpenAI embeddings compatibility > diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index 7052cfb57..a5475bc92 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -7,7 +7,7 @@ import warnings from collections.abc import AsyncIterator -from openai import NOT_GIVEN, APIConnectionError, BadRequestError +from openai import NOT_GIVEN, APIConnectionError from llama_stack.apis.common.content_types import ( InterleavedContent, @@ -57,7 +57,7 @@ from .openai_utils import ( ) from .utils import _is_nvidia_hosted -logger = get_logger(name=__name__, category="inference") +logger = get_logger(name=__name__, category="inference::nvidia") class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper): @@ -197,15 +197,11 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper): } extra_body["input_type"] = task_type_options[task_type] - try: - response = await self.client.embeddings.create( - model=provider_model_id, - input=input, - extra_body=extra_body, - ) - except BadRequestError as e: - raise ValueError(f"Failed to get embeddings: {e}") from e - + response = await self.client.embeddings.create( + model=provider_model_id, + input=input, + extra_body=extra_body, + ) # # OpenAI: CreateEmbeddingResponse(data=[Embedding(embedding=list[float], ...)], ...) # -> diff --git a/llama_stack/providers/remote/inference/nvidia/utils.py b/llama_stack/providers/remote/inference/nvidia/utils.py index 790bbafd1..b8431e859 100644 --- a/llama_stack/providers/remote/inference/nvidia/utils.py +++ b/llama_stack/providers/remote/inference/nvidia/utils.py @@ -10,7 +10,7 @@ from llama_stack.log import get_logger from . import NVIDIAConfig -logger = get_logger(name=__name__, category="inference") +logger = get_logger(name=__name__, category="inference::nvidia") def _is_nvidia_hosted(config: NVIDIAConfig) -> bool: diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index a93421536..d72a94615 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -37,11 +37,14 @@ from llama_stack.apis.inference import ( Message, OpenAIChatCompletion, OpenAIChatCompletionChunk, + OpenAIChatCompletionContentPartImageParam, + OpenAIChatCompletionContentPartTextParam, OpenAICompletion, OpenAIEmbeddingsResponse, OpenAIEmbeddingUsage, OpenAIMessageParam, OpenAIResponseFormatParam, + RerankResponse, ResponseFormat, SamplingParams, TextTruncation, @@ -85,7 +88,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .models import MODEL_ENTRIES -logger = get_logger(name=__name__, category="inference") +logger = get_logger(name=__name__, category="inference::ollama") class OllamaInferenceAdapter( @@ -641,6 +644,15 @@ class OllamaInferenceAdapter( ): raise NotImplementedError("Batch chat completion is not supported for Ollama") + async def rerank( + self, + model: str, + query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam, + items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam], + max_num_results: int | None = None, + ) -> RerankResponse: + raise NotImplementedError("Reranking is not supported for Ollama") + async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]: async def _convert_content(content) -> dict: diff --git a/llama_stack/providers/remote/inference/openai/openai.py b/llama_stack/providers/remote/inference/openai/openai.py index 1c72fa0bc..0f73c9321 100644 --- a/llama_stack/providers/remote/inference/openai/openai.py +++ b/llama_stack/providers/remote/inference/openai/openai.py @@ -11,7 +11,7 @@ from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from .config import OpenAIConfig from .models import MODEL_ENTRIES -logger = get_logger(name=__name__, category="inference") +logger = get_logger(name=__name__, category="inference::openai") # diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index 9da961438..97c72d14c 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -58,7 +58,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig -log = get_logger(name=__name__, category="inference") +log = get_logger(name=__name__, category="inference::tgi") def build_hf_repo_model_entries(): diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index a06e4173b..54c76607f 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -61,7 +61,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import TogetherImplConfig from .models import MODEL_ENTRIES -logger = get_logger(name=__name__, category="inference") +logger = get_logger(name=__name__, category="inference::together") class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData): diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index ac626874c..a5f7ba52f 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -39,12 +39,15 @@ from llama_stack.apis.inference import ( Message, ModelStore, OpenAIChatCompletion, + OpenAIChatCompletionContentPartImageParam, + OpenAIChatCompletionContentPartTextParam, OpenAICompletion, OpenAIEmbeddingData, OpenAIEmbeddingsResponse, OpenAIEmbeddingUsage, OpenAIMessageParam, OpenAIResponseFormatParam, + RerankResponse, ResponseFormat, SamplingParams, TextTruncation, @@ -85,7 +88,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import VLLMInferenceAdapterConfig -log = get_logger(name=__name__, category="inference") +log = get_logger(name=__name__, category="inference::vllm") def build_hf_repo_model_entries(): @@ -732,4 +735,13 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): response_format: ResponseFormat | None = None, logprobs: LogProbConfig | None = None, ): - raise NotImplementedError("Batch chat completion is not supported for Ollama") + raise NotImplementedError("Batch chat completion is not supported for vLLM") + + async def rerank( + self, + model: str, + query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam, + items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam], + max_num_results: int | None = None, + ) -> RerankResponse: + raise NotImplementedError("Reranking is not supported for vLLM") diff --git a/llama_stack/providers/remote/post_training/nvidia/utils.py b/llama_stack/providers/remote/post_training/nvidia/utils.py index 9a6c3b53c..162951ff3 100644 --- a/llama_stack/providers/remote/post_training/nvidia/utils.py +++ b/llama_stack/providers/remote/post_training/nvidia/utils.py @@ -15,7 +15,7 @@ from llama_stack.providers.remote.post_training.nvidia.config import SFTLoRADefa from .config import NvidiaPostTrainingConfig -logger = get_logger(name=__name__, category="integration") +logger = get_logger(name=__name__, category="post_training::nvidia") def warn_unsupported_params(config_dict: Any, supported_keys: set[str], config_name: str) -> None: diff --git a/llama_stack/providers/remote/safety/bedrock/bedrock.py b/llama_stack/providers/remote/safety/bedrock/bedrock.py index 1ca87ae3d..8855e02a4 100644 --- a/llama_stack/providers/remote/safety/bedrock/bedrock.py +++ b/llama_stack/providers/remote/safety/bedrock/bedrock.py @@ -21,7 +21,7 @@ from llama_stack.providers.utils.bedrock.client import create_bedrock_client from .config import BedrockSafetyConfig -logger = get_logger(name=__name__, category="safety") +logger = get_logger(name=__name__, category="safety::bedrock") class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate): diff --git a/llama_stack/providers/remote/safety/nvidia/nvidia.py b/llama_stack/providers/remote/safety/nvidia/nvidia.py index 0d8d8ba7a..65f901da2 100644 --- a/llama_stack/providers/remote/safety/nvidia/nvidia.py +++ b/llama_stack/providers/remote/safety/nvidia/nvidia.py @@ -9,7 +9,7 @@ from typing import Any import requests from llama_stack.apis.inference import Message -from llama_stack.apis.safety import RunShieldResponse, Safety, SafetyViolation, ViolationLevel +from llama_stack.apis.safety import ModerationObject, RunShieldResponse, Safety, SafetyViolation, ViolationLevel from llama_stack.apis.shields import Shield from llama_stack.log import get_logger from llama_stack.providers.datatypes import ShieldsProtocolPrivate @@ -17,7 +17,7 @@ from llama_stack.providers.utils.inference.openai_compat import convert_message_ from .config import NVIDIASafetyConfig -logger = get_logger(name=__name__, category="safety") +logger = get_logger(name=__name__, category="safety::nvidia") class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate): @@ -67,6 +67,9 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate): self.shield = NeMoGuardrails(self.config, shield.shield_id) return await self.shield.run(messages) + async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject: + raise NotImplementedError("NVIDIA safety provider currently does not implement run_moderation") + class NeMoGuardrails: """ diff --git a/llama_stack/providers/remote/safety/sambanova/sambanova.py b/llama_stack/providers/remote/safety/sambanova/sambanova.py index 676ee7185..2beb5e0ea 100644 --- a/llama_stack/providers/remote/safety/sambanova/sambanova.py +++ b/llama_stack/providers/remote/safety/sambanova/sambanova.py @@ -25,7 +25,7 @@ from llama_stack.providers.utils.inference.openai_compat import convert_message_ from .config import SambaNovaSafetyConfig -logger = get_logger(name=__name__, category="safety") +logger = get_logger(name=__name__, category="safety::sambanova") CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?" diff --git a/llama_stack/providers/remote/vector_io/chroma/chroma.py b/llama_stack/providers/remote/vector_io/chroma/chroma.py index 0047e6055..a9ec644ef 100644 --- a/llama_stack/providers/remote/vector_io/chroma/chroma.py +++ b/llama_stack/providers/remote/vector_io/chroma/chroma.py @@ -33,7 +33,7 @@ from llama_stack.providers.utils.memory.vector_store import ( from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig -log = get_logger(name=__name__, category="vector_io") +log = get_logger(name=__name__, category="vector_io::chroma") ChromaClientType = chromadb.api.AsyncClientAPI | chromadb.api.ClientAPI diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py index 034ec331c..e07e8ff12 100644 --- a/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -36,7 +36,7 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig -logger = get_logger(name=__name__, category="vector_io") +logger = get_logger(name=__name__, category="vector_io::milvus") VERSION = "v3" VECTOR_DBS_PREFIX = f"vector_dbs:milvus:{VERSION}::" diff --git a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py index e829c9e72..1c8d361c2 100644 --- a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py +++ b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py @@ -34,7 +34,7 @@ from llama_stack.providers.utils.memory.vector_store import ( from .config import PGVectorVectorIOConfig -log = get_logger(name=__name__, category="vector_io") +log = get_logger(name=__name__, category="vector_io::pgvector") VERSION = "v3" VECTOR_DBS_PREFIX = f"vector_dbs:pgvector:{VERSION}::" diff --git a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py index 8499ff997..0a0faa23a 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py +++ b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py @@ -36,7 +36,7 @@ from llama_stack.providers.utils.memory.vector_store import ( from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig -log = get_logger(name=__name__, category="vector_io") +log = get_logger(name=__name__, category="vector_io::qdrant") CHUNK_ID_KEY = "_chunk_id" # KV store prefixes for vector databases diff --git a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py index ddf95317b..59b6bf124 100644 --- a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py +++ b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py @@ -34,7 +34,7 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti from .config import WeaviateVectorIOConfig -log = get_logger(name=__name__, category="vector_io") +log = get_logger(name=__name__, category="vector_io::weaviate") VERSION = "v3" VECTOR_DBS_PREFIX = f"vector_dbs:weaviate:{VERSION}::" diff --git a/llama_stack/providers/utils/inference/embedding_mixin.py b/llama_stack/providers/utils/inference/embedding_mixin.py index 05886cdc8..65ba2854b 100644 --- a/llama_stack/providers/utils/inference/embedding_mixin.py +++ b/llama_stack/providers/utils/inference/embedding_mixin.py @@ -28,7 +28,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import interleaved_con EMBEDDING_MODELS = {} -log = get_logger(name=__name__, category="inference") +log = get_logger(name=__name__, category="providers::utils") class SentenceTransformerEmbeddingMixin: diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index da2e634f6..880348805 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -54,7 +54,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, ) -logger = get_logger(name=__name__, category="inference") +logger = get_logger(name=__name__, category="providers::utils") class LiteLLMOpenAIMixin( diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index ddb3bda8c..44add8f9e 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -17,7 +17,7 @@ from llama_stack.providers.utils.inference import ( ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR, ) -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="providers::utils") class RemoteInferenceProviderConfig(BaseModel): diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index eb32d2de9..55c2ac0ad 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -134,7 +134,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( decode_assistant_message, ) -logger = get_logger(name=__name__, category="inference") +logger = get_logger(name=__name__, category="providers::utils") class OpenAICompatCompletionChoiceDelta(BaseModel): diff --git a/llama_stack/providers/utils/inference/openai_mixin.py b/llama_stack/providers/utils/inference/openai_mixin.py index 72286dffb..f60deee6e 100644 --- a/llama_stack/providers/utils/inference/openai_mixin.py +++ b/llama_stack/providers/utils/inference/openai_mixin.py @@ -25,7 +25,7 @@ from llama_stack.apis.inference import ( from llama_stack.log import get_logger from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params -logger = get_logger(name=__name__, category="core") +logger = get_logger(name=__name__, category="providers::utils") class OpenAIMixin(ABC): diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index bb9a91b97..a93326e41 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -58,7 +58,7 @@ from llama_stack.models.llama.sku_list import resolve_model from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal from llama_stack.providers.utils.inference import supported_inference_models -log = get_logger(name=__name__, category="inference") +log = get_logger(name=__name__, category="providers::utils") class ChatCompletionRequestWithRawContent(ChatCompletionRequest): diff --git a/llama_stack/providers/utils/kvstore/mongodb/mongodb.py b/llama_stack/providers/utils/kvstore/mongodb/mongodb.py index af52f3708..bab87a4aa 100644 --- a/llama_stack/providers/utils/kvstore/mongodb/mongodb.py +++ b/llama_stack/providers/utils/kvstore/mongodb/mongodb.py @@ -13,7 +13,7 @@ from llama_stack.providers.utils.kvstore import KVStore from ..config import MongoDBKVStoreConfig -log = get_logger(name=__name__, category="kvstore") +log = get_logger(name=__name__, category="providers::utils") class MongoDBKVStoreImpl(KVStore): diff --git a/llama_stack/providers/utils/kvstore/postgres/postgres.py b/llama_stack/providers/utils/kvstore/postgres/postgres.py index 021e90774..56d6dbb48 100644 --- a/llama_stack/providers/utils/kvstore/postgres/postgres.py +++ b/llama_stack/providers/utils/kvstore/postgres/postgres.py @@ -14,7 +14,7 @@ from llama_stack.log import get_logger from ..api import KVStore from ..config import PostgresKVStoreConfig -log = get_logger(name=__name__, category="kvstore") +log = get_logger(name=__name__, category="providers::utils") class PostgresKVStoreImpl(KVStore): diff --git a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py index 0775b31d1..3acdcf293 100644 --- a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py +++ b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py @@ -44,7 +44,7 @@ from llama_stack.providers.utils.memory.vector_store import ( make_overlapped_chunks, ) -logger = get_logger(name=__name__, category="memory") +logger = get_logger(name=__name__, category="providers::utils") # Constants for OpenAI vector stores CHUNK_MULTIPLIER = 5 diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index b5d82432d..b74080384 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -33,7 +33,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( ) from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id -log = get_logger(name=__name__, category="memory") +log = get_logger(name=__name__, category="providers::utils") class ChunkForDeletion(BaseModel): diff --git a/llama_stack/providers/utils/scheduler.py b/llama_stack/providers/utils/scheduler.py index 65c3d2898..146591b2f 100644 --- a/llama_stack/providers/utils/scheduler.py +++ b/llama_stack/providers/utils/scheduler.py @@ -17,7 +17,7 @@ from pydantic import BaseModel from llama_stack.log import get_logger -logger = get_logger(name=__name__, category="scheduler") +logger = get_logger(name=__name__, category="providers::utils") # TODO: revisit the list of possible statuses when defining a more coherent diff --git a/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py b/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py index ccc835768..867ba2f55 100644 --- a/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py +++ b/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py @@ -17,7 +17,7 @@ from llama_stack.log import get_logger from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore from .sqlstore import SqlStoreType -logger = get_logger(name=__name__, category="authorized_sqlstore") +logger = get_logger(name=__name__, category="providers::utils") # Hardcoded copy of the default policy that our SQL filtering implements # WARNING: If default_policy() changes, this constant must be updated accordingly diff --git a/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py b/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py index 6414929db..f75c35314 100644 --- a/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py +++ b/llama_stack/providers/utils/sqlstore/sqlalchemy_sqlstore.py @@ -22,6 +22,7 @@ from sqlalchemy import ( text, ) from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine +from sqlalchemy.ext.asyncio.engine import AsyncEngine from llama_stack.apis.common.responses import PaginatedResponse from llama_stack.log import get_logger @@ -29,7 +30,7 @@ from llama_stack.log import get_logger from .api import ColumnDefinition, ColumnType, SqlStore from .sqlstore import SqlAlchemySqlStoreConfig -logger = get_logger(name=__name__, category="sqlstore") +logger = get_logger(name=__name__, category="providers::utils") TYPE_MAPPING: dict[ColumnType, Any] = { ColumnType.INTEGER: Integer, @@ -45,9 +46,12 @@ TYPE_MAPPING: dict[ColumnType, Any] = { class SqlAlchemySqlStoreImpl(SqlStore): def __init__(self, config: SqlAlchemySqlStoreConfig): self.config = config - self.async_session = async_sessionmaker(create_async_engine(config.engine_str)) + self.async_session = async_sessionmaker(self.create_engine()) self.metadata = MetaData() + def create_engine(self) -> AsyncEngine: + return create_async_engine(self.config.engine_str, pool_pre_ping=True) + async def create_table( self, table: str, @@ -83,7 +87,7 @@ class SqlAlchemySqlStoreImpl(SqlStore): else: sqlalchemy_table = self.metadata.tables[table] - engine = create_async_engine(self.config.engine_str) + engine = self.create_engine() async with engine.begin() as conn: await conn.run_sync(self.metadata.create_all, tables=[sqlalchemy_table], checkfirst=True) @@ -241,7 +245,7 @@ class SqlAlchemySqlStoreImpl(SqlStore): nullable: bool = True, ) -> None: """Add a column to an existing table if the column doesn't already exist.""" - engine = create_async_engine(self.config.engine_str) + engine = self.create_engine() try: async with engine.begin() as conn: diff --git a/llama_stack/ui/app/chat-playground/page.test.tsx b/llama_stack/ui/app/chat-playground/page.test.tsx new file mode 100644 index 000000000..54c15f95a --- /dev/null +++ b/llama_stack/ui/app/chat-playground/page.test.tsx @@ -0,0 +1,587 @@ +import React from "react"; +import { + render, + screen, + fireEvent, + waitFor, + act, +} from "@testing-library/react"; +import "@testing-library/jest-dom"; +import ChatPlaygroundPage from "./page"; + +const mockClient = { + agents: { + list: jest.fn(), + create: jest.fn(), + retrieve: jest.fn(), + delete: jest.fn(), + session: { + list: jest.fn(), + create: jest.fn(), + delete: jest.fn(), + retrieve: jest.fn(), + }, + turn: { + create: jest.fn(), + }, + }, + models: { + list: jest.fn(), + }, + toolgroups: { + list: jest.fn(), + }, +}; + +jest.mock("@/hooks/use-auth-client", () => ({ + useAuthClient: jest.fn(() => mockClient), +})); + +jest.mock("@/components/chat-playground/chat", () => ({ + Chat: jest.fn( + ({ + className, + messages, + handleSubmit, + input, + handleInputChange, + isGenerating, + append, + suggestions, + }) => ( +
+
{messages.length}
+ + + {suggestions?.map((suggestion: string, index: number) => ( + + ))} +
+ ) + ), +})); + +jest.mock("@/components/chat-playground/conversations", () => ({ + SessionManager: jest.fn(({ selectedAgentId, onNewSession }) => ( +
+ {selectedAgentId && ( + <> +
{selectedAgentId}
+ + + )} +
+ )), + SessionUtils: { + saveCurrentSessionId: jest.fn(), + loadCurrentSessionId: jest.fn(), + loadCurrentAgentId: jest.fn(), + saveCurrentAgentId: jest.fn(), + clearCurrentSession: jest.fn(), + saveSessionData: jest.fn(), + loadSessionData: jest.fn(), + saveAgentConfig: jest.fn(), + loadAgentConfig: jest.fn(), + clearAgentCache: jest.fn(), + createDefaultSession: jest.fn(() => ({ + id: "test-session-123", + name: "Default Session", + messages: [], + selectedModel: "", + systemMessage: "You are a helpful assistant.", + agentId: "test-agent-123", + createdAt: Date.now(), + updatedAt: Date.now(), + })), + }, +})); + +const mockAgents = [ + { + agent_id: "agent_123", + agent_config: { + name: "Test Agent", + instructions: "You are a test assistant.", + }, + }, + { + agent_id: "agent_456", + agent_config: { + agent_name: "Another Agent", + instructions: "You are another assistant.", + }, + }, +]; + +const mockModels = [ + { + identifier: "test-model-1", + model_type: "llm", + }, + { + identifier: "test-model-2", + model_type: "llm", + }, +]; + +const mockToolgroups = [ + { + identifier: "builtin::rag", + provider_id: "test-provider", + type: "tool_group", + provider_resource_id: "test-resource", + }, +]; + +describe("ChatPlaygroundPage", () => { + beforeEach(() => { + jest.clearAllMocks(); + Element.prototype.scrollIntoView = jest.fn(); + mockClient.agents.list.mockResolvedValue({ data: mockAgents }); + mockClient.models.list.mockResolvedValue(mockModels); + mockClient.toolgroups.list.mockResolvedValue(mockToolgroups); + mockClient.agents.session.create.mockResolvedValue({ + session_id: "new-session-123", + }); + mockClient.agents.session.list.mockResolvedValue({ data: [] }); + mockClient.agents.session.retrieve.mockResolvedValue({ + session_id: "test-session", + session_name: "Test Session", + started_at: new Date().toISOString(), + turns: [], + }); // No turns by default + mockClient.agents.retrieve.mockResolvedValue({ + agent_id: "test-agent", + agent_config: { + toolgroups: ["builtin::rag"], + instructions: "Test instructions", + model: "test-model", + }, + }); + mockClient.agents.delete.mockResolvedValue(undefined); + }); + + describe("Agent Selector Rendering", () => { + test("shows agent selector when agents are available", async () => { + await act(async () => { + render(); + }); + + await waitFor(() => { + expect(screen.getByText("Agent Session:")).toBeInTheDocument(); + expect(screen.getAllByRole("combobox")).toHaveLength(2); + expect(screen.getByText("+ New Agent")).toBeInTheDocument(); + expect(screen.getByText("Clear Chat")).toBeInTheDocument(); + }); + }); + + test("does not show agent selector when no agents are available", async () => { + mockClient.agents.list.mockResolvedValue({ data: [] }); + + await act(async () => { + render(); + }); + + await waitFor(() => { + expect(screen.queryByText("Agent Session:")).not.toBeInTheDocument(); + expect(screen.getAllByRole("combobox")).toHaveLength(1); + expect(screen.getByText("+ New Agent")).toBeInTheDocument(); + expect(screen.queryByText("Clear Chat")).not.toBeInTheDocument(); + }); + }); + + test("does not show agent selector while loading", async () => { + mockClient.agents.list.mockImplementation(() => new Promise(() => {})); + + await act(async () => { + render(); + }); + + expect(screen.queryByText("Agent Session:")).not.toBeInTheDocument(); + expect(screen.getAllByRole("combobox")).toHaveLength(1); + expect(screen.getByText("+ New Agent")).toBeInTheDocument(); + expect(screen.queryByText("Clear Chat")).not.toBeInTheDocument(); + }); + + test("shows agent options in selector", async () => { + await act(async () => { + render(); + }); + + await waitFor(() => { + const agentCombobox = screen.getAllByRole("combobox").find(element => { + return ( + element.textContent?.includes("Test Agent") || + element.textContent?.includes("Select Agent") + ); + }); + expect(agentCombobox).toBeDefined(); + fireEvent.click(agentCombobox!); + }); + + await waitFor(() => { + expect(screen.getAllByText("Test Agent")).toHaveLength(2); + expect(screen.getByText("Another Agent")).toBeInTheDocument(); + }); + }); + + test("displays agent ID when no name is available", async () => { + const agentWithoutName = { + agent_id: "agent_789", + agent_config: { + instructions: "You are an agent without a name.", + }, + }; + + mockClient.agents.list.mockResolvedValue({ data: [agentWithoutName] }); + + await act(async () => { + render(); + }); + + await waitFor(() => { + const agentCombobox = screen.getAllByRole("combobox").find(element => { + return ( + element.textContent?.includes("Agent agent_78") || + element.textContent?.includes("Select Agent") + ); + }); + expect(agentCombobox).toBeDefined(); + fireEvent.click(agentCombobox!); + }); + + await waitFor(() => { + expect(screen.getAllByText("Agent agent_78...")).toHaveLength(2); + }); + }); + }); + + describe("Agent Creation Modal", () => { + test("opens agent creation modal when + New Agent is clicked", async () => { + await act(async () => { + render(); + }); + + const newAgentButton = screen.getByText("+ New Agent"); + fireEvent.click(newAgentButton); + + expect(screen.getByText("Create New Agent")).toBeInTheDocument(); + expect(screen.getByText("Agent Name (optional)")).toBeInTheDocument(); + expect(screen.getAllByText("Model")).toHaveLength(2); + expect(screen.getByText("System Instructions")).toBeInTheDocument(); + expect(screen.getByText("Tools (optional)")).toBeInTheDocument(); + }); + + test("closes modal when Cancel is clicked", async () => { + await act(async () => { + render(); + }); + + const newAgentButton = screen.getByText("+ New Agent"); + fireEvent.click(newAgentButton); + + const cancelButton = screen.getByText("Cancel"); + fireEvent.click(cancelButton); + + expect(screen.queryByText("Create New Agent")).not.toBeInTheDocument(); + }); + + test("creates agent when Create Agent is clicked", async () => { + mockClient.agents.create.mockResolvedValue({ agent_id: "new-agent-123" }); + mockClient.agents.list + .mockResolvedValueOnce({ data: mockAgents }) + .mockResolvedValueOnce({ + data: [ + ...mockAgents, + { agent_id: "new-agent-123", agent_config: { name: "New Agent" } }, + ], + }); + + await act(async () => { + render(); + }); + + const newAgentButton = screen.getByText("+ New Agent"); + await act(async () => { + fireEvent.click(newAgentButton); + }); + + await waitFor(() => { + expect(screen.getByText("Create New Agent")).toBeInTheDocument(); + }); + + const nameInput = screen.getByPlaceholderText("My Custom Agent"); + await act(async () => { + fireEvent.change(nameInput, { target: { value: "Test Agent Name" } }); + }); + + const instructionsTextarea = screen.getByDisplayValue( + "You are a helpful assistant." + ); + await act(async () => { + fireEvent.change(instructionsTextarea, { + target: { value: "Custom instructions" }, + }); + }); + + await waitFor(() => { + const modalModelSelectors = screen + .getAllByRole("combobox") + .filter(el => { + return ( + el.textContent?.includes("Select Model") || + el.closest('[class*="modal"]') || + el.closest('[class*="card"]') + ); + }); + expect(modalModelSelectors.length).toBeGreaterThan(0); + }); + + const modalModelSelectors = screen.getAllByRole("combobox").filter(el => { + return ( + el.textContent?.includes("Select Model") || + el.closest('[class*="modal"]') || + el.closest('[class*="card"]') + ); + }); + + await act(async () => { + fireEvent.click(modalModelSelectors[0]); + }); + + await waitFor(() => { + const modelOptions = screen.getAllByText("test-model-1"); + expect(modelOptions.length).toBeGreaterThan(0); + }); + + const modelOptions = screen.getAllByText("test-model-1"); + const dropdownOption = modelOptions.find( + option => + option.closest('[role="option"]') || + option.id?.includes("radix") || + option.getAttribute("aria-selected") !== null + ); + + await act(async () => { + fireEvent.click( + dropdownOption || modelOptions[modelOptions.length - 1] + ); + }); + + await waitFor(() => { + const createButton = screen.getByText("Create Agent"); + expect(createButton).not.toBeDisabled(); + }); + + const createButton = screen.getByText("Create Agent"); + await act(async () => { + fireEvent.click(createButton); + }); + + await waitFor(() => { + expect(mockClient.agents.create).toHaveBeenCalledWith({ + agent_config: { + model: expect.any(String), + instructions: "Custom instructions", + name: "Test Agent Name", + enable_session_persistence: true, + }, + }); + }); + + await waitFor(() => { + expect(screen.queryByText("Create New Agent")).not.toBeInTheDocument(); + }); + }); + }); + + describe("Agent Selection", () => { + test("creates default session when agent is selected", async () => { + await act(async () => { + render(); + }); + + await waitFor(() => { + // first agent should be auto-selected + expect(mockClient.agents.session.create).toHaveBeenCalledWith( + "agent_123", + { session_name: "Default Session" } + ); + }); + }); + + test("switches agent when different agent is selected", async () => { + await act(async () => { + render(); + }); + + await waitFor(() => { + const agentCombobox = screen.getAllByRole("combobox").find(element => { + return ( + element.textContent?.includes("Test Agent") || + element.textContent?.includes("Select Agent") + ); + }); + expect(agentCombobox).toBeDefined(); + fireEvent.click(agentCombobox!); + }); + + await waitFor(() => { + const anotherAgentOption = screen.getByText("Another Agent"); + fireEvent.click(anotherAgentOption); + }); + + expect(mockClient.agents.session.create).toHaveBeenCalledWith( + "agent_456", + { session_name: "Default Session" } + ); + }); + }); + + describe("Agent Deletion", () => { + test("shows delete button when multiple agents exist", async () => { + await act(async () => { + render(); + }); + + await waitFor(() => { + expect(screen.getByTitle("Delete current agent")).toBeInTheDocument(); + }); + }); + + test("hides delete button when only one agent exists", async () => { + mockClient.agents.list.mockResolvedValue({ + data: [mockAgents[0]], + }); + + await act(async () => { + render(); + }); + + await waitFor(() => { + expect( + screen.queryByTitle("Delete current agent") + ).not.toBeInTheDocument(); + }); + }); + + test("deletes agent and switches to another when confirmed", async () => { + global.confirm = jest.fn(() => true); + + await act(async () => { + render(); + }); + + await waitFor(() => { + expect(screen.getByTitle("Delete current agent")).toBeInTheDocument(); + }); + + mockClient.agents.delete.mockResolvedValue(undefined); + mockClient.agents.list.mockResolvedValueOnce({ data: mockAgents }); + mockClient.agents.list.mockResolvedValueOnce({ + data: [mockAgents[1]], + }); + + const deleteButton = screen.getByTitle("Delete current agent"); + await act(async () => { + deleteButton.click(); + }); + + await waitFor(() => { + expect(mockClient.agents.delete).toHaveBeenCalledWith("agent_123"); + expect(global.confirm).toHaveBeenCalledWith( + "Are you sure you want to delete this agent? This action cannot be undone and will delete all associated sessions." + ); + }); + + (global.confirm as jest.Mock).mockRestore(); + }); + + test("does not delete agent when cancelled", async () => { + global.confirm = jest.fn(() => false); + + await act(async () => { + render(); + }); + + await waitFor(() => { + expect(screen.getByTitle("Delete current agent")).toBeInTheDocument(); + }); + + const deleteButton = screen.getByTitle("Delete current agent"); + await act(async () => { + deleteButton.click(); + }); + + await waitFor(() => { + expect(global.confirm).toHaveBeenCalled(); + expect(mockClient.agents.delete).not.toHaveBeenCalled(); + }); + + (global.confirm as jest.Mock).mockRestore(); + }); + }); + + describe("Error Handling", () => { + test("handles agent loading errors gracefully", async () => { + mockClient.agents.list.mockRejectedValue( + new Error("Failed to load agents") + ); + const consoleSpy = jest + .spyOn(console, "error") + .mockImplementation(() => {}); + + await act(async () => { + render(); + }); + + await waitFor(() => { + expect(consoleSpy).toHaveBeenCalledWith( + "Error fetching agents:", + expect.any(Error) + ); + }); + + expect(screen.getByText("+ New Agent")).toBeInTheDocument(); + + consoleSpy.mockRestore(); + }); + + test("handles model loading errors gracefully", async () => { + mockClient.models.list.mockRejectedValue( + new Error("Failed to load models") + ); + const consoleSpy = jest + .spyOn(console, "error") + .mockImplementation(() => {}); + + await act(async () => { + render(); + }); + + await waitFor(() => { + expect(consoleSpy).toHaveBeenCalledWith( + "Error fetching models:", + expect.any(Error) + ); + }); + + consoleSpy.mockRestore(); + }); + }); +}); diff --git a/llama_stack/ui/app/chat-playground/page.tsx b/llama_stack/ui/app/chat-playground/page.tsx index b8651aca0..f26791a41 100644 --- a/llama_stack/ui/app/chat-playground/page.tsx +++ b/llama_stack/ui/app/chat-playground/page.tsx @@ -1,6 +1,6 @@ "use client"; -import { useState, useEffect } from "react"; +import { useState, useEffect, useCallback, useRef } from "react"; import { flushSync } from "react-dom"; import { Button } from "@/components/ui/button"; import { @@ -10,14 +10,22 @@ import { SelectTrigger, SelectValue, } from "@/components/ui/select"; +import { Card } from "@/components/ui/card"; +import { Input } from "@/components/ui/input"; +import { Trash2 } from "lucide-react"; import { Chat } from "@/components/chat-playground/chat"; import { type Message } from "@/components/chat-playground/chat-message"; import { useAuthClient } from "@/hooks/use-auth-client"; -import type { CompletionCreateParams } from "llama-stack-client/resources/chat/completions"; import type { Model } from "llama-stack-client/resources/models"; - +import type { TurnCreateParams } from "llama-stack-client/resources/agents/turn"; +import { + SessionUtils, + type ChatSession, +} from "@/components/chat-playground/conversations"; export default function ChatPlaygroundPage() { - const [messages, setMessages] = useState([]); + const [currentSession, setCurrentSession] = useState( + null + ); const [input, setInput] = useState(""); const [isGenerating, setIsGenerating] = useState(false); const [error, setError] = useState(null); @@ -25,10 +33,523 @@ export default function ChatPlaygroundPage() { const [selectedModel, setSelectedModel] = useState(""); const [modelsLoading, setModelsLoading] = useState(true); const [modelsError, setModelsError] = useState(null); + const [agents, setAgents] = useState< + Array<{ + agent_id: string; + agent_config?: { + agent_name?: string; + name?: string; + instructions?: string; + }; + [key: string]: unknown; + }> + >([]); + const [selectedAgentConfig, setSelectedAgentConfig] = useState<{ + toolgroups?: Array< + string | { name: string; args: Record } + >; + } | null>(null); + const [selectedAgentId, setSelectedAgentId] = useState(""); + const [agentsLoading, setAgentsLoading] = useState(true); + const [showCreateAgent, setShowCreateAgent] = useState(false); + const [newAgentName, setNewAgentName] = useState(""); + const [newAgentInstructions, setNewAgentInstructions] = useState( + "You are a helpful assistant." + ); + const [selectedToolgroups, setSelectedToolgroups] = useState([]); + const [availableToolgroups, setAvailableToolgroups] = useState< + Array<{ + identifier: string; + provider_id: string; + type: string; + provider_resource_id?: string; + }> + >([]); const client = useAuthClient(); + const abortControllerRef = useRef(null); const isModelsLoading = modelsLoading ?? true; + const loadAgentConfig = useCallback( + async (agentId: string) => { + try { + console.log("Loading agent config for:", agentId); + + // try to load from cache first + const cachedConfig = SessionUtils.loadAgentConfig(agentId); + if (cachedConfig) { + console.log("✅ Loaded agent config from cache:", cachedConfig); + setSelectedAgentConfig({ + toolgroups: cachedConfig.toolgroups, + }); + return; + } + + console.log("📡 Fetching agent config from API..."); + const agentDetails = await client.agents.retrieve(agentId); + console.log("Agent details retrieved:", agentDetails); + console.log("Agent config:", agentDetails.agent_config); + console.log("Agent toolgroups:", agentDetails.agent_config?.toolgroups); + + // cache the config + SessionUtils.saveAgentConfig(agentId, agentDetails.agent_config); + + setSelectedAgentConfig({ + toolgroups: agentDetails.agent_config?.toolgroups, + }); + } catch (error) { + console.error("Error loading agent config:", error); + setSelectedAgentConfig(null); + } + }, + [client] + ); + + const createDefaultSession = useCallback( + async (agentId: string) => { + try { + const response = await client.agents.session.create(agentId, { + session_name: "Default Session", + }); + + const defaultSession: ChatSession = { + id: response.session_id, + name: "Default Session", + messages: [], + selectedModel: selectedModel, // Use current selected model + systemMessage: "You are a helpful assistant.", + agentId, + createdAt: Date.now(), + updatedAt: Date.now(), + }; + + setCurrentSession(defaultSession); + console.log( + `💾 Saving default session ID for agent ${agentId}:`, + defaultSession.id + ); + SessionUtils.saveCurrentSessionId(defaultSession.id, agentId); + // cache entire session data + SessionUtils.saveSessionData(agentId, defaultSession); + } catch (error) { + console.error("Error creating default session:", error); + } + }, + [client, selectedModel] + ); + + const loadSessionMessages = useCallback( + async (agentId: string, sessionId: string): Promise => { + try { + const session = await client.agents.session.retrieve( + agentId, + sessionId + ); + + if (!session || !session.turns || !Array.isArray(session.turns)) { + return []; + } + + const messages: Message[] = []; + for (const turn of session.turns) { + // add user messages + if (turn.input_messages && Array.isArray(turn.input_messages)) { + for (const input of turn.input_messages) { + if (input.role === "user" && input.content) { + messages.push({ + id: `${turn.turn_id}-user-${messages.length}`, + role: "user", + content: + typeof input.content === "string" + ? input.content + : JSON.stringify(input.content), + createdAt: new Date(turn.started_at || Date.now()), + }); + } + } + } + + // add assistant message from output_message + if (turn.output_message && turn.output_message.content) { + messages.push({ + id: `${turn.turn_id}-assistant-${messages.length}`, + role: "assistant", + content: + typeof turn.output_message.content === "string" + ? turn.output_message.content + : JSON.stringify(turn.output_message.content), + createdAt: new Date( + turn.completed_at || turn.started_at || Date.now() + ), + }); + } + } + + return messages; + } catch (error) { + console.error("Error loading session messages:", error); + return []; + } + }, + [client] + ); + + const loadAgentSessions = useCallback( + async (agentId: string) => { + try { + console.log("Loading sessions for agent:", agentId); + const response = await client.agents.session.list(agentId); + console.log("Available sessions:", response.data); + + if ( + response.data && + Array.isArray(response.data) && + response.data.length > 0 + ) { + // check for a previously saved session ID for this specific agent + const savedSessionId = SessionUtils.loadCurrentSessionId(agentId); + console.log(`Saved session ID for agent ${agentId}:`, savedSessionId); + + // try to load cached session data first + if (savedSessionId) { + const cachedSession = SessionUtils.loadSessionData( + agentId, + savedSessionId + ); + if (cachedSession) { + console.log("✅ Loaded session from cache:", cachedSession.id); + setCurrentSession(cachedSession); + SessionUtils.saveCurrentSessionId(cachedSession.id, agentId); + return; + } + console.log("📡 Cache miss, fetching session from API..."); + } + + let sessionToLoad = response.data[0] as { + session_id: string; + session_name?: string; + started_at?: string; + }; + console.log( + "Default session to load (first in list):", + sessionToLoad.session_id + ); + + // try to find saved session id in available sessions + if (savedSessionId) { + const foundSession = response.data.find( + (s: { session_id: string }) => s.session_id === savedSessionId + ); + console.log("Found saved session in list:", foundSession); + if (foundSession) { + sessionToLoad = foundSession as { + session_id: string; + session_name?: string; + started_at?: string; + }; + console.log( + "✅ Restored previously selected session:", + savedSessionId + ); + } else { + console.log( + "❌ Previously selected session not found, using latest session" + ); + } + } else { + console.log("❌ No saved session ID found, using latest session"); + } + + const messages = await loadSessionMessages( + agentId, + sessionToLoad.session_id + ); + + const session: ChatSession = { + id: sessionToLoad.session_id, + name: sessionToLoad.session_name || "Session", + messages, + selectedModel: selectedModel || "", // Preserve current model or use empty + systemMessage: "You are a helpful assistant.", + agentId, + createdAt: sessionToLoad.started_at + ? new Date(sessionToLoad.started_at).getTime() + : Date.now(), + updatedAt: Date.now(), + }; + + setCurrentSession(session); + console.log(`💾 Saving session ID for agent ${agentId}:`, session.id); + SessionUtils.saveCurrentSessionId(session.id, agentId); + // cache session data + SessionUtils.saveSessionData(agentId, session); + } else { + // no sessions, create a new one + await createDefaultSession(agentId); + } + } catch (error) { + console.error("Error loading agent sessions:", error); + // fallback to creating a new session + await createDefaultSession(agentId); + } + }, + [client, loadSessionMessages, createDefaultSession, selectedModel] + ); + + useEffect(() => { + const fetchAgents = async () => { + try { + setAgentsLoading(true); + const agentList = await client.agents.list(); + setAgents( + (agentList.data as Array<{ + agent_id: string; + agent_config?: { + agent_name?: string; + name?: string; + instructions?: string; + }; + [key: string]: unknown; + }>) || [] + ); + + if (agentList.data && agentList.data.length > 0) { + // check if there's a previously selected agent + const savedAgentId = SessionUtils.loadCurrentAgentId(); + + let agentToSelect = agentList.data[0] as { + agent_id: string; + agent_config?: { + agent_name?: string; + name?: string; + instructions?: string; + }; + [key: string]: unknown; + }; + + // if we have a saved agent ID, find it in the available agents + if (savedAgentId) { + const foundAgent = agentList.data.find( + (a: { agent_id: string }) => a.agent_id === savedAgentId + ); + if (foundAgent) { + agentToSelect = foundAgent as typeof agentToSelect; + } else { + console.log("Previously slelected agent not found:"); + } + } + setSelectedAgentId(agentToSelect.agent_id); + SessionUtils.saveCurrentAgentId(agentToSelect.agent_id); + // load agent config immediately + await loadAgentConfig(agentToSelect.agent_id); + // Note: loadAgentSessions will be called after models are loaded + } + } catch (error) { + console.error("Error fetching agents:", error); + } finally { + setAgentsLoading(false); + } + }; + + fetchAgents(); + + // fetch available toolgroups + const fetchToolgroups = async () => { + try { + console.log("Fetching toolgroups..."); + const toolgroups = await client.toolgroups.list(); + console.log("Toolgroups response:", toolgroups); + + // The client returns data directly, not wrapped in .data + const toolGroupsArray = Array.isArray(toolgroups) + ? toolgroups + : toolgroups && + typeof toolgroups === "object" && + "data" in toolgroups && + Array.isArray((toolgroups as { data: unknown }).data) + ? ( + toolgroups as { + data: Array<{ + identifier: string; + provider_id: string; + type: string; + provider_resource_id?: string; + }>; + } + ).data + : []; + + if (toolGroupsArray && Array.isArray(toolGroupsArray)) { + setAvailableToolgroups(toolGroupsArray); + console.log("Set toolgroups:", toolGroupsArray); + } else { + console.error("Invalid toolgroups data format:", toolgroups); + } + } catch (error) { + console.error("Error fetching toolgroups:", error); + if (error instanceof Error) { + console.error("Error details:", { + name: error.name, + message: error.message, + stack: error.stack, + }); + } + } + }; + + fetchToolgroups(); + }, [client, loadAgentSessions, loadAgentConfig]); + + const createNewAgent = useCallback( + async ( + name: string, + instructions: string, + model: string, + toolgroups: string[] = [] + ) => { + try { + console.log("Creating agent with toolgroups:", toolgroups); + const agentConfig = { + model, + instructions, + name: name || undefined, + enable_session_persistence: true, + toolgroups: toolgroups.length > 0 ? toolgroups : undefined, + }; + console.log("Agent config being sent:", agentConfig); + + const response = await client.agents.create({ + agent_config: agentConfig, + }); + + // refresh agents list + const agentList = await client.agents.list(); + setAgents( + (agentList.data as Array<{ + agent_id: string; + agent_config?: { + agent_name?: string; + name?: string; + instructions?: string; + }; + [key: string]: unknown; + }>) || [] + ); + + // set the new agent as selected + setSelectedAgentId(response.agent_id); + await loadAgentConfig(response.agent_id); + await loadAgentSessions(response.agent_id); + + return response.agent_id; + } catch (error) { + console.error("Error creating agent:", error); + throw error; + } + }, + [client, loadAgentSessions, loadAgentConfig] + ); + + const deleteAgent = useCallback( + async (agentId: string) => { + if (agents.length <= 1) { + return; + } + + if ( + confirm( + "Are you sure you want to delete this agent? This action cannot be undone and will delete all associated sessions." + ) + ) { + try { + await client.agents.delete(agentId); + + // clear cached data for agent + SessionUtils.clearAgentCache(agentId); + + // Refresh agents list + const agentList = await client.agents.list(); + setAgents( + (agentList.data as Array<{ + agent_id: string; + agent_config?: { + agent_name?: string; + name?: string; + instructions?: string; + }; + [key: string]: unknown; + }>) || [] + ); + + // if we deleted the current agent, switch to another one + if (selectedAgentId === agentId) { + const remainingAgents = agentList.data?.filter( + (a: { agent_id: string }) => a.agent_id !== agentId + ); + if (remainingAgents && remainingAgents.length > 0) { + const newAgent = remainingAgents[0] as { + agent_id: string; + agent_config?: { + agent_name?: string; + name?: string; + instructions?: string; + }; + [key: string]: unknown; + }; + setSelectedAgentId(newAgent.agent_id); + SessionUtils.saveCurrentAgentId(newAgent.agent_id); + await loadAgentConfig(newAgent.agent_id); + await loadAgentSessions(newAgent.agent_id); + } else { + // No agents left + setSelectedAgentId(""); + setCurrentSession(null); + setSelectedAgentConfig(null); + } + } + } catch (error) { + console.error("Error deleting agent:", error); + } + } + }, + [agents.length, client, selectedAgentId, loadAgentConfig, loadAgentSessions] + ); + + const handleModelChange = useCallback((newModel: string) => { + setSelectedModel(newModel); + setCurrentSession(prev => + prev + ? { + ...prev, + selectedModel: newModel, + updatedAt: Date.now(), + } + : prev + ); + }, []); + + useEffect(() => { + if (currentSession) { + console.log( + `💾 Auto-saving session ID for agent ${currentSession.agentId}:`, + currentSession.id + ); + SessionUtils.saveCurrentSessionId( + currentSession.id, + currentSession.agentId + ); + // cache session data + SessionUtils.saveSessionData(currentSession.agentId, currentSession); + // only update selectedModel if the session has a valid model and it's different from current + if ( + currentSession.selectedModel && + currentSession.selectedModel !== selectedModel + ) { + setSelectedModel(currentSession.selectedModel); + } + } + }, [currentSession, selectedModel]); + useEffect(() => { const fetchModels = async () => { try { @@ -38,7 +559,7 @@ export default function ChatPlaygroundPage() { const llmModels = modelList.filter(model => model.model_type === "llm"); setModels(llmModels); if (llmModels.length > 0) { - setSelectedModel(llmModels[0].identifier); + handleModelChange(llmModels[0].identifier); } } catch (err) { console.error("Error fetching models:", err); @@ -49,39 +570,27 @@ export default function ChatPlaygroundPage() { }; fetchModels(); - }, [client]); + }, [client, handleModelChange]); - const extractTextContent = (content: unknown): string => { - if (typeof content === "string") { - return content; - } - if (Array.isArray(content)) { - return content - .filter( - item => - item && - typeof item === "object" && - "type" in item && - item.type === "text" - ) - .map(item => - item && typeof item === "object" && "text" in item - ? String(item.text) - : "" - ) - .join(""); - } + // load agent sessions after both agents and models are ready + useEffect(() => { if ( - content && - typeof content === "object" && - "type" in content && - content.type === "text" && - "text" in content + selectedAgentId && + !agentsLoading && + !modelsLoading && + selectedModel && + !currentSession ) { - return String(content.text) || ""; + loadAgentSessions(selectedAgentId); } - return ""; - }; + }, [ + selectedAgentId, + agentsLoading, + modelsLoading, + selectedModel, + currentSession, + loadAgentSessions, + ]); const handleInputChange = (e: React.ChangeEvent) => { setInput(e.target.value); @@ -91,7 +600,6 @@ export default function ChatPlaygroundPage() { event?.preventDefault?.(); if (!input.trim()) return; - // Add user message to chat const userMessage: Message = { id: Date.now().toString(), role: "user", @@ -99,40 +607,54 @@ export default function ChatPlaygroundPage() { createdAt: new Date(), }; - setMessages(prev => [...prev, userMessage]); + setCurrentSession(prev => { + if (!prev) return prev; + const updatedSession = { + ...prev, + messages: [...prev.messages, userMessage], + updatedAt: Date.now(), + }; + // Update cache with new message + SessionUtils.saveSessionData(prev.agentId, updatedSession); + return updatedSession; + }); setInput(""); - // Use the helper function with the content await handleSubmitWithContent(userMessage.content); }; const handleSubmitWithContent = async (content: string) => { + if (!currentSession || !selectedAgentId) return; + setIsGenerating(true); setError(null); - try { - const messageParams: CompletionCreateParams["messages"] = [ - ...messages.map(msg => { - const msgContent = - typeof msg.content === "string" - ? msg.content - : extractTextContent(msg.content); - if (msg.role === "user") { - return { role: "user" as const, content: msgContent }; - } else if (msg.role === "assistant") { - return { role: "assistant" as const, content: msgContent }; - } else { - return { role: "system" as const, content: msgContent }; - } - }), - { role: "user" as const, content }, - ]; + if (abortControllerRef.current) { + abortControllerRef.current.abort(); + } - const response = await client.chat.completions.create({ - model: selectedModel, - messages: messageParams, + const abortController = new AbortController(); + abortControllerRef.current = abortController; + + try { + const userMessage = { + role: "user" as const, + content, + }; + + const turnParams: TurnCreateParams = { + messages: [userMessage], stream: true, - }); + }; + + const response = await client.agents.turn.create( + selectedAgentId, + currentSession.id, + turnParams, + { + signal: abortController.signal, + } as { signal: AbortSignal } + ); const assistantMessage: Message = { id: (Date.now() + 1).toString(), @@ -141,31 +663,112 @@ export default function ChatPlaygroundPage() { createdAt: new Date(), }; - setMessages(prev => [...prev, assistantMessage]); + const extractDeltaText = (chunk: unknown): string | null => { + // this is an awful way to handle different chunk formats, but i'm not sure if there's much of a better way + if (chunk?.delta?.text && typeof chunk.delta.text === "string") { + return chunk.delta.text; + } + + if ( + chunk?.event?.delta?.text && + typeof chunk.event.delta.text === "string" + ) { + return chunk.event.delta.text; + } + + if ( + chunk?.choices?.[0]?.delta?.content && + typeof chunk.choices[0].delta.content === "string" + ) { + return chunk.choices[0].delta.content; + } + + if (typeof chunk === "string") { + return chunk; + } + + if ( + chunk?.event?.payload?.delta?.text && + typeof chunk.event.payload.delta.text === "string" + ) { + return chunk.event.payload.delta.text; + } + + if (process.env.NODE_ENV !== "production") { + console.debug("Unrecognized chunk format:", chunk); + } + + return null; + }; + setCurrentSession(prev => { + if (!prev) return null; + const updatedSession = { + ...prev, + messages: [...prev.messages, assistantMessage], + updatedAt: Date.now(), + }; + // update cache with assistant message + SessionUtils.saveSessionData(prev.agentId, updatedSession); + return updatedSession; + }); + let fullContent = ""; for await (const chunk of response) { - if (chunk.choices && chunk.choices[0]?.delta?.content) { - const deltaContent = chunk.choices[0].delta.content; - fullContent += deltaContent; + const deltaText = extractDeltaText(chunk); + + if (deltaText) { + fullContent += deltaText; flushSync(() => { - setMessages(prev => { - const newMessages = [...prev]; - const lastMessage = newMessages[newMessages.length - 1]; - if (lastMessage.role === "assistant") { - lastMessage.content = fullContent; + setCurrentSession(prev => { + if (!prev) return null; + const newMessages = [...prev.messages]; + const last = newMessages[newMessages.length - 1]; + if (last.role === "assistant") { + last.content = fullContent; } - return newMessages; + const updatedSession = { + ...prev, + messages: newMessages, + updatedAt: Date.now(), + }; + // update cache with streaming content (throttled) + if (fullContent.length % 100 === 0) { + // Only cache every 100 characters to avoid spam + SessionUtils.saveSessionData(prev.agentId, updatedSession); + } + return updatedSession; }); }); } } } catch (err) { + if (err instanceof Error && err.name === "AbortError") { + console.log("Request aborted"); + return; + } + console.error("Error sending message:", err); setError("Failed to send message. Please try again."); - setMessages(prev => prev.slice(0, -1)); + setCurrentSession(prev => + prev + ? { + ...prev, + messages: prev.messages.slice(0, -1), + updatedAt: Date.now(), + } + : prev + ); } finally { setIsGenerating(false); + abortControllerRef.current = null; + // cache final session state after streaming completes + setCurrentSession(prev => { + if (prev) { + SessionUtils.saveSessionData(prev.agentId, prev); + } + return prev; + }); } }; const suggestions = [ @@ -181,69 +784,457 @@ export default function ChatPlaygroundPage() { content: message.content, createdAt: new Date(), }; - setMessages(prev => [...prev, newMessage]); + setCurrentSession(prev => + prev + ? { + ...prev, + messages: [...prev.messages, newMessage], + updatedAt: Date.now(), + } + : prev + ); handleSubmitWithContent(newMessage.content); }; const clearChat = () => { - setMessages([]); + if (abortControllerRef.current) { + abortControllerRef.current.abort(); + abortControllerRef.current = null; + setIsGenerating(false); + } + + setCurrentSession(prev => + prev ? { ...prev, messages: [], updatedAt: Date.now() } : prev + ); setError(null); }; return ( -
-
-

Chat Playground (Completions)

-
- - +
+ {/* Header */} +
+
+

Agent Session

+
+ {!agentsLoading && agents.length > 0 && ( +
+ + + {selectedAgentId && agents.length > 1 && ( + + )} +
+ )} + + {!agentsLoading && agents.length > 0 && ( + + )} +
+
+
+ {/* Main Two-Column Layout */} +
+ {/* Left Column - Configuration Panel */} +
+

+ Settings +

+ + {/* Model Configuration */} +
+

+ Model Configuration +

+
+
+ + + {modelsError && ( +

{modelsError}

+ )} +
+ +
+ +
+ {(selectedAgentId && + agents.find(a => a.agent_id === selectedAgentId) + ?.agent_config?.instructions) || + "No agent selected"} +
+

+ Instructions are set when creating an agent and cannot be + changed. +

+
+
+
+ + {/* Agent Tools */} +
+

+ Agent Tools +

+
+
+ +
+ {selectedAgentConfig?.toolgroups && + selectedAgentConfig.toolgroups.length > 0 ? ( + selectedAgentConfig.toolgroups.map( + ( + toolgroup: + | string + | { name: string; args: Record }, + index: number + ) => { + const toolName = + typeof toolgroup === "string" + ? toolgroup + : toolgroup.name; + const toolArgs = + typeof toolgroup === "object" ? toolgroup.args : null; + + return ( +
+
+ + {toolName} + + + {toolName.includes("rag") + ? "🔍 RAG" + : toolName.includes("search") + ? "🌐 Search" + : "🔧 Tool"} + +
+ {toolArgs && Object.keys(toolArgs).length > 0 && ( +
+ Args:{" "} + {Object.entries(toolArgs) + .map( + ([key, value]) => + `${key}: ${JSON.stringify(value)}` + ) + .join(", ")} +
+ )} +
+ ); + } + ) + ) : ( +
+

+ No tools configured +

+

+ This agent only has text generation capabilities +

+
+ )} +
+

+ Tools are configured when creating an agent and provide + additional capabilities like web search, math calculations, or + RAG document retrieval. +

+
+
+
+
+ + {/* Right Column - Chat Interface */} +
+ {error && ( +
+

{error}

+
+ )} + + + setCurrentSession(prev => + prev ? { ...prev, messages, updatedAt: Date.now() } : prev + ) + } + />
- {modelsError && ( -
-

{modelsError}

+ {/* Create Agent Modal */} + {showCreateAgent && ( +
+ +

Create New Agent

+ +
+
+ + setNewAgentName(e.target.value)} + placeholder="My Custom Agent" + /> +
+ +
+ + +
+ +
+ +