mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
Merge branch 'main' into feat/gunicorn-production-server
This commit is contained in:
commit
47bd994824
59 changed files with 3190 additions and 421 deletions
60
.github/actions/install-llama-stack-client/action.yml
vendored
Normal file
60
.github/actions/install-llama-stack-client/action.yml
vendored
Normal file
|
|
@ -0,0 +1,60 @@
|
||||||
|
name: Install llama-stack-client
|
||||||
|
description: Install llama-stack-client based on branch context and client-version input
|
||||||
|
|
||||||
|
inputs:
|
||||||
|
client-version:
|
||||||
|
description: 'Client version to install on non-release branches (latest or published). Ignored on release branches.'
|
||||||
|
required: false
|
||||||
|
default: ""
|
||||||
|
|
||||||
|
outputs:
|
||||||
|
uv-extra-index-url:
|
||||||
|
description: 'UV_EXTRA_INDEX_URL to use (set for release branches)'
|
||||||
|
value: ${{ steps.configure.outputs.uv-extra-index-url }}
|
||||||
|
install-after-sync:
|
||||||
|
description: 'Whether to install client after uv sync'
|
||||||
|
value: ${{ steps.configure.outputs.install-after-sync }}
|
||||||
|
install-source:
|
||||||
|
description: 'Where to install client from after sync'
|
||||||
|
value: ${{ steps.configure.outputs.install-source }}
|
||||||
|
|
||||||
|
runs:
|
||||||
|
using: "composite"
|
||||||
|
steps:
|
||||||
|
- name: Configure client installation
|
||||||
|
id: configure
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
# Determine the branch we're working with
|
||||||
|
BRANCH="${{ github.base_ref || github.ref }}"
|
||||||
|
BRANCH="${BRANCH#refs/heads/}"
|
||||||
|
|
||||||
|
echo "Working with branch: $BRANCH"
|
||||||
|
|
||||||
|
# On release branches: use test.pypi for uv sync, then install from git
|
||||||
|
# On non-release branches: install based on client-version after sync
|
||||||
|
if [[ "$BRANCH" =~ ^release-[0-9]+\.[0-9]+\.x$ ]]; then
|
||||||
|
echo "Detected release branch: $BRANCH"
|
||||||
|
|
||||||
|
# Check if matching branch exists in client repo
|
||||||
|
if ! git ls-remote --exit-code --heads https://github.com/llamastack/llama-stack-client-python.git "$BRANCH" > /dev/null 2>&1; then
|
||||||
|
echo "::error::Branch $BRANCH not found in llama-stack-client-python repository"
|
||||||
|
echo "::error::Please create the matching release branch in llama-stack-client-python before testing"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Configure to use test.pypi as extra index (PyPI is primary)
|
||||||
|
echo "uv-extra-index-url=https://test.pypi.org/simple/" >> $GITHUB_OUTPUT
|
||||||
|
echo "install-after-sync=true" >> $GITHUB_OUTPUT
|
||||||
|
echo "install-source=git+https://github.com/llamastack/llama-stack-client-python.git@$BRANCH" >> $GITHUB_OUTPUT
|
||||||
|
elif [ "${{ inputs.client-version }}" = "latest" ]; then
|
||||||
|
# Install from main git after sync
|
||||||
|
echo "install-after-sync=true" >> $GITHUB_OUTPUT
|
||||||
|
echo "install-source=git+https://github.com/llamastack/llama-stack-client-python.git@main" >> $GITHUB_OUTPUT
|
||||||
|
elif [ "${{ inputs.client-version }}" = "published" ]; then
|
||||||
|
# Use published version from PyPI (installed by sync)
|
||||||
|
echo "install-after-sync=false" >> $GITHUB_OUTPUT
|
||||||
|
elif [ -n "${{ inputs.client-version }}" ]; then
|
||||||
|
echo "::error::Invalid client-version: ${{ inputs.client-version }}"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
@ -94,7 +94,7 @@ runs:
|
||||||
if: ${{ always() }}
|
if: ${{ always() }}
|
||||||
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
|
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
|
||||||
with:
|
with:
|
||||||
name: logs-${{ github.run_id }}-${{ github.run_attempt || '' }}-${{ strategy.job-index }}
|
name: logs-${{ github.run_id }}-${{ github.run_attempt || '1' }}-${{ strategy.job-index || github.job }}-${{ github.action }}
|
||||||
path: |
|
path: |
|
||||||
*.log
|
*.log
|
||||||
retention-days: 1
|
retention-days: 1
|
||||||
|
|
|
||||||
30
.github/actions/setup-runner/action.yml
vendored
30
.github/actions/setup-runner/action.yml
vendored
|
|
@ -18,25 +18,35 @@ runs:
|
||||||
python-version: ${{ inputs.python-version }}
|
python-version: ${{ inputs.python-version }}
|
||||||
version: 0.7.6
|
version: 0.7.6
|
||||||
|
|
||||||
|
- name: Configure client installation
|
||||||
|
id: client-config
|
||||||
|
uses: ./.github/actions/install-llama-stack-client
|
||||||
|
with:
|
||||||
|
client-version: ${{ inputs.client-version }}
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
shell: bash
|
shell: bash
|
||||||
|
env:
|
||||||
|
UV_EXTRA_INDEX_URL: ${{ steps.client-config.outputs.uv-extra-index-url }}
|
||||||
run: |
|
run: |
|
||||||
|
# Export UV env vars for current step and persist to GITHUB_ENV for subsequent steps
|
||||||
|
if [ -n "$UV_EXTRA_INDEX_URL" ]; then
|
||||||
|
export UV_INDEX_STRATEGY=unsafe-best-match
|
||||||
|
echo "UV_EXTRA_INDEX_URL=$UV_EXTRA_INDEX_URL" >> $GITHUB_ENV
|
||||||
|
echo "UV_INDEX_STRATEGY=$UV_INDEX_STRATEGY" >> $GITHUB_ENV
|
||||||
|
echo "Exported UV environment variables for current and subsequent steps"
|
||||||
|
fi
|
||||||
|
|
||||||
echo "Updating project dependencies via uv sync"
|
echo "Updating project dependencies via uv sync"
|
||||||
uv sync --all-groups
|
uv sync --all-groups
|
||||||
|
|
||||||
echo "Installing ad-hoc dependencies"
|
echo "Installing ad-hoc dependencies"
|
||||||
uv pip install faiss-cpu
|
uv pip install faiss-cpu
|
||||||
|
|
||||||
# Install llama-stack-client-python based on the client-version input
|
# Install specific client version after sync if needed
|
||||||
if [ "${{ inputs.client-version }}" = "latest" ]; then
|
if [ "${{ steps.client-config.outputs.install-after-sync }}" = "true" ]; then
|
||||||
echo "Installing latest llama-stack-client-python from main branch"
|
echo "Installing llama-stack-client from: ${{ steps.client-config.outputs.install-source }}"
|
||||||
uv pip install git+https://github.com/llamastack/llama-stack-client-python.git@main
|
uv pip install ${{ steps.client-config.outputs.install-source }}
|
||||||
elif [ "${{ inputs.client-version }}" = "published" ]; then
|
|
||||||
echo "Installing published llama-stack-client-python from PyPI"
|
|
||||||
uv pip install llama-stack-client
|
|
||||||
else
|
|
||||||
echo "Invalid client-version: ${{ inputs.client-version }}"
|
|
||||||
exit 1
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
echo "Installed llama packages"
|
echo "Installed llama packages"
|
||||||
|
|
|
||||||
|
|
@ -42,18 +42,7 @@ runs:
|
||||||
- name: Build Llama Stack
|
- name: Build Llama Stack
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
# Install llama-stack-client-python based on the client-version input
|
# Client is already installed by setup-runner (handles both main and release branches)
|
||||||
if [ "${{ inputs.client-version }}" = "latest" ]; then
|
|
||||||
echo "Installing latest llama-stack-client-python from main branch"
|
|
||||||
export LLAMA_STACK_CLIENT_DIR=git+https://github.com/llamastack/llama-stack-client-python.git@main
|
|
||||||
elif [ "${{ inputs.client-version }}" = "published" ]; then
|
|
||||||
echo "Installing published llama-stack-client-python from PyPI"
|
|
||||||
unset LLAMA_STACK_CLIENT_DIR
|
|
||||||
else
|
|
||||||
echo "Invalid client-version: ${{ inputs.client-version }}"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo "Building Llama Stack"
|
echo "Building Llama Stack"
|
||||||
|
|
||||||
LLAMA_STACK_DIR=. \
|
LLAMA_STACK_DIR=. \
|
||||||
|
|
|
||||||
1
.github/workflows/README.md
vendored
1
.github/workflows/README.md
vendored
|
|
@ -13,7 +13,6 @@ Llama Stack uses GitHub Actions for Continuous Integration (CI). Below is a tabl
|
||||||
| Integration Tests (Replay) | [integration-tests.yml](integration-tests.yml) | Run the integration test suites from tests/integration in replay mode |
|
| Integration Tests (Replay) | [integration-tests.yml](integration-tests.yml) | Run the integration test suites from tests/integration in replay mode |
|
||||||
| Vector IO Integration Tests | [integration-vector-io-tests.yml](integration-vector-io-tests.yml) | Run the integration test suite with various VectorIO providers |
|
| Vector IO Integration Tests | [integration-vector-io-tests.yml](integration-vector-io-tests.yml) | Run the integration test suite with various VectorIO providers |
|
||||||
| Pre-commit | [pre-commit.yml](pre-commit.yml) | Run pre-commit checks |
|
| Pre-commit | [pre-commit.yml](pre-commit.yml) | Run pre-commit checks |
|
||||||
| Pre-commit Bot | [precommit-trigger.yml](precommit-trigger.yml) | Pre-commit bot for PR |
|
|
||||||
| Test Llama Stack Build | [providers-build.yml](providers-build.yml) | Test llama stack build |
|
| Test Llama Stack Build | [providers-build.yml](providers-build.yml) | Test llama stack build |
|
||||||
| Test llama stack list-deps | [providers-list-deps.yml](providers-list-deps.yml) | Test llama stack list-deps |
|
| Test llama stack list-deps | [providers-list-deps.yml](providers-list-deps.yml) | Test llama stack list-deps |
|
||||||
| Python Package Build Test | [python-build-test.yml](python-build-test.yml) | Test building the llama-stack PyPI project |
|
| Python Package Build Test | [python-build-test.yml](python-build-test.yml) | Test building the llama-stack PyPI project |
|
||||||
|
|
|
||||||
6
.github/workflows/backward-compat.yml
vendored
6
.github/workflows/backward-compat.yml
vendored
|
|
@ -4,7 +4,11 @@ run-name: Check backward compatibility for run.yaml configs
|
||||||
|
|
||||||
on:
|
on:
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [main]
|
branches:
|
||||||
|
- main
|
||||||
|
- 'release-[0-9]+.[0-9]+.[0-9]+.[0-9]+'
|
||||||
|
- 'release-[0-9]+.[0-9]+.[0-9]+'
|
||||||
|
- 'release-[0-9]+.[0-9]+'
|
||||||
paths:
|
paths:
|
||||||
- 'src/llama_stack/core/datatypes.py'
|
- 'src/llama_stack/core/datatypes.py'
|
||||||
- 'src/llama_stack/providers/datatypes.py'
|
- 'src/llama_stack/providers/datatypes.py'
|
||||||
|
|
|
||||||
10
.github/workflows/install-script-ci.yml
vendored
10
.github/workflows/install-script-ci.yml
vendored
|
|
@ -30,10 +30,16 @@ jobs:
|
||||||
|
|
||||||
- name: Build a single provider
|
- name: Build a single provider
|
||||||
run: |
|
run: |
|
||||||
|
BUILD_ARGS="--build-arg INSTALL_MODE=editable --build-arg DISTRO_NAME=starter"
|
||||||
|
if [ -n "${UV_EXTRA_INDEX_URL:-}" ]; then
|
||||||
|
BUILD_ARGS="$BUILD_ARGS --build-arg UV_EXTRA_INDEX_URL=$UV_EXTRA_INDEX_URL"
|
||||||
|
fi
|
||||||
|
if [ -n "${UV_INDEX_STRATEGY:-}" ]; then
|
||||||
|
BUILD_ARGS="$BUILD_ARGS --build-arg UV_INDEX_STRATEGY=$UV_INDEX_STRATEGY"
|
||||||
|
fi
|
||||||
docker build . \
|
docker build . \
|
||||||
-f containers/Containerfile \
|
-f containers/Containerfile \
|
||||||
--build-arg INSTALL_MODE=editable \
|
$BUILD_ARGS \
|
||||||
--build-arg DISTRO_NAME=starter \
|
|
||||||
--tag llama-stack:starter-ci
|
--tag llama-stack:starter-ci
|
||||||
|
|
||||||
- name: Run installer end-to-end
|
- name: Run installer end-to-end
|
||||||
|
|
|
||||||
8
.github/workflows/integration-auth-tests.yml
vendored
8
.github/workflows/integration-auth-tests.yml
vendored
|
|
@ -4,9 +4,13 @@ run-name: Run the integration test suite with Kubernetes authentication
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches: [ main ]
|
branches:
|
||||||
|
- main
|
||||||
|
- 'release-[0-9]+.[0-9]+.x'
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ main ]
|
branches:
|
||||||
|
- main
|
||||||
|
- 'release-[0-9]+.[0-9]+.x'
|
||||||
paths:
|
paths:
|
||||||
- 'distributions/**'
|
- 'distributions/**'
|
||||||
- 'src/llama_stack/**'
|
- 'src/llama_stack/**'
|
||||||
|
|
|
||||||
|
|
@ -4,9 +4,13 @@ run-name: Run the integration test suite with SqlStore
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches: [ main ]
|
branches:
|
||||||
|
- main
|
||||||
|
- 'release-[0-9]+.[0-9]+.x'
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ main ]
|
branches:
|
||||||
|
- main
|
||||||
|
- 'release-[0-9]+.[0-9]+.x'
|
||||||
paths:
|
paths:
|
||||||
- 'src/llama_stack/providers/utils/sqlstore/**'
|
- 'src/llama_stack/providers/utils/sqlstore/**'
|
||||||
- 'tests/integration/sqlstore/**'
|
- 'tests/integration/sqlstore/**'
|
||||||
|
|
|
||||||
10
.github/workflows/integration-tests.yml
vendored
10
.github/workflows/integration-tests.yml
vendored
|
|
@ -4,9 +4,13 @@ run-name: Run the integration test suites from tests/integration in replay mode
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches: [ main ]
|
branches:
|
||||||
|
- main
|
||||||
|
- 'release-[0-9]+.[0-9]+.x'
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ main ]
|
branches:
|
||||||
|
- main
|
||||||
|
- 'release-[0-9]+.[0-9]+.x'
|
||||||
types: [opened, synchronize, reopened]
|
types: [opened, synchronize, reopened]
|
||||||
paths:
|
paths:
|
||||||
- 'src/llama_stack/**'
|
- 'src/llama_stack/**'
|
||||||
|
|
@ -47,7 +51,7 @@ jobs:
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
client-type: [library, docker]
|
client-type: [library, docker, server]
|
||||||
# Use Python 3.13 only on nightly schedule (daily latest client test), otherwise use 3.12
|
# Use Python 3.13 only on nightly schedule (daily latest client test), otherwise use 3.12
|
||||||
python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }}
|
python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }}
|
||||||
client-version: ${{ (github.event.schedule == '0 0 * * *' || github.event.inputs.test-all-client-versions == 'true') && fromJSON('["published", "latest"]') || fromJSON('["latest"]') }}
|
client-version: ${{ (github.event.schedule == '0 0 * * *' || github.event.inputs.test-all-client-versions == 'true') && fromJSON('["published", "latest"]') || fromJSON('["latest"]') }}
|
||||||
|
|
|
||||||
|
|
@ -4,9 +4,13 @@ run-name: Run the integration test suite with various VectorIO providers
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches: [ main ]
|
branches:
|
||||||
|
- main
|
||||||
|
- 'release-[0-9]+.[0-9]+.x'
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ main ]
|
branches:
|
||||||
|
- main
|
||||||
|
- 'release-[0-9]+.[0-9]+.x'
|
||||||
paths:
|
paths:
|
||||||
- 'src/llama_stack/**'
|
- 'src/llama_stack/**'
|
||||||
- '!src/llama_stack/ui/**'
|
- '!src/llama_stack/ui/**'
|
||||||
|
|
|
||||||
52
.github/workflows/pre-commit.yml
vendored
52
.github/workflows/pre-commit.yml
vendored
|
|
@ -5,7 +5,9 @@ run-name: Run pre-commit checks
|
||||||
on:
|
on:
|
||||||
pull_request:
|
pull_request:
|
||||||
push:
|
push:
|
||||||
branches: [main]
|
branches:
|
||||||
|
- main
|
||||||
|
- 'release-[0-9]+.[0-9]+.x'
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
|
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
|
||||||
|
|
@ -50,19 +52,34 @@ jobs:
|
||||||
run: npm ci
|
run: npm ci
|
||||||
working-directory: src/llama_stack/ui
|
working-directory: src/llama_stack/ui
|
||||||
|
|
||||||
|
- name: Install pre-commit
|
||||||
|
run: python -m pip install pre-commit
|
||||||
|
|
||||||
|
- name: Cache pre-commit
|
||||||
|
uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # v4
|
||||||
|
with:
|
||||||
|
path: ~/.cache/pre-commit
|
||||||
|
key: pre-commit-3|${{ env.pythonLocation }}|${{ hashFiles('.pre-commit-config.yaml') }}
|
||||||
|
|
||||||
- name: Run pre-commit
|
- name: Run pre-commit
|
||||||
id: precommit
|
id: precommit
|
||||||
uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
|
run: |
|
||||||
continue-on-error: true
|
set +e
|
||||||
|
pre-commit run --show-diff-on-failure --color=always --all-files 2>&1 | tee /tmp/precommit.log
|
||||||
|
status=${PIPESTATUS[0]}
|
||||||
|
echo "status=$status" >> $GITHUB_OUTPUT
|
||||||
|
exit 0
|
||||||
env:
|
env:
|
||||||
SKIP: no-commit-to-branch,mypy
|
SKIP: no-commit-to-branch,mypy
|
||||||
RUFF_OUTPUT_FORMAT: github
|
RUFF_OUTPUT_FORMAT: github
|
||||||
|
|
||||||
- name: Check pre-commit results
|
- name: Check pre-commit results
|
||||||
if: steps.precommit.outcome == 'failure'
|
if: steps.precommit.outputs.status != '0'
|
||||||
run: |
|
run: |
|
||||||
echo "::error::Pre-commit hooks failed. Please run 'pre-commit run --all-files' locally and commit the fixes."
|
echo "::error::Pre-commit hooks failed. Please run 'pre-commit run --all-files' locally and commit the fixes."
|
||||||
echo "::warning::Some pre-commit hooks failed. Check the output above for details."
|
echo ""
|
||||||
|
echo "Failed hooks output:"
|
||||||
|
cat /tmp/precommit.log
|
||||||
exit 1
|
exit 1
|
||||||
|
|
||||||
- name: Debug
|
- name: Debug
|
||||||
|
|
@ -113,11 +130,34 @@ jobs:
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
- name: Configure client installation
|
||||||
|
id: client-config
|
||||||
|
uses: ./.github/actions/install-llama-stack-client
|
||||||
|
|
||||||
- name: Sync dev + type_checking dependencies
|
- name: Sync dev + type_checking dependencies
|
||||||
run: uv sync --group dev --group type_checking
|
env:
|
||||||
|
UV_EXTRA_INDEX_URL: ${{ steps.client-config.outputs.uv-extra-index-url }}
|
||||||
|
run: |
|
||||||
|
if [ -n "$UV_EXTRA_INDEX_URL" ]; then
|
||||||
|
export UV_INDEX_STRATEGY="unsafe-best-match"
|
||||||
|
fi
|
||||||
|
|
||||||
|
uv sync --group dev --group type_checking
|
||||||
|
|
||||||
|
# Install specific client version after sync if needed
|
||||||
|
if [ "${{ steps.client-config.outputs.install-after-sync }}" = "true" ]; then
|
||||||
|
echo "Installing llama-stack-client from: ${{ steps.client-config.outputs.install-source }}"
|
||||||
|
uv pip install ${{ steps.client-config.outputs.install-source }}
|
||||||
|
fi
|
||||||
|
|
||||||
- name: Run mypy (full type_checking)
|
- name: Run mypy (full type_checking)
|
||||||
|
env:
|
||||||
|
UV_EXTRA_INDEX_URL: ${{ steps.client-config.outputs.uv-extra-index-url }}
|
||||||
run: |
|
run: |
|
||||||
|
if [ -n "$UV_EXTRA_INDEX_URL" ]; then
|
||||||
|
export UV_INDEX_STRATEGY="unsafe-best-match"
|
||||||
|
fi
|
||||||
|
|
||||||
set +e
|
set +e
|
||||||
uv run --group dev --group type_checking mypy
|
uv run --group dev --group type_checking mypy
|
||||||
status=$?
|
status=$?
|
||||||
|
|
|
||||||
227
.github/workflows/precommit-trigger.yml
vendored
227
.github/workflows/precommit-trigger.yml
vendored
|
|
@ -1,227 +0,0 @@
|
||||||
name: Pre-commit Bot
|
|
||||||
|
|
||||||
run-name: Pre-commit bot for PR #${{ github.event.issue.number }}
|
|
||||||
|
|
||||||
on:
|
|
||||||
issue_comment:
|
|
||||||
types: [created]
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
pre-commit:
|
|
||||||
# Only run on pull request comments
|
|
||||||
if: github.event.issue.pull_request && contains(github.event.comment.body, '@github-actions run precommit')
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
permissions:
|
|
||||||
contents: write
|
|
||||||
pull-requests: write
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Check comment author and get PR details
|
|
||||||
id: check_author
|
|
||||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
|
||||||
with:
|
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
script: |
|
|
||||||
// Get PR details
|
|
||||||
const pr = await github.rest.pulls.get({
|
|
||||||
owner: context.repo.owner,
|
|
||||||
repo: context.repo.repo,
|
|
||||||
pull_number: context.issue.number
|
|
||||||
});
|
|
||||||
|
|
||||||
// Check if commenter has write access or is the PR author
|
|
||||||
const commenter = context.payload.comment.user.login;
|
|
||||||
const prAuthor = pr.data.user.login;
|
|
||||||
|
|
||||||
let hasPermission = false;
|
|
||||||
|
|
||||||
// Check if commenter is PR author
|
|
||||||
if (commenter === prAuthor) {
|
|
||||||
hasPermission = true;
|
|
||||||
console.log(`Comment author ${commenter} is the PR author`);
|
|
||||||
} else {
|
|
||||||
// Check if commenter has write/admin access
|
|
||||||
try {
|
|
||||||
const permission = await github.rest.repos.getCollaboratorPermissionLevel({
|
|
||||||
owner: context.repo.owner,
|
|
||||||
repo: context.repo.repo,
|
|
||||||
username: commenter
|
|
||||||
});
|
|
||||||
|
|
||||||
const level = permission.data.permission;
|
|
||||||
hasPermission = ['write', 'admin', 'maintain'].includes(level);
|
|
||||||
console.log(`Comment author ${commenter} has permission: ${level}`);
|
|
||||||
} catch (error) {
|
|
||||||
console.log(`Could not check permissions for ${commenter}: ${error.message}`);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!hasPermission) {
|
|
||||||
await github.rest.issues.createComment({
|
|
||||||
owner: context.repo.owner,
|
|
||||||
repo: context.repo.repo,
|
|
||||||
issue_number: context.issue.number,
|
|
||||||
body: `❌ @${commenter} You don't have permission to trigger pre-commit. Only PR authors or repository collaborators can run this command.`
|
|
||||||
});
|
|
||||||
core.setFailed(`User ${commenter} does not have permission`);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Save PR info for later steps
|
|
||||||
core.setOutput('pr_number', context.issue.number);
|
|
||||||
core.setOutput('pr_head_ref', pr.data.head.ref);
|
|
||||||
core.setOutput('pr_head_sha', pr.data.head.sha);
|
|
||||||
core.setOutput('pr_head_repo', pr.data.head.repo.full_name);
|
|
||||||
core.setOutput('pr_base_ref', pr.data.base.ref);
|
|
||||||
core.setOutput('is_fork', pr.data.head.repo.full_name !== context.payload.repository.full_name);
|
|
||||||
core.setOutput('authorized', 'true');
|
|
||||||
|
|
||||||
- name: React to comment
|
|
||||||
if: steps.check_author.outputs.authorized == 'true'
|
|
||||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
|
||||||
with:
|
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
script: |
|
|
||||||
await github.rest.reactions.createForIssueComment({
|
|
||||||
owner: context.repo.owner,
|
|
||||||
repo: context.repo.repo,
|
|
||||||
comment_id: context.payload.comment.id,
|
|
||||||
content: 'rocket'
|
|
||||||
});
|
|
||||||
|
|
||||||
- name: Comment starting
|
|
||||||
if: steps.check_author.outputs.authorized == 'true'
|
|
||||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
|
||||||
with:
|
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
script: |
|
|
||||||
await github.rest.issues.createComment({
|
|
||||||
owner: context.repo.owner,
|
|
||||||
repo: context.repo.repo,
|
|
||||||
issue_number: ${{ steps.check_author.outputs.pr_number }},
|
|
||||||
body: `⏳ Running [pre-commit hooks](https://github.com/${context.repo.owner}/${context.repo.repo}/actions/runs/${context.runId}) on PR #${{ steps.check_author.outputs.pr_number }}...`
|
|
||||||
});
|
|
||||||
|
|
||||||
- name: Checkout PR branch (same-repo)
|
|
||||||
if: steps.check_author.outputs.authorized == 'true' && steps.check_author.outputs.is_fork == 'false'
|
|
||||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
|
||||||
with:
|
|
||||||
ref: ${{ steps.check_author.outputs.pr_head_ref }}
|
|
||||||
fetch-depth: 0
|
|
||||||
token: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
|
|
||||||
- name: Checkout PR branch (fork)
|
|
||||||
if: steps.check_author.outputs.authorized == 'true' && steps.check_author.outputs.is_fork == 'true'
|
|
||||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
|
||||||
with:
|
|
||||||
repository: ${{ steps.check_author.outputs.pr_head_repo }}
|
|
||||||
ref: ${{ steps.check_author.outputs.pr_head_ref }}
|
|
||||||
fetch-depth: 0
|
|
||||||
token: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
|
|
||||||
- name: Verify checkout
|
|
||||||
if: steps.check_author.outputs.authorized == 'true'
|
|
||||||
run: |
|
|
||||||
echo "Current SHA: $(git rev-parse HEAD)"
|
|
||||||
echo "Expected SHA: ${{ steps.check_author.outputs.pr_head_sha }}"
|
|
||||||
if [[ "$(git rev-parse HEAD)" != "${{ steps.check_author.outputs.pr_head_sha }}" ]]; then
|
|
||||||
echo "::error::Checked out SHA does not match expected SHA"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
- name: Set up Python
|
|
||||||
if: steps.check_author.outputs.authorized == 'true'
|
|
||||||
uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0
|
|
||||||
with:
|
|
||||||
python-version: '3.12'
|
|
||||||
cache: pip
|
|
||||||
cache-dependency-path: |
|
|
||||||
**/requirements*.txt
|
|
||||||
.pre-commit-config.yaml
|
|
||||||
|
|
||||||
- name: Set up Node.js
|
|
||||||
if: steps.check_author.outputs.authorized == 'true'
|
|
||||||
uses: actions/setup-node@2028fbc5c25fe9cf00d9f06a71cc4710d4507903 # v6.0.0
|
|
||||||
with:
|
|
||||||
node-version: '20'
|
|
||||||
cache: 'npm'
|
|
||||||
cache-dependency-path: 'src/llama_stack/ui/'
|
|
||||||
|
|
||||||
- name: Install npm dependencies
|
|
||||||
if: steps.check_author.outputs.authorized == 'true'
|
|
||||||
run: npm ci
|
|
||||||
working-directory: src/llama_stack/ui
|
|
||||||
|
|
||||||
- name: Run pre-commit
|
|
||||||
if: steps.check_author.outputs.authorized == 'true'
|
|
||||||
id: precommit
|
|
||||||
uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
|
|
||||||
continue-on-error: true
|
|
||||||
env:
|
|
||||||
SKIP: no-commit-to-branch
|
|
||||||
RUFF_OUTPUT_FORMAT: github
|
|
||||||
|
|
||||||
- name: Check for changes
|
|
||||||
if: steps.check_author.outputs.authorized == 'true'
|
|
||||||
id: changes
|
|
||||||
run: |
|
|
||||||
if ! git diff --exit-code || [ -n "$(git ls-files --others --exclude-standard)" ]; then
|
|
||||||
echo "has_changes=true" >> $GITHUB_OUTPUT
|
|
||||||
echo "Changes detected after pre-commit"
|
|
||||||
else
|
|
||||||
echo "has_changes=false" >> $GITHUB_OUTPUT
|
|
||||||
echo "No changes after pre-commit"
|
|
||||||
fi
|
|
||||||
|
|
||||||
- name: Commit and push changes
|
|
||||||
if: steps.check_author.outputs.authorized == 'true' && steps.changes.outputs.has_changes == 'true'
|
|
||||||
run: |
|
|
||||||
git config --local user.email "github-actions[bot]@users.noreply.github.com"
|
|
||||||
git config --local user.name "github-actions[bot]"
|
|
||||||
|
|
||||||
git add -A
|
|
||||||
git commit -m "style: apply pre-commit fixes
|
|
||||||
|
|
||||||
🤖 Applied by @github-actions bot via pre-commit workflow"
|
|
||||||
|
|
||||||
# Push changes
|
|
||||||
git push origin HEAD:${{ steps.check_author.outputs.pr_head_ref }}
|
|
||||||
|
|
||||||
- name: Comment success with changes
|
|
||||||
if: steps.check_author.outputs.authorized == 'true' && steps.changes.outputs.has_changes == 'true'
|
|
||||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
|
||||||
with:
|
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
script: |
|
|
||||||
await github.rest.issues.createComment({
|
|
||||||
owner: context.repo.owner,
|
|
||||||
repo: context.repo.repo,
|
|
||||||
issue_number: ${{ steps.check_author.outputs.pr_number }},
|
|
||||||
body: `✅ Pre-commit hooks completed successfully!\n\n🔧 Changes have been committed and pushed to the PR branch.`
|
|
||||||
});
|
|
||||||
|
|
||||||
- name: Comment success without changes
|
|
||||||
if: steps.check_author.outputs.authorized == 'true' && steps.changes.outputs.has_changes == 'false' && steps.precommit.outcome == 'success'
|
|
||||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
|
||||||
with:
|
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
script: |
|
|
||||||
await github.rest.issues.createComment({
|
|
||||||
owner: context.repo.owner,
|
|
||||||
repo: context.repo.repo,
|
|
||||||
issue_number: ${{ steps.check_author.outputs.pr_number }},
|
|
||||||
body: `✅ Pre-commit hooks passed!\n\n✨ No changes needed - your code is already formatted correctly.`
|
|
||||||
});
|
|
||||||
|
|
||||||
- name: Comment failure
|
|
||||||
if: failure()
|
|
||||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
|
||||||
with:
|
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
script: |
|
|
||||||
await github.rest.issues.createComment({
|
|
||||||
owner: context.repo.owner,
|
|
||||||
repo: context.repo.repo,
|
|
||||||
issue_number: ${{ steps.check_author.outputs.pr_number }},
|
|
||||||
body: `❌ Pre-commit workflow failed!\n\nPlease check the [workflow logs](https://github.com/${context.repo.owner}/${context.repo.repo}/actions/runs/${context.runId}) for details.`
|
|
||||||
});
|
|
||||||
38
.github/workflows/providers-build.yml
vendored
38
.github/workflows/providers-build.yml
vendored
|
|
@ -72,10 +72,16 @@ jobs:
|
||||||
- name: Build container image
|
- name: Build container image
|
||||||
if: matrix.image-type == 'container'
|
if: matrix.image-type == 'container'
|
||||||
run: |
|
run: |
|
||||||
|
BUILD_ARGS="--build-arg INSTALL_MODE=editable --build-arg DISTRO_NAME=${{ matrix.distro }}"
|
||||||
|
if [ -n "${UV_EXTRA_INDEX_URL:-}" ]; then
|
||||||
|
BUILD_ARGS="$BUILD_ARGS --build-arg UV_EXTRA_INDEX_URL=$UV_EXTRA_INDEX_URL"
|
||||||
|
fi
|
||||||
|
if [ -n "${UV_INDEX_STRATEGY:-}" ]; then
|
||||||
|
BUILD_ARGS="$BUILD_ARGS --build-arg UV_INDEX_STRATEGY=$UV_INDEX_STRATEGY"
|
||||||
|
fi
|
||||||
docker build . \
|
docker build . \
|
||||||
-f containers/Containerfile \
|
-f containers/Containerfile \
|
||||||
--build-arg INSTALL_MODE=editable \
|
$BUILD_ARGS \
|
||||||
--build-arg DISTRO_NAME=${{ matrix.distro }} \
|
|
||||||
--tag llama-stack:${{ matrix.distro }}-ci
|
--tag llama-stack:${{ matrix.distro }}-ci
|
||||||
|
|
||||||
- name: Print dependencies in the image
|
- name: Print dependencies in the image
|
||||||
|
|
@ -108,12 +114,18 @@ jobs:
|
||||||
- name: Build container image
|
- name: Build container image
|
||||||
run: |
|
run: |
|
||||||
BASE_IMAGE=$(yq -r '.distribution_spec.container_image // "python:3.12-slim"' src/llama_stack/distributions/ci-tests/build.yaml)
|
BASE_IMAGE=$(yq -r '.distribution_spec.container_image // "python:3.12-slim"' src/llama_stack/distributions/ci-tests/build.yaml)
|
||||||
|
BUILD_ARGS="--build-arg INSTALL_MODE=editable --build-arg DISTRO_NAME=ci-tests"
|
||||||
|
BUILD_ARGS="$BUILD_ARGS --build-arg BASE_IMAGE=$BASE_IMAGE"
|
||||||
|
BUILD_ARGS="$BUILD_ARGS --build-arg RUN_CONFIG_PATH=/workspace/src/llama_stack/distributions/ci-tests/run.yaml"
|
||||||
|
if [ -n "${UV_EXTRA_INDEX_URL:-}" ]; then
|
||||||
|
BUILD_ARGS="$BUILD_ARGS --build-arg UV_EXTRA_INDEX_URL=$UV_EXTRA_INDEX_URL"
|
||||||
|
fi
|
||||||
|
if [ -n "${UV_INDEX_STRATEGY:-}" ]; then
|
||||||
|
BUILD_ARGS="$BUILD_ARGS --build-arg UV_INDEX_STRATEGY=$UV_INDEX_STRATEGY"
|
||||||
|
fi
|
||||||
docker build . \
|
docker build . \
|
||||||
-f containers/Containerfile \
|
-f containers/Containerfile \
|
||||||
--build-arg INSTALL_MODE=editable \
|
$BUILD_ARGS \
|
||||||
--build-arg DISTRO_NAME=ci-tests \
|
|
||||||
--build-arg BASE_IMAGE="$BASE_IMAGE" \
|
|
||||||
--build-arg RUN_CONFIG_PATH=/workspace/src/llama_stack/distributions/ci-tests/run.yaml \
|
|
||||||
-t llama-stack:ci-tests
|
-t llama-stack:ci-tests
|
||||||
|
|
||||||
- name: Inspect the container image entrypoint
|
- name: Inspect the container image entrypoint
|
||||||
|
|
@ -148,12 +160,18 @@ jobs:
|
||||||
- name: Build UBI9 container image
|
- name: Build UBI9 container image
|
||||||
run: |
|
run: |
|
||||||
BASE_IMAGE=$(yq -r '.distribution_spec.container_image // "registry.access.redhat.com/ubi9:latest"' src/llama_stack/distributions/ci-tests/build.yaml)
|
BASE_IMAGE=$(yq -r '.distribution_spec.container_image // "registry.access.redhat.com/ubi9:latest"' src/llama_stack/distributions/ci-tests/build.yaml)
|
||||||
|
BUILD_ARGS="--build-arg INSTALL_MODE=editable --build-arg DISTRO_NAME=ci-tests"
|
||||||
|
BUILD_ARGS="$BUILD_ARGS --build-arg BASE_IMAGE=$BASE_IMAGE"
|
||||||
|
BUILD_ARGS="$BUILD_ARGS --build-arg RUN_CONFIG_PATH=/workspace/src/llama_stack/distributions/ci-tests/run.yaml"
|
||||||
|
if [ -n "${UV_EXTRA_INDEX_URL:-}" ]; then
|
||||||
|
BUILD_ARGS="$BUILD_ARGS --build-arg UV_EXTRA_INDEX_URL=$UV_EXTRA_INDEX_URL"
|
||||||
|
fi
|
||||||
|
if [ -n "${UV_INDEX_STRATEGY:-}" ]; then
|
||||||
|
BUILD_ARGS="$BUILD_ARGS --build-arg UV_INDEX_STRATEGY=$UV_INDEX_STRATEGY"
|
||||||
|
fi
|
||||||
docker build . \
|
docker build . \
|
||||||
-f containers/Containerfile \
|
-f containers/Containerfile \
|
||||||
--build-arg INSTALL_MODE=editable \
|
$BUILD_ARGS \
|
||||||
--build-arg DISTRO_NAME=ci-tests \
|
|
||||||
--build-arg BASE_IMAGE="$BASE_IMAGE" \
|
|
||||||
--build-arg RUN_CONFIG_PATH=/workspace/src/llama_stack/distributions/ci-tests/run.yaml \
|
|
||||||
-t llama-stack:ci-tests-ubi9
|
-t llama-stack:ci-tests-ubi9
|
||||||
|
|
||||||
- name: Inspect UBI9 image
|
- name: Inspect UBI9 image
|
||||||
|
|
|
||||||
8
.github/workflows/unit-tests.yml
vendored
8
.github/workflows/unit-tests.yml
vendored
|
|
@ -4,9 +4,13 @@ run-name: Run the unit test suite
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches: [ main ]
|
branches:
|
||||||
|
- main
|
||||||
|
- 'release-[0-9]+.[0-9]+.x'
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ main ]
|
branches:
|
||||||
|
- main
|
||||||
|
- 'release-[0-9]+.[0-9]+.x'
|
||||||
paths:
|
paths:
|
||||||
- 'src/llama_stack/**'
|
- 'src/llama_stack/**'
|
||||||
- '!src/llama_stack/ui/**'
|
- '!src/llama_stack/ui/**'
|
||||||
|
|
|
||||||
|
|
@ -52,10 +52,6 @@ repos:
|
||||||
additional_dependencies:
|
additional_dependencies:
|
||||||
- black==24.3.0
|
- black==24.3.0
|
||||||
|
|
||||||
- repo: https://github.com/astral-sh/uv-pre-commit
|
|
||||||
rev: 0.7.20
|
|
||||||
hooks:
|
|
||||||
- id: uv-lock
|
|
||||||
|
|
||||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||||
rev: v1.18.2
|
rev: v1.18.2
|
||||||
|
|
@ -63,22 +59,13 @@ repos:
|
||||||
- id: mypy
|
- id: mypy
|
||||||
additional_dependencies:
|
additional_dependencies:
|
||||||
- uv==0.6.2
|
- uv==0.6.2
|
||||||
|
- mypy
|
||||||
- pytest
|
- pytest
|
||||||
- rich
|
- rich
|
||||||
- types-requests
|
- types-requests
|
||||||
- pydantic
|
- pydantic
|
||||||
- httpx
|
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
|
|
||||||
- repo: local
|
|
||||||
hooks:
|
|
||||||
- id: mypy-full
|
|
||||||
name: mypy (full type_checking)
|
|
||||||
entry: uv run --group dev --group type_checking mypy
|
|
||||||
language: system
|
|
||||||
pass_filenames: false
|
|
||||||
stages: [manual]
|
|
||||||
|
|
||||||
# - repo: https://github.com/tcort/markdown-link-check
|
# - repo: https://github.com/tcort/markdown-link-check
|
||||||
# rev: v3.11.2
|
# rev: v3.11.2
|
||||||
# hooks:
|
# hooks:
|
||||||
|
|
@ -87,11 +74,26 @@ repos:
|
||||||
|
|
||||||
- repo: local
|
- repo: local
|
||||||
hooks:
|
hooks:
|
||||||
|
- id: uv-lock
|
||||||
|
name: uv-lock
|
||||||
|
additional_dependencies:
|
||||||
|
- uv==0.7.20
|
||||||
|
entry: ./scripts/uv-run-with-index.sh lock
|
||||||
|
language: python
|
||||||
|
pass_filenames: false
|
||||||
|
require_serial: true
|
||||||
|
files: ^(pyproject\.toml|uv\.lock)$
|
||||||
|
- id: mypy-full
|
||||||
|
name: mypy (full type_checking)
|
||||||
|
entry: ./scripts/uv-run-with-index.sh run --group dev --group type_checking mypy
|
||||||
|
language: system
|
||||||
|
pass_filenames: false
|
||||||
|
stages: [manual]
|
||||||
- id: distro-codegen
|
- id: distro-codegen
|
||||||
name: Distribution Template Codegen
|
name: Distribution Template Codegen
|
||||||
additional_dependencies:
|
additional_dependencies:
|
||||||
- uv==0.7.8
|
- uv==0.7.8
|
||||||
entry: uv run --group codegen ./scripts/distro_codegen.py
|
entry: ./scripts/uv-run-with-index.sh run --group codegen ./scripts/distro_codegen.py
|
||||||
language: python
|
language: python
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
require_serial: true
|
require_serial: true
|
||||||
|
|
@ -100,7 +102,7 @@ repos:
|
||||||
name: Provider Codegen
|
name: Provider Codegen
|
||||||
additional_dependencies:
|
additional_dependencies:
|
||||||
- uv==0.7.8
|
- uv==0.7.8
|
||||||
entry: uv run --group codegen ./scripts/provider_codegen.py
|
entry: ./scripts/uv-run-with-index.sh run --group codegen ./scripts/provider_codegen.py
|
||||||
language: python
|
language: python
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
require_serial: true
|
require_serial: true
|
||||||
|
|
@ -109,7 +111,7 @@ repos:
|
||||||
name: API Spec Codegen
|
name: API Spec Codegen
|
||||||
additional_dependencies:
|
additional_dependencies:
|
||||||
- uv==0.7.8
|
- uv==0.7.8
|
||||||
entry: sh -c 'uv run ./docs/openapi_generator/run_openapi_generator.sh > /dev/null'
|
entry: sh -c './scripts/uv-run-with-index.sh run ./docs/openapi_generator/run_openapi_generator.sh > /dev/null'
|
||||||
language: python
|
language: python
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
require_serial: true
|
require_serial: true
|
||||||
|
|
@ -150,7 +152,7 @@ repos:
|
||||||
name: Generate CI documentation
|
name: Generate CI documentation
|
||||||
additional_dependencies:
|
additional_dependencies:
|
||||||
- uv==0.7.8
|
- uv==0.7.8
|
||||||
entry: uv run ./scripts/gen-ci-docs.py
|
entry: ./scripts/uv-run-with-index.sh run ./scripts/gen-ci-docs.py
|
||||||
language: python
|
language: python
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
require_serial: true
|
require_serial: true
|
||||||
|
|
@ -162,6 +164,7 @@ repos:
|
||||||
files: ^src/llama_stack/ui/.*\.(ts|tsx)$
|
files: ^src/llama_stack/ui/.*\.(ts|tsx)$
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
require_serial: true
|
require_serial: true
|
||||||
|
|
||||||
- id: check-log-usage
|
- id: check-log-usage
|
||||||
name: Ensure 'llama_stack.log' usage for logging
|
name: Ensure 'llama_stack.log' usage for logging
|
||||||
entry: bash
|
entry: bash
|
||||||
|
|
@ -197,6 +200,7 @@ repos:
|
||||||
echo;
|
echo;
|
||||||
exit 1;
|
exit 1;
|
||||||
} || true
|
} || true
|
||||||
|
|
||||||
ci:
|
ci:
|
||||||
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
|
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
|
||||||
autoupdate_commit_msg: ⬆ [pre-commit.ci] pre-commit autoupdate
|
autoupdate_commit_msg: ⬆ [pre-commit.ci] pre-commit autoupdate
|
||||||
|
|
|
||||||
|
|
@ -956,7 +956,22 @@ paths:
|
||||||
List routes.
|
List routes.
|
||||||
|
|
||||||
List all available API routes with their methods and implementing providers.
|
List all available API routes with their methods and implementing providers.
|
||||||
parameters: []
|
parameters:
|
||||||
|
- name: api_filter
|
||||||
|
in: query
|
||||||
|
description: >-
|
||||||
|
Optional filter to control which routes are returned. Can be an API level
|
||||||
|
('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level,
|
||||||
|
or 'deprecated' to show deprecated routes across all levels. If not specified,
|
||||||
|
returns only non-deprecated v1 routes.
|
||||||
|
required: false
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
enum:
|
||||||
|
- v1
|
||||||
|
- v1alpha
|
||||||
|
- v1beta
|
||||||
|
- deprecated
|
||||||
deprecated: false
|
deprecated: false
|
||||||
/v1/models:
|
/v1/models:
|
||||||
get:
|
get:
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,8 @@ ARG KEEP_WORKSPACE=""
|
||||||
ARG DISTRO_NAME="starter"
|
ARG DISTRO_NAME="starter"
|
||||||
ARG RUN_CONFIG_PATH=""
|
ARG RUN_CONFIG_PATH=""
|
||||||
ARG UV_HTTP_TIMEOUT=500
|
ARG UV_HTTP_TIMEOUT=500
|
||||||
|
ARG UV_EXTRA_INDEX_URL=""
|
||||||
|
ARG UV_INDEX_STRATEGY=""
|
||||||
ENV UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT}
|
ENV UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT}
|
||||||
ENV PYTHONDONTWRITEBYTECODE=1
|
ENV PYTHONDONTWRITEBYTECODE=1
|
||||||
ENV PIP_DISABLE_PIP_VERSION_CHECK=1
|
ENV PIP_DISABLE_PIP_VERSION_CHECK=1
|
||||||
|
|
@ -45,7 +47,7 @@ RUN set -eux; \
|
||||||
exit 1; \
|
exit 1; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
RUN pip install --no-cache uv
|
RUN pip install --no-cache-dir uv
|
||||||
ENV UV_SYSTEM_PYTHON=1
|
ENV UV_SYSTEM_PYTHON=1
|
||||||
|
|
||||||
ENV INSTALL_MODE=${INSTALL_MODE}
|
ENV INSTALL_MODE=${INSTALL_MODE}
|
||||||
|
|
@ -62,47 +64,60 @@ COPY . /workspace
|
||||||
|
|
||||||
# Install the client package if it is provided
|
# Install the client package if it is provided
|
||||||
# NOTE: this is installed before llama-stack since llama-stack depends on llama-stack-client-python
|
# NOTE: this is installed before llama-stack since llama-stack depends on llama-stack-client-python
|
||||||
|
# Unset UV index env vars to ensure we only use PyPI for the client
|
||||||
RUN set -eux; \
|
RUN set -eux; \
|
||||||
|
unset UV_EXTRA_INDEX_URL UV_INDEX_STRATEGY; \
|
||||||
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then \
|
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then \
|
||||||
if [ ! -d "$LLAMA_STACK_CLIENT_DIR" ]; then \
|
if [ ! -d "$LLAMA_STACK_CLIENT_DIR" ]; then \
|
||||||
echo "LLAMA_STACK_CLIENT_DIR is set but $LLAMA_STACK_CLIENT_DIR does not exist" >&2; \
|
echo "LLAMA_STACK_CLIENT_DIR is set but $LLAMA_STACK_CLIENT_DIR does not exist" >&2; \
|
||||||
exit 1; \
|
exit 1; \
|
||||||
fi; \
|
fi; \
|
||||||
uv pip install --no-cache -e "$LLAMA_STACK_CLIENT_DIR"; \
|
uv pip install --no-cache-dir -e "$LLAMA_STACK_CLIENT_DIR"; \
|
||||||
fi;
|
fi;
|
||||||
|
|
||||||
# Install llama-stack
|
# Install llama-stack
|
||||||
|
# Use UV_EXTRA_INDEX_URL inline only for editable install with RC dependencies
|
||||||
RUN set -eux; \
|
RUN set -eux; \
|
||||||
|
SAVED_UV_EXTRA_INDEX_URL="${UV_EXTRA_INDEX_URL:-}"; \
|
||||||
|
SAVED_UV_INDEX_STRATEGY="${UV_INDEX_STRATEGY:-}"; \
|
||||||
|
unset UV_EXTRA_INDEX_URL UV_INDEX_STRATEGY; \
|
||||||
if [ "$INSTALL_MODE" = "editable" ]; then \
|
if [ "$INSTALL_MODE" = "editable" ]; then \
|
||||||
if [ ! -d "$LLAMA_STACK_DIR" ]; then \
|
if [ ! -d "$LLAMA_STACK_DIR" ]; then \
|
||||||
echo "INSTALL_MODE=editable requires LLAMA_STACK_DIR to point to a directory inside the build context" >&2; \
|
echo "INSTALL_MODE=editable requires LLAMA_STACK_DIR to point to a directory inside the build context" >&2; \
|
||||||
exit 1; \
|
exit 1; \
|
||||||
fi; \
|
fi; \
|
||||||
uv pip install --no-cache -e "$LLAMA_STACK_DIR"; \
|
if [ -n "$SAVED_UV_EXTRA_INDEX_URL" ] && [ -n "$SAVED_UV_INDEX_STRATEGY" ]; then \
|
||||||
elif [ "$INSTALL_MODE" = "test-pypi" ]; then \
|
UV_EXTRA_INDEX_URL="$SAVED_UV_EXTRA_INDEX_URL" UV_INDEX_STRATEGY="$SAVED_UV_INDEX_STRATEGY" \
|
||||||
uv pip install --no-cache fastapi libcst; \
|
uv pip install --no-cache-dir -e "$LLAMA_STACK_DIR"; \
|
||||||
if [ -n "$TEST_PYPI_VERSION" ]; then \
|
|
||||||
uv pip install --no-cache --extra-index-url https://test.pypi.org/simple/ --index-strategy unsafe-best-match "llama-stack==$TEST_PYPI_VERSION"; \
|
|
||||||
else \
|
else \
|
||||||
uv pip install --no-cache --extra-index-url https://test.pypi.org/simple/ --index-strategy unsafe-best-match llama-stack; \
|
uv pip install --no-cache-dir -e "$LLAMA_STACK_DIR"; \
|
||||||
|
fi; \
|
||||||
|
elif [ "$INSTALL_MODE" = "test-pypi" ]; then \
|
||||||
|
uv pip install --no-cache-dir fastapi libcst; \
|
||||||
|
if [ -n "$TEST_PYPI_VERSION" ]; then \
|
||||||
|
uv pip install --no-cache-dir --extra-index-url https://test.pypi.org/simple/ --index-strategy unsafe-best-match "llama-stack==$TEST_PYPI_VERSION"; \
|
||||||
|
else \
|
||||||
|
uv pip install --no-cache-dir --extra-index-url https://test.pypi.org/simple/ --index-strategy unsafe-best-match llama-stack; \
|
||||||
fi; \
|
fi; \
|
||||||
else \
|
else \
|
||||||
if [ -n "$PYPI_VERSION" ]; then \
|
if [ -n "$PYPI_VERSION" ]; then \
|
||||||
uv pip install --no-cache "llama-stack==$PYPI_VERSION"; \
|
uv pip install --no-cache-dir "llama-stack==$PYPI_VERSION"; \
|
||||||
else \
|
else \
|
||||||
uv pip install --no-cache llama-stack; \
|
uv pip install --no-cache-dir llama-stack; \
|
||||||
fi; \
|
fi; \
|
||||||
fi;
|
fi;
|
||||||
|
|
||||||
# Install the dependencies for the distribution
|
# Install the dependencies for the distribution
|
||||||
|
# Explicitly unset UV index env vars to ensure we only use PyPI for distribution deps
|
||||||
RUN set -eux; \
|
RUN set -eux; \
|
||||||
|
unset UV_EXTRA_INDEX_URL UV_INDEX_STRATEGY; \
|
||||||
if [ -z "$DISTRO_NAME" ]; then \
|
if [ -z "$DISTRO_NAME" ]; then \
|
||||||
echo "DISTRO_NAME must be provided" >&2; \
|
echo "DISTRO_NAME must be provided" >&2; \
|
||||||
exit 1; \
|
exit 1; \
|
||||||
fi; \
|
fi; \
|
||||||
deps="$(llama stack list-deps "$DISTRO_NAME")"; \
|
deps="$(llama stack list-deps "$DISTRO_NAME")"; \
|
||||||
if [ -n "$deps" ]; then \
|
if [ -n "$deps" ]; then \
|
||||||
printf '%s\n' "$deps" | xargs -L1 uv pip install --no-cache; \
|
printf '%s\n' "$deps" | xargs -L1 uv pip install --no-cache-dir; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Cleanup
|
# Cleanup
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ NVIDIA inference provider for accessing NVIDIA NIM models and AI services.
|
||||||
| `url` | `<class 'str'>` | No | https://integrate.api.nvidia.com | A base url for accessing the NVIDIA NIM |
|
| `url` | `<class 'str'>` | No | https://integrate.api.nvidia.com | A base url for accessing the NVIDIA NIM |
|
||||||
| `timeout` | `<class 'int'>` | No | 60 | Timeout for the HTTP requests |
|
| `timeout` | `<class 'int'>` | No | 60 | Timeout for the HTTP requests |
|
||||||
| `append_api_version` | `<class 'bool'>` | No | True | When set to false, the API version will not be appended to the base_url. By default, it is true. |
|
| `append_api_version` | `<class 'bool'>` | No | True | When set to false, the API version will not be appended to the base_url. By default, it is true. |
|
||||||
|
| `rerank_model_to_url` | `dict[str, str` | No | `{'nv-rerank-qa-mistral-4b:1': 'https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking', 'nvidia/nv-rerankqa-mistral-4b-v3': 'https://ai.api.nvidia.com/v1/retrieval/nvidia/nv-rerankqa-mistral-4b-v3/reranking', 'nvidia/llama-3.2-nv-rerankqa-1b-v2': 'https://ai.api.nvidia.com/v1/retrieval/nvidia/llama-3_2-nv-rerankqa-1b-v2/reranking'}` | Mapping of rerank model identifiers to their API endpoints. |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
|
|
|
||||||
18
docs/static/llama-stack-spec.html
vendored
18
docs/static/llama-stack-spec.html
vendored
|
|
@ -1258,7 +1258,23 @@
|
||||||
],
|
],
|
||||||
"summary": "List routes.",
|
"summary": "List routes.",
|
||||||
"description": "List routes.\nList all available API routes with their methods and implementing providers.",
|
"description": "List routes.\nList all available API routes with their methods and implementing providers.",
|
||||||
"parameters": [],
|
"parameters": [
|
||||||
|
{
|
||||||
|
"name": "api_filter",
|
||||||
|
"in": "query",
|
||||||
|
"description": "Optional filter to control which routes are returned. Can be an API level ('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level, or 'deprecated' to show deprecated routes across all levels. If not specified, returns only non-deprecated v1 routes.",
|
||||||
|
"required": false,
|
||||||
|
"schema": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": [
|
||||||
|
"v1",
|
||||||
|
"v1alpha",
|
||||||
|
"v1beta",
|
||||||
|
"deprecated"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
"deprecated": false
|
"deprecated": false
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
|
||||||
17
docs/static/llama-stack-spec.yaml
vendored
17
docs/static/llama-stack-spec.yaml
vendored
|
|
@ -953,7 +953,22 @@ paths:
|
||||||
List routes.
|
List routes.
|
||||||
|
|
||||||
List all available API routes with their methods and implementing providers.
|
List all available API routes with their methods and implementing providers.
|
||||||
parameters: []
|
parameters:
|
||||||
|
- name: api_filter
|
||||||
|
in: query
|
||||||
|
description: >-
|
||||||
|
Optional filter to control which routes are returned. Can be an API level
|
||||||
|
('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level,
|
||||||
|
or 'deprecated' to show deprecated routes across all levels. If not specified,
|
||||||
|
returns only non-deprecated v1 routes.
|
||||||
|
required: false
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
enum:
|
||||||
|
- v1
|
||||||
|
- v1alpha
|
||||||
|
- v1beta
|
||||||
|
- deprecated
|
||||||
deprecated: false
|
deprecated: false
|
||||||
/v1/models:
|
/v1/models:
|
||||||
get:
|
get:
|
||||||
|
|
|
||||||
18
docs/static/stainless-llama-stack-spec.html
vendored
18
docs/static/stainless-llama-stack-spec.html
vendored
|
|
@ -1258,7 +1258,23 @@
|
||||||
],
|
],
|
||||||
"summary": "List routes.",
|
"summary": "List routes.",
|
||||||
"description": "List routes.\nList all available API routes with their methods and implementing providers.",
|
"description": "List routes.\nList all available API routes with their methods and implementing providers.",
|
||||||
"parameters": [],
|
"parameters": [
|
||||||
|
{
|
||||||
|
"name": "api_filter",
|
||||||
|
"in": "query",
|
||||||
|
"description": "Optional filter to control which routes are returned. Can be an API level ('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level, or 'deprecated' to show deprecated routes across all levels. If not specified, returns only non-deprecated v1 routes.",
|
||||||
|
"required": false,
|
||||||
|
"schema": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": [
|
||||||
|
"v1",
|
||||||
|
"v1alpha",
|
||||||
|
"v1beta",
|
||||||
|
"deprecated"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
"deprecated": false
|
"deprecated": false
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
|
||||||
17
docs/static/stainless-llama-stack-spec.yaml
vendored
17
docs/static/stainless-llama-stack-spec.yaml
vendored
|
|
@ -956,7 +956,22 @@ paths:
|
||||||
List routes.
|
List routes.
|
||||||
|
|
||||||
List all available API routes with their methods and implementing providers.
|
List all available API routes with their methods and implementing providers.
|
||||||
parameters: []
|
parameters:
|
||||||
|
- name: api_filter
|
||||||
|
in: query
|
||||||
|
description: >-
|
||||||
|
Optional filter to control which routes are returned. Can be an API level
|
||||||
|
('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level,
|
||||||
|
or 'deprecated' to show deprecated routes across all levels. If not specified,
|
||||||
|
returns only non-deprecated v1 routes.
|
||||||
|
required: false
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
enum:
|
||||||
|
- v1
|
||||||
|
- v1alpha
|
||||||
|
- v1beta
|
||||||
|
- deprecated
|
||||||
deprecated: false
|
deprecated: false
|
||||||
/v1/models:
|
/v1/models:
|
||||||
get:
|
get:
|
||||||
|
|
|
||||||
|
|
@ -215,6 +215,16 @@ build_image() {
|
||||||
--build-arg "LLAMA_STACK_DIR=/workspace"
|
--build-arg "LLAMA_STACK_DIR=/workspace"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Pass UV index configuration for release branches
|
||||||
|
if [[ -n "${UV_EXTRA_INDEX_URL:-}" ]]; then
|
||||||
|
echo "Adding UV_EXTRA_INDEX_URL to docker build: $UV_EXTRA_INDEX_URL"
|
||||||
|
build_cmd+=(--build-arg "UV_EXTRA_INDEX_URL=$UV_EXTRA_INDEX_URL")
|
||||||
|
fi
|
||||||
|
if [[ -n "${UV_INDEX_STRATEGY:-}" ]]; then
|
||||||
|
echo "Adding UV_INDEX_STRATEGY to docker build: $UV_INDEX_STRATEGY"
|
||||||
|
build_cmd+=(--build-arg "UV_INDEX_STRATEGY=$UV_INDEX_STRATEGY")
|
||||||
|
fi
|
||||||
|
|
||||||
if ! "${build_cmd[@]}"; then
|
if ! "${build_cmd[@]}"; then
|
||||||
echo "❌ Failed to build Docker image"
|
echo "❌ Failed to build Docker image"
|
||||||
exit 1
|
exit 1
|
||||||
|
|
|
||||||
|
|
@ -102,7 +102,6 @@ while [[ $# -gt 0 ]]; do
|
||||||
esac
|
esac
|
||||||
done
|
done
|
||||||
|
|
||||||
|
|
||||||
# Validate required parameters
|
# Validate required parameters
|
||||||
if [[ -z "$STACK_CONFIG" && "$COLLECT_ONLY" == false ]]; then
|
if [[ -z "$STACK_CONFIG" && "$COLLECT_ONLY" == false ]]; then
|
||||||
echo "Error: --stack-config is required"
|
echo "Error: --stack-config is required"
|
||||||
|
|
@ -208,6 +207,15 @@ if [[ "$STACK_CONFIG" == *"server:"* && "$COLLECT_ONLY" == false ]]; then
|
||||||
echo "=== Starting Llama Stack Server ==="
|
echo "=== Starting Llama Stack Server ==="
|
||||||
export LLAMA_STACK_LOG_WIDTH=120
|
export LLAMA_STACK_LOG_WIDTH=120
|
||||||
|
|
||||||
|
# Configure telemetry collector for server mode
|
||||||
|
# Use a fixed port for the OTEL collector so the server can connect to it
|
||||||
|
COLLECTOR_PORT=4317
|
||||||
|
export LLAMA_STACK_TEST_COLLECTOR_PORT="${COLLECTOR_PORT}"
|
||||||
|
export OTEL_EXPORTER_OTLP_ENDPOINT="http://127.0.0.1:${COLLECTOR_PORT}"
|
||||||
|
export OTEL_EXPORTER_OTLP_PROTOCOL="http/protobuf"
|
||||||
|
export OTEL_BSP_SCHEDULE_DELAY="200"
|
||||||
|
export OTEL_BSP_EXPORT_TIMEOUT="2000"
|
||||||
|
|
||||||
# remove "server:" from STACK_CONFIG
|
# remove "server:" from STACK_CONFIG
|
||||||
stack_config=$(echo "$STACK_CONFIG" | sed 's/^server://')
|
stack_config=$(echo "$STACK_CONFIG" | sed 's/^server://')
|
||||||
nohup llama stack run $stack_config >server.log 2>&1 &
|
nohup llama stack run $stack_config >server.log 2>&1 &
|
||||||
|
|
@ -271,6 +279,16 @@ if [[ "$STACK_CONFIG" == *"docker:"* && "$COLLECT_ONLY" == false ]]; then
|
||||||
--build-arg "LLAMA_STACK_DIR=/workspace"
|
--build-arg "LLAMA_STACK_DIR=/workspace"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Pass UV index configuration for release branches
|
||||||
|
if [[ -n "${UV_EXTRA_INDEX_URL:-}" ]]; then
|
||||||
|
echo "Adding UV_EXTRA_INDEX_URL to docker build: $UV_EXTRA_INDEX_URL"
|
||||||
|
build_cmd+=(--build-arg "UV_EXTRA_INDEX_URL=$UV_EXTRA_INDEX_URL")
|
||||||
|
fi
|
||||||
|
if [[ -n "${UV_INDEX_STRATEGY:-}" ]]; then
|
||||||
|
echo "Adding UV_INDEX_STRATEGY to docker build: $UV_INDEX_STRATEGY"
|
||||||
|
build_cmd+=(--build-arg "UV_INDEX_STRATEGY=$UV_INDEX_STRATEGY")
|
||||||
|
fi
|
||||||
|
|
||||||
if ! "${build_cmd[@]}"; then
|
if ! "${build_cmd[@]}"; then
|
||||||
echo "❌ Failed to build Docker image"
|
echo "❌ Failed to build Docker image"
|
||||||
exit 1
|
exit 1
|
||||||
|
|
@ -428,17 +446,13 @@ elif [ $exit_code -eq 5 ]; then
|
||||||
else
|
else
|
||||||
echo "❌ Tests failed"
|
echo "❌ Tests failed"
|
||||||
echo ""
|
echo ""
|
||||||
echo "=== Dumping last 100 lines of logs for debugging ==="
|
|
||||||
|
|
||||||
# Output server or container logs based on stack config
|
# Output server or container logs based on stack config
|
||||||
if [[ "$STACK_CONFIG" == *"server:"* && -f "server.log" ]]; then
|
if [[ "$STACK_CONFIG" == *"server:"* && -f "server.log" ]]; then
|
||||||
echo "--- Last 100 lines of server.log ---"
|
echo "--- Server side failures can be located inside server.log (available from artifacts on CI) ---"
|
||||||
tail -100 server.log
|
|
||||||
elif [[ "$STACK_CONFIG" == *"docker:"* ]]; then
|
elif [[ "$STACK_CONFIG" == *"docker:"* ]]; then
|
||||||
docker_log_file="docker-${DISTRO}-${INFERENCE_MODE}.log"
|
docker_log_file="docker-${DISTRO}-${INFERENCE_MODE}.log"
|
||||||
if [[ -f "$docker_log_file" ]]; then
|
if [[ -f "$docker_log_file" ]]; then
|
||||||
echo "--- Last 100 lines of $docker_log_file ---"
|
echo "--- Server side failures can be located inside $docker_log_file (available from artifacts on CI) ---"
|
||||||
tail -100 "$docker_log_file"
|
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
|
||||||
42
scripts/uv-run-with-index.sh
Executable file
42
scripts/uv-run-with-index.sh
Executable file
|
|
@ -0,0 +1,42 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
# Detect current branch and target branch
|
||||||
|
# In GitHub Actions, use GITHUB_REF/GITHUB_BASE_REF
|
||||||
|
if [[ -n "${GITHUB_REF:-}" ]]; then
|
||||||
|
BRANCH="${GITHUB_REF#refs/heads/}"
|
||||||
|
else
|
||||||
|
BRANCH=$(git rev-parse --abbrev-ref HEAD 2>/dev/null || echo "")
|
||||||
|
fi
|
||||||
|
|
||||||
|
# For PRs, check the target branch
|
||||||
|
if [[ -n "${GITHUB_BASE_REF:-}" ]]; then
|
||||||
|
TARGET_BRANCH="${GITHUB_BASE_REF}"
|
||||||
|
else
|
||||||
|
TARGET_BRANCH=$(git rev-parse --abbrev-ref HEAD@{upstream} 2>/dev/null | sed 's|origin/||' || echo "")
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check if on a release branch or targeting one, or LLAMA_STACK_RELEASE_MODE is set
|
||||||
|
IS_RELEASE=false
|
||||||
|
if [[ "$BRANCH" =~ ^release-[0-9]+\.[0-9]+\.x$ ]]; then
|
||||||
|
IS_RELEASE=true
|
||||||
|
elif [[ "$TARGET_BRANCH" =~ ^release-[0-9]+\.[0-9]+\.x$ ]]; then
|
||||||
|
IS_RELEASE=true
|
||||||
|
elif [[ "${LLAMA_STACK_RELEASE_MODE:-}" == "true" ]]; then
|
||||||
|
IS_RELEASE=true
|
||||||
|
fi
|
||||||
|
|
||||||
|
# On release branches, use test.pypi as extra index for RC versions
|
||||||
|
if [[ "$IS_RELEASE" == "true" ]]; then
|
||||||
|
export UV_EXTRA_INDEX_URL="https://test.pypi.org/simple/"
|
||||||
|
export UV_INDEX_STRATEGY="unsafe-best-match"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Run uv with all arguments passed through
|
||||||
|
exec uv "$@"
|
||||||
|
|
@ -4,14 +4,21 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Protocol, runtime_checkable
|
from typing import Literal, Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
from llama_stack.apis.version import (
|
||||||
|
LLAMA_STACK_API_V1,
|
||||||
|
)
|
||||||
from llama_stack.providers.datatypes import HealthStatus
|
from llama_stack.providers.datatypes import HealthStatus
|
||||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
# Valid values for the route filter parameter.
|
||||||
|
# Actual API levels: v1, v1alpha, v1beta (filters by level, excludes deprecated)
|
||||||
|
# Special filter value: "deprecated" (shows deprecated routes regardless of level)
|
||||||
|
ApiFilter = Literal["v1", "v1alpha", "v1beta", "deprecated"]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RouteInfo(BaseModel):
|
class RouteInfo(BaseModel):
|
||||||
|
|
@ -64,11 +71,12 @@ class Inspect(Protocol):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@webmethod(route="/inspect/routes", method="GET", level=LLAMA_STACK_API_V1)
|
@webmethod(route="/inspect/routes", method="GET", level=LLAMA_STACK_API_V1)
|
||||||
async def list_routes(self) -> ListRoutesResponse:
|
async def list_routes(self, api_filter: ApiFilter | None = None) -> ListRoutesResponse:
|
||||||
"""List routes.
|
"""List routes.
|
||||||
|
|
||||||
List all available API routes with their methods and implementing providers.
|
List all available API routes with their methods and implementing providers.
|
||||||
|
|
||||||
|
:param api_filter: Optional filter to control which routes are returned. Can be an API level ('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level, or 'deprecated' to show deprecated routes across all levels. If not specified, returns only non-deprecated v1 routes.
|
||||||
:returns: Response containing information about all available routes.
|
:returns: Response containing information about all available routes.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
|
||||||
|
|
@ -13,11 +13,23 @@ from pathlib import Path
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
import yaml
|
import yaml
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack.cli.stack.utils import ImageType
|
from llama_stack.cli.stack.utils import ImageType
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
from llama_stack.core.datatypes import StackRunConfig
|
from llama_stack.core.datatypes import Api, Provider, StackRunConfig
|
||||||
|
from llama_stack.core.distribution import get_provider_registry
|
||||||
from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
|
from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
|
||||||
|
from llama_stack.core.storage.datatypes import (
|
||||||
|
InferenceStoreReference,
|
||||||
|
KVStoreReference,
|
||||||
|
ServerStoresConfig,
|
||||||
|
SqliteKVStoreConfig,
|
||||||
|
SqliteSqlStoreConfig,
|
||||||
|
SqlStoreReference,
|
||||||
|
StorageConfig,
|
||||||
|
)
|
||||||
|
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||||
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
|
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
|
||||||
from llama_stack.log import LoggingConfig, get_logger
|
from llama_stack.log import LoggingConfig, get_logger
|
||||||
|
|
||||||
|
|
@ -69,6 +81,12 @@ class StackRun(Subcommand):
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Start the UI server",
|
help="Start the UI server",
|
||||||
)
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--providers",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Run a stack with only a list of providers. This list is formatted like: api1=provider1,api1=provider2,api2=provider3. Where there can be multiple providers per API.",
|
||||||
|
)
|
||||||
|
|
||||||
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
|
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
|
||||||
import yaml
|
import yaml
|
||||||
|
|
@ -94,6 +112,49 @@ class StackRun(Subcommand):
|
||||||
config_file = resolve_config_or_distro(args.config, Mode.RUN)
|
config_file = resolve_config_or_distro(args.config, Mode.RUN)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
self.parser.error(str(e))
|
self.parser.error(str(e))
|
||||||
|
elif args.providers:
|
||||||
|
provider_list: dict[str, list[Provider]] = dict()
|
||||||
|
for api_provider in args.providers.split(","):
|
||||||
|
if "=" not in api_provider:
|
||||||
|
cprint(
|
||||||
|
"Could not parse `--providers`. Please ensure the list is in the format api1=provider1,api2=provider2",
|
||||||
|
color="red",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
api, provider_type = api_provider.split("=")
|
||||||
|
providers_for_api = get_provider_registry().get(Api(api), None)
|
||||||
|
if providers_for_api is None:
|
||||||
|
cprint(
|
||||||
|
f"{api} is not a valid API.",
|
||||||
|
color="red",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
if provider_type in providers_for_api:
|
||||||
|
provider = Provider(
|
||||||
|
provider_type=provider_type,
|
||||||
|
provider_id=provider_type.split("::")[1],
|
||||||
|
)
|
||||||
|
provider_list.setdefault(api, []).append(provider)
|
||||||
|
else:
|
||||||
|
cprint(
|
||||||
|
f"{provider} is not a valid provider for the {api} API.",
|
||||||
|
color="red",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
run_config = self._generate_run_config_from_providers(providers=provider_list)
|
||||||
|
config_dict = run_config.model_dump(mode="json")
|
||||||
|
|
||||||
|
# Write config to disk in providers-run directory
|
||||||
|
distro_dir = DISTRIBS_BASE_DIR / "providers-run"
|
||||||
|
config_file = distro_dir / "run.yaml"
|
||||||
|
|
||||||
|
logger.info(f"Writing generated config to: {config_file}")
|
||||||
|
with open(config_file, "w") as f:
|
||||||
|
yaml.dump(config_dict, f, default_flow_style=False, sort_keys=False)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
config_file = None
|
config_file = None
|
||||||
|
|
||||||
|
|
@ -107,7 +168,8 @@ class StackRun(Subcommand):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
config = parse_and_maybe_upgrade_config(config_dict)
|
config = parse_and_maybe_upgrade_config(config_dict)
|
||||||
if not os.path.exists(str(config.external_providers_dir)):
|
# Create external_providers_dir if it's specified and doesn't exist
|
||||||
|
if config.external_providers_dir and not os.path.exists(str(config.external_providers_dir)):
|
||||||
os.makedirs(str(config.external_providers_dir), exist_ok=True)
|
os.makedirs(str(config.external_providers_dir), exist_ok=True)
|
||||||
except AttributeError as e:
|
except AttributeError as e:
|
||||||
self.parser.error(f"failed to parse config file '{config_file}':\n {e}")
|
self.parser.error(f"failed to parse config file '{config_file}':\n {e}")
|
||||||
|
|
@ -128,7 +190,7 @@ class StackRun(Subcommand):
|
||||||
config = StackRunConfig(**cast_image_name_to_string(replace_env_vars(config_contents)))
|
config = StackRunConfig(**cast_image_name_to_string(replace_env_vars(config_contents)))
|
||||||
|
|
||||||
port = args.port or config.server.port
|
port = args.port or config.server.port
|
||||||
host = config.server.host or ["::", "0.0.0.0"]
|
host = config.server.host or "0.0.0.0"
|
||||||
|
|
||||||
# Set the config file in environment so create_app can find it
|
# Set the config file in environment so create_app can find it
|
||||||
os.environ["LLAMA_STACK_CONFIG"] = str(config_file)
|
os.environ["LLAMA_STACK_CONFIG"] = str(config_file)
|
||||||
|
|
@ -140,6 +202,7 @@ class StackRun(Subcommand):
|
||||||
"lifespan": "on",
|
"lifespan": "on",
|
||||||
"log_level": logger.getEffectiveLevel(),
|
"log_level": logger.getEffectiveLevel(),
|
||||||
"log_config": logger_config,
|
"log_config": logger_config,
|
||||||
|
"workers": config.server.workers,
|
||||||
}
|
}
|
||||||
|
|
||||||
keyfile = config.server.tls_keyfile
|
keyfile = config.server.tls_keyfile
|
||||||
|
|
@ -340,3 +403,44 @@ class StackRun(Subcommand):
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to start UI development server in {ui_dir}: {e}")
|
logger.error(f"Failed to start UI development server in {ui_dir}: {e}")
|
||||||
|
|
||||||
|
def _generate_run_config_from_providers(self, providers: dict[str, list[Provider]]):
|
||||||
|
apis = list(providers.keys())
|
||||||
|
distro_dir = DISTRIBS_BASE_DIR / "providers-run"
|
||||||
|
# need somewhere to put the storage.
|
||||||
|
os.makedirs(distro_dir, exist_ok=True)
|
||||||
|
storage = StorageConfig(
|
||||||
|
backends={
|
||||||
|
"kv_default": SqliteKVStoreConfig(
|
||||||
|
db_path=f"${{env.SQLITE_STORE_DIR:={distro_dir}}}/kvstore.db",
|
||||||
|
),
|
||||||
|
"sql_default": SqliteSqlStoreConfig(
|
||||||
|
db_path=f"${{env.SQLITE_STORE_DIR:={distro_dir}}}/sql_store.db",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
stores=ServerStoresConfig(
|
||||||
|
metadata=KVStoreReference(
|
||||||
|
backend="kv_default",
|
||||||
|
namespace="registry",
|
||||||
|
),
|
||||||
|
inference=InferenceStoreReference(
|
||||||
|
backend="sql_default",
|
||||||
|
table_name="inference_store",
|
||||||
|
),
|
||||||
|
conversations=SqlStoreReference(
|
||||||
|
backend="sql_default",
|
||||||
|
table_name="openai_conversations",
|
||||||
|
),
|
||||||
|
prompts=KVStoreReference(
|
||||||
|
backend="kv_default",
|
||||||
|
namespace="prompts",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return StackRunConfig(
|
||||||
|
image_name="providers-run",
|
||||||
|
apis=apis,
|
||||||
|
providers=providers,
|
||||||
|
storage=storage,
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,6 @@ from llama_stack.core.distribution import (
|
||||||
get_provider_registry,
|
get_provider_registry,
|
||||||
)
|
)
|
||||||
from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
|
from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
|
||||||
from llama_stack.core.utils.config_dirs import EXTERNAL_PROVIDERS_DIR
|
|
||||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||||
from llama_stack.core.utils.prompt_for_config import prompt_for_config
|
from llama_stack.core.utils.prompt_for_config import prompt_for_config
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
@ -194,19 +193,11 @@ def upgrade_from_routing_table(
|
||||||
|
|
||||||
|
|
||||||
def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfig:
|
def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfig:
|
||||||
version = config_dict.get("version", None)
|
|
||||||
if version == LLAMA_STACK_RUN_CONFIG_VERSION:
|
|
||||||
processed_config_dict = replace_env_vars(config_dict)
|
|
||||||
return StackRunConfig(**cast_image_name_to_string(processed_config_dict))
|
|
||||||
|
|
||||||
if "routing_table" in config_dict:
|
if "routing_table" in config_dict:
|
||||||
logger.info("Upgrading config...")
|
logger.info("Upgrading config...")
|
||||||
config_dict = upgrade_from_routing_table(config_dict)
|
config_dict = upgrade_from_routing_table(config_dict)
|
||||||
|
|
||||||
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION
|
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION
|
||||||
|
|
||||||
if not config_dict.get("external_providers_dir", None):
|
|
||||||
config_dict["external_providers_dir"] = EXTERNAL_PROVIDERS_DIR
|
|
||||||
|
|
||||||
processed_config_dict = replace_env_vars(config_dict)
|
processed_config_dict = replace_env_vars(config_dict)
|
||||||
return StackRunConfig(**cast_image_name_to_string(processed_config_dict))
|
return StackRunConfig(**cast_image_name_to_string(processed_config_dict))
|
||||||
|
|
|
||||||
|
|
@ -473,6 +473,10 @@ class ServerConfig(BaseModel):
|
||||||
"- true: Enable localhost CORS for development\n"
|
"- true: Enable localhost CORS for development\n"
|
||||||
"- {allow_origins: [...], allow_methods: [...], ...}: Full configuration",
|
"- {allow_origins: [...], allow_methods: [...], ...}: Full configuration",
|
||||||
)
|
)
|
||||||
|
workers: int = Field(
|
||||||
|
default=1,
|
||||||
|
description="Number of workers to use for the server",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class StackRunConfig(BaseModel):
|
class StackRunConfig(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ from llama_stack.apis.inspect import (
|
||||||
RouteInfo,
|
RouteInfo,
|
||||||
VersionInfo,
|
VersionInfo,
|
||||||
)
|
)
|
||||||
|
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||||
from llama_stack.core.datatypes import StackRunConfig
|
from llama_stack.core.datatypes import StackRunConfig
|
||||||
from llama_stack.core.external import load_external_apis
|
from llama_stack.core.external import load_external_apis
|
||||||
from llama_stack.core.server.routes import get_all_api_routes
|
from llama_stack.core.server.routes import get_all_api_routes
|
||||||
|
|
@ -39,9 +40,21 @@ class DistributionInspectImpl(Inspect):
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def list_routes(self) -> ListRoutesResponse:
|
async def list_routes(self, api_filter: str | None = None) -> ListRoutesResponse:
|
||||||
run_config: StackRunConfig = self.config.run_config
|
run_config: StackRunConfig = self.config.run_config
|
||||||
|
|
||||||
|
# Helper function to determine if a route should be included based on api_filter
|
||||||
|
def should_include_route(webmethod) -> bool:
|
||||||
|
if api_filter is None:
|
||||||
|
# Default: only non-deprecated v1 APIs
|
||||||
|
return not webmethod.deprecated and webmethod.level == LLAMA_STACK_API_V1
|
||||||
|
elif api_filter == "deprecated":
|
||||||
|
# Special filter: show deprecated routes regardless of their actual level
|
||||||
|
return bool(webmethod.deprecated)
|
||||||
|
else:
|
||||||
|
# Filter by API level (non-deprecated routes only)
|
||||||
|
return not webmethod.deprecated and webmethod.level == api_filter
|
||||||
|
|
||||||
ret = []
|
ret = []
|
||||||
external_apis = load_external_apis(run_config)
|
external_apis = load_external_apis(run_config)
|
||||||
all_endpoints = get_all_api_routes(external_apis)
|
all_endpoints = get_all_api_routes(external_apis)
|
||||||
|
|
@ -55,8 +68,8 @@ class DistributionInspectImpl(Inspect):
|
||||||
method=next(iter([m for m in e.methods if m != "HEAD"])),
|
method=next(iter([m for m in e.methods if m != "HEAD"])),
|
||||||
provider_types=[], # These APIs don't have "real" providers - they're internal to the stack
|
provider_types=[], # These APIs don't have "real" providers - they're internal to the stack
|
||||||
)
|
)
|
||||||
for e, _ in endpoints
|
for e, webmethod in endpoints
|
||||||
if e.methods is not None
|
if e.methods is not None and should_include_route(webmethod)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
@ -69,8 +82,8 @@ class DistributionInspectImpl(Inspect):
|
||||||
method=next(iter([m for m in e.methods if m != "HEAD"])),
|
method=next(iter([m for m in e.methods if m != "HEAD"])),
|
||||||
provider_types=[p.provider_type for p in providers],
|
provider_types=[p.provider_type for p in providers],
|
||||||
)
|
)
|
||||||
for e, _ in endpoints
|
for e, webmethod in endpoints
|
||||||
if e.methods is not None
|
if e.methods is not None and should_include_route(webmethod)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -181,3 +181,22 @@ vlm_response = client.chat.completions.create(
|
||||||
|
|
||||||
print(f"VLM Response: {vlm_response.choices[0].message.content}")
|
print(f"VLM Response: {vlm_response.choices[0].message.content}")
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Rerank Example
|
||||||
|
|
||||||
|
The following example shows how to rerank documents using an NVIDIA NIM.
|
||||||
|
|
||||||
|
```python
|
||||||
|
rerank_response = client.alpha.inference.rerank(
|
||||||
|
model="nvidia/nvidia/llama-3.2-nv-rerankqa-1b-v2",
|
||||||
|
query="query",
|
||||||
|
items=[
|
||||||
|
"item_1",
|
||||||
|
"item_2",
|
||||||
|
"item_3",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, result in enumerate(rerank_response):
|
||||||
|
print(f"{i+1}. [Index: {result.index}, " f"Score: {(result.relevance_score):.3f}]")
|
||||||
|
```
|
||||||
|
|
@ -28,6 +28,7 @@ class NVIDIAConfig(RemoteInferenceProviderConfig):
|
||||||
Attributes:
|
Attributes:
|
||||||
url (str): A base url for accessing the NVIDIA NIM, e.g. http://localhost:8000
|
url (str): A base url for accessing the NVIDIA NIM, e.g. http://localhost:8000
|
||||||
api_key (str): The access key for the hosted NIM endpoints
|
api_key (str): The access key for the hosted NIM endpoints
|
||||||
|
rerank_model_to_url (dict[str, str]): Mapping of rerank model identifiers to their API endpoints
|
||||||
|
|
||||||
There are two ways to access NVIDIA NIMs -
|
There are two ways to access NVIDIA NIMs -
|
||||||
0. Hosted: Preview APIs hosted at https://integrate.api.nvidia.com
|
0. Hosted: Preview APIs hosted at https://integrate.api.nvidia.com
|
||||||
|
|
@ -55,6 +56,14 @@ class NVIDIAConfig(RemoteInferenceProviderConfig):
|
||||||
default_factory=lambda: os.getenv("NVIDIA_APPEND_API_VERSION", "True").lower() != "false",
|
default_factory=lambda: os.getenv("NVIDIA_APPEND_API_VERSION", "True").lower() != "false",
|
||||||
description="When set to false, the API version will not be appended to the base_url. By default, it is true.",
|
description="When set to false, the API version will not be appended to the base_url. By default, it is true.",
|
||||||
)
|
)
|
||||||
|
rerank_model_to_url: dict[str, str] = Field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"nv-rerank-qa-mistral-4b:1": "https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking",
|
||||||
|
"nvidia/nv-rerankqa-mistral-4b-v3": "https://ai.api.nvidia.com/v1/retrieval/nvidia/nv-rerankqa-mistral-4b-v3/reranking",
|
||||||
|
"nvidia/llama-3.2-nv-rerankqa-1b-v2": "https://ai.api.nvidia.com/v1/retrieval/nvidia/llama-3_2-nv-rerankqa-1b-v2/reranking",
|
||||||
|
},
|
||||||
|
description="Mapping of rerank model identifiers to their API endpoints. ",
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(
|
def sample_run_config(
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,19 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
from collections.abc import Iterable
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import (
|
||||||
|
RerankData,
|
||||||
|
RerankResponse,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.inference.inference import (
|
||||||
|
OpenAIChatCompletionContentPartImageParam,
|
||||||
|
OpenAIChatCompletionContentPartTextParam,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.models import Model, ModelType
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
|
|
||||||
|
|
@ -61,3 +74,101 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
|
||||||
:return: The NVIDIA API base URL
|
:return: The NVIDIA API base URL
|
||||||
"""
|
"""
|
||||||
return f"{self.config.url}/v1" if self.config.append_api_version else self.config.url
|
return f"{self.config.url}/v1" if self.config.append_api_version else self.config.url
|
||||||
|
|
||||||
|
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||||
|
"""
|
||||||
|
Return both dynamic model IDs and statically configured rerank model IDs.
|
||||||
|
"""
|
||||||
|
dynamic_ids: Iterable[str] = []
|
||||||
|
try:
|
||||||
|
dynamic_ids = await super().list_provider_model_ids()
|
||||||
|
except Exception:
|
||||||
|
# If the dynamic listing fails, proceed with just configured rerank IDs
|
||||||
|
dynamic_ids = []
|
||||||
|
|
||||||
|
configured_rerank_ids = list(self.config.rerank_model_to_url.keys())
|
||||||
|
return list(dict.fromkeys(list(dynamic_ids) + configured_rerank_ids)) # remove duplicates
|
||||||
|
|
||||||
|
def construct_model_from_identifier(self, identifier: str) -> Model:
|
||||||
|
"""
|
||||||
|
Classify rerank models from config; otherwise use the base behavior.
|
||||||
|
"""
|
||||||
|
if identifier in self.config.rerank_model_to_url:
|
||||||
|
return Model(
|
||||||
|
provider_id=self.__provider_id__, # type: ignore[attr-defined]
|
||||||
|
provider_resource_id=identifier,
|
||||||
|
identifier=identifier,
|
||||||
|
model_type=ModelType.rerank,
|
||||||
|
)
|
||||||
|
return super().construct_model_from_identifier(identifier)
|
||||||
|
|
||||||
|
async def rerank(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
|
||||||
|
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
|
||||||
|
max_num_results: int | None = None,
|
||||||
|
) -> RerankResponse:
|
||||||
|
provider_model_id = await self._get_provider_model_id(model)
|
||||||
|
|
||||||
|
ranking_url = self.get_base_url()
|
||||||
|
|
||||||
|
if _is_nvidia_hosted(self.config) and provider_model_id in self.config.rerank_model_to_url:
|
||||||
|
ranking_url = self.config.rerank_model_to_url[provider_model_id]
|
||||||
|
|
||||||
|
logger.debug(f"Using rerank endpoint: {ranking_url} for model: {provider_model_id}")
|
||||||
|
|
||||||
|
# Convert query to text format
|
||||||
|
if isinstance(query, str):
|
||||||
|
query_text = query
|
||||||
|
elif isinstance(query, OpenAIChatCompletionContentPartTextParam):
|
||||||
|
query_text = query.text
|
||||||
|
else:
|
||||||
|
raise ValueError("Query must be a string or text content part")
|
||||||
|
|
||||||
|
# Convert items to text format
|
||||||
|
passages = []
|
||||||
|
for item in items:
|
||||||
|
if isinstance(item, str):
|
||||||
|
passages.append({"text": item})
|
||||||
|
elif isinstance(item, OpenAIChatCompletionContentPartTextParam):
|
||||||
|
passages.append({"text": item.text})
|
||||||
|
else:
|
||||||
|
raise ValueError("Items must be strings or text content parts")
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": provider_model_id,
|
||||||
|
"query": {"text": query_text},
|
||||||
|
"passages": passages,
|
||||||
|
}
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.get_api_key()}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(ranking_url, headers=headers, json=payload) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
response_text = await response.text()
|
||||||
|
raise ConnectionError(
|
||||||
|
f"NVIDIA rerank API request failed with status {response.status}: {response_text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await response.json()
|
||||||
|
rankings = result.get("rankings", [])
|
||||||
|
|
||||||
|
# Convert to RerankData format
|
||||||
|
rerank_data = []
|
||||||
|
for ranking in rankings:
|
||||||
|
rerank_data.append(RerankData(index=ranking["index"], relevance_score=ranking["logit"]))
|
||||||
|
|
||||||
|
# Apply max_num_results limit
|
||||||
|
if max_num_results is not None:
|
||||||
|
rerank_data = rerank_data[:max_num_results]
|
||||||
|
|
||||||
|
return RerankResponse(data=rerank_data)
|
||||||
|
|
||||||
|
except aiohttp.ClientError as e:
|
||||||
|
raise ConnectionError(f"Failed to connect to NVIDIA rerank API at {ranking_url}: {e}") from e
|
||||||
|
|
|
||||||
|
|
@ -51,7 +51,11 @@ async function proxyRequest(request: NextRequest, method: string) {
|
||||||
);
|
);
|
||||||
|
|
||||||
// Create response with same status and headers
|
// Create response with same status and headers
|
||||||
const proxyResponse = new NextResponse(responseText, {
|
// Handle 204 No Content responses specially
|
||||||
|
const proxyResponse =
|
||||||
|
response.status === 204
|
||||||
|
? new NextResponse(null, { status: 204 })
|
||||||
|
: new NextResponse(responseText, {
|
||||||
status: response.status,
|
status: response.status,
|
||||||
statusText: response.statusText,
|
statusText: response.statusText,
|
||||||
});
|
});
|
||||||
|
|
|
||||||
5
src/llama_stack/ui/app/prompts/page.tsx
Normal file
5
src/llama_stack/ui/app/prompts/page.tsx
Normal file
|
|
@ -0,0 +1,5 @@
|
||||||
|
import { PromptManagement } from "@/components/prompts";
|
||||||
|
|
||||||
|
export default function PromptsPage() {
|
||||||
|
return <PromptManagement />;
|
||||||
|
}
|
||||||
|
|
@ -8,6 +8,7 @@ import {
|
||||||
MessageCircle,
|
MessageCircle,
|
||||||
Settings2,
|
Settings2,
|
||||||
Compass,
|
Compass,
|
||||||
|
FileText,
|
||||||
} from "lucide-react";
|
} from "lucide-react";
|
||||||
import Link from "next/link";
|
import Link from "next/link";
|
||||||
import { usePathname } from "next/navigation";
|
import { usePathname } from "next/navigation";
|
||||||
|
|
@ -50,6 +51,11 @@ const manageItems = [
|
||||||
url: "/logs/vector-stores",
|
url: "/logs/vector-stores",
|
||||||
icon: Database,
|
icon: Database,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
title: "Prompts",
|
||||||
|
url: "/prompts",
|
||||||
|
icon: FileText,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
title: "Documentation",
|
title: "Documentation",
|
||||||
url: "https://llama-stack.readthedocs.io/en/latest/references/api_reference/index.html",
|
url: "https://llama-stack.readthedocs.io/en/latest/references/api_reference/index.html",
|
||||||
|
|
|
||||||
4
src/llama_stack/ui/components/prompts/index.ts
Normal file
4
src/llama_stack/ui/components/prompts/index.ts
Normal file
|
|
@ -0,0 +1,4 @@
|
||||||
|
export { PromptManagement } from "./prompt-management";
|
||||||
|
export { PromptList } from "./prompt-list";
|
||||||
|
export { PromptEditor } from "./prompt-editor";
|
||||||
|
export * from "./types";
|
||||||
309
src/llama_stack/ui/components/prompts/prompt-editor.test.tsx
Normal file
309
src/llama_stack/ui/components/prompts/prompt-editor.test.tsx
Normal file
|
|
@ -0,0 +1,309 @@
|
||||||
|
import React from "react";
|
||||||
|
import { render, screen, fireEvent } from "@testing-library/react";
|
||||||
|
import "@testing-library/jest-dom";
|
||||||
|
import { PromptEditor } from "./prompt-editor";
|
||||||
|
import type { Prompt, PromptFormData } from "./types";
|
||||||
|
|
||||||
|
describe("PromptEditor", () => {
|
||||||
|
const mockOnSave = jest.fn();
|
||||||
|
const mockOnCancel = jest.fn();
|
||||||
|
const mockOnDelete = jest.fn();
|
||||||
|
|
||||||
|
const defaultProps = {
|
||||||
|
onSave: mockOnSave,
|
||||||
|
onCancel: mockOnCancel,
|
||||||
|
onDelete: mockOnDelete,
|
||||||
|
};
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
jest.clearAllMocks();
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Create Mode", () => {
|
||||||
|
test("renders create form correctly", () => {
|
||||||
|
render(<PromptEditor {...defaultProps} />);
|
||||||
|
|
||||||
|
expect(screen.getByLabelText("Prompt Content *")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Variables")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Preview")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Create Prompt")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Cancel")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("shows preview placeholder when no content", () => {
|
||||||
|
render(<PromptEditor {...defaultProps} />);
|
||||||
|
|
||||||
|
expect(
|
||||||
|
screen.getByText("Enter content to preview the compiled prompt")
|
||||||
|
).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("submits form with correct data", () => {
|
||||||
|
render(<PromptEditor {...defaultProps} />);
|
||||||
|
|
||||||
|
const promptInput = screen.getByLabelText("Prompt Content *");
|
||||||
|
fireEvent.change(promptInput, {
|
||||||
|
target: { value: "Hello {{name}}, welcome!" },
|
||||||
|
});
|
||||||
|
|
||||||
|
fireEvent.click(screen.getByText("Create Prompt"));
|
||||||
|
|
||||||
|
expect(mockOnSave).toHaveBeenCalledWith({
|
||||||
|
prompt: "Hello {{name}}, welcome!",
|
||||||
|
variables: [],
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
test("prevents submission with empty prompt", () => {
|
||||||
|
render(<PromptEditor {...defaultProps} />);
|
||||||
|
|
||||||
|
fireEvent.click(screen.getByText("Create Prompt"));
|
||||||
|
|
||||||
|
expect(mockOnSave).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Edit Mode", () => {
|
||||||
|
const mockPrompt: Prompt = {
|
||||||
|
prompt_id: "prompt_123",
|
||||||
|
prompt: "Hello {{name}}, how is {{weather}}?",
|
||||||
|
version: 1,
|
||||||
|
variables: ["name", "weather"],
|
||||||
|
is_default: true,
|
||||||
|
};
|
||||||
|
|
||||||
|
test("renders edit form with existing data", () => {
|
||||||
|
render(<PromptEditor {...defaultProps} prompt={mockPrompt} />);
|
||||||
|
|
||||||
|
expect(
|
||||||
|
screen.getByDisplayValue("Hello {{name}}, how is {{weather}}?")
|
||||||
|
).toBeInTheDocument();
|
||||||
|
expect(screen.getAllByText("name")).toHaveLength(2); // One in variables, one in preview
|
||||||
|
expect(screen.getAllByText("weather")).toHaveLength(2); // One in variables, one in preview
|
||||||
|
expect(screen.getByText("Update Prompt")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Delete Prompt")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("submits updated data correctly", () => {
|
||||||
|
render(<PromptEditor {...defaultProps} prompt={mockPrompt} />);
|
||||||
|
|
||||||
|
const promptInput = screen.getByLabelText("Prompt Content *");
|
||||||
|
fireEvent.change(promptInput, {
|
||||||
|
target: { value: "Updated: Hello {{name}}!" },
|
||||||
|
});
|
||||||
|
|
||||||
|
fireEvent.click(screen.getByText("Update Prompt"));
|
||||||
|
|
||||||
|
expect(mockOnSave).toHaveBeenCalledWith({
|
||||||
|
prompt: "Updated: Hello {{name}}!",
|
||||||
|
variables: ["name", "weather"],
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Variables Management", () => {
|
||||||
|
test("adds new variable", () => {
|
||||||
|
render(<PromptEditor {...defaultProps} />);
|
||||||
|
|
||||||
|
const variableInput = screen.getByPlaceholderText(
|
||||||
|
"Add variable name (e.g. user_name, topic)"
|
||||||
|
);
|
||||||
|
fireEvent.change(variableInput, { target: { value: "testVar" } });
|
||||||
|
fireEvent.click(screen.getByText("Add"));
|
||||||
|
|
||||||
|
expect(screen.getByText("testVar")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("prevents adding duplicate variables", () => {
|
||||||
|
render(<PromptEditor {...defaultProps} />);
|
||||||
|
|
||||||
|
const variableInput = screen.getByPlaceholderText(
|
||||||
|
"Add variable name (e.g. user_name, topic)"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Add first variable
|
||||||
|
fireEvent.change(variableInput, { target: { value: "test" } });
|
||||||
|
fireEvent.click(screen.getByText("Add"));
|
||||||
|
|
||||||
|
// Try to add same variable again
|
||||||
|
fireEvent.change(variableInput, { target: { value: "test" } });
|
||||||
|
|
||||||
|
// Button should be disabled
|
||||||
|
expect(screen.getByText("Add")).toBeDisabled();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("removes variable", () => {
|
||||||
|
const mockPrompt: Prompt = {
|
||||||
|
prompt_id: "prompt_123",
|
||||||
|
prompt: "Hello {{name}}",
|
||||||
|
version: 1,
|
||||||
|
variables: ["name", "location"],
|
||||||
|
is_default: true,
|
||||||
|
};
|
||||||
|
|
||||||
|
render(<PromptEditor {...defaultProps} prompt={mockPrompt} />);
|
||||||
|
|
||||||
|
// Check that both variables are present initially
|
||||||
|
expect(screen.getAllByText("name").length).toBeGreaterThan(0);
|
||||||
|
expect(screen.getAllByText("location").length).toBeGreaterThan(0);
|
||||||
|
|
||||||
|
// Remove the location variable by clicking the X button with the specific title
|
||||||
|
const removeLocationButton = screen.getByTitle(
|
||||||
|
"Remove location variable"
|
||||||
|
);
|
||||||
|
fireEvent.click(removeLocationButton);
|
||||||
|
|
||||||
|
// Name should still be there, location should be gone from the variables section
|
||||||
|
expect(screen.getAllByText("name").length).toBeGreaterThan(0);
|
||||||
|
expect(
|
||||||
|
screen.queryByTitle("Remove location variable")
|
||||||
|
).not.toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("adds variable on Enter key", () => {
|
||||||
|
render(<PromptEditor {...defaultProps} />);
|
||||||
|
|
||||||
|
const variableInput = screen.getByPlaceholderText(
|
||||||
|
"Add variable name (e.g. user_name, topic)"
|
||||||
|
);
|
||||||
|
fireEvent.change(variableInput, { target: { value: "enterVar" } });
|
||||||
|
|
||||||
|
// Simulate Enter key press
|
||||||
|
fireEvent.keyPress(variableInput, {
|
||||||
|
key: "Enter",
|
||||||
|
code: "Enter",
|
||||||
|
charCode: 13,
|
||||||
|
preventDefault: jest.fn(),
|
||||||
|
});
|
||||||
|
|
||||||
|
// Check if the variable was added by looking for the badge
|
||||||
|
expect(screen.getAllByText("enterVar").length).toBeGreaterThan(0);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Preview Functionality", () => {
|
||||||
|
test("shows live preview with variables", () => {
|
||||||
|
render(<PromptEditor {...defaultProps} />);
|
||||||
|
|
||||||
|
// Add prompt content
|
||||||
|
const promptInput = screen.getByLabelText("Prompt Content *");
|
||||||
|
fireEvent.change(promptInput, {
|
||||||
|
target: { value: "Hello {{name}}, welcome to {{place}}!" },
|
||||||
|
});
|
||||||
|
|
||||||
|
// Add variables
|
||||||
|
const variableInput = screen.getByPlaceholderText(
|
||||||
|
"Add variable name (e.g. user_name, topic)"
|
||||||
|
);
|
||||||
|
fireEvent.change(variableInput, { target: { value: "name" } });
|
||||||
|
fireEvent.click(screen.getByText("Add"));
|
||||||
|
|
||||||
|
fireEvent.change(variableInput, { target: { value: "place" } });
|
||||||
|
fireEvent.click(screen.getByText("Add"));
|
||||||
|
|
||||||
|
// Check that preview area shows the content
|
||||||
|
expect(screen.getByText("Compiled Prompt")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("shows variable value inputs in preview", () => {
|
||||||
|
const mockPrompt: Prompt = {
|
||||||
|
prompt_id: "prompt_123",
|
||||||
|
prompt: "Hello {{name}}",
|
||||||
|
version: 1,
|
||||||
|
variables: ["name"],
|
||||||
|
is_default: true,
|
||||||
|
};
|
||||||
|
|
||||||
|
render(<PromptEditor {...defaultProps} prompt={mockPrompt} />);
|
||||||
|
|
||||||
|
expect(screen.getByText("Variable Values")).toBeInTheDocument();
|
||||||
|
expect(
|
||||||
|
screen.getByPlaceholderText("Enter value for name")
|
||||||
|
).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("shows color legend for variable states", () => {
|
||||||
|
render(<PromptEditor {...defaultProps} />);
|
||||||
|
|
||||||
|
// Add content to show preview
|
||||||
|
const promptInput = screen.getByLabelText("Prompt Content *");
|
||||||
|
fireEvent.change(promptInput, {
|
||||||
|
target: { value: "Hello {{name}}" },
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(screen.getByText("Used")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Unused")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Undefined")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Error Handling", () => {
|
||||||
|
test("displays error message", () => {
|
||||||
|
const errorMessage = "Prompt contains undeclared variables";
|
||||||
|
render(<PromptEditor {...defaultProps} error={errorMessage} />);
|
||||||
|
|
||||||
|
expect(screen.getByText(errorMessage)).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Delete Functionality", () => {
|
||||||
|
const mockPrompt: Prompt = {
|
||||||
|
prompt_id: "prompt_123",
|
||||||
|
prompt: "Hello {{name}}",
|
||||||
|
version: 1,
|
||||||
|
variables: ["name"],
|
||||||
|
is_default: true,
|
||||||
|
};
|
||||||
|
|
||||||
|
test("shows delete button in edit mode", () => {
|
||||||
|
render(<PromptEditor {...defaultProps} prompt={mockPrompt} />);
|
||||||
|
|
||||||
|
expect(screen.getByText("Delete Prompt")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("hides delete button in create mode", () => {
|
||||||
|
render(<PromptEditor {...defaultProps} />);
|
||||||
|
|
||||||
|
expect(screen.queryByText("Delete Prompt")).not.toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("calls onDelete with confirmation", () => {
|
||||||
|
const originalConfirm = window.confirm;
|
||||||
|
window.confirm = jest.fn(() => true);
|
||||||
|
|
||||||
|
render(<PromptEditor {...defaultProps} prompt={mockPrompt} />);
|
||||||
|
|
||||||
|
fireEvent.click(screen.getByText("Delete Prompt"));
|
||||||
|
|
||||||
|
expect(window.confirm).toHaveBeenCalledWith(
|
||||||
|
"Are you sure you want to delete this prompt? This action cannot be undone."
|
||||||
|
);
|
||||||
|
expect(mockOnDelete).toHaveBeenCalledWith("prompt_123");
|
||||||
|
|
||||||
|
window.confirm = originalConfirm;
|
||||||
|
});
|
||||||
|
|
||||||
|
test("does not delete when confirmation is cancelled", () => {
|
||||||
|
const originalConfirm = window.confirm;
|
||||||
|
window.confirm = jest.fn(() => false);
|
||||||
|
|
||||||
|
render(<PromptEditor {...defaultProps} prompt={mockPrompt} />);
|
||||||
|
|
||||||
|
fireEvent.click(screen.getByText("Delete Prompt"));
|
||||||
|
|
||||||
|
expect(mockOnDelete).not.toHaveBeenCalled();
|
||||||
|
|
||||||
|
window.confirm = originalConfirm;
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Cancel Functionality", () => {
|
||||||
|
test("calls onCancel when cancel button is clicked", () => {
|
||||||
|
render(<PromptEditor {...defaultProps} />);
|
||||||
|
|
||||||
|
fireEvent.click(screen.getByText("Cancel"));
|
||||||
|
|
||||||
|
expect(mockOnCancel).toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
346
src/llama_stack/ui/components/prompts/prompt-editor.tsx
Normal file
346
src/llama_stack/ui/components/prompts/prompt-editor.tsx
Normal file
|
|
@ -0,0 +1,346 @@
|
||||||
|
"use client";
|
||||||
|
|
||||||
|
import { useState, useEffect } from "react";
|
||||||
|
import { Button } from "@/components/ui/button";
|
||||||
|
import { Input } from "@/components/ui/input";
|
||||||
|
import { Label } from "@/components/ui/label";
|
||||||
|
import { Textarea } from "@/components/ui/textarea";
|
||||||
|
import { Badge } from "@/components/ui/badge";
|
||||||
|
import {
|
||||||
|
Card,
|
||||||
|
CardContent,
|
||||||
|
CardDescription,
|
||||||
|
CardHeader,
|
||||||
|
CardTitle,
|
||||||
|
} from "@/components/ui/card";
|
||||||
|
import { Separator } from "@/components/ui/separator";
|
||||||
|
import { X, Plus, Save, Trash2 } from "lucide-react";
|
||||||
|
import { Prompt, PromptFormData } from "./types";
|
||||||
|
|
||||||
|
interface PromptEditorProps {
|
||||||
|
prompt?: Prompt;
|
||||||
|
onSave: (prompt: PromptFormData) => void;
|
||||||
|
onCancel: () => void;
|
||||||
|
onDelete?: (promptId: string) => void;
|
||||||
|
error?: string | null;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function PromptEditor({
|
||||||
|
prompt,
|
||||||
|
onSave,
|
||||||
|
onCancel,
|
||||||
|
onDelete,
|
||||||
|
error,
|
||||||
|
}: PromptEditorProps) {
|
||||||
|
const [formData, setFormData] = useState<PromptFormData>({
|
||||||
|
prompt: "",
|
||||||
|
variables: [],
|
||||||
|
});
|
||||||
|
|
||||||
|
const [newVariable, setNewVariable] = useState("");
|
||||||
|
const [variableValues, setVariableValues] = useState<Record<string, string>>(
|
||||||
|
{}
|
||||||
|
);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (prompt) {
|
||||||
|
setFormData({
|
||||||
|
prompt: prompt.prompt || "",
|
||||||
|
variables: prompt.variables || [],
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}, [prompt]);
|
||||||
|
|
||||||
|
const handleSubmit = (e: React.FormEvent) => {
|
||||||
|
e.preventDefault();
|
||||||
|
if (!formData.prompt.trim()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
onSave(formData);
|
||||||
|
};
|
||||||
|
|
||||||
|
const addVariable = () => {
|
||||||
|
if (
|
||||||
|
newVariable.trim() &&
|
||||||
|
!formData.variables.includes(newVariable.trim())
|
||||||
|
) {
|
||||||
|
setFormData(prev => ({
|
||||||
|
...prev,
|
||||||
|
variables: [...prev.variables, newVariable.trim()],
|
||||||
|
}));
|
||||||
|
setNewVariable("");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const removeVariable = (variableToRemove: string) => {
|
||||||
|
setFormData(prev => ({
|
||||||
|
...prev,
|
||||||
|
variables: prev.variables.filter(
|
||||||
|
variable => variable !== variableToRemove
|
||||||
|
),
|
||||||
|
}));
|
||||||
|
};
|
||||||
|
|
||||||
|
const renderPreview = () => {
|
||||||
|
const text = formData.prompt;
|
||||||
|
if (!text) return text;
|
||||||
|
|
||||||
|
// Split text by variable patterns and process each part
|
||||||
|
const parts = text.split(/(\{\{\s*\w+\s*\}\})/g);
|
||||||
|
|
||||||
|
return parts.map((part, index) => {
|
||||||
|
const variableMatch = part.match(/\{\{\s*(\w+)\s*\}\}/);
|
||||||
|
if (variableMatch) {
|
||||||
|
const variableName = variableMatch[1];
|
||||||
|
const isDefined = formData.variables.includes(variableName);
|
||||||
|
const value = variableValues[variableName];
|
||||||
|
|
||||||
|
if (!isDefined) {
|
||||||
|
// Variable not in variables list - likely a typo/bug (RED)
|
||||||
|
return (
|
||||||
|
<span
|
||||||
|
key={index}
|
||||||
|
className="bg-red-100 text-red-800 dark:bg-red-900 dark:text-red-200 px-1 rounded font-medium"
|
||||||
|
>
|
||||||
|
{part}
|
||||||
|
</span>
|
||||||
|
);
|
||||||
|
} else if (value && value.trim()) {
|
||||||
|
// Variable defined and has value - show the value (GREEN)
|
||||||
|
return (
|
||||||
|
<span
|
||||||
|
key={index}
|
||||||
|
className="bg-green-100 text-green-800 dark:bg-green-900 dark:text-green-200 px-1 rounded font-medium"
|
||||||
|
>
|
||||||
|
{value}
|
||||||
|
</span>
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
// Variable defined but empty (YELLOW)
|
||||||
|
return (
|
||||||
|
<span
|
||||||
|
key={index}
|
||||||
|
className="bg-yellow-100 text-yellow-800 dark:bg-yellow-900 dark:text-yellow-200 px-1 rounded font-medium"
|
||||||
|
>
|
||||||
|
{part}
|
||||||
|
</span>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return part;
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
const updateVariableValue = (variable: string, value: string) => {
|
||||||
|
setVariableValues(prev => ({
|
||||||
|
...prev,
|
||||||
|
[variable]: value,
|
||||||
|
}));
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<form onSubmit={handleSubmit} className="space-y-6">
|
||||||
|
{error && (
|
||||||
|
<div className="p-4 bg-destructive/10 border border-destructive/20 rounded-md">
|
||||||
|
<p className="text-destructive text-sm">{error}</p>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
<div className="grid grid-cols-1 lg:grid-cols-2 gap-6">
|
||||||
|
{/* Form Section */}
|
||||||
|
<div className="space-y-4">
|
||||||
|
<div>
|
||||||
|
<Label htmlFor="prompt">Prompt Content *</Label>
|
||||||
|
<Textarea
|
||||||
|
id="prompt"
|
||||||
|
value={formData.prompt}
|
||||||
|
onChange={e =>
|
||||||
|
setFormData(prev => ({ ...prev, prompt: e.target.value }))
|
||||||
|
}
|
||||||
|
placeholder="Enter your prompt content here. Use {{variable_name}} for dynamic variables."
|
||||||
|
className="min-h-32 font-mono mt-2"
|
||||||
|
required
|
||||||
|
/>
|
||||||
|
<p className="text-xs text-muted-foreground mt-2">
|
||||||
|
Use double curly braces around variable names, e.g.,{" "}
|
||||||
|
{`{{user_name}}`} or {`{{topic}}`}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="space-y-3">
|
||||||
|
<Label className="text-sm font-medium">Variables</Label>
|
||||||
|
|
||||||
|
<div className="flex gap-2 mt-2">
|
||||||
|
<Input
|
||||||
|
value={newVariable}
|
||||||
|
onChange={e => setNewVariable(e.target.value)}
|
||||||
|
placeholder="Add variable name (e.g. user_name, topic)"
|
||||||
|
onKeyPress={e =>
|
||||||
|
e.key === "Enter" && (e.preventDefault(), addVariable())
|
||||||
|
}
|
||||||
|
className="flex-1"
|
||||||
|
/>
|
||||||
|
<Button
|
||||||
|
type="button"
|
||||||
|
onClick={addVariable}
|
||||||
|
size="sm"
|
||||||
|
disabled={
|
||||||
|
!newVariable.trim() ||
|
||||||
|
formData.variables.includes(newVariable.trim())
|
||||||
|
}
|
||||||
|
>
|
||||||
|
<Plus className="h-4 w-4" />
|
||||||
|
Add
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{formData.variables.length > 0 && (
|
||||||
|
<div className="border rounded-lg p-3 bg-muted/20">
|
||||||
|
<div className="flex flex-wrap gap-2">
|
||||||
|
{formData.variables.map(variable => (
|
||||||
|
<Badge
|
||||||
|
key={variable}
|
||||||
|
variant="secondary"
|
||||||
|
className="text-sm px-2 py-1"
|
||||||
|
>
|
||||||
|
{variable}
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
onClick={() => removeVariable(variable)}
|
||||||
|
className="ml-2 hover:text-destructive transition-colors"
|
||||||
|
title={`Remove ${variable} variable`}
|
||||||
|
>
|
||||||
|
<X className="h-3 w-3" />
|
||||||
|
</button>
|
||||||
|
</Badge>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<p className="text-xs text-muted-foreground">
|
||||||
|
Variables that can be used in the prompt template. Each variable
|
||||||
|
should match a {`{{variable}}`} placeholder in the content above.
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Preview Section */}
|
||||||
|
<div className="space-y-4">
|
||||||
|
<Card>
|
||||||
|
<CardHeader>
|
||||||
|
<CardTitle className="text-lg">Preview</CardTitle>
|
||||||
|
<CardDescription>
|
||||||
|
Live preview of compiled prompt and variable substitution.
|
||||||
|
</CardDescription>
|
||||||
|
</CardHeader>
|
||||||
|
<CardContent className="space-y-4">
|
||||||
|
{formData.prompt ? (
|
||||||
|
<>
|
||||||
|
{/* Variable Values */}
|
||||||
|
{formData.variables.length > 0 && (
|
||||||
|
<div className="space-y-3">
|
||||||
|
<Label className="text-sm font-medium">
|
||||||
|
Variable Values
|
||||||
|
</Label>
|
||||||
|
<div className="space-y-2">
|
||||||
|
{formData.variables.map(variable => (
|
||||||
|
<div
|
||||||
|
key={variable}
|
||||||
|
className="grid grid-cols-2 gap-3 items-center"
|
||||||
|
>
|
||||||
|
<div className="text-sm font-mono text-muted-foreground">
|
||||||
|
{variable}
|
||||||
|
</div>
|
||||||
|
<Input
|
||||||
|
id={`var-${variable}`}
|
||||||
|
value={variableValues[variable] || ""}
|
||||||
|
onChange={e =>
|
||||||
|
updateVariableValue(variable, e.target.value)
|
||||||
|
}
|
||||||
|
placeholder={`Enter value for ${variable}`}
|
||||||
|
className="text-sm"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
<Separator />
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Live Preview */}
|
||||||
|
<div>
|
||||||
|
<Label className="text-sm font-medium mb-2 block">
|
||||||
|
Compiled Prompt
|
||||||
|
</Label>
|
||||||
|
<div className="bg-muted/50 p-4 rounded-lg border">
|
||||||
|
<div className="text-sm leading-relaxed whitespace-pre-wrap">
|
||||||
|
{renderPreview()}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div className="flex flex-wrap gap-4 mt-2 text-xs">
|
||||||
|
<div className="flex items-center gap-1">
|
||||||
|
<div className="w-3 h-3 bg-green-500 dark:bg-green-400 border rounded"></div>
|
||||||
|
<span className="text-muted-foreground">Used</span>
|
||||||
|
</div>
|
||||||
|
<div className="flex items-center gap-1">
|
||||||
|
<div className="w-3 h-3 bg-yellow-500 dark:bg-yellow-400 border rounded"></div>
|
||||||
|
<span className="text-muted-foreground">Unused</span>
|
||||||
|
</div>
|
||||||
|
<div className="flex items-center gap-1">
|
||||||
|
<div className="w-3 h-3 bg-red-500 dark:bg-red-400 border rounded"></div>
|
||||||
|
<span className="text-muted-foreground">Undefined</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</>
|
||||||
|
) : (
|
||||||
|
<div className="text-center py-8">
|
||||||
|
<div className="text-muted-foreground text-sm">
|
||||||
|
Enter content to preview the compiled prompt
|
||||||
|
</div>
|
||||||
|
<div className="text-xs text-muted-foreground mt-2">
|
||||||
|
Use {`{{variable_name}}`} to add dynamic variables
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</CardContent>
|
||||||
|
</Card>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<Separator />
|
||||||
|
|
||||||
|
<div className="flex justify-between">
|
||||||
|
<div>
|
||||||
|
{prompt && onDelete && (
|
||||||
|
<Button
|
||||||
|
type="button"
|
||||||
|
variant="destructive"
|
||||||
|
onClick={() => {
|
||||||
|
if (
|
||||||
|
confirm(
|
||||||
|
`Are you sure you want to delete this prompt? This action cannot be undone.`
|
||||||
|
)
|
||||||
|
) {
|
||||||
|
onDelete(prompt.prompt_id);
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Trash2 className="h-4 w-4 mr-2" />
|
||||||
|
Delete Prompt
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
<div className="flex gap-2">
|
||||||
|
<Button type="button" variant="outline" onClick={onCancel}>
|
||||||
|
Cancel
|
||||||
|
</Button>
|
||||||
|
<Button type="submit">
|
||||||
|
<Save className="h-4 w-4 mr-2" />
|
||||||
|
{prompt ? "Update" : "Create"} Prompt
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</form>
|
||||||
|
);
|
||||||
|
}
|
||||||
259
src/llama_stack/ui/components/prompts/prompt-list.test.tsx
Normal file
259
src/llama_stack/ui/components/prompts/prompt-list.test.tsx
Normal file
|
|
@ -0,0 +1,259 @@
|
||||||
|
import React from "react";
|
||||||
|
import { render, screen, fireEvent } from "@testing-library/react";
|
||||||
|
import "@testing-library/jest-dom";
|
||||||
|
import { PromptList } from "./prompt-list";
|
||||||
|
import type { Prompt } from "./types";
|
||||||
|
|
||||||
|
describe("PromptList", () => {
|
||||||
|
const mockOnEdit = jest.fn();
|
||||||
|
const mockOnDelete = jest.fn();
|
||||||
|
|
||||||
|
const defaultProps = {
|
||||||
|
prompts: [],
|
||||||
|
onEdit: mockOnEdit,
|
||||||
|
onDelete: mockOnDelete,
|
||||||
|
};
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
jest.clearAllMocks();
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Empty State", () => {
|
||||||
|
test("renders empty message when no prompts", () => {
|
||||||
|
render(<PromptList {...defaultProps} />);
|
||||||
|
|
||||||
|
expect(screen.getByText("No prompts yet")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("shows filtered empty message when search has no results", () => {
|
||||||
|
const prompts: Prompt[] = [
|
||||||
|
{
|
||||||
|
prompt_id: "prompt_123",
|
||||||
|
prompt: "Hello world",
|
||||||
|
version: 1,
|
||||||
|
variables: [],
|
||||||
|
is_default: false,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
render(<PromptList {...defaultProps} prompts={prompts} />);
|
||||||
|
|
||||||
|
// Search for something that doesn't exist
|
||||||
|
const searchInput = screen.getByPlaceholderText("Search prompts...");
|
||||||
|
fireEvent.change(searchInput, { target: { value: "nonexistent" } });
|
||||||
|
|
||||||
|
expect(
|
||||||
|
screen.getByText("No prompts match your filters")
|
||||||
|
).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Prompts Display", () => {
|
||||||
|
const mockPrompts: Prompt[] = [
|
||||||
|
{
|
||||||
|
prompt_id: "prompt_123",
|
||||||
|
prompt: "Hello {{name}}, how are you?",
|
||||||
|
version: 1,
|
||||||
|
variables: ["name"],
|
||||||
|
is_default: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
prompt_id: "prompt_456",
|
||||||
|
prompt: "Summarize this {{text}} in {{length}} words",
|
||||||
|
version: 2,
|
||||||
|
variables: ["text", "length"],
|
||||||
|
is_default: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
prompt_id: "prompt_789",
|
||||||
|
prompt: "Simple prompt with no variables",
|
||||||
|
version: 1,
|
||||||
|
variables: [],
|
||||||
|
is_default: false,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
test("renders prompts table with correct headers", () => {
|
||||||
|
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
|
||||||
|
|
||||||
|
expect(screen.getByText("ID")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Content")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Variables")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Version")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Actions")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("renders prompt data correctly", () => {
|
||||||
|
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
|
||||||
|
|
||||||
|
// Check prompt IDs
|
||||||
|
expect(screen.getByText("prompt_123")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("prompt_456")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("prompt_789")).toBeInTheDocument();
|
||||||
|
|
||||||
|
// Check content
|
||||||
|
expect(
|
||||||
|
screen.getByText("Hello {{name}}, how are you?")
|
||||||
|
).toBeInTheDocument();
|
||||||
|
expect(
|
||||||
|
screen.getByText("Summarize this {{text}} in {{length}} words")
|
||||||
|
).toBeInTheDocument();
|
||||||
|
expect(
|
||||||
|
screen.getByText("Simple prompt with no variables")
|
||||||
|
).toBeInTheDocument();
|
||||||
|
|
||||||
|
// Check versions
|
||||||
|
expect(screen.getAllByText("1")).toHaveLength(2); // Two prompts with version 1
|
||||||
|
expect(screen.getByText("2")).toBeInTheDocument();
|
||||||
|
|
||||||
|
// Check default badge
|
||||||
|
expect(screen.getByText("Default")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("renders variables correctly", () => {
|
||||||
|
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
|
||||||
|
|
||||||
|
// Check variables display
|
||||||
|
expect(screen.getByText("name")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("text")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("length")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("None")).toBeInTheDocument(); // For prompt with no variables
|
||||||
|
});
|
||||||
|
|
||||||
|
test("prompt ID links are clickable and call onEdit", () => {
|
||||||
|
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
|
||||||
|
|
||||||
|
// Click on the first prompt ID link
|
||||||
|
const promptLink = screen.getByRole("button", { name: "prompt_123" });
|
||||||
|
fireEvent.click(promptLink);
|
||||||
|
|
||||||
|
expect(mockOnEdit).toHaveBeenCalledWith(mockPrompts[0]);
|
||||||
|
});
|
||||||
|
|
||||||
|
test("edit buttons call onEdit", () => {
|
||||||
|
const { container } = render(
|
||||||
|
<PromptList {...defaultProps} prompts={mockPrompts} />
|
||||||
|
);
|
||||||
|
|
||||||
|
// Find the action buttons in the table - they should be in the last column
|
||||||
|
const actionCells = container.querySelectorAll("td:last-child");
|
||||||
|
const firstActionCell = actionCells[0];
|
||||||
|
const editButton = firstActionCell?.querySelector("button");
|
||||||
|
|
||||||
|
expect(editButton).toBeInTheDocument();
|
||||||
|
fireEvent.click(editButton!);
|
||||||
|
|
||||||
|
expect(mockOnEdit).toHaveBeenCalledWith(mockPrompts[0]);
|
||||||
|
});
|
||||||
|
|
||||||
|
test("delete buttons call onDelete with confirmation", () => {
|
||||||
|
const originalConfirm = window.confirm;
|
||||||
|
window.confirm = jest.fn(() => true);
|
||||||
|
|
||||||
|
const { container } = render(
|
||||||
|
<PromptList {...defaultProps} prompts={mockPrompts} />
|
||||||
|
);
|
||||||
|
|
||||||
|
// Find the delete button (second button in the first action cell)
|
||||||
|
const actionCells = container.querySelectorAll("td:last-child");
|
||||||
|
const firstActionCell = actionCells[0];
|
||||||
|
const buttons = firstActionCell?.querySelectorAll("button");
|
||||||
|
const deleteButton = buttons?.[1]; // Second button should be delete
|
||||||
|
|
||||||
|
expect(deleteButton).toBeInTheDocument();
|
||||||
|
fireEvent.click(deleteButton!);
|
||||||
|
|
||||||
|
expect(window.confirm).toHaveBeenCalledWith(
|
||||||
|
"Are you sure you want to delete this prompt? This action cannot be undone."
|
||||||
|
);
|
||||||
|
expect(mockOnDelete).toHaveBeenCalledWith("prompt_123");
|
||||||
|
|
||||||
|
window.confirm = originalConfirm;
|
||||||
|
});
|
||||||
|
|
||||||
|
test("delete does not execute when confirmation is cancelled", () => {
|
||||||
|
const originalConfirm = window.confirm;
|
||||||
|
window.confirm = jest.fn(() => false);
|
||||||
|
|
||||||
|
const { container } = render(
|
||||||
|
<PromptList {...defaultProps} prompts={mockPrompts} />
|
||||||
|
);
|
||||||
|
|
||||||
|
const actionCells = container.querySelectorAll("td:last-child");
|
||||||
|
const firstActionCell = actionCells[0];
|
||||||
|
const buttons = firstActionCell?.querySelectorAll("button");
|
||||||
|
const deleteButton = buttons?.[1]; // Second button should be delete
|
||||||
|
|
||||||
|
expect(deleteButton).toBeInTheDocument();
|
||||||
|
fireEvent.click(deleteButton!);
|
||||||
|
|
||||||
|
expect(mockOnDelete).not.toHaveBeenCalled();
|
||||||
|
|
||||||
|
window.confirm = originalConfirm;
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Search Functionality", () => {
|
||||||
|
const mockPrompts: Prompt[] = [
|
||||||
|
{
|
||||||
|
prompt_id: "user_greeting",
|
||||||
|
prompt: "Hello {{name}}, welcome!",
|
||||||
|
version: 1,
|
||||||
|
variables: ["name"],
|
||||||
|
is_default: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
prompt_id: "system_summary",
|
||||||
|
prompt: "Summarize the following text",
|
||||||
|
version: 1,
|
||||||
|
variables: [],
|
||||||
|
is_default: false,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
test("filters prompts by prompt ID", () => {
|
||||||
|
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
|
||||||
|
|
||||||
|
const searchInput = screen.getByPlaceholderText("Search prompts...");
|
||||||
|
fireEvent.change(searchInput, { target: { value: "user" } });
|
||||||
|
|
||||||
|
expect(screen.getByText("user_greeting")).toBeInTheDocument();
|
||||||
|
expect(screen.queryByText("system_summary")).not.toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("filters prompts by content", () => {
|
||||||
|
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
|
||||||
|
|
||||||
|
const searchInput = screen.getByPlaceholderText("Search prompts...");
|
||||||
|
fireEvent.change(searchInput, { target: { value: "welcome" } });
|
||||||
|
|
||||||
|
expect(screen.getByText("user_greeting")).toBeInTheDocument();
|
||||||
|
expect(screen.queryByText("system_summary")).not.toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("search is case insensitive", () => {
|
||||||
|
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
|
||||||
|
|
||||||
|
const searchInput = screen.getByPlaceholderText("Search prompts...");
|
||||||
|
fireEvent.change(searchInput, { target: { value: "HELLO" } });
|
||||||
|
|
||||||
|
expect(screen.getByText("user_greeting")).toBeInTheDocument();
|
||||||
|
expect(screen.queryByText("system_summary")).not.toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("clearing search shows all prompts", () => {
|
||||||
|
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
|
||||||
|
|
||||||
|
const searchInput = screen.getByPlaceholderText("Search prompts...");
|
||||||
|
|
||||||
|
// Filter first
|
||||||
|
fireEvent.change(searchInput, { target: { value: "user" } });
|
||||||
|
expect(screen.queryByText("system_summary")).not.toBeInTheDocument();
|
||||||
|
|
||||||
|
// Clear search
|
||||||
|
fireEvent.change(searchInput, { target: { value: "" } });
|
||||||
|
expect(screen.getByText("user_greeting")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("system_summary")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
164
src/llama_stack/ui/components/prompts/prompt-list.tsx
Normal file
164
src/llama_stack/ui/components/prompts/prompt-list.tsx
Normal file
|
|
@ -0,0 +1,164 @@
|
||||||
|
"use client";
|
||||||
|
|
||||||
|
import { useState } from "react";
|
||||||
|
import { Badge } from "@/components/ui/badge";
|
||||||
|
import { Button } from "@/components/ui/button";
|
||||||
|
import {
|
||||||
|
Table,
|
||||||
|
TableBody,
|
||||||
|
TableCell,
|
||||||
|
TableHead,
|
||||||
|
TableHeader,
|
||||||
|
TableRow,
|
||||||
|
} from "@/components/ui/table";
|
||||||
|
import { Input } from "@/components/ui/input";
|
||||||
|
import { Edit, Search, Trash2 } from "lucide-react";
|
||||||
|
import { Prompt, PromptFilters } from "./types";
|
||||||
|
|
||||||
|
interface PromptListProps {
|
||||||
|
prompts: Prompt[];
|
||||||
|
onEdit: (prompt: Prompt) => void;
|
||||||
|
onDelete: (promptId: string) => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function PromptList({ prompts, onEdit, onDelete }: PromptListProps) {
|
||||||
|
const [filters, setFilters] = useState<PromptFilters>({});
|
||||||
|
|
||||||
|
const filteredPrompts = prompts.filter(prompt => {
|
||||||
|
if (
|
||||||
|
filters.searchTerm &&
|
||||||
|
!(
|
||||||
|
prompt.prompt
|
||||||
|
?.toLowerCase()
|
||||||
|
.includes(filters.searchTerm.toLowerCase()) ||
|
||||||
|
prompt.prompt_id
|
||||||
|
.toLowerCase()
|
||||||
|
.includes(filters.searchTerm.toLowerCase())
|
||||||
|
)
|
||||||
|
) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
});
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="space-y-4">
|
||||||
|
{/* Filters */}
|
||||||
|
<div className="flex flex-col sm:flex-row gap-4">
|
||||||
|
<div className="relative flex-1">
|
||||||
|
<Search className="absolute left-3 top-1/2 transform -translate-y-1/2 text-muted-foreground h-4 w-4" />
|
||||||
|
<Input
|
||||||
|
placeholder="Search prompts..."
|
||||||
|
value={filters.searchTerm || ""}
|
||||||
|
onChange={e =>
|
||||||
|
setFilters(prev => ({ ...prev, searchTerm: e.target.value }))
|
||||||
|
}
|
||||||
|
className="pl-10"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Prompts Table */}
|
||||||
|
<div className="overflow-auto">
|
||||||
|
<Table>
|
||||||
|
<TableHeader>
|
||||||
|
<TableRow>
|
||||||
|
<TableHead>ID</TableHead>
|
||||||
|
<TableHead>Content</TableHead>
|
||||||
|
<TableHead>Variables</TableHead>
|
||||||
|
<TableHead>Version</TableHead>
|
||||||
|
<TableHead>Actions</TableHead>
|
||||||
|
</TableRow>
|
||||||
|
</TableHeader>
|
||||||
|
<TableBody>
|
||||||
|
{filteredPrompts.map(prompt => (
|
||||||
|
<TableRow key={prompt.prompt_id}>
|
||||||
|
<TableCell className="max-w-48">
|
||||||
|
<Button
|
||||||
|
variant="link"
|
||||||
|
className="p-0 h-auto font-mono text-blue-600 hover:text-blue-800 dark:text-blue-400 dark:hover:text-blue-300 max-w-full justify-start"
|
||||||
|
onClick={() => onEdit(prompt)}
|
||||||
|
title={prompt.prompt_id}
|
||||||
|
>
|
||||||
|
<div className="truncate">{prompt.prompt_id}</div>
|
||||||
|
</Button>
|
||||||
|
</TableCell>
|
||||||
|
<TableCell className="max-w-64">
|
||||||
|
<div
|
||||||
|
className="font-mono text-xs text-muted-foreground truncate"
|
||||||
|
title={prompt.prompt || "No content"}
|
||||||
|
>
|
||||||
|
{prompt.prompt || "No content"}
|
||||||
|
</div>
|
||||||
|
</TableCell>
|
||||||
|
<TableCell>
|
||||||
|
{prompt.variables.length > 0 ? (
|
||||||
|
<div className="flex flex-wrap gap-1">
|
||||||
|
{prompt.variables.map(variable => (
|
||||||
|
<Badge
|
||||||
|
key={variable}
|
||||||
|
variant="outline"
|
||||||
|
className="text-xs"
|
||||||
|
>
|
||||||
|
{variable}
|
||||||
|
</Badge>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
<span className="text-muted-foreground text-sm">None</span>
|
||||||
|
)}
|
||||||
|
</TableCell>
|
||||||
|
<TableCell className="text-sm">
|
||||||
|
{prompt.version}
|
||||||
|
{prompt.is_default && (
|
||||||
|
<Badge variant="secondary" className="text-xs ml-2">
|
||||||
|
Default
|
||||||
|
</Badge>
|
||||||
|
)}
|
||||||
|
</TableCell>
|
||||||
|
<TableCell>
|
||||||
|
<div className="flex gap-1">
|
||||||
|
<Button
|
||||||
|
size="sm"
|
||||||
|
variant="outline"
|
||||||
|
onClick={() => onEdit(prompt)}
|
||||||
|
className="h-8 w-8 p-0"
|
||||||
|
>
|
||||||
|
<Edit className="h-3 w-3" />
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
size="sm"
|
||||||
|
variant="outline"
|
||||||
|
onClick={() => {
|
||||||
|
if (
|
||||||
|
confirm(
|
||||||
|
`Are you sure you want to delete this prompt? This action cannot be undone.`
|
||||||
|
)
|
||||||
|
) {
|
||||||
|
onDelete(prompt.prompt_id);
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
className="h-8 w-8 p-0 text-destructive hover:text-destructive"
|
||||||
|
>
|
||||||
|
<Trash2 className="h-3 w-3" />
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</TableCell>
|
||||||
|
</TableRow>
|
||||||
|
))}
|
||||||
|
</TableBody>
|
||||||
|
</Table>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{filteredPrompts.length === 0 && (
|
||||||
|
<div className="text-center py-12">
|
||||||
|
<div className="text-muted-foreground">
|
||||||
|
{prompts.length === 0
|
||||||
|
? "No prompts yet"
|
||||||
|
: "No prompts match your filters"}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
304
src/llama_stack/ui/components/prompts/prompt-management.test.tsx
Normal file
304
src/llama_stack/ui/components/prompts/prompt-management.test.tsx
Normal file
|
|
@ -0,0 +1,304 @@
|
||||||
|
import React from "react";
|
||||||
|
import { render, screen, fireEvent, waitFor } from "@testing-library/react";
|
||||||
|
import "@testing-library/jest-dom";
|
||||||
|
import { PromptManagement } from "./prompt-management";
|
||||||
|
import type { Prompt } from "./types";
|
||||||
|
|
||||||
|
// Mock the auth client
|
||||||
|
const mockPromptsClient = {
|
||||||
|
list: jest.fn(),
|
||||||
|
create: jest.fn(),
|
||||||
|
update: jest.fn(),
|
||||||
|
delete: jest.fn(),
|
||||||
|
};
|
||||||
|
|
||||||
|
jest.mock("@/hooks/use-auth-client", () => ({
|
||||||
|
useAuthClient: () => ({
|
||||||
|
prompts: mockPromptsClient,
|
||||||
|
}),
|
||||||
|
}));
|
||||||
|
|
||||||
|
describe("PromptManagement", () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
jest.clearAllMocks();
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Loading State", () => {
|
||||||
|
test("renders loading state initially", () => {
|
||||||
|
mockPromptsClient.list.mockReturnValue(new Promise(() => {})); // Never resolves
|
||||||
|
render(<PromptManagement />);
|
||||||
|
|
||||||
|
expect(screen.getByText("Loading prompts...")).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Prompts")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Empty State", () => {
|
||||||
|
test("renders empty state when no prompts", async () => {
|
||||||
|
mockPromptsClient.list.mockResolvedValue([]);
|
||||||
|
render(<PromptManagement />);
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText("No prompts found.")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(screen.getByText("Create Your First Prompt")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("opens modal when clicking 'Create Your First Prompt'", async () => {
|
||||||
|
mockPromptsClient.list.mockResolvedValue([]);
|
||||||
|
render(<PromptManagement />);
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(
|
||||||
|
screen.getByText("Create Your First Prompt")
|
||||||
|
).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
fireEvent.click(screen.getByText("Create Your First Prompt"));
|
||||||
|
|
||||||
|
expect(screen.getByText("Create New Prompt")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Error State", () => {
|
||||||
|
test("renders error state when API fails", async () => {
|
||||||
|
const error = new Error("API not found");
|
||||||
|
mockPromptsClient.list.mockRejectedValue(error);
|
||||||
|
render(<PromptManagement />);
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText(/Error:/)).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
test("renders specific error for 404", async () => {
|
||||||
|
const error = new Error("404 Not found");
|
||||||
|
mockPromptsClient.list.mockRejectedValue(error);
|
||||||
|
render(<PromptManagement />);
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(
|
||||||
|
screen.getByText(/Prompts API endpoint not found/)
|
||||||
|
).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Prompts List", () => {
|
||||||
|
const mockPrompts: Prompt[] = [
|
||||||
|
{
|
||||||
|
prompt_id: "prompt_123",
|
||||||
|
prompt: "Hello {{name}}, how are you?",
|
||||||
|
version: 1,
|
||||||
|
variables: ["name"],
|
||||||
|
is_default: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
prompt_id: "prompt_456",
|
||||||
|
prompt: "Summarize this {{text}}",
|
||||||
|
version: 2,
|
||||||
|
variables: ["text"],
|
||||||
|
is_default: false,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
test("renders prompts list correctly", async () => {
|
||||||
|
mockPromptsClient.list.mockResolvedValue(mockPrompts);
|
||||||
|
render(<PromptManagement />);
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText("prompt_123")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(screen.getByText("prompt_456")).toBeInTheDocument();
|
||||||
|
expect(
|
||||||
|
screen.getByText("Hello {{name}}, how are you?")
|
||||||
|
).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("Summarize this {{text}}")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("opens modal when clicking 'New Prompt' button", async () => {
|
||||||
|
mockPromptsClient.list.mockResolvedValue(mockPrompts);
|
||||||
|
render(<PromptManagement />);
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText("prompt_123")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
fireEvent.click(screen.getByText("New Prompt"));
|
||||||
|
|
||||||
|
expect(screen.getByText("Create New Prompt")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("Modal Operations", () => {
|
||||||
|
const mockPrompts: Prompt[] = [
|
||||||
|
{
|
||||||
|
prompt_id: "prompt_123",
|
||||||
|
prompt: "Hello {{name}}",
|
||||||
|
version: 1,
|
||||||
|
variables: ["name"],
|
||||||
|
is_default: true,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
test("closes modal when clicking cancel", async () => {
|
||||||
|
mockPromptsClient.list.mockResolvedValue(mockPrompts);
|
||||||
|
render(<PromptManagement />);
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText("prompt_123")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
// Open modal
|
||||||
|
fireEvent.click(screen.getByText("New Prompt"));
|
||||||
|
expect(screen.getByText("Create New Prompt")).toBeInTheDocument();
|
||||||
|
|
||||||
|
// Close modal
|
||||||
|
fireEvent.click(screen.getByText("Cancel"));
|
||||||
|
expect(screen.queryByText("Create New Prompt")).not.toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
test("creates new prompt successfully", async () => {
|
||||||
|
const newPrompt: Prompt = {
|
||||||
|
prompt_id: "prompt_new",
|
||||||
|
prompt: "New prompt content",
|
||||||
|
version: 1,
|
||||||
|
variables: [],
|
||||||
|
is_default: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
mockPromptsClient.list.mockResolvedValue(mockPrompts);
|
||||||
|
mockPromptsClient.create.mockResolvedValue(newPrompt);
|
||||||
|
render(<PromptManagement />);
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText("prompt_123")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
// Open modal
|
||||||
|
fireEvent.click(screen.getByText("New Prompt"));
|
||||||
|
|
||||||
|
// Fill form
|
||||||
|
const promptInput = screen.getByLabelText("Prompt Content *");
|
||||||
|
fireEvent.change(promptInput, {
|
||||||
|
target: { value: "New prompt content" },
|
||||||
|
});
|
||||||
|
|
||||||
|
// Submit form
|
||||||
|
fireEvent.click(screen.getByText("Create Prompt"));
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(mockPromptsClient.create).toHaveBeenCalledWith({
|
||||||
|
prompt: "New prompt content",
|
||||||
|
variables: [],
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
test("handles create error gracefully", async () => {
|
||||||
|
const error = {
|
||||||
|
detail: {
|
||||||
|
errors: [{ msg: "Prompt contains undeclared variables: ['test']" }],
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
mockPromptsClient.list.mockResolvedValue(mockPrompts);
|
||||||
|
mockPromptsClient.create.mockRejectedValue(error);
|
||||||
|
render(<PromptManagement />);
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText("prompt_123")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
// Open modal
|
||||||
|
fireEvent.click(screen.getByText("New Prompt"));
|
||||||
|
|
||||||
|
// Fill form
|
||||||
|
const promptInput = screen.getByLabelText("Prompt Content *");
|
||||||
|
fireEvent.change(promptInput, { target: { value: "Hello {{test}}" } });
|
||||||
|
|
||||||
|
// Submit form
|
||||||
|
fireEvent.click(screen.getByText("Create Prompt"));
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(
|
||||||
|
screen.getByText("Prompt contains undeclared variables: ['test']")
|
||||||
|
).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
test("updates existing prompt successfully", async () => {
|
||||||
|
const updatedPrompt: Prompt = {
|
||||||
|
...mockPrompts[0],
|
||||||
|
prompt: "Updated content",
|
||||||
|
};
|
||||||
|
|
||||||
|
mockPromptsClient.list.mockResolvedValue(mockPrompts);
|
||||||
|
mockPromptsClient.update.mockResolvedValue(updatedPrompt);
|
||||||
|
const { container } = render(<PromptManagement />);
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText("prompt_123")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
// Click edit button (first button in the action cell of the first row)
|
||||||
|
const actionCells = container.querySelectorAll("td:last-child");
|
||||||
|
const firstActionCell = actionCells[0];
|
||||||
|
const editButton = firstActionCell?.querySelector("button");
|
||||||
|
|
||||||
|
expect(editButton).toBeInTheDocument();
|
||||||
|
fireEvent.click(editButton!);
|
||||||
|
|
||||||
|
expect(screen.getByText("Edit Prompt")).toBeInTheDocument();
|
||||||
|
|
||||||
|
// Update content
|
||||||
|
const promptInput = screen.getByLabelText("Prompt Content *");
|
||||||
|
fireEvent.change(promptInput, { target: { value: "Updated content" } });
|
||||||
|
|
||||||
|
// Submit form
|
||||||
|
fireEvent.click(screen.getByText("Update Prompt"));
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(mockPromptsClient.update).toHaveBeenCalledWith("prompt_123", {
|
||||||
|
prompt: "Updated content",
|
||||||
|
variables: ["name"],
|
||||||
|
version: 1,
|
||||||
|
set_as_default: true,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
test("deletes prompt successfully", async () => {
|
||||||
|
mockPromptsClient.list.mockResolvedValue(mockPrompts);
|
||||||
|
mockPromptsClient.delete.mockResolvedValue(undefined);
|
||||||
|
|
||||||
|
// Mock window.confirm
|
||||||
|
const originalConfirm = window.confirm;
|
||||||
|
window.confirm = jest.fn(() => true);
|
||||||
|
|
||||||
|
const { container } = render(<PromptManagement />);
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.getByText("prompt_123")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
// Click delete button (second button in the action cell of the first row)
|
||||||
|
const actionCells = container.querySelectorAll("td:last-child");
|
||||||
|
const firstActionCell = actionCells[0];
|
||||||
|
const buttons = firstActionCell?.querySelectorAll("button");
|
||||||
|
const deleteButton = buttons?.[1]; // Second button should be delete
|
||||||
|
|
||||||
|
expect(deleteButton).toBeInTheDocument();
|
||||||
|
fireEvent.click(deleteButton!);
|
||||||
|
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(mockPromptsClient.delete).toHaveBeenCalledWith("prompt_123");
|
||||||
|
});
|
||||||
|
|
||||||
|
// Restore window.confirm
|
||||||
|
window.confirm = originalConfirm;
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
233
src/llama_stack/ui/components/prompts/prompt-management.tsx
Normal file
233
src/llama_stack/ui/components/prompts/prompt-management.tsx
Normal file
|
|
@ -0,0 +1,233 @@
|
||||||
|
"use client";
|
||||||
|
|
||||||
|
import { useState, useEffect } from "react";
|
||||||
|
import { Button } from "@/components/ui/button";
|
||||||
|
import { Plus } from "lucide-react";
|
||||||
|
import { PromptList } from "./prompt-list";
|
||||||
|
import { PromptEditor } from "./prompt-editor";
|
||||||
|
import { Prompt, PromptFormData } from "./types";
|
||||||
|
import { useAuthClient } from "@/hooks/use-auth-client";
|
||||||
|
|
||||||
|
export function PromptManagement() {
|
||||||
|
const [prompts, setPrompts] = useState<Prompt[]>([]);
|
||||||
|
const [showPromptModal, setShowPromptModal] = useState(false);
|
||||||
|
const [editingPrompt, setEditingPrompt] = useState<Prompt | undefined>();
|
||||||
|
const [loading, setLoading] = useState(true);
|
||||||
|
const [error, setError] = useState<string | null>(null); // For main page errors (loading, etc.)
|
||||||
|
const [modalError, setModalError] = useState<string | null>(null); // For form submission errors
|
||||||
|
const client = useAuthClient();
|
||||||
|
|
||||||
|
// Load prompts from API on component mount
|
||||||
|
useEffect(() => {
|
||||||
|
const fetchPrompts = async () => {
|
||||||
|
try {
|
||||||
|
setLoading(true);
|
||||||
|
setError(null);
|
||||||
|
|
||||||
|
const response = await client.prompts.list();
|
||||||
|
setPrompts(response || []);
|
||||||
|
} catch (err: unknown) {
|
||||||
|
console.error("Failed to load prompts:", err);
|
||||||
|
|
||||||
|
// Handle different types of errors
|
||||||
|
const error = err as Error & { status?: number };
|
||||||
|
if (error?.message?.includes("404") || error?.status === 404) {
|
||||||
|
setError(
|
||||||
|
"Prompts API endpoint not found. Please ensure your Llama Stack server supports the prompts API."
|
||||||
|
);
|
||||||
|
} else if (
|
||||||
|
error?.message?.includes("not implemented") ||
|
||||||
|
error?.message?.includes("not supported")
|
||||||
|
) {
|
||||||
|
setError(
|
||||||
|
"Prompts API is not yet implemented on this Llama Stack server."
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
setError(
|
||||||
|
`Failed to load prompts: ${error?.message || "Unknown error"}`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
setLoading(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
fetchPrompts();
|
||||||
|
}, [client]);
|
||||||
|
|
||||||
|
const handleSavePrompt = async (formData: PromptFormData) => {
|
||||||
|
try {
|
||||||
|
setModalError(null);
|
||||||
|
|
||||||
|
if (editingPrompt) {
|
||||||
|
// Update existing prompt
|
||||||
|
const response = await client.prompts.update(editingPrompt.prompt_id, {
|
||||||
|
prompt: formData.prompt,
|
||||||
|
variables: formData.variables,
|
||||||
|
version: editingPrompt.version,
|
||||||
|
set_as_default: true,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Update local state
|
||||||
|
setPrompts(prev =>
|
||||||
|
prev.map(p =>
|
||||||
|
p.prompt_id === editingPrompt.prompt_id ? response : p
|
||||||
|
)
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
// Create new prompt
|
||||||
|
const response = await client.prompts.create({
|
||||||
|
prompt: formData.prompt,
|
||||||
|
variables: formData.variables,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Add to local state
|
||||||
|
setPrompts(prev => [response, ...prev]);
|
||||||
|
}
|
||||||
|
|
||||||
|
setShowPromptModal(false);
|
||||||
|
setEditingPrompt(undefined);
|
||||||
|
} catch (err) {
|
||||||
|
console.error("Failed to save prompt:", err);
|
||||||
|
|
||||||
|
// Extract specific error message from API response
|
||||||
|
const error = err as Error & {
|
||||||
|
message?: string;
|
||||||
|
detail?: { errors?: Array<{ msg?: string }> };
|
||||||
|
};
|
||||||
|
|
||||||
|
// Try to parse JSON from error message if it's a string
|
||||||
|
let parsedError = error;
|
||||||
|
if (typeof error?.message === "string" && error.message.includes("{")) {
|
||||||
|
try {
|
||||||
|
const jsonMatch = error.message.match(/\d+\s+(.+)/);
|
||||||
|
if (jsonMatch) {
|
||||||
|
parsedError = JSON.parse(jsonMatch[1]);
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
// If parsing fails, use original error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to get the specific validation error message
|
||||||
|
const validationError = parsedError?.detail?.errors?.[0]?.msg;
|
||||||
|
if (validationError) {
|
||||||
|
// Clean up validation error messages (remove "Value error, " prefix if present)
|
||||||
|
const cleanMessage = validationError.replace(/^Value error,\s*/i, "");
|
||||||
|
setModalError(cleanMessage);
|
||||||
|
} else {
|
||||||
|
// For other errors, format them nicely with line breaks
|
||||||
|
const statusMatch = error?.message?.match(/(\d+)\s+(.+)/);
|
||||||
|
if (statusMatch) {
|
||||||
|
const statusCode = statusMatch[1];
|
||||||
|
const response = statusMatch[2];
|
||||||
|
setModalError(
|
||||||
|
`Failed to save prompt: Status Code ${statusCode}\n\nResponse: ${response}`
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
const message = error?.message || error?.detail || "Unknown error";
|
||||||
|
setModalError(`Failed to save prompt: ${message}`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleEditPrompt = (prompt: Prompt) => {
|
||||||
|
setEditingPrompt(prompt);
|
||||||
|
setShowPromptModal(true);
|
||||||
|
setModalError(null); // Clear any previous modal errors
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleDeletePrompt = async (promptId: string) => {
|
||||||
|
try {
|
||||||
|
setError(null);
|
||||||
|
await client.prompts.delete(promptId);
|
||||||
|
setPrompts(prev => prev.filter(p => p.prompt_id !== promptId));
|
||||||
|
|
||||||
|
// If we're deleting the currently editing prompt, close the modal
|
||||||
|
if (editingPrompt && editingPrompt.prompt_id === promptId) {
|
||||||
|
setShowPromptModal(false);
|
||||||
|
setEditingPrompt(undefined);
|
||||||
|
}
|
||||||
|
} catch (err) {
|
||||||
|
console.error("Failed to delete prompt:", err);
|
||||||
|
setError("Failed to delete prompt");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleCreateNew = () => {
|
||||||
|
setEditingPrompt(undefined);
|
||||||
|
setShowPromptModal(true);
|
||||||
|
setModalError(null); // Clear any previous modal errors
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleCancel = () => {
|
||||||
|
setShowPromptModal(false);
|
||||||
|
setEditingPrompt(undefined);
|
||||||
|
};
|
||||||
|
|
||||||
|
const renderContent = () => {
|
||||||
|
if (loading) {
|
||||||
|
return <div className="text-muted-foreground">Loading prompts...</div>;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (error) {
|
||||||
|
return <div className="text-destructive">Error: {error}</div>;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!prompts || prompts.length === 0) {
|
||||||
|
return (
|
||||||
|
<div className="text-center py-12">
|
||||||
|
<p className="text-muted-foreground mb-4">No prompts found.</p>
|
||||||
|
<Button onClick={handleCreateNew}>
|
||||||
|
<Plus className="h-4 w-4 mr-2" />
|
||||||
|
Create Your First Prompt
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<PromptList
|
||||||
|
prompts={prompts}
|
||||||
|
onEdit={handleEditPrompt}
|
||||||
|
onDelete={handleDeletePrompt}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="space-y-4">
|
||||||
|
<div className="flex items-center justify-between">
|
||||||
|
<h1 className="text-2xl font-semibold">Prompts</h1>
|
||||||
|
<Button onClick={handleCreateNew} disabled={loading}>
|
||||||
|
<Plus className="h-4 w-4 mr-2" />
|
||||||
|
New Prompt
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
{renderContent()}
|
||||||
|
|
||||||
|
{/* Create/Edit Prompt Modal */}
|
||||||
|
{showPromptModal && (
|
||||||
|
<div className="fixed inset-0 bg-black/50 flex items-center justify-center z-50">
|
||||||
|
<div className="bg-background border rounded-lg shadow-lg max-w-4xl w-full mx-4 max-h-[90vh] overflow-hidden">
|
||||||
|
<div className="p-6 border-b">
|
||||||
|
<h2 className="text-2xl font-bold">
|
||||||
|
{editingPrompt ? "Edit Prompt" : "Create New Prompt"}
|
||||||
|
</h2>
|
||||||
|
</div>
|
||||||
|
<div className="p-6 overflow-y-auto max-h-[calc(90vh-120px)]">
|
||||||
|
<PromptEditor
|
||||||
|
prompt={editingPrompt}
|
||||||
|
onSave={handleSavePrompt}
|
||||||
|
onCancel={handleCancel}
|
||||||
|
onDelete={handleDeletePrompt}
|
||||||
|
error={modalError}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
16
src/llama_stack/ui/components/prompts/types.ts
Normal file
16
src/llama_stack/ui/components/prompts/types.ts
Normal file
|
|
@ -0,0 +1,16 @@
|
||||||
|
export interface Prompt {
|
||||||
|
prompt_id: string;
|
||||||
|
prompt: string | null;
|
||||||
|
version: number;
|
||||||
|
variables: string[];
|
||||||
|
is_default: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface PromptFormData {
|
||||||
|
prompt: string;
|
||||||
|
variables: string[];
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface PromptFilters {
|
||||||
|
searchTerm?: string;
|
||||||
|
}
|
||||||
36
src/llama_stack/ui/components/ui/badge.tsx
Normal file
36
src/llama_stack/ui/components/ui/badge.tsx
Normal file
|
|
@ -0,0 +1,36 @@
|
||||||
|
import * as React from "react";
|
||||||
|
import { cva, type VariantProps } from "class-variance-authority";
|
||||||
|
|
||||||
|
import { cn } from "@/lib/utils";
|
||||||
|
|
||||||
|
const badgeVariants = cva(
|
||||||
|
"inline-flex items-center rounded-full border px-2.5 py-0.5 text-xs font-semibold transition-colors focus:outline-none focus:ring-2 focus:ring-ring focus:ring-offset-2",
|
||||||
|
{
|
||||||
|
variants: {
|
||||||
|
variant: {
|
||||||
|
default:
|
||||||
|
"border-transparent bg-primary text-primary-foreground hover:bg-primary/80",
|
||||||
|
secondary:
|
||||||
|
"border-transparent bg-secondary text-secondary-foreground hover:bg-secondary/80",
|
||||||
|
destructive:
|
||||||
|
"border-transparent bg-destructive text-destructive-foreground hover:bg-destructive/80",
|
||||||
|
outline: "text-foreground",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
defaultVariants: {
|
||||||
|
variant: "default",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
export interface BadgeProps
|
||||||
|
extends React.HTMLAttributes<HTMLDivElement>,
|
||||||
|
VariantProps<typeof badgeVariants> {}
|
||||||
|
|
||||||
|
function Badge({ className, variant, ...props }: BadgeProps) {
|
||||||
|
return (
|
||||||
|
<div className={cn(badgeVariants({ variant }), className)} {...props} />
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export { Badge, badgeVariants };
|
||||||
24
src/llama_stack/ui/components/ui/label.tsx
Normal file
24
src/llama_stack/ui/components/ui/label.tsx
Normal file
|
|
@ -0,0 +1,24 @@
|
||||||
|
import * as React from "react";
|
||||||
|
import * as LabelPrimitive from "@radix-ui/react-label";
|
||||||
|
import { cva, type VariantProps } from "class-variance-authority";
|
||||||
|
|
||||||
|
import { cn } from "@/lib/utils";
|
||||||
|
|
||||||
|
const labelVariants = cva(
|
||||||
|
"text-sm font-medium leading-none peer-disabled:cursor-not-allowed peer-disabled:opacity-70"
|
||||||
|
);
|
||||||
|
|
||||||
|
const Label = React.forwardRef<
|
||||||
|
React.ElementRef<typeof LabelPrimitive.Root>,
|
||||||
|
React.ComponentPropsWithoutRef<typeof LabelPrimitive.Root> &
|
||||||
|
VariantProps<typeof labelVariants>
|
||||||
|
>(({ className, ...props }, ref) => (
|
||||||
|
<LabelPrimitive.Root
|
||||||
|
ref={ref}
|
||||||
|
className={cn(labelVariants(), className)}
|
||||||
|
{...props}
|
||||||
|
/>
|
||||||
|
));
|
||||||
|
Label.displayName = LabelPrimitive.Root.displayName;
|
||||||
|
|
||||||
|
export { Label };
|
||||||
53
src/llama_stack/ui/components/ui/tabs.tsx
Normal file
53
src/llama_stack/ui/components/ui/tabs.tsx
Normal file
|
|
@ -0,0 +1,53 @@
|
||||||
|
import * as React from "react";
|
||||||
|
import * as TabsPrimitive from "@radix-ui/react-tabs";
|
||||||
|
|
||||||
|
import { cn } from "@/lib/utils";
|
||||||
|
|
||||||
|
const Tabs = TabsPrimitive.Root;
|
||||||
|
|
||||||
|
const TabsList = React.forwardRef<
|
||||||
|
React.ElementRef<typeof TabsPrimitive.List>,
|
||||||
|
React.ComponentPropsWithoutRef<typeof TabsPrimitive.List>
|
||||||
|
>(({ className, ...props }, ref) => (
|
||||||
|
<TabsPrimitive.List
|
||||||
|
ref={ref}
|
||||||
|
className={cn(
|
||||||
|
"inline-flex h-10 items-center justify-center rounded-md bg-muted p-1 text-muted-foreground",
|
||||||
|
className
|
||||||
|
)}
|
||||||
|
{...props}
|
||||||
|
/>
|
||||||
|
));
|
||||||
|
TabsList.displayName = TabsPrimitive.List.displayName;
|
||||||
|
|
||||||
|
const TabsTrigger = React.forwardRef<
|
||||||
|
React.ElementRef<typeof TabsPrimitive.Trigger>,
|
||||||
|
React.ComponentPropsWithoutRef<typeof TabsPrimitive.Trigger>
|
||||||
|
>(({ className, ...props }, ref) => (
|
||||||
|
<TabsPrimitive.Trigger
|
||||||
|
ref={ref}
|
||||||
|
className={cn(
|
||||||
|
"inline-flex items-center justify-center whitespace-nowrap rounded-sm px-3 py-1.5 text-sm font-medium ring-offset-background transition-all focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:pointer-events-none disabled:opacity-50 data-[state=active]:bg-background data-[state=active]:text-foreground data-[state=active]:shadow-sm",
|
||||||
|
className
|
||||||
|
)}
|
||||||
|
{...props}
|
||||||
|
/>
|
||||||
|
));
|
||||||
|
TabsTrigger.displayName = TabsPrimitive.Trigger.displayName;
|
||||||
|
|
||||||
|
const TabsContent = React.forwardRef<
|
||||||
|
React.ElementRef<typeof TabsPrimitive.Content>,
|
||||||
|
React.ComponentPropsWithoutRef<typeof TabsPrimitive.Content>
|
||||||
|
>(({ className, ...props }, ref) => (
|
||||||
|
<TabsPrimitive.Content
|
||||||
|
ref={ref}
|
||||||
|
className={cn(
|
||||||
|
"mt-2 ring-offset-background focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2",
|
||||||
|
className
|
||||||
|
)}
|
||||||
|
{...props}
|
||||||
|
/>
|
||||||
|
));
|
||||||
|
TabsContent.displayName = TabsPrimitive.Content.displayName;
|
||||||
|
|
||||||
|
export { Tabs, TabsList, TabsTrigger, TabsContent };
|
||||||
23
src/llama_stack/ui/components/ui/textarea.tsx
Normal file
23
src/llama_stack/ui/components/ui/textarea.tsx
Normal file
|
|
@ -0,0 +1,23 @@
|
||||||
|
import * as React from "react";
|
||||||
|
|
||||||
|
import { cn } from "@/lib/utils";
|
||||||
|
|
||||||
|
export type TextareaProps = React.TextareaHTMLAttributes<HTMLTextAreaElement>;
|
||||||
|
|
||||||
|
const Textarea = React.forwardRef<HTMLTextAreaElement, TextareaProps>(
|
||||||
|
({ className, ...props }, ref) => {
|
||||||
|
return (
|
||||||
|
<textarea
|
||||||
|
className={cn(
|
||||||
|
"flex min-h-[80px] w-full rounded-md border border-input bg-background px-3 py-2 text-sm ring-offset-background placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-50",
|
||||||
|
className
|
||||||
|
)}
|
||||||
|
ref={ref}
|
||||||
|
{...props}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
);
|
||||||
|
Textarea.displayName = "Textarea";
|
||||||
|
|
||||||
|
export { Textarea };
|
||||||
62
src/llama_stack/ui/package-lock.json
generated
62
src/llama_stack/ui/package-lock.json
generated
|
|
@ -11,14 +11,16 @@
|
||||||
"@radix-ui/react-collapsible": "^1.1.12",
|
"@radix-ui/react-collapsible": "^1.1.12",
|
||||||
"@radix-ui/react-dialog": "^1.1.15",
|
"@radix-ui/react-dialog": "^1.1.15",
|
||||||
"@radix-ui/react-dropdown-menu": "^2.1.16",
|
"@radix-ui/react-dropdown-menu": "^2.1.16",
|
||||||
|
"@radix-ui/react-label": "^2.1.7",
|
||||||
"@radix-ui/react-select": "^2.2.6",
|
"@radix-ui/react-select": "^2.2.6",
|
||||||
"@radix-ui/react-separator": "^1.1.7",
|
"@radix-ui/react-separator": "^1.1.7",
|
||||||
"@radix-ui/react-slot": "^1.2.3",
|
"@radix-ui/react-slot": "^1.2.3",
|
||||||
|
"@radix-ui/react-tabs": "^1.1.13",
|
||||||
"@radix-ui/react-tooltip": "^1.2.8",
|
"@radix-ui/react-tooltip": "^1.2.8",
|
||||||
"class-variance-authority": "^0.7.1",
|
"class-variance-authority": "^0.7.1",
|
||||||
"clsx": "^2.1.1",
|
"clsx": "^2.1.1",
|
||||||
"framer-motion": "^12.23.24",
|
"framer-motion": "^12.23.24",
|
||||||
"llama-stack-client": "^0.3.0",
|
"llama-stack-client": "github:llamastack/llama-stack-client-typescript",
|
||||||
"lucide-react": "^0.545.0",
|
"lucide-react": "^0.545.0",
|
||||||
"next": "15.5.4",
|
"next": "15.5.4",
|
||||||
"next-auth": "^4.24.11",
|
"next-auth": "^4.24.11",
|
||||||
|
|
@ -2597,6 +2599,29 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/@radix-ui/react-label": {
|
||||||
|
"version": "2.1.7",
|
||||||
|
"resolved": "https://registry.npmjs.org/@radix-ui/react-label/-/react-label-2.1.7.tgz",
|
||||||
|
"integrity": "sha512-YT1GqPSL8kJn20djelMX7/cTRp/Y9w5IZHvfxQTVHrOqa2yMl7i/UfMqKRU5V7mEyKTrUVgJXhNQPVCG8PBLoQ==",
|
||||||
|
"license": "MIT",
|
||||||
|
"dependencies": {
|
||||||
|
"@radix-ui/react-primitive": "2.1.3"
|
||||||
|
},
|
||||||
|
"peerDependencies": {
|
||||||
|
"@types/react": "*",
|
||||||
|
"@types/react-dom": "*",
|
||||||
|
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
|
||||||
|
"react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||||
|
},
|
||||||
|
"peerDependenciesMeta": {
|
||||||
|
"@types/react": {
|
||||||
|
"optional": true
|
||||||
|
},
|
||||||
|
"@types/react-dom": {
|
||||||
|
"optional": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"node_modules/@radix-ui/react-menu": {
|
"node_modules/@radix-ui/react-menu": {
|
||||||
"version": "2.1.16",
|
"version": "2.1.16",
|
||||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-menu/-/react-menu-2.1.16.tgz",
|
"resolved": "https://registry.npmjs.org/@radix-ui/react-menu/-/react-menu-2.1.16.tgz",
|
||||||
|
|
@ -2855,6 +2880,36 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"node_modules/@radix-ui/react-tabs": {
|
||||||
|
"version": "1.1.13",
|
||||||
|
"resolved": "https://registry.npmjs.org/@radix-ui/react-tabs/-/react-tabs-1.1.13.tgz",
|
||||||
|
"integrity": "sha512-7xdcatg7/U+7+Udyoj2zodtI9H/IIopqo+YOIcZOq1nJwXWBZ9p8xiu5llXlekDbZkca79a/fozEYQXIA4sW6A==",
|
||||||
|
"license": "MIT",
|
||||||
|
"dependencies": {
|
||||||
|
"@radix-ui/primitive": "1.1.3",
|
||||||
|
"@radix-ui/react-context": "1.1.2",
|
||||||
|
"@radix-ui/react-direction": "1.1.1",
|
||||||
|
"@radix-ui/react-id": "1.1.1",
|
||||||
|
"@radix-ui/react-presence": "1.1.5",
|
||||||
|
"@radix-ui/react-primitive": "2.1.3",
|
||||||
|
"@radix-ui/react-roving-focus": "1.1.11",
|
||||||
|
"@radix-ui/react-use-controllable-state": "1.2.2"
|
||||||
|
},
|
||||||
|
"peerDependencies": {
|
||||||
|
"@types/react": "*",
|
||||||
|
"@types/react-dom": "*",
|
||||||
|
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
|
||||||
|
"react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||||
|
},
|
||||||
|
"peerDependenciesMeta": {
|
||||||
|
"@types/react": {
|
||||||
|
"optional": true
|
||||||
|
},
|
||||||
|
"@types/react-dom": {
|
||||||
|
"optional": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
"node_modules/@radix-ui/react-tooltip": {
|
"node_modules/@radix-ui/react-tooltip": {
|
||||||
"version": "1.2.8",
|
"version": "1.2.8",
|
||||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-tooltip/-/react-tooltip-1.2.8.tgz",
|
"resolved": "https://registry.npmjs.org/@radix-ui/react-tooltip/-/react-tooltip-1.2.8.tgz",
|
||||||
|
|
@ -9629,9 +9684,8 @@
|
||||||
"license": "MIT"
|
"license": "MIT"
|
||||||
},
|
},
|
||||||
"node_modules/llama-stack-client": {
|
"node_modules/llama-stack-client": {
|
||||||
"version": "0.3.0",
|
"version": "0.4.0-alpha.1",
|
||||||
"resolved": "https://registry.npmjs.org/llama-stack-client/-/llama-stack-client-0.3.0.tgz",
|
"resolved": "git+ssh://git@github.com/llamastack/llama-stack-client-typescript.git#78de4862c4b7d77939ac210fa9f9bde77a2c5c5f",
|
||||||
"integrity": "sha512-76K/t1doaGmlBbDxCADaral9Vccvys9P8pqAMIhwBhMAqWudCEORrMMhUSg+pjhamWmEKj3wa++d4zeOGbfN/w==",
|
|
||||||
"license": "MIT",
|
"license": "MIT",
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@types/node": "^18.11.18",
|
"@types/node": "^18.11.18",
|
||||||
|
|
|
||||||
|
|
@ -16,14 +16,16 @@
|
||||||
"@radix-ui/react-collapsible": "^1.1.12",
|
"@radix-ui/react-collapsible": "^1.1.12",
|
||||||
"@radix-ui/react-dialog": "^1.1.15",
|
"@radix-ui/react-dialog": "^1.1.15",
|
||||||
"@radix-ui/react-dropdown-menu": "^2.1.16",
|
"@radix-ui/react-dropdown-menu": "^2.1.16",
|
||||||
|
"@radix-ui/react-label": "^2.1.7",
|
||||||
"@radix-ui/react-select": "^2.2.6",
|
"@radix-ui/react-select": "^2.2.6",
|
||||||
"@radix-ui/react-separator": "^1.1.7",
|
"@radix-ui/react-separator": "^1.1.7",
|
||||||
"@radix-ui/react-slot": "^1.2.3",
|
"@radix-ui/react-slot": "^1.2.3",
|
||||||
|
"@radix-ui/react-tabs": "^1.1.13",
|
||||||
"@radix-ui/react-tooltip": "^1.2.8",
|
"@radix-ui/react-tooltip": "^1.2.8",
|
||||||
"class-variance-authority": "^0.7.1",
|
"class-variance-authority": "^0.7.1",
|
||||||
"clsx": "^2.1.1",
|
"clsx": "^2.1.1",
|
||||||
"framer-motion": "^12.23.24",
|
"framer-motion": "^12.23.24",
|
||||||
"llama-stack-client": "^0.3.0",
|
"llama-stack-client": "github:llamastack/llama-stack-client-typescript",
|
||||||
"lucide-react": "^0.545.0",
|
"lucide-react": "^0.545.0",
|
||||||
"next": "15.5.4",
|
"next": "15.5.4",
|
||||||
"next-auth": "^4.24.11",
|
"next-auth": "^4.24.11",
|
||||||
|
|
|
||||||
|
|
@ -171,6 +171,10 @@ def pytest_addoption(parser):
|
||||||
"--embedding-model",
|
"--embedding-model",
|
||||||
help="comma-separated list of embedding models. Fixture name: embedding_model_id",
|
help="comma-separated list of embedding models. Fixture name: embedding_model_id",
|
||||||
)
|
)
|
||||||
|
parser.addoption(
|
||||||
|
"--rerank-model",
|
||||||
|
help="comma-separated list of rerank models. Fixture name: rerank_model_id",
|
||||||
|
)
|
||||||
parser.addoption(
|
parser.addoption(
|
||||||
"--safety-shield",
|
"--safety-shield",
|
||||||
help="comma-separated list of safety shields. Fixture name: shield_id",
|
help="comma-separated list of safety shields. Fixture name: shield_id",
|
||||||
|
|
@ -249,6 +253,7 @@ def pytest_generate_tests(metafunc):
|
||||||
"shield_id": ("--safety-shield", "shield"),
|
"shield_id": ("--safety-shield", "shield"),
|
||||||
"judge_model_id": ("--judge-model", "judge"),
|
"judge_model_id": ("--judge-model", "judge"),
|
||||||
"embedding_dimension": ("--embedding-dimension", "dim"),
|
"embedding_dimension": ("--embedding-dimension", "dim"),
|
||||||
|
"rerank_model_id": ("--rerank-model", "rerank"),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Collect all parameters and their values
|
# Collect all parameters and their values
|
||||||
|
|
|
||||||
|
|
@ -153,6 +153,7 @@ def client_with_models(
|
||||||
vision_model_id,
|
vision_model_id,
|
||||||
embedding_model_id,
|
embedding_model_id,
|
||||||
judge_model_id,
|
judge_model_id,
|
||||||
|
rerank_model_id,
|
||||||
):
|
):
|
||||||
client = llama_stack_client
|
client = llama_stack_client
|
||||||
|
|
||||||
|
|
@ -170,6 +171,9 @@ def client_with_models(
|
||||||
|
|
||||||
if embedding_model_id and embedding_model_id not in model_ids:
|
if embedding_model_id and embedding_model_id not in model_ids:
|
||||||
raise ValueError(f"embedding_model_id {embedding_model_id} not found")
|
raise ValueError(f"embedding_model_id {embedding_model_id} not found")
|
||||||
|
|
||||||
|
if rerank_model_id and rerank_model_id not in model_ids:
|
||||||
|
raise ValueError(f"rerank_model_id {rerank_model_id} not found")
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -185,7 +189,14 @@ def model_providers(llama_stack_client):
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def skip_if_no_model(request):
|
def skip_if_no_model(request):
|
||||||
model_fixtures = ["text_model_id", "vision_model_id", "embedding_model_id", "judge_model_id", "shield_id"]
|
model_fixtures = [
|
||||||
|
"text_model_id",
|
||||||
|
"vision_model_id",
|
||||||
|
"embedding_model_id",
|
||||||
|
"judge_model_id",
|
||||||
|
"shield_id",
|
||||||
|
"rerank_model_id",
|
||||||
|
]
|
||||||
test_func = request.node.function
|
test_func = request.node.function
|
||||||
|
|
||||||
actual_params = inspect.signature(test_func).parameters.keys()
|
actual_params = inspect.signature(test_func).parameters.keys()
|
||||||
|
|
@ -230,6 +241,7 @@ def instantiate_llama_stack_client(session):
|
||||||
|
|
||||||
force_restart = os.environ.get("LLAMA_STACK_TEST_FORCE_SERVER_RESTART") == "1"
|
force_restart = os.environ.get("LLAMA_STACK_TEST_FORCE_SERVER_RESTART") == "1"
|
||||||
if force_restart:
|
if force_restart:
|
||||||
|
print(f"Forcing restart of the server on port {port}")
|
||||||
stop_server_on_port(port)
|
stop_server_on_port(port)
|
||||||
|
|
||||||
# Check if port is available
|
# Check if port is available
|
||||||
|
|
|
||||||
|
|
@ -721,6 +721,6 @@ def test_openai_chat_completion_structured_output(openai_client, text_model_id,
|
||||||
print(response.choices[0].message.content)
|
print(response.choices[0].message.content)
|
||||||
answer = AnswerFormat.model_validate_json(response.choices[0].message.content)
|
answer = AnswerFormat.model_validate_json(response.choices[0].message.content)
|
||||||
expected = tc["expected"]
|
expected = tc["expected"]
|
||||||
assert answer.first_name == expected["first_name"]
|
assert expected["first_name"].lower() in answer.first_name.lower()
|
||||||
assert answer.last_name == expected["last_name"]
|
assert expected["last_name"].lower() in answer.last_name.lower()
|
||||||
assert answer.year_of_birth == expected["year_of_birth"]
|
assert answer.year_of_birth == expected["year_of_birth"]
|
||||||
|
|
|
||||||
214
tests/integration/inference/test_rerank.py
Normal file
214
tests/integration/inference/test_rerank.py
Normal file
|
|
@ -0,0 +1,214 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from llama_stack_client import BadRequestError as LlamaStackBadRequestError
|
||||||
|
from llama_stack_client.types.alpha import InferenceRerankResponse
|
||||||
|
from llama_stack_client.types.shared.interleaved_content import (
|
||||||
|
ImageContentItem,
|
||||||
|
ImageContentItemImage,
|
||||||
|
ImageContentItemImageURL,
|
||||||
|
TextContentItem,
|
||||||
|
)
|
||||||
|
|
||||||
|
from llama_stack.core.library_client import LlamaStackAsLibraryClient
|
||||||
|
|
||||||
|
# Test data
|
||||||
|
DUMMY_STRING = "string_1"
|
||||||
|
DUMMY_STRING2 = "string_2"
|
||||||
|
DUMMY_TEXT = TextContentItem(text=DUMMY_STRING, type="text")
|
||||||
|
DUMMY_TEXT2 = TextContentItem(text=DUMMY_STRING2, type="text")
|
||||||
|
DUMMY_IMAGE_URL = ImageContentItem(
|
||||||
|
image=ImageContentItemImage(url=ImageContentItemImageURL(uri="https://example.com/image.jpg")), type="image"
|
||||||
|
)
|
||||||
|
DUMMY_IMAGE_BASE64 = ImageContentItem(image=ImageContentItemImage(data="base64string"), type="image")
|
||||||
|
|
||||||
|
PROVIDERS_SUPPORTING_MEDIA = {} # Providers that support media input for rerank models
|
||||||
|
|
||||||
|
|
||||||
|
def skip_if_provider_doesnt_support_rerank(inference_provider_type):
|
||||||
|
supported_providers = {"remote::nvidia"}
|
||||||
|
if inference_provider_type not in supported_providers:
|
||||||
|
pytest.skip(f"{inference_provider_type} doesn't support rerank models")
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_rerank_response(response: InferenceRerankResponse, items: list) -> None:
|
||||||
|
"""
|
||||||
|
Validate that a rerank response has the correct structure and ordering.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: The InferenceRerankResponse to validate
|
||||||
|
items: The original items list that was ranked
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If any validation fails
|
||||||
|
"""
|
||||||
|
seen = set()
|
||||||
|
last_score = float("inf")
|
||||||
|
for d in response:
|
||||||
|
assert 0 <= d.index < len(items), f"Index {d.index} out of bounds for {len(items)} items"
|
||||||
|
assert d.index not in seen, f"Duplicate index {d.index} found"
|
||||||
|
seen.add(d.index)
|
||||||
|
assert isinstance(d.relevance_score, float), f"Score must be float, got {type(d.relevance_score)}"
|
||||||
|
assert d.relevance_score <= last_score, f"Scores not in descending order: {d.relevance_score} > {last_score}"
|
||||||
|
last_score = d.relevance_score
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_semantic_ranking(response: InferenceRerankResponse, items: list, expected_first_item: str) -> None:
|
||||||
|
"""
|
||||||
|
Validate that the expected most relevant item ranks first.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: The InferenceRerankResponse to validate
|
||||||
|
items: The original items list that was ranked
|
||||||
|
expected_first_item: The expected first item in the ranking
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If any validation fails
|
||||||
|
"""
|
||||||
|
if not response:
|
||||||
|
raise AssertionError("No ranking data returned in response")
|
||||||
|
|
||||||
|
actual_first_index = response[0].index
|
||||||
|
actual_first_item = items[actual_first_index]
|
||||||
|
assert actual_first_item == expected_first_item, (
|
||||||
|
f"Expected '{expected_first_item}' to rank first, but '{actual_first_item}' ranked first instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"query,items",
|
||||||
|
[
|
||||||
|
(DUMMY_STRING, [DUMMY_STRING, DUMMY_STRING2]),
|
||||||
|
(DUMMY_TEXT, [DUMMY_TEXT, DUMMY_TEXT2]),
|
||||||
|
(DUMMY_STRING, [DUMMY_STRING2, DUMMY_TEXT]),
|
||||||
|
(DUMMY_TEXT, [DUMMY_STRING, DUMMY_TEXT2]),
|
||||||
|
],
|
||||||
|
ids=[
|
||||||
|
"string-query-string-items",
|
||||||
|
"text-query-text-items",
|
||||||
|
"mixed-content-1",
|
||||||
|
"mixed-content-2",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_rerank_text(client_with_models, rerank_model_id, query, items, inference_provider_type):
|
||||||
|
skip_if_provider_doesnt_support_rerank(inference_provider_type)
|
||||||
|
|
||||||
|
response = client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||||
|
assert isinstance(response, list)
|
||||||
|
# TODO: Add type validation for response items once InferenceRerankResponseItem is exported from llama stack client.
|
||||||
|
assert len(response) <= len(items)
|
||||||
|
_validate_rerank_response(response, items)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"query,items",
|
||||||
|
[
|
||||||
|
(DUMMY_IMAGE_URL, [DUMMY_STRING]),
|
||||||
|
(DUMMY_IMAGE_BASE64, [DUMMY_TEXT]),
|
||||||
|
(DUMMY_TEXT, [DUMMY_IMAGE_URL]),
|
||||||
|
(DUMMY_IMAGE_BASE64, [DUMMY_IMAGE_URL, DUMMY_STRING, DUMMY_IMAGE_BASE64, DUMMY_TEXT]),
|
||||||
|
(DUMMY_TEXT, [DUMMY_IMAGE_URL, DUMMY_STRING, DUMMY_IMAGE_BASE64, DUMMY_TEXT]),
|
||||||
|
],
|
||||||
|
ids=[
|
||||||
|
"image-query-url",
|
||||||
|
"image-query-base64",
|
||||||
|
"text-query-image-item",
|
||||||
|
"mixed-content-1",
|
||||||
|
"mixed-content-2",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_rerank_image(client_with_models, rerank_model_id, query, items, inference_provider_type):
|
||||||
|
skip_if_provider_doesnt_support_rerank(inference_provider_type)
|
||||||
|
|
||||||
|
if rerank_model_id not in PROVIDERS_SUPPORTING_MEDIA:
|
||||||
|
error_type = (
|
||||||
|
ValueError if isinstance(client_with_models, LlamaStackAsLibraryClient) else LlamaStackBadRequestError
|
||||||
|
)
|
||||||
|
with pytest.raises(error_type):
|
||||||
|
client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||||
|
else:
|
||||||
|
response = client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||||
|
|
||||||
|
assert isinstance(response, list)
|
||||||
|
assert len(response) <= len(items)
|
||||||
|
_validate_rerank_response(response, items)
|
||||||
|
|
||||||
|
|
||||||
|
def test_rerank_max_results(client_with_models, rerank_model_id, inference_provider_type):
|
||||||
|
skip_if_provider_doesnt_support_rerank(inference_provider_type)
|
||||||
|
|
||||||
|
items = [DUMMY_STRING, DUMMY_STRING2, DUMMY_TEXT, DUMMY_TEXT2]
|
||||||
|
max_num_results = 2
|
||||||
|
|
||||||
|
response = client_with_models.alpha.inference.rerank(
|
||||||
|
model=rerank_model_id,
|
||||||
|
query=DUMMY_STRING,
|
||||||
|
items=items,
|
||||||
|
max_num_results=max_num_results,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, list)
|
||||||
|
assert len(response) == max_num_results
|
||||||
|
_validate_rerank_response(response, items)
|
||||||
|
|
||||||
|
|
||||||
|
def test_rerank_max_results_larger_than_items(client_with_models, rerank_model_id, inference_provider_type):
|
||||||
|
skip_if_provider_doesnt_support_rerank(inference_provider_type)
|
||||||
|
|
||||||
|
items = [DUMMY_STRING, DUMMY_STRING2]
|
||||||
|
response = client_with_models.alpha.inference.rerank(
|
||||||
|
model=rerank_model_id,
|
||||||
|
query=DUMMY_STRING,
|
||||||
|
items=items,
|
||||||
|
max_num_results=10, # Larger than items length
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, list)
|
||||||
|
assert len(response) <= len(items) # Should return at most len(items)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"query,items,expected_first_item",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"What is a reranking model? ",
|
||||||
|
[
|
||||||
|
"A reranking model reranks a list of items based on the query. ",
|
||||||
|
"Machine learning algorithms learn patterns from data. ",
|
||||||
|
"Python is a programming language. ",
|
||||||
|
],
|
||||||
|
"A reranking model reranks a list of items based on the query. ",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"What is C++?",
|
||||||
|
[
|
||||||
|
"Learning new things is interesting. ",
|
||||||
|
"C++ is a programming language. ",
|
||||||
|
"Books provide knowledge and entertainment. ",
|
||||||
|
],
|
||||||
|
"C++ is a programming language. ",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"What are good learning habits? ",
|
||||||
|
[
|
||||||
|
"Cooking pasta is a fun activity. ",
|
||||||
|
"Plants need water and sunlight. ",
|
||||||
|
"Good learning habits include reading daily and taking notes. ",
|
||||||
|
],
|
||||||
|
"Good learning habits include reading daily and taking notes. ",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_rerank_semantic_correctness(
|
||||||
|
client_with_models, rerank_model_id, query, items, expected_first_item, inference_provider_type
|
||||||
|
):
|
||||||
|
skip_if_provider_doesnt_support_rerank(inference_provider_type)
|
||||||
|
|
||||||
|
response = client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||||
|
|
||||||
|
_validate_rerank_response(response, items)
|
||||||
|
_validate_semantic_ranking(response, items, expected_first_item)
|
||||||
|
|
@ -4,18 +4,75 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import pytest
|
||||||
from llama_stack_client import LlamaStackClient
|
from llama_stack_client import LlamaStackClient
|
||||||
|
|
||||||
from llama_stack import LlamaStackAsLibraryClient
|
from llama_stack import LlamaStackAsLibraryClient
|
||||||
|
|
||||||
|
|
||||||
class TestInspect:
|
class TestInspect:
|
||||||
|
@pytest.mark.skip(reason="inspect tests disabled")
|
||||||
def test_health(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
|
def test_health(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
|
||||||
health = llama_stack_client.inspect.health()
|
health = llama_stack_client.inspect.health()
|
||||||
assert health is not None
|
assert health is not None
|
||||||
assert health.status == "OK"
|
assert health.status == "OK"
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="inspect tests disabled")
|
||||||
def test_version(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
|
def test_version(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
|
||||||
version = llama_stack_client.inspect.version()
|
version = llama_stack_client.inspect.version()
|
||||||
assert version is not None
|
assert version is not None
|
||||||
assert version.version is not None
|
assert version.version is not None
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="inspect tests disabled")
|
||||||
|
def test_list_routes_default(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
|
||||||
|
"""Test list_routes with default filter (non-deprecated v1 routes)."""
|
||||||
|
response = llama_stack_client.routes.list()
|
||||||
|
assert response is not None
|
||||||
|
assert hasattr(response, "data")
|
||||||
|
routes = response.data
|
||||||
|
assert len(routes) > 0
|
||||||
|
|
||||||
|
# All routes should be non-deprecated
|
||||||
|
# Check that we don't see any /openai/ routes (which are deprecated)
|
||||||
|
openai_routes = [r for r in routes if "/openai/" in r.route]
|
||||||
|
assert len(openai_routes) == 0, "Default filter should not include deprecated /openai/ routes"
|
||||||
|
|
||||||
|
# Should see standard v1 routes like /inspect/routes, /health, /version
|
||||||
|
paths = [r.route for r in routes]
|
||||||
|
assert "/inspect/routes" in paths or "/v1/inspect/routes" in paths
|
||||||
|
assert "/health" in paths or "/v1/health" in paths
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="inspect tests disabled")
|
||||||
|
def test_list_routes_filter_by_deprecated(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
|
||||||
|
"""Test list_routes with deprecated filter."""
|
||||||
|
response = llama_stack_client.routes.list(api_filter="deprecated")
|
||||||
|
assert response is not None
|
||||||
|
assert hasattr(response, "data")
|
||||||
|
routes = response.data
|
||||||
|
|
||||||
|
# When filtering for deprecated, we should get deprecated routes
|
||||||
|
# At minimum, we should see some /openai/ routes which are deprecated
|
||||||
|
if len(routes) > 0:
|
||||||
|
# If there are any deprecated routes, they should include openai routes
|
||||||
|
openai_routes = [r for r in routes if "/openai/" in r.route]
|
||||||
|
assert len(openai_routes) > 0, "Deprecated filter should include /openai/ routes"
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="inspect tests disabled")
|
||||||
|
def test_list_routes_filter_by_v1(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
|
||||||
|
"""Test list_routes with v1 filter."""
|
||||||
|
response = llama_stack_client.routes.list(api_filter="v1")
|
||||||
|
assert response is not None
|
||||||
|
assert hasattr(response, "data")
|
||||||
|
routes = response.data
|
||||||
|
assert len(routes) > 0
|
||||||
|
|
||||||
|
# Should not include deprecated routes
|
||||||
|
openai_routes = [r for r in routes if "/openai/" in r.route]
|
||||||
|
assert len(openai_routes) == 0
|
||||||
|
|
||||||
|
# Should include v1 routes
|
||||||
|
paths = [r.route for r in routes]
|
||||||
|
assert any(
|
||||||
|
"/v1/" in p or p.startswith("/inspect/") or p.startswith("/health") or p.startswith("/version")
|
||||||
|
for p in paths
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,6 @@ import os
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import llama_stack.core.telemetry.telemetry as telemetry_module
|
|
||||||
from llama_stack.testing.api_recorder import patch_httpx_for_test_id
|
from llama_stack.testing.api_recorder import patch_httpx_for_test_id
|
||||||
from tests.integration.fixtures.common import instantiate_llama_stack_client
|
from tests.integration.fixtures.common import instantiate_llama_stack_client
|
||||||
from tests.integration.telemetry.collectors import InMemoryTelemetryManager, OtlpHttpTestCollector
|
from tests.integration.telemetry.collectors import InMemoryTelemetryManager, OtlpHttpTestCollector
|
||||||
|
|
@ -22,40 +21,26 @@ def telemetry_test_collector():
|
||||||
stack_mode = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client")
|
stack_mode = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client")
|
||||||
|
|
||||||
if stack_mode == "server":
|
if stack_mode == "server":
|
||||||
|
# In server mode, the collector must be started and the server is already running.
|
||||||
|
# The integration test script (scripts/integration-tests.sh) should have set
|
||||||
|
# LLAMA_STACK_TEST_COLLECTOR_PORT and OTEL_EXPORTER_OTLP_ENDPOINT before starting the server.
|
||||||
try:
|
try:
|
||||||
collector = OtlpHttpTestCollector()
|
collector = OtlpHttpTestCollector()
|
||||||
except RuntimeError as exc:
|
except RuntimeError as exc:
|
||||||
pytest.skip(str(exc))
|
pytest.skip(str(exc))
|
||||||
env_overrides = {
|
|
||||||
"OTEL_EXPORTER_OTLP_ENDPOINT": collector.endpoint,
|
|
||||||
"OTEL_EXPORTER_OTLP_PROTOCOL": "http/protobuf",
|
|
||||||
"OTEL_BSP_SCHEDULE_DELAY": "200",
|
|
||||||
"OTEL_BSP_EXPORT_TIMEOUT": "2000",
|
|
||||||
"LLAMA_STACK_DISABLE_GUNICORN": "true", # Disable multi-process for telemetry collection
|
|
||||||
}
|
|
||||||
|
|
||||||
previous_env = {key: os.environ.get(key) for key in env_overrides}
|
# Verify the collector is listening on the expected endpoint
|
||||||
previous_force_restart = os.environ.get("LLAMA_STACK_TEST_FORCE_SERVER_RESTART")
|
expected_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT")
|
||||||
|
if expected_endpoint and collector.endpoint != expected_endpoint:
|
||||||
for key, value in env_overrides.items():
|
pytest.skip(
|
||||||
os.environ[key] = value
|
f"Collector endpoint mismatch: expected {expected_endpoint}, got {collector.endpoint}. "
|
||||||
|
"Server was likely started before collector."
|
||||||
os.environ["LLAMA_STACK_TEST_FORCE_SERVER_RESTART"] = "1"
|
)
|
||||||
telemetry_module._TRACER_PROVIDER = None
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield collector
|
yield collector
|
||||||
finally:
|
finally:
|
||||||
collector.shutdown()
|
collector.shutdown()
|
||||||
for key, prior in previous_env.items():
|
|
||||||
if prior is None:
|
|
||||||
os.environ.pop(key, None)
|
|
||||||
else:
|
|
||||||
os.environ[key] = prior
|
|
||||||
if previous_force_restart is None:
|
|
||||||
os.environ.pop("LLAMA_STACK_TEST_FORCE_SERVER_RESTART", None)
|
|
||||||
else:
|
|
||||||
os.environ["LLAMA_STACK_TEST_FORCE_SERVER_RESTART"] = previous_force_restart
|
|
||||||
else:
|
else:
|
||||||
manager = InMemoryTelemetryManager()
|
manager = InMemoryTelemetryManager()
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -206,3 +206,65 @@ def test_parse_and_maybe_upgrade_config_invalid(invalid_config):
|
||||||
def test_parse_and_maybe_upgrade_config_image_name_int(config_with_image_name_int):
|
def test_parse_and_maybe_upgrade_config_image_name_int(config_with_image_name_int):
|
||||||
result = parse_and_maybe_upgrade_config(config_with_image_name_int)
|
result = parse_and_maybe_upgrade_config(config_with_image_name_int)
|
||||||
assert isinstance(result.image_name, str)
|
assert isinstance(result.image_name, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_and_maybe_upgrade_config_sets_external_providers_dir(up_to_date_config):
|
||||||
|
"""Test that external_providers_dir is None when not specified (deprecated field)."""
|
||||||
|
# Ensure the config doesn't have external_providers_dir set
|
||||||
|
assert "external_providers_dir" not in up_to_date_config
|
||||||
|
|
||||||
|
result = parse_and_maybe_upgrade_config(up_to_date_config)
|
||||||
|
|
||||||
|
# Verify external_providers_dir is None (not set to default)
|
||||||
|
# This aligns with the deprecation of external_providers_dir
|
||||||
|
assert result.external_providers_dir is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_and_maybe_upgrade_config_preserves_custom_external_providers_dir(up_to_date_config):
|
||||||
|
"""Test that custom external_providers_dir values are preserved."""
|
||||||
|
custom_dir = "/custom/providers/dir"
|
||||||
|
up_to_date_config["external_providers_dir"] = custom_dir
|
||||||
|
|
||||||
|
result = parse_and_maybe_upgrade_config(up_to_date_config)
|
||||||
|
|
||||||
|
# Verify the custom value was preserved
|
||||||
|
assert str(result.external_providers_dir) == custom_dir
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_run_config_from_providers():
|
||||||
|
"""Test that _generate_run_config_from_providers creates a valid config"""
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from llama_stack.cli.stack.run import StackRun
|
||||||
|
from llama_stack.core.datatypes import Provider
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
subparsers = parser.add_subparsers()
|
||||||
|
stack_run = StackRun(subparsers)
|
||||||
|
|
||||||
|
providers = {
|
||||||
|
"inference": [
|
||||||
|
Provider(
|
||||||
|
provider_type="inline::meta-reference",
|
||||||
|
provider_id="meta-reference",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
config = stack_run._generate_run_config_from_providers(providers=providers)
|
||||||
|
config_dict = config.model_dump(mode="json")
|
||||||
|
|
||||||
|
# Verify basic structure
|
||||||
|
assert config_dict["image_name"] == "providers-run"
|
||||||
|
assert "inference" in config_dict["apis"]
|
||||||
|
assert "inference" in config_dict["providers"]
|
||||||
|
|
||||||
|
# Verify storage has all required stores including prompts
|
||||||
|
assert "storage" in config_dict
|
||||||
|
stores = config_dict["storage"]["stores"]
|
||||||
|
assert "prompts" in stores
|
||||||
|
assert stores["prompts"]["namespace"] == "prompts"
|
||||||
|
|
||||||
|
# Verify config can be parsed back
|
||||||
|
parsed = parse_and_maybe_upgrade_config(config_dict)
|
||||||
|
assert parsed.image_name == "providers-run"
|
||||||
|
|
|
||||||
251
tests/unit/providers/nvidia/test_rerank_inference.py
Normal file
251
tests/unit/providers/nvidia/test_rerank_inference.py
Normal file
|
|
@ -0,0 +1,251 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.apis.models import ModelType
|
||||||
|
from llama_stack.providers.remote.inference.nvidia.config import NVIDIAConfig
|
||||||
|
from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAInferenceAdapter
|
||||||
|
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
|
|
||||||
|
|
||||||
|
class MockResponse:
|
||||||
|
def __init__(self, status=200, json_data=None, text_data="OK"):
|
||||||
|
self.status = status
|
||||||
|
self._json_data = json_data or {"rankings": []}
|
||||||
|
self._text_data = text_data
|
||||||
|
|
||||||
|
async def json(self):
|
||||||
|
return self._json_data
|
||||||
|
|
||||||
|
async def text(self):
|
||||||
|
return self._text_data
|
||||||
|
|
||||||
|
|
||||||
|
class MockSession:
|
||||||
|
def __init__(self, response):
|
||||||
|
self.response = response
|
||||||
|
self.post_calls = []
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def post(self, url, **kwargs):
|
||||||
|
self.post_calls.append((url, kwargs))
|
||||||
|
|
||||||
|
class PostContext:
|
||||||
|
def __init__(self, response):
|
||||||
|
self.response = response
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self.response
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return PostContext(self.response)
|
||||||
|
|
||||||
|
|
||||||
|
def create_adapter(config=None, rerank_endpoints=None):
|
||||||
|
if config is None:
|
||||||
|
config = NVIDIAConfig(api_key="test-key")
|
||||||
|
|
||||||
|
adapter = NVIDIAInferenceAdapter(config=config)
|
||||||
|
|
||||||
|
class MockModel:
|
||||||
|
provider_resource_id = "test-model"
|
||||||
|
metadata = {}
|
||||||
|
|
||||||
|
adapter.model_store = AsyncMock()
|
||||||
|
adapter.model_store.get_model = AsyncMock(return_value=MockModel())
|
||||||
|
|
||||||
|
if rerank_endpoints is not None:
|
||||||
|
adapter.config.rerank_model_to_url = rerank_endpoints
|
||||||
|
|
||||||
|
return adapter
|
||||||
|
|
||||||
|
|
||||||
|
async def test_rerank_basic_functionality():
|
||||||
|
adapter = create_adapter()
|
||||||
|
mock_response = MockResponse(json_data={"rankings": [{"index": 0, "logit": 0.5}]})
|
||||||
|
mock_session = MockSession(mock_response)
|
||||||
|
|
||||||
|
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||||
|
result = await adapter.rerank(model="test-model", query="test query", items=["item1", "item2"])
|
||||||
|
|
||||||
|
assert len(result.data) == 1
|
||||||
|
assert result.data[0].index == 0
|
||||||
|
assert result.data[0].relevance_score == 0.5
|
||||||
|
|
||||||
|
url, kwargs = mock_session.post_calls[0]
|
||||||
|
payload = kwargs["json"]
|
||||||
|
assert payload["model"] == "test-model"
|
||||||
|
assert payload["query"] == {"text": "test query"}
|
||||||
|
assert payload["passages"] == [{"text": "item1"}, {"text": "item2"}]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_missing_rankings_key():
|
||||||
|
adapter = create_adapter()
|
||||||
|
mock_session = MockSession(MockResponse(json_data={}))
|
||||||
|
|
||||||
|
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||||
|
result = await adapter.rerank(model="test-model", query="q", items=["a"])
|
||||||
|
|
||||||
|
assert len(result.data) == 0
|
||||||
|
|
||||||
|
|
||||||
|
async def test_hosted_with_endpoint():
|
||||||
|
adapter = create_adapter(
|
||||||
|
config=NVIDIAConfig(api_key="key"), rerank_endpoints={"test-model": "https://model.endpoint/rerank"}
|
||||||
|
)
|
||||||
|
mock_session = MockSession(MockResponse())
|
||||||
|
|
||||||
|
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||||
|
await adapter.rerank(model="test-model", query="q", items=["a"])
|
||||||
|
|
||||||
|
url, _ = mock_session.post_calls[0]
|
||||||
|
assert url == "https://model.endpoint/rerank"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_hosted_without_endpoint():
|
||||||
|
adapter = create_adapter(
|
||||||
|
config=NVIDIAConfig(api_key="key"), # This creates hosted config (integrate.api.nvidia.com).
|
||||||
|
rerank_endpoints={}, # No endpoint mapping for test-model
|
||||||
|
)
|
||||||
|
mock_session = MockSession(MockResponse())
|
||||||
|
|
||||||
|
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||||
|
await adapter.rerank(model="test-model", query="q", items=["a"])
|
||||||
|
|
||||||
|
url, _ = mock_session.post_calls[0]
|
||||||
|
assert "https://integrate.api.nvidia.com" in url
|
||||||
|
|
||||||
|
|
||||||
|
async def test_hosted_model_not_in_endpoint_mapping():
|
||||||
|
adapter = create_adapter(
|
||||||
|
config=NVIDIAConfig(api_key="key"), rerank_endpoints={"other-model": "https://other.endpoint/rerank"}
|
||||||
|
)
|
||||||
|
mock_session = MockSession(MockResponse())
|
||||||
|
|
||||||
|
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||||
|
await adapter.rerank(model="test-model", query="q", items=["a"])
|
||||||
|
|
||||||
|
url, _ = mock_session.post_calls[0]
|
||||||
|
assert "https://integrate.api.nvidia.com" in url
|
||||||
|
assert url != "https://other.endpoint/rerank"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_self_hosted_ignores_endpoint():
|
||||||
|
adapter = create_adapter(
|
||||||
|
config=NVIDIAConfig(url="http://localhost:8000", api_key=None),
|
||||||
|
rerank_endpoints={"test-model": "https://model.endpoint/rerank"}, # This should be ignored for self-hosted.
|
||||||
|
)
|
||||||
|
mock_session = MockSession(MockResponse())
|
||||||
|
|
||||||
|
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||||
|
await adapter.rerank(model="test-model", query="q", items=["a"])
|
||||||
|
|
||||||
|
url, _ = mock_session.post_calls[0]
|
||||||
|
assert "http://localhost:8000" in url
|
||||||
|
assert "model.endpoint/rerank" not in url
|
||||||
|
|
||||||
|
|
||||||
|
async def test_max_num_results():
|
||||||
|
adapter = create_adapter()
|
||||||
|
rankings = [{"index": 0, "logit": 0.8}, {"index": 1, "logit": 0.6}]
|
||||||
|
mock_session = MockSession(MockResponse(json_data={"rankings": rankings}))
|
||||||
|
|
||||||
|
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||||
|
result = await adapter.rerank(model="test-model", query="q", items=["a", "b"], max_num_results=1)
|
||||||
|
|
||||||
|
assert len(result.data) == 1
|
||||||
|
assert result.data[0].index == 0
|
||||||
|
assert result.data[0].relevance_score == 0.8
|
||||||
|
|
||||||
|
|
||||||
|
async def test_http_error():
|
||||||
|
adapter = create_adapter()
|
||||||
|
mock_session = MockSession(MockResponse(status=500, text_data="Server Error"))
|
||||||
|
|
||||||
|
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||||
|
with pytest.raises(ConnectionError, match="status 500.*Server Error"):
|
||||||
|
await adapter.rerank(model="test-model", query="q", items=["a"])
|
||||||
|
|
||||||
|
|
||||||
|
async def test_client_error():
|
||||||
|
adapter = create_adapter()
|
||||||
|
mock_session = AsyncMock()
|
||||||
|
mock_session.__aenter__.side_effect = aiohttp.ClientError("Network error")
|
||||||
|
|
||||||
|
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||||
|
with pytest.raises(ConnectionError, match="Failed to connect.*Network error"):
|
||||||
|
await adapter.rerank(model="test-model", query="q", items=["a"])
|
||||||
|
|
||||||
|
|
||||||
|
async def test_list_models_includes_configured_rerank_models():
|
||||||
|
"""Test that list_models adds rerank models to the dynamic model list."""
|
||||||
|
adapter = create_adapter()
|
||||||
|
adapter.__provider_id__ = "nvidia"
|
||||||
|
adapter.__provider_spec__ = MagicMock()
|
||||||
|
|
||||||
|
dynamic_ids = ["llm-1", "embedding-1"]
|
||||||
|
with patch.object(OpenAIMixin, "list_provider_model_ids", new=AsyncMock(return_value=dynamic_ids)):
|
||||||
|
result = await adapter.list_models()
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
# Check that the rerank models are added
|
||||||
|
model_ids = [m.identifier for m in result]
|
||||||
|
assert "nv-rerank-qa-mistral-4b:1" in model_ids
|
||||||
|
assert "nvidia/nv-rerankqa-mistral-4b-v3" in model_ids
|
||||||
|
assert "nvidia/llama-3.2-nv-rerankqa-1b-v2" in model_ids
|
||||||
|
|
||||||
|
rerank_models = [m for m in result if m.model_type == ModelType.rerank]
|
||||||
|
|
||||||
|
assert len(rerank_models) == 3
|
||||||
|
|
||||||
|
for m in rerank_models:
|
||||||
|
assert m.provider_id == "nvidia"
|
||||||
|
assert m.model_type == ModelType.rerank
|
||||||
|
assert m.metadata == {}
|
||||||
|
assert m.identifier in adapter._model_cache
|
||||||
|
|
||||||
|
|
||||||
|
async def test_list_provider_model_ids_has_no_duplicates():
|
||||||
|
adapter = create_adapter()
|
||||||
|
|
||||||
|
dynamic_ids = [
|
||||||
|
"llm-1",
|
||||||
|
"nvidia/nv-rerankqa-mistral-4b-v3", # overlaps configured rerank ids
|
||||||
|
"embedding-1",
|
||||||
|
"llm-1",
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch.object(OpenAIMixin, "list_provider_model_ids", new=AsyncMock(return_value=dynamic_ids)):
|
||||||
|
ids = list(await adapter.list_provider_model_ids())
|
||||||
|
|
||||||
|
assert len(ids) == len(set(ids))
|
||||||
|
assert ids.count("nvidia/nv-rerankqa-mistral-4b-v3") == 1
|
||||||
|
assert "nv-rerank-qa-mistral-4b:1" in ids
|
||||||
|
assert "nvidia/llama-3.2-nv-rerankqa-1b-v2" in ids
|
||||||
|
|
||||||
|
|
||||||
|
async def test_list_provider_model_ids_uses_configured_on_dynamic_failure():
|
||||||
|
adapter = create_adapter()
|
||||||
|
|
||||||
|
# Simulate dynamic listing failure
|
||||||
|
with patch.object(OpenAIMixin, "list_provider_model_ids", new=AsyncMock(side_effect=Exception)):
|
||||||
|
ids = list(await adapter.list_provider_model_ids())
|
||||||
|
|
||||||
|
# Should still return configured rerank ids
|
||||||
|
configured_ids = list(adapter.config.rerank_model_to_url.keys())
|
||||||
|
assert set(ids) == set(configured_ids)
|
||||||
Loading…
Add table
Add a link
Reference in a new issue