mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 17:22:38 +00:00
Merge branch 'main' into precommit-ui-linter
This commit is contained in:
commit
8f6c4eb9e6
114 changed files with 3912 additions and 454 deletions
12
.github/dependabot.yml
vendored
12
.github/dependabot.yml
vendored
|
|
@ -9,6 +9,7 @@ updates:
|
||||||
day: "saturday"
|
day: "saturday"
|
||||||
commit-message:
|
commit-message:
|
||||||
prefix: chore(github-deps)
|
prefix: chore(github-deps)
|
||||||
|
|
||||||
- package-ecosystem: "uv"
|
- package-ecosystem: "uv"
|
||||||
directory: "/"
|
directory: "/"
|
||||||
schedule:
|
schedule:
|
||||||
|
|
@ -19,3 +20,14 @@ updates:
|
||||||
- python
|
- python
|
||||||
commit-message:
|
commit-message:
|
||||||
prefix: chore(python-deps)
|
prefix: chore(python-deps)
|
||||||
|
|
||||||
|
- package-ecosystem: npm
|
||||||
|
directory: "/llama_stack/ui"
|
||||||
|
schedule:
|
||||||
|
interval: "weekly"
|
||||||
|
day: "saturday"
|
||||||
|
labels:
|
||||||
|
- type/dependencies
|
||||||
|
- javascript
|
||||||
|
commit-message:
|
||||||
|
prefix: chore(ui-deps)
|
||||||
|
|
|
||||||
2
.github/workflows/changelog.yml
vendored
2
.github/workflows/changelog.yml
vendored
|
|
@ -17,7 +17,7 @@ jobs:
|
||||||
pull-requests: write # for peter-evans/create-pull-request to create a PR
|
pull-requests: write # for peter-evans/create-pull-request to create a PR
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
- uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||||
with:
|
with:
|
||||||
ref: main
|
ref: main
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
|
|
||||||
4
.github/workflows/install-script-ci.yml
vendored
4
.github/workflows/install-script-ci.yml
vendored
|
|
@ -16,14 +16,14 @@ jobs:
|
||||||
lint:
|
lint:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # 4.2.2
|
- uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # 5.0.0
|
||||||
- name: Run ShellCheck on install.sh
|
- name: Run ShellCheck on install.sh
|
||||||
run: shellcheck scripts/install.sh
|
run: shellcheck scripts/install.sh
|
||||||
smoke-test-on-dev:
|
smoke-test-on-dev:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
uses: ./.github/actions/setup-runner
|
uses: ./.github/actions/setup-runner
|
||||||
|
|
|
||||||
4
.github/workflows/integration-auth-tests.yml
vendored
4
.github/workflows/integration-auth-tests.yml
vendored
|
|
@ -18,7 +18,7 @@ on:
|
||||||
- '.github/workflows/integration-auth-tests.yml' # This workflow
|
- '.github/workflows/integration-auth-tests.yml' # This workflow
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.ref }}
|
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
|
@ -31,7 +31,7 @@ jobs:
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
uses: ./.github/actions/setup-runner
|
uses: ./.github/actions/setup-runner
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ on:
|
||||||
- '.github/workflows/integration-sql-store-tests.yml' # This workflow
|
- '.github/workflows/integration-sql-store-tests.yml' # This workflow
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.ref }}
|
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
|
@ -44,7 +44,7 @@ jobs:
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
uses: ./.github/actions/setup-runner
|
uses: ./.github/actions/setup-runner
|
||||||
|
|
|
||||||
2
.github/workflows/integration-tests.yml
vendored
2
.github/workflows/integration-tests.yml
vendored
|
|
@ -65,7 +65,7 @@ jobs:
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||||
|
|
||||||
- name: Setup test environment
|
- name: Setup test environment
|
||||||
uses: ./.github/actions/setup-test-environment
|
uses: ./.github/actions/setup-test-environment
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@ jobs:
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
uses: ./.github/actions/setup-runner
|
uses: ./.github/actions/setup-runner
|
||||||
|
|
|
||||||
4
.github/workflows/pre-commit.yml
vendored
4
.github/workflows/pre-commit.yml
vendored
|
|
@ -8,7 +8,7 @@ on:
|
||||||
branches: [main]
|
branches: [main]
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.ref }}
|
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
|
@ -20,7 +20,7 @@ jobs:
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||||
with:
|
with:
|
||||||
# For dependabot PRs, we need to checkout with a token that can push changes
|
# For dependabot PRs, we need to checkout with a token that can push changes
|
||||||
token: ${{ github.actor == 'dependabot[bot]' && secrets.GITHUB_TOKEN || github.token }}
|
token: ${{ github.actor == 'dependabot[bot]' && secrets.GITHUB_TOKEN || github.token }}
|
||||||
|
|
|
||||||
20
.github/workflows/providers-build.yml
vendored
20
.github/workflows/providers-build.yml
vendored
|
|
@ -26,7 +26,7 @@ on:
|
||||||
- 'pyproject.toml'
|
- 'pyproject.toml'
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.ref }}
|
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
|
@ -36,7 +36,7 @@ jobs:
|
||||||
distros: ${{ steps.set-matrix.outputs.distros }}
|
distros: ${{ steps.set-matrix.outputs.distros }}
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||||
|
|
||||||
- name: Generate Distribution List
|
- name: Generate Distribution List
|
||||||
id: set-matrix
|
id: set-matrix
|
||||||
|
|
@ -55,7 +55,7 @@ jobs:
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
uses: ./.github/actions/setup-runner
|
uses: ./.github/actions/setup-runner
|
||||||
|
|
@ -79,7 +79,7 @@ jobs:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
uses: ./.github/actions/setup-runner
|
uses: ./.github/actions/setup-runner
|
||||||
|
|
@ -92,7 +92,7 @@ jobs:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
uses: ./.github/actions/setup-runner
|
uses: ./.github/actions/setup-runner
|
||||||
|
|
@ -106,6 +106,10 @@ jobs:
|
||||||
- name: Inspect the container image entrypoint
|
- name: Inspect the container image entrypoint
|
||||||
run: |
|
run: |
|
||||||
IMAGE_ID=$(docker images --format "{{.Repository}}:{{.Tag}}" | head -n 1)
|
IMAGE_ID=$(docker images --format "{{.Repository}}:{{.Tag}}" | head -n 1)
|
||||||
|
if [ -z "$IMAGE_ID" ]; then
|
||||||
|
echo "No image found"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID)
|
entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID)
|
||||||
echo "Entrypoint: $entrypoint"
|
echo "Entrypoint: $entrypoint"
|
||||||
if [ "$entrypoint" != "[python -m llama_stack.core.server.server /app/run.yaml]" ]; then
|
if [ "$entrypoint" != "[python -m llama_stack.core.server.server /app/run.yaml]" ]; then
|
||||||
|
|
@ -117,7 +121,7 @@ jobs:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
uses: ./.github/actions/setup-runner
|
uses: ./.github/actions/setup-runner
|
||||||
|
|
@ -140,6 +144,10 @@ jobs:
|
||||||
- name: Inspect UBI9 image
|
- name: Inspect UBI9 image
|
||||||
run: |
|
run: |
|
||||||
IMAGE_ID=$(docker images --format "{{.Repository}}:{{.Tag}}" | head -n 1)
|
IMAGE_ID=$(docker images --format "{{.Repository}}:{{.Tag}}" | head -n 1)
|
||||||
|
if [ -z "$IMAGE_ID" ]; then
|
||||||
|
echo "No image found"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID)
|
entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID)
|
||||||
echo "Entrypoint: $entrypoint"
|
echo "Entrypoint: $entrypoint"
|
||||||
if [ "$entrypoint" != "[python -m llama_stack.core.server.server /app/run.yaml]" ]; then
|
if [ "$entrypoint" != "[python -m llama_stack.core.server.server /app/run.yaml]" ]; then
|
||||||
|
|
|
||||||
4
.github/workflows/python-build-test.yml
vendored
4
.github/workflows/python-build-test.yml
vendored
|
|
@ -21,10 +21,10 @@ jobs:
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||||
|
|
||||||
- name: Install uv
|
- name: Install uv
|
||||||
uses: astral-sh/setup-uv@e92bafb6253dcd438e0484186d7669ea7a8ca1cc # v6.4.3
|
uses: astral-sh/setup-uv@d9e0f98d3fc6adb07d1e3d37f3043649ddad06a1 # v6.5.0
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
activate-environment: true
|
activate-environment: true
|
||||||
|
|
|
||||||
|
|
@ -46,7 +46,7 @@ jobs:
|
||||||
echo "::endgroup::"
|
echo "::endgroup::"
|
||||||
|
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
|
||||||
|
|
|
||||||
2
.github/workflows/semantic-pr.yml
vendored
2
.github/workflows/semantic-pr.yml
vendored
|
|
@ -22,6 +22,6 @@ jobs:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Check PR Title's semantic conformance
|
- name: Check PR Title's semantic conformance
|
||||||
uses: amannn/action-semantic-pull-request@0723387faaf9b38adef4775cd42cfd5155ed6017 # v5.5.3
|
uses: amannn/action-semantic-pull-request@7f33ba792281b034f64e96f4c0b5496782dd3b37 # v6.1.0
|
||||||
env:
|
env:
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,7 @@ jobs:
|
||||||
# container and point 'uv pip install' to the correct path...
|
# container and point 'uv pip install' to the correct path...
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
uses: ./.github/actions/setup-runner
|
uses: ./.github/actions/setup-runner
|
||||||
|
|
|
||||||
2
.github/workflows/test-external.yml
vendored
2
.github/workflows/test-external.yml
vendored
|
|
@ -27,7 +27,7 @@ jobs:
|
||||||
# container and point 'uv pip install' to the correct path...
|
# container and point 'uv pip install' to the correct path...
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
uses: ./.github/actions/setup-runner
|
uses: ./.github/actions/setup-runner
|
||||||
|
|
|
||||||
6
.github/workflows/ui-unit-tests.yml
vendored
6
.github/workflows/ui-unit-tests.yml
vendored
|
|
@ -13,7 +13,7 @@ on:
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.ref }}
|
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
|
@ -26,10 +26,10 @@ jobs:
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||||
|
|
||||||
- name: Setup Node.js
|
- name: Setup Node.js
|
||||||
uses: actions/setup-node@39370e3970a6d050c480ffad4ff0ed4d3fdee5af # v4.1.0
|
uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4.4.0
|
||||||
with:
|
with:
|
||||||
node-version: ${{ matrix.node-version }}
|
node-version: ${{ matrix.node-version }}
|
||||||
cache: 'npm'
|
cache: 'npm'
|
||||||
|
|
|
||||||
4
.github/workflows/unit-tests.yml
vendored
4
.github/workflows/unit-tests.yml
vendored
|
|
@ -18,7 +18,7 @@ on:
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.ref }}
|
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
|
@ -32,7 +32,7 @@ jobs:
|
||||||
- "3.13"
|
- "3.13"
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
uses: ./.github/actions/setup-runner
|
uses: ./.github/actions/setup-runner
|
||||||
|
|
|
||||||
4
.github/workflows/update-readthedocs.yml
vendored
4
.github/workflows/update-readthedocs.yml
vendored
|
|
@ -27,7 +27,7 @@ on:
|
||||||
- '.github/workflows/update-readthedocs.yml'
|
- '.github/workflows/update-readthedocs.yml'
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.ref }}
|
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
|
@ -37,7 +37,7 @@ jobs:
|
||||||
TOKEN: ${{ secrets.READTHEDOCS_TOKEN }}
|
TOKEN: ${{ secrets.READTHEDOCS_TOKEN }}
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
uses: ./.github/actions/setup-runner
|
uses: ./.github/actions/setup-runner
|
||||||
|
|
|
||||||
132
docs/_static/llama-stack-spec.html
vendored
132
docs/_static/llama-stack-spec.html
vendored
|
|
@ -4605,6 +4605,49 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"/v1/inference/rerank": {
|
||||||
|
"post": {
|
||||||
|
"responses": {
|
||||||
|
"200": {
|
||||||
|
"description": "RerankResponse with indices sorted by relevance score (descending).",
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/RerankResponse"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"400": {
|
||||||
|
"$ref": "#/components/responses/BadRequest400"
|
||||||
|
},
|
||||||
|
"429": {
|
||||||
|
"$ref": "#/components/responses/TooManyRequests429"
|
||||||
|
},
|
||||||
|
"500": {
|
||||||
|
"$ref": "#/components/responses/InternalServerError500"
|
||||||
|
},
|
||||||
|
"default": {
|
||||||
|
"$ref": "#/components/responses/DefaultError"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"tags": [
|
||||||
|
"Inference"
|
||||||
|
],
|
||||||
|
"description": "Rerank a list of documents based on their relevance to a query.",
|
||||||
|
"parameters": [],
|
||||||
|
"requestBody": {
|
||||||
|
"content": {
|
||||||
|
"application/json": {
|
||||||
|
"schema": {
|
||||||
|
"$ref": "#/components/schemas/RerankRequest"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume": {
|
"/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume": {
|
||||||
"post": {
|
"post": {
|
||||||
"responses": {
|
"responses": {
|
||||||
|
|
@ -16587,6 +16630,95 @@
|
||||||
],
|
],
|
||||||
"title": "RegisterVectorDbRequest"
|
"title": "RegisterVectorDbRequest"
|
||||||
},
|
},
|
||||||
|
"RerankRequest": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"model": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The identifier of the reranking model to use."
|
||||||
|
},
|
||||||
|
"query": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartImageParam"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"description": "The search query to rank items against. Can be a string, text content part, or image content part. The input must not exceed the model's max input token length."
|
||||||
|
},
|
||||||
|
"items": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartTextParam"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/OpenAIChatCompletionContentPartImageParam"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"description": "List of items to rerank. Each item can be a string, text content part, or image content part. Each input must not exceed the model's max input token length."
|
||||||
|
},
|
||||||
|
"max_num_results": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "(Optional) Maximum number of results to return. Default: returns all."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"model",
|
||||||
|
"query",
|
||||||
|
"items"
|
||||||
|
],
|
||||||
|
"title": "RerankRequest"
|
||||||
|
},
|
||||||
|
"RerankData": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"index": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "The original index of the document in the input list"
|
||||||
|
},
|
||||||
|
"relevance_score": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "The relevance score from the model output. Values are inverted when applicable so that higher scores indicate greater relevance."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"index",
|
||||||
|
"relevance_score"
|
||||||
|
],
|
||||||
|
"title": "RerankData",
|
||||||
|
"description": "A single rerank result from a reranking response."
|
||||||
|
},
|
||||||
|
"RerankResponse": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"data": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"$ref": "#/components/schemas/RerankData"
|
||||||
|
},
|
||||||
|
"description": "List of rerank result objects, sorted by relevance score (descending)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"data"
|
||||||
|
],
|
||||||
|
"title": "RerankResponse",
|
||||||
|
"description": "Response from a reranking request."
|
||||||
|
},
|
||||||
"ResumeAgentTurnRequest": {
|
"ResumeAgentTurnRequest": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
|
|
||||||
101
docs/_static/llama-stack-spec.yaml
vendored
101
docs/_static/llama-stack-spec.yaml
vendored
|
|
@ -3264,6 +3264,37 @@ paths:
|
||||||
schema:
|
schema:
|
||||||
$ref: '#/components/schemas/QueryTracesRequest'
|
$ref: '#/components/schemas/QueryTracesRequest'
|
||||||
required: true
|
required: true
|
||||||
|
/v1/inference/rerank:
|
||||||
|
post:
|
||||||
|
responses:
|
||||||
|
'200':
|
||||||
|
description: >-
|
||||||
|
RerankResponse with indices sorted by relevance score (descending).
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/RerankResponse'
|
||||||
|
'400':
|
||||||
|
$ref: '#/components/responses/BadRequest400'
|
||||||
|
'429':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/TooManyRequests429
|
||||||
|
'500':
|
||||||
|
$ref: >-
|
||||||
|
#/components/responses/InternalServerError500
|
||||||
|
default:
|
||||||
|
$ref: '#/components/responses/DefaultError'
|
||||||
|
tags:
|
||||||
|
- Inference
|
||||||
|
description: >-
|
||||||
|
Rerank a list of documents based on their relevance to a query.
|
||||||
|
parameters: []
|
||||||
|
requestBody:
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: '#/components/schemas/RerankRequest'
|
||||||
|
required: true
|
||||||
/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume:
|
/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume:
|
||||||
post:
|
post:
|
||||||
responses:
|
responses:
|
||||||
|
|
@ -12337,6 +12368,76 @@ components:
|
||||||
- vector_db_id
|
- vector_db_id
|
||||||
- embedding_model
|
- embedding_model
|
||||||
title: RegisterVectorDbRequest
|
title: RegisterVectorDbRequest
|
||||||
|
RerankRequest:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
model:
|
||||||
|
type: string
|
||||||
|
description: >-
|
||||||
|
The identifier of the reranking model to use.
|
||||||
|
query:
|
||||||
|
oneOf:
|
||||||
|
- type: string
|
||||||
|
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
|
||||||
|
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartImageParam'
|
||||||
|
description: >-
|
||||||
|
The search query to rank items against. Can be a string, text content
|
||||||
|
part, or image content part. The input must not exceed the model's max
|
||||||
|
input token length.
|
||||||
|
items:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
oneOf:
|
||||||
|
- type: string
|
||||||
|
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartTextParam'
|
||||||
|
- $ref: '#/components/schemas/OpenAIChatCompletionContentPartImageParam'
|
||||||
|
description: >-
|
||||||
|
List of items to rerank. Each item can be a string, text content part,
|
||||||
|
or image content part. Each input must not exceed the model's max input
|
||||||
|
token length.
|
||||||
|
max_num_results:
|
||||||
|
type: integer
|
||||||
|
description: >-
|
||||||
|
(Optional) Maximum number of results to return. Default: returns all.
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- model
|
||||||
|
- query
|
||||||
|
- items
|
||||||
|
title: RerankRequest
|
||||||
|
RerankData:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
index:
|
||||||
|
type: integer
|
||||||
|
description: >-
|
||||||
|
The original index of the document in the input list
|
||||||
|
relevance_score:
|
||||||
|
type: number
|
||||||
|
description: >-
|
||||||
|
The relevance score from the model output. Values are inverted when applicable
|
||||||
|
so that higher scores indicate greater relevance.
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- index
|
||||||
|
- relevance_score
|
||||||
|
title: RerankData
|
||||||
|
description: >-
|
||||||
|
A single rerank result from a reranking response.
|
||||||
|
RerankResponse:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
data:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
$ref: '#/components/schemas/RerankData'
|
||||||
|
description: >-
|
||||||
|
List of rerank result objects, sorted by relevance score (descending)
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- data
|
||||||
|
title: RerankResponse
|
||||||
|
description: Response from a reranking request.
|
||||||
ResumeAgentTurnRequest:
|
ResumeAgentTurnRequest:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
|
|
||||||
|
|
@ -225,8 +225,32 @@ server:
|
||||||
port: 8321 # Port to listen on (default: 8321)
|
port: 8321 # Port to listen on (default: 8321)
|
||||||
tls_certfile: "/path/to/cert.pem" # Optional: Path to TLS certificate for HTTPS
|
tls_certfile: "/path/to/cert.pem" # Optional: Path to TLS certificate for HTTPS
|
||||||
tls_keyfile: "/path/to/key.pem" # Optional: Path to TLS key for HTTPS
|
tls_keyfile: "/path/to/key.pem" # Optional: Path to TLS key for HTTPS
|
||||||
|
cors: true # Optional: Enable CORS (dev mode) or full config object
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### CORS Configuration
|
||||||
|
|
||||||
|
CORS (Cross-Origin Resource Sharing) can be configured in two ways:
|
||||||
|
|
||||||
|
**Local development** (allows localhost origins only):
|
||||||
|
```yaml
|
||||||
|
server:
|
||||||
|
cors: true
|
||||||
|
```
|
||||||
|
|
||||||
|
**Explicit configuration** (custom origins and settings):
|
||||||
|
```yaml
|
||||||
|
server:
|
||||||
|
cors:
|
||||||
|
allow_origins: ["https://myapp.com", "https://app.example.com"]
|
||||||
|
allow_methods: ["GET", "POST", "PUT", "DELETE"]
|
||||||
|
allow_headers: ["Content-Type", "Authorization"]
|
||||||
|
allow_credentials: true
|
||||||
|
max_age: 3600
|
||||||
|
```
|
||||||
|
|
||||||
|
When `cors: true`, the server enables secure localhost-only access for local development. For production, specify exact origins to maintain security.
|
||||||
|
|
||||||
### Authentication Configuration
|
### Authentication Configuration
|
||||||
|
|
||||||
> **Breaking Change (v0.2.14)**: The authentication configuration structure has changed. The previous format with `provider_type` and `config` fields has been replaced with a unified `provider_config` field that includes the `type` field. Update your configuration files accordingly.
|
> **Breaking Change (v0.2.14)**: The authentication configuration structure has changed. The previous format with `provider_type` and `config` fields has been replaced with a unified `provider_config` field that includes the `type` field. Update your configuration files accordingly.
|
||||||
|
|
@ -618,6 +642,54 @@ Content-Type: application/json
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### CORS Configuration
|
||||||
|
|
||||||
|
Configure CORS to allow web browsers to make requests from different domains. Disabled by default.
|
||||||
|
|
||||||
|
#### Quick Setup
|
||||||
|
|
||||||
|
For development, use the simple boolean flag:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
server:
|
||||||
|
cors: true # Auto-enables localhost with any port
|
||||||
|
```
|
||||||
|
|
||||||
|
This automatically allows `http://localhost:*` and `https://localhost:*` with secure defaults.
|
||||||
|
|
||||||
|
#### Custom Configuration
|
||||||
|
|
||||||
|
For specific origins and full control:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
server:
|
||||||
|
cors:
|
||||||
|
allow_origins: ["https://myapp.com", "https://staging.myapp.com"]
|
||||||
|
allow_credentials: true
|
||||||
|
allow_methods: ["GET", "POST", "PUT", "DELETE"]
|
||||||
|
allow_headers: ["Content-Type", "Authorization"]
|
||||||
|
allow_origin_regex: "https://.*\\.example\\.com" # Optional regex pattern
|
||||||
|
expose_headers: ["X-Total-Count"]
|
||||||
|
max_age: 86400
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Configuration Options
|
||||||
|
|
||||||
|
| Field | Description | Default |
|
||||||
|
| -------------------- | ---------------------------------------------- | ------- |
|
||||||
|
| `allow_origins` | List of allowed origins. Use `["*"]` for any. | `["*"]` |
|
||||||
|
| `allow_origin_regex` | Regex pattern for allowed origins (optional). | `None` |
|
||||||
|
| `allow_methods` | Allowed HTTP methods. | `["*"]` |
|
||||||
|
| `allow_headers` | Allowed headers. | `["*"]` |
|
||||||
|
| `allow_credentials` | Allow credentials (cookies, auth headers). | `false` |
|
||||||
|
| `expose_headers` | Headers exposed to browser. | `[]` |
|
||||||
|
| `max_age` | Preflight cache time (seconds). | `600` |
|
||||||
|
|
||||||
|
**Security Notes**:
|
||||||
|
- `allow_credentials: true` requires explicit origins (no wildcards)
|
||||||
|
- `cors: true` enables localhost access only (secure for development)
|
||||||
|
- For public APIs, always specify exact allowed origins
|
||||||
|
|
||||||
## Extending to handle Safety
|
## Extending to handle Safety
|
||||||
|
|
||||||
Configuring Safety can be a little involved so it is instructive to go through an example.
|
Configuring Safety can be a little involved so it is instructive to go through an example.
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,6 @@ client = LlamaStackAsLibraryClient(
|
||||||
# provider_data is optional, but if you need to pass in any provider specific data, you can do so here.
|
# provider_data is optional, but if you need to pass in any provider specific data, you can do so here.
|
||||||
provider_data={"tavily_search_api_key": os.environ["TAVILY_SEARCH_API_KEY"]},
|
provider_data={"tavily_search_api_key": os.environ["TAVILY_SEARCH_API_KEY"]},
|
||||||
)
|
)
|
||||||
client.initialize()
|
|
||||||
```
|
```
|
||||||
|
|
||||||
This will parse your config and set up any inline implementations and remote clients needed for your implementation.
|
This will parse your config and set up any inline implementations and remote clients needed for your implementation.
|
||||||
|
|
@ -32,5 +31,4 @@ If you've created a [custom distribution](https://llama-stack.readthedocs.io/en/
|
||||||
|
|
||||||
```python
|
```python
|
||||||
client = LlamaStackAsLibraryClient(config_path)
|
client = LlamaStackAsLibraryClient(config_path)
|
||||||
client.initialize()
|
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -473,6 +473,28 @@ class EmbeddingsResponse(BaseModel):
|
||||||
embeddings: list[list[float]]
|
embeddings: list[list[float]]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RerankData(BaseModel):
|
||||||
|
"""A single rerank result from a reranking response.
|
||||||
|
|
||||||
|
:param index: The original index of the document in the input list
|
||||||
|
:param relevance_score: The relevance score from the model output. Values are inverted when applicable so that higher scores indicate greater relevance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
index: int
|
||||||
|
relevance_score: float
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RerankResponse(BaseModel):
|
||||||
|
"""Response from a reranking request.
|
||||||
|
|
||||||
|
:param data: List of rerank result objects, sorted by relevance score (descending)
|
||||||
|
"""
|
||||||
|
|
||||||
|
data: list[RerankData]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIChatCompletionContentPartTextParam(BaseModel):
|
class OpenAIChatCompletionContentPartTextParam(BaseModel):
|
||||||
"""Text content part for OpenAI-compatible chat completion messages.
|
"""Text content part for OpenAI-compatible chat completion messages.
|
||||||
|
|
@ -1131,6 +1153,24 @@ class InferenceProvider(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@webmethod(route="/inference/rerank", method="POST", experimental=True)
|
||||||
|
async def rerank(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
|
||||||
|
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
|
||||||
|
max_num_results: int | None = None,
|
||||||
|
) -> RerankResponse:
|
||||||
|
"""Rerank a list of documents based on their relevance to a query.
|
||||||
|
|
||||||
|
:param model: The identifier of the reranking model to use.
|
||||||
|
:param query: The search query to rank items against. Can be a string, text content part, or image content part. The input must not exceed the model's max input token length.
|
||||||
|
:param items: List of items to rerank. Each item can be a string, text content part, or image content part. Each input must not exceed the model's max input token length.
|
||||||
|
:param max_num_results: (Optional) Maximum number of results to return. Default: returns all.
|
||||||
|
:returns: RerankResponse with indices sorted by relevance score (descending).
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Reranking is not implemented")
|
||||||
|
|
||||||
@webmethod(route="/openai/v1/completions", method="POST")
|
@webmethod(route="/openai/v1/completions", method="POST")
|
||||||
async def openai_completion(
|
async def openai_completion(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="server")
|
logger = get_logger(name=__name__, category="cli")
|
||||||
|
|
||||||
|
|
||||||
class StackRun(Subcommand):
|
class StackRun(Subcommand):
|
||||||
|
|
|
||||||
|
|
@ -318,6 +318,41 @@ class QuotaConfig(BaseModel):
|
||||||
period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set")
|
period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set")
|
||||||
|
|
||||||
|
|
||||||
|
class CORSConfig(BaseModel):
|
||||||
|
allow_origins: list[str] = Field(default_factory=list)
|
||||||
|
allow_origin_regex: str | None = Field(default=None)
|
||||||
|
allow_methods: list[str] = Field(default=["OPTIONS"])
|
||||||
|
allow_headers: list[str] = Field(default_factory=list)
|
||||||
|
allow_credentials: bool = Field(default=False)
|
||||||
|
expose_headers: list[str] = Field(default_factory=list)
|
||||||
|
max_age: int = Field(default=600, ge=0)
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_credentials_config(self) -> Self:
|
||||||
|
if self.allow_credentials and (self.allow_origins == ["*"] or "*" in self.allow_origins):
|
||||||
|
raise ValueError("Cannot use wildcard origins with credentials enabled")
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
def process_cors_config(cors_config: bool | CORSConfig | None) -> CORSConfig | None:
|
||||||
|
if cors_config is False or cors_config is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if cors_config is True:
|
||||||
|
# dev mode: allow localhost on any port
|
||||||
|
return CORSConfig(
|
||||||
|
allow_origins=[],
|
||||||
|
allow_origin_regex=r"https?://localhost:\d+",
|
||||||
|
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
||||||
|
allow_headers=["Content-Type", "Authorization", "X-Requested-With"],
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(cors_config, CORSConfig):
|
||||||
|
return cors_config
|
||||||
|
|
||||||
|
raise ValueError(f"Expected bool or CORSConfig, got {type(cors_config).__name__}")
|
||||||
|
|
||||||
|
|
||||||
class ServerConfig(BaseModel):
|
class ServerConfig(BaseModel):
|
||||||
port: int = Field(
|
port: int = Field(
|
||||||
default=8321,
|
default=8321,
|
||||||
|
|
@ -349,6 +384,12 @@ class ServerConfig(BaseModel):
|
||||||
default=None,
|
default=None,
|
||||||
description="Per client quota request configuration",
|
description="Per client quota request configuration",
|
||||||
)
|
)
|
||||||
|
cors: bool | CORSConfig | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="CORS configuration for cross-origin requests. Can be:\n"
|
||||||
|
"- true: Enable localhost CORS for development\n"
|
||||||
|
"- {allow_origins: [...], allow_methods: [...], ...}: Full configuration",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class StackRunConfig(BaseModel):
|
class StackRunConfig(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -146,39 +146,26 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.async_client = AsyncLlamaStackAsLibraryClient(
|
self.async_client = AsyncLlamaStackAsLibraryClient(
|
||||||
config_path_or_distro_name, custom_provider_registry, provider_data
|
config_path_or_distro_name, custom_provider_registry, provider_data, skip_logger_removal
|
||||||
)
|
)
|
||||||
self.pool_executor = ThreadPoolExecutor(max_workers=4)
|
self.pool_executor = ThreadPoolExecutor(max_workers=4)
|
||||||
self.skip_logger_removal = skip_logger_removal
|
|
||||||
self.provider_data = provider_data
|
self.provider_data = provider_data
|
||||||
|
|
||||||
self.loop = asyncio.new_event_loop()
|
self.loop = asyncio.new_event_loop()
|
||||||
|
|
||||||
def initialize(self):
|
|
||||||
if in_notebook():
|
|
||||||
import nest_asyncio
|
|
||||||
|
|
||||||
nest_asyncio.apply()
|
|
||||||
if not self.skip_logger_removal:
|
|
||||||
self._remove_root_logger_handlers()
|
|
||||||
|
|
||||||
# use a new event loop to avoid interfering with the main event loop
|
# use a new event loop to avoid interfering with the main event loop
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
try:
|
try:
|
||||||
return loop.run_until_complete(self.async_client.initialize())
|
loop.run_until_complete(self.async_client.initialize())
|
||||||
finally:
|
finally:
|
||||||
asyncio.set_event_loop(None)
|
asyncio.set_event_loop(None)
|
||||||
|
|
||||||
def _remove_root_logger_handlers(self):
|
def initialize(self):
|
||||||
"""
|
"""
|
||||||
Remove all handlers from the root logger. Needed to avoid polluting the console with logs.
|
Deprecated method for backward compatibility.
|
||||||
"""
|
"""
|
||||||
root_logger = logging.getLogger()
|
pass
|
||||||
|
|
||||||
for handler in root_logger.handlers[:]:
|
|
||||||
root_logger.removeHandler(handler)
|
|
||||||
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
|
|
||||||
|
|
||||||
def request(self, *args, **kwargs):
|
def request(self, *args, **kwargs):
|
||||||
loop = self.loop
|
loop = self.loop
|
||||||
|
|
@ -216,6 +203,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
config_path_or_distro_name: str,
|
config_path_or_distro_name: str,
|
||||||
custom_provider_registry: ProviderRegistry | None = None,
|
custom_provider_registry: ProviderRegistry | None = None,
|
||||||
provider_data: dict[str, Any] | None = None,
|
provider_data: dict[str, Any] | None = None,
|
||||||
|
skip_logger_removal: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# when using the library client, we should not log to console since many
|
# when using the library client, we should not log to console since many
|
||||||
|
|
@ -223,6 +211,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",")
|
current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",")
|
||||||
os.environ["TELEMETRY_SINKS"] = ",".join(sink for sink in current_sinks if sink != "console")
|
os.environ["TELEMETRY_SINKS"] = ",".join(sink for sink in current_sinks if sink != "console")
|
||||||
|
|
||||||
|
if in_notebook():
|
||||||
|
import nest_asyncio
|
||||||
|
|
||||||
|
nest_asyncio.apply()
|
||||||
|
if not skip_logger_removal:
|
||||||
|
self._remove_root_logger_handlers()
|
||||||
|
|
||||||
if config_path_or_distro_name.endswith(".yaml"):
|
if config_path_or_distro_name.endswith(".yaml"):
|
||||||
config_path = Path(config_path_or_distro_name)
|
config_path = Path(config_path_or_distro_name)
|
||||||
if not config_path.exists():
|
if not config_path.exists():
|
||||||
|
|
@ -239,7 +234,24 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||||
self.provider_data = provider_data
|
self.provider_data = provider_data
|
||||||
self.route_impls: RouteImpls | None = None # Initialize to None to prevent AttributeError
|
self.route_impls: RouteImpls | None = None # Initialize to None to prevent AttributeError
|
||||||
|
|
||||||
|
def _remove_root_logger_handlers(self):
|
||||||
|
"""
|
||||||
|
Remove all handlers from the root logger. Needed to avoid polluting the console with logs.
|
||||||
|
"""
|
||||||
|
root_logger = logging.getLogger()
|
||||||
|
|
||||||
|
for handler in root_logger.handlers[:]:
|
||||||
|
root_logger.removeHandler(handler)
|
||||||
|
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
|
||||||
|
|
||||||
async def initialize(self) -> bool:
|
async def initialize(self) -> bool:
|
||||||
|
"""
|
||||||
|
Initialize the async client.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if initialization was successful
|
||||||
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.route_impls = None
|
self.route_impls = None
|
||||||
self.impls = await construct_stack(self.config, self.custom_provider_registry)
|
self.impls = await construct_stack(self.config, self.custom_provider_registry)
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ from llama_stack.apis.datasets import DatasetPurpose, DataSource
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import RoutingTable
|
from llama_stack.providers.datatypes import RoutingTable
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routers")
|
||||||
|
|
||||||
|
|
||||||
class DatasetIORouter(DatasetIO):
|
class DatasetIORouter(DatasetIO):
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ from llama_stack.apis.scoring import (
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import RoutingTable
|
from llama_stack.providers.datatypes import RoutingTable
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routers")
|
||||||
|
|
||||||
|
|
||||||
class ScoringRouter(Scoring):
|
class ScoringRouter(Scoring):
|
||||||
|
|
|
||||||
|
|
@ -65,7 +65,7 @@ from llama_stack.providers.datatypes import HealthResponse, HealthStatus, Routin
|
||||||
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
from llama_stack.providers.utils.inference.inference_store import InferenceStore
|
||||||
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="inference")
|
logger = get_logger(name=__name__, category="core::routers")
|
||||||
|
|
||||||
|
|
||||||
class InferenceRouter(Inference):
|
class InferenceRouter(Inference):
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ from llama_stack.apis.shields import Shield
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import RoutingTable
|
from llama_stack.providers.datatypes import RoutingTable
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routers")
|
||||||
|
|
||||||
|
|
||||||
class SafetyRouter(Safety):
|
class SafetyRouter(Safety):
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
|
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routers")
|
||||||
|
|
||||||
|
|
||||||
class ToolRuntimeRouter(ToolRuntime):
|
class ToolRuntimeRouter(ToolRuntime):
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ from llama_stack.apis.vector_io import (
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routers")
|
||||||
|
|
||||||
|
|
||||||
class VectorIORouter(VectorIO):
|
class VectorIORouter(VectorIO):
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl
|
from .common import CommonRoutingTableImpl
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||||
|
|
||||||
|
|
||||||
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ from llama_stack.core.store import DistributionRegistry
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import Api, RoutingTable
|
from llama_stack.providers.datatypes import Api, RoutingTable
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||||
|
|
||||||
|
|
||||||
def get_impl_api(p: Any) -> Api:
|
def get_impl_api(p: Any) -> Api:
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl
|
from .common import CommonRoutingTableImpl
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||||
|
|
||||||
|
|
||||||
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl, lookup_model
|
from .common import CommonRoutingTableImpl, lookup_model
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||||
|
|
||||||
|
|
||||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl
|
from .common import CommonRoutingTableImpl
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||||
|
|
||||||
|
|
||||||
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl
|
from .common import CommonRoutingTableImpl
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||||
|
|
||||||
|
|
||||||
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl
|
from .common import CommonRoutingTableImpl
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||||
|
|
||||||
|
|
||||||
def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name: str) -> str | None:
|
def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name: str) -> str | None:
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .common import CommonRoutingTableImpl, lookup_model
|
from .common import CommonRoutingTableImpl, lookup_model
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core::routing_tables")
|
||||||
|
|
||||||
|
|
||||||
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ from llama_stack.core.server.auth_providers import create_auth_provider
|
||||||
from llama_stack.core.server.routes import find_matching_route, initialize_route_impls
|
from llama_stack.core.server.routes import find_matching_route, initialize_route_impls
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="auth")
|
logger = get_logger(name=__name__, category="core::auth")
|
||||||
|
|
||||||
|
|
||||||
class AuthenticationMiddleware:
|
class AuthenticationMiddleware:
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ from llama_stack.core.datatypes import (
|
||||||
)
|
)
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="auth")
|
logger = get_logger(name=__name__, category="core::auth")
|
||||||
|
|
||||||
|
|
||||||
class AuthResponse(BaseModel):
|
class AuthResponse(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ from llama_stack.providers.utils.kvstore.api import KVStore
|
||||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||||
from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl
|
from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="quota")
|
logger = get_logger(name=__name__, category="core::server")
|
||||||
|
|
||||||
|
|
||||||
class QuotaMiddleware:
|
class QuotaMiddleware:
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,7 @@ from aiohttp import hdrs
|
||||||
from fastapi import Body, FastAPI, HTTPException, Request, Response
|
from fastapi import Body, FastAPI, HTTPException, Request, Response
|
||||||
from fastapi import Path as FastapiPath
|
from fastapi import Path as FastapiPath
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
from openai import BadRequestError
|
from openai import BadRequestError
|
||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import BaseModel, ValidationError
|
||||||
|
|
@ -40,6 +41,7 @@ from llama_stack.core.datatypes import (
|
||||||
AuthenticationRequiredError,
|
AuthenticationRequiredError,
|
||||||
LoggingConfig,
|
LoggingConfig,
|
||||||
StackRunConfig,
|
StackRunConfig,
|
||||||
|
process_cors_config,
|
||||||
)
|
)
|
||||||
from llama_stack.core.distribution import builtin_automatically_routed_apis
|
from llama_stack.core.distribution import builtin_automatically_routed_apis
|
||||||
from llama_stack.core.external import ExternalApiSpec, load_external_apis
|
from llama_stack.core.external import ExternalApiSpec, load_external_apis
|
||||||
|
|
@ -82,7 +84,7 @@ from .quota import QuotaMiddleware
|
||||||
|
|
||||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="server")
|
logger = get_logger(name=__name__, category="core::server")
|
||||||
|
|
||||||
|
|
||||||
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
|
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
|
||||||
|
|
@ -413,7 +415,7 @@ def main(args: argparse.Namespace | None = None):
|
||||||
config_contents = yaml.safe_load(fp)
|
config_contents = yaml.safe_load(fp)
|
||||||
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
||||||
logger_config = LoggingConfig(**cfg)
|
logger_config = LoggingConfig(**cfg)
|
||||||
logger = get_logger(name=__name__, category="server", config=logger_config)
|
logger = get_logger(name=__name__, category="core::server", config=logger_config)
|
||||||
if args.env:
|
if args.env:
|
||||||
for env_pair in args.env:
|
for env_pair in args.env:
|
||||||
try:
|
try:
|
||||||
|
|
@ -483,6 +485,12 @@ def main(args: argparse.Namespace | None = None):
|
||||||
window_seconds=window_seconds,
|
window_seconds=window_seconds,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if config.server.cors:
|
||||||
|
logger.info("Enabling CORS")
|
||||||
|
cors_config = process_cors_config(config.server.cors)
|
||||||
|
if cors_config:
|
||||||
|
app.add_middleware(CORSMiddleware, **cors_config.model_dump())
|
||||||
|
|
||||||
if Api.telemetry in impls:
|
if Api.telemetry in impls:
|
||||||
setup_logger(impls[Api.telemetry])
|
setup_logger(impls[Api.telemetry])
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
|
||||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||||
|
|
||||||
logger = get_logger(__name__, category="core")
|
logger = get_logger(__name__, category="core::registry")
|
||||||
|
|
||||||
|
|
||||||
class DistributionRegistry(Protocol):
|
class DistributionRegistry(Protocol):
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ from pathlib import Path
|
||||||
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="config_resolution")
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
DISTRO_DIR = Path(__file__).parent.parent.parent.parent / "llama_stack" / "distributions"
|
DISTRO_DIR = Path(__file__).parent.parent.parent.parent / "llama_stack" / "distributions"
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,7 @@ from .utils import get_negative_inf_value, to_2tuple
|
||||||
|
|
||||||
MP_SCALE = 8
|
MP_SCALE = 8
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="models")
|
logger = get_logger(name=__name__, category="models::llama")
|
||||||
|
|
||||||
|
|
||||||
def reduce_from_tensor_model_parallel_region(input_):
|
def reduce_from_tensor_model_parallel_region(input_):
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
|
from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="inference")
|
logger = get_logger(name=__name__, category="models::llama")
|
||||||
|
|
||||||
BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)'
|
BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)'
|
||||||
CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})")
|
CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})")
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ from ...datatypes import QuantizationMode
|
||||||
from ..model import Transformer, TransformerBlock
|
from ..model import Transformer, TransformerBlock
|
||||||
from ..moe import MoE
|
from ..moe import MoE
|
||||||
|
|
||||||
log = get_logger(name=__name__, category="models")
|
log = get_logger(name=__name__, category="models::llama")
|
||||||
|
|
||||||
|
|
||||||
def swiglu_wrapper_no_reduce(
|
def swiglu_wrapper_no_reduce(
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ import collections
|
||||||
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
log = get_logger(name=__name__, category="llama")
|
log = get_logger(name=__name__, category="models::llama")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import fbgemm_gpu.experimental.gen_ai # noqa: F401
|
import fbgemm_gpu.experimental.gen_ai # noqa: F401
|
||||||
|
|
|
||||||
|
|
@ -84,7 +84,7 @@ MEMORY_QUERY_TOOL = "knowledge_search"
|
||||||
WEB_SEARCH_TOOL = "web_search"
|
WEB_SEARCH_TOOL = "web_search"
|
||||||
RAG_TOOL_GROUP = "builtin::rag"
|
RAG_TOOL_GROUP = "builtin::rag"
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="agents")
|
logger = get_logger(name=__name__, category="agents::meta_reference")
|
||||||
|
|
||||||
|
|
||||||
class ChatAgent(ShieldRunnerMixin):
|
class ChatAgent(ShieldRunnerMixin):
|
||||||
|
|
|
||||||
|
|
@ -51,7 +51,7 @@ from .config import MetaReferenceAgentsImplConfig
|
||||||
from .persistence import AgentInfo
|
from .persistence import AgentInfo
|
||||||
from .responses.openai_responses import OpenAIResponsesImpl
|
from .responses.openai_responses import OpenAIResponsesImpl
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="agents")
|
logger = get_logger(name=__name__, category="agents::meta_reference")
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceAgentsImpl(Agents):
|
class MetaReferenceAgentsImpl(Agents):
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ from llama_stack.core.request_headers import get_authenticated_user
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
|
|
||||||
log = get_logger(name=__name__, category="agents")
|
log = get_logger(name=__name__, category="agents::meta_reference")
|
||||||
|
|
||||||
|
|
||||||
class AgentSessionInfo(Session):
|
class AgentSessionInfo(Session):
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,7 @@ from .utils import (
|
||||||
convert_response_text_to_chat_response_format,
|
convert_response_text_to_chat_response_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="responses")
|
logger = get_logger(name=__name__, category="openai::responses")
|
||||||
|
|
||||||
|
|
||||||
class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
|
class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -47,7 +47,7 @@ from llama_stack.log import get_logger
|
||||||
from .types import ChatCompletionContext, ChatCompletionResult
|
from .types import ChatCompletionContext, ChatCompletionResult
|
||||||
from .utils import convert_chat_choice_to_response_message, is_function_tool_call
|
from .utils import convert_chat_choice_to_response_message, is_function_tool_call
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="responses")
|
logger = get_logger(name=__name__, category="agents::meta_reference")
|
||||||
|
|
||||||
|
|
||||||
class StreamingResponseOrchestrator:
|
class StreamingResponseOrchestrator:
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .types import ChatCompletionContext, ToolExecutionResult
|
from .types import ChatCompletionContext, ToolExecutionResult
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="responses")
|
logger = get_logger(name=__name__, category="agents::meta_reference")
|
||||||
|
|
||||||
|
|
||||||
class ToolExecutor:
|
class ToolExecutor:
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,8 @@ from llama_stack.apis.agents.openai_responses import (
|
||||||
OpenAIResponseOutputMessageContent,
|
OpenAIResponseOutputMessageContent,
|
||||||
OpenAIResponseOutputMessageContentOutputText,
|
OpenAIResponseOutputMessageContentOutputText,
|
||||||
OpenAIResponseOutputMessageFunctionToolCall,
|
OpenAIResponseOutputMessageFunctionToolCall,
|
||||||
|
OpenAIResponseOutputMessageMCPCall,
|
||||||
|
OpenAIResponseOutputMessageMCPListTools,
|
||||||
OpenAIResponseText,
|
OpenAIResponseText,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
|
|
@ -117,6 +119,25 @@ async def convert_response_input_to_chat_messages(
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
|
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
|
||||||
|
elif isinstance(input_item, OpenAIResponseOutputMessageMCPCall):
|
||||||
|
tool_call = OpenAIChatCompletionToolCall(
|
||||||
|
index=0,
|
||||||
|
id=input_item.id,
|
||||||
|
function=OpenAIChatCompletionToolCallFunction(
|
||||||
|
name=input_item.name,
|
||||||
|
arguments=input_item.arguments,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
|
||||||
|
messages.append(
|
||||||
|
OpenAIToolMessageParam(
|
||||||
|
content=input_item.output,
|
||||||
|
tool_call_id=input_item.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif isinstance(input_item, OpenAIResponseOutputMessageMCPListTools):
|
||||||
|
# the tool list will be handled separately
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
content = await convert_response_content_to_chat_content(input_item.content)
|
content = await convert_response_content_to_chat_content(input_item.content)
|
||||||
message_type = await get_message_type_by_role(input_item.role)
|
message_type = await get_message_type_by_role(input_item.role)
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.telemetry import tracing
|
from llama_stack.providers.utils.telemetry import tracing
|
||||||
|
|
||||||
log = get_logger(name=__name__, category="agents")
|
log = get_logger(name=__name__, category="agents::meta_reference")
|
||||||
|
|
||||||
|
|
||||||
class SafetyException(Exception): # noqa: N818
|
class SafetyException(Exception): # noqa: N818
|
||||||
|
|
|
||||||
|
|
@ -33,6 +33,9 @@ from llama_stack.apis.inference import (
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
|
OpenAIChatCompletionContentPartImageParam,
|
||||||
|
OpenAIChatCompletionContentPartTextParam,
|
||||||
|
RerankResponse,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
StopReason,
|
StopReason,
|
||||||
|
|
@ -442,6 +445,15 @@ class MetaReferenceInferenceImpl(
|
||||||
results = await self._nonstream_chat_completion(request_batch)
|
results = await self._nonstream_chat_completion(request_batch)
|
||||||
return BatchChatCompletionResponse(batch=results)
|
return BatchChatCompletionResponse(batch=results)
|
||||||
|
|
||||||
|
async def rerank(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
|
||||||
|
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
|
||||||
|
max_num_results: int | None = None,
|
||||||
|
) -> RerankResponse:
|
||||||
|
raise NotImplementedError("Reranking is not supported for Meta Reference")
|
||||||
|
|
||||||
async def _nonstream_chat_completion(
|
async def _nonstream_chat_completion(
|
||||||
self, request_batch: list[ChatCompletionRequest]
|
self, request_batch: list[ChatCompletionRequest]
|
||||||
) -> list[ChatCompletionResponse]:
|
) -> list[ChatCompletionResponse]:
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,9 @@ from llama_stack.apis.inference import (
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
|
OpenAIChatCompletionContentPartImageParam,
|
||||||
|
OpenAIChatCompletionContentPartTextParam,
|
||||||
|
RerankResponse,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
|
|
@ -122,3 +125,12 @@ class SentenceTransformersInferenceImpl(
|
||||||
logprobs: LogProbConfig | None = None,
|
logprobs: LogProbConfig | None = None,
|
||||||
):
|
):
|
||||||
raise NotImplementedError("Batch chat completion is not supported for Sentence Transformers")
|
raise NotImplementedError("Batch chat completion is not supported for Sentence Transformers")
|
||||||
|
|
||||||
|
async def rerank(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
|
||||||
|
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
|
||||||
|
max_num_results: int | None = None,
|
||||||
|
) -> RerankResponse:
|
||||||
|
raise NotImplementedError("Reranking is not supported for Sentence Transformers")
|
||||||
|
|
|
||||||
|
|
@ -65,7 +65,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
from .config import FireworksImplConfig
|
from .config import FireworksImplConfig
|
||||||
from .models import MODEL_ENTRIES
|
from .models import MODEL_ENTRIES
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="inference")
|
logger = get_logger(name=__name__, category="inference::fireworks")
|
||||||
|
|
||||||
|
|
||||||
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,11 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
from llama_stack.apis.inference import (
|
||||||
|
OpenAIChatCompletionContentPartImageParam,
|
||||||
|
OpenAIChatCompletionContentPartTextParam,
|
||||||
|
RerankResponse,
|
||||||
|
)
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
|
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
|
||||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||||
|
|
@ -10,7 +15,7 @@ from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
|
|
||||||
from .models import MODEL_ENTRIES
|
from .models import MODEL_ENTRIES
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="inference")
|
logger = get_logger(name=__name__, category="inference::llama_openai_compat")
|
||||||
|
|
||||||
|
|
||||||
class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||||
|
|
@ -54,3 +59,12 @@ class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||||
|
|
||||||
async def shutdown(self):
|
async def shutdown(self):
|
||||||
await super().shutdown()
|
await super().shutdown()
|
||||||
|
|
||||||
|
async def rerank(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
|
||||||
|
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
|
||||||
|
max_num_results: int | None = None,
|
||||||
|
) -> RerankResponse:
|
||||||
|
raise NotImplementedError("Reranking is not supported for Llama OpenAI Compat")
|
||||||
|
|
|
||||||
|
|
@ -41,6 +41,11 @@ client.initialize()
|
||||||
|
|
||||||
### Create Completion
|
### Create Completion
|
||||||
|
|
||||||
|
> Note on Completion API
|
||||||
|
>
|
||||||
|
> The hosted NVIDIA Llama NIMs (e.g., `meta-llama/Llama-3.1-8B-Instruct`) with ```NVIDIA_BASE_URL="https://integrate.api.nvidia.com"``` does not support the ```completion``` method, while the locally deployed NIM does.
|
||||||
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
response = client.inference.completion(
|
response = client.inference.completion(
|
||||||
model_id="meta-llama/Llama-3.1-8B-Instruct",
|
model_id="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
|
@ -76,6 +81,73 @@ response = client.inference.chat_completion(
|
||||||
print(f"Response: {response.completion_message.content}")
|
print(f"Response: {response.completion_message.content}")
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Tool Calling Example ###
|
||||||
|
```python
|
||||||
|
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
||||||
|
|
||||||
|
tool_definition = ToolDefinition(
|
||||||
|
tool_name="get_weather",
|
||||||
|
description="Get current weather information for a location",
|
||||||
|
parameters={
|
||||||
|
"location": ToolParamDefinition(
|
||||||
|
param_type="string",
|
||||||
|
description="The city and state, e.g. San Francisco, CA",
|
||||||
|
required=True,
|
||||||
|
),
|
||||||
|
"unit": ToolParamDefinition(
|
||||||
|
param_type="string",
|
||||||
|
description="Temperature unit (celsius or fahrenheit)",
|
||||||
|
required=False,
|
||||||
|
default="celsius",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_response = client.inference.chat_completion(
|
||||||
|
model_id="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
messages=[{"role": "user", "content": "What's the weather like in San Francisco?"}],
|
||||||
|
tools=[tool_definition],
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Tool Response: {tool_response.completion_message.content}")
|
||||||
|
if tool_response.completion_message.tool_calls:
|
||||||
|
for tool_call in tool_response.completion_message.tool_calls:
|
||||||
|
print(f"Tool Called: {tool_call.tool_name}")
|
||||||
|
print(f"Arguments: {tool_call.arguments}")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Structured Output Example
|
||||||
|
```python
|
||||||
|
from llama_stack.apis.inference import JsonSchemaResponseFormat, ResponseFormatType
|
||||||
|
|
||||||
|
person_schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {"type": "string"},
|
||||||
|
"age": {"type": "integer"},
|
||||||
|
"occupation": {"type": "string"},
|
||||||
|
},
|
||||||
|
"required": ["name", "age", "occupation"],
|
||||||
|
}
|
||||||
|
|
||||||
|
response_format = JsonSchemaResponseFormat(
|
||||||
|
type=ResponseFormatType.json_schema, json_schema=person_schema
|
||||||
|
)
|
||||||
|
|
||||||
|
structured_response = client.inference.chat_completion(
|
||||||
|
model_id="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Create a profile for a fictional person named Alice who is 30 years old and is a software engineer. ",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
response_format=response_format,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Structured Response: {structured_response.completion_message.content}")
|
||||||
|
```
|
||||||
|
|
||||||
### Create Embeddings
|
### Create Embeddings
|
||||||
> Note on OpenAI embeddings compatibility
|
> Note on OpenAI embeddings compatibility
|
||||||
>
|
>
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@
|
||||||
import warnings
|
import warnings
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
|
|
||||||
from openai import NOT_GIVEN, APIConnectionError, BadRequestError
|
from openai import NOT_GIVEN, APIConnectionError
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
|
|
@ -57,7 +57,7 @@ from .openai_utils import (
|
||||||
)
|
)
|
||||||
from .utils import _is_nvidia_hosted
|
from .utils import _is_nvidia_hosted
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="inference")
|
logger = get_logger(name=__name__, category="inference::nvidia")
|
||||||
|
|
||||||
|
|
||||||
class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
|
class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
|
||||||
|
|
@ -197,15 +197,11 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
|
||||||
}
|
}
|
||||||
extra_body["input_type"] = task_type_options[task_type]
|
extra_body["input_type"] = task_type_options[task_type]
|
||||||
|
|
||||||
try:
|
response = await self.client.embeddings.create(
|
||||||
response = await self.client.embeddings.create(
|
model=provider_model_id,
|
||||||
model=provider_model_id,
|
input=input,
|
||||||
input=input,
|
extra_body=extra_body,
|
||||||
extra_body=extra_body,
|
)
|
||||||
)
|
|
||||||
except BadRequestError as e:
|
|
||||||
raise ValueError(f"Failed to get embeddings: {e}") from e
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# OpenAI: CreateEmbeddingResponse(data=[Embedding(embedding=list[float], ...)], ...)
|
# OpenAI: CreateEmbeddingResponse(data=[Embedding(embedding=list[float], ...)], ...)
|
||||||
# ->
|
# ->
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
from . import NVIDIAConfig
|
from . import NVIDIAConfig
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="inference")
|
logger = get_logger(name=__name__, category="inference::nvidia")
|
||||||
|
|
||||||
|
|
||||||
def _is_nvidia_hosted(config: NVIDIAConfig) -> bool:
|
def _is_nvidia_hosted(config: NVIDIAConfig) -> bool:
|
||||||
|
|
|
||||||
|
|
@ -37,11 +37,14 @@ from llama_stack.apis.inference import (
|
||||||
Message,
|
Message,
|
||||||
OpenAIChatCompletion,
|
OpenAIChatCompletion,
|
||||||
OpenAIChatCompletionChunk,
|
OpenAIChatCompletionChunk,
|
||||||
|
OpenAIChatCompletionContentPartImageParam,
|
||||||
|
OpenAIChatCompletionContentPartTextParam,
|
||||||
OpenAICompletion,
|
OpenAICompletion,
|
||||||
OpenAIEmbeddingsResponse,
|
OpenAIEmbeddingsResponse,
|
||||||
OpenAIEmbeddingUsage,
|
OpenAIEmbeddingUsage,
|
||||||
OpenAIMessageParam,
|
OpenAIMessageParam,
|
||||||
OpenAIResponseFormatParam,
|
OpenAIResponseFormatParam,
|
||||||
|
RerankResponse,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
TextTruncation,
|
TextTruncation,
|
||||||
|
|
@ -85,7 +88,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
|
||||||
from .models import MODEL_ENTRIES
|
from .models import MODEL_ENTRIES
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="inference")
|
logger = get_logger(name=__name__, category="inference::ollama")
|
||||||
|
|
||||||
|
|
||||||
class OllamaInferenceAdapter(
|
class OllamaInferenceAdapter(
|
||||||
|
|
@ -641,6 +644,15 @@ class OllamaInferenceAdapter(
|
||||||
):
|
):
|
||||||
raise NotImplementedError("Batch chat completion is not supported for Ollama")
|
raise NotImplementedError("Batch chat completion is not supported for Ollama")
|
||||||
|
|
||||||
|
async def rerank(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
|
||||||
|
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
|
||||||
|
max_num_results: int | None = None,
|
||||||
|
) -> RerankResponse:
|
||||||
|
raise NotImplementedError("Reranking is not supported for Ollama")
|
||||||
|
|
||||||
|
|
||||||
async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]:
|
async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]:
|
||||||
async def _convert_content(content) -> dict:
|
async def _convert_content(content) -> dict:
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
from .config import OpenAIConfig
|
from .config import OpenAIConfig
|
||||||
from .models import MODEL_ENTRIES
|
from .models import MODEL_ENTRIES
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="inference")
|
logger = get_logger(name=__name__, category="inference::openai")
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
|
|
|
||||||
|
|
@ -58,7 +58,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
|
||||||
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
|
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
|
||||||
|
|
||||||
log = get_logger(name=__name__, category="inference")
|
log = get_logger(name=__name__, category="inference::tgi")
|
||||||
|
|
||||||
|
|
||||||
def build_hf_repo_model_entries():
|
def build_hf_repo_model_entries():
|
||||||
|
|
|
||||||
|
|
@ -61,7 +61,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
from .config import TogetherImplConfig
|
from .config import TogetherImplConfig
|
||||||
from .models import MODEL_ENTRIES
|
from .models import MODEL_ENTRIES
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="inference")
|
logger = get_logger(name=__name__, category="inference::together")
|
||||||
|
|
||||||
|
|
||||||
class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
|
||||||
|
|
|
||||||
|
|
@ -39,12 +39,15 @@ from llama_stack.apis.inference import (
|
||||||
Message,
|
Message,
|
||||||
ModelStore,
|
ModelStore,
|
||||||
OpenAIChatCompletion,
|
OpenAIChatCompletion,
|
||||||
|
OpenAIChatCompletionContentPartImageParam,
|
||||||
|
OpenAIChatCompletionContentPartTextParam,
|
||||||
OpenAICompletion,
|
OpenAICompletion,
|
||||||
OpenAIEmbeddingData,
|
OpenAIEmbeddingData,
|
||||||
OpenAIEmbeddingsResponse,
|
OpenAIEmbeddingsResponse,
|
||||||
OpenAIEmbeddingUsage,
|
OpenAIEmbeddingUsage,
|
||||||
OpenAIMessageParam,
|
OpenAIMessageParam,
|
||||||
OpenAIResponseFormatParam,
|
OpenAIResponseFormatParam,
|
||||||
|
RerankResponse,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
TextTruncation,
|
TextTruncation,
|
||||||
|
|
@ -85,7 +88,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
|
||||||
from .config import VLLMInferenceAdapterConfig
|
from .config import VLLMInferenceAdapterConfig
|
||||||
|
|
||||||
log = get_logger(name=__name__, category="inference")
|
log = get_logger(name=__name__, category="inference::vllm")
|
||||||
|
|
||||||
|
|
||||||
def build_hf_repo_model_entries():
|
def build_hf_repo_model_entries():
|
||||||
|
|
@ -732,4 +735,13 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
response_format: ResponseFormat | None = None,
|
response_format: ResponseFormat | None = None,
|
||||||
logprobs: LogProbConfig | None = None,
|
logprobs: LogProbConfig | None = None,
|
||||||
):
|
):
|
||||||
raise NotImplementedError("Batch chat completion is not supported for Ollama")
|
raise NotImplementedError("Batch chat completion is not supported for vLLM")
|
||||||
|
|
||||||
|
async def rerank(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
|
||||||
|
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
|
||||||
|
max_num_results: int | None = None,
|
||||||
|
) -> RerankResponse:
|
||||||
|
raise NotImplementedError("Reranking is not supported for vLLM")
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ from llama_stack.providers.remote.post_training.nvidia.config import SFTLoRADefa
|
||||||
|
|
||||||
from .config import NvidiaPostTrainingConfig
|
from .config import NvidiaPostTrainingConfig
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="integration")
|
logger = get_logger(name=__name__, category="post_training::nvidia")
|
||||||
|
|
||||||
|
|
||||||
def warn_unsupported_params(config_dict: Any, supported_keys: set[str], config_name: str) -> None:
|
def warn_unsupported_params(config_dict: Any, supported_keys: set[str], config_name: str) -> None:
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@ from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
||||||
|
|
||||||
from .config import BedrockSafetyConfig
|
from .config import BedrockSafetyConfig
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="safety")
|
logger = get_logger(name=__name__, category="safety::bedrock")
|
||||||
|
|
||||||
|
|
||||||
class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ from typing import Any
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
from llama_stack.apis.inference import Message
|
from llama_stack.apis.inference import Message
|
||||||
from llama_stack.apis.safety import RunShieldResponse, Safety, SafetyViolation, ViolationLevel
|
from llama_stack.apis.safety import ModerationObject, RunShieldResponse, Safety, SafetyViolation, ViolationLevel
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||||
|
|
@ -17,7 +17,7 @@ from llama_stack.providers.utils.inference.openai_compat import convert_message_
|
||||||
|
|
||||||
from .config import NVIDIASafetyConfig
|
from .config import NVIDIASafetyConfig
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="safety")
|
logger = get_logger(name=__name__, category="safety::nvidia")
|
||||||
|
|
||||||
|
|
||||||
class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
|
class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||||
|
|
@ -67,6 +67,9 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||||
self.shield = NeMoGuardrails(self.config, shield.shield_id)
|
self.shield = NeMoGuardrails(self.config, shield.shield_id)
|
||||||
return await self.shield.run(messages)
|
return await self.shield.run(messages)
|
||||||
|
|
||||||
|
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
|
||||||
|
raise NotImplementedError("NVIDIA safety provider currently does not implement run_moderation")
|
||||||
|
|
||||||
|
|
||||||
class NeMoGuardrails:
|
class NeMoGuardrails:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@ from llama_stack.providers.utils.inference.openai_compat import convert_message_
|
||||||
|
|
||||||
from .config import SambaNovaSafetyConfig
|
from .config import SambaNovaSafetyConfig
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="safety")
|
logger = get_logger(name=__name__, category="safety::sambanova")
|
||||||
|
|
||||||
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
|
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
||||||
|
|
||||||
from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig
|
from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig
|
||||||
|
|
||||||
log = get_logger(name=__name__, category="vector_io")
|
log = get_logger(name=__name__, category="vector_io::chroma")
|
||||||
|
|
||||||
ChromaClientType = chromadb.api.AsyncClientAPI | chromadb.api.ClientAPI
|
ChromaClientType = chromadb.api.AsyncClientAPI | chromadb.api.ClientAPI
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,7 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti
|
||||||
|
|
||||||
from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig
|
from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="vector_io")
|
logger = get_logger(name=__name__, category="vector_io::milvus")
|
||||||
|
|
||||||
VERSION = "v3"
|
VERSION = "v3"
|
||||||
VECTOR_DBS_PREFIX = f"vector_dbs:milvus:{VERSION}::"
|
VECTOR_DBS_PREFIX = f"vector_dbs:milvus:{VERSION}::"
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
||||||
|
|
||||||
from .config import PGVectorVectorIOConfig
|
from .config import PGVectorVectorIOConfig
|
||||||
|
|
||||||
log = get_logger(name=__name__, category="vector_io")
|
log = get_logger(name=__name__, category="vector_io::pgvector")
|
||||||
|
|
||||||
VERSION = "v3"
|
VERSION = "v3"
|
||||||
VECTOR_DBS_PREFIX = f"vector_dbs:pgvector:{VERSION}::"
|
VECTOR_DBS_PREFIX = f"vector_dbs:pgvector:{VERSION}::"
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
||||||
|
|
||||||
from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig
|
from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig
|
||||||
|
|
||||||
log = get_logger(name=__name__, category="vector_io")
|
log = get_logger(name=__name__, category="vector_io::qdrant")
|
||||||
CHUNK_ID_KEY = "_chunk_id"
|
CHUNK_ID_KEY = "_chunk_id"
|
||||||
|
|
||||||
# KV store prefixes for vector databases
|
# KV store prefixes for vector databases
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,7 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti
|
||||||
|
|
||||||
from .config import WeaviateVectorIOConfig
|
from .config import WeaviateVectorIOConfig
|
||||||
|
|
||||||
log = get_logger(name=__name__, category="vector_io")
|
log = get_logger(name=__name__, category="vector_io::weaviate")
|
||||||
|
|
||||||
VERSION = "v3"
|
VERSION = "v3"
|
||||||
VECTOR_DBS_PREFIX = f"vector_dbs:weaviate:{VERSION}::"
|
VECTOR_DBS_PREFIX = f"vector_dbs:weaviate:{VERSION}::"
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import interleaved_con
|
||||||
EMBEDDING_MODELS = {}
|
EMBEDDING_MODELS = {}
|
||||||
|
|
||||||
|
|
||||||
log = get_logger(name=__name__, category="inference")
|
log = get_logger(name=__name__, category="providers::utils")
|
||||||
|
|
||||||
|
|
||||||
class SentenceTransformerEmbeddingMixin:
|
class SentenceTransformerEmbeddingMixin:
|
||||||
|
|
|
||||||
|
|
@ -54,7 +54,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
interleaved_content_as_str,
|
interleaved_content_as_str,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="inference")
|
logger = get_logger(name=__name__, category="providers::utils")
|
||||||
|
|
||||||
|
|
||||||
class LiteLLMOpenAIMixin(
|
class LiteLLMOpenAIMixin(
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ from llama_stack.providers.utils.inference import (
|
||||||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
|
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="providers::utils")
|
||||||
|
|
||||||
|
|
||||||
class RemoteInferenceProviderConfig(BaseModel):
|
class RemoteInferenceProviderConfig(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -134,7 +134,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
decode_assistant_message,
|
decode_assistant_message,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="inference")
|
logger = get_logger(name=__name__, category="providers::utils")
|
||||||
|
|
||||||
|
|
||||||
class OpenAICompatCompletionChoiceDelta(BaseModel):
|
class OpenAICompatCompletionChoiceDelta(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@ from llama_stack.apis.inference import (
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
|
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="providers::utils")
|
||||||
|
|
||||||
|
|
||||||
class OpenAIMixin(ABC):
|
class OpenAIMixin(ABC):
|
||||||
|
|
|
||||||
|
|
@ -58,7 +58,7 @@ from llama_stack.models.llama.sku_list import resolve_model
|
||||||
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
|
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
|
||||||
from llama_stack.providers.utils.inference import supported_inference_models
|
from llama_stack.providers.utils.inference import supported_inference_models
|
||||||
|
|
||||||
log = get_logger(name=__name__, category="inference")
|
log = get_logger(name=__name__, category="providers::utils")
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionRequestWithRawContent(ChatCompletionRequest):
|
class ChatCompletionRequestWithRawContent(ChatCompletionRequest):
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ from llama_stack.providers.utils.kvstore import KVStore
|
||||||
|
|
||||||
from ..config import MongoDBKVStoreConfig
|
from ..config import MongoDBKVStoreConfig
|
||||||
|
|
||||||
log = get_logger(name=__name__, category="kvstore")
|
log = get_logger(name=__name__, category="providers::utils")
|
||||||
|
|
||||||
|
|
||||||
class MongoDBKVStoreImpl(KVStore):
|
class MongoDBKVStoreImpl(KVStore):
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ from llama_stack.log import get_logger
|
||||||
from ..api import KVStore
|
from ..api import KVStore
|
||||||
from ..config import PostgresKVStoreConfig
|
from ..config import PostgresKVStoreConfig
|
||||||
|
|
||||||
log = get_logger(name=__name__, category="kvstore")
|
log = get_logger(name=__name__, category="providers::utils")
|
||||||
|
|
||||||
|
|
||||||
class PostgresKVStoreImpl(KVStore):
|
class PostgresKVStoreImpl(KVStore):
|
||||||
|
|
|
||||||
|
|
@ -44,7 +44,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
||||||
make_overlapped_chunks,
|
make_overlapped_chunks,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="memory")
|
logger = get_logger(name=__name__, category="providers::utils")
|
||||||
|
|
||||||
# Constants for OpenAI vector stores
|
# Constants for OpenAI vector stores
|
||||||
CHUNK_MULTIPLIER = 5
|
CHUNK_MULTIPLIER = 5
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
|
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
|
||||||
|
|
||||||
log = get_logger(name=__name__, category="memory")
|
log = get_logger(name=__name__, category="providers::utils")
|
||||||
|
|
||||||
|
|
||||||
class ChunkForDeletion(BaseModel):
|
class ChunkForDeletion(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="scheduler")
|
logger = get_logger(name=__name__, category="providers::utils")
|
||||||
|
|
||||||
|
|
||||||
# TODO: revisit the list of possible statuses when defining a more coherent
|
# TODO: revisit the list of possible statuses when defining a more coherent
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ from llama_stack.log import get_logger
|
||||||
from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore
|
from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore
|
||||||
from .sqlstore import SqlStoreType
|
from .sqlstore import SqlStoreType
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="authorized_sqlstore")
|
logger = get_logger(name=__name__, category="providers::utils")
|
||||||
|
|
||||||
# Hardcoded copy of the default policy that our SQL filtering implements
|
# Hardcoded copy of the default policy that our SQL filtering implements
|
||||||
# WARNING: If default_policy() changes, this constant must be updated accordingly
|
# WARNING: If default_policy() changes, this constant must be updated accordingly
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ from sqlalchemy import (
|
||||||
text,
|
text,
|
||||||
)
|
)
|
||||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||||
|
from sqlalchemy.ext.asyncio.engine import AsyncEngine
|
||||||
|
|
||||||
from llama_stack.apis.common.responses import PaginatedResponse
|
from llama_stack.apis.common.responses import PaginatedResponse
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
@ -29,7 +30,7 @@ from llama_stack.log import get_logger
|
||||||
from .api import ColumnDefinition, ColumnType, SqlStore
|
from .api import ColumnDefinition, ColumnType, SqlStore
|
||||||
from .sqlstore import SqlAlchemySqlStoreConfig
|
from .sqlstore import SqlAlchemySqlStoreConfig
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="sqlstore")
|
logger = get_logger(name=__name__, category="providers::utils")
|
||||||
|
|
||||||
TYPE_MAPPING: dict[ColumnType, Any] = {
|
TYPE_MAPPING: dict[ColumnType, Any] = {
|
||||||
ColumnType.INTEGER: Integer,
|
ColumnType.INTEGER: Integer,
|
||||||
|
|
@ -45,9 +46,12 @@ TYPE_MAPPING: dict[ColumnType, Any] = {
|
||||||
class SqlAlchemySqlStoreImpl(SqlStore):
|
class SqlAlchemySqlStoreImpl(SqlStore):
|
||||||
def __init__(self, config: SqlAlchemySqlStoreConfig):
|
def __init__(self, config: SqlAlchemySqlStoreConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.async_session = async_sessionmaker(create_async_engine(config.engine_str))
|
self.async_session = async_sessionmaker(self.create_engine())
|
||||||
self.metadata = MetaData()
|
self.metadata = MetaData()
|
||||||
|
|
||||||
|
def create_engine(self) -> AsyncEngine:
|
||||||
|
return create_async_engine(self.config.engine_str, pool_pre_ping=True)
|
||||||
|
|
||||||
async def create_table(
|
async def create_table(
|
||||||
self,
|
self,
|
||||||
table: str,
|
table: str,
|
||||||
|
|
@ -83,7 +87,7 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
||||||
else:
|
else:
|
||||||
sqlalchemy_table = self.metadata.tables[table]
|
sqlalchemy_table = self.metadata.tables[table]
|
||||||
|
|
||||||
engine = create_async_engine(self.config.engine_str)
|
engine = self.create_engine()
|
||||||
async with engine.begin() as conn:
|
async with engine.begin() as conn:
|
||||||
await conn.run_sync(self.metadata.create_all, tables=[sqlalchemy_table], checkfirst=True)
|
await conn.run_sync(self.metadata.create_all, tables=[sqlalchemy_table], checkfirst=True)
|
||||||
|
|
||||||
|
|
@ -241,7 +245,7 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
||||||
nullable: bool = True,
|
nullable: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Add a column to an existing table if the column doesn't already exist."""
|
"""Add a column to an existing table if the column doesn't already exist."""
|
||||||
engine = create_async_engine(self.config.engine_str)
|
engine = self.create_engine()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with engine.begin() as conn:
|
async with engine.begin() as conn:
|
||||||
|
|
|
||||||
587
llama_stack/ui/app/chat-playground/page.test.tsx
Normal file
587
llama_stack/ui/app/chat-playground/page.test.tsx
Normal file
|
|
@ -0,0 +1,587 @@
|
||||||
|
import React from "react";
|
||||||
|
import {
|
||||||
|
render,
|
||||||
|
screen,
|
||||||
|
fireEvent,
|
||||||
|
waitFor,
|
||||||
|
act,
|
||||||
|
} from "@testing-library/react";
|
||||||
|
import "@testing-library/jest-dom";
|
||||||
|
import ChatPlaygroundPage from "./page";
|
||||||
|
|
||||||
|
const mockClient = {
|
||||||
|
agents: {
|
||||||
|
list: jest.fn(),
|
||||||
|
create: jest.fn(),
|
||||||
|
retrieve: jest.fn(),
|
||||||
|
delete: jest.fn(),
|
||||||
|
session: {
|
||||||
|
list: jest.fn(),
|
||||||
|
create: jest.fn(),
|
||||||
|
delete: jest.fn(),
|
||||||
|
retrieve: jest.fn(),
|
||||||
|
},
|
||||||
|
turn: {
|
||||||
|
create: jest.fn(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
models: {
|
||||||
|
list: jest.fn(),
|
||||||
|
},
|
||||||
|
toolgroups: {
|
||||||
|
list: jest.fn(),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
jest.mock("@/hooks/use-auth-client", () => ({
|
||||||
|
useAuthClient: jest.fn(() => mockClient),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock("@/components/chat-playground/chat", () => ({
|
||||||
|
Chat: jest.fn(
|
||||||
|
({
|
||||||
|
className,
|
||||||
|
messages,
|
||||||
|
handleSubmit,
|
||||||
|
input,
|
||||||
|
handleInputChange,
|
||||||
|
isGenerating,
|
||||||
|
append,
|
||||||
|
suggestions,
|
||||||
|
}) => (
|
||||||
|
<div data-testid="chat-component" className={className}>
|
||||||
|
<div data-testid="messages-count">{messages.length}</div>
|
||||||
|
<input
|
||||||
|
data-testid="chat-input"
|
||||||
|
value={input}
|
||||||
|
onChange={handleInputChange}
|
||||||
|
disabled={isGenerating}
|
||||||
|
/>
|
||||||
|
<button data-testid="submit-button" onClick={handleSubmit}>
|
||||||
|
Submit
|
||||||
|
</button>
|
||||||
|
{suggestions?.map((suggestion: string, index: number) => (
|
||||||
|
<button
|
||||||
|
key={index}
|
||||||
|
data-testid={`suggestion-${index}`}
|
||||||
|
onClick={() => append({ role: "user", content: suggestion })}
|
||||||
|
>
|
||||||
|
{suggestion}
|
||||||
|
</button>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock("@/components/chat-playground/conversations", () => ({
|
||||||
|
SessionManager: jest.fn(({ selectedAgentId, onNewSession }) => (
|
||||||
|
<div data-testid="session-manager">
|
||||||
|
{selectedAgentId && (
|
||||||
|
<>
|
||||||
|
<div data-testid="selected-agent">{selectedAgentId}</div>
|
||||||
|
<button data-testid="new-session-button" onClick={onNewSession}>
|
||||||
|
New Session
|
||||||
|
</button>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)),
|
||||||
|
SessionUtils: {
|
||||||
|
saveCurrentSessionId: jest.fn(),
|
||||||
|
loadCurrentSessionId: jest.fn(),
|
||||||
|
loadCurrentAgentId: jest.fn(),
|
||||||
|
saveCurrentAgentId: jest.fn(),
|
||||||
|
clearCurrentSession: jest.fn(),
|
||||||
|
saveSessionData: jest.fn(),
|
||||||
|
loadSessionData: jest.fn(),
|
||||||
|
saveAgentConfig: jest.fn(),
|
||||||
|
loadAgentConfig: jest.fn(),
|
||||||
|
clearAgentCache: jest.fn(),
|
||||||
|
createDefaultSession: jest.fn(() => ({
|
||||||
|
id: "test-session-123",
|
||||||
|
name: "Default Session",
|
||||||
|
messages: [],
|
||||||
|
selectedModel: "",
|
||||||
|
systemMessage: "You are a helpful assistant.",
|
||||||
|
agentId: "test-agent-123",
|
||||||
|
createdAt: Date.now(),
|
||||||
|
updatedAt: Date.now(),
|
||||||
|
})),
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
|
||||||
|
const mockAgents = [
|
||||||
|
{
|
||||||
|
agent_id: "agent_123",
|
||||||
|
agent_config: {
|
||||||
|
name: "Test Agent",
|
||||||
|
instructions: "You are a test assistant.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
agent_id: "agent_456",
|
||||||
|
agent_config: {
|
||||||
|
agent_name: "Another Agent",
|
||||||
|
instructions: "You are another assistant.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
const mockModels = [
|
||||||
|
{
|
||||||
|
identifier: "test-model-1",
|
||||||
|
model_type: "llm",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
identifier: "test-model-2",
|
||||||
|
model_type: "llm",
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
const mockToolgroups = [
|
||||||
|
{
|
||||||
|
identifier: "builtin::rag",
|
||||||
|
provider_id: "test-provider",
|
||||||
|
type: "tool_group",
|
||||||
|
provider_resource_id: "test-resource",
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
describe("ChatPlaygroundPage", () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
jest.clearAllMocks();
|
||||||
|
Element.prototype.scrollIntoView = jest.fn();
|
||||||
|
mockClient.agents.list.mockResolvedValue({ data: mockAgents });
|
||||||
|
mockClient.models.list.mockResolvedValue(mockModels);
|
||||||
|
mockClient.toolgroups.list.mockResolvedValue(mockToolgroups);
|
||||||
|
mockClient.agents.session.create.mockResolvedValue({
|
||||||
|
session_id: "new-session-123",
|
||||||
|
});
|
||||||
|
mockClient.agents.session.list.mockResolvedValue({ data: [] });
|
||||||
|
mockClient.agents.session.retrieve.mockResolvedValue({
|
||||||
|
session_id: "test-session",
|
||||||
|
session_name: "Test Session",
|
||||||
|
started_at: new Date().toISOString(),
|
||||||
|
turns: [],
|
||||||
|
}); // No turns by default
|
||||||
|
mockClient.agents.retrieve.mockResolvedValue({
|
||||||
|
agent_id: "test-agent",
|
||||||
|
agent_config: {
|
||||||
|
toolgroups: ["builtin::rag"],
|
||||||
|
instructions: "Test instructions",
|
||||||
|
model: "test-model",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
mockClient.agents.delete.mockResolvedValue(undefined);
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Agent Selector Rendering", () => {
|
||||||
|
test("shows agent selector when agents are available", async () => {
|
||||||
|
await act(async () => {
|
||||||
|
render(<ChatPlaygroundPage />);
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText("Agent Session:")).toBeInTheDocument();
|
||||||
|
expect(screen.getAllByRole("combobox")).toHaveLength(2);
|
||||||
|
expect(screen.getByText("+ New Agent")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Clear Chat")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
test("does not show agent selector when no agents are available", async () => {
|
||||||
|
mockClient.agents.list.mockResolvedValue({ data: [] });
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
render(<ChatPlaygroundPage />);
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.queryByText("Agent Session:")).not.toBeInTheDocument();
|
||||||
|
expect(screen.getAllByRole("combobox")).toHaveLength(1);
|
||||||
|
expect(screen.getByText("+ New Agent")).toBeInTheDocument();
|
||||||
|
expect(screen.queryByText("Clear Chat")).not.toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
test("does not show agent selector while loading", async () => {
|
||||||
|
mockClient.agents.list.mockImplementation(() => new Promise(() => {}));
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
render(<ChatPlaygroundPage />);
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(screen.queryByText("Agent Session:")).not.toBeInTheDocument();
|
||||||
|
expect(screen.getAllByRole("combobox")).toHaveLength(1);
|
||||||
|
expect(screen.getByText("+ New Agent")).toBeInTheDocument();
|
||||||
|
expect(screen.queryByText("Clear Chat")).not.toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("shows agent options in selector", async () => {
|
||||||
|
await act(async () => {
|
||||||
|
render(<ChatPlaygroundPage />);
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
const agentCombobox = screen.getAllByRole("combobox").find(element => {
|
||||||
|
return (
|
||||||
|
element.textContent?.includes("Test Agent") ||
|
||||||
|
element.textContent?.includes("Select Agent")
|
||||||
|
);
|
||||||
|
});
|
||||||
|
expect(agentCombobox).toBeDefined();
|
||||||
|
fireEvent.click(agentCombobox!);
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getAllByText("Test Agent")).toHaveLength(2);
|
||||||
|
expect(screen.getByText("Another Agent")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
test("displays agent ID when no name is available", async () => {
|
||||||
|
const agentWithoutName = {
|
||||||
|
agent_id: "agent_789",
|
||||||
|
agent_config: {
|
||||||
|
instructions: "You are an agent without a name.",
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
mockClient.agents.list.mockResolvedValue({ data: [agentWithoutName] });
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
render(<ChatPlaygroundPage />);
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
const agentCombobox = screen.getAllByRole("combobox").find(element => {
|
||||||
|
return (
|
||||||
|
element.textContent?.includes("Agent agent_78") ||
|
||||||
|
element.textContent?.includes("Select Agent")
|
||||||
|
);
|
||||||
|
});
|
||||||
|
expect(agentCombobox).toBeDefined();
|
||||||
|
fireEvent.click(agentCombobox!);
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getAllByText("Agent agent_78...")).toHaveLength(2);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Agent Creation Modal", () => {
|
||||||
|
test("opens agent creation modal when + New Agent is clicked", async () => {
|
||||||
|
await act(async () => {
|
||||||
|
render(<ChatPlaygroundPage />);
|
||||||
|
});
|
||||||
|
|
||||||
|
const newAgentButton = screen.getByText("+ New Agent");
|
||||||
|
fireEvent.click(newAgentButton);
|
||||||
|
|
||||||
|
expect(screen.getByText("Create New Agent")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Agent Name (optional)")).toBeInTheDocument();
|
||||||
|
expect(screen.getAllByText("Model")).toHaveLength(2);
|
||||||
|
expect(screen.getByText("System Instructions")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Tools (optional)")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("closes modal when Cancel is clicked", async () => {
|
||||||
|
await act(async () => {
|
||||||
|
render(<ChatPlaygroundPage />);
|
||||||
|
});
|
||||||
|
|
||||||
|
const newAgentButton = screen.getByText("+ New Agent");
|
||||||
|
fireEvent.click(newAgentButton);
|
||||||
|
|
||||||
|
const cancelButton = screen.getByText("Cancel");
|
||||||
|
fireEvent.click(cancelButton);
|
||||||
|
|
||||||
|
expect(screen.queryByText("Create New Agent")).not.toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("creates agent when Create Agent is clicked", async () => {
|
||||||
|
mockClient.agents.create.mockResolvedValue({ agent_id: "new-agent-123" });
|
||||||
|
mockClient.agents.list
|
||||||
|
.mockResolvedValueOnce({ data: mockAgents })
|
||||||
|
.mockResolvedValueOnce({
|
||||||
|
data: [
|
||||||
|
...mockAgents,
|
||||||
|
{ agent_id: "new-agent-123", agent_config: { name: "New Agent" } },
|
||||||
|
],
|
||||||
|
});
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
render(<ChatPlaygroundPage />);
|
||||||
|
});
|
||||||
|
|
||||||
|
const newAgentButton = screen.getByText("+ New Agent");
|
||||||
|
await act(async () => {
|
||||||
|
fireEvent.click(newAgentButton);
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText("Create New Agent")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
const nameInput = screen.getByPlaceholderText("My Custom Agent");
|
||||||
|
await act(async () => {
|
||||||
|
fireEvent.change(nameInput, { target: { value: "Test Agent Name" } });
|
||||||
|
});
|
||||||
|
|
||||||
|
const instructionsTextarea = screen.getByDisplayValue(
|
||||||
|
"You are a helpful assistant."
|
||||||
|
);
|
||||||
|
await act(async () => {
|
||||||
|
fireEvent.change(instructionsTextarea, {
|
||||||
|
target: { value: "Custom instructions" },
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
const modalModelSelectors = screen
|
||||||
|
.getAllByRole("combobox")
|
||||||
|
.filter(el => {
|
||||||
|
return (
|
||||||
|
el.textContent?.includes("Select Model") ||
|
||||||
|
el.closest('[class*="modal"]') ||
|
||||||
|
el.closest('[class*="card"]')
|
||||||
|
);
|
||||||
|
});
|
||||||
|
expect(modalModelSelectors.length).toBeGreaterThan(0);
|
||||||
|
});
|
||||||
|
|
||||||
|
const modalModelSelectors = screen.getAllByRole("combobox").filter(el => {
|
||||||
|
return (
|
||||||
|
el.textContent?.includes("Select Model") ||
|
||||||
|
el.closest('[class*="modal"]') ||
|
||||||
|
el.closest('[class*="card"]')
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
fireEvent.click(modalModelSelectors[0]);
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
const modelOptions = screen.getAllByText("test-model-1");
|
||||||
|
expect(modelOptions.length).toBeGreaterThan(0);
|
||||||
|
});
|
||||||
|
|
||||||
|
const modelOptions = screen.getAllByText("test-model-1");
|
||||||
|
const dropdownOption = modelOptions.find(
|
||||||
|
option =>
|
||||||
|
option.closest('[role="option"]') ||
|
||||||
|
option.id?.includes("radix") ||
|
||||||
|
option.getAttribute("aria-selected") !== null
|
||||||
|
);
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
fireEvent.click(
|
||||||
|
dropdownOption || modelOptions[modelOptions.length - 1]
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
const createButton = screen.getByText("Create Agent");
|
||||||
|
expect(createButton).not.toBeDisabled();
|
||||||
|
});
|
||||||
|
|
||||||
|
const createButton = screen.getByText("Create Agent");
|
||||||
|
await act(async () => {
|
||||||
|
fireEvent.click(createButton);
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(mockClient.agents.create).toHaveBeenCalledWith({
|
||||||
|
agent_config: {
|
||||||
|
model: expect.any(String),
|
||||||
|
instructions: "Custom instructions",
|
||||||
|
name: "Test Agent Name",
|
||||||
|
enable_session_persistence: true,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.queryByText("Create New Agent")).not.toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Agent Selection", () => {
|
||||||
|
test("creates default session when agent is selected", async () => {
|
||||||
|
await act(async () => {
|
||||||
|
render(<ChatPlaygroundPage />);
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
// first agent should be auto-selected
|
||||||
|
expect(mockClient.agents.session.create).toHaveBeenCalledWith(
|
||||||
|
"agent_123",
|
||||||
|
{ session_name: "Default Session" }
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
test("switches agent when different agent is selected", async () => {
|
||||||
|
await act(async () => {
|
||||||
|
render(<ChatPlaygroundPage />);
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
const agentCombobox = screen.getAllByRole("combobox").find(element => {
|
||||||
|
return (
|
||||||
|
element.textContent?.includes("Test Agent") ||
|
||||||
|
element.textContent?.includes("Select Agent")
|
||||||
|
);
|
||||||
|
});
|
||||||
|
expect(agentCombobox).toBeDefined();
|
||||||
|
fireEvent.click(agentCombobox!);
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
const anotherAgentOption = screen.getByText("Another Agent");
|
||||||
|
fireEvent.click(anotherAgentOption);
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(mockClient.agents.session.create).toHaveBeenCalledWith(
|
||||||
|
"agent_456",
|
||||||
|
{ session_name: "Default Session" }
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Agent Deletion", () => {
|
||||||
|
test("shows delete button when multiple agents exist", async () => {
|
||||||
|
await act(async () => {
|
||||||
|
render(<ChatPlaygroundPage />);
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByTitle("Delete current agent")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
test("hides delete button when only one agent exists", async () => {
|
||||||
|
mockClient.agents.list.mockResolvedValue({
|
||||||
|
data: [mockAgents[0]],
|
||||||
|
});
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
render(<ChatPlaygroundPage />);
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(
|
||||||
|
screen.queryByTitle("Delete current agent")
|
||||||
|
).not.toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
test("deletes agent and switches to another when confirmed", async () => {
|
||||||
|
global.confirm = jest.fn(() => true);
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
render(<ChatPlaygroundPage />);
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByTitle("Delete current agent")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
mockClient.agents.delete.mockResolvedValue(undefined);
|
||||||
|
mockClient.agents.list.mockResolvedValueOnce({ data: mockAgents });
|
||||||
|
mockClient.agents.list.mockResolvedValueOnce({
|
||||||
|
data: [mockAgents[1]],
|
||||||
|
});
|
||||||
|
|
||||||
|
const deleteButton = screen.getByTitle("Delete current agent");
|
||||||
|
await act(async () => {
|
||||||
|
deleteButton.click();
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(mockClient.agents.delete).toHaveBeenCalledWith("agent_123");
|
||||||
|
expect(global.confirm).toHaveBeenCalledWith(
|
||||||
|
"Are you sure you want to delete this agent? This action cannot be undone and will delete all associated sessions."
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
(global.confirm as jest.Mock).mockRestore();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("does not delete agent when cancelled", async () => {
|
||||||
|
global.confirm = jest.fn(() => false);
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
render(<ChatPlaygroundPage />);
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByTitle("Delete current agent")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
const deleteButton = screen.getByTitle("Delete current agent");
|
||||||
|
await act(async () => {
|
||||||
|
deleteButton.click();
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(global.confirm).toHaveBeenCalled();
|
||||||
|
expect(mockClient.agents.delete).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
(global.confirm as jest.Mock).mockRestore();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Error Handling", () => {
|
||||||
|
test("handles agent loading errors gracefully", async () => {
|
||||||
|
mockClient.agents.list.mockRejectedValue(
|
||||||
|
new Error("Failed to load agents")
|
||||||
|
);
|
||||||
|
const consoleSpy = jest
|
||||||
|
.spyOn(console, "error")
|
||||||
|
.mockImplementation(() => {});
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
render(<ChatPlaygroundPage />);
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(consoleSpy).toHaveBeenCalledWith(
|
||||||
|
"Error fetching agents:",
|
||||||
|
expect.any(Error)
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(screen.getByText("+ New Agent")).toBeInTheDocument();
|
||||||
|
|
||||||
|
consoleSpy.mockRestore();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("handles model loading errors gracefully", async () => {
|
||||||
|
mockClient.models.list.mockRejectedValue(
|
||||||
|
new Error("Failed to load models")
|
||||||
|
);
|
||||||
|
const consoleSpy = jest
|
||||||
|
.spyOn(console, "error")
|
||||||
|
.mockImplementation(() => {});
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
render(<ChatPlaygroundPage />);
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(consoleSpy).toHaveBeenCalledWith(
|
||||||
|
"Error fetching models:",
|
||||||
|
expect.any(Error)
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
consoleSpy.mockRestore();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
File diff suppressed because it is too large
Load diff
Binary file not shown.
|
Before Width: | Height: | Size: 25 KiB |
|
|
@ -120,3 +120,44 @@
|
||||||
@apply bg-background text-foreground;
|
@apply bg-background text-foreground;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@layer utilities {
|
||||||
|
.animate-typing-dot-1 {
|
||||||
|
animation: typing-dot-bounce-1 0.8s cubic-bezier(0.4, 0, 0.6, 1) infinite;
|
||||||
|
}
|
||||||
|
|
||||||
|
.animate-typing-dot-2 {
|
||||||
|
animation: typing-dot-bounce-2 0.8s cubic-bezier(0.4, 0, 0.6, 1) infinite;
|
||||||
|
}
|
||||||
|
|
||||||
|
.animate-typing-dot-3 {
|
||||||
|
animation: typing-dot-bounce-3 0.8s cubic-bezier(0.4, 0, 0.6, 1) infinite;
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes typing-dot-bounce-1 {
|
||||||
|
0%, 15%, 85%, 100% {
|
||||||
|
transform: translateY(0);
|
||||||
|
}
|
||||||
|
7.5% {
|
||||||
|
transform: translateY(-6px);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes typing-dot-bounce-2 {
|
||||||
|
0%, 15%, 35%, 85%, 100% {
|
||||||
|
transform: translateY(0);
|
||||||
|
}
|
||||||
|
25% {
|
||||||
|
transform: translateY(-6px);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes typing-dot-bounce-3 {
|
||||||
|
0%, 35%, 55%, 85%, 100% {
|
||||||
|
transform: translateY(0);
|
||||||
|
}
|
||||||
|
45% {
|
||||||
|
transform: translateY(-6px);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,9 @@ const geistMono = Geist_Mono({
|
||||||
export const metadata: Metadata = {
|
export const metadata: Metadata = {
|
||||||
title: "Llama Stack",
|
title: "Llama Stack",
|
||||||
description: "Llama Stack UI",
|
description: "Llama Stack UI",
|
||||||
|
icons: {
|
||||||
|
icon: "/favicon.ico",
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
import { SidebarProvider, SidebarTrigger } from "@/components/ui/sidebar";
|
import { SidebarProvider, SidebarTrigger } from "@/components/ui/sidebar";
|
||||||
|
|
|
||||||
|
|
@ -161,10 +161,12 @@ export const ChatMessage: React.FC<ChatMessageProps> = ({
|
||||||
|
|
||||||
const isUser = role === "user";
|
const isUser = role === "user";
|
||||||
|
|
||||||
const formattedTime = createdAt?.toLocaleTimeString("en-US", {
|
const formattedTime = createdAt
|
||||||
hour: "2-digit",
|
? new Date(createdAt).toLocaleTimeString("en-US", {
|
||||||
minute: "2-digit",
|
hour: "2-digit",
|
||||||
});
|
minute: "2-digit",
|
||||||
|
})
|
||||||
|
: undefined;
|
||||||
|
|
||||||
if (isUser) {
|
if (isUser) {
|
||||||
return (
|
return (
|
||||||
|
|
@ -185,7 +187,7 @@ export const ChatMessage: React.FC<ChatMessageProps> = ({
|
||||||
|
|
||||||
{showTimeStamp && createdAt ? (
|
{showTimeStamp && createdAt ? (
|
||||||
<time
|
<time
|
||||||
dateTime={createdAt.toISOString()}
|
dateTime={new Date(createdAt).toISOString()}
|
||||||
className={cn(
|
className={cn(
|
||||||
"mt-1 block px-1 text-xs opacity-50",
|
"mt-1 block px-1 text-xs opacity-50",
|
||||||
animation !== "none" && "duration-500 animate-in fade-in-0"
|
animation !== "none" && "duration-500 animate-in fade-in-0"
|
||||||
|
|
@ -220,7 +222,7 @@ export const ChatMessage: React.FC<ChatMessageProps> = ({
|
||||||
|
|
||||||
{showTimeStamp && createdAt ? (
|
{showTimeStamp && createdAt ? (
|
||||||
<time
|
<time
|
||||||
dateTime={createdAt.toISOString()}
|
dateTime={new Date(createdAt).toISOString()}
|
||||||
className={cn(
|
className={cn(
|
||||||
"mt-1 block px-1 text-xs opacity-50",
|
"mt-1 block px-1 text-xs opacity-50",
|
||||||
animation !== "none" && "duration-500 animate-in fade-in-0"
|
animation !== "none" && "duration-500 animate-in fade-in-0"
|
||||||
|
|
@ -262,7 +264,7 @@ export const ChatMessage: React.FC<ChatMessageProps> = ({
|
||||||
|
|
||||||
{showTimeStamp && createdAt ? (
|
{showTimeStamp && createdAt ? (
|
||||||
<time
|
<time
|
||||||
dateTime={createdAt.toISOString()}
|
dateTime={new Date(createdAt).toISOString()}
|
||||||
className={cn(
|
className={cn(
|
||||||
"mt-1 block px-1 text-xs opacity-50",
|
"mt-1 block px-1 text-xs opacity-50",
|
||||||
animation !== "none" && "duration-500 animate-in fade-in-0"
|
animation !== "none" && "duration-500 animate-in fade-in-0"
|
||||||
|
|
|
||||||
345
llama_stack/ui/components/chat-playground/conversations.test.tsx
Normal file
345
llama_stack/ui/components/chat-playground/conversations.test.tsx
Normal file
|
|
@ -0,0 +1,345 @@
|
||||||
|
import React from "react";
|
||||||
|
import { render, screen, waitFor, act } from "@testing-library/react";
|
||||||
|
import "@testing-library/jest-dom";
|
||||||
|
import { Conversations, SessionUtils } from "./conversations";
|
||||||
|
import type { Message } from "@/components/chat-playground/chat-message";
|
||||||
|
|
||||||
|
interface ChatSession {
|
||||||
|
id: string;
|
||||||
|
name: string;
|
||||||
|
messages: Message[];
|
||||||
|
selectedModel: string;
|
||||||
|
systemMessage: string;
|
||||||
|
agentId: string;
|
||||||
|
createdAt: number;
|
||||||
|
updatedAt: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
const mockOnSessionChange = jest.fn();
|
||||||
|
const mockOnNewSession = jest.fn();
|
||||||
|
|
||||||
|
// Mock the auth client
|
||||||
|
const mockClient = {
|
||||||
|
agents: {
|
||||||
|
session: {
|
||||||
|
list: jest.fn(),
|
||||||
|
create: jest.fn(),
|
||||||
|
delete: jest.fn(),
|
||||||
|
retrieve: jest.fn(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
// Mock the useAuthClient hook
|
||||||
|
jest.mock("@/hooks/use-auth-client", () => ({
|
||||||
|
useAuthClient: jest.fn(() => mockClient),
|
||||||
|
}));
|
||||||
|
|
||||||
|
// Mock additional SessionUtils methods that are now being used
|
||||||
|
jest.mock("./conversations", () => {
|
||||||
|
const actual = jest.requireActual("./conversations");
|
||||||
|
return {
|
||||||
|
...actual,
|
||||||
|
SessionUtils: {
|
||||||
|
...actual.SessionUtils,
|
||||||
|
saveSessionData: jest.fn(),
|
||||||
|
loadSessionData: jest.fn(),
|
||||||
|
saveAgentConfig: jest.fn(),
|
||||||
|
loadAgentConfig: jest.fn(),
|
||||||
|
clearAgentCache: jest.fn(),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
const localStorageMock = {
|
||||||
|
getItem: jest.fn(),
|
||||||
|
setItem: jest.fn(),
|
||||||
|
removeItem: jest.fn(),
|
||||||
|
clear: jest.fn(),
|
||||||
|
};
|
||||||
|
|
||||||
|
Object.defineProperty(window, "localStorage", {
|
||||||
|
value: localStorageMock,
|
||||||
|
writable: true,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Mock crypto.randomUUID for test environment
|
||||||
|
let uuidCounter = 0;
|
||||||
|
Object.defineProperty(globalThis, "crypto", {
|
||||||
|
value: {
|
||||||
|
randomUUID: jest.fn(() => `test-uuid-${++uuidCounter}`),
|
||||||
|
},
|
||||||
|
writable: true,
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("SessionManager", () => {
|
||||||
|
const mockSession: ChatSession = {
|
||||||
|
id: "session_123",
|
||||||
|
name: "Test Session",
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
id: "msg_1",
|
||||||
|
role: "user",
|
||||||
|
content: "Hello",
|
||||||
|
createdAt: new Date(),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
selectedModel: "test-model",
|
||||||
|
systemMessage: "You are a helpful assistant.",
|
||||||
|
agentId: "agent_123",
|
||||||
|
createdAt: 1710000000,
|
||||||
|
updatedAt: 1710001000,
|
||||||
|
};
|
||||||
|
|
||||||
|
const mockAgentSessions = [
|
||||||
|
{
|
||||||
|
session_id: "session_123",
|
||||||
|
session_name: "Test Session",
|
||||||
|
started_at: "2024-01-01T00:00:00Z",
|
||||||
|
turns: [],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
session_id: "session_456",
|
||||||
|
session_name: "Another Session",
|
||||||
|
started_at: "2024-01-01T01:00:00Z",
|
||||||
|
turns: [],
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
jest.clearAllMocks();
|
||||||
|
localStorageMock.getItem.mockReturnValue(null);
|
||||||
|
localStorageMock.setItem.mockImplementation(() => {});
|
||||||
|
mockClient.agents.session.list.mockResolvedValue({
|
||||||
|
data: mockAgentSessions,
|
||||||
|
});
|
||||||
|
mockClient.agents.session.create.mockResolvedValue({
|
||||||
|
session_id: "new_session_123",
|
||||||
|
});
|
||||||
|
mockClient.agents.session.delete.mockResolvedValue(undefined);
|
||||||
|
mockClient.agents.session.retrieve.mockResolvedValue({
|
||||||
|
session_id: "test-session",
|
||||||
|
session_name: "Test Session",
|
||||||
|
started_at: new Date().toISOString(),
|
||||||
|
turns: [],
|
||||||
|
});
|
||||||
|
uuidCounter = 0; // Reset UUID counter for consistent test behavior
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Component Rendering", () => {
|
||||||
|
test("does not render when no agent is selected", async () => {
|
||||||
|
const { container } = await act(async () => {
|
||||||
|
return render(
|
||||||
|
<Conversations
|
||||||
|
selectedAgentId=""
|
||||||
|
currentSession={null}
|
||||||
|
onSessionChange={mockOnSessionChange}
|
||||||
|
onNewSession={mockOnNewSession}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(container.firstChild).toBeNull();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("renders loading state initially", async () => {
|
||||||
|
mockClient.agents.session.list.mockImplementation(
|
||||||
|
() => new Promise(() => {}) // Never resolves to simulate loading
|
||||||
|
);
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
render(
|
||||||
|
<Conversations
|
||||||
|
selectedAgentId="agent_123"
|
||||||
|
currentSession={null}
|
||||||
|
onSessionChange={mockOnSessionChange}
|
||||||
|
onNewSession={mockOnNewSession}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(screen.getByText("Select Session")).toBeInTheDocument();
|
||||||
|
// When loading, the "+ New" button should be disabled
|
||||||
|
expect(screen.getByText("+ New")).toBeDisabled();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("renders session selector when agent sessions are loaded", async () => {
|
||||||
|
await act(async () => {
|
||||||
|
render(
|
||||||
|
<Conversations
|
||||||
|
selectedAgentId="agent_123"
|
||||||
|
currentSession={null}
|
||||||
|
onSessionChange={mockOnSessionChange}
|
||||||
|
onNewSession={mockOnNewSession}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText("Select Session")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
test("renders current session name when session is selected", async () => {
|
||||||
|
await act(async () => {
|
||||||
|
render(
|
||||||
|
<Conversations
|
||||||
|
selectedAgentId="agent_123"
|
||||||
|
currentSession={mockSession}
|
||||||
|
onSessionChange={mockOnSessionChange}
|
||||||
|
onNewSession={mockOnNewSession}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText("Test Session")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Agent API Integration", () => {
|
||||||
|
test("loads sessions from agent API on mount", async () => {
|
||||||
|
await act(async () => {
|
||||||
|
render(
|
||||||
|
<Conversations
|
||||||
|
selectedAgentId="agent_123"
|
||||||
|
currentSession={mockSession}
|
||||||
|
onSessionChange={mockOnSessionChange}
|
||||||
|
onNewSession={mockOnNewSession}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(mockClient.agents.session.list).toHaveBeenCalledWith(
|
||||||
|
"agent_123"
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
test("handles API errors gracefully", async () => {
|
||||||
|
mockClient.agents.session.list.mockRejectedValue(new Error("API Error"));
|
||||||
|
const consoleSpy = jest
|
||||||
|
.spyOn(console, "error")
|
||||||
|
.mockImplementation(() => {});
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
render(
|
||||||
|
<Conversations
|
||||||
|
selectedAgentId="agent_123"
|
||||||
|
currentSession={mockSession}
|
||||||
|
onSessionChange={mockOnSessionChange}
|
||||||
|
onNewSession={mockOnNewSession}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(consoleSpy).toHaveBeenCalledWith(
|
||||||
|
"Error loading agent sessions:",
|
||||||
|
expect.any(Error)
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
consoleSpy.mockRestore();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Error Handling", () => {
|
||||||
|
test("component renders without crashing when API is unavailable", async () => {
|
||||||
|
mockClient.agents.session.list.mockRejectedValue(
|
||||||
|
new Error("Network Error")
|
||||||
|
);
|
||||||
|
const consoleSpy = jest
|
||||||
|
.spyOn(console, "error")
|
||||||
|
.mockImplementation(() => {});
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
render(
|
||||||
|
<Conversations
|
||||||
|
selectedAgentId="agent_123"
|
||||||
|
currentSession={mockSession}
|
||||||
|
onSessionChange={mockOnSessionChange}
|
||||||
|
onNewSession={mockOnNewSession}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Should still render the session manager with the select trigger
|
||||||
|
expect(screen.getByRole("combobox")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("+ New")).toBeInTheDocument();
|
||||||
|
consoleSpy.mockRestore();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("SessionUtils", () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
jest.clearAllMocks();
|
||||||
|
localStorageMock.getItem.mockReturnValue(null);
|
||||||
|
localStorageMock.setItem.mockImplementation(() => {});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("saveCurrentSessionId", () => {
|
||||||
|
test("saves session ID to localStorage", () => {
|
||||||
|
SessionUtils.saveCurrentSessionId("test-session-id");
|
||||||
|
|
||||||
|
expect(localStorageMock.setItem).toHaveBeenCalledWith(
|
||||||
|
"chat-playground-current-session",
|
||||||
|
"test-session-id"
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("createDefaultSession", () => {
|
||||||
|
test("creates default session with agent ID", () => {
|
||||||
|
const result = SessionUtils.createDefaultSession("agent_123");
|
||||||
|
|
||||||
|
expect(result).toEqual(
|
||||||
|
expect.objectContaining({
|
||||||
|
name: "Default Session",
|
||||||
|
messages: [],
|
||||||
|
selectedModel: "",
|
||||||
|
systemMessage: "You are a helpful assistant.",
|
||||||
|
agentId: "agent_123",
|
||||||
|
})
|
||||||
|
);
|
||||||
|
expect(result.id).toBeTruthy();
|
||||||
|
expect(result.createdAt).toBeTruthy();
|
||||||
|
expect(result.updatedAt).toBeTruthy();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("creates default session with inherited model", () => {
|
||||||
|
const result = SessionUtils.createDefaultSession(
|
||||||
|
"agent_123",
|
||||||
|
"inherited-model"
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(result.selectedModel).toBe("inherited-model");
|
||||||
|
expect(result.agentId).toBe("agent_123");
|
||||||
|
});
|
||||||
|
|
||||||
|
test("creates unique session IDs", () => {
|
||||||
|
const originalNow = Date.now;
|
||||||
|
let mockTime = 1710005000;
|
||||||
|
Date.now = jest.fn(() => ++mockTime);
|
||||||
|
|
||||||
|
const session1 = SessionUtils.createDefaultSession("agent_123");
|
||||||
|
const session2 = SessionUtils.createDefaultSession("agent_123");
|
||||||
|
|
||||||
|
expect(session1.id).not.toBe(session2.id);
|
||||||
|
|
||||||
|
Date.now = originalNow;
|
||||||
|
});
|
||||||
|
|
||||||
|
test("sets creation and update timestamps", () => {
|
||||||
|
const result = SessionUtils.createDefaultSession("agent_123");
|
||||||
|
|
||||||
|
expect(result.createdAt).toBeTruthy();
|
||||||
|
expect(result.updatedAt).toBeTruthy();
|
||||||
|
expect(typeof result.createdAt).toBe("number");
|
||||||
|
expect(typeof result.updatedAt).toBe("number");
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
568
llama_stack/ui/components/chat-playground/conversations.tsx
Normal file
568
llama_stack/ui/components/chat-playground/conversations.tsx
Normal file
|
|
@ -0,0 +1,568 @@
|
||||||
|
"use client";
|
||||||
|
|
||||||
|
import { useState, useEffect, useCallback } from "react";
|
||||||
|
import { Button } from "@/components/ui/button";
|
||||||
|
import {
|
||||||
|
Select,
|
||||||
|
SelectContent,
|
||||||
|
SelectItem,
|
||||||
|
SelectTrigger,
|
||||||
|
SelectValue,
|
||||||
|
} from "@/components/ui/select";
|
||||||
|
import { Input } from "@/components/ui/input";
|
||||||
|
import { Card } from "@/components/ui/card";
|
||||||
|
import { Trash2 } from "lucide-react";
|
||||||
|
import type { Message } from "@/components/chat-playground/chat-message";
|
||||||
|
import { useAuthClient } from "@/hooks/use-auth-client";
|
||||||
|
import type {
|
||||||
|
Session,
|
||||||
|
SessionCreateParams,
|
||||||
|
} from "llama-stack-client/resources/agents";
|
||||||
|
|
||||||
|
export interface ChatSession {
|
||||||
|
id: string;
|
||||||
|
name: string;
|
||||||
|
messages: Message[];
|
||||||
|
selectedModel: string;
|
||||||
|
systemMessage: string;
|
||||||
|
agentId: string;
|
||||||
|
session?: Session;
|
||||||
|
createdAt: number;
|
||||||
|
updatedAt: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface SessionManagerProps {
|
||||||
|
currentSession: ChatSession | null;
|
||||||
|
onSessionChange: (session: ChatSession) => void;
|
||||||
|
onNewSession: () => void;
|
||||||
|
selectedAgentId: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
const CURRENT_SESSION_KEY = "chat-playground-current-session";
|
||||||
|
|
||||||
|
// ensures this only happens client side
|
||||||
|
const safeLocalStorage = {
|
||||||
|
getItem: (key: string): string | null => {
|
||||||
|
if (typeof window === "undefined") return null;
|
||||||
|
try {
|
||||||
|
return localStorage.getItem(key);
|
||||||
|
} catch (err) {
|
||||||
|
console.error("Error accessing localStorage:", err);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
setItem: (key: string, value: string): void => {
|
||||||
|
if (typeof window === "undefined") return;
|
||||||
|
try {
|
||||||
|
localStorage.setItem(key, value);
|
||||||
|
} catch (err) {
|
||||||
|
console.error("Error writing to localStorage:", err);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
removeItem: (key: string): void => {
|
||||||
|
if (typeof window === "undefined") return;
|
||||||
|
try {
|
||||||
|
localStorage.removeItem(key);
|
||||||
|
} catch (err) {
|
||||||
|
console.error("Error removing from localStorage:", err);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
const generateSessionId = (): string => {
|
||||||
|
return globalThis.crypto.randomUUID();
|
||||||
|
};
|
||||||
|
|
||||||
|
export function Conversations({
|
||||||
|
currentSession,
|
||||||
|
onSessionChange,
|
||||||
|
selectedAgentId,
|
||||||
|
}: SessionManagerProps) {
|
||||||
|
const [sessions, setSessions] = useState<ChatSession[]>([]);
|
||||||
|
const [showCreateForm, setShowCreateForm] = useState(false);
|
||||||
|
const [newSessionName, setNewSessionName] = useState("");
|
||||||
|
const [loading, setLoading] = useState(false);
|
||||||
|
const client = useAuthClient();
|
||||||
|
|
||||||
|
const loadAgentSessions = useCallback(async () => {
|
||||||
|
if (!selectedAgentId) return;
|
||||||
|
|
||||||
|
setLoading(true);
|
||||||
|
try {
|
||||||
|
const response = await client.agents.session.list(selectedAgentId);
|
||||||
|
console.log("Sessions response:", response);
|
||||||
|
|
||||||
|
if (!response.data || !Array.isArray(response.data)) {
|
||||||
|
console.warn("Invalid sessions response, starting fresh");
|
||||||
|
setSessions([]);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const agentSessions: ChatSession[] = response.data
|
||||||
|
.filter(sessionData => {
|
||||||
|
const isValid =
|
||||||
|
sessionData &&
|
||||||
|
typeof sessionData === "object" &&
|
||||||
|
sessionData.session_id &&
|
||||||
|
sessionData.session_name;
|
||||||
|
if (!isValid) {
|
||||||
|
console.warn("Filtering out invalid session:", sessionData);
|
||||||
|
}
|
||||||
|
return isValid;
|
||||||
|
})
|
||||||
|
.map(sessionData => ({
|
||||||
|
id: sessionData.session_id,
|
||||||
|
name: sessionData.session_name,
|
||||||
|
messages: [],
|
||||||
|
selectedModel: currentSession?.selectedModel || "",
|
||||||
|
systemMessage:
|
||||||
|
currentSession?.systemMessage || "You are a helpful assistant.",
|
||||||
|
agentId: selectedAgentId,
|
||||||
|
session: sessionData,
|
||||||
|
createdAt: sessionData.started_at
|
||||||
|
? new Date(sessionData.started_at).getTime()
|
||||||
|
: Date.now(),
|
||||||
|
updatedAt: sessionData.started_at
|
||||||
|
? new Date(sessionData.started_at).getTime()
|
||||||
|
: Date.now(),
|
||||||
|
}));
|
||||||
|
setSessions(agentSessions);
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Error loading agent sessions:", error);
|
||||||
|
setSessions([]);
|
||||||
|
} finally {
|
||||||
|
setLoading(false);
|
||||||
|
}
|
||||||
|
}, [
|
||||||
|
selectedAgentId,
|
||||||
|
client,
|
||||||
|
currentSession?.selectedModel,
|
||||||
|
currentSession?.systemMessage,
|
||||||
|
]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (selectedAgentId) {
|
||||||
|
loadAgentSessions();
|
||||||
|
}
|
||||||
|
}, [selectedAgentId, loadAgentSessions]);
|
||||||
|
|
||||||
|
const createNewSession = async () => {
|
||||||
|
if (!selectedAgentId) return;
|
||||||
|
|
||||||
|
const sessionName =
|
||||||
|
newSessionName.trim() || `Session ${sessions.length + 1}`;
|
||||||
|
setLoading(true);
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await client.agents.session.create(selectedAgentId, {
|
||||||
|
session_name: sessionName,
|
||||||
|
} as SessionCreateParams);
|
||||||
|
|
||||||
|
const newSession: ChatSession = {
|
||||||
|
id: response.session_id,
|
||||||
|
name: sessionName,
|
||||||
|
messages: [],
|
||||||
|
selectedModel: currentSession?.selectedModel || "",
|
||||||
|
systemMessage:
|
||||||
|
currentSession?.systemMessage || "You are a helpful assistant.",
|
||||||
|
agentId: selectedAgentId,
|
||||||
|
createdAt: Date.now(),
|
||||||
|
updatedAt: Date.now(),
|
||||||
|
};
|
||||||
|
|
||||||
|
setSessions(prev => [...prev, newSession]);
|
||||||
|
SessionUtils.saveCurrentSessionId(newSession.id, selectedAgentId);
|
||||||
|
onSessionChange(newSession);
|
||||||
|
|
||||||
|
setNewSessionName("");
|
||||||
|
setShowCreateForm(false);
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Error creating session:", error);
|
||||||
|
} finally {
|
||||||
|
setLoading(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const loadSessionMessages = useCallback(
|
||||||
|
async (agentId: string, sessionId: string): Promise<Message[]> => {
|
||||||
|
try {
|
||||||
|
const session = await client.agents.session.retrieve(
|
||||||
|
agentId,
|
||||||
|
sessionId
|
||||||
|
);
|
||||||
|
|
||||||
|
if (!session || !session.turns || !Array.isArray(session.turns)) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
const messages: Message[] = [];
|
||||||
|
for (const turn of session.turns) {
|
||||||
|
// Add user messages from input_messages
|
||||||
|
if (turn.input_messages && Array.isArray(turn.input_messages)) {
|
||||||
|
for (const input of turn.input_messages) {
|
||||||
|
if (input.role === "user" && input.content) {
|
||||||
|
messages.push({
|
||||||
|
id: `${turn.turn_id}-user-${messages.length}`,
|
||||||
|
role: "user",
|
||||||
|
content:
|
||||||
|
typeof input.content === "string"
|
||||||
|
? input.content
|
||||||
|
: JSON.stringify(input.content),
|
||||||
|
createdAt: new Date(turn.started_at || Date.now()),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add assistant message from output_message
|
||||||
|
if (turn.output_message && turn.output_message.content) {
|
||||||
|
messages.push({
|
||||||
|
id: `${turn.turn_id}-assistant-${messages.length}`,
|
||||||
|
role: "assistant",
|
||||||
|
content:
|
||||||
|
typeof turn.output_message.content === "string"
|
||||||
|
? turn.output_message.content
|
||||||
|
: JSON.stringify(turn.output_message.content),
|
||||||
|
createdAt: new Date(
|
||||||
|
turn.completed_at || turn.started_at || Date.now()
|
||||||
|
),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return messages;
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Error loading session messages:", error);
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[client]
|
||||||
|
);
|
||||||
|
|
||||||
|
const switchToSession = useCallback(
|
||||||
|
async (sessionId: string) => {
|
||||||
|
const session = sessions.find(s => s.id === sessionId);
|
||||||
|
if (session) {
|
||||||
|
setLoading(true);
|
||||||
|
try {
|
||||||
|
// Load messages for this session
|
||||||
|
const messages = await loadSessionMessages(
|
||||||
|
selectedAgentId,
|
||||||
|
sessionId
|
||||||
|
);
|
||||||
|
const sessionWithMessages = {
|
||||||
|
...session,
|
||||||
|
messages,
|
||||||
|
};
|
||||||
|
|
||||||
|
SessionUtils.saveCurrentSessionId(sessionId, selectedAgentId);
|
||||||
|
onSessionChange(sessionWithMessages);
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Error switching to session:", error);
|
||||||
|
// Fallback to session without messages
|
||||||
|
SessionUtils.saveCurrentSessionId(sessionId, selectedAgentId);
|
||||||
|
onSessionChange(session);
|
||||||
|
} finally {
|
||||||
|
setLoading(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[sessions, selectedAgentId, loadSessionMessages, onSessionChange]
|
||||||
|
);
|
||||||
|
|
||||||
|
const deleteSession = async (sessionId: string) => {
|
||||||
|
if (sessions.length <= 1 || !selectedAgentId) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
confirm(
|
||||||
|
"Are you sure you want to delete this session? This action cannot be undone."
|
||||||
|
)
|
||||||
|
) {
|
||||||
|
setLoading(true);
|
||||||
|
try {
|
||||||
|
await client.agents.session.delete(selectedAgentId, sessionId);
|
||||||
|
|
||||||
|
const updatedSessions = sessions.filter(s => s.id !== sessionId);
|
||||||
|
setSessions(updatedSessions);
|
||||||
|
|
||||||
|
if (currentSession?.id === sessionId) {
|
||||||
|
const newCurrentSession = updatedSessions[0] || null;
|
||||||
|
if (newCurrentSession) {
|
||||||
|
SessionUtils.saveCurrentSessionId(
|
||||||
|
newCurrentSession.id,
|
||||||
|
selectedAgentId
|
||||||
|
);
|
||||||
|
onSessionChange(newCurrentSession);
|
||||||
|
} else {
|
||||||
|
SessionUtils.clearCurrentSession(selectedAgentId);
|
||||||
|
onNewSession();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Error deleting session:", error);
|
||||||
|
} finally {
|
||||||
|
setLoading(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (currentSession) {
|
||||||
|
setSessions(prevSessions => {
|
||||||
|
const updatedSessions = prevSessions.map(session =>
|
||||||
|
session.id === currentSession.id ? currentSession : session
|
||||||
|
);
|
||||||
|
|
||||||
|
if (!prevSessions.find(s => s.id === currentSession.id)) {
|
||||||
|
updatedSessions.push(currentSession);
|
||||||
|
}
|
||||||
|
|
||||||
|
return updatedSessions;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}, [currentSession]);
|
||||||
|
|
||||||
|
// Don't render if no agent is selected
|
||||||
|
if (!selectedAgentId) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="relative">
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
<Select
|
||||||
|
value={currentSession?.id || ""}
|
||||||
|
onValueChange={switchToSession}
|
||||||
|
>
|
||||||
|
<SelectTrigger className="w-[200px]">
|
||||||
|
<SelectValue placeholder="Select Session" />
|
||||||
|
</SelectTrigger>
|
||||||
|
<SelectContent>
|
||||||
|
{sessions.map(session => (
|
||||||
|
<SelectItem key={session.id} value={session.id}>
|
||||||
|
{session.name}
|
||||||
|
</SelectItem>
|
||||||
|
))}
|
||||||
|
</SelectContent>
|
||||||
|
</Select>
|
||||||
|
|
||||||
|
<Button
|
||||||
|
onClick={() => setShowCreateForm(true)}
|
||||||
|
variant="outline"
|
||||||
|
size="sm"
|
||||||
|
disabled={loading || !selectedAgentId}
|
||||||
|
>
|
||||||
|
+ New
|
||||||
|
</Button>
|
||||||
|
|
||||||
|
{currentSession && sessions.length > 1 && (
|
||||||
|
<Button
|
||||||
|
onClick={() => deleteSession(currentSession.id)}
|
||||||
|
variant="outline"
|
||||||
|
size="sm"
|
||||||
|
className="text-destructive hover:text-destructive hover:bg-destructive/10"
|
||||||
|
title="Delete current session"
|
||||||
|
>
|
||||||
|
<Trash2 className="h-3 w-3" />
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{showCreateForm && (
|
||||||
|
<Card className="absolute top-full left-0 mt-2 p-4 space-y-3 w-80 z-50 bg-background border shadow-lg">
|
||||||
|
<h3 className="text-md font-semibold">Create New Session</h3>
|
||||||
|
|
||||||
|
<Input
|
||||||
|
value={newSessionName}
|
||||||
|
onChange={e => setNewSessionName(e.target.value)}
|
||||||
|
placeholder="Session name (optional)"
|
||||||
|
onKeyDown={e => {
|
||||||
|
if (e.key === "Enter") {
|
||||||
|
createNewSession();
|
||||||
|
} else if (e.key === "Escape") {
|
||||||
|
setShowCreateForm(false);
|
||||||
|
setNewSessionName("");
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<div className="flex gap-2">
|
||||||
|
<Button
|
||||||
|
onClick={createNewSession}
|
||||||
|
className="flex-1"
|
||||||
|
disabled={loading}
|
||||||
|
>
|
||||||
|
{loading ? "Creating..." : "Create"}
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
variant="outline"
|
||||||
|
onClick={() => {
|
||||||
|
setShowCreateForm(false);
|
||||||
|
setNewSessionName("");
|
||||||
|
}}
|
||||||
|
className="flex-1"
|
||||||
|
>
|
||||||
|
Cancel
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</Card>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{currentSession && sessions.length > 1 && (
|
||||||
|
<div className="absolute top-full left-0 mt-1 text-xs text-gray-500 whitespace-nowrap">
|
||||||
|
{sessions.length} sessions • Current: {currentSession.name}
|
||||||
|
{currentSession.messages.length > 0 &&
|
||||||
|
` • ${currentSession.messages.length} messages`}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export const SessionUtils = {
|
||||||
|
loadCurrentSessionId: (agentId?: string): string | null => {
|
||||||
|
const key = agentId
|
||||||
|
? `${CURRENT_SESSION_KEY}-${agentId}`
|
||||||
|
: CURRENT_SESSION_KEY;
|
||||||
|
return safeLocalStorage.getItem(key);
|
||||||
|
},
|
||||||
|
|
||||||
|
saveCurrentSessionId: (sessionId: string, agentId?: string) => {
|
||||||
|
const key = agentId
|
||||||
|
? `${CURRENT_SESSION_KEY}-${agentId}`
|
||||||
|
: CURRENT_SESSION_KEY;
|
||||||
|
safeLocalStorage.setItem(key, sessionId);
|
||||||
|
},
|
||||||
|
|
||||||
|
createDefaultSession: (
|
||||||
|
agentId: string,
|
||||||
|
inheritModel?: string
|
||||||
|
): ChatSession => ({
|
||||||
|
id: generateSessionId(),
|
||||||
|
name: "Default Session",
|
||||||
|
messages: [],
|
||||||
|
selectedModel: inheritModel || "",
|
||||||
|
systemMessage: "You are a helpful assistant.",
|
||||||
|
agentId,
|
||||||
|
createdAt: Date.now(),
|
||||||
|
updatedAt: Date.now(),
|
||||||
|
}),
|
||||||
|
|
||||||
|
clearCurrentSession: (agentId?: string) => {
|
||||||
|
const key = agentId
|
||||||
|
? `${CURRENT_SESSION_KEY}-${agentId}`
|
||||||
|
: CURRENT_SESSION_KEY;
|
||||||
|
safeLocalStorage.removeItem(key);
|
||||||
|
},
|
||||||
|
|
||||||
|
loadCurrentAgentId: (): string | null => {
|
||||||
|
return safeLocalStorage.getItem("chat-playground-current-agent");
|
||||||
|
},
|
||||||
|
|
||||||
|
saveCurrentAgentId: (agentId: string) => {
|
||||||
|
safeLocalStorage.setItem("chat-playground-current-agent", agentId);
|
||||||
|
},
|
||||||
|
|
||||||
|
// Comprehensive session caching
|
||||||
|
saveSessionData: (agentId: string, sessionData: ChatSession) => {
|
||||||
|
const key = `chat-playground-session-data-${agentId}-${sessionData.id}`;
|
||||||
|
safeLocalStorage.setItem(
|
||||||
|
key,
|
||||||
|
JSON.stringify({
|
||||||
|
...sessionData,
|
||||||
|
cachedAt: Date.now(),
|
||||||
|
})
|
||||||
|
);
|
||||||
|
},
|
||||||
|
|
||||||
|
loadSessionData: (agentId: string, sessionId: string): ChatSession | null => {
|
||||||
|
const key = `chat-playground-session-data-${agentId}-${sessionId}`;
|
||||||
|
const cached = safeLocalStorage.getItem(key);
|
||||||
|
if (!cached) return null;
|
||||||
|
|
||||||
|
try {
|
||||||
|
const data = JSON.parse(cached);
|
||||||
|
// Check if cache is fresh (less than 1 hour old)
|
||||||
|
const cacheAge = Date.now() - (data.cachedAt || 0);
|
||||||
|
if (cacheAge > 60 * 60 * 1000) {
|
||||||
|
safeLocalStorage.removeItem(key);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert date strings back to Date objects
|
||||||
|
return {
|
||||||
|
...data,
|
||||||
|
messages: data.messages.map(
|
||||||
|
(msg: { createdAt: string; [key: string]: unknown }) => ({
|
||||||
|
...msg,
|
||||||
|
createdAt: new Date(msg.createdAt),
|
||||||
|
})
|
||||||
|
),
|
||||||
|
};
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Error parsing cached session data:", error);
|
||||||
|
safeLocalStorage.removeItem(key);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
// Agent config caching
|
||||||
|
saveAgentConfig: (
|
||||||
|
agentId: string,
|
||||||
|
config: {
|
||||||
|
toolgroups?: Array<
|
||||||
|
string | { name: string; args: Record<string, unknown> }
|
||||||
|
>;
|
||||||
|
[key: string]: unknown;
|
||||||
|
}
|
||||||
|
) => {
|
||||||
|
const key = `chat-playground-agent-config-${agentId}`;
|
||||||
|
safeLocalStorage.setItem(
|
||||||
|
key,
|
||||||
|
JSON.stringify({
|
||||||
|
config,
|
||||||
|
cachedAt: Date.now(),
|
||||||
|
})
|
||||||
|
);
|
||||||
|
},
|
||||||
|
|
||||||
|
loadAgentConfig: (
|
||||||
|
agentId: string
|
||||||
|
): {
|
||||||
|
toolgroups?: Array<
|
||||||
|
string | { name: string; args: Record<string, unknown> }
|
||||||
|
>;
|
||||||
|
[key: string]: unknown;
|
||||||
|
} | null => {
|
||||||
|
const key = `chat-playground-agent-config-${agentId}`;
|
||||||
|
const cached = safeLocalStorage.getItem(key);
|
||||||
|
if (!cached) return null;
|
||||||
|
|
||||||
|
try {
|
||||||
|
const data = JSON.parse(cached);
|
||||||
|
// Check if cache is fresh (less than 30 minutes old)
|
||||||
|
const cacheAge = Date.now() - (data.cachedAt || 0);
|
||||||
|
if (cacheAge > 30 * 60 * 1000) {
|
||||||
|
safeLocalStorage.removeItem(key);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return data.config;
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Error parsing cached agent config:", error);
|
||||||
|
safeLocalStorage.removeItem(key);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
// Clear all cached data for an agent
|
||||||
|
clearAgentCache: (agentId: string) => {
|
||||||
|
const keys = Object.keys(localStorage).filter(
|
||||||
|
key =>
|
||||||
|
key.includes(`chat-playground-session-data-${agentId}`) ||
|
||||||
|
key.includes(`chat-playground-agent-config-${agentId}`)
|
||||||
|
);
|
||||||
|
keys.forEach(key => safeLocalStorage.removeItem(key));
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
@ -5,9 +5,9 @@ export function TypingIndicator() {
|
||||||
<div className="justify-left flex space-x-1">
|
<div className="justify-left flex space-x-1">
|
||||||
<div className="rounded-lg bg-muted p-3">
|
<div className="rounded-lg bg-muted p-3">
|
||||||
<div className="flex -space-x-2.5">
|
<div className="flex -space-x-2.5">
|
||||||
<Dot className="h-5 w-5 animate-typing-dot-bounce" />
|
<Dot className="h-5 w-5 animate-typing-dot-1" />
|
||||||
<Dot className="h-5 w-5 animate-typing-dot-bounce [animation-delay:90ms]" />
|
<Dot className="h-5 w-5 animate-typing-dot-2" />
|
||||||
<Dot className="h-5 w-5 animate-typing-dot-bounce [animation-delay:180ms]" />
|
<Dot className="h-5 w-5 animate-typing-dot-3" />
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
|
||||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue