Merge branch 'main' into redis-kv-store

This commit is contained in:
Shrinit Goyal 2025-08-19 18:07:58 +05:30 committed by GitHub
commit 5a0d71452e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
186 changed files with 15553 additions and 8443 deletions

2
.github/TRIAGERS.md vendored
View file

@ -1,2 +1,2 @@
# This file documents Triage members in the Llama Stack community # This file documents Triage members in the Llama Stack community
@bbrowning @franciscojavierarceo @leseb @franciscojavierarceo

View file

@ -2,9 +2,13 @@ name: 'Run and Record Tests'
description: 'Run integration tests and handle recording/artifact upload' description: 'Run integration tests and handle recording/artifact upload'
inputs: inputs:
test-types: test-subdirs:
description: 'JSON array of test types to run' description: 'Comma-separated list of test subdirectories to run'
required: true required: true
test-pattern:
description: 'Regex pattern to pass to pytest -k'
required: false
default: ''
stack-config: stack-config:
description: 'Stack configuration to use' description: 'Stack configuration to use'
required: true required: true
@ -32,12 +36,14 @@ runs:
- name: Run Integration Tests - name: Run Integration Tests
shell: bash shell: bash
run: | run: |
./scripts/integration-tests.sh \ uv run --no-sync ./scripts/integration-tests.sh \
--stack-config '${{ inputs.stack-config }}' \ --stack-config '${{ inputs.stack-config }}' \
--provider '${{ inputs.provider }}' \ --provider '${{ inputs.provider }}' \
--test-types '${{ inputs.test-types }}' \ --test-subdirs '${{ inputs.test-subdirs }}' \
--test-pattern '${{ inputs.test-pattern }}' \
--inference-mode '${{ inputs.inference-mode }}' \ --inference-mode '${{ inputs.inference-mode }}' \
${{ inputs.run-vision-tests == 'true' && '--run-vision-tests' || '' }} ${{ inputs.run-vision-tests == 'true' && '--run-vision-tests' || '' }} \
| tee pytest-${{ inputs.inference-mode }}.log
- name: Commit and push recordings - name: Commit and push recordings
@ -57,10 +63,10 @@ runs:
git commit -m "Recordings update from CI" git commit -m "Recordings update from CI"
fi fi
git fetch origin ${{ github.event.pull_request.head.ref }} git fetch origin ${{ github.ref_name }}
git rebase origin/${{ github.event.pull_request.head.ref }} git rebase origin/${{ github.ref_name }}
echo "Rebased successfully" echo "Rebased successfully"
git push origin HEAD:${{ github.event.pull_request.head.ref }} git push origin HEAD:${{ github.ref_name }}
echo "Pushed successfully" echo "Pushed successfully"
else else
echo "No recording changes" echo "No recording changes"

View file

@ -16,14 +16,16 @@ runs:
uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1 uses: astral-sh/setup-uv@6b9c6063abd6010835644d4c2e1bef4cf5cd0fca # v6.0.1
with: with:
python-version: ${{ inputs.python-version }} python-version: ${{ inputs.python-version }}
activate-environment: true
version: 0.7.6 version: 0.7.6
- name: Install dependencies - name: Install dependencies
shell: bash shell: bash
run: | run: |
echo "Updating project dependencies via uv sync"
uv sync --all-groups uv sync --all-groups
uv pip install ollama faiss-cpu
echo "Installing ad-hoc dependencies"
uv pip install faiss-cpu
# Install llama-stack-client-python based on the client-version input # Install llama-stack-client-python based on the client-version input
if [ "${{ inputs.client-version }}" = "latest" ]; then if [ "${{ inputs.client-version }}" = "latest" ]; then
@ -37,4 +39,5 @@ runs:
exit 1 exit 1
fi fi
uv pip install -e . echo "Installed llama packages"
uv pip list | grep llama

View file

@ -42,7 +42,22 @@ runs:
- name: Build Llama Stack - name: Build Llama Stack
shell: bash shell: bash
run: | run: |
uv run llama stack build --template ci-tests --image-type venv # Install llama-stack-client-python based on the client-version input
if [ "${{ inputs.client-version }}" = "latest" ]; then
echo "Installing latest llama-stack-client-python from main branch"
export LLAMA_STACK_CLIENT_DIR=git+https://github.com/llamastack/llama-stack-client-python.git@main
elif [ "${{ inputs.client-version }}" = "published" ]; then
echo "Installing published llama-stack-client-python from PyPI"
unset LLAMA_STACK_CLIENT_DIR
else
echo "Invalid client-version: ${{ inputs.client-version }}"
exit 1
fi
echo "Building Llama Stack"
LLAMA_STACK_DIR=. \
uv run --no-sync llama stack build --template ci-tests --image-type venv
- name: Configure git for commits - name: Configure git for commits
shell: bash shell: bash

View file

@ -18,5 +18,6 @@ Llama Stack uses GitHub Actions for Continuous Integration (CI). Below is a tabl
| Close stale issues and PRs | [stale_bot.yml](stale_bot.yml) | Run the Stale Bot action | | Close stale issues and PRs | [stale_bot.yml](stale_bot.yml) | Run the Stale Bot action |
| Test External Providers Installed via Module | [test-external-provider-module.yml](test-external-provider-module.yml) | Test External Provider installation via Python module | | Test External Providers Installed via Module | [test-external-provider-module.yml](test-external-provider-module.yml) | Test External Provider installation via Python module |
| Test External API and Providers | [test-external.yml](test-external.yml) | Test the External API and Provider mechanisms | | Test External API and Providers | [test-external.yml](test-external.yml) | Test the External API and Provider mechanisms |
| UI Tests | [ui-unit-tests.yml](ui-unit-tests.yml) | Run the UI test suite |
| Unit Tests | [unit-tests.yml](unit-tests.yml) | Run the unit test suite | | Unit Tests | [unit-tests.yml](unit-tests.yml) | Run the unit test suite |
| Update ReadTheDocs | [update-readthedocs.yml](update-readthedocs.yml) | Update the Llama Stack ReadTheDocs site | | Update ReadTheDocs | [update-readthedocs.yml](update-readthedocs.yml) | Update the Llama Stack ReadTheDocs site |

View file

@ -30,7 +30,8 @@ jobs:
- name: Build a single provider - name: Build a single provider
run: | run: |
USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --template starter --image-type container --image-name test USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run --no-sync \
llama stack build --template starter --image-type container --image-name test
- name: Run installer end-to-end - name: Run installer end-to-end
run: | run: |

View file

@ -10,6 +10,7 @@ on:
paths: paths:
- 'distributions/**' - 'distributions/**'
- 'llama_stack/**' - 'llama_stack/**'
- '!llama_stack/ui/**'
- 'tests/integration/**' - 'tests/integration/**'
- 'uv.lock' - 'uv.lock'
- 'pyproject.toml' - 'pyproject.toml'

View file

@ -5,11 +5,12 @@ run-name: Run the integration test suite from tests/integration in replay mode
on: on:
push: push:
branches: [ main ] branches: [ main ]
pull_request_target: pull_request:
branches: [ main ] branches: [ main ]
types: [opened, synchronize, reopened] types: [opened, synchronize, reopened]
paths: paths:
- 'llama_stack/**' - 'llama_stack/**'
- '!llama_stack/ui/**'
- 'tests/**' - 'tests/**'
- 'uv.lock' - 'uv.lock'
- 'pyproject.toml' - 'pyproject.toml'
@ -31,35 +32,23 @@ on:
description: 'Test against a specific provider' description: 'Test against a specific provider'
type: string type: string
default: 'ollama' default: 'ollama'
test-subdirs:
description: 'Comma-separated list of test subdirectories to run'
type: string
default: ''
test-pattern:
description: 'Regex pattern to pass to pytest -k'
type: string
default: ''
concurrency: concurrency:
# Skip concurrency for pushes to main - each commit should be tested independently # Skip concurrency for pushes to main - each commit should be tested independently
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.event.pull_request.number }} group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
cancel-in-progress: true cancel-in-progress: true
jobs: jobs:
discover-tests:
runs-on: ubuntu-latest
outputs:
test-types: ${{ steps.generate-test-types.outputs.test-types }}
steps:
- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Generate test types
id: generate-test-types
run: |
# Get test directories dynamically, excluding non-test directories
# NOTE: we are excluding post_training since the tests take too long
TEST_TYPES=$(find tests/integration -maxdepth 1 -mindepth 1 -type d |
sed 's|tests/integration/||' |
grep -Ev "^(__pycache__|fixtures|test_cases|recordings|non_ci|post_training)$" |
sort | jq -R -s -c 'split("\n")[:-1]')
echo "test-types=$TEST_TYPES" >> $GITHUB_OUTPUT
run-replay-mode-tests: run-replay-mode-tests:
needs: discover-tests
runs-on: ubuntu-latest runs-on: ubuntu-latest
name: ${{ format('Integration Tests ({0}, {1}, {2}, client={3}, vision={4})', matrix.client-type, matrix.provider, matrix.python-version, matrix.client-version, matrix.run-vision-tests) }} name: ${{ format('Integration Tests ({0}, {1}, {2}, client={3}, vision={4})', matrix.client-type, matrix.provider, matrix.python-version, matrix.client-version, matrix.run-vision-tests) }}
@ -90,7 +79,8 @@ jobs:
- name: Run tests - name: Run tests
uses: ./.github/actions/run-and-record-tests uses: ./.github/actions/run-and-record-tests
with: with:
test-types: ${{ needs.discover-tests.outputs.test-types }} test-subdirs: ${{ inputs.test-subdirs }}
test-pattern: ${{ inputs.test-pattern }}
stack-config: ${{ matrix.client-type == 'library' && 'ci-tests' || 'server:ci-tests' }} stack-config: ${{ matrix.client-type == 'library' && 'ci-tests' || 'server:ci-tests' }}
provider: ${{ matrix.provider }} provider: ${{ matrix.provider }}
inference-mode: 'replay' inference-mode: 'replay'

View file

@ -9,14 +9,17 @@ on:
branches: [ main ] branches: [ main ]
paths: paths:
- 'llama_stack/**' - 'llama_stack/**'
- '!llama_stack/ui/**'
- 'tests/integration/vector_io/**' - 'tests/integration/vector_io/**'
- 'uv.lock' - 'uv.lock'
- 'pyproject.toml' - 'pyproject.toml'
- 'requirements.txt' - 'requirements.txt'
- '.github/workflows/integration-vector-io-tests.yml' # This workflow - '.github/workflows/integration-vector-io-tests.yml' # This workflow
schedule:
- cron: '0 0 * * *' # (test on python 3.13) Daily at 12 AM UTC
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.ref }} group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
cancel-in-progress: true cancel-in-progress: true
jobs: jobs:
@ -25,7 +28,7 @@ jobs:
strategy: strategy:
matrix: matrix:
vector-io-provider: ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "remote::chromadb", "remote::pgvector", "remote::weaviate", "remote::qdrant"] vector-io-provider: ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "remote::chromadb", "remote::pgvector", "remote::weaviate", "remote::qdrant"]
python-version: ["3.12", "3.13"] python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }}
fail-fast: false # we want to run all tests regardless of failure fail-fast: false # we want to run all tests regardless of failure
steps: steps:
@ -141,7 +144,7 @@ jobs:
- name: Build Llama Stack - name: Build Llama Stack
run: | run: |
uv run llama stack build --template ci-tests --image-type venv uv run --no-sync llama stack build --template ci-tests --image-type venv
- name: Check Storage and Memory Available Before Tests - name: Check Storage and Memory Available Before Tests
if: ${{ always() }} if: ${{ always() }}
@ -164,7 +167,8 @@ jobs:
ENABLE_WEAVIATE: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'true' || '' }} ENABLE_WEAVIATE: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'true' || '' }}
WEAVIATE_CLUSTER_URL: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'localhost:8080' || '' }} WEAVIATE_CLUSTER_URL: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'localhost:8080' || '' }}
run: | run: |
uv run pytest -sv --stack-config="files=inline::localfs,inference=inline::sentence-transformers,vector_io=${{ matrix.vector-io-provider }}" \ uv run --no-sync \
pytest -sv --stack-config="files=inline::localfs,inference=inline::sentence-transformers,vector_io=${{ matrix.vector-io-provider }}" \
tests/integration/vector_io \ tests/integration/vector_io \
--embedding-model inline::sentence-transformers/all-MiniLM-L6-v2 --embedding-model inline::sentence-transformers/all-MiniLM-L6-v2

View file

@ -9,6 +9,8 @@ on:
pull_request: pull_request:
branches: branches:
- main - main
paths-ignore:
- 'llama_stack/ui/**'
jobs: jobs:
build: build:

View file

@ -1,93 +1,53 @@
# This workflow should be run manually when needing to re-record tests. This happens when you have
# - added a new test
# - or changed an existing test such that a new inference call is made
# You should make a PR and then run this workflow on that PR branch. The workflow will re-record the
# tests and commit the recordings to the PR branch.
name: Integration Tests (Record) name: Integration Tests (Record)
run-name: Run the integration test suite from tests/integration run-name: Run the integration test suite from tests/integration
on: on:
pull_request:
branches: [ main ]
types: [opened, synchronize, labeled]
paths:
- 'llama_stack/**'
- 'tests/**'
- 'uv.lock'
- 'pyproject.toml'
- '.github/workflows/record-integration-tests.yml' # This workflow
- '.github/actions/setup-ollama/action.yml'
- '.github/actions/setup-test-environment/action.yml'
- '.github/actions/run-and-record-tests/action.yml'
workflow_dispatch: workflow_dispatch:
inputs: inputs:
test-subdirs:
description: 'Comma-separated list of test subdirectories to run'
type: string
default: ''
test-provider: test-provider:
description: 'Test against a specific provider' description: 'Test against a specific provider'
type: string type: string
default: 'ollama' default: 'ollama'
run-vision-tests:
concurrency: description: 'Whether to run vision tests'
group: ${{ github.workflow }}-${{ github.ref }} type: boolean
cancel-in-progress: true default: false
test-pattern:
description: 'Regex pattern to pass to pytest -k'
type: string
default: ''
jobs: jobs:
discover-tests:
if: contains(github.event.pull_request.labels.*.name, 're-record-tests') ||
contains(github.event.pull_request.labels.*.name, 're-record-vision-tests')
runs-on: ubuntu-latest
outputs:
test-types: ${{ steps.generate-test-types.outputs.test-types }}
matrix-modes: ${{ steps.generate-test-types.outputs.matrix-modes }}
steps:
- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Generate test types
id: generate-test-types
run: |
# Get test directories dynamically, excluding non-test directories
TEST_TYPES=$(find tests/integration -maxdepth 1 -mindepth 1 -type d -printf "%f\n" |
grep -Ev "^(__pycache__|fixtures|test_cases|recordings|post_training)$" |
sort | jq -R -s -c 'split("\n")[:-1]')
echo "test-types=$TEST_TYPES" >> $GITHUB_OUTPUT
labels=$(gh pr view ${{ github.event.pull_request.number }} --json labels --jq '.labels[].name')
echo "labels=$labels"
modes_array=()
if [[ $labels == *"re-record-vision-tests"* ]]; then
modes_array+=("vision")
fi
if [[ $labels == *"re-record-tests"* ]]; then
modes_array+=("non-vision")
fi
# Convert to JSON array
if [ ${#modes_array[@]} -eq 0 ]; then
matrix_modes="[]"
else
matrix_modes=$(printf '%s\n' "${modes_array[@]}" | jq -R -s -c 'split("\n")[:-1]')
fi
echo "matrix_modes=$matrix_modes"
echo "matrix-modes=$matrix_modes" >> $GITHUB_OUTPUT
env:
GH_TOKEN: ${{ github.token }}
record-tests: record-tests:
needs: discover-tests
runs-on: ubuntu-latest runs-on: ubuntu-latest
permissions: permissions:
contents: write contents: write
strategy:
fail-fast: false
matrix:
mode: ${{ fromJSON(needs.discover-tests.outputs.matrix-modes) }}
steps: steps:
- name: Echo workflow inputs
run: |
echo "::group::Workflow Inputs"
echo "test-subdirs: ${{ inputs.test-subdirs }}"
echo "test-provider: ${{ inputs.test-provider }}"
echo "run-vision-tests: ${{ inputs.run-vision-tests }}"
echo "test-pattern: ${{ inputs.test-pattern }}"
echo "branch: ${{ github.ref_name }}"
echo "::endgroup::"
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with: with:
ref: ${{ github.event.pull_request.head.ref }}
fetch-depth: 0 fetch-depth: 0
- name: Setup test environment - name: Setup test environment
@ -96,14 +56,15 @@ jobs:
python-version: "3.12" # Use single Python version for recording python-version: "3.12" # Use single Python version for recording
client-version: "latest" client-version: "latest"
provider: ${{ inputs.test-provider || 'ollama' }} provider: ${{ inputs.test-provider || 'ollama' }}
run-vision-tests: ${{ matrix.mode == 'vision' && 'true' || 'false' }} run-vision-tests: ${{ inputs.run-vision-tests }}
inference-mode: 'record' inference-mode: 'record'
- name: Run and record tests - name: Run and record tests
uses: ./.github/actions/run-and-record-tests uses: ./.github/actions/run-and-record-tests
with: with:
test-types: ${{ needs.discover-tests.outputs.test-types }} test-pattern: ${{ inputs.test-pattern }}
test-subdirs: ${{ inputs.test-subdirs }}
stack-config: 'server:ci-tests' # recording must be done with server since more tests are run stack-config: 'server:ci-tests' # recording must be done with server since more tests are run
provider: ${{ inputs.test-provider || 'ollama' }} provider: ${{ inputs.test-provider || 'ollama' }}
inference-mode: 'record' inference-mode: 'record'
run-vision-tests: ${{ matrix.mode == 'vision' && 'true' || 'false' }} run-vision-tests: ${{ inputs.run-vision-tests }}

View file

@ -9,6 +9,7 @@ on:
branches: [ main ] branches: [ main ]
paths: paths:
- 'llama_stack/**' - 'llama_stack/**'
- '!llama_stack/ui/**'
- 'tests/integration/**' - 'tests/integration/**'
- 'uv.lock' - 'uv.lock'
- 'pyproject.toml' - 'pyproject.toml'
@ -43,11 +44,11 @@ jobs:
- name: Print distro dependencies - name: Print distro dependencies
run: | run: |
USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --config tests/external/build.yaml --print-deps-only USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run --no-sync llama stack build --config tests/external/build.yaml --print-deps-only
- name: Build distro from config file - name: Build distro from config file
run: | run: |
USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --config tests/external/build.yaml USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run --no-sync llama stack build --config tests/external/build.yaml
- name: Start Llama Stack server in background - name: Start Llama Stack server in background
if: ${{ matrix.image-type }} == 'venv' if: ${{ matrix.image-type }} == 'venv'

55
.github/workflows/ui-unit-tests.yml vendored Normal file
View file

@ -0,0 +1,55 @@
name: UI Tests
run-name: Run the UI test suite
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
paths:
- 'llama_stack/ui/**'
- '.github/workflows/ui-unit-tests.yml' # This workflow
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
ui-tests:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
node-version: [22]
steps:
- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Setup Node.js
uses: actions/setup-node@39370e3970a6d050c480ffad4ff0ed4d3fdee5af # v4.1.0
with:
node-version: ${{ matrix.node-version }}
cache: 'npm'
cache-dependency-path: 'llama_stack/ui/package-lock.json'
- name: Install dependencies
working-directory: llama_stack/ui
run: npm ci
- name: Run linting
working-directory: llama_stack/ui
run: npm run lint
- name: Run format check
working-directory: llama_stack/ui
run: npm run format:check
- name: Run unit tests
working-directory: llama_stack/ui
env:
CI: true
run: npm test -- --coverage --watchAll=false --passWithNoTests

View file

@ -9,6 +9,7 @@ on:
branches: [ main ] branches: [ main ]
paths: paths:
- 'llama_stack/**' - 'llama_stack/**'
- '!llama_stack/ui/**'
- 'tests/unit/**' - 'tests/unit/**'
- 'uv.lock' - 'uv.lock'
- 'pyproject.toml' - 'pyproject.toml'

View file

@ -2,6 +2,7 @@ exclude: 'build/'
default_language_version: default_language_version:
python: python3.12 python: python3.12
node: "22"
repos: repos:
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
@ -145,6 +146,20 @@ repos:
pass_filenames: false pass_filenames: false
require_serial: true require_serial: true
files: ^.github/workflows/.*$ files: ^.github/workflows/.*$
- id: ui-prettier
name: Format UI code with Prettier
entry: bash -c 'cd llama_stack/ui && npm run format'
language: system
files: ^llama_stack/ui/.*\.(ts|tsx)$
pass_filenames: false
require_serial: true
- id: ui-eslint
name: Lint UI code with ESLint
entry: bash -c 'cd llama_stack/ui && npm run lint -- --fix --quiet'
language: system
files: ^llama_stack/ui/.*\.(ts|tsx)$
pass_filenames: false
require_serial: true
ci: ci:
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks

View file

@ -14767,7 +14767,8 @@
"OpenAIFilePurpose": { "OpenAIFilePurpose": {
"type": "string", "type": "string",
"enum": [ "enum": [
"assistants" "assistants",
"batch"
], ],
"title": "OpenAIFilePurpose", "title": "OpenAIFilePurpose",
"description": "Valid purpose values for OpenAI Files API." "description": "Valid purpose values for OpenAI Files API."
@ -14844,7 +14845,8 @@
"purpose": { "purpose": {
"type": "string", "type": "string",
"enum": [ "enum": [
"assistants" "assistants",
"batch"
], ],
"description": "The intended purpose of the file" "description": "The intended purpose of the file"
} }

View file

@ -10951,6 +10951,7 @@ components:
type: string type: string
enum: enum:
- assistants - assistants
- batch
title: OpenAIFilePurpose title: OpenAIFilePurpose
description: >- description: >-
Valid purpose values for OpenAI Files API. Valid purpose values for OpenAI Files API.
@ -11019,6 +11020,7 @@ components:
type: string type: string
enum: enum:
- assistants - assistants
- batch
description: The intended purpose of the file description: The intended purpose of the file
additionalProperties: false additionalProperties: false
required: required:

View file

@ -18,3 +18,4 @@ We are working on adding a few more APIs to complete the application lifecycle.
- **Batch Inference**: run inference on a dataset of inputs - **Batch Inference**: run inference on a dataset of inputs
- **Batch Agents**: run agents on a dataset of inputs - **Batch Agents**: run agents on a dataset of inputs
- **Synthetic Data Generation**: generate synthetic data for model development - **Synthetic Data Generation**: generate synthetic data for model development
- **Batches**: OpenAI-compatible batch management for inference

View file

@ -4,11 +4,11 @@
## Adding a New Provider ## Adding a New Provider
See the [Adding a New API Provider Page](new_api_provider.md) which describes how to add new API providers to the Stack. See:
- [Adding a New API Provider Page](new_api_provider.md) which describes how to add new API providers to the Stack.
- [Vector Database Page](new_vector_database.md) which describes how to add a new vector databases with Llama Stack.
- [External Provider Page](../providers/external/index.md) which describes how to add external providers to the Stack.
See the [Vector Database Page](new_vector_database.md) which describes how to add a new vector databases with Llama Stack.
See the [External Provider Page](../providers/external/index.md) which describes how to add external providers to the Stack.
```{toctree} ```{toctree}
:maxdepth: 1 :maxdepth: 1
:hidden: :hidden:
@ -19,11 +19,21 @@ new_vector_database
## Testing ## Testing
See the [Test Page](testing.md) which describes how to test your changes.
```{include} ../../../tests/README.md
```
## Advanced Topics
For developers who need deeper understanding of the testing system internals:
```{toctree} ```{toctree}
:maxdepth: 1 :maxdepth: 1
:hidden:
:caption: Testing
testing testing/record-replay
``` ```
### Benchmarking
```{include} ../../../docs/source/distributions/k8s-benchmark/README.md
```

View file

@ -1,8 +0,0 @@
```{include} ../../../tests/README.md
```
```{include} ../../../tests/unit/README.md
```
```{include} ../../../tests/integration/README.md
```

View file

@ -0,0 +1,234 @@
# Record-Replay System
Understanding how Llama Stack captures and replays API interactions for testing.
## Overview
The record-replay system solves a fundamental challenge in AI testing: how do you test against expensive, non-deterministic APIs without breaking the bank or dealing with flaky tests?
The solution: intercept API calls, store real responses, and replay them later. This gives you real API behavior without the cost or variability.
## How It Works
### Request Hashing
Every API request gets converted to a deterministic hash for lookup:
```python
def normalize_request(method: str, url: str, headers: dict, body: dict) -> str:
normalized = {
"method": method.upper(),
"endpoint": urlparse(url).path, # Just the path, not full URL
"body": body, # Request parameters
}
return hashlib.sha256(json.dumps(normalized, sort_keys=True).encode()).hexdigest()
```
**Key insight:** The hashing is intentionally precise. Different whitespace, float precision, or parameter order produces different hashes. This prevents subtle bugs from false cache hits.
```python
# These produce DIFFERENT hashes:
{"content": "Hello world"}
{"content": "Hello world\n"}
{"temperature": 0.7}
{"temperature": 0.7000001}
```
### Client Interception
The system patches OpenAI and Ollama client methods to intercept calls before they leave your application. This happens transparently - your test code doesn't change.
### Storage Architecture
Recordings use a two-tier storage system optimized for both speed and debuggability:
```
recordings/
├── index.sqlite # Fast lookup by request hash
└── responses/
├── abc123def456.json # Individual response files
└── def789ghi012.json
```
**SQLite index** enables O(log n) hash lookups and metadata queries without loading response bodies.
**JSON files** store complete request/response pairs in human-readable format for debugging.
## Recording Modes
### LIVE Mode
Direct API calls with no recording or replay:
```python
with inference_recording(mode=InferenceMode.LIVE):
response = await client.chat.completions.create(...)
```
Use for initial development and debugging against real APIs.
### RECORD Mode
Captures API interactions while passing through real responses:
```python
with inference_recording(mode=InferenceMode.RECORD, storage_dir="./recordings"):
response = await client.chat.completions.create(...)
# Real API call made, response captured AND returned
```
The recording process:
1. Request intercepted and hashed
2. Real API call executed
3. Response captured and serialized
4. Recording stored to disk
5. Original response returned to caller
### REPLAY Mode
Returns stored responses instead of making API calls:
```python
with inference_recording(mode=InferenceMode.REPLAY, storage_dir="./recordings"):
response = await client.chat.completions.create(...)
# No API call made, cached response returned instantly
```
The replay process:
1. Request intercepted and hashed
2. Hash looked up in SQLite index
3. Response loaded from JSON file
4. Response deserialized and returned
5. Error if no recording found
## Streaming Support
Streaming APIs present a unique challenge: how do you capture an async generator?
### The Problem
```python
# How do you record this?
async for chunk in client.chat.completions.create(stream=True):
process(chunk)
```
### The Solution
The system captures all chunks immediately before yielding any:
```python
async def handle_streaming_record(response):
# Capture complete stream first
chunks = []
async for chunk in response:
chunks.append(chunk)
# Store complete recording
storage.store_recording(
request_hash, request_data, {"body": chunks, "is_streaming": True}
)
# Return generator that replays captured chunks
async def replay_stream():
for chunk in chunks:
yield chunk
return replay_stream()
```
This ensures:
- **Complete capture** - The entire stream is saved atomically
- **Interface preservation** - The returned object behaves like the original API
- **Deterministic replay** - Same chunks in the same order every time
## Serialization
API responses contain complex Pydantic objects that need careful serialization:
```python
def _serialize_response(response):
if hasattr(response, "model_dump"):
# Preserve type information for proper deserialization
return {
"__type__": f"{response.__class__.__module__}.{response.__class__.__qualname__}",
"__data__": response.model_dump(mode="json"),
}
return response
```
This preserves type safety - when replayed, you get the same Pydantic objects with all their validation and methods.
## Environment Integration
### Environment Variables
Control recording behavior globally:
```bash
export LLAMA_STACK_TEST_INFERENCE_MODE=replay
export LLAMA_STACK_TEST_RECORDING_DIR=/path/to/recordings
pytest tests/integration/
```
### Pytest Integration
The system integrates automatically based on environment variables, requiring no changes to test code.
## Debugging Recordings
### Inspecting Storage
```bash
# See what's recorded
sqlite3 recordings/index.sqlite "SELECT endpoint, model, timestamp FROM recordings LIMIT 10;"
# View specific response
cat recordings/responses/abc123def456.json | jq '.response.body'
# Find recordings by endpoint
sqlite3 recordings/index.sqlite "SELECT * FROM recordings WHERE endpoint='/v1/chat/completions';"
```
### Common Issues
**Hash mismatches:** Request parameters changed slightly between record and replay
```bash
# Compare request details
cat recordings/responses/abc123.json | jq '.request'
```
**Serialization errors:** Response types changed between versions
```bash
# Re-record with updated types
rm recordings/responses/failing_hash.json
LLAMA_STACK_TEST_INFERENCE_MODE=record pytest test_failing.py
```
**Missing recordings:** New test or changed parameters
```bash
# Record the missing interaction
LLAMA_STACK_TEST_INFERENCE_MODE=record pytest test_new.py
```
## Design Decisions
### Why Not Mocks?
Traditional mocking breaks down with AI APIs because:
- Response structures are complex and evolve frequently
- Streaming behavior is hard to mock correctly
- Edge cases in real APIs get missed
- Mocks become brittle maintenance burdens
### Why Precise Hashing?
Loose hashing (normalizing whitespace, rounding floats) seems convenient but hides bugs. If a test changes slightly, you want to know about it rather than accidentally getting the wrong cached response.
### Why JSON + SQLite?
- **JSON** - Human readable, diff-friendly, easy to inspect and modify
- **SQLite** - Fast indexed lookups without loading response bodies
- **Hybrid** - Best of both worlds for different use cases
This system provides reliable, fast testing against real AI APIs while maintaining the ability to debug issues when they arise.

View file

@ -0,0 +1,156 @@
# Llama Stack Benchmark Suite on Kubernetes
## Motivation
Performance benchmarking is critical for understanding the overhead and characteristics of the Llama Stack abstraction layer compared to direct inference engines like vLLM.
### Why This Benchmark Suite Exists
**Performance Validation**: The Llama Stack provides a unified API layer across multiple inference providers, but this abstraction introduces potential overhead. This benchmark suite quantifies the performance impact by comparing:
- Llama Stack inference (with vLLM backend)
- Direct vLLM inference calls
- Both under identical Kubernetes deployment conditions
**Production Readiness Assessment**: Real-world deployments require understanding performance characteristics under load. This suite simulates concurrent user scenarios with configurable parameters (duration, concurrency, request patterns) to validate production readiness.
**Regression Detection (TODO)**: As the Llama Stack evolves, this benchmark provides automated regression detection for performance changes. CI/CD pipelines can leverage these benchmarks to catch performance degradations before production deployments.
**Resource Planning**: By measuring throughput, latency percentiles, and resource utilization patterns, teams can make informed decisions about:
- Kubernetes resource allocation (CPU, memory, GPU)
- Auto-scaling configurations
- Cost optimization strategies
### Key Metrics Captured
The benchmark suite measures critical performance indicators:
- **Throughput**: Requests per second under sustained load
- **Latency Distribution**: P50, P95, P99 response times
- **Time to First Token (TTFT)**: Critical for streaming applications
- **Error Rates**: Request failures and timeout analysis
This data enables data-driven architectural decisions and performance optimization efforts.
## Setup
**1. Deploy base k8s infrastructure:**
```bash
cd ../k8s
./apply.sh
```
**2. Deploy benchmark components:**
```bash
cd ../k8s-benchmark
./apply.sh
```
**3. Verify deployment:**
```bash
kubectl get pods
# Should see: llama-stack-benchmark-server, vllm-server, etc.
```
## Quick Start
### Basic Benchmarks
**Benchmark Llama Stack (default):**
```bash
cd docs/source/distributions/k8s-benchmark/
./run-benchmark.sh
```
**Benchmark vLLM direct:**
```bash
./run-benchmark.sh --target vllm
```
### Custom Configuration
**Extended benchmark with high concurrency:**
```bash
./run-benchmark.sh --target vllm --duration 120 --concurrent 20
```
**Short test run:**
```bash
./run-benchmark.sh --target stack --duration 30 --concurrent 5
```
## Command Reference
### run-benchmark.sh Options
```bash
./run-benchmark.sh [options]
Options:
-t, --target <stack|vllm> Target to benchmark (default: stack)
-d, --duration <seconds> Duration in seconds (default: 60)
-c, --concurrent <users> Number of concurrent users (default: 10)
-h, --help Show help message
Examples:
./run-benchmark.sh --target vllm # Benchmark vLLM direct
./run-benchmark.sh --target stack # Benchmark Llama Stack
./run-benchmark.sh -t vllm -d 120 -c 20 # vLLM with 120s, 20 users
```
## Local Testing
### Running Benchmark Locally
For local development without Kubernetes:
**1. Start OpenAI mock server:**
```bash
uv run python openai-mock-server.py --port 8080
```
**2. Run benchmark against mock server:**
```bash
uv run python benchmark.py \
--base-url http://localhost:8080/v1 \
--model mock-inference \
--duration 30 \
--concurrent 5
```
**3. Test against local vLLM server:**
```bash
# If you have vLLM running locally on port 8000
uv run python benchmark.py \
--base-url http://localhost:8000/v1 \
--model meta-llama/Llama-3.2-3B-Instruct \
--duration 30 \
--concurrent 5
```
**4. Profile the running server:**
```bash
./profile_running_server.sh
```
### OpenAI Mock Server
The `openai-mock-server.py` provides:
- **OpenAI-compatible API** for testing without real models
- **Configurable streaming delay** via `STREAM_DELAY_SECONDS` env var
- **Consistent responses** for reproducible benchmarks
- **Lightweight testing** without GPU requirements
**Mock server usage:**
```bash
uv run python openai-mock-server.py --port 8080
```
The mock server is also deployed in k8s as `openai-mock-service:8080` and can be used by changing the Llama Stack configuration to use the `mock-vllm-inference` provider.
## Files in this Directory
- `benchmark.py` - Core benchmark script with async streaming support
- `run-benchmark.sh` - Main script with target selection and configuration
- `openai-mock-server.py` - Mock OpenAI API server for local testing
- `README.md` - This documentation file

View file

@ -8,7 +8,6 @@
# Deploys the benchmark-specific components on top of the base k8s deployment (../k8s/apply.sh). # Deploys the benchmark-specific components on top of the base k8s deployment (../k8s/apply.sh).
export MOCK_INFERENCE_PORT=8080
export STREAM_DELAY_SECONDS=0.005 export STREAM_DELAY_SECONDS=0.005
export POSTGRES_USER=llamastack export POSTGRES_USER=llamastack
@ -20,14 +19,7 @@ export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
export MOCK_INFERENCE_MODEL=mock-inference export MOCK_INFERENCE_MODEL=mock-inference
# Use llama-stack-benchmark-service as the benchmark server export MOCK_INFERENCE_URL=openai-mock-service:8080
export LOCUST_HOST=http://llama-stack-benchmark-service:8323
export LOCUST_BASE_PATH=/v1/openai/v1
# Use vllm-service as the benchmark server
# export LOCUST_HOST=http://vllm-server:8000
# export LOCUST_BASE_PATH=/v1
export BENCHMARK_INFERENCE_MODEL=$INFERENCE_MODEL export BENCHMARK_INFERENCE_MODEL=$INFERENCE_MODEL
@ -35,13 +27,6 @@ set -euo pipefail
set -x set -x
# Deploy benchmark-specific components # Deploy benchmark-specific components
# Deploy OpenAI mock server
kubectl create configmap openai-mock --from-file=openai-mock-server.py \
--dry-run=client -o yaml | kubectl apply --validate=false -f -
envsubst < openai-mock-deployment.yaml | kubectl apply --validate=false -f -
# Create configmap with our custom stack config
kubectl create configmap llama-stack-config --from-file=stack_run_config.yaml \ kubectl create configmap llama-stack-config --from-file=stack_run_config.yaml \
--dry-run=client -o yaml > stack-configmap.yaml --dry-run=client -o yaml > stack-configmap.yaml
@ -49,9 +34,3 @@ kubectl apply --validate=false -f stack-configmap.yaml
# Deploy our custom llama stack server (overriding the base one) # Deploy our custom llama stack server (overriding the base one)
envsubst < stack-k8s.yaml.template | kubectl apply --validate=false -f - envsubst < stack-k8s.yaml.template | kubectl apply --validate=false -f -
# Deploy Locust load testing
kubectl create configmap locust-script --from-file=locustfile.py \
--dry-run=client -o yaml | kubectl apply --validate=false -f -
envsubst < locust-k8s.yaml | kubectl apply --validate=false -f -

View file

@ -0,0 +1,268 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Simple benchmark script for Llama Stack with OpenAI API compatibility.
"""
import argparse
import asyncio
import os
import random
import statistics
import time
from typing import Tuple
import aiohttp
class BenchmarkStats:
def __init__(self):
self.response_times = []
self.ttft_times = []
self.chunks_received = []
self.errors = []
self.success_count = 0
self.total_requests = 0
self.concurrent_users = 0
self.start_time = None
self.end_time = None
self._lock = asyncio.Lock()
async def add_result(self, response_time: float, chunks: int, ttft: float = None, error: str = None):
async with self._lock:
self.total_requests += 1
if error:
self.errors.append(error)
else:
self.success_count += 1
self.response_times.append(response_time)
self.chunks_received.append(chunks)
if ttft is not None:
self.ttft_times.append(ttft)
def print_summary(self):
if not self.response_times:
print("No successful requests to report")
if self.errors:
print(f"Total errors: {len(self.errors)}")
print("First 5 errors:")
for error in self.errors[:5]:
print(f" {error}")
return
total_time = self.end_time - self.start_time
success_rate = (self.success_count / self.total_requests) * 100
print(f"\n{'='*60}")
print(f"BENCHMARK RESULTS")
print(f"{'='*60}")
print(f"Total time: {total_time:.2f}s")
print(f"Concurrent users: {self.concurrent_users}")
print(f"Total requests: {self.total_requests}")
print(f"Successful requests: {self.success_count}")
print(f"Failed requests: {len(self.errors)}")
print(f"Success rate: {success_rate:.1f}%")
print(f"Requests per second: {self.success_count / total_time:.2f}")
print(f"\nResponse Time Statistics:")
print(f" Mean: {statistics.mean(self.response_times):.3f}s")
print(f" Median: {statistics.median(self.response_times):.3f}s")
print(f" Min: {min(self.response_times):.3f}s")
print(f" Max: {max(self.response_times):.3f}s")
if len(self.response_times) > 1:
print(f" Std Dev: {statistics.stdev(self.response_times):.3f}s")
percentiles = [50, 90, 95, 99]
sorted_times = sorted(self.response_times)
print(f"\nPercentiles:")
for p in percentiles:
idx = int(len(sorted_times) * p / 100) - 1
idx = max(0, min(idx, len(sorted_times) - 1))
print(f" P{p}: {sorted_times[idx]:.3f}s")
if self.ttft_times:
print(f"\nTime to First Token (TTFT) Statistics:")
print(f" Mean: {statistics.mean(self.ttft_times):.3f}s")
print(f" Median: {statistics.median(self.ttft_times):.3f}s")
print(f" Min: {min(self.ttft_times):.3f}s")
print(f" Max: {max(self.ttft_times):.3f}s")
if len(self.ttft_times) > 1:
print(f" Std Dev: {statistics.stdev(self.ttft_times):.3f}s")
sorted_ttft = sorted(self.ttft_times)
print(f"\nTTFT Percentiles:")
for p in percentiles:
idx = int(len(sorted_ttft) * p / 100) - 1
idx = max(0, min(idx, len(sorted_ttft) - 1))
print(f" P{p}: {sorted_ttft[idx]:.3f}s")
if self.chunks_received:
print(f"\nStreaming Statistics:")
print(f" Mean chunks per response: {statistics.mean(self.chunks_received):.1f}")
print(f" Total chunks received: {sum(self.chunks_received)}")
if self.errors:
print(f"\nErrors (showing first 5):")
for error in self.errors[:5]:
print(f" {error}")
class LlamaStackBenchmark:
def __init__(self, base_url: str, model_id: str):
self.base_url = base_url.rstrip('/')
self.model_id = model_id
self.headers = {"Content-Type": "application/json"}
self.test_messages = [
[{"role": "user", "content": "Hi"}],
[{"role": "user", "content": "What is the capital of France?"}],
[{"role": "user", "content": "Explain quantum physics in simple terms."}],
[{"role": "user", "content": "Write a short story about a robot learning to paint."}],
[
{"role": "user", "content": "What is machine learning?"},
{"role": "assistant", "content": "Machine learning is a subset of AI..."},
{"role": "user", "content": "Can you give me a practical example?"}
]
]
async def make_async_streaming_request(self) -> Tuple[float, int, float | None, str | None]:
"""Make a single async streaming chat completion request."""
messages = random.choice(self.test_messages)
payload = {
"model": self.model_id,
"messages": messages,
"stream": True,
"max_tokens": 100
}
start_time = time.time()
chunks_received = 0
ttft = None
error = None
session = aiohttp.ClientSession()
try:
async with session.post(
f"{self.base_url}/chat/completions",
headers=self.headers,
json=payload,
timeout=aiohttp.ClientTimeout(total=30)
) as response:
if response.status == 200:
async for line in response.content:
if line:
line_str = line.decode('utf-8').strip()
if line_str.startswith('data: '):
chunks_received += 1
if ttft is None:
ttft = time.time() - start_time
if line_str == 'data: [DONE]':
break
if chunks_received == 0:
error = "No streaming chunks received"
else:
text = await response.text()
error = f"HTTP {response.status}: {text[:100]}"
except Exception as e:
error = f"Request error: {str(e)}"
finally:
await session.close()
response_time = time.time() - start_time
return response_time, chunks_received, ttft, error
async def run_benchmark(self, duration: int, concurrent_users: int) -> BenchmarkStats:
"""Run benchmark using async requests for specified duration."""
stats = BenchmarkStats()
stats.concurrent_users = concurrent_users
stats.start_time = time.time()
print(f"Starting benchmark: {duration}s duration, {concurrent_users} concurrent users")
print(f"Target URL: {self.base_url}/chat/completions")
print(f"Model: {self.model_id}")
connector = aiohttp.TCPConnector(limit=concurrent_users)
async with aiohttp.ClientSession(connector=connector) as session:
async def worker(worker_id: int):
"""Worker that sends requests sequentially until canceled."""
request_count = 0
while True:
try:
response_time, chunks, ttft, error = await self.make_async_streaming_request()
await stats.add_result(response_time, chunks, ttft, error)
request_count += 1
except asyncio.CancelledError:
break
except Exception as e:
await stats.add_result(0, 0, None, f"Worker {worker_id} error: {str(e)}")
# Progress reporting task
async def progress_reporter():
last_report_time = time.time()
while True:
try:
await asyncio.sleep(1) # Report every second
if time.time() >= last_report_time + 10: # Report every 10 seconds
elapsed = time.time() - stats.start_time
print(f"Completed: {stats.total_requests} requests in {elapsed:.1f}s")
last_report_time = time.time()
except asyncio.CancelledError:
break
# Spawn concurrent workers
tasks = [asyncio.create_task(worker(i)) for i in range(concurrent_users)]
progress_task = asyncio.create_task(progress_reporter())
tasks.append(progress_task)
# Wait for duration then cancel all tasks
await asyncio.sleep(duration)
for task in tasks:
task.cancel()
# Wait for all tasks to complete
await asyncio.gather(*tasks, return_exceptions=True)
stats.end_time = time.time()
return stats
def main():
parser = argparse.ArgumentParser(description="Llama Stack Benchmark Tool")
parser.add_argument("--base-url", default=os.getenv("BENCHMARK_BASE_URL", "http://localhost:8000/v1/openai/v1"),
help="Base URL for the API (default: http://localhost:8000/v1/openai/v1)")
parser.add_argument("--model", default=os.getenv("INFERENCE_MODEL", "test-model"),
help="Model ID to use for requests")
parser.add_argument("--duration", type=int, default=60,
help="Duration in seconds to run benchmark (default: 60)")
parser.add_argument("--concurrent", type=int, default=10,
help="Number of concurrent users (default: 10)")
args = parser.parse_args()
benchmark = LlamaStackBenchmark(args.base_url, args.model)
try:
stats = asyncio.run(benchmark.run_benchmark(args.duration, args.concurrent))
stats.print_summary()
except KeyboardInterrupt:
print("\nBenchmark interrupted by user")
except Exception as e:
print(f"Benchmark failed: {e}")
if __name__ == "__main__":
main()

View file

@ -1,131 +0,0 @@
apiVersion: apps/v1
kind: Deployment
metadata:
name: locust-master
labels:
app: locust
role: master
spec:
replicas: 1
selector:
matchLabels:
app: locust
role: master
template:
metadata:
labels:
app: locust
role: master
spec:
containers:
- name: locust-master
image: locustio/locust:2.31.8
ports:
- containerPort: 8089 # Web UI
- containerPort: 5557 # Master communication
env:
- name: LOCUST_HOST
value: "${LOCUST_HOST}"
- name: LOCUST_LOCUSTFILE
value: "/locust/locustfile.py"
- name: LOCUST_WEB_HOST
value: "0.0.0.0"
- name: LOCUST_MASTER
value: "true"
- name: LOCUST_BASE_PATH
value: "${LOCUST_BASE_PATH}"
- name: INFERENCE_MODEL
value: "${BENCHMARK_INFERENCE_MODEL}"
volumeMounts:
- name: locust-script
mountPath: /locust
command: ["locust"]
args:
- "--master"
- "--web-host=0.0.0.0"
- "--web-port=8089"
- "--host=${LOCUST_HOST}"
- "--locustfile=/locust/locustfile.py"
volumes:
- name: locust-script
configMap:
name: locust-script
---
apiVersion: apps/v1
kind: Deployment
metadata:
name: locust-worker
labels:
app: locust
role: worker
spec:
replicas: 2 # Start with 2 workers, can be scaled up
selector:
matchLabels:
app: locust
role: worker
template:
metadata:
labels:
app: locust
role: worker
spec:
containers:
- name: locust-worker
image: locustio/locust:2.31.8
env:
- name: LOCUST_HOST
value: "${LOCUST_HOST}"
- name: LOCUST_LOCUSTFILE
value: "/locust/locustfile.py"
- name: LOCUST_MASTER_HOST
value: "locust-master-service"
- name: LOCUST_MASTER_PORT
value: "5557"
- name: INFERENCE_MODEL
value: "${BENCHMARK_INFERENCE_MODEL}"
- name: LOCUST_BASE_PATH
value: "${LOCUST_BASE_PATH}"
volumeMounts:
- name: locust-script
mountPath: /locust
command: ["locust"]
args:
- "--worker"
- "--master-host=locust-master-service"
- "--master-port=5557"
- "--locustfile=/locust/locustfile.py"
volumes:
- name: locust-script
configMap:
name: locust-script
---
apiVersion: v1
kind: Service
metadata:
name: locust-master-service
spec:
selector:
app: locust
role: master
ports:
- name: web-ui
port: 8089
targetPort: 8089
- name: master-comm
port: 5557
targetPort: 5557
type: ClusterIP
---
apiVersion: v1
kind: Service
metadata:
name: locust-web-ui
spec:
selector:
app: locust
role: master
ports:
- port: 8089
targetPort: 8089
type: ClusterIP # Keep internal, use port-forward to access

View file

@ -1,78 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Locust load testing script for Llama Stack with Prism mock OpenAI provider.
"""
import random
from locust import HttpUser, task, between
import os
base_path = os.getenv("LOCUST_BASE_PATH", "/v1/openai/v1")
MODEL_ID = os.getenv("INFERENCE_MODEL")
class LlamaStackUser(HttpUser):
wait_time = between(0.0, 0.0001)
def on_start(self):
"""Setup authentication and test data."""
# No auth required for benchmark server
self.headers = {
"Content-Type": "application/json"
}
# Test messages of varying lengths
self.test_messages = [
[{"role": "user", "content": "Hi"}],
[{"role": "user", "content": "What is the capital of France?"}],
[{"role": "user", "content": "Explain quantum physics in simple terms."}],
[{"role": "user", "content": "Write a short story about a robot learning to paint."}],
[
{"role": "user", "content": "What is machine learning?"},
{"role": "assistant", "content": "Machine learning is a subset of AI..."},
{"role": "user", "content": "Can you give me a practical example?"}
]
]
@task(weight=100)
def chat_completion_streaming(self):
"""Test streaming chat completion (20% of requests)."""
messages = random.choice(self.test_messages)
payload = {
"model": MODEL_ID,
"messages": messages,
"stream": True,
"max_tokens": 100
}
with self.client.post(
f"{base_path}/chat/completions",
headers=self.headers,
json=payload,
stream=True,
catch_response=True
) as response:
if response.status_code == 200:
chunks_received = 0
try:
for line in response.iter_lines():
if line:
line_str = line.decode('utf-8')
if line_str.startswith('data: '):
chunks_received += 1
if line_str.strip() == 'data: [DONE]':
break
if chunks_received > 0:
response.success()
else:
response.failure("No streaming chunks received")
except Exception as e:
response.failure(f"Streaming error: {e}")
else:
response.failure(f"HTTP {response.status_code}: {response.text}")

View file

@ -1,52 +0,0 @@
apiVersion: apps/v1
kind: Deployment
metadata:
name: openai-mock
labels:
app: openai-mock
spec:
replicas: 1
selector:
matchLabels:
app: openai-mock
template:
metadata:
labels:
app: openai-mock
spec:
containers:
- name: openai-mock
image: python:3.12-slim
ports:
- containerPort: ${MOCK_INFERENCE_PORT}
env:
- name: PORT
value: "${MOCK_INFERENCE_PORT}"
- name: MOCK_MODELS
value: "${MOCK_INFERENCE_MODEL}"
- name: STREAM_DELAY_SECONDS
value: "${STREAM_DELAY_SECONDS}"
command: ["sh", "-c"]
args:
- |
pip install flask &&
python /app/openai-mock-server.py --port ${MOCK_INFERENCE_PORT}
volumeMounts:
- name: openai-mock-script
mountPath: /app
volumes:
- name: openai-mock-script
configMap:
name: openai-mock
---
apiVersion: v1
kind: Service
metadata:
name: openai-mock-service
spec:
selector:
app: openai-mock
ports:
- port: 8080
targetPort: 8080
type: ClusterIP

View file

@ -23,7 +23,7 @@ app = Flask(__name__)
# Models from environment variables # Models from environment variables
def get_models(): def get_models():
models_str = os.getenv("MOCK_MODELS", "mock-inference") models_str = os.getenv("MOCK_MODELS", "meta-llama/Llama-3.2-3B-Instruct")
model_ids = [m.strip() for m in models_str.split(",") if m.strip()] model_ids = [m.strip() for m in models_str.split(",") if m.strip()]
return { return {
@ -49,13 +49,13 @@ def generate_random_text(length=50):
] ]
return " ".join(random.choices(words, k=length)) return " ".join(random.choices(words, k=length))
@app.route('/models', methods=['GET']) @app.route('/v1/models', methods=['GET'])
def list_models(): def list_models():
models = get_models() models = get_models()
print(f"[MOCK] Returning models: {[m['id'] for m in models['data']]}") print(f"[MOCK] Returning models: {[m['id'] for m in models['data']]}")
return jsonify(models) return jsonify(models)
@app.route('/chat/completions', methods=['POST']) @app.route('/v1/chat/completions', methods=['POST'])
def chat_completions(): def chat_completions():
"""Return OpenAI-formatted chat completion responses.""" """Return OpenAI-formatted chat completion responses."""
data = request.get_json() data = request.get_json()

View file

@ -0,0 +1,52 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Script to profile an already running Llama Stack server
# Usage: ./profile_running_server.sh [duration_seconds] [output_file]
DURATION=${1:-60} # Default 60 seconds
OUTPUT_FILE=${2:-"llama_stack_profile"} # Default output file
echo "Looking for running Llama Stack server..."
# Find the server PID
SERVER_PID=$(ps aux | grep "llama_stack.core.server.server" | grep -v grep | awk '{print $2}' | head -1)
if [ -z "$SERVER_PID" ]; then
echo "Error: No running Llama Stack server found"
echo "Please start your server first with:"
echo "LLAMA_STACK_LOGGING=\"all=ERROR\" MOCK_INFERENCE_URL=http://localhost:8080 SAFETY_MODEL=llama-guard3:1b uv run --with llama-stack python -m llama_stack.core.server.server docs/source/distributions/k8s-benchmark/stack_run_config.yaml"
exit 1
fi
echo "Found Llama Stack server with PID: $SERVER_PID"
# Start py-spy profiling
echo "Starting py-spy profiling for ${DURATION} seconds..."
echo "Output will be saved to: ${OUTPUT_FILE}.svg"
echo ""
echo "You can now run your load test..."
echo ""
# Get the full path to py-spy
PYSPY_PATH=$(which py-spy)
# Check if running as root, if not, use sudo
if [ "$EUID" -ne 0 ]; then
echo "py-spy requires root permissions on macOS. Running with sudo..."
sudo "$PYSPY_PATH" record -o "${OUTPUT_FILE}.svg" -d ${DURATION} -p $SERVER_PID
else
"$PYSPY_PATH" record -o "${OUTPUT_FILE}.svg" -d ${DURATION} -p $SERVER_PID
fi
echo ""
echo "Profiling completed! Results saved to: ${OUTPUT_FILE}.svg"
echo ""
echo "To view the flame graph:"
echo "open ${OUTPUT_FILE}.svg"

View file

@ -0,0 +1,148 @@
#!/usr/bin/env bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
set -euo pipefail
# Default values
TARGET="stack"
DURATION=60
CONCURRENT=10
# Parse command line arguments
usage() {
echo "Usage: $0 [options]"
echo "Options:"
echo " -t, --target <stack|vllm> Target to benchmark (default: stack)"
echo " -d, --duration <seconds> Duration in seconds (default: 60)"
echo " -c, --concurrent <users> Number of concurrent users (default: 10)"
echo " -h, --help Show this help message"
echo ""
echo "Examples:"
echo " $0 --target vllm # Benchmark vLLM direct"
echo " $0 --target stack # Benchmark Llama Stack (default)"
echo " $0 -t vllm -d 120 -c 20 # vLLM with 120s duration, 20 users"
}
while [[ $# -gt 0 ]]; do
case $1 in
-t|--target)
TARGET="$2"
shift 2
;;
-d|--duration)
DURATION="$2"
shift 2
;;
-c|--concurrent)
CONCURRENT="$2"
shift 2
;;
-h|--help)
usage
exit 0
;;
*)
echo "Unknown option: $1"
usage
exit 1
;;
esac
done
# Validate target
if [[ "$TARGET" != "stack" && "$TARGET" != "vllm" ]]; then
echo "Error: Target must be 'stack' or 'vllm'"
usage
exit 1
fi
# Set configuration based on target
if [[ "$TARGET" == "vllm" ]]; then
BASE_URL="http://vllm-server:8000/v1"
JOB_NAME="vllm-benchmark-job"
echo "Benchmarking vLLM direct..."
else
BASE_URL="http://llama-stack-benchmark-service:8323/v1/openai/v1"
JOB_NAME="stack-benchmark-job"
echo "Benchmarking Llama Stack..."
fi
echo "Configuration:"
echo " Target: $TARGET"
echo " Base URL: $BASE_URL"
echo " Duration: ${DURATION}s"
echo " Concurrent users: $CONCURRENT"
echo ""
# Create temporary job yaml
TEMP_YAML="/tmp/benchmark-job-temp-$(date +%s).yaml"
cat > "$TEMP_YAML" << EOF
apiVersion: batch/v1
kind: Job
metadata:
name: $JOB_NAME
namespace: default
spec:
template:
spec:
containers:
- name: benchmark
image: python:3.11-slim
command: ["/bin/bash"]
args:
- "-c"
- |
pip install aiohttp &&
python3 /benchmark/benchmark.py \\
--base-url $BASE_URL \\
--model \${INFERENCE_MODEL} \\
--duration $DURATION \\
--concurrent $CONCURRENT
env:
- name: INFERENCE_MODEL
value: "meta-llama/Llama-3.2-3B-Instruct"
volumeMounts:
- name: benchmark-script
mountPath: /benchmark
resources:
requests:
memory: "256Mi"
cpu: "250m"
limits:
memory: "512Mi"
cpu: "500m"
volumes:
- name: benchmark-script
configMap:
name: benchmark-script
restartPolicy: Never
backoffLimit: 3
EOF
echo "Creating benchmark ConfigMap..."
kubectl create configmap benchmark-script \
--from-file=benchmark.py=benchmark.py \
--dry-run=client -o yaml | kubectl apply -f -
echo "Cleaning up any existing benchmark job..."
kubectl delete job $JOB_NAME 2>/dev/null || true
echo "Deploying benchmark Job..."
kubectl apply -f "$TEMP_YAML"
echo "Waiting for job to start..."
kubectl wait --for=condition=Ready pod -l job-name=$JOB_NAME --timeout=60s
echo "Following benchmark logs..."
kubectl logs -f job/$JOB_NAME
echo "Job completed. Checking final status..."
kubectl get job $JOB_NAME
# Clean up temporary file
rm -f "$TEMP_YAML"

View file

@ -26,13 +26,6 @@ data:
max_tokens: ${env.VLLM_MAX_TOKENS:=4096} max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake} api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true} tls_verify: ${env.VLLM_TLS_VERIFY:=true}
- provider_id: mock-vllm-inference
provider_type: remote::vllm
config:
url: http://openai-mock-service:${env.MOCK_INFERENCE_PORT}
max_tokens: 4096
api_token: fake
tls_verify: false
- provider_id: sentence-transformers - provider_id: sentence-transformers
provider_type: inline::sentence-transformers provider_type: inline::sentence-transformers
config: {} config: {}
@ -121,9 +114,6 @@ data:
- model_id: ${env.SAFETY_MODEL} - model_id: ${env.SAFETY_MODEL}
provider_id: vllm-safety provider_id: vllm-safety
model_type: llm model_type: llm
- model_id: ${env.MOCK_INFERENCE_MODEL}
provider_id: mock-vllm-inference
model_type: llm
shields: shields:
- shield_id: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B} - shield_id: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B}
vector_dbs: [] vector_dbs: []

View file

@ -44,8 +44,6 @@ spec:
value: "${SAFETY_MODEL}" value: "${SAFETY_MODEL}"
- name: TAVILY_SEARCH_API_KEY - name: TAVILY_SEARCH_API_KEY
value: "${TAVILY_SEARCH_API_KEY}" value: "${TAVILY_SEARCH_API_KEY}"
- name: MOCK_INFERENCE_PORT
value: "${MOCK_INFERENCE_PORT}"
- name: VLLM_URL - name: VLLM_URL
value: http://vllm-server.default.svc.cluster.local:8000/v1 value: http://vllm-server.default.svc.cluster.local:8000/v1
- name: VLLM_MAX_TOKENS - name: VLLM_MAX_TOKENS
@ -54,8 +52,6 @@ spec:
value: http://vllm-server-safety.default.svc.cluster.local:8001/v1 value: http://vllm-server-safety.default.svc.cluster.local:8001/v1
- name: VLLM_TLS_VERIFY - name: VLLM_TLS_VERIFY
value: "false" value: "false"
- name: MOCK_INFERENCE_MODEL
value: "${MOCK_INFERENCE_MODEL}"
command: ["python", "-m", "llama_stack.core.server.server", "/etc/config/stack_run_config.yaml", "--port", "8323"] command: ["python", "-m", "llama_stack.core.server.server", "/etc/config/stack_run_config.yaml", "--port", "8323"]
ports: ports:
- containerPort: 8323 - containerPort: 8323

View file

@ -3,7 +3,6 @@ image_name: kubernetes-benchmark-demo
apis: apis:
- agents - agents
- inference - inference
- safety
- telemetry - telemetry
- tool_runtime - tool_runtime
- vector_io - vector_io
@ -16,20 +15,6 @@ providers:
max_tokens: ${env.VLLM_MAX_TOKENS:=4096} max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake} api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true} tls_verify: ${env.VLLM_TLS_VERIFY:=true}
- provider_id: vllm-safety
provider_type: remote::vllm
config:
url: ${env.VLLM_SAFETY_URL:=http://localhost:8000/v1}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
- provider_id: mock-vllm-inference
provider_type: remote::vllm
config:
url: http://openai-mock-service:${env.MOCK_INFERENCE_PORT}
max_tokens: 4096
api_token: fake
tls_verify: false
- provider_id: sentence-transformers - provider_id: sentence-transformers
provider_type: inline::sentence-transformers provider_type: inline::sentence-transformers
config: {} config: {}
@ -45,11 +30,6 @@ providers:
db: ${env.POSTGRES_DB:=llamastack} db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack} user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack} password: ${env.POSTGRES_PASSWORD:=llamastack}
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config:
excluded_categories: []
agents: agents:
- provider_id: meta-reference - provider_id: meta-reference
provider_type: inline::meta-reference provider_type: inline::meta-reference
@ -115,14 +95,6 @@ models:
- model_id: ${env.INFERENCE_MODEL} - model_id: ${env.INFERENCE_MODEL}
provider_id: vllm-inference provider_id: vllm-inference
model_type: llm model_type: llm
- model_id: ${env.SAFETY_MODEL}
provider_id: vllm-safety
model_type: llm
- model_id: ${env.MOCK_INFERENCE_MODEL}
provider_id: mock-vllm-inference
model_type: llm
shields:
- shield_id: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B}
vector_dbs: [] vector_dbs: []
datasets: [] datasets: []
scoring_fns: [] scoring_fns: []

View file

@ -2,6 +2,15 @@
## Overview ## Overview
Agents API for creating and interacting with agentic systems.
Main functionalities provided by this API:
- Create agents with specific instructions and ability to use tools.
- Interactions with agents are grouped into sessions ("threads"), and each interaction is called a "turn".
- Agents can be provided with various tools (see the ToolGroups and ToolRuntime APIs for more details).
- Agents can be provided with various shields (see the Safety API for more details).
- Agents can also use Memory to retrieve information from knowledge bases. See the RAG Tool and Vector IO APIs for more details.
This section contains documentation for all available providers for the **agents** API. This section contains documentation for all available providers for the **agents** API.
## Providers ## Providers

View file

@ -0,0 +1,21 @@
# Batches
## Overview
Protocol for batch processing API operations.
The Batches API enables efficient processing of multiple requests in a single operation,
particularly useful for processing large datasets, batch evaluation workflows, and
cost-effective inference at scale.
Note: This API is currently under active development and may undergo changes.
This section contains documentation for all available providers for the **batches** API.
## Providers
```{toctree}
:maxdepth: 1
inline_reference
```

View file

@ -0,0 +1,23 @@
# inline::reference
## Description
Reference implementation of batches API with KVStore persistence.
## Configuration
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Configuration for the key-value store backend. |
| `max_concurrent_batches` | `<class 'int'>` | No | 1 | Maximum number of concurrent batches to process simultaneously. |
| `max_concurrent_requests_per_batch` | `<class 'int'>` | No | 10 | Maximum number of concurrent requests to process per batch. |
## Sample Configuration
```yaml
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/batches.db
```

View file

@ -2,6 +2,8 @@
## Overview ## Overview
Llama Stack Evaluation API for running evaluations on model and agent candidates.
This section contains documentation for all available providers for the **eval** API. This section contains documentation for all available providers for the **eval** API.
## Providers ## Providers

View file

@ -2,6 +2,12 @@
## Overview ## Overview
Llama Stack Inference API for generating completions, chat completions, and embeddings.
This API provides the raw interface to the underlying models. Two kinds of models are supported:
- LLM models: these models generate "raw" and "chat" (conversational) completions.
- Embedding models: these models generate embeddings to be used for semantic search.
This section contains documentation for all available providers for the **inference** API. This section contains documentation for all available providers for the **inference** API.
## Providers ## Providers

View file

@ -0,0 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .batches import Batches, BatchObject, ListBatchesResponse
__all__ = ["Batches", "BatchObject", "ListBatchesResponse"]

View file

@ -0,0 +1,89 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Literal, Protocol, runtime_checkable
from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type, webmethod
try:
from openai.types import Batch as BatchObject
except ImportError as e:
raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e
@json_schema_type
class ListBatchesResponse(BaseModel):
"""Response containing a list of batch objects."""
object: Literal["list"] = "list"
data: list[BatchObject] = Field(..., description="List of batch objects")
first_id: str | None = Field(default=None, description="ID of the first batch in the list")
last_id: str | None = Field(default=None, description="ID of the last batch in the list")
has_more: bool = Field(default=False, description="Whether there are more batches available")
@runtime_checkable
class Batches(Protocol):
"""Protocol for batch processing API operations.
The Batches API enables efficient processing of multiple requests in a single operation,
particularly useful for processing large datasets, batch evaluation workflows, and
cost-effective inference at scale.
Note: This API is currently under active development and may undergo changes.
"""
@webmethod(route="/openai/v1/batches", method="POST")
async def create_batch(
self,
input_file_id: str,
endpoint: str,
completion_window: Literal["24h"],
metadata: dict[str, str] | None = None,
) -> BatchObject:
"""Create a new batch for processing multiple API requests.
:param input_file_id: The ID of an uploaded file containing requests for the batch.
:param endpoint: The endpoint to be used for all requests in the batch.
:param completion_window: The time window within which the batch should be processed.
:param metadata: Optional metadata for the batch.
:returns: The created batch object.
"""
...
@webmethod(route="/openai/v1/batches/{batch_id}", method="GET")
async def retrieve_batch(self, batch_id: str) -> BatchObject:
"""Retrieve information about a specific batch.
:param batch_id: The ID of the batch to retrieve.
:returns: The batch object.
"""
...
@webmethod(route="/openai/v1/batches/{batch_id}/cancel", method="POST")
async def cancel_batch(self, batch_id: str) -> BatchObject:
"""Cancel a batch that is in progress.
:param batch_id: The ID of the batch to cancel.
:returns: The updated batch object.
"""
...
@webmethod(route="/openai/v1/batches", method="GET")
async def list_batches(
self,
after: str | None = None,
limit: int = 20,
) -> ListBatchesResponse:
"""List all batches for the current user.
:param after: A cursor for pagination; returns batches after this batch ID.
:param limit: Number of batches to return (default 20, max 100).
:returns: A list of batch objects.
"""
...

View file

@ -72,3 +72,10 @@ class ModelTypeError(TypeError):
f"Model '{model_name}' is of type '{model_type}' rather than the expected type '{expected_model_type}'" f"Model '{model_name}' is of type '{model_type}' rather than the expected type '{expected_model_type}'"
) )
super().__init__(message) super().__init__(message)
class ConflictError(ValueError):
"""raised when an operation cannot be performed due to a conflict with the current state"""
def __init__(self, message: str) -> None:
super().__init__(message)

View file

@ -86,6 +86,7 @@ class Api(Enum, metaclass=DynamicApiMeta):
:cvar inference: Text generation, chat completions, and embeddings :cvar inference: Text generation, chat completions, and embeddings
:cvar safety: Content moderation and safety shields :cvar safety: Content moderation and safety shields
:cvar agents: Agent orchestration and execution :cvar agents: Agent orchestration and execution
:cvar batches: Batch processing for asynchronous API requests
:cvar vector_io: Vector database operations and queries :cvar vector_io: Vector database operations and queries
:cvar datasetio: Dataset input/output operations :cvar datasetio: Dataset input/output operations
:cvar scoring: Model output evaluation and scoring :cvar scoring: Model output evaluation and scoring
@ -108,6 +109,7 @@ class Api(Enum, metaclass=DynamicApiMeta):
inference = "inference" inference = "inference"
safety = "safety" safety = "safety"
agents = "agents" agents = "agents"
batches = "batches"
vector_io = "vector_io" vector_io = "vector_io"
datasetio = "datasetio" datasetio = "datasetio"
scoring = "scoring" scoring = "scoring"

View file

@ -22,6 +22,7 @@ class OpenAIFilePurpose(StrEnum):
""" """
ASSISTANTS = "assistants" ASSISTANTS = "assistants"
BATCH = "batch"
# TODO: Add other purposes as needed # TODO: Add other purposes as needed

View file

@ -1,207 +0,0 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
LLAMA_STACK_CLIENT_DIR=${LLAMA_STACK_CLIENT_DIR:-}
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
PYPI_VERSION=${PYPI_VERSION:-}
# This timeout (in seconds) is necessary when installing PyTorch via uv since it's likely to time out
# Reference: https://github.com/astral-sh/uv/pull/1694
UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
set -euo pipefail
# Define color codes
RED='\033[0;31m'
GREEN='\033[0;32m'
NC='\033[0m' # No Color
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
source "$SCRIPT_DIR/common.sh"
# Usage function
usage() {
echo "Usage: $0 --env-name <conda_env_name> --build-file-path <build_file_path> --normal-deps <pip_dependencies> [--external-provider-deps <external_provider_deps>] [--optional-deps <special_pip_deps>]"
echo "Example: $0 --env-name my-conda-env --build-file-path ./my-stack-build.yaml --normal-deps 'numpy pandas scipy' --external-provider-deps 'foo' --optional-deps 'bar'"
exit 1
}
# Parse arguments
env_name=""
build_file_path=""
normal_deps=""
external_provider_deps=""
optional_deps=""
while [[ $# -gt 0 ]]; do
key="$1"
case "$key" in
--env-name)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --env-name requires a string value" >&2
usage
fi
env_name="$2"
shift 2
;;
--build-file-path)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --build-file-path requires a string value" >&2
usage
fi
build_file_path="$2"
shift 2
;;
--normal-deps)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --normal-deps requires a string value" >&2
usage
fi
normal_deps="$2"
shift 2
;;
--external-provider-deps)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --external-provider-deps requires a string value" >&2
usage
fi
external_provider_deps="$2"
shift 2
;;
--optional-deps)
if [[ -z "$2" || "$2" == --* ]]; then
echo "Error: --optional-deps requires a string value" >&2
usage
fi
optional_deps="$2"
shift 2
;;
*)
echo "Unknown option: $1" >&2
usage
;;
esac
done
# Check required arguments
if [[ -z "$env_name" || -z "$build_file_path" || -z "$normal_deps" ]]; then
echo "Error: --env-name, --build-file-path, and --normal-deps are required." >&2
usage
fi
if [ -n "$LLAMA_STACK_DIR" ]; then
echo "Using llama-stack-dir=$LLAMA_STACK_DIR"
fi
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
echo "Using llama-stack-client-dir=$LLAMA_STACK_CLIENT_DIR"
fi
ensure_conda_env_python310() {
# Use only global variables set by flag parser
local python_version="3.12"
if ! is_command_available conda; then
printf "${RED}Error: conda command not found. Is Conda installed and in your PATH?${NC}" >&2
exit 1
fi
if conda env list | grep -q "^${env_name} "; then
printf "Conda environment '${env_name}' exists. Checking Python version...\n"
current_version=$(conda run -n "${env_name}" python --version 2>&1 | cut -d' ' -f2 | cut -d'.' -f1,2)
if [ "$current_version" = "$python_version" ]; then
printf "Environment '${env_name}' already has Python ${python_version}. No action needed.\n"
else
printf "Updating environment '${env_name}' to Python ${python_version}...\n"
conda install -n "${env_name}" python="${python_version}" -y
fi
else
printf "Conda environment '${env_name}' does not exist. Creating with Python ${python_version}...\n"
conda create -n "${env_name}" python="${python_version}" -y
fi
eval "$(conda shell.bash hook)"
conda deactivate && conda activate "${env_name}"
"$CONDA_PREFIX"/bin/pip install uv
if [ -n "$TEST_PYPI_VERSION" ]; then
uv pip install fastapi libcst
uv pip install --extra-index-url https://test.pypi.org/simple/ \
llama-stack=="$TEST_PYPI_VERSION" \
"$normal_deps"
if [ -n "$optional_deps" ]; then
IFS='#' read -ra parts <<<"$optional_deps"
for part in "${parts[@]}"; do
echo "$part"
uv pip install $part
done
fi
if [ -n "$external_provider_deps" ]; then
IFS='#' read -ra parts <<<"$external_provider_deps"
for part in "${parts[@]}"; do
echo "$part"
uv pip install "$part"
done
fi
else
if [ -n "$LLAMA_STACK_DIR" ]; then
if [ ! -d "$LLAMA_STACK_DIR" ]; then
printf "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: $LLAMA_STACK_DIR${NC}\n" >&2
exit 1
fi
printf "Installing from LLAMA_STACK_DIR: $LLAMA_STACK_DIR\n"
uv pip install --no-cache-dir -e "$LLAMA_STACK_DIR"
else
PYPI_VERSION="${PYPI_VERSION:-}"
if [ -n "$PYPI_VERSION" ]; then
SPEC_VERSION="llama-stack==${PYPI_VERSION}"
else
SPEC_VERSION="llama-stack"
fi
uv pip install --no-cache-dir "$SPEC_VERSION"
fi
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
if [ ! -d "$LLAMA_STACK_CLIENT_DIR" ]; then
printf "${RED}Warning: LLAMA_STACK_CLIENT_DIR is set but directory does not exist: $LLAMA_STACK_CLIENT_DIR${NC}\n" >&2
exit 1
fi
printf "Installing from LLAMA_STACK_CLIENT_DIR: $LLAMA_STACK_CLIENT_DIR\n"
uv pip install --no-cache-dir -e "$LLAMA_STACK_CLIENT_DIR"
fi
printf "Installing pip dependencies\n"
uv pip install $normal_deps
if [ -n "$optional_deps" ]; then
IFS='#' read -ra parts <<<"$optional_deps"
for part in "${parts[@]}"; do
echo "$part"
uv pip install $part
done
fi
if [ -n "$external_provider_deps" ]; then
IFS='#' read -ra parts <<<"$external_provider_deps"
for part in "${parts[@]}"; do
echo "Getting provider spec for module: $part and installing dependencies"
package_name=$(echo "$part" | sed 's/[<>=!].*//')
python3 -c "
import importlib
import sys
try:
module = importlib.import_module(f'$package_name.provider')
spec = module.get_provider_spec()
if hasattr(spec, 'pip_packages') and spec.pip_packages:
print('\\n'.join(spec.pip_packages))
except Exception as e:
print(f'Error getting provider spec for $package_name: {e}', file=sys.stderr)
" | uv pip install -r -
done
fi
fi
mv "$build_file_path" "$CONDA_PREFIX"/llamastack-build.yaml
echo "Build spec configuration saved at $CONDA_PREFIX/llamastack-build.yaml"
}
ensure_conda_env_python310 "$env_name" "$build_file_path" "$normal_deps" "$optional_deps" "$external_provider_deps"

View file

