mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
Merge branch 'main' into feat/gunicorn-production-server
This commit is contained in:
commit
47bd994824
59 changed files with 3190 additions and 421 deletions
60
.github/actions/install-llama-stack-client/action.yml
vendored
Normal file
60
.github/actions/install-llama-stack-client/action.yml
vendored
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
name: Install llama-stack-client
|
||||
description: Install llama-stack-client based on branch context and client-version input
|
||||
|
||||
inputs:
|
||||
client-version:
|
||||
description: 'Client version to install on non-release branches (latest or published). Ignored on release branches.'
|
||||
required: false
|
||||
default: ""
|
||||
|
||||
outputs:
|
||||
uv-extra-index-url:
|
||||
description: 'UV_EXTRA_INDEX_URL to use (set for release branches)'
|
||||
value: ${{ steps.configure.outputs.uv-extra-index-url }}
|
||||
install-after-sync:
|
||||
description: 'Whether to install client after uv sync'
|
||||
value: ${{ steps.configure.outputs.install-after-sync }}
|
||||
install-source:
|
||||
description: 'Where to install client from after sync'
|
||||
value: ${{ steps.configure.outputs.install-source }}
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Configure client installation
|
||||
id: configure
|
||||
shell: bash
|
||||
run: |
|
||||
# Determine the branch we're working with
|
||||
BRANCH="${{ github.base_ref || github.ref }}"
|
||||
BRANCH="${BRANCH#refs/heads/}"
|
||||
|
||||
echo "Working with branch: $BRANCH"
|
||||
|
||||
# On release branches: use test.pypi for uv sync, then install from git
|
||||
# On non-release branches: install based on client-version after sync
|
||||
if [[ "$BRANCH" =~ ^release-[0-9]+\.[0-9]+\.x$ ]]; then
|
||||
echo "Detected release branch: $BRANCH"
|
||||
|
||||
# Check if matching branch exists in client repo
|
||||
if ! git ls-remote --exit-code --heads https://github.com/llamastack/llama-stack-client-python.git "$BRANCH" > /dev/null 2>&1; then
|
||||
echo "::error::Branch $BRANCH not found in llama-stack-client-python repository"
|
||||
echo "::error::Please create the matching release branch in llama-stack-client-python before testing"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Configure to use test.pypi as extra index (PyPI is primary)
|
||||
echo "uv-extra-index-url=https://test.pypi.org/simple/" >> $GITHUB_OUTPUT
|
||||
echo "install-after-sync=true" >> $GITHUB_OUTPUT
|
||||
echo "install-source=git+https://github.com/llamastack/llama-stack-client-python.git@$BRANCH" >> $GITHUB_OUTPUT
|
||||
elif [ "${{ inputs.client-version }}" = "latest" ]; then
|
||||
# Install from main git after sync
|
||||
echo "install-after-sync=true" >> $GITHUB_OUTPUT
|
||||
echo "install-source=git+https://github.com/llamastack/llama-stack-client-python.git@main" >> $GITHUB_OUTPUT
|
||||
elif [ "${{ inputs.client-version }}" = "published" ]; then
|
||||
# Use published version from PyPI (installed by sync)
|
||||
echo "install-after-sync=false" >> $GITHUB_OUTPUT
|
||||
elif [ -n "${{ inputs.client-version }}" ]; then
|
||||
echo "::error::Invalid client-version: ${{ inputs.client-version }}"
|
||||
exit 1
|
||||
fi
|
||||
|
|
@ -94,7 +94,7 @@ runs:
|
|||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
|
||||
with:
|
||||
name: logs-${{ github.run_id }}-${{ github.run_attempt || '' }}-${{ strategy.job-index }}
|
||||
name: logs-${{ github.run_id }}-${{ github.run_attempt || '1' }}-${{ strategy.job-index || github.job }}-${{ github.action }}
|
||||
path: |
|
||||
*.log
|
||||
retention-days: 1
|
||||
|
|
|
|||
30
.github/actions/setup-runner/action.yml
vendored
30
.github/actions/setup-runner/action.yml
vendored
|
|
@ -18,25 +18,35 @@ runs:
|
|||
python-version: ${{ inputs.python-version }}
|
||||
version: 0.7.6
|
||||
|
||||
- name: Configure client installation
|
||||
id: client-config
|
||||
uses: ./.github/actions/install-llama-stack-client
|
||||
with:
|
||||
client-version: ${{ inputs.client-version }}
|
||||
|
||||
- name: Install dependencies
|
||||
shell: bash
|
||||
env:
|
||||
UV_EXTRA_INDEX_URL: ${{ steps.client-config.outputs.uv-extra-index-url }}
|
||||
run: |
|
||||
# Export UV env vars for current step and persist to GITHUB_ENV for subsequent steps
|
||||
if [ -n "$UV_EXTRA_INDEX_URL" ]; then
|
||||
export UV_INDEX_STRATEGY=unsafe-best-match
|
||||
echo "UV_EXTRA_INDEX_URL=$UV_EXTRA_INDEX_URL" >> $GITHUB_ENV
|
||||
echo "UV_INDEX_STRATEGY=$UV_INDEX_STRATEGY" >> $GITHUB_ENV
|
||||
echo "Exported UV environment variables for current and subsequent steps"
|
||||
fi
|
||||
|
||||
echo "Updating project dependencies via uv sync"
|
||||
uv sync --all-groups
|
||||
|
||||
echo "Installing ad-hoc dependencies"
|
||||
uv pip install faiss-cpu
|
||||
|
||||
# Install llama-stack-client-python based on the client-version input
|
||||
if [ "${{ inputs.client-version }}" = "latest" ]; then
|
||||
echo "Installing latest llama-stack-client-python from main branch"
|
||||
uv pip install git+https://github.com/llamastack/llama-stack-client-python.git@main
|
||||
elif [ "${{ inputs.client-version }}" = "published" ]; then
|
||||
echo "Installing published llama-stack-client-python from PyPI"
|
||||
uv pip install llama-stack-client
|
||||
else
|
||||
echo "Invalid client-version: ${{ inputs.client-version }}"
|
||||
exit 1
|
||||
# Install specific client version after sync if needed
|
||||
if [ "${{ steps.client-config.outputs.install-after-sync }}" = "true" ]; then
|
||||
echo "Installing llama-stack-client from: ${{ steps.client-config.outputs.install-source }}"
|
||||
uv pip install ${{ steps.client-config.outputs.install-source }}
|
||||
fi
|
||||
|
||||
echo "Installed llama packages"
|
||||
|
|
|
|||
|
|
@ -42,18 +42,7 @@ runs:
|
|||
- name: Build Llama Stack
|
||||
shell: bash
|
||||
run: |
|
||||
# Install llama-stack-client-python based on the client-version input
|
||||
if [ "${{ inputs.client-version }}" = "latest" ]; then
|
||||
echo "Installing latest llama-stack-client-python from main branch"
|
||||
export LLAMA_STACK_CLIENT_DIR=git+https://github.com/llamastack/llama-stack-client-python.git@main
|
||||
elif [ "${{ inputs.client-version }}" = "published" ]; then
|
||||
echo "Installing published llama-stack-client-python from PyPI"
|
||||
unset LLAMA_STACK_CLIENT_DIR
|
||||
else
|
||||
echo "Invalid client-version: ${{ inputs.client-version }}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Client is already installed by setup-runner (handles both main and release branches)
|
||||
echo "Building Llama Stack"
|
||||
|
||||
LLAMA_STACK_DIR=. \
|
||||
|
|
|
|||
1
.github/workflows/README.md
vendored
1
.github/workflows/README.md
vendored
|
|
@ -13,7 +13,6 @@ Llama Stack uses GitHub Actions for Continuous Integration (CI). Below is a tabl
|
|||
| Integration Tests (Replay) | [integration-tests.yml](integration-tests.yml) | Run the integration test suites from tests/integration in replay mode |
|
||||
| Vector IO Integration Tests | [integration-vector-io-tests.yml](integration-vector-io-tests.yml) | Run the integration test suite with various VectorIO providers |
|
||||
| Pre-commit | [pre-commit.yml](pre-commit.yml) | Run pre-commit checks |
|
||||
| Pre-commit Bot | [precommit-trigger.yml](precommit-trigger.yml) | Pre-commit bot for PR |
|
||||
| Test Llama Stack Build | [providers-build.yml](providers-build.yml) | Test llama stack build |
|
||||
| Test llama stack list-deps | [providers-list-deps.yml](providers-list-deps.yml) | Test llama stack list-deps |
|
||||
| Python Package Build Test | [python-build-test.yml](python-build-test.yml) | Test building the llama-stack PyPI project |
|
||||
|
|
|
|||
6
.github/workflows/backward-compat.yml
vendored
6
.github/workflows/backward-compat.yml
vendored
|
|
@ -4,7 +4,11 @@ run-name: Check backward compatibility for run.yaml configs
|
|||
|
||||
on:
|
||||
pull_request:
|
||||
branches: [main]
|
||||
branches:
|
||||
- main
|
||||
- 'release-[0-9]+.[0-9]+.[0-9]+.[0-9]+'
|
||||
- 'release-[0-9]+.[0-9]+.[0-9]+'
|
||||
- 'release-[0-9]+.[0-9]+'
|
||||
paths:
|
||||
- 'src/llama_stack/core/datatypes.py'
|
||||
- 'src/llama_stack/providers/datatypes.py'
|
||||
|
|
|
|||
10
.github/workflows/install-script-ci.yml
vendored
10
.github/workflows/install-script-ci.yml
vendored
|
|
@ -30,10 +30,16 @@ jobs:
|
|||
|
||||
- name: Build a single provider
|
||||
run: |
|
||||
BUILD_ARGS="--build-arg INSTALL_MODE=editable --build-arg DISTRO_NAME=starter"
|
||||
if [ -n "${UV_EXTRA_INDEX_URL:-}" ]; then
|
||||
BUILD_ARGS="$BUILD_ARGS --build-arg UV_EXTRA_INDEX_URL=$UV_EXTRA_INDEX_URL"
|
||||
fi
|
||||
if [ -n "${UV_INDEX_STRATEGY:-}" ]; then
|
||||
BUILD_ARGS="$BUILD_ARGS --build-arg UV_INDEX_STRATEGY=$UV_INDEX_STRATEGY"
|
||||
fi
|
||||
docker build . \
|
||||
-f containers/Containerfile \
|
||||
--build-arg INSTALL_MODE=editable \
|
||||
--build-arg DISTRO_NAME=starter \
|
||||
$BUILD_ARGS \
|
||||
--tag llama-stack:starter-ci
|
||||
|
||||
- name: Run installer end-to-end
|
||||
|
|
|
|||
8
.github/workflows/integration-auth-tests.yml
vendored
8
.github/workflows/integration-auth-tests.yml
vendored
|
|
@ -4,9 +4,13 @@ run-name: Run the integration test suite with Kubernetes authentication
|
|||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
branches:
|
||||
- main
|
||||
- 'release-[0-9]+.[0-9]+.x'
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
branches:
|
||||
- main
|
||||
- 'release-[0-9]+.[0-9]+.x'
|
||||
paths:
|
||||
- 'distributions/**'
|
||||
- 'src/llama_stack/**'
|
||||
|
|
|
|||
|
|
@ -4,9 +4,13 @@ run-name: Run the integration test suite with SqlStore
|
|||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
branches:
|
||||
- main
|
||||
- 'release-[0-9]+.[0-9]+.x'
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
branches:
|
||||
- main
|
||||
- 'release-[0-9]+.[0-9]+.x'
|
||||
paths:
|
||||
- 'src/llama_stack/providers/utils/sqlstore/**'
|
||||
- 'tests/integration/sqlstore/**'
|
||||
|
|
|
|||
10
.github/workflows/integration-tests.yml
vendored
10
.github/workflows/integration-tests.yml
vendored
|
|
@ -4,9 +4,13 @@ run-name: Run the integration test suites from tests/integration in replay mode
|
|||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
branches:
|
||||
- main
|
||||
- 'release-[0-9]+.[0-9]+.x'
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
branches:
|
||||
- main
|
||||
- 'release-[0-9]+.[0-9]+.x'
|
||||
types: [opened, synchronize, reopened]
|
||||
paths:
|
||||
- 'src/llama_stack/**'
|
||||
|
|
@ -47,7 +51,7 @@ jobs:
|
|||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
client-type: [library, docker]
|
||||
client-type: [library, docker, server]
|
||||
# Use Python 3.13 only on nightly schedule (daily latest client test), otherwise use 3.12
|
||||
python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }}
|
||||
client-version: ${{ (github.event.schedule == '0 0 * * *' || github.event.inputs.test-all-client-versions == 'true') && fromJSON('["published", "latest"]') || fromJSON('["latest"]') }}
|
||||
|
|
|
|||
|
|
@ -4,9 +4,13 @@ run-name: Run the integration test suite with various VectorIO providers
|
|||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
branches:
|
||||
- main
|
||||
- 'release-[0-9]+.[0-9]+.x'
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
branches:
|
||||
- main
|
||||
- 'release-[0-9]+.[0-9]+.x'
|
||||
paths:
|
||||
- 'src/llama_stack/**'
|
||||
- '!src/llama_stack/ui/**'
|
||||
|
|
|
|||
52
.github/workflows/pre-commit.yml
vendored
52
.github/workflows/pre-commit.yml
vendored
|
|
@ -5,7 +5,9 @@ run-name: Run pre-commit checks
|
|||
on:
|
||||
pull_request:
|
||||
push:
|
||||
branches: [main]
|
||||
branches:
|
||||
- main
|
||||
- 'release-[0-9]+.[0-9]+.x'
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
|
||||
|
|
@ -50,19 +52,34 @@ jobs:
|
|||
run: npm ci
|
||||
working-directory: src/llama_stack/ui
|
||||
|
||||
- name: Install pre-commit
|
||||
run: python -m pip install pre-commit
|
||||
|
||||
- name: Cache pre-commit
|
||||
uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # v4
|
||||
with:
|
||||
path: ~/.cache/pre-commit
|
||||
key: pre-commit-3|${{ env.pythonLocation }}|${{ hashFiles('.pre-commit-config.yaml') }}
|
||||
|
||||
- name: Run pre-commit
|
||||
id: precommit
|
||||
uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
|
||||
continue-on-error: true
|
||||
run: |
|
||||
set +e
|
||||
pre-commit run --show-diff-on-failure --color=always --all-files 2>&1 | tee /tmp/precommit.log
|
||||
status=${PIPESTATUS[0]}
|
||||
echo "status=$status" >> $GITHUB_OUTPUT
|
||||
exit 0
|
||||
env:
|
||||
SKIP: no-commit-to-branch,mypy
|
||||
RUFF_OUTPUT_FORMAT: github
|
||||
|
||||
- name: Check pre-commit results
|
||||
if: steps.precommit.outcome == 'failure'
|
||||
if: steps.precommit.outputs.status != '0'
|
||||
run: |
|
||||
echo "::error::Pre-commit hooks failed. Please run 'pre-commit run --all-files' locally and commit the fixes."
|
||||
echo "::warning::Some pre-commit hooks failed. Check the output above for details."
|
||||
echo ""
|
||||
echo "Failed hooks output:"
|
||||
cat /tmp/precommit.log
|
||||
exit 1
|
||||
|
||||
- name: Debug
|
||||
|
|
@ -113,11 +130,34 @@ jobs:
|
|||
exit 1
|
||||
fi
|
||||
|
||||
- name: Configure client installation
|
||||
id: client-config
|
||||
uses: ./.github/actions/install-llama-stack-client
|
||||
|
||||
- name: Sync dev + type_checking dependencies
|
||||
run: uv sync --group dev --group type_checking
|
||||
env:
|
||||
UV_EXTRA_INDEX_URL: ${{ steps.client-config.outputs.uv-extra-index-url }}
|
||||
run: |
|
||||
if [ -n "$UV_EXTRA_INDEX_URL" ]; then
|
||||
export UV_INDEX_STRATEGY="unsafe-best-match"
|
||||
fi
|
||||
|
||||
uv sync --group dev --group type_checking
|
||||
|
||||
# Install specific client version after sync if needed
|
||||
if [ "${{ steps.client-config.outputs.install-after-sync }}" = "true" ]; then
|
||||
echo "Installing llama-stack-client from: ${{ steps.client-config.outputs.install-source }}"
|
||||
uv pip install ${{ steps.client-config.outputs.install-source }}
|
||||
fi
|
||||
|
||||
- name: Run mypy (full type_checking)
|
||||
env:
|
||||
UV_EXTRA_INDEX_URL: ${{ steps.client-config.outputs.uv-extra-index-url }}
|
||||
run: |
|
||||
if [ -n "$UV_EXTRA_INDEX_URL" ]; then
|
||||
export UV_INDEX_STRATEGY="unsafe-best-match"
|
||||
fi
|
||||
|
||||
set +e
|
||||
uv run --group dev --group type_checking mypy
|
||||
status=$?
|
||||
|
|
|
|||
227
.github/workflows/precommit-trigger.yml
vendored
227
.github/workflows/precommit-trigger.yml
vendored
|
|
@ -1,227 +0,0 @@
|
|||
name: Pre-commit Bot
|
||||
|
||||
run-name: Pre-commit bot for PR #${{ github.event.issue.number }}
|
||||
|
||||
on:
|
||||
issue_comment:
|
||||
types: [created]
|
||||
|
||||
jobs:
|
||||
pre-commit:
|
||||
# Only run on pull request comments
|
||||
if: github.event.issue.pull_request && contains(github.event.comment.body, '@github-actions run precommit')
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
|
||||
steps:
|
||||
- name: Check comment author and get PR details
|
||||
id: check_author
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
// Get PR details
|
||||
const pr = await github.rest.pulls.get({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
pull_number: context.issue.number
|
||||
});
|
||||
|
||||
// Check if commenter has write access or is the PR author
|
||||
const commenter = context.payload.comment.user.login;
|
||||
const prAuthor = pr.data.user.login;
|
||||
|
||||
let hasPermission = false;
|
||||
|
||||
// Check if commenter is PR author
|
||||
if (commenter === prAuthor) {
|
||||
hasPermission = true;
|
||||
console.log(`Comment author ${commenter} is the PR author`);
|
||||
} else {
|
||||
// Check if commenter has write/admin access
|
||||
try {
|
||||
const permission = await github.rest.repos.getCollaboratorPermissionLevel({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
username: commenter
|
||||
});
|
||||
|
||||
const level = permission.data.permission;
|
||||
hasPermission = ['write', 'admin', 'maintain'].includes(level);
|
||||
console.log(`Comment author ${commenter} has permission: ${level}`);
|
||||
} catch (error) {
|
||||
console.log(`Could not check permissions for ${commenter}: ${error.message}`);
|
||||
}
|
||||
}
|
||||
|
||||
if (!hasPermission) {
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: context.issue.number,
|
||||
body: `❌ @${commenter} You don't have permission to trigger pre-commit. Only PR authors or repository collaborators can run this command.`
|
||||
});
|
||||
core.setFailed(`User ${commenter} does not have permission`);
|
||||
return;
|
||||
}
|
||||
|
||||
// Save PR info for later steps
|
||||
core.setOutput('pr_number', context.issue.number);
|
||||
core.setOutput('pr_head_ref', pr.data.head.ref);
|
||||
core.setOutput('pr_head_sha', pr.data.head.sha);
|
||||
core.setOutput('pr_head_repo', pr.data.head.repo.full_name);
|
||||
core.setOutput('pr_base_ref', pr.data.base.ref);
|
||||
core.setOutput('is_fork', pr.data.head.repo.full_name !== context.payload.repository.full_name);
|
||||
core.setOutput('authorized', 'true');
|
||||
|
||||
- name: React to comment
|
||||
if: steps.check_author.outputs.authorized == 'true'
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
await github.rest.reactions.createForIssueComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
comment_id: context.payload.comment.id,
|
||||
content: 'rocket'
|
||||
});
|
||||
|
||||
- name: Comment starting
|
||||
if: steps.check_author.outputs.authorized == 'true'
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: ${{ steps.check_author.outputs.pr_number }},
|
||||
body: `⏳ Running [pre-commit hooks](https://github.com/${context.repo.owner}/${context.repo.repo}/actions/runs/${context.runId}) on PR #${{ steps.check_author.outputs.pr_number }}...`
|
||||
});
|
||||
|
||||
- name: Checkout PR branch (same-repo)
|
||||
if: steps.check_author.outputs.authorized == 'true' && steps.check_author.outputs.is_fork == 'false'
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
with:
|
||||
ref: ${{ steps.check_author.outputs.pr_head_ref }}
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Checkout PR branch (fork)
|
||||
if: steps.check_author.outputs.authorized == 'true' && steps.check_author.outputs.is_fork == 'true'
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
with:
|
||||
repository: ${{ steps.check_author.outputs.pr_head_repo }}
|
||||
ref: ${{ steps.check_author.outputs.pr_head_ref }}
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Verify checkout
|
||||
if: steps.check_author.outputs.authorized == 'true'
|
||||
run: |
|
||||
echo "Current SHA: $(git rev-parse HEAD)"
|
||||
echo "Expected SHA: ${{ steps.check_author.outputs.pr_head_sha }}"
|
||||
if [[ "$(git rev-parse HEAD)" != "${{ steps.check_author.outputs.pr_head_sha }}" ]]; then
|
||||
echo "::error::Checked out SHA does not match expected SHA"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Set up Python
|
||||
if: steps.check_author.outputs.authorized == 'true'
|
||||
uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0
|
||||
with:
|
||||
python-version: '3.12'
|
||||
cache: pip
|
||||
cache-dependency-path: |
|
||||
**/requirements*.txt
|
||||
.pre-commit-config.yaml
|
||||
|
||||
- name: Set up Node.js
|
||||
if: steps.check_author.outputs.authorized == 'true'
|
||||
uses: actions/setup-node@2028fbc5c25fe9cf00d9f06a71cc4710d4507903 # v6.0.0
|
||||
with:
|
||||
node-version: '20'
|
||||
cache: 'npm'
|
||||
cache-dependency-path: 'src/llama_stack/ui/'
|
||||
|
||||
- name: Install npm dependencies
|
||||
if: steps.check_author.outputs.authorized == 'true'
|
||||
run: npm ci
|
||||
working-directory: src/llama_stack/ui
|
||||
|
||||
- name: Run pre-commit
|
||||
if: steps.check_author.outputs.authorized == 'true'
|
||||
id: precommit
|
||||
uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
|
||||
continue-on-error: true
|
||||
env:
|
||||
SKIP: no-commit-to-branch
|
||||
RUFF_OUTPUT_FORMAT: github
|
||||
|
||||
- name: Check for changes
|
||||
if: steps.check_author.outputs.authorized == 'true'
|
||||
id: changes
|
||||
run: |
|
||||
if ! git diff --exit-code || [ -n "$(git ls-files --others --exclude-standard)" ]; then
|
||||
echo "has_changes=true" >> $GITHUB_OUTPUT
|
||||
echo "Changes detected after pre-commit"
|
||||
else
|
||||
echo "has_changes=false" >> $GITHUB_OUTPUT
|
||||
echo "No changes after pre-commit"
|
||||
fi
|
||||
|
||||
- name: Commit and push changes
|
||||
if: steps.check_author.outputs.authorized == 'true' && steps.changes.outputs.has_changes == 'true'
|
||||
run: |
|
||||
git config --local user.email "github-actions[bot]@users.noreply.github.com"
|
||||
git config --local user.name "github-actions[bot]"
|
||||
|
||||
git add -A
|
||||
git commit -m "style: apply pre-commit fixes
|
||||
|
||||
🤖 Applied by @github-actions bot via pre-commit workflow"
|
||||
|
||||
# Push changes
|
||||
git push origin HEAD:${{ steps.check_author.outputs.pr_head_ref }}
|
||||
|
||||
- name: Comment success with changes
|
||||
if: steps.check_author.outputs.authorized == 'true' && steps.changes.outputs.has_changes == 'true'
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: ${{ steps.check_author.outputs.pr_number }},
|
||||
body: `✅ Pre-commit hooks completed successfully!\n\n🔧 Changes have been committed and pushed to the PR branch.`
|
||||
});
|
||||
|
||||
- name: Comment success without changes
|
||||
if: steps.check_author.outputs.authorized == 'true' && steps.changes.outputs.has_changes == 'false' && steps.precommit.outcome == 'success'
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: ${{ steps.check_author.outputs.pr_number }},
|
||||
body: `✅ Pre-commit hooks passed!\n\n✨ No changes needed - your code is already formatted correctly.`
|
||||
});
|
||||
|
||||
- name: Comment failure
|
||||
if: failure()
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: ${{ steps.check_author.outputs.pr_number }},
|
||||
body: `❌ Pre-commit workflow failed!\n\nPlease check the [workflow logs](https://github.com/${context.repo.owner}/${context.repo.repo}/actions/runs/${context.runId}) for details.`
|
||||
});
|
||||
38
.github/workflows/providers-build.yml
vendored
38
.github/workflows/providers-build.yml
vendored
|
|
@ -72,10 +72,16 @@ jobs:
|
|||
- name: Build container image
|
||||
if: matrix.image-type == 'container'
|
||||
run: |
|
||||
BUILD_ARGS="--build-arg INSTALL_MODE=editable --build-arg DISTRO_NAME=${{ matrix.distro }}"
|
||||
if [ -n "${UV_EXTRA_INDEX_URL:-}" ]; then
|
||||
BUILD_ARGS="$BUILD_ARGS --build-arg UV_EXTRA_INDEX_URL=$UV_EXTRA_INDEX_URL"
|
||||
fi
|
||||
if [ -n "${UV_INDEX_STRATEGY:-}" ]; then
|
||||
BUILD_ARGS="$BUILD_ARGS --build-arg UV_INDEX_STRATEGY=$UV_INDEX_STRATEGY"
|
||||
fi
|
||||
docker build . \
|
||||
-f containers/Containerfile \
|
||||
--build-arg INSTALL_MODE=editable \
|
||||
--build-arg DISTRO_NAME=${{ matrix.distro }} \
|
||||
$BUILD_ARGS \
|
||||
--tag llama-stack:${{ matrix.distro }}-ci
|
||||
|
||||
- name: Print dependencies in the image
|
||||
|
|
@ -108,12 +114,18 @@ jobs:
|
|||
- name: Build container image
|
||||
run: |
|
||||
BASE_IMAGE=$(yq -r '.distribution_spec.container_image // "python:3.12-slim"' src/llama_stack/distributions/ci-tests/build.yaml)
|
||||
BUILD_ARGS="--build-arg INSTALL_MODE=editable --build-arg DISTRO_NAME=ci-tests"
|
||||
BUILD_ARGS="$BUILD_ARGS --build-arg BASE_IMAGE=$BASE_IMAGE"
|
||||
BUILD_ARGS="$BUILD_ARGS --build-arg RUN_CONFIG_PATH=/workspace/src/llama_stack/distributions/ci-tests/run.yaml"
|
||||
if [ -n "${UV_EXTRA_INDEX_URL:-}" ]; then
|
||||
BUILD_ARGS="$BUILD_ARGS --build-arg UV_EXTRA_INDEX_URL=$UV_EXTRA_INDEX_URL"
|
||||
fi
|
||||
if [ -n "${UV_INDEX_STRATEGY:-}" ]; then
|
||||
BUILD_ARGS="$BUILD_ARGS --build-arg UV_INDEX_STRATEGY=$UV_INDEX_STRATEGY"
|
||||
fi
|
||||
docker build . \
|
||||
-f containers/Containerfile \
|
||||
--build-arg INSTALL_MODE=editable \
|
||||
--build-arg DISTRO_NAME=ci-tests \
|
||||
--build-arg BASE_IMAGE="$BASE_IMAGE" \
|
||||
--build-arg RUN_CONFIG_PATH=/workspace/src/llama_stack/distributions/ci-tests/run.yaml \
|
||||
$BUILD_ARGS \
|
||||
-t llama-stack:ci-tests
|
||||
|
||||
- name: Inspect the container image entrypoint
|
||||
|
|
@ -148,12 +160,18 @@ jobs:
|
|||
- name: Build UBI9 container image
|
||||
run: |
|
||||
BASE_IMAGE=$(yq -r '.distribution_spec.container_image // "registry.access.redhat.com/ubi9:latest"' src/llama_stack/distributions/ci-tests/build.yaml)
|
||||
BUILD_ARGS="--build-arg INSTALL_MODE=editable --build-arg DISTRO_NAME=ci-tests"
|
||||
BUILD_ARGS="$BUILD_ARGS --build-arg BASE_IMAGE=$BASE_IMAGE"
|
||||
BUILD_ARGS="$BUILD_ARGS --build-arg RUN_CONFIG_PATH=/workspace/src/llama_stack/distributions/ci-tests/run.yaml"
|
||||
if [ -n "${UV_EXTRA_INDEX_URL:-}" ]; then
|
||||
BUILD_ARGS="$BUILD_ARGS --build-arg UV_EXTRA_INDEX_URL=$UV_EXTRA_INDEX_URL"
|
||||
fi
|
||||
if [ -n "${UV_INDEX_STRATEGY:-}" ]; then
|
||||
BUILD_ARGS="$BUILD_ARGS --build-arg UV_INDEX_STRATEGY=$UV_INDEX_STRATEGY"
|
||||
fi
|
||||
docker build . \
|
||||
-f containers/Containerfile \
|
||||
--build-arg INSTALL_MODE=editable \
|
||||
--build-arg DISTRO_NAME=ci-tests \
|
||||
--build-arg BASE_IMAGE="$BASE_IMAGE" \
|
||||
--build-arg RUN_CONFIG_PATH=/workspace/src/llama_stack/distributions/ci-tests/run.yaml \
|
||||
$BUILD_ARGS \
|
||||
-t llama-stack:ci-tests-ubi9
|
||||
|
||||
- name: Inspect UBI9 image
|
||||
|
|
|
|||
8
.github/workflows/unit-tests.yml
vendored
8
.github/workflows/unit-tests.yml
vendored
|
|
@ -4,9 +4,13 @@ run-name: Run the unit test suite
|
|||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
branches:
|
||||
- main
|
||||
- 'release-[0-9]+.[0-9]+.x'
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
branches:
|
||||
- main
|
||||
- 'release-[0-9]+.[0-9]+.x'
|
||||
paths:
|
||||
- 'src/llama_stack/**'
|
||||
- '!src/llama_stack/ui/**'
|
||||
|
|
|
|||
|
|
@ -52,10 +52,6 @@ repos:
|
|||
additional_dependencies:
|
||||
- black==24.3.0
|
||||
|
||||
- repo: https://github.com/astral-sh/uv-pre-commit
|
||||
rev: 0.7.20
|
||||
hooks:
|
||||
- id: uv-lock
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.18.2
|
||||
|
|
@ -63,22 +59,13 @@ repos:
|
|||
- id: mypy
|
||||
additional_dependencies:
|
||||
- uv==0.6.2
|
||||
- mypy
|
||||
- pytest
|
||||
- rich
|
||||
- types-requests
|
||||
- pydantic
|
||||
- httpx
|
||||
pass_filenames: false
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: mypy-full
|
||||
name: mypy (full type_checking)
|
||||
entry: uv run --group dev --group type_checking mypy
|
||||
language: system
|
||||
pass_filenames: false
|
||||
stages: [manual]
|
||||
|
||||
# - repo: https://github.com/tcort/markdown-link-check
|
||||
# rev: v3.11.2
|
||||
# hooks:
|
||||
|
|
@ -87,11 +74,26 @@ repos:
|
|||
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: uv-lock
|
||||
name: uv-lock
|
||||
additional_dependencies:
|
||||
- uv==0.7.20
|
||||
entry: ./scripts/uv-run-with-index.sh lock
|
||||
language: python
|
||||
pass_filenames: false
|
||||
require_serial: true
|
||||
files: ^(pyproject\.toml|uv\.lock)$
|
||||
- id: mypy-full
|
||||
name: mypy (full type_checking)
|
||||
entry: ./scripts/uv-run-with-index.sh run --group dev --group type_checking mypy
|
||||
language: system
|
||||
pass_filenames: false
|
||||
stages: [manual]
|
||||
- id: distro-codegen
|
||||
name: Distribution Template Codegen
|
||||
additional_dependencies:
|
||||
- uv==0.7.8
|
||||
entry: uv run --group codegen ./scripts/distro_codegen.py
|
||||
entry: ./scripts/uv-run-with-index.sh run --group codegen ./scripts/distro_codegen.py
|
||||
language: python
|
||||
pass_filenames: false
|
||||
require_serial: true
|
||||
|
|
@ -100,7 +102,7 @@ repos:
|
|||
name: Provider Codegen
|
||||
additional_dependencies:
|
||||
- uv==0.7.8
|
||||
entry: uv run --group codegen ./scripts/provider_codegen.py
|
||||
entry: ./scripts/uv-run-with-index.sh run --group codegen ./scripts/provider_codegen.py
|
||||
language: python
|
||||
pass_filenames: false
|
||||
require_serial: true
|
||||
|
|
@ -109,7 +111,7 @@ repos:
|
|||
name: API Spec Codegen
|
||||
additional_dependencies:
|
||||
- uv==0.7.8
|
||||
entry: sh -c 'uv run ./docs/openapi_generator/run_openapi_generator.sh > /dev/null'
|
||||
entry: sh -c './scripts/uv-run-with-index.sh run ./docs/openapi_generator/run_openapi_generator.sh > /dev/null'
|
||||
language: python
|
||||
pass_filenames: false
|
||||
require_serial: true
|
||||
|
|
@ -150,7 +152,7 @@ repos:
|
|||
name: Generate CI documentation
|
||||
additional_dependencies:
|
||||
- uv==0.7.8
|
||||
entry: uv run ./scripts/gen-ci-docs.py
|
||||
entry: ./scripts/uv-run-with-index.sh run ./scripts/gen-ci-docs.py
|
||||
language: python
|
||||
pass_filenames: false
|
||||
require_serial: true
|
||||
|
|
@ -162,6 +164,7 @@ repos:
|
|||
files: ^src/llama_stack/ui/.*\.(ts|tsx)$
|
||||
pass_filenames: false
|
||||
require_serial: true
|
||||
|
||||
- id: check-log-usage
|
||||
name: Ensure 'llama_stack.log' usage for logging
|
||||
entry: bash
|
||||
|
|
@ -197,6 +200,7 @@ repos:
|
|||
echo;
|
||||
exit 1;
|
||||
} || true
|
||||
|
||||
ci:
|
||||
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
|
||||
autoupdate_commit_msg: ⬆ [pre-commit.ci] pre-commit autoupdate
|
||||
|
|
|
|||
|
|
@ -956,7 +956,22 @@ paths:
|
|||
List routes.
|
||||
|
||||
List all available API routes with their methods and implementing providers.
|
||||
parameters: []
|
||||
parameters:
|
||||
- name: api_filter
|
||||
in: query
|
||||
description: >-
|
||||
Optional filter to control which routes are returned. Can be an API level
|
||||
('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level,
|
||||
or 'deprecated' to show deprecated routes across all levels. If not specified,
|
||||
returns only non-deprecated v1 routes.
|
||||
required: false
|
||||
schema:
|
||||
type: string
|
||||
enum:
|
||||
- v1
|
||||
- v1alpha
|
||||
- v1beta
|
||||
- deprecated
|
||||
deprecated: false
|
||||
/v1/models:
|
||||
get:
|
||||
|
|
|
|||
|
|
@ -19,6 +19,8 @@ ARG KEEP_WORKSPACE=""
|
|||
ARG DISTRO_NAME="starter"
|
||||
ARG RUN_CONFIG_PATH=""
|
||||
ARG UV_HTTP_TIMEOUT=500
|
||||
ARG UV_EXTRA_INDEX_URL=""
|
||||
ARG UV_INDEX_STRATEGY=""
|
||||
ENV UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT}
|
||||
ENV PYTHONDONTWRITEBYTECODE=1
|
||||
ENV PIP_DISABLE_PIP_VERSION_CHECK=1
|
||||
|
|
@ -45,7 +47,7 @@ RUN set -eux; \
|
|||
exit 1; \
|
||||
fi
|
||||
|
||||
RUN pip install --no-cache uv
|
||||
RUN pip install --no-cache-dir uv
|
||||
ENV UV_SYSTEM_PYTHON=1
|
||||
|
||||
ENV INSTALL_MODE=${INSTALL_MODE}
|
||||
|
|
@ -62,47 +64,60 @@ COPY . /workspace
|
|||
|
||||
# Install the client package if it is provided
|
||||
# NOTE: this is installed before llama-stack since llama-stack depends on llama-stack-client-python
|
||||
# Unset UV index env vars to ensure we only use PyPI for the client
|
||||
RUN set -eux; \
|
||||
unset UV_EXTRA_INDEX_URL UV_INDEX_STRATEGY; \
|
||||
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then \
|
||||
if [ ! -d "$LLAMA_STACK_CLIENT_DIR" ]; then \
|
||||
echo "LLAMA_STACK_CLIENT_DIR is set but $LLAMA_STACK_CLIENT_DIR does not exist" >&2; \
|
||||
exit 1; \
|
||||
fi; \
|
||||
uv pip install --no-cache -e "$LLAMA_STACK_CLIENT_DIR"; \
|
||||
uv pip install --no-cache-dir -e "$LLAMA_STACK_CLIENT_DIR"; \
|
||||
fi;
|
||||
|
||||
# Install llama-stack
|
||||
# Use UV_EXTRA_INDEX_URL inline only for editable install with RC dependencies
|
||||
RUN set -eux; \
|
||||
SAVED_UV_EXTRA_INDEX_URL="${UV_EXTRA_INDEX_URL:-}"; \
|
||||
SAVED_UV_INDEX_STRATEGY="${UV_INDEX_STRATEGY:-}"; \
|
||||
unset UV_EXTRA_INDEX_URL UV_INDEX_STRATEGY; \
|
||||
if [ "$INSTALL_MODE" = "editable" ]; then \
|
||||
if [ ! -d "$LLAMA_STACK_DIR" ]; then \
|
||||
echo "INSTALL_MODE=editable requires LLAMA_STACK_DIR to point to a directory inside the build context" >&2; \
|
||||
exit 1; \
|
||||
fi; \
|
||||
uv pip install --no-cache -e "$LLAMA_STACK_DIR"; \
|
||||
elif [ "$INSTALL_MODE" = "test-pypi" ]; then \
|
||||
uv pip install --no-cache fastapi libcst; \
|
||||
if [ -n "$TEST_PYPI_VERSION" ]; then \
|
||||
uv pip install --no-cache --extra-index-url https://test.pypi.org/simple/ --index-strategy unsafe-best-match "llama-stack==$TEST_PYPI_VERSION"; \
|
||||
if [ -n "$SAVED_UV_EXTRA_INDEX_URL" ] && [ -n "$SAVED_UV_INDEX_STRATEGY" ]; then \
|
||||
UV_EXTRA_INDEX_URL="$SAVED_UV_EXTRA_INDEX_URL" UV_INDEX_STRATEGY="$SAVED_UV_INDEX_STRATEGY" \
|
||||
uv pip install --no-cache-dir -e "$LLAMA_STACK_DIR"; \
|
||||
else \
|
||||
uv pip install --no-cache --extra-index-url https://test.pypi.org/simple/ --index-strategy unsafe-best-match llama-stack; \
|
||||
uv pip install --no-cache-dir -e "$LLAMA_STACK_DIR"; \
|
||||
fi; \
|
||||
elif [ "$INSTALL_MODE" = "test-pypi" ]; then \
|
||||
uv pip install --no-cache-dir fastapi libcst; \
|
||||
if [ -n "$TEST_PYPI_VERSION" ]; then \
|
||||
uv pip install --no-cache-dir --extra-index-url https://test.pypi.org/simple/ --index-strategy unsafe-best-match "llama-stack==$TEST_PYPI_VERSION"; \
|
||||
else \
|
||||
uv pip install --no-cache-dir --extra-index-url https://test.pypi.org/simple/ --index-strategy unsafe-best-match llama-stack; \
|
||||
fi; \
|
||||
else \
|
||||
if [ -n "$PYPI_VERSION" ]; then \
|
||||
uv pip install --no-cache "llama-stack==$PYPI_VERSION"; \
|
||||
uv pip install --no-cache-dir "llama-stack==$PYPI_VERSION"; \
|
||||
else \
|
||||
uv pip install --no-cache llama-stack; \
|
||||
uv pip install --no-cache-dir llama-stack; \
|
||||
fi; \
|
||||
fi;
|
||||
|
||||
# Install the dependencies for the distribution
|
||||
# Explicitly unset UV index env vars to ensure we only use PyPI for distribution deps
|
||||
RUN set -eux; \
|
||||
unset UV_EXTRA_INDEX_URL UV_INDEX_STRATEGY; \
|
||||
if [ -z "$DISTRO_NAME" ]; then \
|
||||
echo "DISTRO_NAME must be provided" >&2; \
|
||||
exit 1; \
|
||||
fi; \
|
||||
deps="$(llama stack list-deps "$DISTRO_NAME")"; \
|
||||
if [ -n "$deps" ]; then \
|
||||
printf '%s\n' "$deps" | xargs -L1 uv pip install --no-cache; \
|
||||
printf '%s\n' "$deps" | xargs -L1 uv pip install --no-cache-dir; \
|
||||
fi
|
||||
|
||||
# Cleanup
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ NVIDIA inference provider for accessing NVIDIA NIM models and AI services.
|
|||
| `url` | `<class 'str'>` | No | https://integrate.api.nvidia.com | A base url for accessing the NVIDIA NIM |
|
||||
| `timeout` | `<class 'int'>` | No | 60 | Timeout for the HTTP requests |
|
||||
| `append_api_version` | `<class 'bool'>` | No | True | When set to false, the API version will not be appended to the base_url. By default, it is true. |
|
||||
| `rerank_model_to_url` | `dict[str, str` | No | `{'nv-rerank-qa-mistral-4b:1': 'https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking', 'nvidia/nv-rerankqa-mistral-4b-v3': 'https://ai.api.nvidia.com/v1/retrieval/nvidia/nv-rerankqa-mistral-4b-v3/reranking', 'nvidia/llama-3.2-nv-rerankqa-1b-v2': 'https://ai.api.nvidia.com/v1/retrieval/nvidia/llama-3_2-nv-rerankqa-1b-v2/reranking'}` | Mapping of rerank model identifiers to their API endpoints. |
|
||||
|
||||
## Sample Configuration
|
||||
|
||||
|
|
|
|||
18
docs/static/llama-stack-spec.html
vendored
18
docs/static/llama-stack-spec.html
vendored
|
|
@ -1258,7 +1258,23 @@
|
|||
],
|
||||
"summary": "List routes.",
|
||||
"description": "List routes.\nList all available API routes with their methods and implementing providers.",
|
||||
"parameters": [],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "api_filter",
|
||||
"in": "query",
|
||||
"description": "Optional filter to control which routes are returned. Can be an API level ('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level, or 'deprecated' to show deprecated routes across all levels. If not specified, returns only non-deprecated v1 routes.",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"v1",
|
||||
"v1alpha",
|
||||
"v1beta",
|
||||
"deprecated"
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"deprecated": false
|
||||
}
|
||||
},
|
||||
|
|
|
|||
17
docs/static/llama-stack-spec.yaml
vendored
17
docs/static/llama-stack-spec.yaml
vendored
|
|
@ -953,7 +953,22 @@ paths:
|
|||
List routes.
|
||||
|
||||
List all available API routes with their methods and implementing providers.
|
||||
parameters: []
|
||||
parameters:
|
||||
- name: api_filter
|
||||
in: query
|
||||
description: >-
|
||||
Optional filter to control which routes are returned. Can be an API level
|
||||
('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level,
|
||||
or 'deprecated' to show deprecated routes across all levels. If not specified,
|
||||
returns only non-deprecated v1 routes.
|
||||
required: false
|
||||
schema:
|
||||
type: string
|
||||
enum:
|
||||
- v1
|
||||
- v1alpha
|
||||
- v1beta
|
||||
- deprecated
|
||||
deprecated: false
|
||||
/v1/models:
|
||||
get:
|
||||
|
|
|
|||
18
docs/static/stainless-llama-stack-spec.html
vendored
18
docs/static/stainless-llama-stack-spec.html
vendored
|
|
@ -1258,7 +1258,23 @@
|
|||
],
|
||||
"summary": "List routes.",
|
||||
"description": "List routes.\nList all available API routes with their methods and implementing providers.",
|
||||
"parameters": [],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "api_filter",
|
||||
"in": "query",
|
||||
"description": "Optional filter to control which routes are returned. Can be an API level ('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level, or 'deprecated' to show deprecated routes across all levels. If not specified, returns only non-deprecated v1 routes.",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"v1",
|
||||
"v1alpha",
|
||||
"v1beta",
|
||||
"deprecated"
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"deprecated": false
|
||||
}
|
||||
},
|
||||
|
|
|
|||
17
docs/static/stainless-llama-stack-spec.yaml
vendored
17
docs/static/stainless-llama-stack-spec.yaml
vendored
|
|
@ -956,7 +956,22 @@ paths:
|
|||
List routes.
|
||||
|
||||
List all available API routes with their methods and implementing providers.
|
||||
parameters: []
|
||||
parameters:
|
||||
- name: api_filter
|
||||
in: query
|
||||
description: >-
|
||||
Optional filter to control which routes are returned. Can be an API level
|
||||
('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level,
|
||||
or 'deprecated' to show deprecated routes across all levels. If not specified,
|
||||
returns only non-deprecated v1 routes.
|
||||
required: false
|
||||
schema:
|
||||
type: string
|
||||
enum:
|
||||
- v1
|
||||
- v1alpha
|
||||
- v1beta
|
||||
- deprecated
|
||||
deprecated: false
|
||||
/v1/models:
|
||||
get:
|
||||
|
|
|
|||
|
|
@ -215,6 +215,16 @@ build_image() {
|
|||
--build-arg "LLAMA_STACK_DIR=/workspace"
|
||||
)
|
||||
|
||||
# Pass UV index configuration for release branches
|
||||
if [[ -n "${UV_EXTRA_INDEX_URL:-}" ]]; then
|
||||
echo "Adding UV_EXTRA_INDEX_URL to docker build: $UV_EXTRA_INDEX_URL"
|
||||
build_cmd+=(--build-arg "UV_EXTRA_INDEX_URL=$UV_EXTRA_INDEX_URL")
|
||||
fi
|
||||
if [[ -n "${UV_INDEX_STRATEGY:-}" ]]; then
|
||||
echo "Adding UV_INDEX_STRATEGY to docker build: $UV_INDEX_STRATEGY"
|
||||
build_cmd+=(--build-arg "UV_INDEX_STRATEGY=$UV_INDEX_STRATEGY")
|
||||
fi
|
||||
|
||||
if ! "${build_cmd[@]}"; then
|
||||
echo "❌ Failed to build Docker image"
|
||||
exit 1
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ COLLECT_ONLY=false
|
|||
|
||||
# Function to display usage
|
||||
usage() {
|
||||
cat << EOF
|
||||
cat <<EOF
|
||||
Usage: $0 [OPTIONS]
|
||||
|
||||
Options:
|
||||
|
|
@ -102,7 +102,6 @@ while [[ $# -gt 0 ]]; do
|
|||
esac
|
||||
done
|
||||
|
||||
|
||||
# Validate required parameters
|
||||
if [[ -z "$STACK_CONFIG" && "$COLLECT_ONLY" == false ]]; then
|
||||
echo "Error: --stack-config is required"
|
||||
|
|
@ -177,12 +176,12 @@ cd $ROOT_DIR
|
|||
# check if "llama" and "pytest" are available. this script does not use `uv run` given
|
||||
# it can be used in a pre-release environment where we have not been able to tell
|
||||
# uv about pre-release dependencies properly (yet).
|
||||
if [[ "$COLLECT_ONLY" == false ]] && ! command -v llama &> /dev/null; then
|
||||
if [[ "$COLLECT_ONLY" == false ]] && ! command -v llama &>/dev/null; then
|
||||
echo "llama could not be found, ensure llama-stack is installed"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! command -v pytest &> /dev/null; then
|
||||
if ! command -v pytest &>/dev/null; then
|
||||
echo "pytest could not be found, ensure pytest is installed"
|
||||
exit 1
|
||||
fi
|
||||
|
|
@ -208,9 +207,18 @@ if [[ "$STACK_CONFIG" == *"server:"* && "$COLLECT_ONLY" == false ]]; then
|
|||
echo "=== Starting Llama Stack Server ==="
|
||||
export LLAMA_STACK_LOG_WIDTH=120
|
||||
|
||||
# Configure telemetry collector for server mode
|
||||
# Use a fixed port for the OTEL collector so the server can connect to it
|
||||
COLLECTOR_PORT=4317
|
||||
export LLAMA_STACK_TEST_COLLECTOR_PORT="${COLLECTOR_PORT}"
|
||||
export OTEL_EXPORTER_OTLP_ENDPOINT="http://127.0.0.1:${COLLECTOR_PORT}"
|
||||
export OTEL_EXPORTER_OTLP_PROTOCOL="http/protobuf"
|
||||
export OTEL_BSP_SCHEDULE_DELAY="200"
|
||||
export OTEL_BSP_EXPORT_TIMEOUT="2000"
|
||||
|
||||
# remove "server:" from STACK_CONFIG
|
||||
stack_config=$(echo "$STACK_CONFIG" | sed 's/^server://')
|
||||
nohup llama stack run $stack_config > server.log 2>&1 &
|
||||
nohup llama stack run $stack_config >server.log 2>&1 &
|
||||
|
||||
echo "Waiting for Llama Stack Server to start..."
|
||||
for i in {1..30}; do
|
||||
|
|
@ -239,7 +247,7 @@ if [[ "$STACK_CONFIG" == *"docker:"* && "$COLLECT_ONLY" == false ]]; then
|
|||
container_name="llama-stack-test-$DISTRO"
|
||||
if docker ps -a --format '{{.Names}}' | grep -q "^${container_name}$"; then
|
||||
echo "Dumping container logs before stopping..."
|
||||
docker logs "$container_name" > "docker-${DISTRO}-${INFERENCE_MODE}.log" 2>&1 || true
|
||||
docker logs "$container_name" >"docker-${DISTRO}-${INFERENCE_MODE}.log" 2>&1 || true
|
||||
echo "Stopping and removing container: $container_name"
|
||||
docker stop "$container_name" 2>/dev/null || true
|
||||
docker rm "$container_name" 2>/dev/null || true
|
||||
|
|
@ -271,6 +279,16 @@ if [[ "$STACK_CONFIG" == *"docker:"* && "$COLLECT_ONLY" == false ]]; then
|
|||
--build-arg "LLAMA_STACK_DIR=/workspace"
|
||||
)
|
||||
|
||||
# Pass UV index configuration for release branches
|
||||
if [[ -n "${UV_EXTRA_INDEX_URL:-}" ]]; then
|
||||
echo "Adding UV_EXTRA_INDEX_URL to docker build: $UV_EXTRA_INDEX_URL"
|
||||
build_cmd+=(--build-arg "UV_EXTRA_INDEX_URL=$UV_EXTRA_INDEX_URL")
|
||||
fi
|
||||
if [[ -n "${UV_INDEX_STRATEGY:-}" ]]; then
|
||||
echo "Adding UV_INDEX_STRATEGY to docker build: $UV_INDEX_STRATEGY"
|
||||
build_cmd+=(--build-arg "UV_INDEX_STRATEGY=$UV_INDEX_STRATEGY")
|
||||
fi
|
||||
|
||||
if ! "${build_cmd[@]}"; then
|
||||
echo "❌ Failed to build Docker image"
|
||||
exit 1
|
||||
|
|
@ -428,17 +446,13 @@ elif [ $exit_code -eq 5 ]; then
|
|||
else
|
||||
echo "❌ Tests failed"
|
||||
echo ""
|
||||
echo "=== Dumping last 100 lines of logs for debugging ==="
|
||||
|
||||
# Output server or container logs based on stack config
|
||||
if [[ "$STACK_CONFIG" == *"server:"* && -f "server.log" ]]; then
|
||||
echo "--- Last 100 lines of server.log ---"
|
||||
tail -100 server.log
|
||||
echo "--- Server side failures can be located inside server.log (available from artifacts on CI) ---"
|
||||
elif [[ "$STACK_CONFIG" == *"docker:"* ]]; then
|
||||
docker_log_file="docker-${DISTRO}-${INFERENCE_MODE}.log"
|
||||
if [[ -f "$docker_log_file" ]]; then
|
||||
echo "--- Last 100 lines of $docker_log_file ---"
|
||||
tail -100 "$docker_log_file"
|
||||
echo "--- Server side failures can be located inside $docker_log_file (available from artifacts on CI) ---"
|
||||
fi
|
||||
fi
|
||||
|
||||
|
|
|
|||
42
scripts/uv-run-with-index.sh
Executable file
42
scripts/uv-run-with-index.sh
Executable file
|
|
@ -0,0 +1,42 @@
|
|||
#!/bin/bash
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# Detect current branch and target branch
|
||||
# In GitHub Actions, use GITHUB_REF/GITHUB_BASE_REF
|
||||
if [[ -n "${GITHUB_REF:-}" ]]; then
|
||||
BRANCH="${GITHUB_REF#refs/heads/}"
|
||||
else
|
||||
BRANCH=$(git rev-parse --abbrev-ref HEAD 2>/dev/null || echo "")
|
||||
fi
|
||||
|
||||
# For PRs, check the target branch
|
||||
if [[ -n "${GITHUB_BASE_REF:-}" ]]; then
|
||||
TARGET_BRANCH="${GITHUB_BASE_REF}"
|
||||
else
|
||||
TARGET_BRANCH=$(git rev-parse --abbrev-ref HEAD@{upstream} 2>/dev/null | sed 's|origin/||' || echo "")
|
||||
fi
|
||||
|
||||
# Check if on a release branch or targeting one, or LLAMA_STACK_RELEASE_MODE is set
|
||||
IS_RELEASE=false
|
||||
if [[ "$BRANCH" =~ ^release-[0-9]+\.[0-9]+\.x$ ]]; then
|
||||
IS_RELEASE=true
|
||||
elif [[ "$TARGET_BRANCH" =~ ^release-[0-9]+\.[0-9]+\.x$ ]]; then
|
||||
IS_RELEASE=true
|
||||
elif [[ "${LLAMA_STACK_RELEASE_MODE:-}" == "true" ]]; then
|
||||
IS_RELEASE=true
|
||||
fi
|
||||
|
||||
# On release branches, use test.pypi as extra index for RC versions
|
||||
if [[ "$IS_RELEASE" == "true" ]]; then
|
||||
export UV_EXTRA_INDEX_URL="https://test.pypi.org/simple/"
|
||||
export UV_INDEX_STRATEGY="unsafe-best-match"
|
||||
fi
|
||||
|
||||
# Run uv with all arguments passed through
|
||||
exec uv "$@"
|
||||
|
|
@ -4,14 +4,21 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Protocol, runtime_checkable
|
||||
from typing import Literal, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.apis.version import (
|
||||
LLAMA_STACK_API_V1,
|
||||
)
|
||||
from llama_stack.providers.datatypes import HealthStatus
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
# Valid values for the route filter parameter.
|
||||
# Actual API levels: v1, v1alpha, v1beta (filters by level, excludes deprecated)
|
||||
# Special filter value: "deprecated" (shows deprecated routes regardless of level)
|
||||
ApiFilter = Literal["v1", "v1alpha", "v1beta", "deprecated"]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RouteInfo(BaseModel):
|
||||
|
|
@ -64,11 +71,12 @@ class Inspect(Protocol):
|
|||
"""
|
||||
|
||||
@webmethod(route="/inspect/routes", method="GET", level=LLAMA_STACK_API_V1)
|
||||
async def list_routes(self) -> ListRoutesResponse:
|
||||
async def list_routes(self, api_filter: ApiFilter | None = None) -> ListRoutesResponse:
|
||||
"""List routes.
|
||||
|
||||
List all available API routes with their methods and implementing providers.
|
||||
|
||||
:param api_filter: Optional filter to control which routes are returned. Can be an API level ('v1', 'v1alpha', 'v1beta') to show non-deprecated routes at that level, or 'deprecated' to show deprecated routes across all levels. If not specified, returns only non-deprecated v1 routes.
|
||||
:returns: Response containing information about all available routes.
|
||||
"""
|
||||
...
|
||||
|
|
|
|||
|
|
@ -13,11 +13,23 @@ from pathlib import Path
|
|||
|
||||
import uvicorn
|
||||
import yaml
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.cli.stack.utils import ImageType
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.core.datatypes import Api, Provider, StackRunConfig
|
||||
from llama_stack.core.distribution import get_provider_registry
|
||||
from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
|
||||
from llama_stack.core.storage.datatypes import (
|
||||
InferenceStoreReference,
|
||||
KVStoreReference,
|
||||
ServerStoresConfig,
|
||||
SqliteKVStoreConfig,
|
||||
SqliteSqlStoreConfig,
|
||||
SqlStoreReference,
|
||||
StorageConfig,
|
||||
)
|
||||
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_distro
|
||||
from llama_stack.log import LoggingConfig, get_logger
|
||||
|
||||
|
|
@ -69,6 +81,12 @@ class StackRun(Subcommand):
|
|||
action="store_true",
|
||||
help="Start the UI server",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
"--providers",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Run a stack with only a list of providers. This list is formatted like: api1=provider1,api1=provider2,api2=provider3. Where there can be multiple providers per API.",
|
||||
)
|
||||
|
||||
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
|
||||
import yaml
|
||||
|
|
@ -94,6 +112,49 @@ class StackRun(Subcommand):
|
|||
config_file = resolve_config_or_distro(args.config, Mode.RUN)
|
||||
except ValueError as e:
|
||||
self.parser.error(str(e))
|
||||
elif args.providers:
|
||||
provider_list: dict[str, list[Provider]] = dict()
|
||||
for api_provider in args.providers.split(","):
|
||||
if "=" not in api_provider:
|
||||
cprint(
|
||||
"Could not parse `--providers`. Please ensure the list is in the format api1=provider1,api2=provider2",
|
||||
color="red",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
api, provider_type = api_provider.split("=")
|
||||
providers_for_api = get_provider_registry().get(Api(api), None)
|
||||
if providers_for_api is None:
|
||||
cprint(
|
||||
f"{api} is not a valid API.",
|
||||
color="red",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
if provider_type in providers_for_api:
|
||||
provider = Provider(
|
||||
provider_type=provider_type,
|
||||
provider_id=provider_type.split("::")[1],
|
||||
)
|
||||
provider_list.setdefault(api, []).append(provider)
|
||||
else:
|
||||
cprint(
|
||||
f"{provider} is not a valid provider for the {api} API.",
|
||||
color="red",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
run_config = self._generate_run_config_from_providers(providers=provider_list)
|
||||
config_dict = run_config.model_dump(mode="json")
|
||||
|
||||
# Write config to disk in providers-run directory
|
||||
distro_dir = DISTRIBS_BASE_DIR / "providers-run"
|
||||
config_file = distro_dir / "run.yaml"
|
||||
|
||||
logger.info(f"Writing generated config to: {config_file}")
|
||||
with open(config_file, "w") as f:
|
||||
yaml.dump(config_dict, f, default_flow_style=False, sort_keys=False)
|
||||
|
||||
else:
|
||||
config_file = None
|
||||
|
||||
|
|
@ -107,7 +168,8 @@ class StackRun(Subcommand):
|
|||
|
||||
try:
|
||||
config = parse_and_maybe_upgrade_config(config_dict)
|
||||
if not os.path.exists(str(config.external_providers_dir)):
|
||||
# Create external_providers_dir if it's specified and doesn't exist
|
||||
if config.external_providers_dir and not os.path.exists(str(config.external_providers_dir)):
|
||||
os.makedirs(str(config.external_providers_dir), exist_ok=True)
|
||||
except AttributeError as e:
|
||||
self.parser.error(f"failed to parse config file '{config_file}':\n {e}")
|
||||
|
|
@ -128,7 +190,7 @@ class StackRun(Subcommand):
|
|||
config = StackRunConfig(**cast_image_name_to_string(replace_env_vars(config_contents)))
|
||||
|
||||
port = args.port or config.server.port
|
||||
host = config.server.host or ["::", "0.0.0.0"]
|
||||
host = config.server.host or "0.0.0.0"
|
||||
|
||||
# Set the config file in environment so create_app can find it
|
||||
os.environ["LLAMA_STACK_CONFIG"] = str(config_file)
|
||||
|
|
@ -140,6 +202,7 @@ class StackRun(Subcommand):
|
|||
"lifespan": "on",
|
||||
"log_level": logger.getEffectiveLevel(),
|
||||
"log_config": logger_config,
|
||||
"workers": config.server.workers,
|
||||
}
|
||||
|
||||
keyfile = config.server.tls_keyfile
|
||||
|
|
@ -340,3 +403,44 @@ class StackRun(Subcommand):
|
|||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start UI development server in {ui_dir}: {e}")
|
||||
|
||||
def _generate_run_config_from_providers(self, providers: dict[str, list[Provider]]):
|
||||
apis = list(providers.keys())
|
||||
distro_dir = DISTRIBS_BASE_DIR / "providers-run"
|
||||
# need somewhere to put the storage.
|
||||
os.makedirs(distro_dir, exist_ok=True)
|
||||
storage = StorageConfig(
|
||||
backends={
|
||||
"kv_default": SqliteKVStoreConfig(
|
||||
db_path=f"${{env.SQLITE_STORE_DIR:={distro_dir}}}/kvstore.db",
|
||||
),
|
||||
"sql_default": SqliteSqlStoreConfig(
|
||||
db_path=f"${{env.SQLITE_STORE_DIR:={distro_dir}}}/sql_store.db",
|
||||
),
|
||||
},
|
||||
stores=ServerStoresConfig(
|
||||
metadata=KVStoreReference(
|
||||
backend="kv_default",
|
||||
namespace="registry",
|
||||
),
|
||||
inference=InferenceStoreReference(
|
||||
backend="sql_default",
|
||||
table_name="inference_store",
|
||||
),
|
||||
conversations=SqlStoreReference(
|
||||
backend="sql_default",
|
||||
table_name="openai_conversations",
|
||||
),
|
||||
prompts=KVStoreReference(
|
||||
backend="kv_default",
|
||||
namespace="prompts",
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
return StackRunConfig(
|
||||
image_name="providers-run",
|
||||
apis=apis,
|
||||
providers=providers,
|
||||
storage=storage,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ from llama_stack.core.distribution import (
|
|||
get_provider_registry,
|
||||
)
|
||||
from llama_stack.core.stack import cast_image_name_to_string, replace_env_vars
|
||||
from llama_stack.core.utils.config_dirs import EXTERNAL_PROVIDERS_DIR
|
||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.core.utils.prompt_for_config import prompt_for_config
|
||||
from llama_stack.log import get_logger
|
||||
|
|
@ -194,19 +193,11 @@ def upgrade_from_routing_table(
|
|||
|
||||
|
||||
def parse_and_maybe_upgrade_config(config_dict: dict[str, Any]) -> StackRunConfig:
|
||||
version = config_dict.get("version", None)
|
||||
if version == LLAMA_STACK_RUN_CONFIG_VERSION:
|
||||
processed_config_dict = replace_env_vars(config_dict)
|
||||
return StackRunConfig(**cast_image_name_to_string(processed_config_dict))
|
||||
|
||||
if "routing_table" in config_dict:
|
||||
logger.info("Upgrading config...")
|
||||
config_dict = upgrade_from_routing_table(config_dict)
|
||||
|
||||
config_dict["version"] = LLAMA_STACK_RUN_CONFIG_VERSION
|
||||
|
||||
if not config_dict.get("external_providers_dir", None):
|
||||
config_dict["external_providers_dir"] = EXTERNAL_PROVIDERS_DIR
|
||||
|
||||
processed_config_dict = replace_env_vars(config_dict)
|
||||
return StackRunConfig(**cast_image_name_to_string(processed_config_dict))
|
||||
|
|
|
|||
|
|
@ -473,6 +473,10 @@ class ServerConfig(BaseModel):
|
|||
"- true: Enable localhost CORS for development\n"
|
||||
"- {allow_origins: [...], allow_methods: [...], ...}: Full configuration",
|
||||
)
|
||||
workers: int = Field(
|
||||
default=1,
|
||||
description="Number of workers to use for the server",
|
||||
)
|
||||
|
||||
|
||||
class StackRunConfig(BaseModel):
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from llama_stack.apis.inspect import (
|
|||
RouteInfo,
|
||||
VersionInfo,
|
||||
)
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.core.external import load_external_apis
|
||||
from llama_stack.core.server.routes import get_all_api_routes
|
||||
|
|
@ -39,9 +40,21 @@ class DistributionInspectImpl(Inspect):
|
|||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_routes(self) -> ListRoutesResponse:
|
||||
async def list_routes(self, api_filter: str | None = None) -> ListRoutesResponse:
|
||||
run_config: StackRunConfig = self.config.run_config
|
||||
|
||||
# Helper function to determine if a route should be included based on api_filter
|
||||
def should_include_route(webmethod) -> bool:
|
||||
if api_filter is None:
|
||||
# Default: only non-deprecated v1 APIs
|
||||
return not webmethod.deprecated and webmethod.level == LLAMA_STACK_API_V1
|
||||
elif api_filter == "deprecated":
|
||||
# Special filter: show deprecated routes regardless of their actual level
|
||||
return bool(webmethod.deprecated)
|
||||
else:
|
||||
# Filter by API level (non-deprecated routes only)
|
||||
return not webmethod.deprecated and webmethod.level == api_filter
|
||||
|
||||
ret = []
|
||||
external_apis = load_external_apis(run_config)
|
||||
all_endpoints = get_all_api_routes(external_apis)
|
||||
|
|
@ -55,8 +68,8 @@ class DistributionInspectImpl(Inspect):
|
|||
method=next(iter([m for m in e.methods if m != "HEAD"])),
|
||||
provider_types=[], # These APIs don't have "real" providers - they're internal to the stack
|
||||
)
|
||||
for e, _ in endpoints
|
||||
if e.methods is not None
|
||||
for e, webmethod in endpoints
|
||||
if e.methods is not None and should_include_route(webmethod)
|
||||
]
|
||||
)
|
||||
else:
|
||||
|
|
@ -69,8 +82,8 @@ class DistributionInspectImpl(Inspect):
|
|||
method=next(iter([m for m in e.methods if m != "HEAD"])),
|
||||
provider_types=[p.provider_type for p in providers],
|
||||
)
|
||||
for e, _ in endpoints
|
||||
if e.methods is not None
|
||||
for e, webmethod in endpoints
|
||||
if e.methods is not None and should_include_route(webmethod)
|
||||
]
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -181,3 +181,22 @@ vlm_response = client.chat.completions.create(
|
|||
|
||||
print(f"VLM Response: {vlm_response.choices[0].message.content}")
|
||||
```
|
||||
|
||||
### Rerank Example
|
||||
|
||||
The following example shows how to rerank documents using an NVIDIA NIM.
|
||||
|
||||
```python
|
||||
rerank_response = client.alpha.inference.rerank(
|
||||
model="nvidia/nvidia/llama-3.2-nv-rerankqa-1b-v2",
|
||||
query="query",
|
||||
items=[
|
||||
"item_1",
|
||||
"item_2",
|
||||
"item_3",
|
||||
],
|
||||
)
|
||||
|
||||
for i, result in enumerate(rerank_response):
|
||||
print(f"{i+1}. [Index: {result.index}, " f"Score: {(result.relevance_score):.3f}]")
|
||||
```
|
||||
|
|
@ -28,6 +28,7 @@ class NVIDIAConfig(RemoteInferenceProviderConfig):
|
|||
Attributes:
|
||||
url (str): A base url for accessing the NVIDIA NIM, e.g. http://localhost:8000
|
||||
api_key (str): The access key for the hosted NIM endpoints
|
||||
rerank_model_to_url (dict[str, str]): Mapping of rerank model identifiers to their API endpoints
|
||||
|
||||
There are two ways to access NVIDIA NIMs -
|
||||
0. Hosted: Preview APIs hosted at https://integrate.api.nvidia.com
|
||||
|
|
@ -55,6 +56,14 @@ class NVIDIAConfig(RemoteInferenceProviderConfig):
|
|||
default_factory=lambda: os.getenv("NVIDIA_APPEND_API_VERSION", "True").lower() != "false",
|
||||
description="When set to false, the API version will not be appended to the base_url. By default, it is true.",
|
||||
)
|
||||
rerank_model_to_url: dict[str, str] = Field(
|
||||
default_factory=lambda: {
|
||||
"nv-rerank-qa-mistral-4b:1": "https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking",
|
||||
"nvidia/nv-rerankqa-mistral-4b-v3": "https://ai.api.nvidia.com/v1/retrieval/nvidia/nv-rerankqa-mistral-4b-v3/reranking",
|
||||
"nvidia/llama-3.2-nv-rerankqa-1b-v2": "https://ai.api.nvidia.com/v1/retrieval/nvidia/llama-3_2-nv-rerankqa-1b-v2/reranking",
|
||||
},
|
||||
description="Mapping of rerank model identifiers to their API endpoints. ",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
|
|
|
|||
|
|
@ -5,6 +5,19 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from collections.abc import Iterable
|
||||
|
||||
import aiohttp
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
RerankData,
|
||||
RerankResponse,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import (
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
|
|
@ -61,3 +74,101 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
|
|||
:return: The NVIDIA API base URL
|
||||
"""
|
||||
return f"{self.config.url}/v1" if self.config.append_api_version else self.config.url
|
||||
|
||||
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||
"""
|
||||
Return both dynamic model IDs and statically configured rerank model IDs.
|
||||
"""
|
||||
dynamic_ids: Iterable[str] = []
|
||||
try:
|
||||
dynamic_ids = await super().list_provider_model_ids()
|
||||
except Exception:
|
||||
# If the dynamic listing fails, proceed with just configured rerank IDs
|
||||
dynamic_ids = []
|
||||
|
||||
configured_rerank_ids = list(self.config.rerank_model_to_url.keys())
|
||||
return list(dict.fromkeys(list(dynamic_ids) + configured_rerank_ids)) # remove duplicates
|
||||
|
||||
def construct_model_from_identifier(self, identifier: str) -> Model:
|
||||
"""
|
||||
Classify rerank models from config; otherwise use the base behavior.
|
||||
"""
|
||||
if identifier in self.config.rerank_model_to_url:
|
||||
return Model(
|
||||
provider_id=self.__provider_id__, # type: ignore[attr-defined]
|
||||
provider_resource_id=identifier,
|
||||
identifier=identifier,
|
||||
model_type=ModelType.rerank,
|
||||
)
|
||||
return super().construct_model_from_identifier(identifier)
|
||||
|
||||
async def rerank(
|
||||
self,
|
||||
model: str,
|
||||
query: str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam,
|
||||
items: list[str | OpenAIChatCompletionContentPartTextParam | OpenAIChatCompletionContentPartImageParam],
|
||||
max_num_results: int | None = None,
|
||||
) -> RerankResponse:
|
||||
provider_model_id = await self._get_provider_model_id(model)
|
||||
|
||||
ranking_url = self.get_base_url()
|
||||
|
||||
if _is_nvidia_hosted(self.config) and provider_model_id in self.config.rerank_model_to_url:
|
||||
ranking_url = self.config.rerank_model_to_url[provider_model_id]
|
||||
|
||||
logger.debug(f"Using rerank endpoint: {ranking_url} for model: {provider_model_id}")
|
||||
|
||||
# Convert query to text format
|
||||
if isinstance(query, str):
|
||||
query_text = query
|
||||
elif isinstance(query, OpenAIChatCompletionContentPartTextParam):
|
||||
query_text = query.text
|
||||
else:
|
||||
raise ValueError("Query must be a string or text content part")
|
||||
|
||||
# Convert items to text format
|
||||
passages = []
|
||||
for item in items:
|
||||
if isinstance(item, str):
|
||||
passages.append({"text": item})
|
||||
elif isinstance(item, OpenAIChatCompletionContentPartTextParam):
|
||||
passages.append({"text": item.text})
|
||||
else:
|
||||
raise ValueError("Items must be strings or text content parts")
|
||||
|
||||
payload = {
|
||||
"model": provider_model_id,
|
||||
"query": {"text": query_text},
|
||||
"passages": passages,
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.get_api_key()}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(ranking_url, headers=headers, json=payload) as response:
|
||||
if response.status != 200:
|
||||
response_text = await response.text()
|
||||
raise ConnectionError(
|
||||
f"NVIDIA rerank API request failed with status {response.status}: {response_text}"
|
||||
)
|
||||
|
||||
result = await response.json()
|
||||
rankings = result.get("rankings", [])
|
||||
|
||||
# Convert to RerankData format
|
||||
rerank_data = []
|
||||
for ranking in rankings:
|
||||
rerank_data.append(RerankData(index=ranking["index"], relevance_score=ranking["logit"]))
|
||||
|
||||
# Apply max_num_results limit
|
||||
if max_num_results is not None:
|
||||
rerank_data = rerank_data[:max_num_results]
|
||||
|
||||
return RerankResponse(data=rerank_data)
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
raise ConnectionError(f"Failed to connect to NVIDIA rerank API at {ranking_url}: {e}") from e
|
||||
|
|
|
|||
|
|
@ -51,7 +51,11 @@ async function proxyRequest(request: NextRequest, method: string) {
|
|||
);
|
||||
|
||||
// Create response with same status and headers
|
||||
const proxyResponse = new NextResponse(responseText, {
|
||||
// Handle 204 No Content responses specially
|
||||
const proxyResponse =
|
||||
response.status === 204
|
||||
? new NextResponse(null, { status: 204 })
|
||||
: new NextResponse(responseText, {
|
||||
status: response.status,
|
||||
statusText: response.statusText,
|
||||
});
|
||||
|
|
|
|||
5
src/llama_stack/ui/app/prompts/page.tsx
Normal file
5
src/llama_stack/ui/app/prompts/page.tsx
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
import { PromptManagement } from "@/components/prompts";
|
||||
|
||||
export default function PromptsPage() {
|
||||
return <PromptManagement />;
|
||||
}
|
||||
|
|
@ -8,6 +8,7 @@ import {
|
|||
MessageCircle,
|
||||
Settings2,
|
||||
Compass,
|
||||
FileText,
|
||||
} from "lucide-react";
|
||||
import Link from "next/link";
|
||||
import { usePathname } from "next/navigation";
|
||||
|
|
@ -50,6 +51,11 @@ const manageItems = [
|
|||
url: "/logs/vector-stores",
|
||||
icon: Database,
|
||||
},
|
||||
{
|
||||
title: "Prompts",
|
||||
url: "/prompts",
|
||||
icon: FileText,
|
||||
},
|
||||
{
|
||||
title: "Documentation",
|
||||
url: "https://llama-stack.readthedocs.io/en/latest/references/api_reference/index.html",
|
||||
|
|
|
|||
4
src/llama_stack/ui/components/prompts/index.ts
Normal file
4
src/llama_stack/ui/components/prompts/index.ts
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
export { PromptManagement } from "./prompt-management";
|
||||
export { PromptList } from "./prompt-list";
|
||||
export { PromptEditor } from "./prompt-editor";
|
||||
export * from "./types";
|
||||
309
src/llama_stack/ui/components/prompts/prompt-editor.test.tsx
Normal file
309
src/llama_stack/ui/components/prompts/prompt-editor.test.tsx
Normal file
|
|
@ -0,0 +1,309 @@
|
|||
import React from "react";
|
||||
import { render, screen, fireEvent } from "@testing-library/react";
|
||||
import "@testing-library/jest-dom";
|
||||
import { PromptEditor } from "./prompt-editor";
|
||||
import type { Prompt, PromptFormData } from "./types";
|
||||
|
||||
describe("PromptEditor", () => {
|
||||
const mockOnSave = jest.fn();
|
||||
const mockOnCancel = jest.fn();
|
||||
const mockOnDelete = jest.fn();
|
||||
|
||||
const defaultProps = {
|
||||
onSave: mockOnSave,
|
||||
onCancel: mockOnCancel,
|
||||
onDelete: mockOnDelete,
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe("Create Mode", () => {
|
||||
test("renders create form correctly", () => {
|
||||
render(<PromptEditor {...defaultProps} />);
|
||||
|
||||
expect(screen.getByLabelText("Prompt Content *")).toBeInTheDocument();
|
||||
expect(screen.getByText("Variables")).toBeInTheDocument();
|
||||
expect(screen.getByText("Preview")).toBeInTheDocument();
|
||||
expect(screen.getByText("Create Prompt")).toBeInTheDocument();
|
||||
expect(screen.getByText("Cancel")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("shows preview placeholder when no content", () => {
|
||||
render(<PromptEditor {...defaultProps} />);
|
||||
|
||||
expect(
|
||||
screen.getByText("Enter content to preview the compiled prompt")
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("submits form with correct data", () => {
|
||||
render(<PromptEditor {...defaultProps} />);
|
||||
|
||||
const promptInput = screen.getByLabelText("Prompt Content *");
|
||||
fireEvent.change(promptInput, {
|
||||
target: { value: "Hello {{name}}, welcome!" },
|
||||
});
|
||||
|
||||
fireEvent.click(screen.getByText("Create Prompt"));
|
||||
|
||||
expect(mockOnSave).toHaveBeenCalledWith({
|
||||
prompt: "Hello {{name}}, welcome!",
|
||||
variables: [],
|
||||
});
|
||||
});
|
||||
|
||||
test("prevents submission with empty prompt", () => {
|
||||
render(<PromptEditor {...defaultProps} />);
|
||||
|
||||
fireEvent.click(screen.getByText("Create Prompt"));
|
||||
|
||||
expect(mockOnSave).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Edit Mode", () => {
|
||||
const mockPrompt: Prompt = {
|
||||
prompt_id: "prompt_123",
|
||||
prompt: "Hello {{name}}, how is {{weather}}?",
|
||||
version: 1,
|
||||
variables: ["name", "weather"],
|
||||
is_default: true,
|
||||
};
|
||||
|
||||
test("renders edit form with existing data", () => {
|
||||
render(<PromptEditor {...defaultProps} prompt={mockPrompt} />);
|
||||
|
||||
expect(
|
||||
screen.getByDisplayValue("Hello {{name}}, how is {{weather}}?")
|
||||
).toBeInTheDocument();
|
||||
expect(screen.getAllByText("name")).toHaveLength(2); // One in variables, one in preview
|
||||
expect(screen.getAllByText("weather")).toHaveLength(2); // One in variables, one in preview
|
||||
expect(screen.getByText("Update Prompt")).toBeInTheDocument();
|
||||
expect(screen.getByText("Delete Prompt")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("submits updated data correctly", () => {
|
||||
render(<PromptEditor {...defaultProps} prompt={mockPrompt} />);
|
||||
|
||||
const promptInput = screen.getByLabelText("Prompt Content *");
|
||||
fireEvent.change(promptInput, {
|
||||
target: { value: "Updated: Hello {{name}}!" },
|
||||
});
|
||||
|
||||
fireEvent.click(screen.getByText("Update Prompt"));
|
||||
|
||||
expect(mockOnSave).toHaveBeenCalledWith({
|
||||
prompt: "Updated: Hello {{name}}!",
|
||||
variables: ["name", "weather"],
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("Variables Management", () => {
|
||||
test("adds new variable", () => {
|
||||
render(<PromptEditor {...defaultProps} />);
|
||||
|
||||
const variableInput = screen.getByPlaceholderText(
|
||||
"Add variable name (e.g. user_name, topic)"
|
||||
);
|
||||
fireEvent.change(variableInput, { target: { value: "testVar" } });
|
||||
fireEvent.click(screen.getByText("Add"));
|
||||
|
||||
expect(screen.getByText("testVar")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("prevents adding duplicate variables", () => {
|
||||
render(<PromptEditor {...defaultProps} />);
|
||||
|
||||
const variableInput = screen.getByPlaceholderText(
|
||||
"Add variable name (e.g. user_name, topic)"
|
||||
);
|
||||
|
||||
// Add first variable
|
||||
fireEvent.change(variableInput, { target: { value: "test" } });
|
||||
fireEvent.click(screen.getByText("Add"));
|
||||
|
||||
// Try to add same variable again
|
||||
fireEvent.change(variableInput, { target: { value: "test" } });
|
||||
|
||||
// Button should be disabled
|
||||
expect(screen.getByText("Add")).toBeDisabled();
|
||||
});
|
||||
|
||||
test("removes variable", () => {
|
||||
const mockPrompt: Prompt = {
|
||||
prompt_id: "prompt_123",
|
||||
prompt: "Hello {{name}}",
|
||||
version: 1,
|
||||
variables: ["name", "location"],
|
||||
is_default: true,
|
||||
};
|
||||
|
||||
render(<PromptEditor {...defaultProps} prompt={mockPrompt} />);
|
||||
|
||||
// Check that both variables are present initially
|
||||
expect(screen.getAllByText("name").length).toBeGreaterThan(0);
|
||||
expect(screen.getAllByText("location").length).toBeGreaterThan(0);
|
||||
|
||||
// Remove the location variable by clicking the X button with the specific title
|
||||
const removeLocationButton = screen.getByTitle(
|
||||
"Remove location variable"
|
||||
);
|
||||
fireEvent.click(removeLocationButton);
|
||||
|
||||
// Name should still be there, location should be gone from the variables section
|
||||
expect(screen.getAllByText("name").length).toBeGreaterThan(0);
|
||||
expect(
|
||||
screen.queryByTitle("Remove location variable")
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("adds variable on Enter key", () => {
|
||||
render(<PromptEditor {...defaultProps} />);
|
||||
|
||||
const variableInput = screen.getByPlaceholderText(
|
||||
"Add variable name (e.g. user_name, topic)"
|
||||
);
|
||||
fireEvent.change(variableInput, { target: { value: "enterVar" } });
|
||||
|
||||
// Simulate Enter key press
|
||||
fireEvent.keyPress(variableInput, {
|
||||
key: "Enter",
|
||||
code: "Enter",
|
||||
charCode: 13,
|
||||
preventDefault: jest.fn(),
|
||||
});
|
||||
|
||||
// Check if the variable was added by looking for the badge
|
||||
expect(screen.getAllByText("enterVar").length).toBeGreaterThan(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe("Preview Functionality", () => {
|
||||
test("shows live preview with variables", () => {
|
||||
render(<PromptEditor {...defaultProps} />);
|
||||
|
||||
// Add prompt content
|
||||
const promptInput = screen.getByLabelText("Prompt Content *");
|
||||
fireEvent.change(promptInput, {
|
||||
target: { value: "Hello {{name}}, welcome to {{place}}!" },
|
||||
});
|
||||
|
||||
// Add variables
|
||||
const variableInput = screen.getByPlaceholderText(
|
||||
"Add variable name (e.g. user_name, topic)"
|
||||
);
|
||||
fireEvent.change(variableInput, { target: { value: "name" } });
|
||||
fireEvent.click(screen.getByText("Add"));
|
||||
|
||||
fireEvent.change(variableInput, { target: { value: "place" } });
|
||||
fireEvent.click(screen.getByText("Add"));
|
||||
|
||||
// Check that preview area shows the content
|
||||
expect(screen.getByText("Compiled Prompt")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("shows variable value inputs in preview", () => {
|
||||
const mockPrompt: Prompt = {
|
||||
prompt_id: "prompt_123",
|
||||
prompt: "Hello {{name}}",
|
||||
version: 1,
|
||||
variables: ["name"],
|
||||
is_default: true,
|
||||
};
|
||||
|
||||
render(<PromptEditor {...defaultProps} prompt={mockPrompt} />);
|
||||
|
||||
expect(screen.getByText("Variable Values")).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByPlaceholderText("Enter value for name")
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("shows color legend for variable states", () => {
|
||||
render(<PromptEditor {...defaultProps} />);
|
||||
|
||||
// Add content to show preview
|
||||
const promptInput = screen.getByLabelText("Prompt Content *");
|
||||
fireEvent.change(promptInput, {
|
||||
target: { value: "Hello {{name}}" },
|
||||
});
|
||||
|
||||
expect(screen.getByText("Used")).toBeInTheDocument();
|
||||
expect(screen.getByText("Unused")).toBeInTheDocument();
|
||||
expect(screen.getByText("Undefined")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Error Handling", () => {
|
||||
test("displays error message", () => {
|
||||
const errorMessage = "Prompt contains undeclared variables";
|
||||
render(<PromptEditor {...defaultProps} error={errorMessage} />);
|
||||
|
||||
expect(screen.getByText(errorMessage)).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Delete Functionality", () => {
|
||||
const mockPrompt: Prompt = {
|
||||
prompt_id: "prompt_123",
|
||||
prompt: "Hello {{name}}",
|
||||
version: 1,
|
||||
variables: ["name"],
|
||||
is_default: true,
|
||||
};
|
||||
|
||||
test("shows delete button in edit mode", () => {
|
||||
render(<PromptEditor {...defaultProps} prompt={mockPrompt} />);
|
||||
|
||||
expect(screen.getByText("Delete Prompt")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("hides delete button in create mode", () => {
|
||||
render(<PromptEditor {...defaultProps} />);
|
||||
|
||||
expect(screen.queryByText("Delete Prompt")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("calls onDelete with confirmation", () => {
|
||||
const originalConfirm = window.confirm;
|
||||
window.confirm = jest.fn(() => true);
|
||||
|
||||
render(<PromptEditor {...defaultProps} prompt={mockPrompt} />);
|
||||
|
||||
fireEvent.click(screen.getByText("Delete Prompt"));
|
||||
|
||||
expect(window.confirm).toHaveBeenCalledWith(
|
||||
"Are you sure you want to delete this prompt? This action cannot be undone."
|
||||
);
|
||||
expect(mockOnDelete).toHaveBeenCalledWith("prompt_123");
|
||||
|
||||
window.confirm = originalConfirm;
|
||||
});
|
||||
|
||||
test("does not delete when confirmation is cancelled", () => {
|
||||
const originalConfirm = window.confirm;
|
||||
window.confirm = jest.fn(() => false);
|
||||
|
||||
render(<PromptEditor {...defaultProps} prompt={mockPrompt} />);
|
||||
|
||||
fireEvent.click(screen.getByText("Delete Prompt"));
|
||||
|
||||
expect(mockOnDelete).not.toHaveBeenCalled();
|
||||
|
||||
window.confirm = originalConfirm;
|
||||
});
|
||||
});
|
||||
|
||||
describe("Cancel Functionality", () => {
|
||||
test("calls onCancel when cancel button is clicked", () => {
|
||||
render(<PromptEditor {...defaultProps} />);
|
||||
|
||||
fireEvent.click(screen.getByText("Cancel"));
|
||||
|
||||
expect(mockOnCancel).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
||||
346
src/llama_stack/ui/components/prompts/prompt-editor.tsx
Normal file
346
src/llama_stack/ui/components/prompts/prompt-editor.tsx
Normal file
|
|
@ -0,0 +1,346 @@
|
|||
"use client";
|
||||
|
||||
import { useState, useEffect } from "react";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import { Textarea } from "@/components/ui/textarea";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import {
|
||||
Card,
|
||||
CardContent,
|
||||
CardDescription,
|
||||
CardHeader,
|
||||
CardTitle,
|
||||
} from "@/components/ui/card";
|
||||
import { Separator } from "@/components/ui/separator";
|
||||
import { X, Plus, Save, Trash2 } from "lucide-react";
|
||||
import { Prompt, PromptFormData } from "./types";
|
||||
|
||||
interface PromptEditorProps {
|
||||
prompt?: Prompt;
|
||||
onSave: (prompt: PromptFormData) => void;
|
||||
onCancel: () => void;
|
||||
onDelete?: (promptId: string) => void;
|
||||
error?: string | null;
|
||||
}
|
||||
|
||||
export function PromptEditor({
|
||||
prompt,
|
||||
onSave,
|
||||
onCancel,
|
||||
onDelete,
|
||||
error,
|
||||
}: PromptEditorProps) {
|
||||
const [formData, setFormData] = useState<PromptFormData>({
|
||||
prompt: "",
|
||||
variables: [],
|
||||
});
|
||||
|
||||
const [newVariable, setNewVariable] = useState("");
|
||||
const [variableValues, setVariableValues] = useState<Record<string, string>>(
|
||||
{}
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (prompt) {
|
||||
setFormData({
|
||||
prompt: prompt.prompt || "",
|
||||
variables: prompt.variables || [],
|
||||
});
|
||||
}
|
||||
}, [prompt]);
|
||||
|
||||
const handleSubmit = (e: React.FormEvent) => {
|
||||
e.preventDefault();
|
||||
if (!formData.prompt.trim()) {
|
||||
return;
|
||||
}
|
||||
onSave(formData);
|
||||
};
|
||||
|
||||
const addVariable = () => {
|
||||
if (
|
||||
newVariable.trim() &&
|
||||
!formData.variables.includes(newVariable.trim())
|
||||
) {
|
||||
setFormData(prev => ({
|
||||
...prev,
|
||||
variables: [...prev.variables, newVariable.trim()],
|
||||
}));
|
||||
setNewVariable("");
|
||||
}
|
||||
};
|
||||
|
||||
const removeVariable = (variableToRemove: string) => {
|
||||
setFormData(prev => ({
|
||||
...prev,
|
||||
variables: prev.variables.filter(
|
||||
variable => variable !== variableToRemove
|
||||
),
|
||||
}));
|
||||
};
|
||||
|
||||
const renderPreview = () => {
|
||||
const text = formData.prompt;
|
||||
if (!text) return text;
|
||||
|
||||
// Split text by variable patterns and process each part
|
||||
const parts = text.split(/(\{\{\s*\w+\s*\}\})/g);
|
||||
|
||||
return parts.map((part, index) => {
|
||||
const variableMatch = part.match(/\{\{\s*(\w+)\s*\}\}/);
|
||||
if (variableMatch) {
|
||||
const variableName = variableMatch[1];
|
||||
const isDefined = formData.variables.includes(variableName);
|
||||
const value = variableValues[variableName];
|
||||
|
||||
if (!isDefined) {
|
||||
// Variable not in variables list - likely a typo/bug (RED)
|
||||
return (
|
||||
<span
|
||||
key={index}
|
||||
className="bg-red-100 text-red-800 dark:bg-red-900 dark:text-red-200 px-1 rounded font-medium"
|
||||
>
|
||||
{part}
|
||||
</span>
|
||||
);
|
||||
} else if (value && value.trim()) {
|
||||
// Variable defined and has value - show the value (GREEN)
|
||||
return (
|
||||
<span
|
||||
key={index}
|
||||
className="bg-green-100 text-green-800 dark:bg-green-900 dark:text-green-200 px-1 rounded font-medium"
|
||||
>
|
||||
{value}
|
||||
</span>
|
||||
);
|
||||
} else {
|
||||
// Variable defined but empty (YELLOW)
|
||||
return (
|
||||
<span
|
||||
key={index}
|
||||
className="bg-yellow-100 text-yellow-800 dark:bg-yellow-900 dark:text-yellow-200 px-1 rounded font-medium"
|
||||
>
|
||||
{part}
|
||||
</span>
|
||||
);
|
||||
}
|
||||
}
|
||||
return part;
|
||||
});
|
||||
};
|
||||
|
||||
const updateVariableValue = (variable: string, value: string) => {
|
||||
setVariableValues(prev => ({
|
||||
...prev,
|
||||
[variable]: value,
|
||||
}));
|
||||
};
|
||||
|
||||
return (
|
||||
<form onSubmit={handleSubmit} className="space-y-6">
|
||||
{error && (
|
||||
<div className="p-4 bg-destructive/10 border border-destructive/20 rounded-md">
|
||||
<p className="text-destructive text-sm">{error}</p>
|
||||
</div>
|
||||
)}
|
||||
<div className="grid grid-cols-1 lg:grid-cols-2 gap-6">
|
||||
{/* Form Section */}
|
||||
<div className="space-y-4">
|
||||
<div>
|
||||
<Label htmlFor="prompt">Prompt Content *</Label>
|
||||
<Textarea
|
||||
id="prompt"
|
||||
value={formData.prompt}
|
||||
onChange={e =>
|
||||
setFormData(prev => ({ ...prev, prompt: e.target.value }))
|
||||
}
|
||||
placeholder="Enter your prompt content here. Use {{variable_name}} for dynamic variables."
|
||||
className="min-h-32 font-mono mt-2"
|
||||
required
|
||||
/>
|
||||
<p className="text-xs text-muted-foreground mt-2">
|
||||
Use double curly braces around variable names, e.g.,{" "}
|
||||
{`{{user_name}}`} or {`{{topic}}`}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="space-y-3">
|
||||
<Label className="text-sm font-medium">Variables</Label>
|
||||
|
||||
<div className="flex gap-2 mt-2">
|
||||
<Input
|
||||
value={newVariable}
|
||||
onChange={e => setNewVariable(e.target.value)}
|
||||
placeholder="Add variable name (e.g. user_name, topic)"
|
||||
onKeyPress={e =>
|
||||
e.key === "Enter" && (e.preventDefault(), addVariable())
|
||||
}
|
||||
className="flex-1"
|
||||
/>
|
||||
<Button
|
||||
type="button"
|
||||
onClick={addVariable}
|
||||
size="sm"
|
||||
disabled={
|
||||
!newVariable.trim() ||
|
||||
formData.variables.includes(newVariable.trim())
|
||||
}
|
||||
>
|
||||
<Plus className="h-4 w-4" />
|
||||
Add
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
{formData.variables.length > 0 && (
|
||||
<div className="border rounded-lg p-3 bg-muted/20">
|
||||
<div className="flex flex-wrap gap-2">
|
||||
{formData.variables.map(variable => (
|
||||
<Badge
|
||||
key={variable}
|
||||
variant="secondary"
|
||||
className="text-sm px-2 py-1"
|
||||
>
|
||||
{variable}
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => removeVariable(variable)}
|
||||
className="ml-2 hover:text-destructive transition-colors"
|
||||
title={`Remove ${variable} variable`}
|
||||
>
|
||||
<X className="h-3 w-3" />
|
||||
</button>
|
||||
</Badge>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Variables that can be used in the prompt template. Each variable
|
||||
should match a {`{{variable}}`} placeholder in the content above.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Preview Section */}
|
||||
<div className="space-y-4">
|
||||
<Card>
|
||||
<CardHeader>
|
||||
<CardTitle className="text-lg">Preview</CardTitle>
|
||||
<CardDescription>
|
||||
Live preview of compiled prompt and variable substitution.
|
||||
</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent className="space-y-4">
|
||||
{formData.prompt ? (
|
||||
<>
|
||||
{/* Variable Values */}
|
||||
{formData.variables.length > 0 && (
|
||||
<div className="space-y-3">
|
||||
<Label className="text-sm font-medium">
|
||||
Variable Values
|
||||
</Label>
|
||||
<div className="space-y-2">
|
||||
{formData.variables.map(variable => (
|
||||
<div
|
||||
key={variable}
|
||||
className="grid grid-cols-2 gap-3 items-center"
|
||||
>
|
||||
<div className="text-sm font-mono text-muted-foreground">
|
||||
{variable}
|
||||
</div>
|
||||
<Input
|
||||
id={`var-${variable}`}
|
||||
value={variableValues[variable] || ""}
|
||||
onChange={e =>
|
||||
updateVariableValue(variable, e.target.value)
|
||||
}
|
||||
placeholder={`Enter value for ${variable}`}
|
||||
className="text-sm"
|
||||
/>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
<Separator />
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Live Preview */}
|
||||
<div>
|
||||
<Label className="text-sm font-medium mb-2 block">
|
||||
Compiled Prompt
|
||||
</Label>
|
||||
<div className="bg-muted/50 p-4 rounded-lg border">
|
||||
<div className="text-sm leading-relaxed whitespace-pre-wrap">
|
||||
{renderPreview()}
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex flex-wrap gap-4 mt-2 text-xs">
|
||||
<div className="flex items-center gap-1">
|
||||
<div className="w-3 h-3 bg-green-500 dark:bg-green-400 border rounded"></div>
|
||||
<span className="text-muted-foreground">Used</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-1">
|
||||
<div className="w-3 h-3 bg-yellow-500 dark:bg-yellow-400 border rounded"></div>
|
||||
<span className="text-muted-foreground">Unused</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-1">
|
||||
<div className="w-3 h-3 bg-red-500 dark:bg-red-400 border rounded"></div>
|
||||
<span className="text-muted-foreground">Undefined</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
<div className="text-center py-8">
|
||||
<div className="text-muted-foreground text-sm">
|
||||
Enter content to preview the compiled prompt
|
||||
</div>
|
||||
<div className="text-xs text-muted-foreground mt-2">
|
||||
Use {`{{variable_name}}`} to add dynamic variables
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<Separator />
|
||||
|
||||
<div className="flex justify-between">
|
||||
<div>
|
||||
{prompt && onDelete && (
|
||||
<Button
|
||||
type="button"
|
||||
variant="destructive"
|
||||
onClick={() => {
|
||||
if (
|
||||
confirm(
|
||||
`Are you sure you want to delete this prompt? This action cannot be undone.`
|
||||
)
|
||||
) {
|
||||
onDelete(prompt.prompt_id);
|
||||
}
|
||||
}}
|
||||
>
|
||||
<Trash2 className="h-4 w-4 mr-2" />
|
||||
Delete Prompt
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
<div className="flex gap-2">
|
||||
<Button type="button" variant="outline" onClick={onCancel}>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button type="submit">
|
||||
<Save className="h-4 w-4 mr-2" />
|
||||
{prompt ? "Update" : "Create"} Prompt
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</form>
|
||||
);
|
||||
}
|
||||
259
src/llama_stack/ui/components/prompts/prompt-list.test.tsx
Normal file
259
src/llama_stack/ui/components/prompts/prompt-list.test.tsx
Normal file
|
|
@ -0,0 +1,259 @@
|
|||
import React from "react";
|
||||
import { render, screen, fireEvent } from "@testing-library/react";
|
||||
import "@testing-library/jest-dom";
|
||||
import { PromptList } from "./prompt-list";
|
||||
import type { Prompt } from "./types";
|
||||
|
||||
describe("PromptList", () => {
|
||||
const mockOnEdit = jest.fn();
|
||||
const mockOnDelete = jest.fn();
|
||||
|
||||
const defaultProps = {
|
||||
prompts: [],
|
||||
onEdit: mockOnEdit,
|
||||
onDelete: mockOnDelete,
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe("Empty State", () => {
|
||||
test("renders empty message when no prompts", () => {
|
||||
render(<PromptList {...defaultProps} />);
|
||||
|
||||
expect(screen.getByText("No prompts yet")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("shows filtered empty message when search has no results", () => {
|
||||
const prompts: Prompt[] = [
|
||||
{
|
||||
prompt_id: "prompt_123",
|
||||
prompt: "Hello world",
|
||||
version: 1,
|
||||
variables: [],
|
||||
is_default: false,
|
||||
},
|
||||
];
|
||||
|
||||
render(<PromptList {...defaultProps} prompts={prompts} />);
|
||||
|
||||
// Search for something that doesn't exist
|
||||
const searchInput = screen.getByPlaceholderText("Search prompts...");
|
||||
fireEvent.change(searchInput, { target: { value: "nonexistent" } });
|
||||
|
||||
expect(
|
||||
screen.getByText("No prompts match your filters")
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Prompts Display", () => {
|
||||
const mockPrompts: Prompt[] = [
|
||||
{
|
||||
prompt_id: "prompt_123",
|
||||
prompt: "Hello {{name}}, how are you?",
|
||||
version: 1,
|
||||
variables: ["name"],
|
||||
is_default: true,
|
||||
},
|
||||
{
|
||||
prompt_id: "prompt_456",
|
||||
prompt: "Summarize this {{text}} in {{length}} words",
|
||||
version: 2,
|
||||
variables: ["text", "length"],
|
||||
is_default: false,
|
||||
},
|
||||
{
|
||||
prompt_id: "prompt_789",
|
||||
prompt: "Simple prompt with no variables",
|
||||
version: 1,
|
||||
variables: [],
|
||||
is_default: false,
|
||||
},
|
||||
];
|
||||
|
||||
test("renders prompts table with correct headers", () => {
|
||||
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
|
||||
|
||||
expect(screen.getByText("ID")).toBeInTheDocument();
|
||||
expect(screen.getByText("Content")).toBeInTheDocument();
|
||||
expect(screen.getByText("Variables")).toBeInTheDocument();
|
||||
expect(screen.getByText("Version")).toBeInTheDocument();
|
||||
expect(screen.getByText("Actions")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("renders prompt data correctly", () => {
|
||||
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
|
||||
|
||||
// Check prompt IDs
|
||||
expect(screen.getByText("prompt_123")).toBeInTheDocument();
|
||||
expect(screen.getByText("prompt_456")).toBeInTheDocument();
|
||||
expect(screen.getByText("prompt_789")).toBeInTheDocument();
|
||||
|
||||
// Check content
|
||||
expect(
|
||||
screen.getByText("Hello {{name}}, how are you?")
|
||||
).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText("Summarize this {{text}} in {{length}} words")
|
||||
).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText("Simple prompt with no variables")
|
||||
).toBeInTheDocument();
|
||||
|
||||
// Check versions
|
||||
expect(screen.getAllByText("1")).toHaveLength(2); // Two prompts with version 1
|
||||
expect(screen.getByText("2")).toBeInTheDocument();
|
||||
|
||||
// Check default badge
|
||||
expect(screen.getByText("Default")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("renders variables correctly", () => {
|
||||
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
|
||||
|
||||
// Check variables display
|
||||
expect(screen.getByText("name")).toBeInTheDocument();
|
||||
expect(screen.getByText("text")).toBeInTheDocument();
|
||||
expect(screen.getByText("length")).toBeInTheDocument();
|
||||
expect(screen.getByText("None")).toBeInTheDocument(); // For prompt with no variables
|
||||
});
|
||||
|
||||
test("prompt ID links are clickable and call onEdit", () => {
|
||||
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
|
||||
|
||||
// Click on the first prompt ID link
|
||||
const promptLink = screen.getByRole("button", { name: "prompt_123" });
|
||||
fireEvent.click(promptLink);
|
||||
|
||||
expect(mockOnEdit).toHaveBeenCalledWith(mockPrompts[0]);
|
||||
});
|
||||
|
||||
test("edit buttons call onEdit", () => {
|
||||
const { container } = render(
|
||||
<PromptList {...defaultProps} prompts={mockPrompts} />
|
||||
);
|
||||
|
||||
// Find the action buttons in the table - they should be in the last column
|
||||
const actionCells = container.querySelectorAll("td:last-child");
|
||||
const firstActionCell = actionCells[0];
|
||||
const editButton = firstActionCell?.querySelector("button");
|
||||
|
||||
expect(editButton).toBeInTheDocument();
|
||||
fireEvent.click(editButton!);
|
||||
|
||||
expect(mockOnEdit).toHaveBeenCalledWith(mockPrompts[0]);
|
||||
});
|
||||
|
||||
test("delete buttons call onDelete with confirmation", () => {
|
||||
const originalConfirm = window.confirm;
|
||||
window.confirm = jest.fn(() => true);
|
||||
|
||||
const { container } = render(
|
||||
<PromptList {...defaultProps} prompts={mockPrompts} />
|
||||
);
|
||||
|
||||
// Find the delete button (second button in the first action cell)
|
||||
const actionCells = container.querySelectorAll("td:last-child");
|
||||
const firstActionCell = actionCells[0];
|
||||
const buttons = firstActionCell?.querySelectorAll("button");
|
||||
const deleteButton = buttons?.[1]; // Second button should be delete
|
||||
|
||||
expect(deleteButton).toBeInTheDocument();
|
||||
fireEvent.click(deleteButton!);
|
||||
|
||||
expect(window.confirm).toHaveBeenCalledWith(
|
||||
"Are you sure you want to delete this prompt? This action cannot be undone."
|
||||
);
|
||||
expect(mockOnDelete).toHaveBeenCalledWith("prompt_123");
|
||||
|
||||
window.confirm = originalConfirm;
|
||||
});
|
||||
|
||||
test("delete does not execute when confirmation is cancelled", () => {
|
||||
const originalConfirm = window.confirm;
|
||||
window.confirm = jest.fn(() => false);
|
||||
|
||||
const { container } = render(
|
||||
<PromptList {...defaultProps} prompts={mockPrompts} />
|
||||
);
|
||||
|
||||
const actionCells = container.querySelectorAll("td:last-child");
|
||||
const firstActionCell = actionCells[0];
|
||||
const buttons = firstActionCell?.querySelectorAll("button");
|
||||
const deleteButton = buttons?.[1]; // Second button should be delete
|
||||
|
||||
expect(deleteButton).toBeInTheDocument();
|
||||
fireEvent.click(deleteButton!);
|
||||
|
||||
expect(mockOnDelete).not.toHaveBeenCalled();
|
||||
|
||||
window.confirm = originalConfirm;
|
||||
});
|
||||
});
|
||||
|
||||
describe("Search Functionality", () => {
|
||||
const mockPrompts: Prompt[] = [
|
||||
{
|
||||
prompt_id: "user_greeting",
|
||||
prompt: "Hello {{name}}, welcome!",
|
||||
version: 1,
|
||||
variables: ["name"],
|
||||
is_default: true,
|
||||
},
|
||||
{
|
||||
prompt_id: "system_summary",
|
||||
prompt: "Summarize the following text",
|
||||
version: 1,
|
||||
variables: [],
|
||||
is_default: false,
|
||||
},
|
||||
];
|
||||
|
||||
test("filters prompts by prompt ID", () => {
|
||||
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
|
||||
|
||||
const searchInput = screen.getByPlaceholderText("Search prompts...");
|
||||
fireEvent.change(searchInput, { target: { value: "user" } });
|
||||
|
||||
expect(screen.getByText("user_greeting")).toBeInTheDocument();
|
||||
expect(screen.queryByText("system_summary")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("filters prompts by content", () => {
|
||||
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
|
||||
|
||||
const searchInput = screen.getByPlaceholderText("Search prompts...");
|
||||
fireEvent.change(searchInput, { target: { value: "welcome" } });
|
||||
|
||||
expect(screen.getByText("user_greeting")).toBeInTheDocument();
|
||||
expect(screen.queryByText("system_summary")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("search is case insensitive", () => {
|
||||
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
|
||||
|
||||
const searchInput = screen.getByPlaceholderText("Search prompts...");
|
||||
fireEvent.change(searchInput, { target: { value: "HELLO" } });
|
||||
|
||||
expect(screen.getByText("user_greeting")).toBeInTheDocument();
|
||||
expect(screen.queryByText("system_summary")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("clearing search shows all prompts", () => {
|
||||
render(<PromptList {...defaultProps} prompts={mockPrompts} />);
|
||||
|
||||
const searchInput = screen.getByPlaceholderText("Search prompts...");
|
||||
|
||||
// Filter first
|
||||
fireEvent.change(searchInput, { target: { value: "user" } });
|
||||
expect(screen.queryByText("system_summary")).not.toBeInTheDocument();
|
||||
|
||||
// Clear search
|
||||
fireEvent.change(searchInput, { target: { value: "" } });
|
||||
expect(screen.getByText("user_greeting")).toBeInTheDocument();
|
||||
expect(screen.getByText("system_summary")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
164
src/llama_stack/ui/components/prompts/prompt-list.tsx
Normal file
164
src/llama_stack/ui/components/prompts/prompt-list.tsx
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/ui/table";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Edit, Search, Trash2 } from "lucide-react";
|
||||
import { Prompt, PromptFilters } from "./types";
|
||||
|
||||
interface PromptListProps {
|
||||
prompts: Prompt[];
|
||||
onEdit: (prompt: Prompt) => void;
|
||||
onDelete: (promptId: string) => void;
|
||||
}
|
||||
|
||||
export function PromptList({ prompts, onEdit, onDelete }: PromptListProps) {
|
||||
const [filters, setFilters] = useState<PromptFilters>({});
|
||||
|
||||
const filteredPrompts = prompts.filter(prompt => {
|
||||
if (
|
||||
filters.searchTerm &&
|
||||
!(
|
||||
prompt.prompt
|
||||
?.toLowerCase()
|
||||
.includes(filters.searchTerm.toLowerCase()) ||
|
||||
prompt.prompt_id
|
||||
.toLowerCase()
|
||||
.includes(filters.searchTerm.toLowerCase())
|
||||
)
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
|
||||
return (
|
||||
<div className="space-y-4">
|
||||
{/* Filters */}
|
||||
<div className="flex flex-col sm:flex-row gap-4">
|
||||
<div className="relative flex-1">
|
||||
<Search className="absolute left-3 top-1/2 transform -translate-y-1/2 text-muted-foreground h-4 w-4" />
|
||||
<Input
|
||||
placeholder="Search prompts..."
|
||||
value={filters.searchTerm || ""}
|
||||
onChange={e =>
|
||||
setFilters(prev => ({ ...prev, searchTerm: e.target.value }))
|
||||
}
|
||||
className="pl-10"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Prompts Table */}
|
||||
<div className="overflow-auto">
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>ID</TableHead>
|
||||
<TableHead>Content</TableHead>
|
||||
<TableHead>Variables</TableHead>
|
||||
<TableHead>Version</TableHead>
|
||||
<TableHead>Actions</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{filteredPrompts.map(prompt => (
|
||||
<TableRow key={prompt.prompt_id}>
|
||||
<TableCell className="max-w-48">
|
||||
<Button
|
||||
variant="link"
|
||||
className="p-0 h-auto font-mono text-blue-600 hover:text-blue-800 dark:text-blue-400 dark:hover:text-blue-300 max-w-full justify-start"
|
||||
onClick={() => onEdit(prompt)}
|
||||
title={prompt.prompt_id}
|
||||
>
|
||||
<div className="truncate">{prompt.prompt_id}</div>
|
||||
</Button>
|
||||
</TableCell>
|
||||
<TableCell className="max-w-64">
|
||||
<div
|
||||
className="font-mono text-xs text-muted-foreground truncate"
|
||||
title={prompt.prompt || "No content"}
|
||||
>
|
||||
{prompt.prompt || "No content"}
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
{prompt.variables.length > 0 ? (
|
||||
<div className="flex flex-wrap gap-1">
|
||||
{prompt.variables.map(variable => (
|
||||
<Badge
|
||||
key={variable}
|
||||
variant="outline"
|
||||
className="text-xs"
|
||||
>
|
||||
{variable}
|
||||
</Badge>
|
||||
))}
|
||||
</div>
|
||||
) : (
|
||||
<span className="text-muted-foreground text-sm">None</span>
|
||||
)}
|
||||
</TableCell>
|
||||
<TableCell className="text-sm">
|
||||
{prompt.version}
|
||||
{prompt.is_default && (
|
||||
<Badge variant="secondary" className="text-xs ml-2">
|
||||
Default
|
||||
</Badge>
|
||||
)}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="flex gap-1">
|
||||
<Button
|
||||
size="sm"
|
||||
variant="outline"
|
||||
onClick={() => onEdit(prompt)}
|
||||
className="h-8 w-8 p-0"
|
||||
>
|
||||
<Edit className="h-3 w-3" />
|
||||
</Button>
|
||||
<Button
|
||||
size="sm"
|
||||
variant="outline"
|
||||
onClick={() => {
|
||||
if (
|
||||
confirm(
|
||||
`Are you sure you want to delete this prompt? This action cannot be undone.`
|
||||
)
|
||||
) {
|
||||
onDelete(prompt.prompt_id);
|
||||
}
|
||||
}}
|
||||
className="h-8 w-8 p-0 text-destructive hover:text-destructive"
|
||||
>
|
||||
<Trash2 className="h-3 w-3" />
|
||||
</Button>
|
||||
</div>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
|
||||
{filteredPrompts.length === 0 && (
|
||||
<div className="text-center py-12">
|
||||
<div className="text-muted-foreground">
|
||||
{prompts.length === 0
|
||||
? "No prompts yet"
|
||||
: "No prompts match your filters"}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
304
src/llama_stack/ui/components/prompts/prompt-management.test.tsx
Normal file
304
src/llama_stack/ui/components/prompts/prompt-management.test.tsx
Normal file
|
|
@ -0,0 +1,304 @@
|
|||
import React from "react";
|
||||
import { render, screen, fireEvent, waitFor } from "@testing-library/react";
|
||||
import "@testing-library/jest-dom";
|
||||
import { PromptManagement } from "./prompt-management";
|
||||
import type { Prompt } from "./types";
|
||||
|
||||
// Mock the auth client
|
||||
const mockPromptsClient = {
|
||||
list: jest.fn(),
|
||||
create: jest.fn(),
|
||||
update: jest.fn(),
|
||||
delete: jest.fn(),
|
||||
};
|
||||
|
||||
jest.mock("@/hooks/use-auth-client", () => ({
|
||||
useAuthClient: () => ({
|
||||
prompts: mockPromptsClient,
|
||||
}),
|
||||
}));
|
||||
|
||||
describe("PromptManagement", () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe("Loading State", () => {
|
||||
test("renders loading state initially", () => {
|
||||
mockPromptsClient.list.mockReturnValue(new Promise(() => {})); // Never resolves
|
||||
render(<PromptManagement />);
|
||||
|
||||
expect(screen.getByText("Loading prompts...")).toBeInTheDocument();
|
||||
expect(screen.getByText("Prompts")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Empty State", () => {
|
||||
test("renders empty state when no prompts", async () => {
|
||||
mockPromptsClient.list.mockResolvedValue([]);
|
||||
render(<PromptManagement />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("No prompts found.")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
expect(screen.getByText("Create Your First Prompt")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("opens modal when clicking 'Create Your First Prompt'", async () => {
|
||||
mockPromptsClient.list.mockResolvedValue([]);
|
||||
render(<PromptManagement />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText("Create Your First Prompt")
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
fireEvent.click(screen.getByText("Create Your First Prompt"));
|
||||
|
||||
expect(screen.getByText("Create New Prompt")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Error State", () => {
|
||||
test("renders error state when API fails", async () => {
|
||||
const error = new Error("API not found");
|
||||
mockPromptsClient.list.mockRejectedValue(error);
|
||||
render(<PromptManagement />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/Error:/)).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("renders specific error for 404", async () => {
|
||||
const error = new Error("404 Not found");
|
||||
mockPromptsClient.list.mockRejectedValue(error);
|
||||
render(<PromptManagement />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText(/Prompts API endpoint not found/)
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("Prompts List", () => {
|
||||
const mockPrompts: Prompt[] = [
|
||||
{
|
||||
prompt_id: "prompt_123",
|
||||
prompt: "Hello {{name}}, how are you?",
|
||||
version: 1,
|
||||
variables: ["name"],
|
||||
is_default: true,
|
||||
},
|
||||
{
|
||||
prompt_id: "prompt_456",
|
||||
prompt: "Summarize this {{text}}",
|
||||
version: 2,
|
||||
variables: ["text"],
|
||||
is_default: false,
|
||||
},
|
||||
];
|
||||
|
||||
test("renders prompts list correctly", async () => {
|
||||
mockPromptsClient.list.mockResolvedValue(mockPrompts);
|
||||
render(<PromptManagement />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("prompt_123")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
expect(screen.getByText("prompt_456")).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText("Hello {{name}}, how are you?")
|
||||
).toBeInTheDocument();
|
||||
expect(screen.getByText("Summarize this {{text}}")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("opens modal when clicking 'New Prompt' button", async () => {
|
||||
mockPromptsClient.list.mockResolvedValue(mockPrompts);
|
||||
render(<PromptManagement />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("prompt_123")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
fireEvent.click(screen.getByText("New Prompt"));
|
||||
|
||||
expect(screen.getByText("Create New Prompt")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Modal Operations", () => {
|
||||
const mockPrompts: Prompt[] = [
|
||||
{
|
||||
prompt_id: "prompt_123",
|
||||
prompt: "Hello {{name}}",
|
||||
version: 1,
|
||||
variables: ["name"],
|
||||
is_default: true,
|
||||
},
|
||||
];
|
||||
|
||||
test("closes modal when clicking cancel", async () => {
|
||||
mockPromptsClient.list.mockResolvedValue(mockPrompts);
|
||||
render(<PromptManagement />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("prompt_123")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
// Open modal
|
||||
fireEvent.click(screen.getByText("New Prompt"));
|
||||
expect(screen.getByText("Create New Prompt")).toBeInTheDocument();
|
||||
|
||||
// Close modal
|
||||
fireEvent.click(screen.getByText("Cancel"));
|
||||
expect(screen.queryByText("Create New Prompt")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("creates new prompt successfully", async () => {
|
||||
const newPrompt: Prompt = {
|
||||
prompt_id: "prompt_new",
|
||||
prompt: "New prompt content",
|
||||
version: 1,
|
||||
variables: [],
|
||||
is_default: false,
|
||||
};
|
||||
|
||||
mockPromptsClient.list.mockResolvedValue(mockPrompts);
|
||||
mockPromptsClient.create.mockResolvedValue(newPrompt);
|
||||
render(<PromptManagement />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("prompt_123")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
// Open modal
|
||||
fireEvent.click(screen.getByText("New Prompt"));
|
||||
|
||||
// Fill form
|
||||
const promptInput = screen.getByLabelText("Prompt Content *");
|
||||
fireEvent.change(promptInput, {
|
||||
target: { value: "New prompt content" },
|
||||
});
|
||||
|
||||
// Submit form
|
||||
fireEvent.click(screen.getByText("Create Prompt"));
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockPromptsClient.create).toHaveBeenCalledWith({
|
||||
prompt: "New prompt content",
|
||||
variables: [],
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
test("handles create error gracefully", async () => {
|
||||
const error = {
|
||||
detail: {
|
||||
errors: [{ msg: "Prompt contains undeclared variables: ['test']" }],
|
||||
},
|
||||
};
|
||||
|
||||
mockPromptsClient.list.mockResolvedValue(mockPrompts);
|
||||
mockPromptsClient.create.mockRejectedValue(error);
|
||||
render(<PromptManagement />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("prompt_123")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
// Open modal
|
||||
fireEvent.click(screen.getByText("New Prompt"));
|
||||
|
||||
// Fill form
|
||||
const promptInput = screen.getByLabelText("Prompt Content *");
|
||||
fireEvent.change(promptInput, { target: { value: "Hello {{test}}" } });
|
||||
|
||||
// Submit form
|
||||
fireEvent.click(screen.getByText("Create Prompt"));
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText("Prompt contains undeclared variables: ['test']")
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("updates existing prompt successfully", async () => {
|
||||
const updatedPrompt: Prompt = {
|
||||
...mockPrompts[0],
|
||||
prompt: "Updated content",
|
||||
};
|
||||
|
||||
mockPromptsClient.list.mockResolvedValue(mockPrompts);
|
||||
mockPromptsClient.update.mockResolvedValue(updatedPrompt);
|
||||
const { container } = render(<PromptManagement />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("prompt_123")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
// Click edit button (first button in the action cell of the first row)
|
||||
const actionCells = container.querySelectorAll("td:last-child");
|
||||
const firstActionCell = actionCells[0];
|
||||
const editButton = firstActionCell?.querySelector("button");
|
||||
|
||||
expect(editButton).toBeInTheDocument();
|
||||
fireEvent.click(editButton!);
|
||||
|
||||
expect(screen.getByText("Edit Prompt")).toBeInTheDocument();
|
||||
|
||||
// Update content
|
||||
const promptInput = screen.getByLabelText("Prompt Content *");
|
||||
fireEvent.change(promptInput, { target: { value: "Updated content" } });
|
||||
|
||||
// Submit form
|
||||
fireEvent.click(screen.getByText("Update Prompt"));
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockPromptsClient.update).toHaveBeenCalledWith("prompt_123", {
|
||||
prompt: "Updated content",
|
||||
variables: ["name"],
|
||||
version: 1,
|
||||
set_as_default: true,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
test("deletes prompt successfully", async () => {
|
||||
mockPromptsClient.list.mockResolvedValue(mockPrompts);
|
||||
mockPromptsClient.delete.mockResolvedValue(undefined);
|
||||
|
||||
// Mock window.confirm
|
||||
const originalConfirm = window.confirm;
|
||||
window.confirm = jest.fn(() => true);
|
||||
|
||||
const { container } = render(<PromptManagement />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("prompt_123")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
// Click delete button (second button in the action cell of the first row)
|
||||
const actionCells = container.querySelectorAll("td:last-child");
|
||||
const firstActionCell = actionCells[0];
|
||||
const buttons = firstActionCell?.querySelectorAll("button");
|
||||
const deleteButton = buttons?.[1]; // Second button should be delete
|
||||
|
||||
expect(deleteButton).toBeInTheDocument();
|
||||
fireEvent.click(deleteButton!);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockPromptsClient.delete).toHaveBeenCalledWith("prompt_123");
|
||||
});
|
||||
|
||||
// Restore window.confirm
|
||||
window.confirm = originalConfirm;
|
||||
});
|
||||
});
|
||||
});
|
||||
233
src/llama_stack/ui/components/prompts/prompt-management.tsx
Normal file
233
src/llama_stack/ui/components/prompts/prompt-management.tsx
Normal file
|
|
@ -0,0 +1,233 @@
|
|||
"use client";
|
||||
|
||||
import { useState, useEffect } from "react";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Plus } from "lucide-react";
|
||||
import { PromptList } from "./prompt-list";
|
||||
import { PromptEditor } from "./prompt-editor";
|
||||
import { Prompt, PromptFormData } from "./types";
|
||||
import { useAuthClient } from "@/hooks/use-auth-client";
|
||||
|
||||
export function PromptManagement() {
|
||||
const [prompts, setPrompts] = useState<Prompt[]>([]);
|
||||
const [showPromptModal, setShowPromptModal] = useState(false);
|
||||
const [editingPrompt, setEditingPrompt] = useState<Prompt | undefined>();
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [error, setError] = useState<string | null>(null); // For main page errors (loading, etc.)
|
||||
const [modalError, setModalError] = useState<string | null>(null); // For form submission errors
|
||||
const client = useAuthClient();
|
||||
|
||||
// Load prompts from API on component mount
|
||||
useEffect(() => {
|
||||
const fetchPrompts = async () => {
|
||||
try {
|
||||
setLoading(true);
|
||||
setError(null);
|
||||
|
||||
const response = await client.prompts.list();
|
||||
setPrompts(response || []);
|
||||
} catch (err: unknown) {
|
||||
console.error("Failed to load prompts:", err);
|
||||
|
||||
// Handle different types of errors
|
||||
const error = err as Error & { status?: number };
|
||||
if (error?.message?.includes("404") || error?.status === 404) {
|
||||
setError(
|
||||
"Prompts API endpoint not found. Please ensure your Llama Stack server supports the prompts API."
|
||||
);
|
||||
} else if (
|
||||
error?.message?.includes("not implemented") ||
|
||||
error?.message?.includes("not supported")
|
||||
) {
|
||||
setError(
|
||||
"Prompts API is not yet implemented on this Llama Stack server."
|
||||
);
|
||||
} else {
|
||||
setError(
|
||||
`Failed to load prompts: ${error?.message || "Unknown error"}`
|
||||
);
|
||||
}
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
fetchPrompts();
|
||||
}, [client]);
|
||||
|
||||
const handleSavePrompt = async (formData: PromptFormData) => {
|
||||
try {
|
||||
setModalError(null);
|
||||
|
||||
if (editingPrompt) {
|
||||
// Update existing prompt
|
||||
const response = await client.prompts.update(editingPrompt.prompt_id, {
|
||||
prompt: formData.prompt,
|
||||
variables: formData.variables,
|
||||
version: editingPrompt.version,
|
||||
set_as_default: true,
|
||||
});
|
||||
|
||||
// Update local state
|
||||
setPrompts(prev =>
|
||||
prev.map(p =>
|
||||
p.prompt_id === editingPrompt.prompt_id ? response : p
|
||||
)
|
||||
);
|
||||
} else {
|
||||
// Create new prompt
|
||||
const response = await client.prompts.create({
|
||||
prompt: formData.prompt,
|
||||
variables: formData.variables,
|
||||
});
|
||||
|
||||
// Add to local state
|
||||
setPrompts(prev => [response, ...prev]);
|
||||
}
|
||||
|
||||
setShowPromptModal(false);
|
||||
setEditingPrompt(undefined);
|
||||
} catch (err) {
|
||||
console.error("Failed to save prompt:", err);
|
||||
|
||||
// Extract specific error message from API response
|
||||
const error = err as Error & {
|
||||
message?: string;
|
||||
detail?: { errors?: Array<{ msg?: string }> };
|
||||
};
|
||||
|
||||
// Try to parse JSON from error message if it's a string
|
||||
let parsedError = error;
|
||||
if (typeof error?.message === "string" && error.message.includes("{")) {
|
||||
try {
|
||||
const jsonMatch = error.message.match(/\d+\s+(.+)/);
|
||||
if (jsonMatch) {
|
||||
parsedError = JSON.parse(jsonMatch[1]);
|
||||
}
|
||||
} catch {
|
||||
// If parsing fails, use original error
|
||||
}
|
||||
}
|
||||
|
||||
// Try to get the specific validation error message
|
||||
const validationError = parsedError?.detail?.errors?.[0]?.msg;
|
||||
if (validationError) {
|
||||
// Clean up validation error messages (remove "Value error, " prefix if present)
|
||||
const cleanMessage = validationError.replace(/^Value error,\s*/i, "");
|
||||
setModalError(cleanMessage);
|
||||
} else {
|
||||
// For other errors, format them nicely with line breaks
|
||||
const statusMatch = error?.message?.match(/(\d+)\s+(.+)/);
|
||||
if (statusMatch) {
|
||||
const statusCode = statusMatch[1];
|
||||
const response = statusMatch[2];
|
||||
setModalError(
|
||||
`Failed to save prompt: Status Code ${statusCode}\n\nResponse: ${response}`
|
||||
);
|
||||
} else {
|
||||
const message = error?.message || error?.detail || "Unknown error";
|
||||
setModalError(`Failed to save prompt: ${message}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const handleEditPrompt = (prompt: Prompt) => {
|
||||
setEditingPrompt(prompt);
|
||||
setShowPromptModal(true);
|
||||
setModalError(null); // Clear any previous modal errors
|
||||
};
|
||||
|
||||
const handleDeletePrompt = async (promptId: string) => {
|
||||
try {
|
||||
setError(null);
|
||||
await client.prompts.delete(promptId);
|
||||
setPrompts(prev => prev.filter(p => p.prompt_id !== promptId));
|
||||
|
||||
// If we're deleting the currently editing prompt, close the modal
|
||||
if (editingPrompt && editingPrompt.prompt_id === promptId) {
|
||||
setShowPromptModal(false);
|
||||
setEditingPrompt(undefined);
|
||||
}
|
||||
} catch (err) {
|
||||
console.error("Failed to delete prompt:", err);
|
||||
setError("Failed to delete prompt");
|
||||
}
|
||||
};
|
||||
|
||||
const handleCreateNew = () => {
|
||||
setEditingPrompt(undefined);
|
||||
setShowPromptModal(true);
|
||||
setModalError(null); // Clear any previous modal errors
|
||||
};
|
||||
|
||||
const handleCancel = () => {
|
||||
setShowPromptModal(false);
|
||||
setEditingPrompt(undefined);
|
||||
};
|
||||
|
||||
const renderContent = () => {
|
||||
if (loading) {
|
||||
return <div className="text-muted-foreground">Loading prompts...</div>;
|
||||
}
|
||||
|
||||
if (error) {
|
||||
return <div className="text-destructive">Error: {error}</div>;
|
||||
}
|
||||
|
||||
if (!prompts || prompts.length === 0) {
|
||||
return (
|
||||
<div className="text-center py-12">
|
||||
<p className="text-muted-foreground mb-4">No prompts found.</p>
|
||||
<Button onClick={handleCreateNew}>
|
||||
<Plus className="h-4 w-4 mr-2" />
|
||||
Create Your First Prompt
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<PromptList
|
||||
prompts={prompts}
|
||||
onEdit={handleEditPrompt}
|
||||
onDelete={handleDeletePrompt}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="space-y-4">
|
||||
<div className="flex items-center justify-between">
|
||||
<h1 className="text-2xl font-semibold">Prompts</h1>
|
||||
<Button onClick={handleCreateNew} disabled={loading}>
|
||||
<Plus className="h-4 w-4 mr-2" />
|
||||
New Prompt
|
||||
</Button>
|
||||
</div>
|
||||
{renderContent()}
|
||||
|
||||
{/* Create/Edit Prompt Modal */}
|
||||
{showPromptModal && (
|
||||
<div className="fixed inset-0 bg-black/50 flex items-center justify-center z-50">
|
||||
<div className="bg-background border rounded-lg shadow-lg max-w-4xl w-full mx-4 max-h-[90vh] overflow-hidden">
|
||||
<div className="p-6 border-b">
|
||||
<h2 className="text-2xl font-bold">
|
||||
{editingPrompt ? "Edit Prompt" : "Create New Prompt"}
|
||||
</h2>
|
||||
</div>
|
||||
<div className="p-6 overflow-y-auto max-h-[calc(90vh-120px)]">
|
||||
<PromptEditor
|
||||
prompt={editingPrompt}
|
||||
onSave={handleSavePrompt}
|
||||
onCancel={handleCancel}
|
||||
onDelete={handleDeletePrompt}
|
||||
error={modalError}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
16
src/llama_stack/ui/components/prompts/types.ts
Normal file
16
src/llama_stack/ui/components/prompts/types.ts
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
export interface Prompt {
|
||||
prompt_id: string;
|
||||
prompt: string | null;
|
||||
version: number;
|
||||
variables: string[];
|
||||
is_default: boolean;
|
||||
}
|
||||
|
||||
export interface PromptFormData {
|
||||
prompt: string;
|
||||
variables: string[];
|
||||
}
|
||||
|
||||
export interface PromptFilters {
|
||||
searchTerm?: string;
|
||||
}
|
||||
36
src/llama_stack/ui/components/ui/badge.tsx
Normal file
36
src/llama_stack/ui/components/ui/badge.tsx
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
import * as React from "react";
|
||||
import { cva, type VariantProps } from "class-variance-authority";
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
const badgeVariants = cva(
|
||||
"inline-flex items-center rounded-full border px-2.5 py-0.5 text-xs font-semibold transition-colors focus:outline-none focus:ring-2 focus:ring-ring focus:ring-offset-2",
|
||||
{
|
||||
variants: {
|
||||
variant: {
|
||||
default:
|
||||
"border-transparent bg-primary text-primary-foreground hover:bg-primary/80",
|
||||
secondary:
|
||||
"border-transparent bg-secondary text-secondary-foreground hover:bg-secondary/80",
|
||||
destructive:
|
||||
"border-transparent bg-destructive text-destructive-foreground hover:bg-destructive/80",
|
||||
outline: "text-foreground",
|
||||
},
|
||||
},
|
||||
defaultVariants: {
|
||||
variant: "default",
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
export interface BadgeProps
|
||||
extends React.HTMLAttributes<HTMLDivElement>,
|
||||
VariantProps<typeof badgeVariants> {}
|
||||
|
||||
function Badge({ className, variant, ...props }: BadgeProps) {
|
||||
return (
|
||||
<div className={cn(badgeVariants({ variant }), className)} {...props} />
|
||||
);
|
||||
}
|
||||
|
||||
export { Badge, badgeVariants };
|
||||
24
src/llama_stack/ui/components/ui/label.tsx
Normal file
24
src/llama_stack/ui/components/ui/label.tsx
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
import * as React from "react";
|
||||
import * as LabelPrimitive from "@radix-ui/react-label";
|
||||
import { cva, type VariantProps } from "class-variance-authority";
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
const labelVariants = cva(
|
||||
"text-sm font-medium leading-none peer-disabled:cursor-not-allowed peer-disabled:opacity-70"
|
||||
);
|
||||
|
||||
const Label = React.forwardRef<
|
||||
React.ElementRef<typeof LabelPrimitive.Root>,
|
||||
React.ComponentPropsWithoutRef<typeof LabelPrimitive.Root> &
|
||||
VariantProps<typeof labelVariants>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<LabelPrimitive.Root
|
||||
ref={ref}
|
||||
className={cn(labelVariants(), className)}
|
||||
{...props}
|
||||
/>
|
||||
));
|
||||
Label.displayName = LabelPrimitive.Root.displayName;
|
||||
|
||||
export { Label };
|
||||
53
src/llama_stack/ui/components/ui/tabs.tsx
Normal file
53
src/llama_stack/ui/components/ui/tabs.tsx
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
import * as React from "react";
|
||||
import * as TabsPrimitive from "@radix-ui/react-tabs";
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
const Tabs = TabsPrimitive.Root;
|
||||
|
||||
const TabsList = React.forwardRef<
|
||||
React.ElementRef<typeof TabsPrimitive.List>,
|
||||
React.ComponentPropsWithoutRef<typeof TabsPrimitive.List>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<TabsPrimitive.List
|
||||
ref={ref}
|
||||
className={cn(
|
||||
"inline-flex h-10 items-center justify-center rounded-md bg-muted p-1 text-muted-foreground",
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
));
|
||||
TabsList.displayName = TabsPrimitive.List.displayName;
|
||||
|
||||
const TabsTrigger = React.forwardRef<
|
||||
React.ElementRef<typeof TabsPrimitive.Trigger>,
|
||||
React.ComponentPropsWithoutRef<typeof TabsPrimitive.Trigger>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<TabsPrimitive.Trigger
|
||||
ref={ref}
|
||||
className={cn(
|
||||
"inline-flex items-center justify-center whitespace-nowrap rounded-sm px-3 py-1.5 text-sm font-medium ring-offset-background transition-all focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:pointer-events-none disabled:opacity-50 data-[state=active]:bg-background data-[state=active]:text-foreground data-[state=active]:shadow-sm",
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
));
|
||||
TabsTrigger.displayName = TabsPrimitive.Trigger.displayName;
|
||||
|
||||
const TabsContent = React.forwardRef<
|
||||
React.ElementRef<typeof TabsPrimitive.Content>,
|
||||
React.ComponentPropsWithoutRef<typeof TabsPrimitive.Content>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<TabsPrimitive.Content
|
||||
ref={ref}
|
||||
className={cn(
|
||||
"mt-2 ring-offset-background focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2",
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
));
|
||||
TabsContent.displayName = TabsPrimitive.Content.displayName;
|
||||
|
||||
export { Tabs, TabsList, TabsTrigger, TabsContent };
|
||||
23
src/llama_stack/ui/components/ui/textarea.tsx
Normal file
23
src/llama_stack/ui/components/ui/textarea.tsx
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
import * as React from "react";
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
export type TextareaProps = React.TextareaHTMLAttributes<HTMLTextAreaElement>;
|
||||
|
||||
const Textarea = React.forwardRef<HTMLTextAreaElement, TextareaProps>(
|
||||
({ className, ...props }, ref) => {
|
||||
return (
|
||||
<textarea
|
||||
className={cn(
|
||||
"flex min-h-[80px] w-full rounded-md border border-input bg-background px-3 py-2 text-sm ring-offset-background placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-50",
|
||||
className
|
||||
)}
|
||||
ref={ref}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
}
|
||||
);
|
||||
Textarea.displayName = "Textarea";
|
||||
|
||||
export { Textarea };
|
||||
62
src/llama_stack/ui/package-lock.json
generated
62
src/llama_stack/ui/package-lock.json
generated
|
|
@ -11,14 +11,16 @@
|
|||
"@radix-ui/react-collapsible": "^1.1.12",
|
||||
"@radix-ui/react-dialog": "^1.1.15",
|
||||
"@radix-ui/react-dropdown-menu": "^2.1.16",
|
||||
"@radix-ui/react-label": "^2.1.7",
|
||||
"@radix-ui/react-select": "^2.2.6",
|
||||
"@radix-ui/react-separator": "^1.1.7",
|
||||
"@radix-ui/react-slot": "^1.2.3",
|
||||
"@radix-ui/react-tabs": "^1.1.13",
|
||||
"@radix-ui/react-tooltip": "^1.2.8",
|
||||
"class-variance-authority": "^0.7.1",
|
||||
"clsx": "^2.1.1",
|
||||
"framer-motion": "^12.23.24",
|
||||
"llama-stack-client": "^0.3.0",
|
||||
"llama-stack-client": "github:llamastack/llama-stack-client-typescript",
|
||||
"lucide-react": "^0.545.0",
|
||||
"next": "15.5.4",
|
||||
"next-auth": "^4.24.11",
|
||||
|
|
@ -2597,6 +2599,29 @@
|
|||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-label": {
|
||||
"version": "2.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-label/-/react-label-2.1.7.tgz",
|
||||
"integrity": "sha512-YT1GqPSL8kJn20djelMX7/cTRp/Y9w5IZHvfxQTVHrOqa2yMl7i/UfMqKRU5V7mEyKTrUVgJXhNQPVCG8PBLoQ==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@radix-ui/react-primitive": "2.1.3"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"@types/react-dom": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
|
||||
"react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
},
|
||||
"@types/react-dom": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-menu": {
|
||||
"version": "2.1.16",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-menu/-/react-menu-2.1.16.tgz",
|
||||
|
|
@ -2855,6 +2880,36 @@
|
|||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-tabs": {
|
||||
"version": "1.1.13",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-tabs/-/react-tabs-1.1.13.tgz",
|
||||
"integrity": "sha512-7xdcatg7/U+7+Udyoj2zodtI9H/IIopqo+YOIcZOq1nJwXWBZ9p8xiu5llXlekDbZkca79a/fozEYQXIA4sW6A==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@radix-ui/primitive": "1.1.3",
|
||||
"@radix-ui/react-context": "1.1.2",
|
||||
"@radix-ui/react-direction": "1.1.1",
|
||||
"@radix-ui/react-id": "1.1.1",
|
||||
"@radix-ui/react-presence": "1.1.5",
|
||||
"@radix-ui/react-primitive": "2.1.3",
|
||||
"@radix-ui/react-roving-focus": "1.1.11",
|
||||
"@radix-ui/react-use-controllable-state": "1.2.2"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"@types/react-dom": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
|
||||
"react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
},
|
||||
"@types/react-dom": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-tooltip": {
|
||||
"version": "1.2.8",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-tooltip/-/react-tooltip-1.2.8.tgz",
|
||||
|
|
@ -9629,9 +9684,8 @@
|
|||
"license": "MIT"
|
||||
},
|
||||
"node_modules/llama-stack-client": {
|
||||
"version": "0.3.0",
|
||||
"resolved": "https://registry.npmjs.org/llama-stack-client/-/llama-stack-client-0.3.0.tgz",
|
||||
"integrity": "sha512-76K/t1doaGmlBbDxCADaral9Vccvys9P8pqAMIhwBhMAqWudCEORrMMhUSg+pjhamWmEKj3wa++d4zeOGbfN/w==",
|
||||
"version": "0.4.0-alpha.1",
|
||||
"resolved": "git+ssh://git@github.com/llamastack/llama-stack-client-typescript.git#78de4862c4b7d77939ac210fa9f9bde77a2c5c5f",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@types/node": "^18.11.18",
|
||||
|
|
|
|||
|
|
@ -16,14 +16,16 @@
|
|||
"@radix-ui/react-collapsible": "^1.1.12",
|
||||
"@radix-ui/react-dialog": "^1.1.15",
|
||||
"@radix-ui/react-dropdown-menu": "^2.1.16",
|
||||
"@radix-ui/react-label": "^2.1.7",
|
||||
"@radix-ui/react-select": "^2.2.6",
|
||||
"@radix-ui/react-separator": "^1.1.7",
|
||||
"@radix-ui/react-slot": "^1.2.3",
|
||||
"@radix-ui/react-tabs": "^1.1.13",
|
||||
"@radix-ui/react-tooltip": "^1.2.8",
|
||||
"class-variance-authority": "^0.7.1",
|
||||
"clsx": "^2.1.1",
|
||||
"framer-motion": "^12.23.24",
|
||||
"llama-stack-client": "^0.3.0",
|
||||
"llama-stack-client": "github:llamastack/llama-stack-client-typescript",
|
||||
"lucide-react": "^0.545.0",
|
||||
"next": "15.5.4",
|
||||
"next-auth": "^4.24.11",
|
||||
|
|
|
|||
|
|
@ -171,6 +171,10 @@ def pytest_addoption(parser):
|
|||
"--embedding-model",
|
||||
help="comma-separated list of embedding models. Fixture name: embedding_model_id",
|
||||
)
|
||||
parser.addoption(
|
||||
"--rerank-model",
|
||||
help="comma-separated list of rerank models. Fixture name: rerank_model_id",
|
||||
)
|
||||
parser.addoption(
|
||||
"--safety-shield",
|
||||
help="comma-separated list of safety shields. Fixture name: shield_id",
|
||||
|
|
@ -249,6 +253,7 @@ def pytest_generate_tests(metafunc):
|
|||
"shield_id": ("--safety-shield", "shield"),
|
||||
"judge_model_id": ("--judge-model", "judge"),
|
||||
"embedding_dimension": ("--embedding-dimension", "dim"),
|
||||
"rerank_model_id": ("--rerank-model", "rerank"),
|
||||
}
|
||||
|
||||
# Collect all parameters and their values
|
||||
|
|
|
|||
|
|
@ -153,6 +153,7 @@ def client_with_models(
|
|||
vision_model_id,
|
||||
embedding_model_id,
|
||||
judge_model_id,
|
||||
rerank_model_id,
|
||||
):
|
||||
client = llama_stack_client
|
||||
|
||||
|
|
@ -170,6 +171,9 @@ def client_with_models(
|
|||
|
||||
if embedding_model_id and embedding_model_id not in model_ids:
|
||||
raise ValueError(f"embedding_model_id {embedding_model_id} not found")
|
||||
|
||||
if rerank_model_id and rerank_model_id not in model_ids:
|
||||
raise ValueError(f"rerank_model_id {rerank_model_id} not found")
|
||||
return client
|
||||
|
||||
|
||||
|
|
@ -185,7 +189,14 @@ def model_providers(llama_stack_client):
|
|||
|
||||
@pytest.fixture(autouse=True)
|
||||
def skip_if_no_model(request):
|
||||
model_fixtures = ["text_model_id", "vision_model_id", "embedding_model_id", "judge_model_id", "shield_id"]
|
||||
model_fixtures = [
|
||||
"text_model_id",
|
||||
"vision_model_id",
|
||||
"embedding_model_id",
|
||||
"judge_model_id",
|
||||
"shield_id",
|
||||
"rerank_model_id",
|
||||
]
|
||||
test_func = request.node.function
|
||||
|
||||
actual_params = inspect.signature(test_func).parameters.keys()
|
||||
|
|
@ -230,6 +241,7 @@ def instantiate_llama_stack_client(session):
|
|||
|
||||
force_restart = os.environ.get("LLAMA_STACK_TEST_FORCE_SERVER_RESTART") == "1"
|
||||
if force_restart:
|
||||
print(f"Forcing restart of the server on port {port}")
|
||||
stop_server_on_port(port)
|
||||
|
||||
# Check if port is available
|
||||
|
|
|
|||
|
|
@ -721,6 +721,6 @@ def test_openai_chat_completion_structured_output(openai_client, text_model_id,
|
|||
print(response.choices[0].message.content)
|
||||
answer = AnswerFormat.model_validate_json(response.choices[0].message.content)
|
||||
expected = tc["expected"]
|
||||
assert answer.first_name == expected["first_name"]
|
||||
assert answer.last_name == expected["last_name"]
|
||||
assert expected["first_name"].lower() in answer.first_name.lower()
|
||||
assert expected["last_name"].lower() in answer.last_name.lower()
|
||||
assert answer.year_of_birth == expected["year_of_birth"]
|
||||
|
|
|
|||
214
tests/integration/inference/test_rerank.py
Normal file
214
tests/integration/inference/test_rerank.py
Normal file
|
|
@ -0,0 +1,214 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
from llama_stack_client import BadRequestError as LlamaStackBadRequestError
|
||||
from llama_stack_client.types.alpha import InferenceRerankResponse
|
||||
from llama_stack_client.types.shared.interleaved_content import (
|
||||
ImageContentItem,
|
||||
ImageContentItemImage,
|
||||
ImageContentItemImageURL,
|
||||
TextContentItem,
|
||||
)
|
||||
|
||||
from llama_stack.core.library_client import LlamaStackAsLibraryClient
|
||||
|
||||
# Test data
|
||||
DUMMY_STRING = "string_1"
|
||||
DUMMY_STRING2 = "string_2"
|
||||
DUMMY_TEXT = TextContentItem(text=DUMMY_STRING, type="text")
|
||||
DUMMY_TEXT2 = TextContentItem(text=DUMMY_STRING2, type="text")
|
||||
DUMMY_IMAGE_URL = ImageContentItem(
|
||||
image=ImageContentItemImage(url=ImageContentItemImageURL(uri="https://example.com/image.jpg")), type="image"
|
||||
)
|
||||
DUMMY_IMAGE_BASE64 = ImageContentItem(image=ImageContentItemImage(data="base64string"), type="image")
|
||||
|
||||
PROVIDERS_SUPPORTING_MEDIA = {} # Providers that support media input for rerank models
|
||||
|
||||
|
||||
def skip_if_provider_doesnt_support_rerank(inference_provider_type):
|
||||
supported_providers = {"remote::nvidia"}
|
||||
if inference_provider_type not in supported_providers:
|
||||
pytest.skip(f"{inference_provider_type} doesn't support rerank models")
|
||||
|
||||
|
||||
def _validate_rerank_response(response: InferenceRerankResponse, items: list) -> None:
|
||||
"""
|
||||
Validate that a rerank response has the correct structure and ordering.
|
||||
|
||||
Args:
|
||||
response: The InferenceRerankResponse to validate
|
||||
items: The original items list that was ranked
|
||||
|
||||
Raises:
|
||||
AssertionError: If any validation fails
|
||||
"""
|
||||
seen = set()
|
||||
last_score = float("inf")
|
||||
for d in response:
|
||||
assert 0 <= d.index < len(items), f"Index {d.index} out of bounds for {len(items)} items"
|
||||
assert d.index not in seen, f"Duplicate index {d.index} found"
|
||||
seen.add(d.index)
|
||||
assert isinstance(d.relevance_score, float), f"Score must be float, got {type(d.relevance_score)}"
|
||||
assert d.relevance_score <= last_score, f"Scores not in descending order: {d.relevance_score} > {last_score}"
|
||||
last_score = d.relevance_score
|
||||
|
||||
|
||||
def _validate_semantic_ranking(response: InferenceRerankResponse, items: list, expected_first_item: str) -> None:
|
||||
"""
|
||||
Validate that the expected most relevant item ranks first.
|
||||
|
||||
Args:
|
||||
response: The InferenceRerankResponse to validate
|
||||
items: The original items list that was ranked
|
||||
expected_first_item: The expected first item in the ranking
|
||||
|
||||
Raises:
|
||||
AssertionError: If any validation fails
|
||||
"""
|
||||
if not response:
|
||||
raise AssertionError("No ranking data returned in response")
|
||||
|
||||
actual_first_index = response[0].index
|
||||
actual_first_item = items[actual_first_index]
|
||||
assert actual_first_item == expected_first_item, (
|
||||
f"Expected '{expected_first_item}' to rank first, but '{actual_first_item}' ranked first instead."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"query,items",
|
||||
[
|
||||
(DUMMY_STRING, [DUMMY_STRING, DUMMY_STRING2]),
|
||||
(DUMMY_TEXT, [DUMMY_TEXT, DUMMY_TEXT2]),
|
||||
(DUMMY_STRING, [DUMMY_STRING2, DUMMY_TEXT]),
|
||||
(DUMMY_TEXT, [DUMMY_STRING, DUMMY_TEXT2]),
|
||||
],
|
||||
ids=[
|
||||
"string-query-string-items",
|
||||
"text-query-text-items",
|
||||
"mixed-content-1",
|
||||
"mixed-content-2",
|
||||
],
|
||||
)
|
||||
def test_rerank_text(client_with_models, rerank_model_id, query, items, inference_provider_type):
|
||||
skip_if_provider_doesnt_support_rerank(inference_provider_type)
|
||||
|
||||
response = client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||
assert isinstance(response, list)
|
||||
# TODO: Add type validation for response items once InferenceRerankResponseItem is exported from llama stack client.
|
||||
assert len(response) <= len(items)
|
||||
_validate_rerank_response(response, items)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"query,items",
|
||||
[
|
||||
(DUMMY_IMAGE_URL, [DUMMY_STRING]),
|
||||
(DUMMY_IMAGE_BASE64, [DUMMY_TEXT]),
|
||||
(DUMMY_TEXT, [DUMMY_IMAGE_URL]),
|
||||
(DUMMY_IMAGE_BASE64, [DUMMY_IMAGE_URL, DUMMY_STRING, DUMMY_IMAGE_BASE64, DUMMY_TEXT]),
|
||||
(DUMMY_TEXT, [DUMMY_IMAGE_URL, DUMMY_STRING, DUMMY_IMAGE_BASE64, DUMMY_TEXT]),
|
||||
],
|
||||
ids=[
|
||||
"image-query-url",
|
||||
"image-query-base64",
|
||||
"text-query-image-item",
|
||||
"mixed-content-1",
|
||||
"mixed-content-2",
|
||||
],
|
||||
)
|
||||
def test_rerank_image(client_with_models, rerank_model_id, query, items, inference_provider_type):
|
||||
skip_if_provider_doesnt_support_rerank(inference_provider_type)
|
||||
|
||||
if rerank_model_id not in PROVIDERS_SUPPORTING_MEDIA:
|
||||
error_type = (
|
||||
ValueError if isinstance(client_with_models, LlamaStackAsLibraryClient) else LlamaStackBadRequestError
|
||||
)
|
||||
with pytest.raises(error_type):
|
||||
client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||
else:
|
||||
response = client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||
|
||||
assert isinstance(response, list)
|
||||
assert len(response) <= len(items)
|
||||
_validate_rerank_response(response, items)
|
||||
|
||||
|
||||
def test_rerank_max_results(client_with_models, rerank_model_id, inference_provider_type):
|
||||
skip_if_provider_doesnt_support_rerank(inference_provider_type)
|
||||
|
||||
items = [DUMMY_STRING, DUMMY_STRING2, DUMMY_TEXT, DUMMY_TEXT2]
|
||||
max_num_results = 2
|
||||
|
||||
response = client_with_models.alpha.inference.rerank(
|
||||
model=rerank_model_id,
|
||||
query=DUMMY_STRING,
|
||||
items=items,
|
||||
max_num_results=max_num_results,
|
||||
)
|
||||
|
||||
assert isinstance(response, list)
|
||||
assert len(response) == max_num_results
|
||||
_validate_rerank_response(response, items)
|
||||
|
||||
|
||||
def test_rerank_max_results_larger_than_items(client_with_models, rerank_model_id, inference_provider_type):
|
||||
skip_if_provider_doesnt_support_rerank(inference_provider_type)
|
||||
|
||||
items = [DUMMY_STRING, DUMMY_STRING2]
|
||||
response = client_with_models.alpha.inference.rerank(
|
||||
model=rerank_model_id,
|
||||
query=DUMMY_STRING,
|
||||
items=items,
|
||||
max_num_results=10, # Larger than items length
|
||||
)
|
||||
|
||||
assert isinstance(response, list)
|
||||
assert len(response) <= len(items) # Should return at most len(items)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"query,items,expected_first_item",
|
||||
[
|
||||
(
|
||||
"What is a reranking model? ",
|
||||
[
|
||||
"A reranking model reranks a list of items based on the query. ",
|
||||
"Machine learning algorithms learn patterns from data. ",
|
||||
"Python is a programming language. ",
|
||||
],
|
||||
"A reranking model reranks a list of items based on the query. ",
|
||||
),
|
||||
(
|
||||
"What is C++?",
|
||||
[
|
||||
"Learning new things is interesting. ",
|
||||
"C++ is a programming language. ",
|
||||
"Books provide knowledge and entertainment. ",
|
||||
],
|
||||
"C++ is a programming language. ",
|
||||
),
|
||||
(
|
||||
"What are good learning habits? ",
|
||||
[
|
||||
"Cooking pasta is a fun activity. ",
|
||||
"Plants need water and sunlight. ",
|
||||
"Good learning habits include reading daily and taking notes. ",
|
||||
],
|
||||
"Good learning habits include reading daily and taking notes. ",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_rerank_semantic_correctness(
|
||||
client_with_models, rerank_model_id, query, items, expected_first_item, inference_provider_type
|
||||
):
|
||||
skip_if_provider_doesnt_support_rerank(inference_provider_type)
|
||||
|
||||
response = client_with_models.alpha.inference.rerank(model=rerank_model_id, query=query, items=items)
|
||||
|
||||
_validate_rerank_response(response, items)
|
||||
_validate_semantic_ranking(response, items, expected_first_item)
|
||||
|
|
@ -4,18 +4,75 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
from llama_stack_client import LlamaStackClient
|
||||
|
||||
from llama_stack import LlamaStackAsLibraryClient
|
||||
|
||||
|
||||
class TestInspect:
|
||||
@pytest.mark.skip(reason="inspect tests disabled")
|
||||
def test_health(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
|
||||
health = llama_stack_client.inspect.health()
|
||||
assert health is not None
|
||||
assert health.status == "OK"
|
||||
|
||||
@pytest.mark.skip(reason="inspect tests disabled")
|
||||
def test_version(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
|
||||
version = llama_stack_client.inspect.version()
|
||||
assert version is not None
|
||||
assert version.version is not None
|
||||
|
||||
@pytest.mark.skip(reason="inspect tests disabled")
|
||||
def test_list_routes_default(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
|
||||
"""Test list_routes with default filter (non-deprecated v1 routes)."""
|
||||
response = llama_stack_client.routes.list()
|
||||
assert response is not None
|
||||
assert hasattr(response, "data")
|
||||
routes = response.data
|
||||
assert len(routes) > 0
|
||||
|
||||
# All routes should be non-deprecated
|
||||
# Check that we don't see any /openai/ routes (which are deprecated)
|
||||
openai_routes = [r for r in routes if "/openai/" in r.route]
|
||||
assert len(openai_routes) == 0, "Default filter should not include deprecated /openai/ routes"
|
||||
|
||||
# Should see standard v1 routes like /inspect/routes, /health, /version
|
||||
paths = [r.route for r in routes]
|
||||
assert "/inspect/routes" in paths or "/v1/inspect/routes" in paths
|
||||
assert "/health" in paths or "/v1/health" in paths
|
||||
|
||||
@pytest.mark.skip(reason="inspect tests disabled")
|
||||
def test_list_routes_filter_by_deprecated(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
|
||||
"""Test list_routes with deprecated filter."""
|
||||
response = llama_stack_client.routes.list(api_filter="deprecated")
|
||||
assert response is not None
|
||||
assert hasattr(response, "data")
|
||||
routes = response.data
|
||||
|
||||
# When filtering for deprecated, we should get deprecated routes
|
||||
# At minimum, we should see some /openai/ routes which are deprecated
|
||||
if len(routes) > 0:
|
||||
# If there are any deprecated routes, they should include openai routes
|
||||
openai_routes = [r for r in routes if "/openai/" in r.route]
|
||||
assert len(openai_routes) > 0, "Deprecated filter should include /openai/ routes"
|
||||
|
||||
@pytest.mark.skip(reason="inspect tests disabled")
|
||||
def test_list_routes_filter_by_v1(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
|
||||
"""Test list_routes with v1 filter."""
|
||||
response = llama_stack_client.routes.list(api_filter="v1")
|
||||
assert response is not None
|
||||
assert hasattr(response, "data")
|
||||
routes = response.data
|
||||
assert len(routes) > 0
|
||||
|
||||
# Should not include deprecated routes
|
||||
openai_routes = [r for r in routes if "/openai/" in r.route]
|
||||
assert len(openai_routes) == 0
|
||||
|
||||
# Should include v1 routes
|
||||
paths = [r.route for r in routes]
|
||||
assert any(
|
||||
"/v1/" in p or p.startswith("/inspect/") or p.startswith("/health") or p.startswith("/version")
|
||||
for p in paths
|
||||
)
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ import os
|
|||
|
||||
import pytest
|
||||
|
||||
import llama_stack.core.telemetry.telemetry as telemetry_module
|
||||
from llama_stack.testing.api_recorder import patch_httpx_for_test_id
|
||||
from tests.integration.fixtures.common import instantiate_llama_stack_client
|
||||
from tests.integration.telemetry.collectors import InMemoryTelemetryManager, OtlpHttpTestCollector
|
||||
|
|
@ -22,40 +21,26 @@ def telemetry_test_collector():
|
|||
stack_mode = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client")
|
||||
|
||||
if stack_mode == "server":
|
||||
# In server mode, the collector must be started and the server is already running.
|
||||
# The integration test script (scripts/integration-tests.sh) should have set
|
||||
# LLAMA_STACK_TEST_COLLECTOR_PORT and OTEL_EXPORTER_OTLP_ENDPOINT before starting the server.
|
||||
try:
|
||||
collector = OtlpHttpTestCollector()
|
||||
except RuntimeError as exc:
|
||||
pytest.skip(str(exc))
|
||||
env_overrides = {
|
||||
"OTEL_EXPORTER_OTLP_ENDPOINT": collector.endpoint,
|
||||
"OTEL_EXPORTER_OTLP_PROTOCOL": "http/protobuf",
|
||||
"OTEL_BSP_SCHEDULE_DELAY": "200",
|
||||
"OTEL_BSP_EXPORT_TIMEOUT": "2000",
|
||||
"LLAMA_STACK_DISABLE_GUNICORN": "true", # Disable multi-process for telemetry collection
|
||||
}
|
||||
|
||||
previous_env = {key: os.environ.get(key) for key in env_overrides}
|
||||
previous_force_restart = os.environ.get("LLAMA_STACK_TEST_FORCE_SERVER_RESTART")
|
||||
|
||||
for key, value in env_overrides.items():
|
||||
os.environ[key] = value
|
||||
|
||||
os.environ["LLAMA_STACK_TEST_FORCE_SERVER_RESTART"] = "1"
|
||||
telemetry_module._TRACER_PROVIDER = None
|
||||
# Verify the collector is listening on the expected endpoint
|
||||
expected_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT")
|
||||
if expected_endpoint and collector.endpoint != expected_endpoint:
|
||||
pytest.skip(
|
||||
f"Collector endpoint mismatch: expected {expected_endpoint}, got {collector.endpoint}. "
|
||||
"Server was likely started before collector."
|
||||
)
|
||||
|
||||
try:
|
||||
yield collector
|
||||
finally:
|
||||
collector.shutdown()
|
||||
for key, prior in previous_env.items():
|
||||
if prior is None:
|
||||
os.environ.pop(key, None)
|
||||
else:
|
||||
os.environ[key] = prior
|
||||
if previous_force_restart is None:
|
||||
os.environ.pop("LLAMA_STACK_TEST_FORCE_SERVER_RESTART", None)
|
||||
else:
|
||||
os.environ["LLAMA_STACK_TEST_FORCE_SERVER_RESTART"] = previous_force_restart
|
||||
else:
|
||||
manager = InMemoryTelemetryManager()
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -206,3 +206,65 @@ def test_parse_and_maybe_upgrade_config_invalid(invalid_config):
|
|||
def test_parse_and_maybe_upgrade_config_image_name_int(config_with_image_name_int):
|
||||
result = parse_and_maybe_upgrade_config(config_with_image_name_int)
|
||||
assert isinstance(result.image_name, str)
|
||||
|
||||
|
||||
def test_parse_and_maybe_upgrade_config_sets_external_providers_dir(up_to_date_config):
|
||||
"""Test that external_providers_dir is None when not specified (deprecated field)."""
|
||||
# Ensure the config doesn't have external_providers_dir set
|
||||
assert "external_providers_dir" not in up_to_date_config
|
||||
|
||||
result = parse_and_maybe_upgrade_config(up_to_date_config)
|
||||
|
||||
# Verify external_providers_dir is None (not set to default)
|
||||
# This aligns with the deprecation of external_providers_dir
|
||||
assert result.external_providers_dir is None
|
||||
|
||||
|
||||
def test_parse_and_maybe_upgrade_config_preserves_custom_external_providers_dir(up_to_date_config):
|
||||
"""Test that custom external_providers_dir values are preserved."""
|
||||
custom_dir = "/custom/providers/dir"
|
||||
up_to_date_config["external_providers_dir"] = custom_dir
|
||||
|
||||
result = parse_and_maybe_upgrade_config(up_to_date_config)
|
||||
|
||||
# Verify the custom value was preserved
|
||||
assert str(result.external_providers_dir) == custom_dir
|
||||
|
||||
|
||||
def test_generate_run_config_from_providers():
|
||||
"""Test that _generate_run_config_from_providers creates a valid config"""
|
||||
import argparse
|
||||
|
||||
from llama_stack.cli.stack.run import StackRun
|
||||
from llama_stack.core.datatypes import Provider
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
subparsers = parser.add_subparsers()
|
||||
stack_run = StackRun(subparsers)
|
||||
|
||||
providers = {
|
||||
"inference": [
|
||||
Provider(
|
||||
provider_type="inline::meta-reference",
|
||||
provider_id="meta-reference",
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
config = stack_run._generate_run_config_from_providers(providers=providers)
|
||||
config_dict = config.model_dump(mode="json")
|
||||
|
||||
# Verify basic structure
|
||||
assert config_dict["image_name"] == "providers-run"
|
||||
assert "inference" in config_dict["apis"]
|
||||
assert "inference" in config_dict["providers"]
|
||||
|
||||
# Verify storage has all required stores including prompts
|
||||
assert "storage" in config_dict
|
||||
stores = config_dict["storage"]["stores"]
|
||||
assert "prompts" in stores
|
||||
assert stores["prompts"]["namespace"] == "prompts"
|
||||
|
||||
# Verify config can be parsed back
|
||||
parsed = parse_and_maybe_upgrade_config(config_dict)
|
||||
assert parsed.image_name == "providers-run"
|
||||
|
|
|
|||
251
tests/unit/providers/nvidia/test_rerank_inference.py
Normal file
251
tests/unit/providers/nvidia/test_rerank_inference.py
Normal file
|
|
@ -0,0 +1,251 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.providers.remote.inference.nvidia.config import NVIDIAConfig
|
||||
from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAInferenceAdapter
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
|
||||
class MockResponse:
|
||||
def __init__(self, status=200, json_data=None, text_data="OK"):
|
||||
self.status = status
|
||||
self._json_data = json_data or {"rankings": []}
|
||||
self._text_data = text_data
|
||||
|
||||
async def json(self):
|
||||
return self._json_data
|
||||
|
||||
async def text(self):
|
||||
return self._text_data
|
||||
|
||||
|
||||
class MockSession:
|
||||
def __init__(self, response):
|
||||
self.response = response
|
||||
self.post_calls = []
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
return False
|
||||
|
||||
def post(self, url, **kwargs):
|
||||
self.post_calls.append((url, kwargs))
|
||||
|
||||
class PostContext:
|
||||
def __init__(self, response):
|
||||
self.response = response
|
||||
|
||||
async def __aenter__(self):
|
||||
return self.response
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
return False
|
||||
|
||||
return PostContext(self.response)
|
||||
|
||||
|
||||
def create_adapter(config=None, rerank_endpoints=None):
|
||||
if config is None:
|
||||
config = NVIDIAConfig(api_key="test-key")
|
||||
|
||||
adapter = NVIDIAInferenceAdapter(config=config)
|
||||
|
||||
class MockModel:
|
||||
provider_resource_id = "test-model"
|
||||
metadata = {}
|
||||
|
||||
adapter.model_store = AsyncMock()
|
||||
adapter.model_store.get_model = AsyncMock(return_value=MockModel())
|
||||
|
||||
if rerank_endpoints is not None:
|
||||
adapter.config.rerank_model_to_url = rerank_endpoints
|
||||
|
||||
return adapter
|
||||
|
||||
|
||||
async def test_rerank_basic_functionality():
|
||||
adapter = create_adapter()
|
||||
mock_response = MockResponse(json_data={"rankings": [{"index": 0, "logit": 0.5}]})
|
||||
mock_session = MockSession(mock_response)
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
result = await adapter.rerank(model="test-model", query="test query", items=["item1", "item2"])
|
||||
|
||||
assert len(result.data) == 1
|
||||
assert result.data[0].index == 0
|
||||
assert result.data[0].relevance_score == 0.5
|
||||
|
||||
url, kwargs = mock_session.post_calls[0]
|
||||
payload = kwargs["json"]
|
||||
assert payload["model"] == "test-model"
|
||||
assert payload["query"] == {"text": "test query"}
|
||||
assert payload["passages"] == [{"text": "item1"}, {"text": "item2"}]
|
||||
|
||||
|
||||
async def test_missing_rankings_key():
|
||||
adapter = create_adapter()
|
||||
mock_session = MockSession(MockResponse(json_data={}))
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
result = await adapter.rerank(model="test-model", query="q", items=["a"])
|
||||
|
||||
assert len(result.data) == 0
|
||||
|
||||
|
||||
async def test_hosted_with_endpoint():
|
||||
adapter = create_adapter(
|
||||
config=NVIDIAConfig(api_key="key"), rerank_endpoints={"test-model": "https://model.endpoint/rerank"}
|
||||
)
|
||||
mock_session = MockSession(MockResponse())
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
await adapter.rerank(model="test-model", query="q", items=["a"])
|
||||
|
||||
url, _ = mock_session.post_calls[0]
|
||||
assert url == "https://model.endpoint/rerank"
|
||||
|
||||
|
||||
async def test_hosted_without_endpoint():
|
||||
adapter = create_adapter(
|
||||
config=NVIDIAConfig(api_key="key"), # This creates hosted config (integrate.api.nvidia.com).
|
||||
rerank_endpoints={}, # No endpoint mapping for test-model
|
||||
)
|
||||
mock_session = MockSession(MockResponse())
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
await adapter.rerank(model="test-model", query="q", items=["a"])
|
||||
|
||||
url, _ = mock_session.post_calls[0]
|
||||
assert "https://integrate.api.nvidia.com" in url
|
||||
|
||||
|
||||
async def test_hosted_model_not_in_endpoint_mapping():
|
||||
adapter = create_adapter(
|
||||
config=NVIDIAConfig(api_key="key"), rerank_endpoints={"other-model": "https://other.endpoint/rerank"}
|
||||
)
|
||||
mock_session = MockSession(MockResponse())
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
await adapter.rerank(model="test-model", query="q", items=["a"])
|
||||
|
||||
url, _ = mock_session.post_calls[0]
|
||||
assert "https://integrate.api.nvidia.com" in url
|
||||
assert url != "https://other.endpoint/rerank"
|
||||
|
||||
|
||||
async def test_self_hosted_ignores_endpoint():
|
||||
adapter = create_adapter(
|
||||
config=NVIDIAConfig(url="http://localhost:8000", api_key=None),
|
||||
rerank_endpoints={"test-model": "https://model.endpoint/rerank"}, # This should be ignored for self-hosted.
|
||||
)
|
||||
mock_session = MockSession(MockResponse())
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
await adapter.rerank(model="test-model", query="q", items=["a"])
|
||||
|
||||
url, _ = mock_session.post_calls[0]
|
||||
assert "http://localhost:8000" in url
|
||||
assert "model.endpoint/rerank" not in url
|
||||
|
||||
|
||||
async def test_max_num_results():
|
||||
adapter = create_adapter()
|
||||
rankings = [{"index": 0, "logit": 0.8}, {"index": 1, "logit": 0.6}]
|
||||
mock_session = MockSession(MockResponse(json_data={"rankings": rankings}))
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
result = await adapter.rerank(model="test-model", query="q", items=["a", "b"], max_num_results=1)
|
||||
|
||||
assert len(result.data) == 1
|
||||
assert result.data[0].index == 0
|
||||
assert result.data[0].relevance_score == 0.8
|
||||
|
||||
|
||||
async def test_http_error():
|
||||
adapter = create_adapter()
|
||||
mock_session = MockSession(MockResponse(status=500, text_data="Server Error"))
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
with pytest.raises(ConnectionError, match="status 500.*Server Error"):
|
||||
await adapter.rerank(model="test-model", query="q", items=["a"])
|
||||
|
||||
|
||||
async def test_client_error():
|
||||
adapter = create_adapter()
|
||||
mock_session = AsyncMock()
|
||||
mock_session.__aenter__.side_effect = aiohttp.ClientError("Network error")
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
with pytest.raises(ConnectionError, match="Failed to connect.*Network error"):
|
||||
await adapter.rerank(model="test-model", query="q", items=["a"])
|
||||
|
||||
|
||||
async def test_list_models_includes_configured_rerank_models():
|
||||
"""Test that list_models adds rerank models to the dynamic model list."""
|
||||
adapter = create_adapter()
|
||||
adapter.__provider_id__ = "nvidia"
|
||||
adapter.__provider_spec__ = MagicMock()
|
||||
|
||||
dynamic_ids = ["llm-1", "embedding-1"]
|
||||
with patch.object(OpenAIMixin, "list_provider_model_ids", new=AsyncMock(return_value=dynamic_ids)):
|
||||
result = await adapter.list_models()
|
||||
|
||||
assert result is not None
|
||||
|
||||
# Check that the rerank models are added
|
||||
model_ids = [m.identifier for m in result]
|
||||
assert "nv-rerank-qa-mistral-4b:1" in model_ids
|
||||
assert "nvidia/nv-rerankqa-mistral-4b-v3" in model_ids
|
||||
assert "nvidia/llama-3.2-nv-rerankqa-1b-v2" in model_ids
|
||||
|
||||
rerank_models = [m for m in result if m.model_type == ModelType.rerank]
|
||||
|
||||
assert len(rerank_models) == 3
|
||||
|
||||
for m in rerank_models:
|
||||
assert m.provider_id == "nvidia"
|
||||
assert m.model_type == ModelType.rerank
|
||||
assert m.metadata == {}
|
||||
assert m.identifier in adapter._model_cache
|
||||
|
||||
|
||||
async def test_list_provider_model_ids_has_no_duplicates():
|
||||
adapter = create_adapter()
|
||||
|
||||
dynamic_ids = [
|
||||
"llm-1",
|
||||
"nvidia/nv-rerankqa-mistral-4b-v3", # overlaps configured rerank ids
|
||||
"embedding-1",
|
||||
"llm-1",
|
||||
]
|
||||
|
||||
with patch.object(OpenAIMixin, "list_provider_model_ids", new=AsyncMock(return_value=dynamic_ids)):
|
||||
ids = list(await adapter.list_provider_model_ids())
|
||||
|
||||
assert len(ids) == len(set(ids))
|
||||
assert ids.count("nvidia/nv-rerankqa-mistral-4b-v3") == 1
|
||||
assert "nv-rerank-qa-mistral-4b:1" in ids
|
||||
assert "nvidia/llama-3.2-nv-rerankqa-1b-v2" in ids
|
||||
|
||||
|
||||
async def test_list_provider_model_ids_uses_configured_on_dynamic_failure():
|
||||
adapter = create_adapter()
|
||||
|
||||
# Simulate dynamic listing failure
|
||||
with patch.object(OpenAIMixin, "list_provider_model_ids", new=AsyncMock(side_effect=Exception)):
|
||||
ids = list(await adapter.list_provider_model_ids())
|
||||
|
||||
# Should still return configured rerank ids
|
||||
configured_ids = list(adapter.config.rerank_model_to_url.keys())
|
||||
assert set(ids) == set(configured_ids)
|
||||
Loading…
Add table
Add a link
Reference in a new issue