@ -151,23 +151,37 @@ run() {
fi fi
else else
if [ -n "$LLAMA_STACK_DIR" ]; then if [ -n "$LLAMA_STACK_DIR" ]; then
if [ ! -d "$LLAMA_STACK_DIR" ]; then # only warn if DIR does not start with "git+"
if [ ! -d "$LLAMA_STACK_DIR" ] && [[ "$LLAMA_STACK_DIR" != git+* ]]; then
printf "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: %s${NC}\n" "$LLAMA_STACK_DIR" >&2 printf "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: %s${NC}\n" "$LLAMA_STACK_DIR" >&2
exit 1 exit 1
fi fi
printf "Installing from LLAMA_STACK_DIR: %s\n" "$LLAMA_STACK_DIR" printf "Installing from LLAMA_STACK_DIR: %s\n" "$LLAMA_STACK_DIR"
uv pip install --no-cache-dir -e "$LLAMA_STACK_DIR" # editable only if LLAMA_STACK_DIR does not start with "git+"
if [[ "$LLAMA_STACK_DIR" != git+* ]]; then
EDITABLE="-e"
else
EDITABLE=""
fi
uv pip install --no-cache-dir $EDITABLE "$LLAMA_STACK_DIR"
else else
uv pip install --no-cache-dir llama-stack uv pip install --no-cache-dir llama-stack
fi fi
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
if [ ! -d "$LLAMA_STACK_CLIENT_DIR" ]; then # only warn if DIR does not start with "git+"
if [ ! -d "$LLAMA_STACK_CLIENT_DIR" ] && [[ "$LLAMA_STACK_CLIENT_DIR" != git+* ]]; then
printf "${RED}Warning: LLAMA_STACK_CLIENT_DIR is set but directory does not exist: %s${NC}\n" "$LLAMA_STACK_CLIENT_DIR" >&2 printf "${RED}Warning: LLAMA_STACK_CLIENT_DIR is set but directory does not exist: %s${NC}\n" "$LLAMA_STACK_CLIENT_DIR" >&2
exit 1 exit 1
fi fi
printf "Installing from LLAMA_STACK_CLIENT_DIR: %s\n" "$LLAMA_STACK_CLIENT_DIR" printf "Installing from LLAMA_STACK_CLIENT_DIR: %s\n" "$LLAMA_STACK_CLIENT_DIR"
uv pip install --no-cache-dir -e "$LLAMA_STACK_CLIENT_DIR" # editable only if LLAMA_STACK_CLIENT_DIR does not start with "git+"
if [[ "$LLAMA_STACK_CLIENT_DIR" != git+* ]]; then
EDITABLE="-e"
else
EDITABLE=""
fi
uv pip install --no-cache-dir $EDITABLE "$LLAMA_STACK_CLIENT_DIR"
fi fi
printf "Installing pip dependencies\n" printf "Installing pip dependencies\n"

View file

@ -8,6 +8,7 @@ import inspect
from typing import Any from typing import Any
from llama_stack.apis.agents import Agents from llama_stack.apis.agents import Agents
from llama_stack.apis.batches import Batches
from llama_stack.apis.benchmarks import Benchmarks from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets from llama_stack.apis.datasets import Datasets
@ -75,6 +76,7 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) ->
Api.agents: Agents, Api.agents: Agents,
Api.inference: Inference, Api.inference: Inference,
Api.inspect: Inspect, Api.inspect: Inspect,
Api.batches: Batches,
Api.vector_io: VectorIO, Api.vector_io: VectorIO,
Api.vector_dbs: VectorDBs, Api.vector_dbs: VectorDBs,
Api.models: Models, Api.models: Models,

View file

@ -6,9 +6,7 @@
from typing import Any from typing import Any
from llama_stack.apis.inference import ( from llama_stack.apis.inference import Message
Message,
)
from llama_stack.apis.safety import RunShieldResponse, Safety from llama_stack.apis.safety import RunShieldResponse, Safety
from llama_stack.apis.safety.safety import ModerationObject from llama_stack.apis.safety.safety import ModerationObject
from llama_stack.apis.shields import Shield from llama_stack.apis.shields import Shield
@ -68,6 +66,7 @@ class SafetyRouter(Safety):
list_shields_response = await self.routing_table.list_shields() list_shields_response = await self.routing_table.list_shields()
matches = [s.identifier for s in list_shields_response.data if model == s.provider_resource_id] matches = [s.identifier for s in list_shields_response.data if model == s.provider_resource_id]
if not matches: if not matches:
raise ValueError(f"No shield associated with provider_resource id {model}") raise ValueError(f"No shield associated with provider_resource id {model}")
if len(matches) > 1: if len(matches) > 1:

View file

@ -32,6 +32,7 @@ from fastapi.responses import JSONResponse, StreamingResponse
from openai import BadRequestError from openai import BadRequestError
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
from llama_stack.apis.common.responses import PaginatedResponse from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.cli.utils import add_config_distro_args, get_config_from_args from llama_stack.cli.utils import add_config_distro_args, get_config_from_args
from llama_stack.core.access_control.access_control import AccessDeniedError from llama_stack.core.access_control.access_control import AccessDeniedError
@ -128,6 +129,10 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
] ]
}, },
) )
elif isinstance(exc, ConflictError):
return HTTPException(status_code=409, detail=str(exc))
elif isinstance(exc, ResourceNotFoundError):
return HTTPException(status_code=404, detail=str(exc))
elif isinstance(exc, ValueError): elif isinstance(exc, ValueError):
return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=f"Invalid value: {str(exc)}") return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=f"Invalid value: {str(exc)}")
elif isinstance(exc, BadRequestError): elif isinstance(exc, BadRequestError):

View file

@ -28,6 +28,7 @@ distribution_spec:
- provider_type: inline::localfs - provider_type: inline::localfs
safety: safety:
- provider_type: inline::llama-guard - provider_type: inline::llama-guard
- provider_type: inline::code-scanner
agents: agents:
- provider_type: inline::meta-reference - provider_type: inline::meta-reference
telemetry: telemetry:
@ -48,6 +49,8 @@ distribution_spec:
- provider_type: remote::tavily-search - provider_type: remote::tavily-search
- provider_type: inline::rag-runtime - provider_type: inline::rag-runtime
- provider_type: remote::model-context-protocol - provider_type: remote::model-context-protocol
batches:
- provider_type: inline::reference
image_type: venv image_type: venv
additional_pip_packages: additional_pip_packages:
- aiosqlite - aiosqlite

View file

@ -2,6 +2,7 @@ version: 2
image_name: ci-tests image_name: ci-tests
apis: apis:
- agents - agents
- batches
- datasetio - datasetio
- eval - eval
- files - files
@ -134,6 +135,8 @@ providers:
provider_type: inline::llama-guard provider_type: inline::llama-guard
config: config:
excluded_categories: [] excluded_categories: []
- provider_id: code-scanner
provider_type: inline::code-scanner
agents: agents:
- provider_id: meta-reference - provider_id: meta-reference
provider_type: inline::meta-reference provider_type: inline::meta-reference
@ -204,6 +207,13 @@ providers:
provider_type: inline::rag-runtime provider_type: inline::rag-runtime
- provider_id: model-context-protocol - provider_id: model-context-protocol
provider_type: remote::model-context-protocol provider_type: remote::model-context-protocol
batches:
- provider_id: reference
provider_type: inline::reference
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/batches.db
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/registry.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/registry.db
@ -215,6 +225,9 @@ shields:
- shield_id: llama-guard - shield_id: llama-guard
provider_id: ${env.SAFETY_MODEL:+llama-guard} provider_id: ${env.SAFETY_MODEL:+llama-guard}
provider_shield_id: ${env.SAFETY_MODEL:=} provider_shield_id: ${env.SAFETY_MODEL:=}
- shield_id: code-scanner
provider_id: ${env.CODE_SCANNER_MODEL:+code-scanner}
provider_shield_id: ${env.CODE_SCANNER_MODEL:=}
vector_dbs: [] vector_dbs: []
datasets: [] datasets: []
scoring_fns: [] scoring_fns: []

View file

@ -28,6 +28,7 @@ distribution_spec:
- provider_type: inline::localfs - provider_type: inline::localfs
safety: safety:
- provider_type: inline::llama-guard - provider_type: inline::llama-guard
- provider_type: inline::code-scanner
agents: agents:
- provider_type: inline::meta-reference - provider_type: inline::meta-reference
telemetry: telemetry:
@ -48,6 +49,8 @@ distribution_spec:
- provider_type: remote::tavily-search - provider_type: remote::tavily-search
- provider_type: inline::rag-runtime - provider_type: inline::rag-runtime
- provider_type: remote::model-context-protocol - provider_type: remote::model-context-protocol
batches:
- provider_type: inline::reference
image_type: venv image_type: venv
additional_pip_packages: additional_pip_packages:
- aiosqlite - aiosqlite

View file

@ -2,6 +2,7 @@ version: 2
image_name: starter image_name: starter
apis: apis:
- agents - agents
- batches
- datasetio - datasetio
- eval - eval
- files - files
@ -134,6 +135,8 @@ providers:
provider_type: inline::llama-guard provider_type: inline::llama-guard
config: config:
excluded_categories: [] excluded_categories: []
- provider_id: code-scanner
provider_type: inline::code-scanner
agents: agents:
- provider_id: meta-reference - provider_id: meta-reference
provider_type: inline::meta-reference provider_type: inline::meta-reference
@ -204,6 +207,13 @@ providers:
provider_type: inline::rag-runtime provider_type: inline::rag-runtime
- provider_id: model-context-protocol - provider_id: model-context-protocol
provider_type: remote::model-context-protocol provider_type: remote::model-context-protocol
batches:
- provider_id: reference
provider_type: inline::reference
config:
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/batches.db
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/registry.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/registry.db
@ -215,6 +225,9 @@ shields:
- shield_id: llama-guard - shield_id: llama-guard
provider_id: ${env.SAFETY_MODEL:+llama-guard} provider_id: ${env.SAFETY_MODEL:+llama-guard}
provider_shield_id: ${env.SAFETY_MODEL:=} provider_shield_id: ${env.SAFETY_MODEL:=}
- shield_id: code-scanner
provider_id: ${env.CODE_SCANNER_MODEL:+code-scanner}
provider_shield_id: ${env.CODE_SCANNER_MODEL:=}
vector_dbs: [] vector_dbs: []
datasets: [] datasets: []
scoring_fns: [] scoring_fns: []

View file

@ -15,19 +15,14 @@ from llama_stack.core.datatypes import (
ToolGroupInput, ToolGroupInput,
) )
from llama_stack.core.utils.dynamic import instantiate_class_type from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.distributions.template import ( from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings
DistributionTemplate,
RunConfigSettings,
)
from llama_stack.providers.datatypes import RemoteProviderSpec from llama_stack.providers.datatypes import RemoteProviderSpec
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
from llama_stack.providers.inline.inference.sentence_transformers import ( from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig, SentenceTransformersInferenceConfig,
) )
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.inline.vector_io.milvus.config import ( from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig
MilvusVectorIOConfig,
)
from llama_stack.providers.inline.vector_io.sqlite_vec.config import ( from llama_stack.providers.inline.vector_io.sqlite_vec.config import (
SQLiteVectorIOConfig, SQLiteVectorIOConfig,
) )
@ -119,7 +114,10 @@ def get_distribution_template() -> DistributionTemplate:
BuildProvider(provider_type="remote::pgvector"), BuildProvider(provider_type="remote::pgvector"),
], ],
"files": [BuildProvider(provider_type="inline::localfs")], "files": [BuildProvider(provider_type="inline::localfs")],
"safety": [BuildProvider(provider_type="inline::llama-guard")], "safety": [
BuildProvider(provider_type="inline::llama-guard"),
BuildProvider(provider_type="inline::code-scanner"),
],
"agents": [BuildProvider(provider_type="inline::meta-reference")], "agents": [BuildProvider(provider_type="inline::meta-reference")],
"telemetry": [BuildProvider(provider_type="inline::meta-reference")], "telemetry": [BuildProvider(provider_type="inline::meta-reference")],
"post_training": [BuildProvider(provider_type="inline::huggingface")], "post_training": [BuildProvider(provider_type="inline::huggingface")],
@ -139,6 +137,9 @@ def get_distribution_template() -> DistributionTemplate:
BuildProvider(provider_type="inline::rag-runtime"), BuildProvider(provider_type="inline::rag-runtime"),
BuildProvider(provider_type="remote::model-context-protocol"), BuildProvider(provider_type="remote::model-context-protocol"),
], ],
"batches": [
BuildProvider(provider_type="inline::reference"),
],
} }
files_provider = Provider( files_provider = Provider(
provider_id="meta-reference-files", provider_id="meta-reference-files",
@ -167,6 +168,11 @@ def get_distribution_template() -> DistributionTemplate:
provider_id="${env.SAFETY_MODEL:+llama-guard}", provider_id="${env.SAFETY_MODEL:+llama-guard}",
provider_shield_id="${env.SAFETY_MODEL:=}", provider_shield_id="${env.SAFETY_MODEL:=}",
), ),
ShieldInput(
shield_id="code-scanner",
provider_id="${env.CODE_SCANNER_MODEL:+code-scanner}",
provider_shield_id="${env.CODE_SCANNER_MODEL:=}",
),
] ]
return DistributionTemplate( return DistributionTemplate(

View file

@ -7,13 +7,11 @@
import logging import logging
import os import os
import re import re
import sys
from logging.config import dictConfig from logging.config import dictConfig
from rich.console import Console from rich.console import Console
from rich.errors import MarkupError from rich.errors import MarkupError
from rich.logging import RichHandler from rich.logging import RichHandler
from termcolor import cprint
from llama_stack.core.datatypes import LoggingConfig from llama_stack.core.datatypes import LoggingConfig
@ -66,7 +64,6 @@ def config_to_category_levels(category: str, level: str):
category_levels["root"] = level_value category_levels["root"] = level_value
elif category in CATEGORIES: elif category in CATEGORIES:
category_levels[category] = level_value category_levels[category] = level_value
logging.info(f"Setting '{category}' category to level '{level}'.")
else: else:
logging.warning(f"Unknown logging category: {category}. No changes made.") logging.warning(f"Unknown logging category: {category}. No changes made.")
return category_levels return category_levels
@ -256,7 +253,6 @@ def get_logger(
env_config = os.environ.get("LLAMA_STACK_LOGGING", "") env_config = os.environ.get("LLAMA_STACK_LOGGING", "")
if env_config: if env_config:
cprint(f"Environment variable LLAMA_STACK_LOGGING found: {env_config}", color="yellow", file=sys.stderr)
_category_levels.update(parse_environment_config(env_config)) _category_levels.update(parse_environment_config(env_config))
log_file = os.environ.get("LLAMA_STACK_LOG_FILE") log_file = os.environ.get("LLAMA_STACK_LOG_FILE")

View file

@ -48,8 +48,8 @@ from llama_stack.providers.utils.responses.responses_store import ResponsesStore
from .agent_instance import ChatAgent from .agent_instance import ChatAgent
from .config import MetaReferenceAgentsImplConfig from .config import MetaReferenceAgentsImplConfig
from .openai_responses import OpenAIResponsesImpl
from .persistence import AgentInfo from .persistence import AgentInfo
from .responses.openai_responses import OpenAIResponsesImpl
logger = logging.getLogger() logger = logging.getLogger()

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,271 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import time
import uuid
from collections.abc import AsyncIterator
from pydantic import BaseModel
from llama_stack.apis.agents import Order
from llama_stack.apis.agents.openai_responses import (
ListOpenAIResponseInputItem,
ListOpenAIResponseObject,
OpenAIDeleteResponseObject,
OpenAIResponseInput,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputTool,
OpenAIResponseMessage,
OpenAIResponseObject,
OpenAIResponseObjectStream,
OpenAIResponseText,
OpenAIResponseTextFormat,
)
from llama_stack.apis.inference import (
Inference,
OpenAISystemMessageParam,
)
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.log import get_logger
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
from .streaming import StreamingResponseOrchestrator
from .tool_executor import ToolExecutor
from .types import ChatCompletionContext
from .utils import (
convert_response_input_to_chat_messages,
convert_response_text_to_chat_response_format,
)
logger = get_logger(name=__name__, category="responses")
class OpenAIResponsePreviousResponseWithInputItems(BaseModel):
input_items: ListOpenAIResponseInputItem
response: OpenAIResponseObject
class OpenAIResponsesImpl:
def __init__(
self,
inference_api: Inference,
tool_groups_api: ToolGroups,
tool_runtime_api: ToolRuntime,
responses_store: ResponsesStore,
vector_io_api: VectorIO, # VectorIO
):
self.inference_api = inference_api
self.tool_groups_api = tool_groups_api
self.tool_runtime_api = tool_runtime_api
self.responses_store = responses_store
self.vector_io_api = vector_io_api
self.tool_executor = ToolExecutor(
tool_groups_api=tool_groups_api,
tool_runtime_api=tool_runtime_api,
vector_io_api=vector_io_api,
)
async def _prepend_previous_response(
self,
input: str | list[OpenAIResponseInput],
previous_response_id: str | None = None,
):
if previous_response_id:
previous_response_with_input = await self.responses_store.get_response_object(previous_response_id)
# previous response input items
new_input_items = previous_response_with_input.input
# previous response output items
new_input_items.extend(previous_response_with_input.output)
# new input items from the current request
if isinstance(input, str):
new_input_items.append(OpenAIResponseMessage(content=input, role="user"))
else:
new_input_items.extend(input)
input = new_input_items
return input
async def _prepend_instructions(self, messages, instructions):
if instructions:
messages.insert(0, OpenAISystemMessageParam(content=instructions))
async def get_openai_response(
self,
response_id: str,
) -> OpenAIResponseObject:
response_with_input = await self.responses_store.get_response_object(response_id)
return OpenAIResponseObject(**{k: v for k, v in response_with_input.model_dump().items() if k != "input"})
async def list_openai_responses(
self,
after: str | None = None,
limit: int | None = 50,
model: str | None = None,
order: Order | None = Order.desc,
) -> ListOpenAIResponseObject:
return await self.responses_store.list_responses(after, limit, model, order)
async def list_openai_response_input_items(
self,
response_id: str,
after: str | None = None,
before: str | None = None,
include: list[str] | None = None,
limit: int | None = 20,
order: Order | None = Order.desc,
) -> ListOpenAIResponseInputItem:
"""List input items for a given OpenAI response.
:param response_id: The ID of the response to retrieve input items for.
:param after: An item ID to list items after, used for pagination.
:param before: An item ID to list items before, used for pagination.
:param include: Additional fields to include in the response.
:param limit: A limit on the number of objects to be returned.
:param order: The order to return the input items in.
:returns: An ListOpenAIResponseInputItem.
"""
return await self.responses_store.list_response_input_items(response_id, after, before, include, limit, order)
async def _store_response(
self,
response: OpenAIResponseObject,
input: str | list[OpenAIResponseInput],
) -> None:
new_input_id = f"msg_{uuid.uuid4()}"
if isinstance(input, str):
# synthesize a message from the input string
input_content = OpenAIResponseInputMessageContentText(text=input)
input_content_item = OpenAIResponseMessage(
role="user",
content=[input_content],
id=new_input_id,
)
input_items_data = [input_content_item]
else:
# we already have a list of messages
input_items_data = []
for input_item in input:
if isinstance(input_item, OpenAIResponseMessage):
# These may or may not already have an id, so dump to dict, check for id, and add if missing
input_item_dict = input_item.model_dump()
if "id" not in input_item_dict:
input_item_dict["id"] = new_input_id
input_items_data.append(OpenAIResponseMessage(**input_item_dict))
else:
input_items_data.append(input_item)
await self.responses_store.store_response_object(
response_object=response,
input=input_items_data,
)
async def create_openai_response(
self,
input: str | list[OpenAIResponseInput],
model: str,
instructions: str | None = None,
previous_response_id: str | None = None,
store: bool | None = True,
stream: bool | None = False,
temperature: float | None = None,
text: OpenAIResponseText | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
include: list[str] | None = None,
max_infer_iters: int | None = 10,
):
stream = bool(stream)
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
stream_gen = self._create_streaming_response(
input=input,
model=model,
instructions=instructions,
previous_response_id=previous_response_id,
store=store,
temperature=temperature,
text=text,
tools=tools,
max_infer_iters=max_infer_iters,
)
if stream:
return stream_gen
else:
response = None
async for stream_chunk in stream_gen:
if stream_chunk.type == "response.completed":
if response is not None:
raise ValueError("The response stream completed multiple times! Earlier response: {response}")
response = stream_chunk.response
# don't leave the generator half complete!
if response is None:
raise ValueError("The response stream never completed")
return response
async def _create_streaming_response(
self,
input: str | list[OpenAIResponseInput],
model: str,
instructions: str | None = None,
previous_response_id: str | None = None,
store: bool | None = True,
temperature: float | None = None,
text: OpenAIResponseText | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
max_infer_iters: int | None = 10,
) -> AsyncIterator[OpenAIResponseObjectStream]:
# Input preprocessing
input = await self._prepend_previous_response(input, previous_response_id)
messages = await convert_response_input_to_chat_messages(input)
await self._prepend_instructions(messages, instructions)
# Structured outputs
response_format = await convert_response_text_to_chat_response_format(text)
ctx = ChatCompletionContext(
model=model,
messages=messages,
response_tools=tools,
temperature=temperature,
response_format=response_format,
)
# Create orchestrator and delegate streaming logic
response_id = f"resp-{uuid.uuid4()}"
created_at = int(time.time())
orchestrator = StreamingResponseOrchestrator(
inference_api=self.inference_api,
ctx=ctx,
response_id=response_id,
created_at=created_at,
text=text,
max_infer_iters=max_infer_iters,
tool_executor=self.tool_executor,
)
# Stream the response
final_response = None
async for stream_chunk in orchestrator.create_response():
if stream_chunk.type == "response.completed":
final_response = stream_chunk.response
yield stream_chunk
# Store the response if requested
if store and final_response:
await self._store_response(
response=final_response,
input=input,
)
async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject:
return await self.responses_store.delete_response_object(response_id)

View file

@ -0,0 +1,634 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
from collections.abc import AsyncIterator
from typing import Any
from llama_stack.apis.agents.openai_responses import (
AllowedToolsFilter,
MCPListToolsTool,
OpenAIResponseContentPartOutputText,
OpenAIResponseInputTool,
OpenAIResponseInputToolMCP,
OpenAIResponseObject,
OpenAIResponseObjectStream,
OpenAIResponseObjectStreamResponseCompleted,
OpenAIResponseObjectStreamResponseContentPartAdded,
OpenAIResponseObjectStreamResponseContentPartDone,
OpenAIResponseObjectStreamResponseCreated,
OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta,
OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone,
OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta,
OpenAIResponseObjectStreamResponseMcpCallArgumentsDone,
OpenAIResponseObjectStreamResponseMcpListToolsCompleted,
OpenAIResponseObjectStreamResponseMcpListToolsInProgress,
OpenAIResponseObjectStreamResponseOutputItemAdded,
OpenAIResponseObjectStreamResponseOutputItemDone,
OpenAIResponseObjectStreamResponseOutputTextDelta,
OpenAIResponseOutput,
OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseOutputMessageMCPListTools,
OpenAIResponseText,
WebSearchToolTypes,
)
from llama_stack.apis.inference import (
Inference,
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChatCompletionToolCall,
OpenAIChoice,
)
from llama_stack.log import get_logger
from .types import ChatCompletionContext, ChatCompletionResult
from .utils import convert_chat_choice_to_response_message, is_function_tool_call
logger = get_logger(name=__name__, category="responses")
class StreamingResponseOrchestrator:
def __init__(
self,
inference_api: Inference,
ctx: ChatCompletionContext,
response_id: str,
created_at: int,
text: OpenAIResponseText,
max_infer_iters: int,
tool_executor, # Will be the tool execution logic from the main class
):
self.inference_api = inference_api
self.ctx = ctx
self.response_id = response_id
self.created_at = created_at
self.text = text
self.max_infer_iters = max_infer_iters
self.tool_executor = tool_executor
self.sequence_number = 0
# Store MCP tool mapping that gets built during tool processing
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = {}
async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
# Initialize output messages
output_messages: list[OpenAIResponseOutput] = []
# Create initial response and emit response.created immediately
initial_response = OpenAIResponseObject(
created_at=self.created_at,
id=self.response_id,
model=self.ctx.model,
object="response",
status="in_progress",
output=output_messages.copy(),
text=self.text,
)
yield OpenAIResponseObjectStreamResponseCreated(response=initial_response)
# Process all tools (including MCP tools) and emit streaming events
if self.ctx.response_tools:
async for stream_event in self._process_tools(self.ctx.response_tools, output_messages):
yield stream_event
n_iter = 0
messages = self.ctx.messages.copy()
while True:
completion_result = await self.inference_api.openai_chat_completion(
model=self.ctx.model,
messages=messages,
tools=self.ctx.chat_tools,
stream=True,
temperature=self.ctx.temperature,
response_format=self.ctx.response_format,
)
# Process streaming chunks and build complete response
completion_result_data = None
async for stream_event_or_result in self._process_streaming_chunks(completion_result, output_messages):
if isinstance(stream_event_or_result, ChatCompletionResult):
completion_result_data = stream_event_or_result
else:
yield stream_event_or_result
if not completion_result_data:
raise ValueError("Streaming chunk processor failed to return completion data")
current_response = self._build_chat_completion(completion_result_data)
function_tool_calls, non_function_tool_calls, next_turn_messages = self._separate_tool_calls(
current_response, messages
)
# Handle choices with no tool calls
for choice in current_response.choices:
if not (choice.message.tool_calls and self.ctx.response_tools):
output_messages.append(await convert_chat_choice_to_response_message(choice))
# Execute tool calls and coordinate results
async for stream_event in self._coordinate_tool_execution(
function_tool_calls,
non_function_tool_calls,
completion_result_data,
output_messages,
next_turn_messages,
):
yield stream_event
if not function_tool_calls and not non_function_tool_calls:
break
if function_tool_calls:
logger.info("Exiting inference loop since there is a function (client-side) tool call")
break
n_iter += 1
if n_iter >= self.max_infer_iters:
logger.info(f"Exiting inference loop since iteration count({n_iter}) exceeds {self.max_infer_iters=}")
break
messages = next_turn_messages
# Create final response
final_response = OpenAIResponseObject(
created_at=self.created_at,
id=self.response_id,
model=self.ctx.model,
object="response",
status="completed",
text=self.text,
output=output_messages,
)
# Emit response.completed
yield OpenAIResponseObjectStreamResponseCompleted(response=final_response)
def _separate_tool_calls(self, current_response, messages) -> tuple[list, list, list]:
"""Separate tool calls into function and non-function categories."""
function_tool_calls = []
non_function_tool_calls = []
next_turn_messages = messages.copy()
for choice in current_response.choices:
next_turn_messages.append(choice.message)
if choice.message.tool_calls and self.ctx.response_tools:
for tool_call in choice.message.tool_calls:
if is_function_tool_call(tool_call, self.ctx.response_tools):
function_tool_calls.append(tool_call)
else:
non_function_tool_calls.append(tool_call)
return function_tool_calls, non_function_tool_calls, next_turn_messages
async def _process_streaming_chunks(
self, completion_result, output_messages: list[OpenAIResponseOutput]
) -> AsyncIterator[OpenAIResponseObjectStream | ChatCompletionResult]:
"""Process streaming chunks and emit events, returning completion data."""
# Initialize result tracking
chat_response_id = ""
chat_response_content = []
chat_response_tool_calls: dict[int, OpenAIChatCompletionToolCall] = {}
chunk_created = 0
chunk_model = ""
chunk_finish_reason = ""
# Create a placeholder message item for delta events
message_item_id = f"msg_{uuid.uuid4()}"
# Track tool call items for streaming events
tool_call_item_ids: dict[int, str] = {}
# Track content parts for streaming events
content_part_emitted = False
async for chunk in completion_result:
chat_response_id = chunk.id
chunk_created = chunk.created
chunk_model = chunk.model
for chunk_choice in chunk.choices:
# Emit incremental text content as delta events
if chunk_choice.delta.content:
# Emit content_part.added event for first text chunk
if not content_part_emitted:
content_part_emitted = True
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseContentPartAdded(
response_id=self.response_id,
item_id=message_item_id,
part=OpenAIResponseContentPartOutputText(
text="", # Will be filled incrementally via text deltas
),
sequence_number=self.sequence_number,
)
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputTextDelta(
content_index=0,
delta=chunk_choice.delta.content,
item_id=message_item_id,
output_index=0,
sequence_number=self.sequence_number,
)
# Collect content for final response
chat_response_content.append(chunk_choice.delta.content or "")
if chunk_choice.finish_reason:
chunk_finish_reason = chunk_choice.finish_reason
# Aggregate tool call arguments across chunks
if chunk_choice.delta.tool_calls:
for tool_call in chunk_choice.delta.tool_calls:
response_tool_call = chat_response_tool_calls.get(tool_call.index, None)
# Create new tool call entry if this is the first chunk for this index
is_new_tool_call = response_tool_call is None
if is_new_tool_call:
tool_call_dict: dict[str, Any] = tool_call.model_dump()
tool_call_dict.pop("type", None)
response_tool_call = OpenAIChatCompletionToolCall(**tool_call_dict)
chat_response_tool_calls[tool_call.index] = response_tool_call
# Create item ID for this tool call for streaming events
tool_call_item_id = f"fc_{uuid.uuid4()}"
tool_call_item_ids[tool_call.index] = tool_call_item_id
# Emit output_item.added event for the new function call
self.sequence_number += 1
function_call_item = OpenAIResponseOutputMessageFunctionToolCall(
arguments="", # Will be filled incrementally via delta events
call_id=tool_call.id or "",
name=tool_call.function.name if tool_call.function else "",
id=tool_call_item_id,
status="in_progress",
)
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
response_id=self.response_id,
item=function_call_item,
output_index=len(output_messages),
sequence_number=self.sequence_number,
)
# Stream tool call arguments as they arrive (differentiate between MCP and function calls)
if tool_call.function and tool_call.function.arguments:
tool_call_item_id = tool_call_item_ids[tool_call.index]
self.sequence_number += 1
# Check if this is an MCP tool call
is_mcp_tool = tool_call.function.name and tool_call.function.name in self.mcp_tool_to_server
if is_mcp_tool:
# Emit MCP-specific argument delta event
yield OpenAIResponseObjectStreamResponseMcpCallArgumentsDelta(
delta=tool_call.function.arguments,
item_id=tool_call_item_id,
output_index=len(output_messages),
sequence_number=self.sequence_number,
)
else:
# Emit function call argument delta event
yield OpenAIResponseObjectStreamResponseFunctionCallArgumentsDelta(
delta=tool_call.function.arguments,
item_id=tool_call_item_id,
output_index=len(output_messages),
sequence_number=self.sequence_number,
)
# Accumulate arguments for final response (only for subsequent chunks)
if not is_new_tool_call:
response_tool_call.function.arguments = (
response_tool_call.function.arguments or ""
) + tool_call.function.arguments
# Emit arguments.done events for completed tool calls (differentiate between MCP and function calls)
for tool_call_index in sorted(chat_response_tool_calls.keys()):
tool_call_item_id = tool_call_item_ids[tool_call_index]
final_arguments = chat_response_tool_calls[tool_call_index].function.arguments or ""
tool_call_name = chat_response_tool_calls[tool_call_index].function.name
# Check if this is an MCP tool call
is_mcp_tool = tool_call_name and tool_call_name in self.mcp_tool_to_server
self.sequence_number += 1
done_event_cls = (
OpenAIResponseObjectStreamResponseMcpCallArgumentsDone
if is_mcp_tool
else OpenAIResponseObjectStreamResponseFunctionCallArgumentsDone
)
yield done_event_cls(
arguments=final_arguments,
item_id=tool_call_item_id,
output_index=len(output_messages),
sequence_number=self.sequence_number,
)
# Emit content_part.done event if text content was streamed (before content gets cleared)
if content_part_emitted:
final_text = "".join(chat_response_content)
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseContentPartDone(
response_id=self.response_id,
item_id=message_item_id,
part=OpenAIResponseContentPartOutputText(
text=final_text,
),
sequence_number=self.sequence_number,
)
# Clear content when there are tool calls (OpenAI spec behavior)
if chat_response_tool_calls:
chat_response_content = []
yield ChatCompletionResult(
response_id=chat_response_id,
content=chat_response_content,
tool_calls=chat_response_tool_calls,
created=chunk_created,
model=chunk_model,
finish_reason=chunk_finish_reason,
message_item_id=message_item_id,
tool_call_item_ids=tool_call_item_ids,
content_part_emitted=content_part_emitted,
)
def _build_chat_completion(self, result: ChatCompletionResult) -> OpenAIChatCompletion:
"""Build OpenAIChatCompletion from ChatCompletionResult."""
# Convert collected chunks to complete response
if result.tool_calls:
tool_calls = [result.tool_calls[i] for i in sorted(result.tool_calls.keys())]
else:
tool_calls = None
assistant_message = OpenAIAssistantMessageParam(
content=result.content_text,
tool_calls=tool_calls,
)
return OpenAIChatCompletion(
id=result.response_id,
choices=[
OpenAIChoice(
message=assistant_message,
finish_reason=result.finish_reason,
index=0,
)
],
created=result.created,
model=result.model,
)
async def _coordinate_tool_execution(
self,
function_tool_calls: list,
non_function_tool_calls: list,
completion_result_data: ChatCompletionResult,
output_messages: list[OpenAIResponseOutput],
next_turn_messages: list,
) -> AsyncIterator[OpenAIResponseObjectStream]:
"""Coordinate execution of both function and non-function tool calls."""
# Execute non-function tool calls
for tool_call in non_function_tool_calls:
# Find the item_id for this tool call
matching_item_id = None
for index, item_id in completion_result_data.tool_call_item_ids.items():
response_tool_call = completion_result_data.tool_calls.get(index)
if response_tool_call and response_tool_call.id == tool_call.id:
matching_item_id = item_id
break
# Use a fallback item_id if not found
if not matching_item_id:
matching_item_id = f"tc_{uuid.uuid4()}"
# Execute tool call with streaming
tool_call_log = None
tool_response_message = None
async for result in self.tool_executor.execute_tool_call(
tool_call,
self.ctx,
self.sequence_number,
len(output_messages),
matching_item_id,
self.mcp_tool_to_server,
):
if result.stream_event:
# Forward streaming events
self.sequence_number = result.sequence_number
yield result.stream_event
if result.final_output_message is not None:
tool_call_log = result.final_output_message
tool_response_message = result.final_input_message
self.sequence_number = result.sequence_number
if tool_call_log:
output_messages.append(tool_call_log)
# Emit output_item.done event for completed non-function tool call
if matching_item_id:
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputItemDone(
response_id=self.response_id,
item=tool_call_log,
output_index=len(output_messages) - 1,
sequence_number=self.sequence_number,
)
if tool_response_message:
next_turn_messages.append(tool_response_message)
# Execute function tool calls (client-side)
for tool_call in function_tool_calls:
# Find the item_id for this tool call from our tracking dictionary
matching_item_id = None
for index, item_id in completion_result_data.tool_call_item_ids.items():
response_tool_call = completion_result_data.tool_calls.get(index)
if response_tool_call and response_tool_call.id == tool_call.id:
matching_item_id = item_id
break
# Use existing item_id or create new one if not found
final_item_id = matching_item_id or f"fc_{uuid.uuid4()}"
function_call_item = OpenAIResponseOutputMessageFunctionToolCall(
arguments=tool_call.function.arguments or "",
call_id=tool_call.id,
name=tool_call.function.name or "",
id=final_item_id,
status="completed",
)
output_messages.append(function_call_item)
# Emit output_item.done event for completed function call
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputItemDone(
response_id=self.response_id,
item=function_call_item,
output_index=len(output_messages) - 1,
sequence_number=self.sequence_number,
)
async def _process_tools(
self, tools: list[OpenAIResponseInputTool], output_messages: list[OpenAIResponseOutput]
) -> AsyncIterator[OpenAIResponseObjectStream]:
"""Process all tools and emit appropriate streaming events."""
from openai.types.chat import ChatCompletionToolParam
from llama_stack.apis.tools import Tool
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
def make_openai_tool(tool_name: str, tool: Tool) -> ChatCompletionToolParam:
tool_def = ToolDefinition(
tool_name=tool_name,
description=tool.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
)
for param in tool.parameters
},
)
return convert_tooldef_to_openai_tool(tool_def)
# Initialize chat_tools if not already set
if self.ctx.chat_tools is None:
self.ctx.chat_tools = []
for input_tool in tools:
if input_tool.type == "function":
self.ctx.chat_tools.append(ChatCompletionToolParam(type="function", function=input_tool.model_dump()))
elif input_tool.type in WebSearchToolTypes:
tool_name = "web_search"
# Need to access tool_groups_api from tool_executor
tool = await self.tool_executor.tool_groups_api.get_tool(tool_name)
if not tool:
raise ValueError(f"Tool {tool_name} not found")
self.ctx.chat_tools.append(make_openai_tool(tool_name, tool))
elif input_tool.type == "file_search":
tool_name = "knowledge_search"
tool = await self.tool_executor.tool_groups_api.get_tool(tool_name)
if not tool:
raise ValueError(f"Tool {tool_name} not found")
self.ctx.chat_tools.append(make_openai_tool(tool_name, tool))
elif input_tool.type == "mcp":
async for stream_event in self._process_mcp_tool(input_tool, output_messages):
yield stream_event
else:
raise ValueError(f"Llama Stack OpenAI Responses does not yet support tool type: {input_tool.type}")
async def _process_mcp_tool(
self, mcp_tool: OpenAIResponseInputToolMCP, output_messages: list[OpenAIResponseOutput]
) -> AsyncIterator[OpenAIResponseObjectStream]:
"""Process an MCP tool configuration and emit appropriate streaming events."""
from llama_stack.providers.utils.tools.mcp import list_mcp_tools
# Emit mcp_list_tools.in_progress
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseMcpListToolsInProgress(
sequence_number=self.sequence_number,
)
try:
# Parse allowed/never allowed tools
always_allowed = None
never_allowed = None
if mcp_tool.allowed_tools:
if isinstance(mcp_tool.allowed_tools, list):
always_allowed = mcp_tool.allowed_tools
elif isinstance(mcp_tool.allowed_tools, AllowedToolsFilter):
always_allowed = mcp_tool.allowed_tools.always
never_allowed = mcp_tool.allowed_tools.never
# Call list_mcp_tools
tool_defs = await list_mcp_tools(
endpoint=mcp_tool.server_url,
headers=mcp_tool.headers or {},
)
# Create the MCP list tools message
mcp_list_message = OpenAIResponseOutputMessageMCPListTools(
id=f"mcp_list_{uuid.uuid4()}",
server_label=mcp_tool.server_label,
tools=[],
)
# Process tools and update context
for t in tool_defs.data:
if never_allowed and t.name in never_allowed:
continue
if not always_allowed or t.name in always_allowed:
# Add to chat tools for inference
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
from llama_stack.providers.utils.inference.openai_compat import convert_tooldef_to_openai_tool
tool_def = ToolDefinition(
tool_name=t.name,
description=t.description,
parameters={
param.name: ToolParamDefinition(
param_type=param.parameter_type,
description=param.description,
required=param.required,
default=param.default,
)
for param in t.parameters
},
)
openai_tool = convert_tooldef_to_openai_tool(tool_def)
if self.ctx.chat_tools is None:
self.ctx.chat_tools = []
self.ctx.chat_tools.append(openai_tool)
# Add to MCP tool mapping
if t.name in self.mcp_tool_to_server:
raise ValueError(f"Duplicate tool name {t.name} found for server {mcp_tool.server_label}")
self.mcp_tool_to_server[t.name] = mcp_tool
# Add to MCP list message
mcp_list_message.tools.append(
MCPListToolsTool(
name=t.name,
description=t.description,
input_schema={
"type": "object",
"properties": {
p.name: {
"type": p.parameter_type,
"description": p.description,
}
for p in t.parameters
},
"required": [p.name for p in t.parameters if p.required],
},
)
)
# Add the MCP list message to output
output_messages.append(mcp_list_message)
# Emit output_item.added for the MCP list tools message
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
response_id=self.response_id,
item=mcp_list_message,
output_index=len(output_messages) - 1,
sequence_number=self.sequence_number,
)
# Emit mcp_list_tools.completed
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseMcpListToolsCompleted(
sequence_number=self.sequence_number,
)
# Emit output_item.done for the MCP list tools message
self.sequence_number += 1
yield OpenAIResponseObjectStreamResponseOutputItemDone(
response_id=self.response_id,
item=mcp_list_message,
output_index=len(output_messages) - 1,
sequence_number=self.sequence_number,
)
except Exception as e:
# TODO: Emit mcp_list_tools.failed event if needed
logger.exception(f"Failed to list MCP tools from {mcp_tool.server_url}: {e}")
raise

View file

@ -0,0 +1,379 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import json
from collections.abc import AsyncIterator
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInputToolFileSearch,
OpenAIResponseInputToolMCP,
OpenAIResponseObjectStreamResponseMcpCallCompleted,
OpenAIResponseObjectStreamResponseMcpCallFailed,
OpenAIResponseObjectStreamResponseMcpCallInProgress,
OpenAIResponseObjectStreamResponseWebSearchCallCompleted,
OpenAIResponseObjectStreamResponseWebSearchCallInProgress,
OpenAIResponseObjectStreamResponseWebSearchCallSearching,
OpenAIResponseOutputMessageFileSearchToolCall,
OpenAIResponseOutputMessageFileSearchToolCallResults,
OpenAIResponseOutputMessageWebSearchToolCall,
)
from llama_stack.apis.common.content_types import (
ImageContentItem,
TextContentItem,
)
from llama_stack.apis.inference import (
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
OpenAIChatCompletionToolCall,
OpenAIImageURL,
OpenAIToolMessageParam,
)
from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.log import get_logger
from .types import ChatCompletionContext, ToolExecutionResult
logger = get_logger(name=__name__, category="responses")
class ToolExecutor:
def __init__(
self,
tool_groups_api: ToolGroups,
tool_runtime_api: ToolRuntime,
vector_io_api: VectorIO,
):
self.tool_groups_api = tool_groups_api
self.tool_runtime_api = tool_runtime_api
self.vector_io_api = vector_io_api
async def execute_tool_call(
self,
tool_call: OpenAIChatCompletionToolCall,
ctx: ChatCompletionContext,
sequence_number: int,
output_index: int,
item_id: str,
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
) -> AsyncIterator[ToolExecutionResult]:
tool_call_id = tool_call.id
function = tool_call.function
tool_kwargs = json.loads(function.arguments) if function.arguments else {}
if not function or not tool_call_id or not function.name:
yield ToolExecutionResult(sequence_number=sequence_number)
return
# Emit progress events for tool execution start
async for event_result in self._emit_progress_events(
function.name, ctx, sequence_number, output_index, item_id, mcp_tool_to_server
):
sequence_number = event_result.sequence_number
yield event_result
# Execute the actual tool call
error_exc, result = await self._execute_tool(function.name, tool_kwargs, ctx, mcp_tool_to_server)
# Emit completion events for tool execution
has_error = error_exc or (result and ((result.error_code and result.error_code > 0) or result.error_message))
async for event_result in self._emit_completion_events(
function.name, ctx, sequence_number, output_index, item_id, has_error, mcp_tool_to_server
):
sequence_number = event_result.sequence_number
yield event_result
# Build result messages from tool execution
output_message, input_message = await self._build_result_messages(
function, tool_call_id, tool_kwargs, ctx, error_exc, result, has_error, mcp_tool_to_server
)
# Yield the final result
yield ToolExecutionResult(
sequence_number=sequence_number, final_output_message=output_message, final_input_message=input_message
)
async def _execute_knowledge_search_via_vector_store(
self,
query: str,
response_file_search_tool: OpenAIResponseInputToolFileSearch,
) -> ToolInvocationResult:
"""Execute knowledge search using vector_stores.search API with filters support."""
search_results = []
# Create search tasks for all vector stores
async def search_single_store(vector_store_id):
try:
search_response = await self.vector_io_api.openai_search_vector_store(
vector_store_id=vector_store_id,
query=query,
filters=response_file_search_tool.filters,
max_num_results=response_file_search_tool.max_num_results,
ranking_options=response_file_search_tool.ranking_options,
rewrite_query=False,
)
return search_response.data
except Exception as e:
logger.warning(f"Failed to search vector store {vector_store_id}: {e}")
return []
# Run all searches in parallel using gather
search_tasks = [search_single_store(vid) for vid in response_file_search_tool.vector_store_ids]
all_results = await asyncio.gather(*search_tasks)
# Flatten results
for results in all_results:
search_results.extend(results)
# Convert search results to tool result format matching memory.py
# Format the results as interleaved content similar to memory.py
content_items = []
content_items.append(
TextContentItem(
text=f"knowledge_search tool found {len(search_results)} chunks:\nBEGIN of knowledge_search tool results.\n"
)
)
for i, result_item in enumerate(search_results):
chunk_text = result_item.content[0].text if result_item.content else ""
metadata_text = f"document_id: {result_item.file_id}, score: {result_item.score}"
if result_item.attributes:
metadata_text += f", attributes: {result_item.attributes}"
text_content = f"[{i + 1}] {metadata_text}\n{chunk_text}\n"
content_items.append(TextContentItem(text=text_content))
content_items.append(TextContentItem(text="END of knowledge_search tool results.\n"))
content_items.append(
TextContentItem(
text=f'The above results were retrieved to help answer the user\'s query: "{query}". Use them as supporting information only in answering this query.\n',
)
)
return ToolInvocationResult(
content=content_items,
metadata={
"document_ids": [r.file_id for r in search_results],
"chunks": [r.content[0].text if r.content else "" for r in search_results],
"scores": [r.score for r in search_results],
},
)
async def _emit_progress_events(
self,
function_name: str,
ctx: ChatCompletionContext,
sequence_number: int,
output_index: int,
item_id: str,
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
) -> AsyncIterator[ToolExecutionResult]:
"""Emit progress events for tool execution start."""
# Emit in_progress event based on tool type (only for tools with specific streaming events)
progress_event = None
if mcp_tool_to_server and function_name in mcp_tool_to_server:
sequence_number += 1
progress_event = OpenAIResponseObjectStreamResponseMcpCallInProgress(
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
)
elif function_name == "web_search":
sequence_number += 1
progress_event = OpenAIResponseObjectStreamResponseWebSearchCallInProgress(
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
)
# Note: knowledge_search and other custom tools don't have specific streaming events in OpenAI spec
if progress_event:
yield ToolExecutionResult(stream_event=progress_event, sequence_number=sequence_number)
# For web search, emit searching event
if function_name == "web_search":
sequence_number += 1
searching_event = OpenAIResponseObjectStreamResponseWebSearchCallSearching(
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
)
yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number)
async def _execute_tool(
self,
function_name: str,
tool_kwargs: dict,
ctx: ChatCompletionContext,
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
) -> tuple[Exception | None, any]:
"""Execute the tool and return error exception and result."""
error_exc = None
result = None
try:
if mcp_tool_to_server and function_name in mcp_tool_to_server:
from llama_stack.providers.utils.tools.mcp import invoke_mcp_tool
mcp_tool = mcp_tool_to_server[function_name]
result = await invoke_mcp_tool(
endpoint=mcp_tool.server_url,
headers=mcp_tool.headers or {},
tool_name=function_name,
kwargs=tool_kwargs,
)
elif function_name == "knowledge_search":
response_file_search_tool = next(
(t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)),
None,
)
if response_file_search_tool:
# Use vector_stores.search API instead of knowledge_search tool
# to support filters and ranking_options
query = tool_kwargs.get("query", "")
result = await self._execute_knowledge_search_via_vector_store(
query=query,
response_file_search_tool=response_file_search_tool,
)
else:
result = await self.tool_runtime_api.invoke_tool(
tool_name=function_name,
kwargs=tool_kwargs,
)
except Exception as e:
error_exc = e
return error_exc, result
async def _emit_completion_events(
self,
function_name: str,
ctx: ChatCompletionContext,
sequence_number: int,
output_index: int,
item_id: str,
has_error: bool,
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
) -> AsyncIterator[ToolExecutionResult]:
"""Emit completion or failure events for tool execution."""
completion_event = None
if mcp_tool_to_server and function_name in mcp_tool_to_server:
sequence_number += 1
if has_error:
completion_event = OpenAIResponseObjectStreamResponseMcpCallFailed(
sequence_number=sequence_number,
)
else:
completion_event = OpenAIResponseObjectStreamResponseMcpCallCompleted(
sequence_number=sequence_number,
)
elif function_name == "web_search":
sequence_number += 1
completion_event = OpenAIResponseObjectStreamResponseWebSearchCallCompleted(
item_id=item_id,
output_index=output_index,
sequence_number=sequence_number,
)
# Note: knowledge_search and other custom tools don't have specific completion events in OpenAI spec
if completion_event:
yield ToolExecutionResult(stream_event=completion_event, sequence_number=sequence_number)
async def _build_result_messages(
self,
function,
tool_call_id: str,
tool_kwargs: dict,
ctx: ChatCompletionContext,
error_exc: Exception | None,
result: any,
has_error: bool,
mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None,
) -> tuple[any, any]:
"""Build output and input messages from tool execution results."""
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
# Build output message
if mcp_tool_to_server and function.name in mcp_tool_to_server:
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseOutputMessageMCPCall,
)
message = OpenAIResponseOutputMessageMCPCall(
id=tool_call_id,
arguments=function.arguments,
name=function.name,
server_label=mcp_tool_to_server[function.name].server_label,
)
if error_exc:
message.error = str(error_exc)
elif (result and result.error_code and result.error_code > 0) or (result and result.error_message):
message.error = f"Error (code {result.error_code}): {result.error_message}"
elif result and result.content:
message.output = interleaved_content_as_str(result.content)
else:
if function.name == "web_search":
message = OpenAIResponseOutputMessageWebSearchToolCall(
id=tool_call_id,
status="completed",
)
if has_error:
message.status = "failed"
elif function.name == "knowledge_search":
message = OpenAIResponseOutputMessageFileSearchToolCall(
id=tool_call_id,
queries=[tool_kwargs.get("query", "")],
status="completed",
)
if result and "document_ids" in result.metadata:
message.results = []
for i, doc_id in enumerate(result.metadata["document_ids"]):
text = result.metadata["chunks"][i] if "chunks" in result.metadata else None
score = result.metadata["scores"][i] if "scores" in result.metadata else None
message.results.append(
OpenAIResponseOutputMessageFileSearchToolCallResults(
file_id=doc_id,
filename=doc_id,
text=text,
score=score,
attributes={},
)
)
if has_error:
message.status = "failed"
else:
raise ValueError(f"Unknown tool {function.name} called")
# Build input message
input_message = None
if result and result.content:
if isinstance(result.content, str):
content = result.content
elif isinstance(result.content, list):
content = []
for item in result.content:
if isinstance(item, TextContentItem):
part = OpenAIChatCompletionContentPartTextParam(text=item.text)
elif isinstance(item, ImageContentItem):
if item.image.data:
url = f"data:image;base64,{item.image.data}"
else:
url = item.image.url
part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url))
else:
raise ValueError(f"Unknown result content type: {type(item)}")
content.append(part)
else:
raise ValueError(f"Unknown result content type: {type(result.content)}")
input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id)
else:
text = str(error_exc) if error_exc else "Tool execution failed"
input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id)
return message, input_message

View file

@ -0,0 +1,60 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from dataclasses import dataclass
from openai.types.chat import ChatCompletionToolParam
from pydantic import BaseModel
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInputTool,
OpenAIResponseObjectStream,
OpenAIResponseOutput,
)
from llama_stack.apis.inference import OpenAIChatCompletionToolCall, OpenAIMessageParam, OpenAIResponseFormatParam
class ToolExecutionResult(BaseModel):
"""Result of streaming tool execution."""
stream_event: OpenAIResponseObjectStream | None = None
sequence_number: int
final_output_message: OpenAIResponseOutput | None = None
final_input_message: OpenAIMessageParam | None = None
@dataclass
class ChatCompletionResult:
"""Result of processing streaming chat completion chunks."""
response_id: str
content: list[str]
tool_calls: dict[int, OpenAIChatCompletionToolCall]
created: int
model: str
finish_reason: str
message_item_id: str # For streaming events
tool_call_item_ids: dict[int, str] # For streaming events
content_part_emitted: bool # Tracking state
@property
def content_text(self) -> str:
"""Get joined content as string."""
return "".join(self.content)
@property
def has_tool_calls(self) -> bool:
"""Check if there are any tool calls."""
return bool(self.tool_calls)
class ChatCompletionContext(BaseModel):
model: str
messages: list[OpenAIMessageParam]
response_tools: list[OpenAIResponseInputTool] | None = None
chat_tools: list[ChatCompletionToolParam] | None = None
temperature: float | None
response_format: OpenAIResponseFormatParam

View file

@ -0,0 +1,169 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInput,
OpenAIResponseInputFunctionToolCallOutput,
OpenAIResponseInputMessageContent,
OpenAIResponseInputMessageContentImage,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputTool,
OpenAIResponseMessage,
OpenAIResponseOutputMessageContent,
OpenAIResponseOutputMessageContentOutputText,
OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseText,
)
from llama_stack.apis.inference import (
OpenAIAssistantMessageParam,
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartParam,
OpenAIChatCompletionContentPartTextParam,
OpenAIChatCompletionToolCall,
OpenAIChatCompletionToolCallFunction,
OpenAIChoice,
OpenAIDeveloperMessageParam,
OpenAIImageURL,
OpenAIJSONSchema,
OpenAIMessageParam,
OpenAIResponseFormatJSONObject,
OpenAIResponseFormatJSONSchema,
OpenAIResponseFormatParam,
OpenAIResponseFormatText,
OpenAISystemMessageParam,
OpenAIToolMessageParam,
OpenAIUserMessageParam,
)
async def convert_chat_choice_to_response_message(choice: OpenAIChoice) -> OpenAIResponseMessage:
"""Convert an OpenAI Chat Completion choice into an OpenAI Response output message."""
output_content = ""
if isinstance(choice.message.content, str):
output_content = choice.message.content
elif isinstance(choice.message.content, OpenAIChatCompletionContentPartTextParam):
output_content = choice.message.content.text
else:
raise ValueError(
f"Llama Stack OpenAI Responses does not yet support output content type: {type(choice.message.content)}"
)
return OpenAIResponseMessage(
id=f"msg_{uuid.uuid4()}",
content=[OpenAIResponseOutputMessageContentOutputText(text=output_content)],
status="completed",
role="assistant",
)
async def convert_response_content_to_chat_content(
content: (str | list[OpenAIResponseInputMessageContent] | list[OpenAIResponseOutputMessageContent]),
) -> str | list[OpenAIChatCompletionContentPartParam]:
"""
Convert the content parts from an OpenAI Response API request into OpenAI Chat Completion content parts.
The content schemas of each API look similar, but are not exactly the same.
"""
if isinstance(content, str):
return content
converted_parts = []
for content_part in content:
if isinstance(content_part, OpenAIResponseInputMessageContentText):
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
elif isinstance(content_part, OpenAIResponseOutputMessageContentOutputText):
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part.text))
elif isinstance(content_part, OpenAIResponseInputMessageContentImage):
if content_part.image_url:
image_url = OpenAIImageURL(url=content_part.image_url, detail=content_part.detail)
converted_parts.append(OpenAIChatCompletionContentPartImageParam(image_url=image_url))
elif isinstance(content_part, str):
converted_parts.append(OpenAIChatCompletionContentPartTextParam(text=content_part))
else:
raise ValueError(
f"Llama Stack OpenAI Responses does not yet support content type '{type(content_part)}' in this context"
)
return converted_parts
async def convert_response_input_to_chat_messages(
input: str | list[OpenAIResponseInput],
) -> list[OpenAIMessageParam]:
"""
Convert the input from an OpenAI Response API request into OpenAI Chat Completion messages.
"""
messages: list[OpenAIMessageParam] = []
if isinstance(input, list):
for input_item in input:
if isinstance(input_item, OpenAIResponseInputFunctionToolCallOutput):
messages.append(
OpenAIToolMessageParam(
content=input_item.output,
tool_call_id=input_item.call_id,
)
)
elif isinstance(input_item, OpenAIResponseOutputMessageFunctionToolCall):
tool_call = OpenAIChatCompletionToolCall(
index=0,
id=input_item.call_id,
function=OpenAIChatCompletionToolCallFunction(
name=input_item.name,
arguments=input_item.arguments,
),
)
messages.append(OpenAIAssistantMessageParam(tool_calls=[tool_call]))
else:
content = await convert_response_content_to_chat_content(input_item.content)
message_type = await get_message_type_by_role(input_item.role)
if message_type is None:
raise ValueError(
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
)
messages.append(message_type(content=content))
else:
messages.append(OpenAIUserMessageParam(content=input))
return messages
async def convert_response_text_to_chat_response_format(
text: OpenAIResponseText,
) -> OpenAIResponseFormatParam:
"""
Convert an OpenAI Response text parameter into an OpenAI Chat Completion response format.
"""
if not text.format or text.format["type"] == "text":
return OpenAIResponseFormatText(type="text")
if text.format["type"] == "json_object":
return OpenAIResponseFormatJSONObject()
if text.format["type"] == "json_schema":
return OpenAIResponseFormatJSONSchema(
json_schema=OpenAIJSONSchema(name=text.format["name"], schema=text.format["schema"])
)
raise ValueError(f"Unsupported text format: {text.format}")
async def get_message_type_by_role(role: str):
role_to_type = {
"user": OpenAIUserMessageParam,
"system": OpenAISystemMessageParam,
"assistant": OpenAIAssistantMessageParam,
"developer": OpenAIDeveloperMessageParam,
}
return role_to_type.get(role)
def is_function_tool_call(
tool_call: OpenAIChatCompletionToolCall,
tools: list[OpenAIResponseInputTool],
) -> bool:
if not tool_call.function:
return False
for t in tools:
if t.type == "function" and t.name == tool_call.function.name:
return True
return False

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,36 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference
from llama_stack.apis.models import Models
from llama_stack.core.datatypes import AccessRule, Api
from llama_stack.providers.utils.kvstore import kvstore_impl
from .batches import ReferenceBatchesImpl
from .config import ReferenceBatchesImplConfig
__all__ = ["ReferenceBatchesImpl", "ReferenceBatchesImplConfig"]
async def get_provider_impl(config: ReferenceBatchesImplConfig, deps: dict[Api, Any], policy: list[AccessRule]):
kvstore = await kvstore_impl(config.kvstore)
inference_api: Inference | None = deps.get(Api.inference)
files_api: Files | None = deps.get(Api.files)
models_api: Models | None = deps.get(Api.models)
if inference_api is None:
raise ValueError("Inference API is required but not provided in dependencies")
if files_api is None:
raise ValueError("Files API is required but not provided in dependencies")
if models_api is None:
raise ValueError("Models API is required but not provided in dependencies")
impl = ReferenceBatchesImpl(config, inference_api, files_api, models_api, kvstore)
await impl.initialize()
return impl

View file

@ -0,0 +1,580 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import itertools
import json
import time
import uuid
from io import BytesIO
from typing import Any, Literal
from openai.types.batch import BatchError, Errors
from pydantic import BaseModel
from llama_stack.apis.batches import Batches, BatchObject, ListBatchesResponse
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
from llama_stack.apis.files import Files, OpenAIFilePurpose
from llama_stack.apis.inference import (
Inference,
OpenAIAssistantMessageParam,
OpenAIDeveloperMessageParam,
OpenAIMessageParam,
OpenAISystemMessageParam,
OpenAIToolMessageParam,
OpenAIUserMessageParam,
)
from llama_stack.apis.models import Models
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore
from .config import ReferenceBatchesImplConfig
BATCH_PREFIX = "batch:"
logger = get_logger(__name__)
class AsyncBytesIO:
"""
Async-compatible BytesIO wrapper to allow async file-like operations.
We use this when uploading files to the Files API, as it expects an
async file-like object.
"""
def __init__(self, data: bytes):
self._buffer = BytesIO(data)
async def read(self, n=-1):
return self._buffer.read(n)
async def seek(self, pos, whence=0):
return self._buffer.seek(pos, whence)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self._buffer.close()
def __getattr__(self, name):
return getattr(self._buffer, name)
class BatchRequest(BaseModel):
line_num: int
custom_id: str
method: str
url: str
body: dict[str, Any]
def convert_to_openai_message_param(msg: dict[str, Any]) -> OpenAIMessageParam:
"""Convert a message dictionary to OpenAIMessageParam based on role."""
role = msg.get("role")
if role == "user":
return OpenAIUserMessageParam(**msg)
elif role == "system":
return OpenAISystemMessageParam(**msg)
elif role == "assistant":
return OpenAIAssistantMessageParam(**msg)
elif role == "tool":
return OpenAIToolMessageParam(**msg)
elif role == "developer":
return OpenAIDeveloperMessageParam(**msg)
else:
raise ValueError(f"Unknown message role: {role}")
class ReferenceBatchesImpl(Batches):
"""Reference implementation of the Batches API.
This implementation processes batch files by making individual requests
to the inference API and generates output files with results.
"""
def __init__(
self,
config: ReferenceBatchesImplConfig,
inference_api: Inference,
files_api: Files,
models_api: Models,
kvstore: KVStore,
) -> None:
self.config = config
self.kvstore = kvstore
self.inference_api = inference_api
self.files_api = files_api
self.models_api = models_api
self._processing_tasks: dict[str, asyncio.Task] = {}
self._batch_semaphore = asyncio.Semaphore(config.max_concurrent_batches)
self._update_batch_lock = asyncio.Lock()
# this is to allow tests to disable background processing
self.process_batches = True
async def initialize(self) -> None:
# TODO: start background processing of existing tasks
pass
async def shutdown(self) -> None:
"""Shutdown the batches provider."""
if self._processing_tasks:
# don't cancel tasks - just let them stop naturally on shutdown
# cancelling would mark batches as "cancelled" in the database
logger.info(f"Shutdown initiated with {len(self._processing_tasks)} active batch processing tasks")
# TODO (SECURITY): this currently works w/ configured api keys, not with x-llamastack-provider-data or with user policy restrictions
async def create_batch(
self,
input_file_id: str,
endpoint: str,
completion_window: Literal["24h"],
metadata: dict[str, str] | None = None,
) -> BatchObject:
"""
Create a new batch for processing multiple API requests.
Error handling by levels -
0. Input param handling, results in 40x errors before processing, e.g.
- Wrong completion_window
- Invalid metadata types
- Unknown endpoint
-> no batch created
1. Errors preventing processing, result in BatchErrors aggregated in process_batch, e.g.
- input_file_id missing
- invalid json in file
- missing custom_id, method, url, body
- invalid model
- streaming
-> batch created, validation sends to failed status
2. Processing errors, result in error_file_id entries, e.g.
- Any error returned from inference endpoint
-> batch created, goes to completed status
"""
# TODO: set expiration time for garbage collection
if endpoint not in ["/v1/chat/completions"]:
raise ValueError(
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions. Code: invalid_value. Param: endpoint",
)
if completion_window != "24h":
raise ValueError(
f"Invalid completion_window: {completion_window}. Supported values are: 24h. Code: invalid_value. Param: completion_window",
)
batch_id = f"batch_{uuid.uuid4().hex[:16]}"
current_time = int(time.time())
batch = BatchObject(
id=batch_id,
object="batch",
endpoint=endpoint,
input_file_id=input_file_id,
completion_window=completion_window,
status="validating",
created_at=current_time,
metadata=metadata,
)
await self.kvstore.set(f"batch:{batch_id}", batch.to_json())
if self.process_batches:
task = asyncio.create_task(self._process_batch(batch_id))
self._processing_tasks[batch_id] = task
return batch
async def cancel_batch(self, batch_id: str) -> BatchObject:
"""Cancel a batch that is in progress."""
batch = await self.retrieve_batch(batch_id)
if batch.status in ["cancelled", "cancelling"]:
return batch
if batch.status in ["completed", "failed", "expired"]:
raise ConflictError(f"Cannot cancel batch '{batch_id}' with status '{batch.status}'")
await self._update_batch(batch_id, status="cancelling", cancelling_at=int(time.time()))
if batch_id in self._processing_tasks:
self._processing_tasks[batch_id].cancel()
# note: task removal and status="cancelled" handled in finally block of _process_batch
return await self.retrieve_batch(batch_id)
async def list_batches(
self,
after: str | None = None,
limit: int = 20,
) -> ListBatchesResponse:
"""
List all batches, eventually only for the current user.
With no notion of user, we return all batches.
"""
batch_values = await self.kvstore.values_in_range("batch:", "batch:\xff")
batches = []
for batch_data in batch_values:
if batch_data:
batches.append(BatchObject.model_validate_json(batch_data))
batches.sort(key=lambda b: b.created_at, reverse=True)
start_idx = 0
if after:
for i, batch in enumerate(batches):
if batch.id == after:
start_idx = i + 1
break
page_batches = batches[start_idx : start_idx + limit]
has_more = (start_idx + limit) < len(batches)
first_id = page_batches[0].id if page_batches else None
last_id = page_batches[-1].id if page_batches else None
return ListBatchesResponse(
data=page_batches,
first_id=first_id,
last_id=last_id,
has_more=has_more,
)
async def retrieve_batch(self, batch_id: str) -> BatchObject:
"""Retrieve information about a specific batch."""
batch_data = await self.kvstore.get(f"batch:{batch_id}")
if not batch_data:
raise ResourceNotFoundError(batch_id, "Batch", "batches.list()")
return BatchObject.model_validate_json(batch_data)
async def _update_batch(self, batch_id: str, **updates) -> None:
"""Update batch fields in kvstore."""
async with self._update_batch_lock:
try:
batch = await self.retrieve_batch(batch_id)
# batch processing is async. once cancelling, only allow "cancelled" status updates
if batch.status == "cancelling" and updates.get("status") != "cancelled":
logger.info(
f"Skipping status update for cancelled batch {batch_id}: attempted {updates.get('status')}"
)
return
if "errors" in updates:
updates["errors"] = updates["errors"].model_dump()
batch_dict = batch.model_dump()
batch_dict.update(updates)
await self.kvstore.set(f"batch:{batch_id}", json.dumps(batch_dict))
except Exception as e:
logger.error(f"Failed to update batch {batch_id}: {e}")
async def _validate_input(self, batch: BatchObject) -> tuple[list[BatchError], list[BatchRequest]]:
"""
Read & validate input, return errors and valid input.
Validation of
- input_file_id existance
- valid json
- custom_id, method, url, body presence and valid
- no streaming
"""
requests: list[BatchRequest] = []
errors: list[BatchError] = []
try:
await self.files_api.openai_retrieve_file(batch.input_file_id)
except Exception:
errors.append(
BatchError(
code="invalid_request",
line=None,
message=f"Cannot find file {batch.input_file_id}.",
param="input_file_id",
)
)
return errors, requests
# TODO(SECURITY): do something about large files
file_content_response = await self.files_api.openai_retrieve_file_content(batch.input_file_id)
file_content = file_content_response.body.decode("utf-8")
for line_num, line in enumerate(file_content.strip().split("\n"), 1):
if line.strip(): # skip empty lines
try:
request = json.loads(line)
if not isinstance(request, dict):
errors.append(
BatchError(
code="invalid_request",
line=line_num,
message="Each line must be a JSON dictionary object",
)
)
continue
valid = True
for param, expected_type, type_string in [
("custom_id", str, "string"),
("method", str, "string"),
("url", str, "string"),
("body", dict, "JSON dictionary object"),
]:
if param not in request:
errors.append(
BatchError(
code="missing_required_parameter",
line=line_num,
message=f"Missing required parameter: {param}",
param=param,
)
)
valid = False
elif not isinstance(request[param], expected_type):
param_name = "URL" if param == "url" else param.capitalize()
errors.append(
BatchError(
code="invalid_request",
line=line_num,
message=f"{param_name} must be a {type_string}",
param=param,
)
)
valid = False
if (url := request.get("url")) and isinstance(url, str) and url != batch.endpoint:
errors.append(
BatchError(
code="invalid_url",
line=line_num,
message="URL provided for this request does not match the batch endpoint",
param="url",
)
)
valid = False
if (body := request.get("body")) and isinstance(body, dict):
if body.get("stream", False):
errors.append(
BatchError(
code="streaming_unsupported",
line=line_num,
message="Streaming is not supported in batch processing",
param="body.stream",
)
)
valid = False
for param, expected_type, type_string in [
("model", str, "a string"),
# messages is specific to /v1/chat/completions
# we could skip validating messages here and let inference fail. however,
# that would be a very expensive way to find out messages is wrong.
("messages", list, "an array"), # TODO: allow messages to be a string?
]:
if param not in body:
errors.append(
BatchError(
code="invalid_request",
line=line_num,
message=f"{param.capitalize()} parameter is required",
param=f"body.{param}",
)
)
valid = False
elif not isinstance(body[param], expected_type):
errors.append(
BatchError(
code="invalid_request",
line=line_num,
message=f"{param.capitalize()} must be {type_string}",
param=f"body.{param}",
)
)
valid = False
if "model" in body and isinstance(body["model"], str):
try:
await self.models_api.get_model(body["model"])
except Exception:
errors.append(
BatchError(
code="model_not_found",
line=line_num,
message=f"Model '{body['model']}' does not exist or is not supported",
param="body.model",
)
)
valid = False
if valid:
assert isinstance(url, str), "URL must be a string" # for mypy
assert isinstance(body, dict), "Body must be a dictionary" # for mypy
requests.append(
BatchRequest(
line_num=line_num,
url=url,
method=request["method"],
custom_id=request["custom_id"],
body=body,
),
)
except json.JSONDecodeError:
errors.append(
BatchError(
code="invalid_json_line",
line=line_num,
message="This line is not parseable as valid JSON.",
)
)
return errors, requests
async def _process_batch(self, batch_id: str) -> None:
"""Background task to process a batch of requests."""
try:
logger.info(f"Starting batch processing for {batch_id}")
async with self._batch_semaphore: # semaphore to limit concurrency
logger.info(f"Acquired semaphore for batch {batch_id}")
await self._process_batch_impl(batch_id)
except asyncio.CancelledError:
logger.info(f"Batch processing cancelled for {batch_id}")
await self._update_batch(batch_id, status="cancelled", cancelled_at=int(time.time()))
except Exception as e:
logger.error(f"Batch processing failed for {batch_id}: {e}")
await self._update_batch(
batch_id,
status="failed",
failed_at=int(time.time()),
errors=Errors(data=[BatchError(code="internal_error", message=str(e))]),
)
finally:
self._processing_tasks.pop(batch_id, None)
async def _process_batch_impl(self, batch_id: str) -> None:
"""Implementation of batch processing logic."""
errors: list[BatchError] = []
batch = await self.retrieve_batch(batch_id)
errors, requests = await self._validate_input(batch)
if errors:
await self._update_batch(batch_id, status="failed", failed_at=int(time.time()), errors=Errors(data=errors))
logger.info(f"Batch validation failed for {batch_id} with {len(errors)} errors")
return
logger.info(f"Processing {len(requests)} requests for batch {batch_id}")
total_requests = len(requests)
await self._update_batch(
batch_id,
status="in_progress",
request_counts={"total": total_requests, "completed": 0, "failed": 0},
)
error_results = []
success_results = []
completed_count = 0
failed_count = 0
for chunk in itertools.batched(requests, self.config.max_concurrent_requests_per_batch):
# we use a TaskGroup to ensure all process-single-request tasks are canceled when process-batch is cancelled
async with asyncio.TaskGroup() as tg:
chunk_tasks = [tg.create_task(self._process_single_request(batch_id, request)) for request in chunk]
chunk_results = await asyncio.gather(*chunk_tasks, return_exceptions=True)
for result in chunk_results:
if isinstance(result, dict) and result.get("error") is not None: # error response from inference
failed_count += 1
error_results.append(result)
elif isinstance(result, dict) and result.get("response") is not None: # successful inference
completed_count += 1
success_results.append(result)
else: # unexpected result
failed_count += 1
errors.append(BatchError(code="internal_error", message=f"Unexpected result: {result}"))
await self._update_batch(
batch_id,
request_counts={"total": total_requests, "completed": completed_count, "failed": failed_count},
)
if errors:
await self._update_batch(
batch_id, status="failed", failed_at=int(time.time()), errors=Errors(data=errors)
)
return
try:
output_file_id = await self._create_output_file(batch_id, success_results, "success")
await self._update_batch(batch_id, output_file_id=output_file_id)
error_file_id = await self._create_output_file(batch_id, error_results, "error")
await self._update_batch(batch_id, error_file_id=error_file_id)
await self._update_batch(batch_id, status="completed", completed_at=int(time.time()))
logger.info(
f"Batch processing completed for {batch_id}: {completed_count} completed, {failed_count} failed"
)
except Exception as e:
# note: errors is empty at this point, so we don't lose anything by ignoring it
await self._update_batch(
batch_id,
status="failed",
failed_at=int(time.time()),
errors=Errors(data=[BatchError(code="output_failed", message=str(e))]),
)
async def _process_single_request(self, batch_id: str, request: BatchRequest) -> dict:
"""Process a single request from the batch."""
request_id = f"batch_req_{batch_id}_{request.line_num}"
try:
# TODO(SECURITY): review body for security issues
request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]]
chat_response = await self.inference_api.openai_chat_completion(**request.body)
# this is for mypy, we don't allow streaming so we'll get the right type
assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method"
return {
"id": request_id,
"custom_id": request.custom_id,
"response": {
"status_code": 200,
"request_id": request_id, # TODO: should this be different?
"body": chat_response.model_dump_json(),
},
}
except Exception as e:
logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}")
return {
"id": request_id,
"custom_id": request.custom_id,
"error": {"type": "request_failed", "message": str(e)},
}
async def _create_output_file(self, batch_id: str, results: list[dict], file_type: str) -> str:
"""
Create an output file with batch results.
This function filters results based on the specified file_type
and uploads the file to the Files API.
"""
output_lines = [json.dumps(result) for result in results]
with AsyncBytesIO("\n".join(output_lines).encode("utf-8")) as file_buffer:
file_buffer.filename = f"{batch_id}_{file_type}.jsonl"
uploaded_file = await self.files_api.openai_upload_file(file=file_buffer, purpose=OpenAIFilePurpose.BATCH)
return uploaded_file.id

View file

@ -0,0 +1,40 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel, Field
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
class ReferenceBatchesImplConfig(BaseModel):
"""Configuration for the Reference Batches implementation."""
kvstore: KVStoreConfig = Field(
description="Configuration for the key-value store backend.",
)
max_concurrent_batches: int = Field(
default=1,
description="Maximum number of concurrent batches to process simultaneously.",
ge=1,
)
max_concurrent_requests_per_batch: int = Field(
default=10,
description="Maximum number of concurrent requests to process per batch.",
ge=1,
)
# TODO: add a max requests per second rate limiter
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> dict:
return {
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="batches.db",
),
}

View file

@ -5,7 +5,11 @@
# the root directory of this source tree. # the root directory of this source tree.
import logging import logging
from typing import Any import uuid
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from codeshield.cs import CodeShieldScanResult
from llama_stack.apis.inference import Message from llama_stack.apis.inference import Message
from llama_stack.apis.safety import ( from llama_stack.apis.safety import (
@ -14,6 +18,7 @@ from llama_stack.apis.safety import (
SafetyViolation, SafetyViolation,
ViolationLevel, ViolationLevel,
) )
from llama_stack.apis.safety.safety import ModerationObject, ModerationObjectResults
from llama_stack.apis.shields import Shield from llama_stack.apis.shields import Shield
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str, interleaved_content_as_str,
@ -24,8 +29,8 @@ from .config import CodeScannerConfig
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
ALLOWED_CODE_SCANNER_MODEL_IDS = [ ALLOWED_CODE_SCANNER_MODEL_IDS = [
"CodeScanner", "code-scanner",
"CodeShield", "code-shield",
] ]
@ -69,3 +74,55 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
metadata={"violation_type": ",".join([issue.pattern_id for issue in result.issues_found])}, metadata={"violation_type": ",".join([issue.pattern_id for issue in result.issues_found])},
) )
return RunShieldResponse(violation=violation) return RunShieldResponse(violation=violation)
def get_moderation_object_results(self, scan_result: "CodeShieldScanResult") -> ModerationObjectResults:
categories = {}
category_scores = {}
category_applied_input_types = {}
flagged = scan_result.is_insecure
user_message = None
metadata = {}
if scan_result.is_insecure:
pattern_ids = [issue.pattern_id for issue in scan_result.issues_found]
categories = dict.fromkeys(pattern_ids, True)
category_scores = dict.fromkeys(pattern_ids, 1.0)
category_applied_input_types = {key: ["text"] for key in pattern_ids}
user_message = f"Security concerns detected in the code. {scan_result.recommended_treatment.name}: {', '.join([issue.description for issue in scan_result.issues_found])}"
metadata = {"violation_type": ",".join([issue.pattern_id for issue in scan_result.issues_found])}
return ModerationObjectResults(
flagged=flagged,
categories=categories,
category_scores=category_scores,
category_applied_input_types=category_applied_input_types,
user_message=user_message,
metadata=metadata,
)
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
inputs = input if isinstance(input, list) else [input]
results = []
from codeshield.cs import CodeShield
for text_input in inputs:
log.info(f"Running CodeScannerShield moderation on input: {text_input[:100]}...")
try:
scan_result = await CodeShield.scan_code(text_input)
moderation_result = self.get_moderation_object_results(scan_result)
except Exception as e:
log.error(f"CodeShield.scan_code failed: {e}")
# create safe fallback response on scanner failure to avoid blocking legitimate requests
moderation_result = ModerationObjectResults(
flagged=False,
categories={},
category_scores={},
category_applied_input_types={},
user_message=None,
metadata={"scanner_error": str(e)},
)
results.append(moderation_result)
return ModerationObject(id=str(uuid.uuid4()), model=model, results=results)

View file

@ -11,11 +11,7 @@ from string import Template
from typing import Any from typing import Any
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
from llama_stack.apis.inference import ( from llama_stack.apis.inference import Inference, Message, UserMessage
Inference,
Message,
UserMessage,
)
from llama_stack.apis.safety import ( from llama_stack.apis.safety import (
RunShieldResponse, RunShieldResponse,
Safety, Safety,
@ -72,7 +68,6 @@ SAFETY_CATEGORIES_TO_CODE_MAP = {
} }
SAFETY_CODE_TO_CATEGORIES_MAP = {v: k for k, v in SAFETY_CATEGORIES_TO_CODE_MAP.items()} SAFETY_CODE_TO_CATEGORIES_MAP = {v: k for k, v in SAFETY_CATEGORIES_TO_CODE_MAP.items()}
DEFAULT_LG_V3_SAFETY_CATEGORIES = [ DEFAULT_LG_V3_SAFETY_CATEGORIES = [
CAT_VIOLENT_CRIMES, CAT_VIOLENT_CRIMES,
CAT_NON_VIOLENT_CRIMES, CAT_NON_VIOLENT_CRIMES,
@ -460,7 +455,7 @@ class LlamaGuardShield:
def is_content_safe(self, response: str, unsafe_code: str | None = None) -> bool: def is_content_safe(self, response: str, unsafe_code: str | None = None) -> bool:
"""Check if content is safe based on response and unsafe code.""" """Check if content is safe based on response and unsafe code."""
if response.strip() == SAFE_RESPONSE: if response.strip().lower().startswith(SAFE_RESPONSE):
return True return True
if unsafe_code: if unsafe_code:

View file

@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec
def available_providers() -> list[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.batches,
provider_type="inline::reference",
pip_packages=["openai"],
module="llama_stack.providers.inline.batches.reference",
config_class="llama_stack.providers.inline.batches.reference.config.ReferenceBatchesImplConfig",
api_dependencies=[
Api.inference,
Api.files,
Api.models,
],
description="Reference implementation of batches API with KVStore persistence.",
),
]

View file

@ -31,15 +31,21 @@ from openai.types.chat import (
from openai.types.chat import ( from openai.types.chat import (
ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam, ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam,
) )
try:
from openai.types.chat import (
ChatCompletionMessageFunctionToolCall as OpenAIChatCompletionMessageFunctionToolCall,
)
except ImportError:
from openai.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall as OpenAIChatCompletionMessageFunctionToolCall,
)
from openai.types.chat import ( from openai.types.chat import (
ChatCompletionMessageParam as OpenAIChatCompletionMessage, ChatCompletionMessageParam as OpenAIChatCompletionMessage,
) )
from openai.types.chat import ( from openai.types.chat import (
ChatCompletionMessageToolCall, ChatCompletionMessageToolCall,
) )
from openai.types.chat import (
ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCall,
)
from openai.types.chat import ( from openai.types.chat import (
ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage, ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage,
) )
@ -633,7 +639,7 @@ async def convert_message_to_openai_dict_new(
) )
elif isinstance(message, CompletionMessage): elif isinstance(message, CompletionMessage):
tool_calls = [ tool_calls = [
OpenAIChatCompletionMessageToolCall( OpenAIChatCompletionMessageFunctionToolCall(
id=tool.call_id, id=tool.call_id,
function=OpenAIFunction( function=OpenAIFunction(
name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value), name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value),
@ -903,7 +909,7 @@ def _convert_openai_request_response_format(
def _convert_openai_tool_calls( def _convert_openai_tool_calls(
tool_calls: list[OpenAIChatCompletionMessageToolCall], tool_calls: list[OpenAIChatCompletionMessageFunctionToolCall],
) -> list[ToolCall]: ) -> list[ToolCall]:
""" """
Convert an OpenAI ChatCompletionMessageToolCall list into a list of ToolCall. Convert an OpenAI ChatCompletionMessageToolCall list into a list of ToolCall.

View file

@ -77,6 +77,8 @@ class PostgresKVStoreConfig(CommonConfig):
db: str = "llamastack" db: str = "llamastack"
user: str user: str
password: str | None = None password: str | None = None
ssl_mode: str | None = None
ca_cert_path: str | None = None
table_name: str = "llamastack_kvstore" table_name: str = "llamastack_kvstore"
@classmethod @classmethod

View file

@ -30,6 +30,8 @@ class PostgresKVStoreImpl(KVStore):
database=self.config.db, database=self.config.db,
user=self.config.user, user=self.config.user,
password=self.config.password, password=self.config.password,
sslmode=self.config.ssl_mode,
sslrootcert=self.config.ca_cert_path,
) )
self.conn.autocommit = True self.conn.autocommit = True
self.cursor = self.conn.cursor(cursor_factory=DictCursor) self.cursor = self.conn.cursor(cursor_factory=DictCursor)

View file

@ -261,7 +261,7 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
else: else:
raise RuntimeError( raise RuntimeError(
f"No recorded response found for request hash: {request_hash}\n" f"No recorded response found for request hash: {request_hash}\n"
f"Endpoint: {endpoint}\n" f"Request: {method} {url} {body}\n"
f"Model: {body.get('model', 'unknown')}\n" f"Model: {body.get('model', 'unknown')}\n"
f"To record this response, run with LLAMA_STACK_INFERENCE_MODE=record" f"To record this response, run with LLAMA_STACK_INFERENCE_MODE=record"
) )

1
llama_stack/ui/.nvmrc Normal file
View file

@ -0,0 +1 @@
22.5.1

View file

@ -1,3 +1,12 @@
# Ignore artifacts: # Ignore artifacts:
build build
coverage coverage
.next
node_modules
dist
*.lock
*.log
# Generated files
*.min.js
*.min.css

View file

@ -1 +1,10 @@
{} {
"semi": true,
"trailingComma": "es5",
"singleQuote": false,
"printWidth": 80,
"tabWidth": 2,
"useTabs": false,
"bracketSpacing": true,
"arrowParens": "avoid"
}

View file

@ -47,7 +47,7 @@ async function proxyRequest(request: NextRequest, method: string) {
const responseText = await response.text(); const responseText = await response.text();
console.log( console.log(
`Response from FastAPI: ${response.status} ${response.statusText}`, `Response from FastAPI: ${response.status} ${response.statusText}`
); );
// Create response with same status and headers // Create response with same status and headers
@ -74,7 +74,7 @@ async function proxyRequest(request: NextRequest, method: string) {
backend_url: BACKEND_URL, backend_url: BACKEND_URL,
timestamp: new Date().toISOString(), timestamp: new Date().toISOString(),
}, },
{ status: 500 }, { status: 500 }
); );
} }
} }

View file

@ -51,9 +51,9 @@ export default function SignInPage() {
onClick={() => { onClick={() => {
console.log("Signing in with GitHub..."); console.log("Signing in with GitHub...");
signIn("github", { callbackUrl: "/auth/signin" }).catch( signIn("github", { callbackUrl: "/auth/signin" }).catch(
(error) => { error => {
console.error("Sign in error:", error); console.error("Sign in error:", error);
}, }
); );
}} }}
className="w-full" className="w-full"

View file

@ -29,14 +29,13 @@ export default function ChatPlaygroundPage() {
const isModelsLoading = modelsLoading ?? true; const isModelsLoading = modelsLoading ?? true;
useEffect(() => { useEffect(() => {
const fetchModels = async () => { const fetchModels = async () => {
try { try {
setModelsLoading(true); setModelsLoading(true);
setModelsError(null); setModelsError(null);
const modelList = await client.models.list(); const modelList = await client.models.list();
const llmModels = modelList.filter(model => model.model_type === 'llm'); const llmModels = modelList.filter(model => model.model_type === "llm");
setModels(llmModels); setModels(llmModels);
if (llmModels.length > 0) { if (llmModels.length > 0) {
setSelectedModel(llmModels[0].identifier); setSelectedModel(llmModels[0].identifier);
@ -53,103 +52,122 @@ export default function ChatPlaygroundPage() {
}, [client]); }, [client]);
const extractTextContent = (content: unknown): string => { const extractTextContent = (content: unknown): string => {
if (typeof content === 'string') { if (typeof content === "string") {
return content; return content;
} }
if (Array.isArray(content)) { if (Array.isArray(content)) {
return content return content
.filter(item => item && typeof item === 'object' && 'type' in item && item.type === 'text') .filter(
.map(item => (item && typeof item === 'object' && 'text' in item) ? String(item.text) : '') item =>
.join(''); item &&
typeof item === "object" &&
"type" in item &&
item.type === "text"
)
.map(item =>
item && typeof item === "object" && "text" in item
? String(item.text)
: ""
)
.join("");
} }
if (content && typeof content === 'object' && 'type' in content && content.type === 'text' && 'text' in content) { if (
return String(content.text) || ''; content &&
typeof content === "object" &&
"type" in content &&
content.type === "text" &&
"text" in content
) {
return String(content.text) || "";
} }
return ''; return "";
}; };
const handleInputChange = (e: React.ChangeEvent<HTMLTextAreaElement>) => { const handleInputChange = (e: React.ChangeEvent<HTMLTextAreaElement>) => {
setInput(e.target.value); setInput(e.target.value);
}; };
const handleSubmit = async (event?: { preventDefault?: () => void }) => { const handleSubmit = async (event?: { preventDefault?: () => void }) => {
event?.preventDefault?.(); event?.preventDefault?.();
if (!input.trim()) return; if (!input.trim()) return;
// Add user message to chat // Add user message to chat
const userMessage: Message = { const userMessage: Message = {
id: Date.now().toString(), id: Date.now().toString(),
role: "user", role: "user",
content: input.trim(), content: input.trim(),
createdAt: new Date(),
};
setMessages(prev => [...prev, userMessage]);
setInput("");
// Use the helper function with the content
await handleSubmitWithContent(userMessage.content);
};
const handleSubmitWithContent = async (content: string) => {
setIsGenerating(true);
setError(null);
try {
const messageParams: CompletionCreateParams["messages"] = [
...messages.map(msg => {
const msgContent = typeof msg.content === 'string' ? msg.content : extractTextContent(msg.content);
if (msg.role === "user") {
return { role: "user" as const, content: msgContent };
} else if (msg.role === "assistant") {
return { role: "assistant" as const, content: msgContent };
} else {
return { role: "system" as const, content: msgContent };
}
}),
{ role: "user" as const, content }
];
const response = await client.chat.completions.create({
model: selectedModel,
messages: messageParams,
stream: true,
});
const assistantMessage: Message = {
id: (Date.now() + 1).toString(),
role: "assistant",
content: "",
createdAt: new Date(), createdAt: new Date(),
}; };
setMessages(prev => [...prev, assistantMessage]); setMessages(prev => [...prev, userMessage]);
let fullContent = ""; setInput("");
for await (const chunk of response) {
if (chunk.choices && chunk.choices[0]?.delta?.content) {
const deltaContent = chunk.choices[0].delta.content;
fullContent += deltaContent;
flushSync(() => { // Use the helper function with the content
setMessages(prev => { await handleSubmitWithContent(userMessage.content);
const newMessages = [...prev]; };
const lastMessage = newMessages[newMessages.length - 1];
if (lastMessage.role === "assistant") { const handleSubmitWithContent = async (content: string) => {
lastMessage.content = fullContent; setIsGenerating(true);
} setError(null);
return newMessages;
try {
const messageParams: CompletionCreateParams["messages"] = [
...messages.map(msg => {
const msgContent =
typeof msg.content === "string"
? msg.content
: extractTextContent(msg.content);
if (msg.role === "user") {
return { role: "user" as const, content: msgContent };
} else if (msg.role === "assistant") {
return { role: "assistant" as const, content: msgContent };
} else {
return { role: "system" as const, content: msgContent };
}
}),
{ role: "user" as const, content },
];
const response = await client.chat.completions.create({
model: selectedModel,
messages: messageParams,
stream: true,
});
const assistantMessage: Message = {
id: (Date.now() + 1).toString(),
role: "assistant",
content: "",
createdAt: new Date(),
};
setMessages(prev => [...prev, assistantMessage]);
let fullContent = "";
for await (const chunk of response) {
if (chunk.choices && chunk.choices[0]?.delta?.content) {
const deltaContent = chunk.choices[0].delta.content;
fullContent += deltaContent;
flushSync(() => {
setMessages(prev => {
const newMessages = [...prev];
const lastMessage = newMessages[newMessages.length - 1];
if (lastMessage.role === "assistant") {
lastMessage.content = fullContent;
}
return newMessages;
});
}); });
}); }
} }
} catch (err) {
console.error("Error sending message:", err);
setError("Failed to send message. Please try again.");
setMessages(prev => prev.slice(0, -1));
} finally {
setIsGenerating(false);
} }
} catch (err) { };
console.error("Error sending message:", err);
setError("Failed to send message. Please try again.");
setMessages(prev => prev.slice(0, -1));
} finally {
setIsGenerating(false);
}
};
const suggestions = [ const suggestions = [
"Write a Python function that prints 'Hello, World!'", "Write a Python function that prints 'Hello, World!'",
"Explain step-by-step how to solve this math problem: If x² + 6x + 9 = 25, what is x?", "Explain step-by-step how to solve this math problem: If x² + 6x + 9 = 25, what is x?",
@ -163,7 +181,7 @@ const handleSubmitWithContent = async (content: string) => {
content: message.content, content: message.content,
createdAt: new Date(), createdAt: new Date(),
}; };
setMessages(prev => [...prev, newMessage]) setMessages(prev => [...prev, newMessage]);
handleSubmitWithContent(newMessage.content); handleSubmitWithContent(newMessage.content);
}; };
@ -177,12 +195,20 @@ const handleSubmitWithContent = async (content: string) => {
<div className="mb-4 flex justify-between items-center"> <div className="mb-4 flex justify-between items-center">
<h1 className="text-2xl font-bold">Chat Playground (Completions)</h1> <h1 className="text-2xl font-bold">Chat Playground (Completions)</h1>
<div className="flex gap-2"> <div className="flex gap-2">
<Select value={selectedModel} onValueChange={setSelectedModel} disabled={isModelsLoading || isGenerating}> <Select
value={selectedModel}
onValueChange={setSelectedModel}
disabled={isModelsLoading || isGenerating}
>
<SelectTrigger className="w-[180px]"> <SelectTrigger className="w-[180px]">
<SelectValue placeholder={isModelsLoading ? "Loading models..." : "Select Model"} /> <SelectValue
placeholder={
isModelsLoading ? "Loading models..." : "Select Model"
}
/>
</SelectTrigger> </SelectTrigger>
<SelectContent> <SelectContent>
{models.map((model) => ( {models.map(model => (
<SelectItem key={model.identifier} value={model.identifier}> <SelectItem key={model.identifier} value={model.identifier}>
{model.identifier} {model.identifier}
</SelectItem> </SelectItem>

View file

@ -33,12 +33,12 @@ export default function ChatCompletionDetailPage() {
} catch (err) { } catch (err) {
console.error( console.error(
`Error fetching chat completion detail for ID ${id}:`, `Error fetching chat completion detail for ID ${id}:`,
err, err
); );
setError( setError(
err instanceof Error err instanceof Error
? err ? err
: new Error("Failed to fetch completion detail"), : new Error("Failed to fetch completion detail")
); );
} finally { } finally {
setIsLoading(false); setIsLoading(false);

View file

@ -13,10 +13,10 @@ export default function ResponseDetailPage() {
const client = useAuthClient(); const client = useAuthClient();
const [responseDetail, setResponseDetail] = useState<OpenAIResponse | null>( const [responseDetail, setResponseDetail] = useState<OpenAIResponse | null>(
null, null
); );
const [inputItems, setInputItems] = useState<InputItemListResponse | null>( const [inputItems, setInputItems] = useState<InputItemListResponse | null>(
null, null
); );
const [isLoading, setIsLoading] = useState<boolean>(true); const [isLoading, setIsLoading] = useState<boolean>(true);
const [isLoadingInputItems, setIsLoadingInputItems] = useState<boolean>(true); const [isLoadingInputItems, setIsLoadingInputItems] = useState<boolean>(true);
@ -25,7 +25,7 @@ export default function ResponseDetailPage() {
// Helper function to convert ResponseObject to OpenAIResponse // Helper function to convert ResponseObject to OpenAIResponse
const convertResponseObject = ( const convertResponseObject = (
responseData: ResponseObject, responseData: ResponseObject
): OpenAIResponse => { ): OpenAIResponse => {
return { return {
id: responseData.id, id: responseData.id,
@ -73,12 +73,12 @@ export default function ResponseDetailPage() {
} else { } else {
console.error( console.error(
`Error fetching response detail for ID ${id}:`, `Error fetching response detail for ID ${id}:`,
responseResult.reason, responseResult.reason
); );
setError( setError(
responseResult.reason instanceof Error responseResult.reason instanceof Error
? responseResult.reason ? responseResult.reason
: new Error("Failed to fetch response detail"), : new Error("Failed to fetch response detail")
); );
} }
@ -90,18 +90,18 @@ export default function ResponseDetailPage() {
} else { } else {
console.error( console.error(
`Error fetching input items for response ID ${id}:`, `Error fetching input items for response ID ${id}:`,
inputItemsResult.reason, inputItemsResult.reason
); );
setInputItemsError( setInputItemsError(
inputItemsResult.reason instanceof Error inputItemsResult.reason instanceof Error
? inputItemsResult.reason ? inputItemsResult.reason
: new Error("Failed to fetch input items"), : new Error("Failed to fetch input items")
); );
} }
} catch (err) { } catch (err) {
console.error(`Unexpected error fetching data for ID ${id}:`, err); console.error(`Unexpected error fetching data for ID ${id}:`, err);
setError( setError(
err instanceof Error ? err : new Error("Unexpected error occurred"), err instanceof Error ? err : new Error("Unexpected error occurred")
); );
} finally { } finally {
setIsLoading(false); setIsLoading(false);

View file

@ -0,0 +1,425 @@
import React from "react";
import { render, screen, fireEvent, waitFor } from "@testing-library/react";
import "@testing-library/jest-dom";
import ContentDetailPage from "./page";
import { VectorStoreContentItem } from "@/lib/contents-api";
import type { VectorStore } from "llama-stack-client/resources/vector-stores/vector-stores";
import type { VectorStoreFile } from "llama-stack-client/resources/vector-stores/files";
const mockPush = jest.fn();
const mockParams = {
id: "vs_123",
fileId: "file_456",
contentId: "content_789",
};
jest.mock("next/navigation", () => ({
useParams: () => mockParams,
useRouter: () => ({
push: mockPush,
}),
}));
const mockClient = {
vectorStores: {
retrieve: jest.fn(),
files: {
retrieve: jest.fn(),
},
},
};
jest.mock("@/hooks/use-auth-client", () => ({
useAuthClient: () => mockClient,
}));
const mockContentsAPI = {
listContents: jest.fn(),
updateContent: jest.fn(),
deleteContent: jest.fn(),
};
jest.mock("@/lib/contents-api", () => ({
ContentsAPI: jest.fn(() => mockContentsAPI),
}));
const originalConfirm = window.confirm;
describe("ContentDetailPage", () => {
const mockStore: VectorStore = {
id: "vs_123",
name: "Test Vector Store",
created_at: 1710000000,
status: "ready",
file_counts: { total: 5 },
usage_bytes: 1024,
metadata: {
provider_id: "test_provider",
},
};
const mockFile: VectorStoreFile = {
id: "file_456",
status: "completed",
created_at: 1710001000,
usage_bytes: 512,
chunking_strategy: { type: "fixed_size" },
};
const mockContent: VectorStoreContentItem = {
id: "content_789",
object: "vector_store.content",
content: "This is test content for the vector store.",
embedding: [0.1, 0.2, 0.3, 0.4, 0.5],
metadata: {
chunk_window: "0-45",
content_length: 45,
custom_field: "custom_value",
},
created_timestamp: 1710002000,
};
beforeEach(() => {
jest.clearAllMocks();
window.confirm = jest.fn();
mockClient.vectorStores.retrieve.mockResolvedValue(mockStore);
mockClient.vectorStores.files.retrieve.mockResolvedValue(mockFile);
mockContentsAPI.listContents.mockResolvedValue({
data: [mockContent],
});
});
afterEach(() => {
window.confirm = originalConfirm;
});
describe("Loading and Error States", () => {
test("renders loading skeleton while fetching data", () => {
mockClient.vectorStores.retrieve.mockImplementation(
() => new Promise(() => {})
);
const { container } = render(<ContentDetailPage />);
const skeletons = container.querySelectorAll('[data-slot="skeleton"]');
expect(skeletons.length).toBeGreaterThan(0);
});
test("renders error message when API calls fail", async () => {
const error = new Error("Network error");
mockClient.vectorStores.retrieve.mockRejectedValue(error);
render(<ContentDetailPage />);
await waitFor(() => {
expect(
screen.getByText(/Error loading details for ID content_789/)
).toBeInTheDocument();
expect(screen.getByText(/Network error/)).toBeInTheDocument();
});
});
test("renders not found when content doesn't exist", async () => {
mockContentsAPI.listContents.mockResolvedValue({
data: [],
});
render(<ContentDetailPage />);
await waitFor(() => {
expect(
screen.getByText(/Content content_789 not found/)
).toBeInTheDocument();
});
});
});
describe("Content Display", () => {
test("renders content details correctly", async () => {
render(<ContentDetailPage />);
await waitFor(() => {
expect(screen.getByText("Content: content_789")).toBeInTheDocument();
expect(
screen.getByText("This is test content for the vector store.")
).toBeInTheDocument();
});
const contentIdTexts = screen.getAllByText("content_789");
expect(contentIdTexts.length).toBeGreaterThan(0);
const fileIdTexts = screen.getAllByText("file_456");
expect(fileIdTexts.length).toBeGreaterThan(0);
const storeIdTexts = screen.getAllByText("vs_123");
expect(storeIdTexts.length).toBeGreaterThan(0);
expect(screen.getByText("vector_store.content")).toBeInTheDocument();
const positionTexts = screen.getAllByText("0-45");
expect(positionTexts.length).toBeGreaterThan(0);
});
test("renders embedding information when available", async () => {
render(<ContentDetailPage />);
await waitFor(() => {
expect(
screen.getByText(/0.100000, 0.200000, 0.300000/)
).toBeInTheDocument();
});
});
test("handles content without embedding", async () => {
const contentWithoutEmbedding = {
...mockContent,
embedding: undefined,
};
mockContentsAPI.listContents.mockResolvedValue({
data: [contentWithoutEmbedding],
});
render(<ContentDetailPage />);
await waitFor(() => {
expect(
screen.getByText("No embedding available for this content.")
).toBeInTheDocument();
});
});
test("renders metadata correctly", async () => {
render(<ContentDetailPage />);
await waitFor(() => {
expect(screen.getByText("chunk_window:")).toBeInTheDocument();
const positionTexts = screen.getAllByText("0-45");
expect(positionTexts.length).toBeGreaterThan(0);
expect(screen.getByText("content_length:")).toBeInTheDocument();
expect(screen.getByText("custom_field:")).toBeInTheDocument();
expect(screen.getByText("custom_value")).toBeInTheDocument();
});
});
});
describe("Edit Functionality", () => {
test("enables edit mode when edit button is clicked", async () => {
render(<ContentDetailPage />);
await waitFor(() => {
expect(
screen.getByText("This is test content for the vector store.")
).toBeInTheDocument();
});
const editButtons = screen.getAllByRole("button", { name: /Edit/ });
const editButton = editButtons[0];
fireEvent.click(editButton);
expect(
screen.getByDisplayValue("This is test content for the vector store.")
).toBeInTheDocument();
expect(screen.getByRole("button", { name: /Save/ })).toBeInTheDocument();
expect(
screen.getByRole("button", { name: /Cancel/ })
).toBeInTheDocument();
});
test("cancels edit mode and resets content", async () => {
render(<ContentDetailPage />);
await waitFor(() => {
expect(
screen.getByText("This is test content for the vector store.")
).toBeInTheDocument();
});
const editButtons = screen.getAllByRole("button", { name: /Edit/ });
const editButton = editButtons[0];
fireEvent.click(editButton);
const textarea = screen.getByDisplayValue(
"This is test content for the vector store."
);
fireEvent.change(textarea, { target: { value: "Modified content" } });
const cancelButton = screen.getByRole("button", { name: /Cancel/ });
fireEvent.click(cancelButton);
expect(
screen.getByText("This is test content for the vector store.")
).toBeInTheDocument();
expect(
screen.queryByDisplayValue("Modified content")
).not.toBeInTheDocument();
});
test("saves content changes", async () => {
const updatedContent = { ...mockContent, content: "Updated content" };
mockContentsAPI.updateContent.mockResolvedValue(updatedContent);
render(<ContentDetailPage />);
await waitFor(() => {
expect(
screen.getByText("This is test content for the vector store.")
).toBeInTheDocument();
});
const editButtons = screen.getAllByRole("button", { name: /Edit/ });
const editButton = editButtons[0];
fireEvent.click(editButton);
const textarea = screen.getByDisplayValue(
"This is test content for the vector store."
);
fireEvent.change(textarea, { target: { value: "Updated content" } });
const saveButton = screen.getByRole("button", { name: /Save/ });
fireEvent.click(saveButton);
await waitFor(() => {
expect(mockContentsAPI.updateContent).toHaveBeenCalledWith(
"vs_123",
"file_456",
"content_789",
{ content: "Updated content" }
);
});
});
});
describe("Delete Functionality", () => {
test("shows confirmation dialog before deleting", async () => {
window.confirm = jest.fn().mockReturnValue(false);
render(<ContentDetailPage />);
await waitFor(() => {
expect(
screen.getByText("This is test content for the vector store.")
).toBeInTheDocument();
});
const deleteButton = screen.getByRole("button", { name: /Delete/ });
fireEvent.click(deleteButton);
expect(window.confirm).toHaveBeenCalledWith(
"Are you sure you want to delete this content?"
);
expect(mockContentsAPI.deleteContent).not.toHaveBeenCalled();
});
test("deletes content when confirmed", async () => {
window.confirm = jest.fn().mockReturnValue(true);
render(<ContentDetailPage />);
await waitFor(() => {
expect(
screen.getByText("This is test content for the vector store.")
).toBeInTheDocument();
});
const deleteButton = screen.getByRole("button", { name: /Delete/ });
fireEvent.click(deleteButton);
await waitFor(() => {
expect(mockContentsAPI.deleteContent).toHaveBeenCalledWith(
"vs_123",
"file_456",
"content_789"
);
expect(mockPush).toHaveBeenCalledWith(
"/logs/vector-stores/vs_123/files/file_456/contents"
);
});
});
});
describe("Embedding Edit Functionality", () => {
test("enables embedding edit mode", async () => {
render(<ContentDetailPage />);
await waitFor(() => {
expect(
screen.getByText("This is test content for the vector store.")
).toBeInTheDocument();
});
const embeddingEditButtons = screen.getAllByRole("button", {
name: /Edit/,
});
expect(embeddingEditButtons.length).toBeGreaterThanOrEqual(1);
});
test.skip("cancels embedding edit mode", async () => {
render(<ContentDetailPage />);
await waitFor(() => {
// skip vector text check, just verify test completes
});
const embeddingEditButtons = screen.getAllByRole("button", {
name: /Edit/,
});
const embeddingEditButton = embeddingEditButtons[1];
fireEvent.click(embeddingEditButton);
const cancelButtons = screen.getAllByRole("button", { name: /Cancel/ });
expect(cancelButtons.length).toBeGreaterThan(0);
expect(
screen.queryByDisplayValue(/0.1,0.2,0.3,0.4,0.5/)
).not.toBeInTheDocument();
});
});
describe("Breadcrumb Navigation", () => {
test("renders correct breadcrumb structure", async () => {
render(<ContentDetailPage />);
await waitFor(() => {
const vectorStoreTexts = screen.getAllByText("Vector Stores");
expect(vectorStoreTexts.length).toBeGreaterThan(0);
const storeNameTexts = screen.getAllByText("Test Vector Store");
expect(storeNameTexts.length).toBeGreaterThan(0);
const contentsTexts = screen.getAllByText("Contents");
expect(contentsTexts.length).toBeGreaterThan(0);
});
});
});
describe("Content Utilities", () => {
test("handles different content types correctly", async () => {
const contentWithObjectType = {
...mockContent,
content: { type: "text", text: "Text object content" },
};
mockContentsAPI.listContents.mockResolvedValue({
data: [contentWithObjectType],
});
render(<ContentDetailPage />);
await waitFor(() => {
expect(screen.getByText("Text object content")).toBeInTheDocument();
});
});
test("handles string content type", async () => {
const contentWithStringType = {
...mockContent,
content: "Simple string content",
};
mockContentsAPI.listContents.mockResolvedValue({
data: [contentWithStringType],
});
render(<ContentDetailPage />);
await waitFor(() => {
expect(screen.getByText("Simple string content")).toBeInTheDocument();
});
});
});
});

View file

@ -18,7 +18,10 @@ import {
PropertiesCard, PropertiesCard,
PropertyItem, PropertyItem,
} from "@/components/layout/detail-layout"; } from "@/components/layout/detail-layout";
import { PageBreadcrumb, BreadcrumbSegment } from "@/components/layout/page-breadcrumb"; import {
PageBreadcrumb,
BreadcrumbSegment,
} from "@/components/layout/page-breadcrumb";
export default function ContentDetailPage() { export default function ContentDetailPage() {
const params = useParams(); const params = useParams();
@ -28,13 +31,13 @@ export default function ContentDetailPage() {
const contentId = params.contentId as string; const contentId = params.contentId as string;
const client = useAuthClient(); const client = useAuthClient();
const getTextFromContent = (content: any): string => { const getTextFromContent = (content: unknown): string => {
if (typeof content === 'string') { if (typeof content === "string") {
return content; return content;
} else if (content && content.type === 'text') { } else if (content && content.type === "text") {
return content.text; return content.text;
} }
return ''; return "";
}; };
const [store, setStore] = useState<VectorStore | null>(null); const [store, setStore] = useState<VectorStore | null>(null);
@ -44,7 +47,9 @@ export default function ContentDetailPage() {
const [error, setError] = useState<Error | null>(null); const [error, setError] = useState<Error | null>(null);
const [isEditing, setIsEditing] = useState(false); const [isEditing, setIsEditing] = useState(false);
const [editedContent, setEditedContent] = useState(""); const [editedContent, setEditedContent] = useState("");
const [editedMetadata, setEditedMetadata] = useState<Record<string, any>>({}); const [editedMetadata, setEditedMetadata] = useState<Record<string, unknown>>(
{}
);
const [isEditingEmbedding, setIsEditingEmbedding] = useState(false); const [isEditingEmbedding, setIsEditingEmbedding] = useState(false);
const [editedEmbedding, setEditedEmbedding] = useState<number[]>([]); const [editedEmbedding, setEditedEmbedding] = useState<number[]>([]);
@ -64,8 +69,13 @@ export default function ContentDetailPage() {
setFile(fileResponse as VectorStoreFile); setFile(fileResponse as VectorStoreFile);
const contentsAPI = new ContentsAPI(client); const contentsAPI = new ContentsAPI(client);
const contentsResponse = await contentsAPI.listContents(vectorStoreId, fileId); const contentsResponse = await contentsAPI.listContents(
const targetContent = contentsResponse.data.find(c => c.id === contentId); vectorStoreId,
fileId
);
const targetContent = contentsResponse.data.find(
c => c.id === contentId
);
if (targetContent) { if (targetContent) {
setContent(targetContent); setContent(targetContent);
@ -76,7 +86,9 @@ export default function ContentDetailPage() {
throw new Error(`Content ${contentId} not found`); throw new Error(`Content ${contentId} not found`);
} }
} catch (err) { } catch (err) {
setError(err instanceof Error ? err : new Error("Failed to load content.")); setError(
err instanceof Error ? err : new Error("Failed to load content.")
);
} finally { } finally {
setIsLoading(false); setIsLoading(false);
} }
@ -88,7 +100,8 @@ export default function ContentDetailPage() {
if (!content) return; if (!content) return;
try { try {
const updates: { content?: string; metadata?: Record<string, any> } = {}; const updates: { content?: string; metadata?: Record<string, unknown> } =
{};
if (editedContent !== getTextFromContent(content.content)) { if (editedContent !== getTextFromContent(content.content)) {
updates.content = editedContent; updates.content = editedContent;
@ -100,25 +113,32 @@ export default function ContentDetailPage() {
if (Object.keys(updates).length > 0) { if (Object.keys(updates).length > 0) {
const contentsAPI = new ContentsAPI(client); const contentsAPI = new ContentsAPI(client);
const updatedContent = await contentsAPI.updateContent(vectorStoreId, fileId, contentId, updates); const updatedContent = await contentsAPI.updateContent(
vectorStoreId,
fileId,
contentId,
updates
);
setContent(updatedContent); setContent(updatedContent);
} }
setIsEditing(false); setIsEditing(false);
} catch (err) { } catch (err) {
console.error('Failed to update content:', err); console.error("Failed to update content:", err);
} }
}; };
const handleDelete = async () => { const handleDelete = async () => {
if (!confirm('Are you sure you want to delete this content?')) return; if (!confirm("Are you sure you want to delete this content?")) return;
try { try {
const contentsAPI = new ContentsAPI(client); const contentsAPI = new ContentsAPI(client);
await contentsAPI.deleteContent(vectorStoreId, fileId, contentId); await contentsAPI.deleteContent(vectorStoreId, fileId, contentId);
router.push(`/logs/vector-stores/${vectorStoreId}/files/${fileId}/contents`); router.push(
`/logs/vector-stores/${vectorStoreId}/files/${fileId}/contents`
);
} catch (err) { } catch (err) {
console.error('Failed to delete content:', err); console.error("Failed to delete content:", err);
} }
}; };
@ -134,10 +154,19 @@ export default function ContentDetailPage() {
const breadcrumbSegments: BreadcrumbSegment[] = [ const breadcrumbSegments: BreadcrumbSegment[] = [
{ label: "Vector Stores", href: "/logs/vector-stores" }, { label: "Vector Stores", href: "/logs/vector-stores" },
{ label: store?.name || vectorStoreId, href: `/logs/vector-stores/${vectorStoreId}` }, {
label: store?.name || vectorStoreId,
href: `/logs/vector-stores/${vectorStoreId}`,
},
{ label: "Files", href: `/logs/vector-stores/${vectorStoreId}` }, { label: "Files", href: `/logs/vector-stores/${vectorStoreId}` },
{ label: fileId, href: `/logs/vector-stores/${vectorStoreId}/files/${fileId}` }, {
{ label: "Contents", href: `/logs/vector-stores/${vectorStoreId}/files/${fileId}/contents` }, label: fileId,
href: `/logs/vector-stores/${vectorStoreId}/files/${fileId}`,
},
{
label: "Contents",
href: `/logs/vector-stores/${vectorStoreId}/files/${fileId}/contents`,
},
{ label: contentId }, { label: contentId },
]; ];
@ -186,7 +215,7 @@ export default function ContentDetailPage() {
{isEditing ? ( {isEditing ? (
<textarea <textarea
value={editedContent} value={editedContent}
onChange={(e) => setEditedContent(e.target.value)} onChange={e => setEditedContent(e.target.value)}
className="w-full h-64 p-3 border rounded-md resize-none font-mono text-sm" className="w-full h-64 p-3 border rounded-md resize-none font-mono text-sm"
placeholder="Enter content..." placeholder="Enter content..."
/> />
@ -206,16 +235,23 @@ export default function ContentDetailPage() {
<div className="flex gap-2"> <div className="flex gap-2">
{isEditingEmbedding ? ( {isEditingEmbedding ? (
<> <>
<Button size="sm" onClick={() => { <Button
setIsEditingEmbedding(false); size="sm"
}}> onClick={() => {
setIsEditingEmbedding(false);
}}
>
<Save className="h-4 w-4 mr-1" /> <Save className="h-4 w-4 mr-1" />
Save Save
</Button> </Button>
<Button size="sm" variant="outline" onClick={() => { <Button
setEditedEmbedding(content?.embedding || []); size="sm"
setIsEditingEmbedding(false); variant="outline"
}}> onClick={() => {
setEditedEmbedding(content?.embedding || []);
setIsEditingEmbedding(false);
}}
>
<X className="h-4 w-4 mr-1" /> <X className="h-4 w-4 mr-1" />
Cancel Cancel
</Button> </Button>
@ -237,14 +273,16 @@ export default function ContentDetailPage() {
</p> </p>
<textarea <textarea
value={JSON.stringify(editedEmbedding, null, 2)} value={JSON.stringify(editedEmbedding, null, 2)}
onChange={(e) => { onChange={e => {
try { try {
const parsed = JSON.parse(e.target.value); const parsed = JSON.parse(e.target.value);
if (Array.isArray(parsed) && parsed.every(v => typeof v === 'number')) { if (
Array.isArray(parsed) &&
parsed.every(v => typeof v === "number")
) {
setEditedEmbedding(parsed); setEditedEmbedding(parsed);
} }
} catch { } catch {}
}
}} }}
className="w-full h-32 p-3 border rounded-md resize-none font-mono text-xs" className="w-full h-32 p-3 border rounded-md resize-none font-mono text-xs"
placeholder="Enter embedding as JSON array..." placeholder="Enter embedding as JSON array..."
@ -259,8 +297,15 @@ export default function ContentDetailPage() {
</div> </div>
<div className="p-3 bg-gray-50 dark:bg-gray-800 rounded-md max-h-32 overflow-y-auto"> <div className="p-3 bg-gray-50 dark:bg-gray-800 rounded-md max-h-32 overflow-y-auto">
<pre className="whitespace-pre-wrap font-mono text-xs text-gray-900 dark:text-gray-100"> <pre className="whitespace-pre-wrap font-mono text-xs text-gray-900 dark:text-gray-100">
[{content.embedding.slice(0, 20).map(v => v.toFixed(6)).join(', ')} [
{content.embedding.length > 20 ? `\n... and ${content.embedding.length - 20} more values` : ''}] {content.embedding
.slice(0, 20)
.map(v => v.toFixed(6))
.join(", ")}
{content.embedding.length > 20
? `\n... and ${content.embedding.length - 20} more values`
: ""}
]
</pre> </pre>
</div> </div>
</div> </div>
@ -284,7 +329,7 @@ export default function ContentDetailPage() {
<div key={key} className="flex gap-2"> <div key={key} className="flex gap-2">
<Input <Input
value={key} value={key}
onChange={(e) => { onChange={e => {
const newMetadata = { ...editedMetadata }; const newMetadata = { ...editedMetadata };
delete newMetadata[key]; delete newMetadata[key];
newMetadata[e.target.value] = value; newMetadata[e.target.value] = value;
@ -294,11 +339,13 @@ export default function ContentDetailPage() {
className="flex-1" className="flex-1"
/> />
<Input <Input
value={typeof value === 'string' ? value : JSON.stringify(value)} value={
onChange={(e) => { typeof value === "string" ? value : JSON.stringify(value)
}
onChange={e => {
setEditedMetadata({ setEditedMetadata({
...editedMetadata, ...editedMetadata,
[key]: e.target.value [key]: e.target.value,
}); });
}} }}
placeholder="Value" placeholder="Value"
@ -312,7 +359,7 @@ export default function ContentDetailPage() {
onClick={() => { onClick={() => {
setEditedMetadata({ setEditedMetadata({
...editedMetadata, ...editedMetadata,
['']: '' [""]: "",
}); });
}} }}
> >
@ -325,7 +372,7 @@ export default function ContentDetailPage() {
<div key={key} className="flex justify-between py-1"> <div key={key} className="flex justify-between py-1">
<span className="font-medium text-gray-600">{key}:</span> <span className="font-medium text-gray-600">{key}:</span>
<span className="font-mono text-sm"> <span className="font-mono text-sm">
{typeof value === 'string' ? value : JSON.stringify(value)} {typeof value === "string" ? value : JSON.stringify(value)}
</span> </span>
</div> </div>
))} ))}
@ -351,15 +398,15 @@ export default function ContentDetailPage() {
value={`${getTextFromContent(content.content).length} chars`} value={`${getTextFromContent(content.content).length} chars`}
/> />
{content.metadata.chunk_window && ( {content.metadata.chunk_window && (
<PropertyItem <PropertyItem label="Position" value={content.metadata.chunk_window} />
label="Position"
value={content.metadata.chunk_window}
/>
)} )}
{file && ( {file && (
<> <>
<PropertyItem label="File Status" value={file.status} /> <PropertyItem label="File Status" value={file.status} />
<PropertyItem label="File Usage" value={`${file.usage_bytes} bytes`} /> <PropertyItem
label="File Usage"
value={`${file.usage_bytes} bytes`}
/>
</> </>
)} )}
{store && ( {store && (

View file

@ -0,0 +1,481 @@
import React from "react";
import {
render,
screen,
fireEvent,
waitFor,
act,
} from "@testing-library/react";
import "@testing-library/jest-dom";
import ContentsListPage from "./page";
import { VectorStoreContentItem } from "@/lib/contents-api";
import type { VectorStore } from "llama-stack-client/resources/vector-stores/vector-stores";
import type { VectorStoreFile } from "llama-stack-client/resources/vector-stores/files";
const mockPush = jest.fn();
const mockParams = {
id: "vs_123",
fileId: "file_456",
};
jest.mock("next/navigation", () => ({
useParams: () => mockParams,
useRouter: () => ({
push: mockPush,
}),
}));
const mockClient = {
vectorStores: {
retrieve: jest.fn(),
files: {
retrieve: jest.fn(),
},
},
};
jest.mock("@/hooks/use-auth-client", () => ({
useAuthClient: () => mockClient,
}));
const mockContentsAPI = {
listContents: jest.fn(),
deleteContent: jest.fn(),
};
jest.mock("@/lib/contents-api", () => ({
ContentsAPI: jest.fn(() => mockContentsAPI),
}));
describe("ContentsListPage", () => {
const mockStore: VectorStore = {
id: "vs_123",
name: "Test Vector Store",
created_at: 1710000000,
status: "ready",
file_counts: { total: 5 },
usage_bytes: 1024,
metadata: {
provider_id: "test_provider",
},
};
const mockFile: VectorStoreFile = {
id: "file_456",
status: "completed",
created_at: 1710001000,
usage_bytes: 512,
chunking_strategy: { type: "fixed_size" },
};
const mockContents: VectorStoreContentItem[] = [
{
id: "content_1",
object: "vector_store.content",
content: "First piece of content for testing.",
embedding: [0.1, 0.2, 0.3, 0.4, 0.5],
metadata: {
chunk_window: "0-35",
content_length: 35,
},
created_timestamp: 1710002000,
},
{
id: "content_2",
object: "vector_store.content",
content:
"Second piece of content with longer text for testing truncation and display.",
embedding: [0.6, 0.7, 0.8],
metadata: {
chunk_window: "36-95",
content_length: 85,
},
created_timestamp: 1710003000,
},
{
id: "content_3",
object: "vector_store.content",
content: "Third content without embedding.",
embedding: undefined,
metadata: {
content_length: 33,
},
created_timestamp: 1710004000,
},
];
beforeEach(() => {
jest.clearAllMocks();
mockClient.vectorStores.retrieve.mockResolvedValue(mockStore);
mockClient.vectorStores.files.retrieve.mockResolvedValue(mockFile);
mockContentsAPI.listContents.mockResolvedValue({
data: mockContents,
});
});
describe("Loading and Error States", () => {
test("renders loading skeleton while fetching store data", async () => {
mockClient.vectorStores.retrieve.mockImplementation(
() => new Promise(() => {})
);
await act(async () => {
render(<ContentsListPage />);
});
const skeletons = document.querySelectorAll('[data-slot="skeleton"]');
expect(skeletons.length).toBeGreaterThan(0);
});
test("renders error message when store API call fails", async () => {
const error = new Error("Failed to load store");
mockClient.vectorStores.retrieve.mockRejectedValue(error);
await act(async () => {
render(<ContentsListPage />);
});
await waitFor(() => {
expect(
screen.getByText(/Error loading details for ID vs_123/)
).toBeInTheDocument();
expect(screen.getByText(/Failed to load store/)).toBeInTheDocument();
});
});
test("renders not found when store doesn't exist", async () => {
mockClient.vectorStores.retrieve.mockResolvedValue(null);
await act(async () => {
render(<ContentsListPage />);
});
await waitFor(() => {
expect(
screen.getByText(/No details found for ID: vs_123/)
).toBeInTheDocument();
});
});
test("renders contents loading skeleton", async () => {
mockContentsAPI.listContents.mockImplementation(
() => new Promise(() => {})
);
const { container } = render(<ContentsListPage />);
await waitFor(() => {
expect(
screen.getByText("Contents in File: file_456")
).toBeInTheDocument();
});
const skeletons = container.querySelectorAll('[data-slot="skeleton"]');
expect(skeletons.length).toBeGreaterThan(0);
});
test("renders contents error message", async () => {
const error = new Error("Failed to load contents");
mockContentsAPI.listContents.mockRejectedValue(error);
render(<ContentsListPage />);
await waitFor(() => {
expect(
screen.getByText("Error loading contents: Failed to load contents")
).toBeInTheDocument();
});
});
});
describe("Contents Table Display", () => {
test("renders contents table with correct headers", async () => {
render(<ContentsListPage />);
await waitFor(() => {
expect(screen.getByText("Content Chunks (3)")).toBeInTheDocument();
expect(screen.getByText("Contents in this file")).toBeInTheDocument();
});
// Check table headers
expect(screen.getByText("Content ID")).toBeInTheDocument();
expect(screen.getByText("Content Preview")).toBeInTheDocument();
expect(screen.getByText("Embedding")).toBeInTheDocument();
expect(screen.getByText("Position")).toBeInTheDocument();
expect(screen.getByText("Created")).toBeInTheDocument();
expect(screen.getByText("Actions")).toBeInTheDocument();
});
test("renders content data correctly", async () => {
render(<ContentsListPage />);
await waitFor(() => {
// Check first content row
expect(screen.getByText("content_1...")).toBeInTheDocument();
expect(
screen.getByText("First piece of content for testing.")
).toBeInTheDocument();
expect(
screen.getByText("[0.100, 0.200, 0.300...] (5D)")
).toBeInTheDocument();
expect(screen.getByText("0-35")).toBeInTheDocument();
expect(
screen.getByText(new Date(1710002000 * 1000).toLocaleString())
).toBeInTheDocument();
expect(screen.getByText("content_2...")).toBeInTheDocument();
expect(
screen.getByText(/Second piece of content with longer text/)
).toBeInTheDocument();
expect(
screen.getByText("[0.600, 0.700, 0.800...] (3D)")
).toBeInTheDocument();
expect(screen.getByText("36-95")).toBeInTheDocument();
expect(screen.getByText("content_3...")).toBeInTheDocument();
expect(
screen.getByText("Third content without embedding.")
).toBeInTheDocument();
expect(screen.getByText("No embedding")).toBeInTheDocument();
expect(screen.getByText("33 chars")).toBeInTheDocument();
});
});
test("handles empty contents list", async () => {
mockContentsAPI.listContents.mockResolvedValue({
data: [],
});
render(<ContentsListPage />);
await waitFor(() => {
expect(screen.getByText("Content Chunks (0)")).toBeInTheDocument();
expect(
screen.getByText("No contents found for this file.")
).toBeInTheDocument();
});
});
test("truncates long content IDs", async () => {
const longIdContent = {
...mockContents[0],
id: "very_long_content_id_that_should_be_truncated_123456789",
};
mockContentsAPI.listContents.mockResolvedValue({
data: [longIdContent],
});
render(<ContentsListPage />);
await waitFor(() => {
expect(screen.getByText("very_long_...")).toBeInTheDocument();
});
});
});
describe("Content Navigation", () => {
test("navigates to content detail when content ID is clicked", async () => {
render(<ContentsListPage />);
await waitFor(() => {
expect(screen.getByText("content_1...")).toBeInTheDocument();
});
const contentLink = screen.getByRole("button", { name: "content_1..." });
fireEvent.click(contentLink);
expect(mockPush).toHaveBeenCalledWith(
"/logs/vector-stores/vs_123/files/file_456/contents/content_1"
);
});
test("navigates to content detail when view button is clicked", async () => {
render(<ContentsListPage />);
await waitFor(() => {
expect(screen.getByText("Content Chunks (3)")).toBeInTheDocument();
});
const viewButtons = screen.getAllByTitle("View content details");
fireEvent.click(viewButtons[0]);
expect(mockPush).toHaveBeenCalledWith(
"/logs/vector-stores/vs_123/files/file_456/contents/content_1"
);
});
test("navigates to content detail when edit button is clicked", async () => {
render(<ContentsListPage />);
await waitFor(() => {
expect(screen.getByText("Content Chunks (3)")).toBeInTheDocument();
});
const editButtons = screen.getAllByTitle("Edit content");
fireEvent.click(editButtons[0]);
expect(mockPush).toHaveBeenCalledWith(
"/logs/vector-stores/vs_123/files/file_456/contents/content_1"
);
});
});
describe("Content Deletion", () => {
test("deletes content when delete button is clicked", async () => {
mockContentsAPI.deleteContent.mockResolvedValue(undefined);
render(<ContentsListPage />);
await waitFor(() => {
expect(screen.getByText("Content Chunks (3)")).toBeInTheDocument();
});
const deleteButtons = screen.getAllByTitle("Delete content");
fireEvent.click(deleteButtons[0]);
await waitFor(() => {
expect(mockContentsAPI.deleteContent).toHaveBeenCalledWith(
"vs_123",
"file_456",
"content_1"
);
});
await waitFor(() => {
expect(screen.getByText("Content Chunks (2)")).toBeInTheDocument();
});
expect(screen.queryByText("content_1...")).not.toBeInTheDocument();
});
test("handles delete error gracefully", async () => {
const consoleError = jest
.spyOn(console, "error")
.mockImplementation(() => {});
mockContentsAPI.deleteContent.mockRejectedValue(
new Error("Delete failed")
);
render(<ContentsListPage />);
await waitFor(() => {
expect(screen.getByText("Content Chunks (3)")).toBeInTheDocument();
});
const deleteButtons = screen.getAllByTitle("Delete content");
fireEvent.click(deleteButtons[0]);
await waitFor(() => {
expect(consoleError).toHaveBeenCalledWith(
"Failed to delete content:",
expect.any(Error)
);
});
expect(screen.getByText("Content Chunks (3)")).toBeInTheDocument();
expect(screen.getByText("content_1...")).toBeInTheDocument();
consoleError.mockRestore();
});
});
describe("Breadcrumb Navigation", () => {
test("renders correct breadcrumb structure", async () => {
render(<ContentsListPage />);
await waitFor(() => {
const vectorStoreTexts = screen.getAllByText("Vector Stores");
expect(vectorStoreTexts.length).toBeGreaterThan(0);
const storeNameTexts = screen.getAllByText("Test Vector Store");
expect(storeNameTexts.length).toBeGreaterThan(0);
const filesTexts = screen.getAllByText("Files");
expect(filesTexts.length).toBeGreaterThan(0);
const fileIdTexts = screen.getAllByText("file_456");
expect(fileIdTexts.length).toBeGreaterThan(0);
const contentsTexts = screen.getAllByText("Contents");
expect(contentsTexts.length).toBeGreaterThan(0);
});
});
});
describe("Sidebar Properties", () => {
test("renders file and store properties", async () => {
render(<ContentsListPage />);
await waitFor(() => {
const fileIdTexts = screen.getAllByText("file_456");
expect(fileIdTexts.length).toBeGreaterThan(0);
const storeIdTexts = screen.getAllByText("vs_123");
expect(storeIdTexts.length).toBeGreaterThan(0);
const storeNameTexts = screen.getAllByText("Test Vector Store");
expect(storeNameTexts.length).toBeGreaterThan(0);
expect(screen.getByText("completed")).toBeInTheDocument();
expect(screen.getByText("512")).toBeInTheDocument();
expect(screen.getByText("fixed_size")).toBeInTheDocument();
expect(screen.getByText("test_provider")).toBeInTheDocument();
});
});
});
describe("Content Text Utilities", () => {
test("handles different content formats correctly", async () => {
const contentWithObject = {
...mockContents[0],
content: { type: "text", text: "Object format content" },
};
mockContentsAPI.listContents.mockResolvedValue({
data: [contentWithObject],
});
render(<ContentsListPage />);
await waitFor(() => {
expect(screen.getByText("Object format content")).toBeInTheDocument();
});
});
test("handles string content format", async () => {
const contentWithString = {
...mockContents[0],
content: "String format content",
};
mockContentsAPI.listContents.mockResolvedValue({
data: [contentWithString],
});
render(<ContentsListPage />);
await waitFor(() => {
expect(screen.getByText("String format content")).toBeInTheDocument();
});
});
test("handles unknown content format", async () => {
const contentWithUnknown = {
...mockContents[0],
content: { unknown: "format" },
};
mockContentsAPI.listContents.mockResolvedValue({
data: [contentWithUnknown],
});
render(<ContentsListPage />);
await waitFor(() => {
expect(screen.getByText("Content Chunks (1)")).toBeInTheDocument();
});
const contentCells = screen.getAllByRole("cell");
const contentPreviewCell = contentCells.find(cell =>
cell.querySelector("p[title]")
);
expect(contentPreviewCell?.querySelector("p")?.textContent).toBe("");
});
});
});

View file

@ -18,7 +18,10 @@ import {
PropertiesCard, PropertiesCard,
PropertyItem, PropertyItem,
} from "@/components/layout/detail-layout"; } from "@/components/layout/detail-layout";
import { PageBreadcrumb, BreadcrumbSegment } from "@/components/layout/page-breadcrumb"; import {
PageBreadcrumb,
BreadcrumbSegment,
} from "@/components/layout/page-breadcrumb";
import { import {
Table, Table,
TableBody, TableBody,
@ -36,13 +39,13 @@ export default function ContentsListPage() {
const fileId = params.fileId as string; const fileId = params.fileId as string;
const client = useAuthClient(); const client = useAuthClient();
const getTextFromContent = (content: any): string => { const getTextFromContent = (content: unknown): string => {
if (typeof content === 'string') { if (typeof content === "string") {
return content; return content;
} else if (content && content.type === 'text') { } else if (content && content.type === "text") {
return content.text; return content.text;
} }
return ''; return "";
}; };
const [store, setStore] = useState<VectorStore | null>(null); const [store, setStore] = useState<VectorStore | null>(null);
@ -65,7 +68,9 @@ export default function ContentsListPage() {
const response = await client.vectorStores.retrieve(vectorStoreId); const response = await client.vectorStores.retrieve(vectorStoreId);
setStore(response as VectorStore); setStore(response as VectorStore);
} catch (err) { } catch (err) {
setErrorStore(err instanceof Error ? err : new Error("Failed to load vector store.")); setErrorStore(
err instanceof Error ? err : new Error("Failed to load vector store.")
);
} finally { } finally {
setIsLoadingStore(false); setIsLoadingStore(false);
} }
@ -80,10 +85,15 @@ export default function ContentsListPage() {
setIsLoadingFile(true); setIsLoadingFile(true);
setErrorFile(null); setErrorFile(null);
try { try {
const response = await client.vectorStores.files.retrieve(vectorStoreId, fileId); const response = await client.vectorStores.files.retrieve(
vectorStoreId,
fileId
);
setFile(response as VectorStoreFile); setFile(response as VectorStoreFile);
} catch (err) { } catch (err) {
setErrorFile(err instanceof Error ? err : new Error("Failed to load file.")); setErrorFile(
err instanceof Error ? err : new Error("Failed to load file.")
);
} finally { } finally {
setIsLoadingFile(false); setIsLoadingFile(false);
} }
@ -99,10 +109,16 @@ export default function ContentsListPage() {
setErrorContents(null); setErrorContents(null);
try { try {
const contentsAPI = new ContentsAPI(client); const contentsAPI = new ContentsAPI(client);
const contentsResponse = await contentsAPI.listContents(vectorStoreId, fileId, { limit: 100 }); const contentsResponse = await contentsAPI.listContents(
vectorStoreId,
fileId,
{ limit: 100 }
);
setContents(contentsResponse.data); setContents(contentsResponse.data);
} catch (err) { } catch (err) {
setErrorContents(err instanceof Error ? err : new Error("Failed to load contents.")); setErrorContents(
err instanceof Error ? err : new Error("Failed to load contents.")
);
} finally { } finally {
setIsLoadingContents(false); setIsLoadingContents(false);
} }
@ -116,26 +132,36 @@ export default function ContentsListPage() {
await contentsAPI.deleteContent(vectorStoreId, fileId, contentId); await contentsAPI.deleteContent(vectorStoreId, fileId, contentId);
setContents(contents.filter(content => content.id !== contentId)); setContents(contents.filter(content => content.id !== contentId));
} catch (err) { } catch (err) {
console.error('Failed to delete content:', err); console.error("Failed to delete content:", err);
} }
}; };
const handleViewContent = (contentId: string) => { const handleViewContent = (contentId: string) => {
router.push(`/logs/vector-stores/${vectorStoreId}/files/${fileId}/contents/${contentId}`); router.push(
`/logs/vector-stores/${vectorStoreId}/files/${fileId}/contents/${contentId}`
);
}; };
const title = `Contents in File: ${fileId}`; const title = `Contents in File: ${fileId}`;
const breadcrumbSegments: BreadcrumbSegment[] = [ const breadcrumbSegments: BreadcrumbSegment[] = [
{ label: "Vector Stores", href: "/logs/vector-stores" }, { label: "Vector Stores", href: "/logs/vector-stores" },
{ label: store?.name || vectorStoreId, href: `/logs/vector-stores/${vectorStoreId}` }, {
label: store?.name || vectorStoreId,
href: `/logs/vector-stores/${vectorStoreId}`,
},
{ label: "Files", href: `/logs/vector-stores/${vectorStoreId}` }, { label: "Files", href: `/logs/vector-stores/${vectorStoreId}` },
{ label: fileId, href: `/logs/vector-stores/${vectorStoreId}/files/${fileId}` }, {
label: fileId,
href: `/logs/vector-stores/${vectorStoreId}/files/${fileId}`,
},
{ label: "Contents" }, { label: "Contents" },
]; ];
if (errorStore) { if (errorStore) {
return <DetailErrorView title={title} id={vectorStoreId} error={errorStore} />; return (
<DetailErrorView title={title} id={vectorStoreId} error={errorStore} />
);
} }
if (isLoadingStore) { if (isLoadingStore) {
return <DetailLoadingView title={title} />; return <DetailLoadingView title={title} />;
@ -151,7 +177,13 @@ export default function ContentsListPage() {
<CardTitle>Content Chunks ({contents.length})</CardTitle> <CardTitle>Content Chunks ({contents.length})</CardTitle>
</CardHeader> </CardHeader>
<CardContent> <CardContent>
{isLoadingContents ? ( {isLoadingFile ? (
<Skeleton className="h-4 w-full" />
) : errorFile ? (
<div className="text-destructive text-sm">
Error loading file: {errorFile.message}
</div>
) : isLoadingContents ? (
<div className="space-y-2"> <div className="space-y-2">
<Skeleton className="h-4 w-full" /> <Skeleton className="h-4 w-full" />
<Skeleton className="h-4 w-3/4" /> <Skeleton className="h-4 w-3/4" />
@ -175,7 +207,7 @@ export default function ContentsListPage() {
</TableRow> </TableRow>
</TableHeader> </TableHeader>
<TableBody> <TableBody>
{contents.map((content) => ( {contents.map(content => (
<TableRow key={content.id}> <TableRow key={content.id}>
<TableCell className="font-mono text-xs"> <TableCell className="font-mono text-xs">
<Button <Button
@ -189,7 +221,10 @@ export default function ContentsListPage() {
</TableCell> </TableCell>
<TableCell> <TableCell>
<div className="max-w-md"> <div className="max-w-md">
<p className="text-sm truncate" title={getTextFromContent(content.content)}> <p
className="text-sm truncate"
title={getTextFromContent(content.content)}
>
{getTextFromContent(content.content)} {getTextFromContent(content.content)}
</p> </p>
</div> </div>
@ -197,12 +232,25 @@ export default function ContentsListPage() {
<TableCell className="text-xs text-gray-500"> <TableCell className="text-xs text-gray-500">
{content.embedding && content.embedding.length > 0 ? ( {content.embedding && content.embedding.length > 0 ? (
<div className="max-w-xs"> <div className="max-w-xs">
<span className="font-mono text-xs bg-gray-100 dark:bg-gray-800 rounded px-1 py-0.5" title={`${content.embedding.length}D vector: [${content.embedding.slice(0, 3).map(v => v.toFixed(3)).join(', ')}...]`}> <span
[{content.embedding.slice(0, 3).map(v => v.toFixed(3)).join(', ')}...] ({content.embedding.length}D) className="font-mono text-xs bg-gray-100 dark:bg-gray-800 rounded px-1 py-0.5"
title={`${content.embedding.length}D vector: [${content.embedding
.slice(0, 3)
.map(v => v.toFixed(3))
.join(", ")}...]`}
>
[
{content.embedding
.slice(0, 3)
.map(v => v.toFixed(3))
.join(", ")}
...] ({content.embedding.length}D)
</span> </span>
</div> </div>
) : ( ) : (
<span className="text-gray-400 dark:text-gray-500 italic">No embedding</span> <span className="text-gray-400 dark:text-gray-500 italic">
No embedding
</span>
)} )}
</TableCell> </TableCell>
<TableCell className="text-xs text-gray-500"> <TableCell className="text-xs text-gray-500">
@ -211,7 +259,9 @@ export default function ContentsListPage() {
: `${content.metadata.content_length || 0} chars`} : `${content.metadata.content_length || 0} chars`}
</TableCell> </TableCell>
<TableCell className="text-xs"> <TableCell className="text-xs">
{new Date(content.created_timestamp * 1000).toLocaleString()} {new Date(
content.created_timestamp * 1000
).toLocaleString()}
</TableCell> </TableCell>
<TableCell> <TableCell>
<div className="flex gap-1"> <div className="flex gap-1">

View file

@ -0,0 +1,458 @@
import React from "react";
import {
render,
screen,
fireEvent,
waitFor,
act,
} from "@testing-library/react";
import "@testing-library/jest-dom";
import FileDetailPage from "./page";
import type { VectorStore } from "llama-stack-client/resources/vector-stores/vector-stores";
import type {
VectorStoreFile,
FileContentResponse,
} from "llama-stack-client/resources/vector-stores/files";
const mockPush = jest.fn();
const mockParams = {
id: "vs_123",
fileId: "file_456",
};
jest.mock("next/navigation", () => ({
useParams: () => mockParams,
useRouter: () => ({
push: mockPush,
}),
}));
const mockClient = {
vectorStores: {
retrieve: jest.fn(),
files: {
retrieve: jest.fn(),
content: jest.fn(),
},
},
};
jest.mock("@/hooks/use-auth-client", () => ({
useAuthClient: () => mockClient,
}));
describe("FileDetailPage", () => {
const mockStore: VectorStore = {
id: "vs_123",
name: "Test Vector Store",
created_at: 1710000000,
status: "ready",
file_counts: { total: 5 },
usage_bytes: 1024,
metadata: {
provider_id: "test_provider",
},
};
const mockFile: VectorStoreFile = {
id: "file_456",
status: "completed",
created_at: 1710001000,
usage_bytes: 2048,
chunking_strategy: { type: "fixed_size" },
};
const mockFileContent: FileContentResponse = {
content: [
{ text: "First chunk of file content." },
{
text: "Second chunk with more detailed information about the content.",
},
{ text: "Third and final chunk of the file." },
],
};
beforeEach(() => {
jest.clearAllMocks();
mockClient.vectorStores.retrieve.mockResolvedValue(mockStore);
mockClient.vectorStores.files.retrieve.mockResolvedValue(mockFile);
mockClient.vectorStores.files.content.mockResolvedValue(mockFileContent);
});
describe("Loading and Error States", () => {
test("renders loading skeleton while fetching store data", async () => {
mockClient.vectorStores.retrieve.mockImplementation(
() => new Promise(() => {})
);
await act(async () => {
await act(async () => {
render(<FileDetailPage />);
});
});
const skeletons = document.querySelectorAll('[data-slot="skeleton"]');
expect(skeletons.length).toBeGreaterThan(0);
});
test("renders error message when store API call fails", async () => {
const error = new Error("Failed to load store");
mockClient.vectorStores.retrieve.mockRejectedValue(error);
await act(async () => {
await act(async () => {
render(<FileDetailPage />);
});
});
await waitFor(() => {
expect(
screen.getByText(/Error loading details for ID vs_123/)
).toBeInTheDocument();
expect(screen.getByText(/Failed to load store/)).toBeInTheDocument();
});
});
test("renders not found when store doesn't exist", async () => {
mockClient.vectorStores.retrieve.mockResolvedValue(null);
await act(async () => {
render(<FileDetailPage />);
});
await waitFor(() => {
expect(
screen.getByText(/No details found for ID: vs_123/)
).toBeInTheDocument();
});
});
test("renders file loading skeleton", async () => {
mockClient.vectorStores.files.retrieve.mockImplementation(
() => new Promise(() => {})
);
const { container } = render(<FileDetailPage />);
await waitFor(() => {
expect(screen.getByText("File: file_456")).toBeInTheDocument();
});
const skeletons = container.querySelectorAll('[data-slot="skeleton"]');
expect(skeletons.length).toBeGreaterThan(0);
});
test("renders file error message", async () => {
const error = new Error("Failed to load file");
mockClient.vectorStores.files.retrieve.mockRejectedValue(error);
await act(async () => {
render(<FileDetailPage />);
});
await waitFor(() => {
expect(
screen.getByText("Error loading file: Failed to load file")
).toBeInTheDocument();
});
});
test("renders content error message", async () => {
const error = new Error("Failed to load contents");
mockClient.vectorStores.files.content.mockRejectedValue(error);
await act(async () => {
render(<FileDetailPage />);
});
await waitFor(() => {
expect(
screen.getByText(
"Error loading content summary: Failed to load contents"
)
).toBeInTheDocument();
});
});
});
describe("File Information Display", () => {
test("renders file details correctly", async () => {
await act(async () => {
await act(async () => {
render(<FileDetailPage />);
});
});
await waitFor(() => {
expect(screen.getByText("File: file_456")).toBeInTheDocument();
expect(screen.getByText("File Information")).toBeInTheDocument();
expect(screen.getByText("File Details")).toBeInTheDocument();
});
const statusTexts = screen.getAllByText("Status:");
expect(statusTexts.length).toBeGreaterThan(0);
const completedTexts = screen.getAllByText("completed");
expect(completedTexts.length).toBeGreaterThan(0);
expect(screen.getByText("Size:")).toBeInTheDocument();
expect(screen.getByText("2048 bytes")).toBeInTheDocument();
const createdTexts = screen.getAllByText("Created:");
expect(createdTexts.length).toBeGreaterThan(0);
const dateTexts = screen.getAllByText(
new Date(1710001000 * 1000).toLocaleString()
);
expect(dateTexts.length).toBeGreaterThan(0);
const strategyTexts = screen.getAllByText("Content Strategy:");
expect(strategyTexts.length).toBeGreaterThan(0);
const fixedSizeTexts = screen.getAllByText("fixed_size");
expect(fixedSizeTexts.length).toBeGreaterThan(0);
});
test("handles missing file data", async () => {
mockClient.vectorStores.files.retrieve.mockResolvedValue(null);
await act(async () => {
render(<FileDetailPage />);
});
await waitFor(() => {
expect(screen.getByText("File not found.")).toBeInTheDocument();
});
});
});
describe("Content Summary Display", () => {
test("renders content summary correctly", async () => {
await act(async () => {
render(<FileDetailPage />);
});
await waitFor(() => {
expect(screen.getByText("Content Summary")).toBeInTheDocument();
expect(screen.getByText("Content Items:")).toBeInTheDocument();
expect(screen.getByText("3")).toBeInTheDocument();
expect(screen.getByText("Total Characters:")).toBeInTheDocument();
const totalChars = mockFileContent.content.reduce(
(total, item) => total + item.text.length,
0
);
expect(screen.getByText(totalChars.toString())).toBeInTheDocument();
expect(screen.getByText("Preview:")).toBeInTheDocument();
expect(
screen.getByText(/First chunk of file content\./)
).toBeInTheDocument();
});
});
test("handles empty content", async () => {
mockClient.vectorStores.files.content.mockResolvedValue({
content: [],
});
await act(async () => {
render(<FileDetailPage />);
});
await waitFor(() => {
expect(
screen.getByText("No contents found for this file.")
).toBeInTheDocument();
});
});
test("truncates long content preview", async () => {
const longContent = {
content: [
{
text: "This is a very long piece of content that should be truncated after 200 characters to ensure the preview doesn't take up too much space in the UI and remains readable and manageable for users viewing the file details page.",
},
],
};
mockClient.vectorStores.files.content.mockResolvedValue(longContent);
await act(async () => {
render(<FileDetailPage />);
});
await waitFor(() => {
expect(
screen.getByText(/This is a very long piece of content/)
).toBeInTheDocument();
expect(screen.getByText(/\.\.\.$/)).toBeInTheDocument();
});
});
});
describe("Navigation and Actions", () => {
test("navigates to contents list when View Contents button is clicked", async () => {
await act(async () => {
render(<FileDetailPage />);
});
await waitFor(() => {
expect(screen.getByText("Actions")).toBeInTheDocument();
});
const viewContentsButton = screen.getByRole("button", {
name: /View Contents/,
});
fireEvent.click(viewContentsButton);
expect(mockPush).toHaveBeenCalledWith(
"/logs/vector-stores/vs_123/files/file_456/contents"
);
});
test("View Contents button is styled correctly", async () => {
await act(async () => {
render(<FileDetailPage />);
});
await waitFor(() => {
const button = screen.getByRole("button", { name: /View Contents/ });
expect(button).toHaveClass("flex", "items-center", "gap-2");
});
});
});
describe("Breadcrumb Navigation", () => {
test("renders correct breadcrumb structure", async () => {
await act(async () => {
render(<FileDetailPage />);
});
await waitFor(() => {
const vectorStoresTexts = screen.getAllByText("Vector Stores");
expect(vectorStoresTexts.length).toBeGreaterThan(0);
const storeNameTexts = screen.getAllByText("Test Vector Store");
expect(storeNameTexts.length).toBeGreaterThan(0);
const filesTexts = screen.getAllByText("Files");
expect(filesTexts.length).toBeGreaterThan(0);
const fileIdTexts = screen.getAllByText("file_456");
expect(fileIdTexts.length).toBeGreaterThan(0);
});
});
test("uses store ID when store name is not available", async () => {
const storeWithoutName = { ...mockStore, name: "" };
mockClient.vectorStores.retrieve.mockResolvedValue(storeWithoutName);
await act(async () => {
render(<FileDetailPage />);
});
await waitFor(() => {
const storeIdTexts = screen.getAllByText("vs_123");
expect(storeIdTexts.length).toBeGreaterThan(0);
});
});
});
describe("Sidebar Properties", () => {
test.skip("renders file and store properties correctly", async () => {
await act(async () => {
render(<FileDetailPage />);
});
await waitFor(() => {
expect(screen.getByText("File ID")).toBeInTheDocument();
const fileIdTexts = screen.getAllByText("file_456");
expect(fileIdTexts.length).toBeGreaterThan(0);
expect(screen.getByText("Vector Store ID")).toBeInTheDocument();
const storeIdTexts = screen.getAllByText("vs_123");
expect(storeIdTexts.length).toBeGreaterThan(0);
expect(screen.getByText("Status")).toBeInTheDocument();
const completedTexts = screen.getAllByText("completed");
expect(completedTexts.length).toBeGreaterThan(0);
expect(screen.getByText("Usage Bytes")).toBeInTheDocument();
const usageTexts = screen.getAllByText("2048");
expect(usageTexts.length).toBeGreaterThan(0);
expect(screen.getByText("Content Strategy")).toBeInTheDocument();
const fixedSizeTexts = screen.getAllByText("fixed_size");
expect(fixedSizeTexts.length).toBeGreaterThan(0);
expect(screen.getByText("Store Name")).toBeInTheDocument();
const storeNameTexts = screen.getAllByText("Test Vector Store");
expect(storeNameTexts.length).toBeGreaterThan(0);
expect(screen.getByText("Provider ID")).toBeInTheDocument();
expect(screen.getByText("test_provider")).toBeInTheDocument();
});
});
test("handles missing optional properties", async () => {
const minimalFile = {
id: "file_456",
status: "completed",
created_at: 1710001000,
usage_bytes: 2048,
chunking_strategy: { type: "fixed_size" },
};
const minimalStore = {
...mockStore,
name: "",
metadata: {},
};
mockClient.vectorStores.files.retrieve.mockResolvedValue(minimalFile);
mockClient.vectorStores.retrieve.mockResolvedValue(minimalStore);
await act(async () => {
render(<FileDetailPage />);
});
await waitFor(() => {
const fileIdTexts = screen.getAllByText("file_456");
expect(fileIdTexts.length).toBeGreaterThan(0);
const storeIdTexts = screen.getAllByText("vs_123");
expect(storeIdTexts.length).toBeGreaterThan(0);
});
expect(screen.getByText("File: file_456")).toBeInTheDocument();
});
});
describe("Loading States for Individual Sections", () => {
test("shows loading skeleton for content while file loads", async () => {
mockClient.vectorStores.files.content.mockImplementation(
() => new Promise(() => {})
);
const { container } = render(<FileDetailPage />);
await waitFor(() => {
expect(screen.getByText("Content Summary")).toBeInTheDocument();
});
const skeletons = container.querySelectorAll('[data-slot="skeleton"]');
expect(skeletons.length).toBeGreaterThan(0);
});
});
describe("Error Handling", () => {
test("handles multiple simultaneous errors gracefully", async () => {
mockClient.vectorStores.files.retrieve.mockRejectedValue(
new Error("File error")
);
mockClient.vectorStores.files.content.mockRejectedValue(
new Error("Content error")
);
await act(async () => {
render(<FileDetailPage />);
});
await waitFor(() => {
expect(
screen.getByText("Error loading file: File error")
).toBeInTheDocument();
expect(
screen.getByText("Error loading content summary: Content error")
).toBeInTheDocument();
});
});
});
});

View file

@ -4,9 +4,12 @@ import { useEffect, useState } from "react";
import { useParams, useRouter } from "next/navigation"; import { useParams, useRouter } from "next/navigation";
import { useAuthClient } from "@/hooks/use-auth-client"; import { useAuthClient } from "@/hooks/use-auth-client";
import type { VectorStore } from "llama-stack-client/resources/vector-stores/vector-stores"; import type { VectorStore } from "llama-stack-client/resources/vector-stores/vector-stores";
import type { VectorStoreFile, FileContentResponse } from "llama-stack-client/resources/vector-stores/files"; import type {
VectorStoreFile,
FileContentResponse,
} from "llama-stack-client/resources/vector-stores/files";
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card"; import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
import { Skeleton } from '@/components/ui/skeleton'; import { Skeleton } from "@/components/ui/skeleton";
import { Button } from "@/components/ui/button"; import { Button } from "@/components/ui/button";
import { List } from "lucide-react"; import { List } from "lucide-react";
import { import {
@ -17,7 +20,10 @@ import {
PropertiesCard, PropertiesCard,
PropertyItem, PropertyItem,
} from "@/components/layout/detail-layout"; } from "@/components/layout/detail-layout";
import { PageBreadcrumb, BreadcrumbSegment } from "@/components/layout/page-breadcrumb"; import {
PageBreadcrumb,
BreadcrumbSegment,
} from "@/components/layout/page-breadcrumb";
export default function FileDetailPage() { export default function FileDetailPage() {
const params = useParams(); const params = useParams();
@ -46,7 +52,9 @@ export default function FileDetailPage() {
const response = await client.vectorStores.retrieve(vectorStoreId); const response = await client.vectorStores.retrieve(vectorStoreId);
setStore(response as VectorStore); setStore(response as VectorStore);
} catch (err) { } catch (err) {
setErrorStore(err instanceof Error ? err : new Error("Failed to load vector store.")); setErrorStore(
err instanceof Error ? err : new Error("Failed to load vector store.")
);
} finally { } finally {
setIsLoadingStore(false); setIsLoadingStore(false);
} }
@ -61,10 +69,15 @@ export default function FileDetailPage() {
setIsLoadingFile(true); setIsLoadingFile(true);
setErrorFile(null); setErrorFile(null);
try { try {
const response = await client.vectorStores.files.retrieve(vectorStoreId, fileId); const response = await client.vectorStores.files.retrieve(
vectorStoreId,
fileId
);
setFile(response as VectorStoreFile); setFile(response as VectorStoreFile);
} catch (err) { } catch (err) {
setErrorFile(err instanceof Error ? err : new Error("Failed to load file.")); setErrorFile(
err instanceof Error ? err : new Error("Failed to load file.")
);
} finally { } finally {
setIsLoadingFile(false); setIsLoadingFile(false);
} }
@ -79,10 +92,15 @@ export default function FileDetailPage() {
setIsLoadingContents(true); setIsLoadingContents(true);
setErrorContents(null); setErrorContents(null);
try { try {
const response = await client.vectorStores.files.content(vectorStoreId, fileId); const response = await client.vectorStores.files.content(
vectorStoreId,
fileId
);
setContents(response); setContents(response);
} catch (err) { } catch (err) {
setErrorContents(err instanceof Error ? err : new Error("Failed to load contents.")); setErrorContents(
err instanceof Error ? err : new Error("Failed to load contents.")
);
} finally { } finally {
setIsLoadingContents(false); setIsLoadingContents(false);
} }
@ -91,20 +109,27 @@ export default function FileDetailPage() {
}, [vectorStoreId, fileId, client]); }, [vectorStoreId, fileId, client]);
const handleViewContents = () => { const handleViewContents = () => {
router.push(`/logs/vector-stores/${vectorStoreId}/files/${fileId}/contents`); router.push(
`/logs/vector-stores/${vectorStoreId}/files/${fileId}/contents`
);
}; };
const title = `File: ${fileId}`; const title = `File: ${fileId}`;
const breadcrumbSegments: BreadcrumbSegment[] = [ const breadcrumbSegments: BreadcrumbSegment[] = [
{ label: "Vector Stores", href: "/logs/vector-stores" }, { label: "Vector Stores", href: "/logs/vector-stores" },
{ label: store?.name || vectorStoreId, href: `/logs/vector-stores/${vectorStoreId}` }, {
label: store?.name || vectorStoreId,
href: `/logs/vector-stores/${vectorStoreId}`,
},
{ label: "Files", href: `/logs/vector-stores/${vectorStoreId}` }, { label: "Files", href: `/logs/vector-stores/${vectorStoreId}` },
{ label: fileId }, { label: fileId },
]; ];
if (errorStore) { if (errorStore) {
return <DetailErrorView title={title} id={vectorStoreId} error={errorStore} />; return (
<DetailErrorView title={title} id={vectorStoreId} error={errorStore} />
);
} }
if (isLoadingStore) { if (isLoadingStore) {
return <DetailLoadingView title={title} />; return <DetailLoadingView title={title} />;
@ -136,19 +161,29 @@ export default function FileDetailPage() {
<h3 className="text-lg font-medium mb-2">File Details</h3> <h3 className="text-lg font-medium mb-2">File Details</h3>
<div className="grid grid-cols-2 gap-4 text-sm"> <div className="grid grid-cols-2 gap-4 text-sm">
<div> <div>
<span className="font-medium text-gray-600 dark:text-gray-400">Status:</span> <span className="font-medium text-gray-600 dark:text-gray-400">
Status:
</span>
<span className="ml-2">{file.status}</span> <span className="ml-2">{file.status}</span>
</div> </div>
<div> <div>
<span className="font-medium text-gray-600 dark:text-gray-400">Size:</span> <span className="font-medium text-gray-600 dark:text-gray-400">
Size:
</span>
<span className="ml-2">{file.usage_bytes} bytes</span> <span className="ml-2">{file.usage_bytes} bytes</span>
</div> </div>
<div> <div>
<span className="font-medium text-gray-600 dark:text-gray-400">Created:</span> <span className="font-medium text-gray-600 dark:text-gray-400">
<span className="ml-2">{new Date(file.created_at * 1000).toLocaleString()}</span> Created:
</span>
<span className="ml-2">
{new Date(file.created_at * 1000).toLocaleString()}
</span>
</div> </div>
<div> <div>
<span className="font-medium text-gray-600 dark:text-gray-400">Content Strategy:</span> <span className="font-medium text-gray-600 dark:text-gray-400">
Content Strategy:
</span>
<span className="ml-2">{file.chunking_strategy.type}</span> <span className="ml-2">{file.chunking_strategy.type}</span>
</div> </div>
</div> </div>
@ -166,9 +201,7 @@ export default function FileDetailPage() {
</div> </div>
</div> </div>
) : ( ) : (
<p className="text-gray-500 italic text-sm"> <p className="text-gray-500 italic text-sm">File not found.</p>
File not found.
</p>
)} )}
</CardContent> </CardContent>
</Card> </Card>
@ -192,16 +225,27 @@ export default function FileDetailPage() {
<div className="space-y-3"> <div className="space-y-3">
<div className="grid grid-cols-2 gap-4 text-sm"> <div className="grid grid-cols-2 gap-4 text-sm">
<div> <div>
<span className="font-medium text-gray-600 dark:text-gray-400">Content Items:</span> <span className="font-medium text-gray-600 dark:text-gray-400">
Content Items:
</span>
<span className="ml-2">{contents.content.length}</span> <span className="ml-2">{contents.content.length}</span>
</div> </div>
<div> <div>
<span className="font-medium text-gray-600 dark:text-gray-400">Total Characters:</span> <span className="font-medium text-gray-600 dark:text-gray-400">
<span className="ml-2">{contents.content.reduce((total, item) => total + item.text.length, 0)}</span> Total Characters:
</span>
<span className="ml-2">
{contents.content.reduce(
(total, item) => total + item.text.length,
0
)}
</span>
</div> </div>
</div> </div>
<div className="pt-2"> <div className="pt-2">
<span className="text-sm font-medium text-gray-600 dark:text-gray-400">Preview:</span> <span className="text-sm font-medium text-gray-600 dark:text-gray-400">
Preview:
</span>
<div className="mt-1 bg-gray-50 dark:bg-gray-800 rounded-md p-3"> <div className="mt-1 bg-gray-50 dark:bg-gray-800 rounded-md p-3">
<p className="text-sm text-gray-900 dark:text-gray-100 line-clamp-3"> <p className="text-sm text-gray-900 dark:text-gray-100 line-clamp-3">
{contents.content[0]?.text.substring(0, 200)}... {contents.content[0]?.text.substring(0, 200)}...

View file

@ -1,7 +1,7 @@
"use client"; "use client";
import { useEffect, useState } from "react"; import { useEffect, useState } from "react";
import { useParams, useRouter } from "next/navigation"; import { useParams } from "next/navigation";
import { useAuthClient } from "@/hooks/use-auth-client"; import { useAuthClient } from "@/hooks/use-auth-client";
import type { VectorStore } from "llama-stack-client/resources/vector-stores/vector-stores"; import type { VectorStore } from "llama-stack-client/resources/vector-stores/vector-stores";
import type { VectorStoreFile } from "llama-stack-client/resources/vector-stores/files"; import type { VectorStoreFile } from "llama-stack-client/resources/vector-stores/files";
@ -11,7 +11,6 @@ export default function VectorStoreDetailPage() {
const params = useParams(); const params = useParams();
const id = params.id as string; const id = params.id as string;
const client = useAuthClient(); const client = useAuthClient();
const router = useRouter();
const [store, setStore] = useState<VectorStore | null>(null); const [store, setStore] = useState<VectorStore | null>(null);
const [files, setFiles] = useState<VectorStoreFile[]>([]); const [files, setFiles] = useState<VectorStoreFile[]>([]);
@ -34,9 +33,7 @@ export default function VectorStoreDetailPage() {
setStore(response as VectorStore); setStore(response as VectorStore);
} catch (err) { } catch (err) {
setErrorStore( setErrorStore(
err instanceof Error err instanceof Error ? err : new Error("Failed to load vector store.")
? err
: new Error("Failed to load vector store."),
); );
} finally { } finally {
setIsLoadingStore(false); setIsLoadingStore(false);
@ -55,18 +52,18 @@ export default function VectorStoreDetailPage() {
setIsLoadingFiles(true); setIsLoadingFiles(true);
setErrorFiles(null); setErrorFiles(null);
try { try {
const result = await client.vectorStores.files.list(id as any); const result = await client.vectorStores.files.list(id);
setFiles((result as any).data); setFiles((result as { data: VectorStoreFile[] }).data);
} catch (err) { } catch (err) {
setErrorFiles( setErrorFiles(
err instanceof Error ? err : new Error("Failed to load files."), err instanceof Error ? err : new Error("Failed to load files.")
); );
} finally { } finally {
setIsLoadingFiles(false); setIsLoadingFiles(false);
} }
}; };
fetchFiles(); fetchFiles();
}, [id]); }, [id, client.vectorStores.files]);
return ( return (
<VectorStoreDetailView <VectorStoreDetailView

View file

@ -1,7 +1,6 @@
"use client"; "use client";
import React from "react"; import React from "react";
import { useAuthClient } from "@/hooks/use-auth-client";
import type { import type {
ListVectorStoresResponse, ListVectorStoresResponse,
VectorStore, VectorStore,
@ -12,7 +11,6 @@ import { Button } from "@/components/ui/button";
import { import {
Table, Table,
TableBody, TableBody,
TableCaption,
TableCell, TableCell,
TableHead, TableHead,
TableHeader, TableHeader,
@ -21,7 +19,6 @@ import {
import { Skeleton } from "@/components/ui/skeleton"; import { Skeleton } from "@/components/ui/skeleton";
export default function VectorStoresPage() { export default function VectorStoresPage() {
const client = useAuthClient();
const router = useRouter(); const router = useRouter();
const { const {
data: stores, data: stores,
@ -37,7 +34,7 @@ export default function VectorStoresPage() {
after: params.after, after: params.after,
limit: params.limit, limit: params.limit,
order: params.order, order: params.order,
} as any); } as Parameters<typeof client.vectorStores.list>[0]);
return response as ListVectorStoresResponse; return response as ListVectorStoresResponse;
}, },
errorMessagePrefix: "vector stores", errorMessagePrefix: "vector stores",
@ -53,11 +50,11 @@ export default function VectorStoresPage() {
const renderContent = () => { const renderContent = () => {
if (status === "loading") { if (status === "loading") {
return ( return (
<div className="space-y-2"> <div className="space-y-2">
<Skeleton className="h-8 w-full"/> <Skeleton className="h-8 w-full" />
<Skeleton className="h-4 w-full"/> <Skeleton className="h-4 w-full" />
<Skeleton className="h-4 w-full"/> <Skeleton className="h-4 w-full" />
</div> </div>
); );
} }
@ -70,72 +67,72 @@ export default function VectorStoresPage() {
} }
return ( return (
<div className="overflow-auto flex-1 min-h-0"> <div className="overflow-auto flex-1 min-h-0">
<Table> <Table>
<TableHeader> <TableHeader>
<TableRow> <TableRow>
<TableHead>ID</TableHead> <TableHead>ID</TableHead>
<TableHead>Name</TableHead> <TableHead>Name</TableHead>
<TableHead>Created</TableHead> <TableHead>Created</TableHead>
<TableHead>Completed</TableHead> <TableHead>Completed</TableHead>
<TableHead>Cancelled</TableHead> <TableHead>Cancelled</TableHead>
<TableHead>Failed</TableHead> <TableHead>Failed</TableHead>
<TableHead>In Progress</TableHead> <TableHead>In Progress</TableHead>
<TableHead>Total</TableHead> <TableHead>Total</TableHead>
<TableHead>Usage Bytes</TableHead> <TableHead>Usage Bytes</TableHead>
<TableHead>Provider ID</TableHead> <TableHead>Provider ID</TableHead>
<TableHead>Provider Vector DB ID</TableHead> <TableHead>Provider Vector DB ID</TableHead>
</TableRow> </TableRow>
</TableHeader> </TableHeader>
<TableBody> <TableBody>
{stores.map((store) => { {stores.map(store => {
const fileCounts = store.file_counts; const fileCounts = store.file_counts;
const metadata = store.metadata || {}; const metadata = store.metadata || {};
const providerId = metadata.provider_id ?? ""; const providerId = metadata.provider_id ?? "";
const providerDbId = metadata.provider_vector_db_id ?? ""; const providerDbId = metadata.provider_vector_db_id ?? "";
return ( return (
<TableRow <TableRow
key={store.id} key={store.id}
onClick={() => router.push(`/logs/vector-stores/${store.id}`)} onClick={() => router.push(`/logs/vector-stores/${store.id}`)}
className="cursor-pointer hover:bg-muted/50" className="cursor-pointer hover:bg-muted/50"
>
<TableCell>
<Button
variant="link"
className="p-0 h-auto font-mono text-blue-600 hover:text-blue-800 dark:text-blue-400 dark:hover:text-blue-300"
onClick={() =>
router.push(`/logs/vector-stores/${store.id}`)
}
> >
<TableCell> {store.id}
<Button </Button>
variant="link" </TableCell>
className="p-0 h-auto font-mono text-blue-600 hover:text-blue-800 dark:text-blue-400 dark:hover:text-blue-300" <TableCell>{store.name}</TableCell>
onClick={() => <TableCell>
router.push(`/logs/vector-stores/${store.id}`) {new Date(store.created_at * 1000).toLocaleString()}
} </TableCell>
> <TableCell>{fileCounts.completed}</TableCell>
{store.id} <TableCell>{fileCounts.cancelled}</TableCell>
</Button> <TableCell>{fileCounts.failed}</TableCell>
</TableCell> <TableCell>{fileCounts.in_progress}</TableCell>
<TableCell>{store.name}</TableCell> <TableCell>{fileCounts.total}</TableCell>
<TableCell> <TableCell>{store.usage_bytes}</TableCell>
{new Date(store.created_at * 1000).toLocaleString()} <TableCell>{providerId}</TableCell>
</TableCell> <TableCell>{providerDbId}</TableCell>
<TableCell>{fileCounts.completed}</TableCell> </TableRow>
<TableCell>{fileCounts.cancelled}</TableCell> );
<TableCell>{fileCounts.failed}</TableCell> })}
<TableCell>{fileCounts.in_progress}</TableCell> </TableBody>
<TableCell>{fileCounts.total}</TableCell> </Table>
<TableCell>{store.usage_bytes}</TableCell> </div>
<TableCell>{providerId}</TableCell>
<TableCell>{providerDbId}</TableCell>
</TableRow>
);
})}
</TableBody>
</Table>
</div>
); );
}; };
return ( return (
<div className="space-y-4"> <div className="space-y-4">
<h1 className="text-2xl font-semibold">Vector Stores</h1> <h1 className="text-2xl font-semibold">Vector Stores</h1>
{renderContent()} {renderContent()}
</div> </div>
); );
} }

View file

@ -14,7 +14,7 @@ describe("ChatCompletionDetailView", () => {
isLoading={true} isLoading={true}
error={null} error={null}
id="test-id" id="test-id"
/>, />
); );
// Use the data-slot attribute for Skeletons // Use the data-slot attribute for Skeletons
const skeletons = container.querySelectorAll('[data-slot="skeleton"]'); const skeletons = container.querySelectorAll('[data-slot="skeleton"]');
@ -28,10 +28,10 @@ describe("ChatCompletionDetailView", () => {
isLoading={false} isLoading={false}
error={{ name: "Error", message: "Network Error" }} error={{ name: "Error", message: "Network Error" }}
id="err-id" id="err-id"
/>, />
); );
expect( expect(
screen.getByText(/Error loading details for ID err-id: Network Error/), screen.getByText(/Error loading details for ID err-id: Network Error/)
).toBeInTheDocument(); ).toBeInTheDocument();
}); });
@ -42,11 +42,11 @@ describe("ChatCompletionDetailView", () => {
isLoading={false} isLoading={false}
error={{ name: "Error", message: "" }} error={{ name: "Error", message: "" }}
id="err-id" id="err-id"
/>, />
); );
// Use regex to match the error message regardless of whitespace // Use regex to match the error message regardless of whitespace
expect( expect(
screen.getByText(/Error loading details for ID\s*err-id\s*:/), screen.getByText(/Error loading details for ID\s*err-id\s*:/)
).toBeInTheDocument(); ).toBeInTheDocument();
}); });
@ -57,11 +57,11 @@ describe("ChatCompletionDetailView", () => {
isLoading={false} isLoading={false}
error={{} as Error} error={{} as Error}
id="err-id" id="err-id"
/>, />
); );
// Use regex to match the error message regardless of whitespace // Use regex to match the error message regardless of whitespace
expect( expect(
screen.getByText(/Error loading details for ID\s*err-id\s*:/), screen.getByText(/Error loading details for ID\s*err-id\s*:/)
).toBeInTheDocument(); ).toBeInTheDocument();
}); });
@ -72,10 +72,10 @@ describe("ChatCompletionDetailView", () => {
isLoading={false} isLoading={false}
error={null} error={null}
id="notfound-id" id="notfound-id"
/>, />
); );
expect( expect(
screen.getByText("No details found for ID: notfound-id."), screen.getByText("No details found for ID: notfound-id.")
).toBeInTheDocument(); ).toBeInTheDocument();
}); });
@ -100,7 +100,7 @@ describe("ChatCompletionDetailView", () => {
isLoading={false} isLoading={false}
error={null} error={null}
id={mockCompletion.id} id={mockCompletion.id}
/>, />
); );
// Input // Input
expect(screen.getByText("Input")).toBeInTheDocument(); expect(screen.getByText("Input")).toBeInTheDocument();
@ -112,7 +112,7 @@ describe("ChatCompletionDetailView", () => {
expect(screen.getByText("Properties")).toBeInTheDocument(); expect(screen.getByText("Properties")).toBeInTheDocument();
expect(screen.getByText("Created:")).toBeInTheDocument(); expect(screen.getByText("Created:")).toBeInTheDocument();
expect( expect(
screen.getByText(new Date(1710000000 * 1000).toLocaleString()), screen.getByText(new Date(1710000000 * 1000).toLocaleString())
).toBeInTheDocument(); ).toBeInTheDocument();
expect(screen.getByText("ID:")).toBeInTheDocument(); expect(screen.getByText("ID:")).toBeInTheDocument();
expect(screen.getByText("comp_123")).toBeInTheDocument(); expect(screen.getByText("comp_123")).toBeInTheDocument();
@ -150,7 +150,7 @@ describe("ChatCompletionDetailView", () => {
isLoading={false} isLoading={false}
error={null} error={null}
id={mockCompletion.id} id={mockCompletion.id}
/>, />
); );
// Output should include the tool call block (should be present twice: input and output) // Output should include the tool call block (should be present twice: input and output)
const toolCallLabels = screen.getAllByText("Tool Call"); const toolCallLabels = screen.getAllByText("Tool Call");
@ -178,13 +178,13 @@ describe("ChatCompletionDetailView", () => {
isLoading={false} isLoading={false}
error={null} error={null}
id={mockCompletion.id} id={mockCompletion.id}
/>, />
); );
// Input section should be present but empty // Input section should be present but empty
expect(screen.getByText("Input")).toBeInTheDocument(); expect(screen.getByText("Input")).toBeInTheDocument();
// Output section should show fallback message // Output section should show fallback message
expect( expect(
screen.getByText("No message found in assistant's choice."), screen.getByText("No message found in assistant's choice.")
).toBeInTheDocument(); ).toBeInTheDocument();
// Properties should show N/A for finish reason // Properties should show N/A for finish reason
expect(screen.getByText("Finish Reason:")).toBeInTheDocument(); expect(screen.getByText("Finish Reason:")).toBeInTheDocument();

View file

@ -53,14 +53,14 @@ export function ChatCompletionDetailView({
{completion.choices?.[0]?.message?.tool_calls && {completion.choices?.[0]?.message?.tool_calls &&
Array.isArray(completion.choices[0].message.tool_calls) && Array.isArray(completion.choices[0].message.tool_calls) &&
!completion.input_messages?.some( !completion.input_messages?.some(
(im) => im =>
im.role === "assistant" && im.role === "assistant" &&
im.tool_calls && im.tool_calls &&
Array.isArray(im.tool_calls) && Array.isArray(im.tool_calls) &&
im.tool_calls.length > 0, im.tool_calls.length > 0
) )
? completion.choices[0].message.tool_calls.map( ? completion.choices[0].message.tool_calls.map(
(toolCall: any, index: number) => { (toolCall: { function?: { name?: string } }, index: number) => {
const assistantToolCallMessage: ChatMessage = { const assistantToolCallMessage: ChatMessage = {
role: "assistant", role: "assistant",
tool_calls: [toolCall], tool_calls: [toolCall],
@ -72,7 +72,7 @@ export function ChatCompletionDetailView({
message={assistantToolCallMessage} message={assistantToolCallMessage}
/> />
); );
}, }
) )
: null} : null}
</CardContent> </CardContent>
@ -89,7 +89,7 @@ export function ChatCompletionDetailView({
/> />
) : ( ) : (
<p className="text-gray-500 italic text-sm"> <p className="text-gray-500 italic text-sm">
No message found in assistant's choice. No message found in assistant&apos;s choice.
</p> </p>
)} )}
</CardContent> </CardContent>
@ -120,13 +120,18 @@ export function ChatCompletionDetailView({
value={ value={
<div> <div>
<ul className="list-disc list-inside pl-4 mt-1"> <ul className="list-disc list-inside pl-4 mt-1">
{toolCalls.map((toolCall: any, index: number) => ( {toolCalls.map(
<li key={index}> (
<span className="text-gray-900 font-medium"> toolCall: { function?: { name?: string } },
{toolCall.function?.name || "N/A"} index: number
</span> ) => (
</li> <li key={index}>
))} <span className="text-gray-900 font-medium">
{toolCall.function?.name || "N/A"}
</span>
</li>
)
)}
</ul> </ul>
</div> </div>
} }

View file

@ -83,7 +83,7 @@ describe("ChatCompletionsTable", () => {
// Default pass-through implementations // Default pass-through implementations
truncateText.mockImplementation((text: string | undefined) => text); truncateText.mockImplementation((text: string | undefined) => text);
extractTextFromContentPart.mockImplementation((content: unknown) => extractTextFromContentPart.mockImplementation((content: unknown) =>
typeof content === "string" ? content : "extracted text", typeof content === "string" ? content : "extracted text"
); );
extractDisplayableText.mockImplementation((message: unknown) => { extractDisplayableText.mockImplementation((message: unknown) => {
const msg = message as { content?: string }; const msg = message as { content?: string };
@ -138,7 +138,7 @@ describe("ChatCompletionsTable", () => {
if (row) { if (row) {
fireEvent.click(row); fireEvent.click(row);
expect(mockPush).toHaveBeenCalledWith( expect(mockPush).toHaveBeenCalledWith(
"/logs/chat-completions/completion_123", "/logs/chat-completions/completion_123"
); );
} else { } else {
throw new Error('Row with "Test prompt" not found for router mock test.'); throw new Error('Row with "Test prompt" not found for router mock test.');
@ -162,7 +162,7 @@ describe("ChatCompletionsTable", () => {
expect(tableCaption).toBeInTheDocument(); expect(tableCaption).toBeInTheDocument();
if (tableCaption) { if (tableCaption) {
const captionSkeleton = tableCaption.querySelector( const captionSkeleton = tableCaption.querySelector(
'[data-slot="skeleton"]', '[data-slot="skeleton"]'
); );
expect(captionSkeleton).toBeInTheDocument(); expect(captionSkeleton).toBeInTheDocument();
} }
@ -172,7 +172,7 @@ describe("ChatCompletionsTable", () => {
expect(tableBody).toBeInTheDocument(); expect(tableBody).toBeInTheDocument();
if (tableBody) { if (tableBody) {
const bodySkeletons = tableBody.querySelectorAll( const bodySkeletons = tableBody.querySelectorAll(
'[data-slot="skeleton"]', '[data-slot="skeleton"]'
); );
expect(bodySkeletons.length).toBeGreaterThan(0); expect(bodySkeletons.length).toBeGreaterThan(0);
} }
@ -192,14 +192,14 @@ describe("ChatCompletionsTable", () => {
render(<ChatCompletionsTable {...defaultProps} />); render(<ChatCompletionsTable {...defaultProps} />);
expect( expect(
screen.getByText("Unable to load chat completions"), screen.getByText("Unable to load chat completions")
).toBeInTheDocument(); ).toBeInTheDocument();
expect(screen.getByText(errorMessage)).toBeInTheDocument(); expect(screen.getByText(errorMessage)).toBeInTheDocument();
}); });
test.each([{ name: "Error", message: "" }, {}])( test.each([{ name: "Error", message: "" }, {}])(
"renders default error message when error has no message", "renders default error message when error has no message",
(errorObject) => { errorObject => {
mockedUsePagination.mockReturnValue({ mockedUsePagination.mockReturnValue({
data: [], data: [],
status: "error", status: "error",
@ -210,14 +210,14 @@ describe("ChatCompletionsTable", () => {
render(<ChatCompletionsTable {...defaultProps} />); render(<ChatCompletionsTable {...defaultProps} />);
expect( expect(
screen.getByText("Unable to load chat completions"), screen.getByText("Unable to load chat completions")
).toBeInTheDocument(); ).toBeInTheDocument();
expect( expect(
screen.getByText( screen.getByText(
"An unexpected error occurred while loading the data.", "An unexpected error occurred while loading the data."
), )
).toBeInTheDocument(); ).toBeInTheDocument();
}, }
); );
}); });
@ -225,7 +225,7 @@ describe("ChatCompletionsTable", () => {
test('renders "No chat completions found." and no table when data array is empty', () => { test('renders "No chat completions found." and no table when data array is empty', () => {
render(<ChatCompletionsTable {...defaultProps} />); render(<ChatCompletionsTable {...defaultProps} />);
expect( expect(
screen.getByText("No chat completions found."), screen.getByText("No chat completions found.")
).toBeInTheDocument(); ).toBeInTheDocument();
// Ensure that the table structure is NOT rendered in the empty state // Ensure that the table structure is NOT rendered in the empty state
@ -292,7 +292,7 @@ describe("ChatCompletionsTable", () => {
// Table caption // Table caption
expect( expect(
screen.getByText("A list of your recent chat completions."), screen.getByText("A list of your recent chat completions.")
).toBeInTheDocument(); ).toBeInTheDocument();
// Table headers // Table headers
@ -306,14 +306,14 @@ describe("ChatCompletionsTable", () => {
expect(screen.getByText("Test output")).toBeInTheDocument(); expect(screen.getByText("Test output")).toBeInTheDocument();
expect(screen.getByText("llama-test-model")).toBeInTheDocument(); expect(screen.getByText("llama-test-model")).toBeInTheDocument();
expect( expect(
screen.getByText(new Date(1710000000 * 1000).toLocaleString()), screen.getByText(new Date(1710000000 * 1000).toLocaleString())
).toBeInTheDocument(); ).toBeInTheDocument();
expect(screen.getByText("Another input")).toBeInTheDocument(); expect(screen.getByText("Another input")).toBeInTheDocument();
expect(screen.getByText("Another output")).toBeInTheDocument(); expect(screen.getByText("Another output")).toBeInTheDocument();
expect(screen.getByText("llama-another-model")).toBeInTheDocument(); expect(screen.getByText("llama-another-model")).toBeInTheDocument();
expect( expect(
screen.getByText(new Date(1710001000 * 1000).toLocaleString()), screen.getByText(new Date(1710001000 * 1000).toLocaleString())
).toBeInTheDocument(); ).toBeInTheDocument();
}); });
}); });
@ -328,7 +328,7 @@ describe("ChatCompletionsTable", () => {
return typeof text === "string" && text.length > effectiveMaxLength return typeof text === "string" && text.length > effectiveMaxLength
? text.slice(0, effectiveMaxLength) + "..." ? text.slice(0, effectiveMaxLength) + "..."
: text; : text;
}, }
); );
const longInput = const longInput =
@ -368,7 +368,7 @@ describe("ChatCompletionsTable", () => {
// The truncated text should be present for both input and output // The truncated text should be present for both input and output
const truncatedTexts = screen.getAllByText( const truncatedTexts = screen.getAllByText(
longInput.slice(0, 10) + "...", longInput.slice(0, 10) + "..."
); );
expect(truncatedTexts.length).toBe(2); // one for input, one for output expect(truncatedTexts.length).toBe(2); // one for input, one for output
}); });
@ -420,7 +420,7 @@ describe("ChatCompletionsTable", () => {
// Verify the extracted text appears in the table // Verify the extracted text appears in the table
expect(screen.getByText("Extracted input")).toBeInTheDocument(); expect(screen.getByText("Extracted input")).toBeInTheDocument();
expect( expect(
screen.getByText("Extracted output from assistant"), screen.getByText("Extracted output from assistant")
).toBeInTheDocument(); ).toBeInTheDocument();
}); });
}); });

View file

@ -5,6 +5,7 @@ import {
UsePaginationOptions, UsePaginationOptions,
ListChatCompletionsResponse, ListChatCompletionsResponse,
} from "@/lib/types"; } from "@/lib/types";
import { ListChatCompletionsParams } from "@/lib/llama-stack-client";
import { LogsTable, LogTableRow } from "@/components/logs/logs-table"; import { LogsTable, LogTableRow } from "@/components/logs/logs-table";
import { import {
extractTextFromContentPart, extractTextFromContentPart,
@ -38,14 +39,14 @@ export function ChatCompletionsTable({
limit: number; limit: number;
model?: string; model?: string;
order?: string; order?: string;
}, }
) => { ) => {
const response = await client.chat.completions.list({ const response = await client.chat.completions.list({
after: params.after, after: params.after,
limit: params.limit, limit: params.limit,
...(params.model && { model: params.model }), ...(params.model && { model: params.model }),
...(params.order && { order: params.order }), ...(params.order && { order: params.order }),
} as any); } as ListChatCompletionsParams);
return response as ListChatCompletionsResponse; return response as ListChatCompletionsResponse;
}; };

View file

@ -37,21 +37,26 @@ export function ChatMessageItem({ message }: ChatMessageItemProps) {
) { ) {
return ( return (
<> <>
{message.tool_calls.map((toolCall: any, index: number) => { {message.tool_calls.map(
const formattedToolCall = formatToolCallToString(toolCall); (
const toolCallContent = ( toolCall: { function?: { name?: string; arguments?: unknown } },
<ToolCallBlock> index: number
{formattedToolCall || "Error: Could not display tool call"} ) => {
</ToolCallBlock> const formattedToolCall = formatToolCallToString(toolCall);
); const toolCallContent = (
return ( <ToolCallBlock>
<MessageBlock {formattedToolCall || "Error: Could not display tool call"}
key={index} </ToolCallBlock>
label="Tool Call" );
content={toolCallContent} return (
/> <MessageBlock
); key={index}
})} label="Tool Call"
content={toolCallContent}
/>
);
}
)}
</> </>
); );
} else { } else {

View file

@ -1,18 +1,18 @@
"use client" "use client";
import React, { useMemo, useState } from "react" import React, { useMemo, useState } from "react";
import { cva, type VariantProps } from "class-variance-authority" import { cva, type VariantProps } from "class-variance-authority";
import { motion } from "framer-motion" import { motion } from "framer-motion";
import { Ban, ChevronRight, Code2, Loader2, Terminal } from "lucide-react" import { Ban, ChevronRight, Code2, Loader2, Terminal } from "lucide-react";
import { cn } from "@/lib/utils" import { cn } from "@/lib/utils";
import { import {
Collapsible, Collapsible,
CollapsibleContent, CollapsibleContent,
CollapsibleTrigger, CollapsibleTrigger,
} from "@/components/ui/collapsible" } from "@/components/ui/collapsible";
import { FilePreview } from "@/components/ui/file-preview" import { FilePreview } from "@/components/ui/file-preview";
import { MarkdownRenderer } from "@/components/chat-playground/markdown-renderer" import { MarkdownRenderer } from "@/components/chat-playground/markdown-renderer";
const chatBubbleVariants = cva( const chatBubbleVariants = cva(
"group/message relative break-words rounded-lg p-3 text-sm sm:max-w-[70%]", "group/message relative break-words rounded-lg p-3 text-sm sm:max-w-[70%]",
@ -52,66 +52,66 @@ const chatBubbleVariants = cva(
}, },
], ],
} }
) );
type Animation = VariantProps<typeof chatBubbleVariants>["animation"] type Animation = VariantProps<typeof chatBubbleVariants>["animation"];
interface Attachment { interface Attachment {
name?: string name?: string;
contentType?: string contentType?: string;
url: string url: string;
} }
interface PartialToolCall { interface PartialToolCall {
state: "partial-call" state: "partial-call";
toolName: string toolName: string;
} }
interface ToolCall { interface ToolCall {
state: "call" state: "call";
toolName: string toolName: string;
} }
interface ToolResult { interface ToolResult {
state: "result" state: "result";
toolName: string toolName: string;
result: { result: {
__cancelled?: boolean __cancelled?: boolean;
[key: string]: any [key: string]: unknown;
} };
} }
type ToolInvocation = PartialToolCall | ToolCall | ToolResult type ToolInvocation = PartialToolCall | ToolCall | ToolResult;
interface ReasoningPart { interface ReasoningPart {
type: "reasoning" type: "reasoning";
reasoning: string reasoning: string;
} }
interface ToolInvocationPart { interface ToolInvocationPart {
type: "tool-invocation" type: "tool-invocation";
toolInvocation: ToolInvocation toolInvocation: ToolInvocation;
} }
interface TextPart { interface TextPart {
type: "text" type: "text";
text: string text: string;
} }
// For compatibility with AI SDK types, not used // For compatibility with AI SDK types, not used
interface SourcePart { interface SourcePart {
type: "source" type: "source";
source?: any source?: unknown;
} }
interface FilePart { interface FilePart {
type: "file" type: "file";
mimeType: string mimeType: string;
data: string data: string;
} }
interface StepStartPart { interface StepStartPart {
type: "step-start" type: "step-start";
} }
type MessagePart = type MessagePart =
@ -120,22 +120,22 @@ type MessagePart =
| ToolInvocationPart | ToolInvocationPart
| SourcePart | SourcePart
| FilePart | FilePart
| StepStartPart | StepStartPart;
export interface Message { export interface Message {
id: string id: string;
role: "user" | "assistant" | (string & {}) role: "user" | "assistant" | (string & {});
content: string content: string;
createdAt?: Date createdAt?: Date;
experimental_attachments?: Attachment[] experimental_attachments?: Attachment[];
toolInvocations?: ToolInvocation[] toolInvocations?: ToolInvocation[];
parts?: MessagePart[] parts?: MessagePart[];
} }
export interface ChatMessageProps extends Message { export interface ChatMessageProps extends Message {
showTimeStamp?: boolean showTimeStamp?: boolean;
animation?: Animation animation?: Animation;
actions?: React.ReactNode actions?: React.ReactNode;
} }
export const ChatMessage: React.FC<ChatMessageProps> = ({ export const ChatMessage: React.FC<ChatMessageProps> = ({
@ -150,21 +150,21 @@ export const ChatMessage: React.FC<ChatMessageProps> = ({
parts, parts,
}) => { }) => {
const files = useMemo(() => { const files = useMemo(() => {
return experimental_attachments?.map((attachment) => { return experimental_attachments?.map(attachment => {
const dataArray = dataUrlToUint8Array(attachment.url) const dataArray = dataUrlToUint8Array(attachment.url);
const file = new File([dataArray], attachment.name ?? "Unknown", { const file = new File([dataArray], attachment.name ?? "Unknown", {
type: attachment.contentType, type: attachment.contentType,
}) });
return file return file;
}) });
}, [experimental_attachments]) }, [experimental_attachments]);
const isUser = role === "user" const isUser = role === "user";
const formattedTime = createdAt?.toLocaleTimeString("en-US", { const formattedTime = createdAt?.toLocaleTimeString("en-US", {
hour: "2-digit", hour: "2-digit",
minute: "2-digit", minute: "2-digit",
}) });
if (isUser) { if (isUser) {
return ( return (
@ -174,7 +174,7 @@ export const ChatMessage: React.FC<ChatMessageProps> = ({
{files ? ( {files ? (
<div className="mb-1 flex flex-wrap gap-2"> <div className="mb-1 flex flex-wrap gap-2">
{files.map((file, index) => { {files.map((file, index) => {
return <FilePreview file={file} key={index} /> return <FilePreview file={file} key={index} />;
})} })}
</div> </div>
) : null} ) : null}
@ -195,7 +195,7 @@ export const ChatMessage: React.FC<ChatMessageProps> = ({
</time> </time>
) : null} ) : null}
</div> </div>
) );
} }
if (parts && parts.length > 0) { if (parts && parts.length > 0) {
@ -230,23 +230,23 @@ export const ChatMessage: React.FC<ChatMessageProps> = ({
</time> </time>
) : null} ) : null}
</div> </div>
) );
} else if (part.type === "reasoning") { } else if (part.type === "reasoning") {
return <ReasoningBlock key={`reasoning-${index}`} part={part} /> return <ReasoningBlock key={`reasoning-${index}`} part={part} />;
} else if (part.type === "tool-invocation") { } else if (part.type === "tool-invocation") {
return ( return (
<ToolCall <ToolCall
key={`tool-${index}`} key={`tool-${index}`}
toolInvocations={[part.toolInvocation]} toolInvocations={[part.toolInvocation]}
/> />
) );
} }
return null return null;
}) });
} }
if (toolInvocations && toolInvocations.length > 0) { if (toolInvocations && toolInvocations.length > 0) {
return <ToolCall toolInvocations={toolInvocations} /> return <ToolCall toolInvocations={toolInvocations} />;
} }
return ( return (
@ -272,17 +272,17 @@ export const ChatMessage: React.FC<ChatMessageProps> = ({
</time> </time>
) : null} ) : null}
</div> </div>
) );
} };
function dataUrlToUint8Array(data: string) { function dataUrlToUint8Array(data: string) {
const base64 = data.split(",")[1] const base64 = data.split(",")[1];
const buf = Buffer.from(base64, "base64") const buf = Buffer.from(base64, "base64");
return new Uint8Array(buf) return new Uint8Array(buf);
} }
const ReasoningBlock = ({ part }: { part: ReasoningPart }) => { const ReasoningBlock = ({ part }: { part: ReasoningPart }) => {
const [isOpen, setIsOpen] = useState(false) const [isOpen, setIsOpen] = useState(false);
return ( return (
<div className="mb-2 flex flex-col items-start sm:max-w-[70%]"> <div className="mb-2 flex flex-col items-start sm:max-w-[70%]">
@ -319,20 +319,20 @@ const ReasoningBlock = ({ part }: { part: ReasoningPart }) => {
</CollapsibleContent> </CollapsibleContent>
</Collapsible> </Collapsible>
</div> </div>
) );
} };
function ToolCall({ function ToolCall({
toolInvocations, toolInvocations,
}: Pick<ChatMessageProps, "toolInvocations">) { }: Pick<ChatMessageProps, "toolInvocations">) {
if (!toolInvocations?.length) return null if (!toolInvocations?.length) return null;
return ( return (
<div className="flex flex-col items-start gap-2"> <div className="flex flex-col items-start gap-2">
{toolInvocations.map((invocation, index) => { {toolInvocations.map((invocation, index) => {
const isCancelled = const isCancelled =
invocation.state === "result" && invocation.state === "result" &&
invocation.result.__cancelled === true invocation.result.__cancelled === true;
if (isCancelled) { if (isCancelled) {
return ( return (
@ -350,7 +350,7 @@ function ToolCall({
</span> </span>
</span> </span>
</div> </div>
) );
} }
switch (invocation.state) { switch (invocation.state) {
@ -373,7 +373,7 @@ function ToolCall({
</span> </span>
<Loader2 className="h-3 w-3 animate-spin" /> <Loader2 className="h-3 w-3 animate-spin" />
</div> </div>
) );
case "result": case "result":
return ( return (
<div <div
@ -395,11 +395,11 @@ function ToolCall({
{JSON.stringify(invocation.result, null, 2)} {JSON.stringify(invocation.result, null, 2)}
</pre> </pre>
</div> </div>
) );
default: default:
return null return null;
} }
})} })}
</div> </div>
) );
} }

View file

@ -1,4 +1,4 @@
"use client" "use client";
import { import {
forwardRef, forwardRef,
@ -6,48 +6,48 @@ import {
useRef, useRef,
useState, useState,
type ReactElement, type ReactElement,
} from "react" } from "react";
import { ArrowDown, ThumbsDown, ThumbsUp } from "lucide-react" import { ArrowDown, ThumbsDown, ThumbsUp } from "lucide-react";
import { cn } from "@/lib/utils" import { cn } from "@/lib/utils";
import { useAutoScroll } from "@/hooks/use-auto-scroll" import { useAutoScroll } from "@/hooks/use-auto-scroll";
import { Button } from "@/components/ui/button" import { Button } from "@/components/ui/button";
import { type Message } from "@/components/chat-playground/chat-message" import { type Message } from "@/components/chat-playground/chat-message";
import { CopyButton } from "@/components/ui/copy-button" import { CopyButton } from "@/components/ui/copy-button";
import { MessageInput } from "@/components/chat-playground/message-input" import { MessageInput } from "@/components/chat-playground/message-input";
import { MessageList } from "@/components/chat-playground/message-list" import { MessageList } from "@/components/chat-playground/message-list";
import { PromptSuggestions } from "@/components/chat-playground/prompt-suggestions" import { PromptSuggestions } from "@/components/chat-playground/prompt-suggestions";
interface ChatPropsBase { interface ChatPropsBase {
handleSubmit: ( handleSubmit: (
event?: { preventDefault?: () => void }, event?: { preventDefault?: () => void },
options?: { experimental_attachments?: FileList } options?: { experimental_attachments?: FileList }
) => void ) => void;
messages: Array<Message> messages: Array<Message>;
input: string input: string;
className?: string className?: string;
handleInputChange: React.ChangeEventHandler<HTMLTextAreaElement> handleInputChange: React.ChangeEventHandler<HTMLTextAreaElement>;
isGenerating: boolean isGenerating: boolean;
stop?: () => void stop?: () => void;
onRateResponse?: ( onRateResponse?: (
messageId: string, messageId: string,
rating: "thumbs-up" | "thumbs-down" rating: "thumbs-up" | "thumbs-down"
) => void ) => void;
setMessages?: (messages: any[]) => void setMessages?: (messages: Message[]) => void;
transcribeAudio?: (blob: Blob) => Promise<string> transcribeAudio?: (blob: Blob) => Promise<string>;
} }
interface ChatPropsWithoutSuggestions extends ChatPropsBase { interface ChatPropsWithoutSuggestions extends ChatPropsBase {
append?: never append?: never;
suggestions?: never suggestions?: never;
} }
interface ChatPropsWithSuggestions extends ChatPropsBase { interface ChatPropsWithSuggestions extends ChatPropsBase {
append: (message: { role: "user"; content: string }) => void append: (message: { role: "user"; content: string }) => void;
suggestions: string[] suggestions: string[];
} }
type ChatProps = ChatPropsWithoutSuggestions | ChatPropsWithSuggestions type ChatProps = ChatPropsWithoutSuggestions | ChatPropsWithSuggestions;
export function Chat({ export function Chat({
messages, messages,
@ -63,34 +63,34 @@ export function Chat({
setMessages, setMessages,
transcribeAudio, transcribeAudio,
}: ChatProps) { }: ChatProps) {
const lastMessage = messages.at(-1) const lastMessage = messages.at(-1);
const isEmpty = messages.length === 0 const isEmpty = messages.length === 0;
const isTyping = lastMessage?.role === "user" const isTyping = lastMessage?.role === "user";
const messagesRef = useRef(messages) const messagesRef = useRef(messages);
messagesRef.current = messages messagesRef.current = messages;
// Enhanced stop function that marks pending tool calls as cancelled // Enhanced stop function that marks pending tool calls as cancelled
const handleStop = useCallback(() => { const handleStop = useCallback(() => {
stop?.() stop?.();
if (!setMessages) return if (!setMessages) return;
const latestMessages = [...messagesRef.current] const latestMessages = [...messagesRef.current];
const lastAssistantMessage = latestMessages.findLast( const lastAssistantMessage = latestMessages.findLast(
(m) => m.role === "assistant" m => m.role === "assistant"
) );
if (!lastAssistantMessage) return if (!lastAssistantMessage) return;
let needsUpdate = false let needsUpdate = false;
let updatedMessage = { ...lastAssistantMessage } let updatedMessage = { ...lastAssistantMessage };
if (lastAssistantMessage.toolInvocations) { if (lastAssistantMessage.toolInvocations) {
const updatedToolInvocations = lastAssistantMessage.toolInvocations.map( const updatedToolInvocations = lastAssistantMessage.toolInvocations.map(
(toolInvocation) => { toolInvocation => {
if (toolInvocation.state === "call") { if (toolInvocation.state === "call") {
needsUpdate = true needsUpdate = true;
return { return {
...toolInvocation, ...toolInvocation,
state: "result", state: "result",
@ -98,61 +98,66 @@ export function Chat({
content: "Tool execution was cancelled", content: "Tool execution was cancelled",
__cancelled: true, // Special marker to indicate cancellation __cancelled: true, // Special marker to indicate cancellation
}, },
} as const } as const;
} }
return toolInvocation return toolInvocation;
} }
) );
if (needsUpdate) { if (needsUpdate) {
updatedMessage = { updatedMessage = {
...updatedMessage, ...updatedMessage,
toolInvocations: updatedToolInvocations, toolInvocations: updatedToolInvocations,
} };
} }
} }
if (lastAssistantMessage.parts && lastAssistantMessage.parts.length > 0) { if (lastAssistantMessage.parts && lastAssistantMessage.parts.length > 0) {
const updatedParts = lastAssistantMessage.parts.map((part: any) => { const updatedParts = lastAssistantMessage.parts.map(
if ( (part: {
part.type === "tool-invocation" && type: string;
part.toolInvocation && toolInvocation?: { state: string; toolName: string };
part.toolInvocation.state === "call" }) => {
) { if (
needsUpdate = true part.type === "tool-invocation" &&
return { part.toolInvocation &&
...part, part.toolInvocation.state === "call"
toolInvocation: { ) {
...part.toolInvocation, needsUpdate = true;
state: "result", return {
result: { ...part,
content: "Tool execution was cancelled", toolInvocation: {
__cancelled: true, ...part.toolInvocation,
state: "result",
result: {
content: "Tool execution was cancelled",
__cancelled: true,
},
}, },
}, };
} }
return part;
} }
return part );
})
if (needsUpdate) { if (needsUpdate) {
updatedMessage = { updatedMessage = {
...updatedMessage, ...updatedMessage,
parts: updatedParts, parts: updatedParts,
} };
} }
} }
if (needsUpdate) { if (needsUpdate) {
const messageIndex = latestMessages.findIndex( const messageIndex = latestMessages.findIndex(
(m) => m.id === lastAssistantMessage.id m => m.id === lastAssistantMessage.id
) );
if (messageIndex !== -1) { if (messageIndex !== -1) {
latestMessages[messageIndex] = updatedMessage latestMessages[messageIndex] = updatedMessage;
setMessages(latestMessages) setMessages(latestMessages);
} }
} }
}, [stop, setMessages, messagesRef]) }, [stop, setMessages, messagesRef]);
const messageOptions = useCallback( const messageOptions = useCallback(
(message: Message) => ({ (message: Message) => ({
@ -189,7 +194,7 @@ export function Chat({
), ),
}), }),
[onRateResponse] [onRateResponse]
) );
return ( return (
<ChatContainer className={className}> <ChatContainer className={className}>
@ -237,15 +242,15 @@ export function Chat({
</div> </div>
</div> </div>
</ChatContainer> </ChatContainer>
) );
} }
Chat.displayName = "Chat" Chat.displayName = "Chat";
export function ChatMessages({ export function ChatMessages({
messages, messages,
children, children,
}: React.PropsWithChildren<{ }: React.PropsWithChildren<{
messages: Message[] messages: Message[];
}>) { }>) {
const { const {
containerRef, containerRef,
@ -253,7 +258,7 @@ export function ChatMessages({
handleScroll, handleScroll,
shouldAutoScroll, shouldAutoScroll,
handleTouchStart, handleTouchStart,
} = useAutoScroll([messages]) } = useAutoScroll([messages]);
return ( return (
<div <div
@ -281,7 +286,7 @@ export function ChatMessages({
</div> </div>
)} )}
</div> </div>
) );
} }
export const ChatContainer = forwardRef< export const ChatContainer = forwardRef<
@ -294,56 +299,56 @@ export const ChatContainer = forwardRef<
className={cn("flex flex-col max-h-full w-full", className)} className={cn("flex flex-col max-h-full w-full", className)}
{...props} {...props}
/> />
) );
}) });
ChatContainer.displayName = "ChatContainer" ChatContainer.displayName = "ChatContainer";
interface ChatFormProps { interface ChatFormProps {
className?: string className?: string;
isPending: boolean isPending: boolean;
handleSubmit: ( handleSubmit: (
event?: { preventDefault?: () => void }, event?: { preventDefault?: () => void },
options?: { experimental_attachments?: FileList } options?: { experimental_attachments?: FileList }
) => void ) => void;
children: (props: { children: (props: {
files: File[] | null files: File[] | null;
setFiles: React.Dispatch<React.SetStateAction<File[] | null>> setFiles: React.Dispatch<React.SetStateAction<File[] | null>>;
}) => ReactElement }) => ReactElement;
} }
export const ChatForm = forwardRef<HTMLFormElement, ChatFormProps>( export const ChatForm = forwardRef<HTMLFormElement, ChatFormProps>(
({ children, handleSubmit, isPending, className }, ref) => { ({ children, handleSubmit, isPending, className }, ref) => {
const [files, setFiles] = useState<File[] | null>(null) const [files, setFiles] = useState<File[] | null>(null);
const onSubmit = (event: React.FormEvent) => { const onSubmit = (event: React.FormEvent) => {
// if (isPending) { if (isPending) {
// event.preventDefault() event.preventDefault();
// return return;
// }
if (!files) {
handleSubmit(event)
return
} }
const fileList = createFileList(files) if (!files) {
handleSubmit(event, { experimental_attachments: fileList }) handleSubmit(event);
setFiles(null) return;
} }
const fileList = createFileList(files);
handleSubmit(event, { experimental_attachments: fileList });
setFiles(null);
};
return ( return (
<form ref={ref} onSubmit={onSubmit} className={className}> <form ref={ref} onSubmit={onSubmit} className={className}>
{children({ files, setFiles })} {children({ files, setFiles })}
</form> </form>
) );
} }
) );
ChatForm.displayName = "ChatForm" ChatForm.displayName = "ChatForm";
function createFileList(files: File[] | FileList): FileList { function createFileList(files: File[] | FileList): FileList {
const dataTransfer = new DataTransfer() const dataTransfer = new DataTransfer();
for (const file of Array.from(files)) { for (const file of Array.from(files)) {
dataTransfer.items.add(file) dataTransfer.items.add(file);
} }
return dataTransfer.files return dataTransfer.files;
} }

View file

@ -1,11 +1,11 @@
"use client" "use client";
import { AnimatePresence, motion } from "framer-motion" import { AnimatePresence, motion } from "framer-motion";
import { X } from "lucide-react" import { X } from "lucide-react";
interface InterruptPromptProps { interface InterruptPromptProps {
isOpen: boolean isOpen: boolean;
close: () => void close: () => void;
} }
export function InterruptPrompt({ isOpen, close }: InterruptPromptProps) { export function InterruptPrompt({ isOpen, close }: InterruptPromptProps) {
@ -37,5 +37,5 @@ export function InterruptPrompt({ isOpen, close }: InterruptPromptProps) {
</motion.div> </motion.div>
)} )}
</AnimatePresence> </AnimatePresence>
) );
} }

View file

@ -1,12 +1,12 @@
import React, { Suspense, useEffect, useState } from "react" import React, { Suspense, useEffect, useState } from "react";
import Markdown from "react-markdown" import Markdown from "react-markdown";
import remarkGfm from "remark-gfm" import remarkGfm from "remark-gfm";
import { cn } from "@/lib/utils" import { cn } from "@/lib/utils";
import { CopyButton } from "@/components/ui/copy-button" import { CopyButton } from "@/components/ui/copy-button";
interface MarkdownRendererProps { interface MarkdownRendererProps {
children: string children: string;
} }
export function MarkdownRenderer({ children }: MarkdownRendererProps) { export function MarkdownRenderer({ children }: MarkdownRendererProps) {
@ -16,34 +16,34 @@ export function MarkdownRenderer({ children }: MarkdownRendererProps) {
{children} {children}
</Markdown> </Markdown>
</div> </div>
) );
} }
interface HighlightedPre extends React.HTMLAttributes<HTMLPreElement> { interface HighlightedPre extends React.HTMLAttributes<HTMLPreElement> {
children: string children: string;
language: string language: string;
} }
const HighlightedPre = React.memo( const HighlightedPre = React.memo(
({ children, language, ...props }: HighlightedPre) => { ({ children, language, ...props }: HighlightedPre) => {
const [tokens, setTokens] = useState<any[] | null>(null) const [tokens, setTokens] = useState<unknown[] | null>(null);
const [isSupported, setIsSupported] = useState(false) const [isSupported, setIsSupported] = useState(false);
useEffect(() => { useEffect(() => {
let mounted = true let mounted = true;
const loadAndHighlight = async () => { const loadAndHighlight = async () => {
try { try {
const { codeToTokens, bundledLanguages } = await import("shiki") const { codeToTokens, bundledLanguages } = await import("shiki");
if (!mounted) return if (!mounted) return;
if (!(language in bundledLanguages)) { if (!(language in bundledLanguages)) {
setIsSupported(false) setIsSupported(false);
return return;
} }
setIsSupported(true) setIsSupported(true);
const { tokens: highlightedTokens } = await codeToTokens(children, { const { tokens: highlightedTokens } = await codeToTokens(children, {
lang: language as keyof typeof bundledLanguages, lang: language as keyof typeof bundledLanguages,
@ -52,31 +52,31 @@ const HighlightedPre = React.memo(
light: "github-light", light: "github-light",
dark: "github-dark", dark: "github-dark",
}, },
}) });
if (mounted) { if (mounted) {
setTokens(highlightedTokens) setTokens(highlightedTokens);
} }
} catch (error) { } catch {
if (mounted) { if (mounted) {
setIsSupported(false) setIsSupported(false);
} }
} }
} };
loadAndHighlight() loadAndHighlight();
return () => { return () => {
mounted = false mounted = false;
} };
}, [children, language]) }, [children, language]);
if (!isSupported) { if (!isSupported) {
return <pre {...props}>{children}</pre> return <pre {...props}>{children}</pre>;
} }
if (!tokens) { if (!tokens) {
return <pre {...props}>{children}</pre> return <pre {...props}>{children}</pre>;
} }
return ( return (
@ -89,7 +89,7 @@ const HighlightedPre = React.memo(
const style = const style =
typeof token.htmlStyle === "string" typeof token.htmlStyle === "string"
? undefined ? undefined
: token.htmlStyle : token.htmlStyle;
return ( return (
<span <span
@ -99,7 +99,7 @@ const HighlightedPre = React.memo(
> >
{token.content} {token.content}
</span> </span>
) );
})} })}
</span> </span>
{lineIndex !== tokens.length - 1 && "\n"} {lineIndex !== tokens.length - 1 && "\n"}
@ -107,15 +107,15 @@ const HighlightedPre = React.memo(
))} ))}
</code> </code>
</pre> </pre>
) );
} }
) );
HighlightedPre.displayName = "HighlightedCode" HighlightedPre.displayName = "HighlightedCode";
interface CodeBlockProps extends React.HTMLAttributes<HTMLPreElement> { interface CodeBlockProps extends React.HTMLAttributes<HTMLPreElement> {
children: React.ReactNode children: React.ReactNode;
className?: string className?: string;
language: string language: string;
} }
const CodeBlock = ({ const CodeBlock = ({
@ -127,12 +127,12 @@ const CodeBlock = ({
const code = const code =
typeof children === "string" typeof children === "string"
? children ? children
: childrenTakeAllStringContents(children) : childrenTakeAllStringContents(children);
const preClass = cn( const preClass = cn(
"overflow-x-scroll rounded-md border bg-background/50 p-4 font-mono text-sm [scrollbar-width:none]", "overflow-x-scroll rounded-md border bg-background/50 p-4 font-mono text-sm [scrollbar-width:none]",
className className
) );
return ( return (
<div className="group/code relative mb-4"> <div className="group/code relative mb-4">
@ -152,27 +152,27 @@ const CodeBlock = ({
<CopyButton content={code} copyMessage="Copied code to clipboard" /> <CopyButton content={code} copyMessage="Copied code to clipboard" />
</div> </div>
</div> </div>
) );
} };
function childrenTakeAllStringContents(element: any): string { function childrenTakeAllStringContents(element: unknown): string {
if (typeof element === "string") { if (typeof element === "string") {
return element return element;
} }
if (element?.props?.children) { if (element?.props?.children) {
let children = element.props.children const children = element.props.children;
if (Array.isArray(children)) { if (Array.isArray(children)) {
return children return children
.map((child) => childrenTakeAllStringContents(child)) .map(child => childrenTakeAllStringContents(child))
.join("") .join("");
} else { } else {
return childrenTakeAllStringContents(children) return childrenTakeAllStringContents(children);
} }
} }
return "" return "";
} }
const COMPONENTS = { const COMPONENTS = {
@ -184,8 +184,15 @@ const COMPONENTS = {
strong: withClass("strong", "font-semibold"), strong: withClass("strong", "font-semibold"),
a: withClass("a", "text-primary underline underline-offset-2"), a: withClass("a", "text-primary underline underline-offset-2"),
blockquote: withClass("blockquote", "border-l-2 border-primary pl-4"), blockquote: withClass("blockquote", "border-l-2 border-primary pl-4"),
code: ({ children, className, node, ...rest }: any) => { code: ({
const match = /language-(\w+)/.exec(className || "") children,
className,
...rest
}: {
children: React.ReactNode;
className?: string;
}) => {
const match = /language-(\w+)/.exec(className || "");
return match ? ( return match ? (
<CodeBlock className={className} language={match[1]} {...rest}> <CodeBlock className={className} language={match[1]} {...rest}>
{children} {children}
@ -199,9 +206,9 @@ const COMPONENTS = {
> >
{children} {children}
</code> </code>
) );
}, },
pre: ({ children }: any) => children, pre: ({ children }: { children: React.ReactNode }) => children,
ol: withClass("ol", "list-decimal space-y-2 pl-6"), ol: withClass("ol", "list-decimal space-y-2 pl-6"),
ul: withClass("ul", "list-disc space-y-2 pl-6"), ul: withClass("ul", "list-disc space-y-2 pl-6"),
li: withClass("li", "my-1.5"), li: withClass("li", "my-1.5"),
@ -220,14 +227,14 @@ const COMPONENTS = {
tr: withClass("tr", "m-0 border-t p-0 even:bg-muted"), tr: withClass("tr", "m-0 border-t p-0 even:bg-muted"),
p: withClass("p", "whitespace-pre-wrap"), p: withClass("p", "whitespace-pre-wrap"),
hr: withClass("hr", "border-foreground/20"), hr: withClass("hr", "border-foreground/20"),
} };
function withClass(Tag: keyof JSX.IntrinsicElements, classes: string) { function withClass(Tag: keyof JSX.IntrinsicElements, classes: string) {
const Component = ({ node, ...props }: any) => ( const Component = ({ ...props }: Record<string, unknown>) => (
<Tag className={classes} {...props} /> <Tag className={classes} {...props} />
) );
Component.displayName = Tag Component.displayName = Tag;
return Component return Component;
} }
export default MarkdownRenderer export default MarkdownRenderer;

View file

@ -1,41 +1,41 @@
"use client" "use client";
import React, { useEffect, useRef, useState } from "react" import React, { useEffect, useRef, useState } from "react";
import { AnimatePresence, motion } from "framer-motion" import { AnimatePresence, motion } from "framer-motion";
import { ArrowUp, Info, Loader2, Mic, Paperclip, Square } from "lucide-react" import { ArrowUp, Info, Loader2, Mic, Paperclip, Square } from "lucide-react";
import { omit } from "remeda" import { omit } from "remeda";
import { cn } from "@/lib/utils" import { cn } from "@/lib/utils";
import { useAudioRecording } from "@/hooks/use-audio-recording" import { useAudioRecording } from "@/hooks/use-audio-recording";
import { useAutosizeTextArea } from "@/hooks/use-autosize-textarea" import { useAutosizeTextArea } from "@/hooks/use-autosize-textarea";
import { AudioVisualizer } from "@/components/ui/audio-visualizer" import { AudioVisualizer } from "@/components/ui/audio-visualizer";
import { Button } from "@/components/ui/button" import { Button } from "@/components/ui/button";
import { FilePreview } from "@/components/ui/file-preview" import { FilePreview } from "@/components/ui/file-preview";
import { InterruptPrompt } from "@/components/chat-playground/interrupt-prompt" import { InterruptPrompt } from "@/components/chat-playground/interrupt-prompt";
interface MessageInputBaseProps interface MessageInputBaseProps
extends React.TextareaHTMLAttributes<HTMLTextAreaElement> { extends React.TextareaHTMLAttributes<HTMLTextAreaElement> {
value: string value: string;
submitOnEnter?: boolean submitOnEnter?: boolean;
stop?: () => void stop?: () => void;
isGenerating: boolean isGenerating: boolean;
enableInterrupt?: boolean enableInterrupt?: boolean;
transcribeAudio?: (blob: Blob) => Promise<string> transcribeAudio?: (blob: Blob) => Promise<string>;
} }
interface MessageInputWithoutAttachmentProps extends MessageInputBaseProps { interface MessageInputWithoutAttachmentProps extends MessageInputBaseProps {
allowAttachments?: false allowAttachments?: false;
} }
interface MessageInputWithAttachmentsProps extends MessageInputBaseProps { interface MessageInputWithAttachmentsProps extends MessageInputBaseProps {
allowAttachments: true allowAttachments: true;
files: File[] | null files: File[] | null;
setFiles: React.Dispatch<React.SetStateAction<File[] | null>> setFiles: React.Dispatch<React.SetStateAction<File[] | null>>;
} }
type MessageInputProps = type MessageInputProps =
| MessageInputWithoutAttachmentProps | MessageInputWithoutAttachmentProps
| MessageInputWithAttachmentsProps | MessageInputWithAttachmentsProps;
export function MessageInput({ export function MessageInput({
placeholder = "Ask AI...", placeholder = "Ask AI...",
@ -48,8 +48,8 @@ export function MessageInput({
transcribeAudio, transcribeAudio,
...props ...props
}: MessageInputProps) { }: MessageInputProps) {
const [isDragging, setIsDragging] = useState(false) const [isDragging, setIsDragging] = useState(false);
const [showInterruptPrompt, setShowInterruptPrompt] = useState(false) const [showInterruptPrompt, setShowInterruptPrompt] = useState(false);
const { const {
isListening, isListening,
@ -61,123 +61,124 @@ export function MessageInput({
stopRecording, stopRecording,
} = useAudioRecording({ } = useAudioRecording({
transcribeAudio, transcribeAudio,
onTranscriptionComplete: (text) => { onTranscriptionComplete: text => {
props.onChange?.({ target: { value: text } } as any) props.onChange?.({
target: { value: text },
} as React.ChangeEvent<HTMLTextAreaElement>);
}, },
}) });
useEffect(() => { useEffect(() => {
if (!isGenerating) { if (!isGenerating) {
setShowInterruptPrompt(false) setShowInterruptPrompt(false);
} }
}, [isGenerating]) }, [isGenerating]);
const addFiles = (files: File[] | null) => { const addFiles = (files: File[] | null) => {
if (props.allowAttachments) { if (props.allowAttachments) {
props.setFiles((currentFiles) => { props.setFiles(currentFiles => {
if (currentFiles === null) { if (currentFiles === null) {
return files return files;
} }
if (files === null) { if (files === null) {
return currentFiles return currentFiles;
} }
return [...currentFiles, ...files] return [...currentFiles, ...files];
}) });
} }
} };
const onDragOver = (event: React.DragEvent) => { const onDragOver = (event: React.DragEvent) => {
if (props.allowAttachments !== true) return if (props.allowAttachments !== true) return;
event.preventDefault() event.preventDefault();
setIsDragging(true) setIsDragging(true);
} };
const onDragLeave = (event: React.DragEvent) => { const onDragLeave = (event: React.DragEvent) => {
if (props.allowAttachments !== true) return if (props.allowAttachments !== true) return;
event.preventDefault() event.preventDefault();
setIsDragging(false) setIsDragging(false);
} };
const onDrop = (event: React.DragEvent) => { const onDrop = (event: React.DragEvent) => {
setIsDragging(false) setIsDragging(false);
if (props.allowAttachments !== true) return if (props.allowAttachments !== true) return;
event.preventDefault() event.preventDefault();
const dataTransfer = event.dataTransfer const dataTransfer = event.dataTransfer;
if (dataTransfer.files.length) { if (dataTransfer.files.length) {
addFiles(Array.from(dataTransfer.files)) addFiles(Array.from(dataTransfer.files));
} }
} };
const onPaste = (event: React.ClipboardEvent) => { const onPaste = (event: React.ClipboardEvent) => {
const items = event.clipboardData?.items const items = event.clipboardData?.items;
if (!items) return if (!items) return;
const text = event.clipboardData.getData("text") const text = event.clipboardData.getData("text");
if (text && text.length > 500 && props.allowAttachments) { if (text && text.length > 500 && props.allowAttachments) {
event.preventDefault() event.preventDefault();
const blob = new Blob([text], { type: "text/plain" }) const blob = new Blob([text], { type: "text/plain" });
const file = new File([blob], "Pasted text", { const file = new File([blob], "Pasted text", {
type: "text/plain", type: "text/plain",
lastModified: Date.now(), lastModified: Date.now(),
}) });
addFiles([file]) addFiles([file]);
return return;
} }
const files = Array.from(items) const files = Array.from(items)
.map((item) => item.getAsFile()) .map(item => item.getAsFile())
.filter((file) => file !== null) .filter(file => file !== null);
if (props.allowAttachments && files.length > 0) { if (props.allowAttachments && files.length > 0) {
addFiles(files) addFiles(files);
} }
} };
const onKeyDown = (event: React.KeyboardEvent<HTMLTextAreaElement>) => { const onKeyDown = (event: React.KeyboardEvent<HTMLTextAreaElement>) => {
if (submitOnEnter && event.key === "Enter" && !event.shiftKey) { if (submitOnEnter && event.key === "Enter" && !event.shiftKey) {
event.preventDefault() event.preventDefault();
if (isGenerating && stop && enableInterrupt) { if (isGenerating && stop && enableInterrupt) {
if (showInterruptPrompt) { if (showInterruptPrompt) {
stop() stop();
setShowInterruptPrompt(false) setShowInterruptPrompt(false);
event.currentTarget.form?.requestSubmit() event.currentTarget.form?.requestSubmit();
} else if ( } else if (
props.value || props.value ||
(props.allowAttachments && props.files?.length) (props.allowAttachments && props.files?.length)
) { ) {
setShowInterruptPrompt(true) setShowInterruptPrompt(true);
return return;
} }
} }
event.currentTarget.form?.requestSubmit() event.currentTarget.form?.requestSubmit();
} }
onKeyDownProp?.(event) onKeyDownProp?.(event);
} };
const textAreaRef = useRef<HTMLTextAreaElement>(null) const textAreaRef = useRef<HTMLTextAreaElement>(null);
const [textAreaHeight, setTextAreaHeight] = useState<number>(0) const [textAreaHeight, setTextAreaHeight] = useState<number>(0);
useEffect(() => { useEffect(() => {
if (textAreaRef.current) { if (textAreaRef.current) {
setTextAreaHeight(textAreaRef.current.offsetHeight) setTextAreaHeight(textAreaRef.current.offsetHeight);
} }
}, [props.value]) }, [props.value]);
const showFileList = const showFileList =
props.allowAttachments && props.files && props.files.length > 0 props.allowAttachments && props.files && props.files.length > 0;
useAutosizeTextArea({ useAutosizeTextArea({
ref: textAreaRef, ref: textAreaRef,
maxHeight: 240, maxHeight: 240,
borderWidth: 1, borderWidth: 1,
dependencies: [props.value, showFileList], dependencies: [props.value, showFileList],
}) });
return ( return (
<div <div
@ -220,24 +221,24 @@ export function MessageInput({
<div className="absolute inset-x-3 bottom-0 z-20 overflow-x-scroll py-3"> <div className="absolute inset-x-3 bottom-0 z-20 overflow-x-scroll py-3">
<div className="flex space-x-3"> <div className="flex space-x-3">
<AnimatePresence mode="popLayout"> <AnimatePresence mode="popLayout">
{props.files?.map((file) => { {props.files?.map(file => {
return ( return (
<FilePreview <FilePreview
key={file.name + String(file.lastModified)} key={file.name + String(file.lastModified)}
file={file} file={file}
onRemove={() => { onRemove={() => {
props.setFiles((files) => { props.setFiles(files => {
if (!files) return null if (!files) return null;
const filtered = Array.from(files).filter( const filtered = Array.from(files).filter(
(f) => f !== file f => f !== file
) );
if (filtered.length === 0) return null if (filtered.length === 0) return null;
return filtered return filtered;
}) });
}} }}
/> />
) );
})} })}
</AnimatePresence> </AnimatePresence>
</div> </div>
@ -256,8 +257,8 @@ export function MessageInput({
aria-label="Attach a file" aria-label="Attach a file"
disabled={true} disabled={true}
onClick={async () => { onClick={async () => {
const files = await showFileUploadDialog() const files = await showFileUploadDialog();
addFiles(files) addFiles(files);
}} }}
> >
<Paperclip className="h-4 w-4" /> <Paperclip className="h-4 w-4" />
@ -308,12 +309,12 @@ export function MessageInput({
onStopRecording={stopRecording} onStopRecording={stopRecording}
/> />
</div> </div>
) );
} }
MessageInput.displayName = "MessageInput" MessageInput.displayName = "MessageInput";
interface FileUploadOverlayProps { interface FileUploadOverlayProps {
isDragging: boolean isDragging: boolean;
} }
function FileUploadOverlay({ isDragging }: FileUploadOverlayProps) { function FileUploadOverlay({ isDragging }: FileUploadOverlayProps) {
@ -333,29 +334,29 @@ function FileUploadOverlay({ isDragging }: FileUploadOverlayProps) {
</motion.div> </motion.div>
)} )}
</AnimatePresence> </AnimatePresence>
) );
} }
function showFileUploadDialog() { function showFileUploadDialog() {
const input = document.createElement("input") const input = document.createElement("input");
input.type = "file" input.type = "file";
input.multiple = true input.multiple = true;
input.accept = "*/*" input.accept = "*/*";
input.click() input.click();
return new Promise<File[] | null>((resolve) => { return new Promise<File[] | null>(resolve => {
input.onchange = (e) => { input.onchange = e => {
const files = (e.currentTarget as HTMLInputElement).files const files = (e.currentTarget as HTMLInputElement).files;
if (files) { if (files) {
resolve(Array.from(files)) resolve(Array.from(files));
return return;
} }
resolve(null) resolve(null);
} };
}) });
} }
function TranscribingOverlay() { function TranscribingOverlay() {
@ -385,12 +386,12 @@ function TranscribingOverlay() {
Transcribing audio... Transcribing audio...
</p> </p>
</motion.div> </motion.div>
) );
} }
interface RecordingPromptProps { interface RecordingPromptProps {
isVisible: boolean isVisible: boolean;
onStopRecording: () => void onStopRecording: () => void;
} }
function RecordingPrompt({ isVisible, onStopRecording }: RecordingPromptProps) { function RecordingPrompt({ isVisible, onStopRecording }: RecordingPromptProps) {
@ -418,15 +419,15 @@ function RecordingPrompt({ isVisible, onStopRecording }: RecordingPromptProps) {
</motion.div> </motion.div>
)} )}
</AnimatePresence> </AnimatePresence>
) );
} }
interface RecordingControlsProps { interface RecordingControlsProps {
isRecording: boolean isRecording: boolean;
isTranscribing: boolean isTranscribing: boolean;
audioStream: MediaStream | null audioStream: MediaStream | null;
textAreaHeight: number textAreaHeight: number;
onStopRecording: () => void onStopRecording: () => void;
} }
function RecordingControls({ function RecordingControls({
@ -448,7 +449,7 @@ function RecordingControls({
onClick={onStopRecording} onClick={onStopRecording}
/> />
</div> </div>
) );
} }
if (isTranscribing) { if (isTranscribing) {
@ -459,8 +460,8 @@ function RecordingControls({
> >
<TranscribingOverlay /> <TranscribingOverlay />
</div> </div>
) );
} }
return null return null;
} }

View file

@ -2,18 +2,18 @@ import {
ChatMessage, ChatMessage,
type ChatMessageProps, type ChatMessageProps,
type Message, type Message,
} from "@/components/chat-playground/chat-message" } from "@/components/chat-playground/chat-message";
import { TypingIndicator } from "@/components/chat-playground/typing-indicator" import { TypingIndicator } from "@/components/chat-playground/typing-indicator";
type AdditionalMessageOptions = Omit<ChatMessageProps, keyof Message> type AdditionalMessageOptions = Omit<ChatMessageProps, keyof Message>;
interface MessageListProps { interface MessageListProps {
messages: Message[] messages: Message[];
showTimeStamps?: boolean showTimeStamps?: boolean;
isTyping?: boolean isTyping?: boolean;
messageOptions?: messageOptions?:
| AdditionalMessageOptions | AdditionalMessageOptions
| ((message: Message) => AdditionalMessageOptions) | ((message: Message) => AdditionalMessageOptions);
} }
export function MessageList({ export function MessageList({
@ -28,7 +28,7 @@ export function MessageList({
const additionalOptions = const additionalOptions =
typeof messageOptions === "function" typeof messageOptions === "function"
? messageOptions(message) ? messageOptions(message)
: messageOptions : messageOptions;
return ( return (
<ChatMessage <ChatMessage
@ -37,9 +37,9 @@ export function MessageList({
{...message} {...message}
{...additionalOptions} {...additionalOptions}
/> />
) );
})} })}
{isTyping && <TypingIndicator />} {isTyping && <TypingIndicator />}
</div> </div>
) );
} }

Some files were not shown because too many files have changed in this diff Show more